diff --git a/peasant/client/transport.py b/peasant/client/transport.py index 9049b85..e1f8b90 100644 --- a/peasant/client/transport.py +++ b/peasant/client/transport.py @@ -16,6 +16,7 @@ import logging import typing +from urllib.parse import urlparse if typing.TYPE_CHECKING: from peasant.client.protocol import Peasant @@ -23,6 +24,13 @@ logger = logging.getLogger(__name__) +def fix_address(address): + parsed_address = urlparse(address) + if parsed_address.path.endswith("/"): + parsed_address = parsed_address._replace(path=parsed_address.path[:-1]) + return parsed_address.geturl() + + class Transport: _peasant: Peasant diff --git a/tests/runtests.py b/tests/runtests.py index 1e089de..ac67bca 100644 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -15,13 +15,14 @@ # limitations under the License. import unittest -from tests import tornado_test +from tests import tornado_test, transport_test def suite(): testLoader = unittest.TestLoader() alltests = unittest.TestSuite() alltests.addTests(testLoader.loadTestsFromModule(tornado_test)) + alltests.addTests(testLoader.loadTestsFromModule(transport_test)) return alltests diff --git a/tests/transport_test.py b/tests/transport_test.py new file mode 100644 index 0000000..6b1da21 --- /dev/null +++ b/tests/transport_test.py @@ -0,0 +1,35 @@ +# Copyright 2020-2024 Flavio Garcia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peasant.client.transport import fix_address +from unittest import TestCase + + +class TransportTestCase(TestCase): + + def test_fix_address(self): + address = "http://localhost" + expected_address = "http://localhost" + fixed_address = fix_address(address) + self.assertEqual(expected_address, fixed_address) + + address = "https://localhost/" + expected_address = "https://localhost" + fixed_address = fix_address(address) + self.assertEqual(expected_address, fixed_address) + + address = "http://localhost/a/path/" + expected_address = "http://localhost/a/path" + fixed_address = fix_address(address) + self.assertEqual(expected_address, fixed_address)