# Copyright 2016 OVH SAS
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import mock

from neutron.agent.linux import tc_lib
from neutron.tests import base

DEVICE_NAME = "tap_device"
KERNEL_HZ_VALUE = 1000
BW_LIMIT = 2000  # [kbps]
BURST = 100  # [kbit]
LATENCY = 50  # [ms]

TC_OUTPUT = (
    'qdisc tbf 8011: root refcnt 2 rate %(bw)skbit burst %(burst)skbit '
    'lat 50.0ms \n') % {'bw': BW_LIMIT, 'burst': BURST}


class BaseUnitConversionTest(object):

    def test_convert_to_kilobits_bare_value(self):
        value = "1000"
        expected_value = 8  # kbit
        self.assertEqual(
            expected_value,
            tc_lib.convert_to_kilobits(value, self.base_unit)
        )

    def test_convert_to_kilobits_bytes_value(self):
        value = "1000b"
        expected_value = 8  # kbit
        self.assertEqual(
            expected_value,
            tc_lib.convert_to_kilobits(value, self.base_unit)
        )

    def test_convert_to_kilobits_bits_value(self):
        value = "1000bit"
        expected_value = tc_lib.bits_to_kilobits(1000, self.base_unit)
        self.assertEqual(
            expected_value,
            tc_lib.convert_to_kilobits(value, self.base_unit)
        )

    def test_convert_to_kilobits_megabytes_value(self):
        value = "1m"
        expected_value = tc_lib.bits_to_kilobits(
            self.base_unit ** 2 * 8, self.base_unit)
        self.assertEqual(
            expected_value,
            tc_lib.convert_to_kilobits(value, self.base_unit)
        )

    def test_convert_to_kilobits_megabits_value(self):
        value = "1mbit"
        expected_value = tc_lib.bits_to_kilobits(
            self.base_unit ** 2, self.base_unit)
        self.assertEqual(
            expected_value,
            tc_lib.convert_to_kilobits(value, self.base_unit)
        )

    def test_convert_to_bytes_wrong_unit(self):
        value = "1Zbit"
        self.assertRaises(
            tc_lib.InvalidUnit,
            tc_lib.convert_to_kilobits, value, self.base_unit
        )

    def test_bytes_to_bits(self):
        test_values = [
            (0, 0),  # 0 bytes should be 0 bits
            (1, 8)   # 1 byte should be 8 bits
        ]
        for input_bytes, expected_bits in test_values:
            self.assertEqual(
                expected_bits, tc_lib.bytes_to_bits(input_bytes)
            )


class TestSIUnitConversions(BaseUnitConversionTest, base.BaseTestCase):

    base_unit = tc_lib.SI_BASE

    def test_bits_to_kilobits(self):
        test_values = [
            (0, 0),  # 0 bites should be 0 kilobites
            (1, 1),  # 1 bit should be 1 kilobit
            (999, 1),  # 999 bits should be 1 kilobit
            (1000, 1),  # 1000 bits should be 1 kilobit
            (1001, 2)   # 1001 bits should be 2 kilobits
        ]
        for input_bits, expected_kilobits in test_values:
            self.assertEqual(
                expected_kilobits,
                tc_lib.bits_to_kilobits(input_bits, self.base_unit)
            )


class TestIECUnitConversions(BaseUnitConversionTest, base.BaseTestCase):

    base_unit = tc_lib.IEC_BASE

    def test_bits_to_kilobits(self):
        test_values = [
            (0, 0),  # 0 bites should be 0 kilobites
            (1, 1),  # 1 bit should be 1 kilobit
            (1023, 1),  # 1023 bits should be 1 kilobit
            (1024, 1),  # 1024 bits should be 1 kilobit
            (1025, 2)   # 1025 bits should be 2 kilobits
        ]
        for input_bits, expected_kilobits in test_values:
            self.assertEqual(
                expected_kilobits,
                tc_lib.bits_to_kilobits(input_bits, self.base_unit)
            )


class TestTcCommand(base.BaseTestCase):
    def setUp(self):
        super(TestTcCommand, self).setUp()
        self.tc = tc_lib.TcCommand(DEVICE_NAME, KERNEL_HZ_VALUE)
        self.bw_limit = "%s%s" % (BW_LIMIT, tc_lib.BW_LIMIT_UNIT)
        self.burst = "%s%s" % (BURST, tc_lib.BURST_UNIT)
        self.latency = "%s%s" % (LATENCY, tc_lib.LATENCY_UNIT)
        self.execute = mock.patch('neutron.agent.common.utils.execute').start()

    def test_check_kernel_hz_lower_then_zero(self):
        self.assertRaises(
            tc_lib.InvalidKernelHzValue,
            tc_lib.TcCommand, DEVICE_NAME, 0
        )
        self.assertRaises(
            tc_lib.InvalidKernelHzValue,
            tc_lib.TcCommand, DEVICE_NAME, -100
        )

    def test_get_bw_limits(self):
        self.execute.return_value = TC_OUTPUT
        bw_limit, burst_limit = self.tc.get_bw_limits()
        self.assertEqual(BW_LIMIT, bw_limit)
        self.assertEqual(BURST, burst_limit)

    def test_get_bw_limits_when_wrong_qdisc(self):
        output = TC_OUTPUT.replace("tbf", "different_qdisc")
        self.execute.return_value = output
        bw_limit, burst_limit = self.tc.get_bw_limits()
        self.assertIsNone(bw_limit)
        self.assertIsNone(burst_limit)

    def test_get_bw_limits_when_wrong_units(self):
        output = TC_OUTPUT.replace("kbit", "Xbit")
        self.execute.return_value = output
        self.assertRaises(tc_lib.InvalidUnit, self.tc.get_bw_limits)

    def test_set_bw_limit(self):
        self.tc.set_bw_limit(BW_LIMIT, BURST, LATENCY)
        self.execute.assert_called_once_with(
            ["tc", "qdisc", "replace", "dev", DEVICE_NAME,
             "root", "tbf", "rate", self.bw_limit,
             "latency", self.latency,
             "burst", self.burst],
            run_as_root=True,
            check_exit_code=True,
            log_fail_as_error=True,
            extra_ok_codes=None
        )

    def test_update_bw_limit(self):
        self.tc.update_bw_limit(BW_LIMIT, BURST, LATENCY)
        self.execute.assert_called_once_with(
            ["tc", "qdisc", "replace", "dev", DEVICE_NAME,
             "root", "tbf", "rate", self.bw_limit,
             "latency", self.latency,
             "burst", self.burst],
            run_as_root=True,
            check_exit_code=True,
            log_fail_as_error=True,
            extra_ok_codes=None
        )

    def test_delete_bw_limit(self):
        self.tc.delete_bw_limit()
        self.execute.assert_called_once_with(
            ["tc", "qdisc", "del", "dev", DEVICE_NAME, "root"],
            run_as_root=True,
            check_exit_code=True,
            log_fail_as_error=True,
            extra_ok_codes=[2]
        )

    def test_burst_value_when_burst_bigger_then_minimal(self):
        result = self.tc.get_burst_value(BW_LIMIT, BURST)
        self.assertEqual(BURST, result)

    def test_burst_value_when_burst_smaller_then_minimal(self):
        result = self.tc.get_burst_value(BW_LIMIT, 0)
        self.assertEqual(2, result)

    def test__get_min_burst_value_in_bits(self):
        result = self.tc._get_min_burst_value(BW_LIMIT)
        #if input is 2000kbit and kernel_hz is configured to 1000 then
        # min_burst should be 2 kbit
        self.assertEqual(2, result)
