From 17a920fee8a8b74f8c955223b35a819f2db5afe3 Mon Sep 17 00:00:00 2001 From: Flavio Garcia Date: Mon, 8 Apr 2024 02:27:58 -0400 Subject: [PATCH] feat(transport): add missing methods to transport tornado Fixes: #6 --- peasant/client/transport_tornado.py | 109 ++++++++++++++++++++++++- tests/fixtures/bastiontest/handlers.py | 2 +- tests/transport_requests_test.py | 2 +- tests/transport_tornado_test.py | 44 ++++++++++ 4 files changed, 151 insertions(+), 6 deletions(-) diff --git a/peasant/client/transport_tornado.py b/peasant/client/transport_tornado.py index ae5409c..d03625d 100644 --- a/peasant/client/transport_tornado.py +++ b/peasant/client/transport_tornado.py @@ -43,8 +43,61 @@ def get_tornado_request(url, **kwargs): :return HTTPRequest: """ method = kwargs.get("method", "GET") + + auth_username = kwargs.get("auth_username") + auth_password = kwargs.get("auth_password") + auth_mode = kwargs.get("auth_mode") + connect_timeout = kwargs.get("connect_timeout") + request_timeout = kwargs.get("request_timeout") + if_modified_since = kwargs.get("if_modified_since") + follow_redirects = kwargs.get("follow_redirects") + max_redirects = kwargs.get("max_redirects") + user_agent = kwargs.get("user_agent") + use_gzip = kwargs.get("use_gzip") + network_interface = kwargs.get("network_interface") + streaming_callback = kwargs.get("streaming_callback") + header_callback = kwargs.get("header_callback") + prepare_curl_callback = kwargs.get("prepare_curl_callback") + proxy_host = kwargs.get("proxy_host") + proxy_port = kwargs.get("proxy_port") + proxy_username = kwargs.get("proxy_username") + proxy_password = kwargs.get("proxy_password") + proxy_auth_mode = kwargs.get("proxy_auth_mode") + allow_nonstandard_methods = kwargs.get("allow_nonstandard_methods") + validate_cert = kwargs.get("validate_cert") + ca_certs = kwargs.get("ca_certs") + allow_ipv6 = kwargs.get("allow_ipv6") + client_key = kwargs.get("client_key") + client_cert = kwargs.get("client_cert") + body_producer = kwargs.get("body_producer") + expect_100_continue = kwargs.get("expect_100_continue") + decompress_response = kwargs.get("decompress_response") + ssl_options = kwargs.get("ssl_options") + form_urlencoded = kwargs.get("form_urlencoded", False) - request = HTTPRequest(url, method=method) + request = HTTPRequest( + url, method=method, headers=None, body=None, + auth_username=auth_username, auth_password=auth_password, + auth_mode=auth_mode, connect_timeout=connect_timeout, + request_timeout=request_timeout, + if_modified_since=if_modified_since, + follow_redirects=follow_redirects, max_redirects=max_redirects, + user_agent=user_agent, use_gzip=use_gzip, + network_interface=network_interface, + streaming_callback=streaming_callback, + header_callback=header_callback, + prepare_curl_callback=prepare_curl_callback, + proxy_host=proxy_host, proxy_port=proxy_port, + proxy_username=proxy_username, proxy_password=proxy_password, + proxy_auth_mode=proxy_auth_mode, + allow_nonstandard_methods=allow_nonstandard_methods, + validate_cert=validate_cert, ca_certs=ca_certs, + allow_ipv6=allow_ipv6, client_key=client_key, + client_cert=client_cert, body_producer=body_producer, + expect_100_continue=expect_100_continue, + decompress_response=decompress_response, + ssl_options=ssl_options, + ) body = kwargs.get("body", None) if body: request.body = body @@ -83,8 +136,20 @@ def get_headers(self, **kwargs): headers.update(_headers) return headers + async def delete(self, path: str, **kwargs: dict): + url = self.get_url(path, **kwargs) + kwargs["method"] = "DELETE" + request = get_tornado_request(url, **kwargs) + headers = self.get_headers(**kwargs) + request.headers.update(headers) + try: + result = await self._client.fetch(request) + except HTTPClientError as error: + raise error + return result + async def get(self, path: str, **kwargs: dict): - url = concat_url(self._bastion_address, path, **kwargs) + url = self.get_url(path, **kwargs) request = get_tornado_request(url, **kwargs) headers = self.get_headers(**kwargs) request.headers.update(headers) @@ -95,7 +160,7 @@ async def get(self, path: str, **kwargs: dict): return result async def head(self, path: str, **kwargs: dict): - url = concat_url(self._bastion_address, path, **kwargs) + url = self.get_url(path, **kwargs) kwargs["method"] = "HEAD" request = get_tornado_request(url, **kwargs) headers = self.get_headers(**kwargs) @@ -106,8 +171,32 @@ async def head(self, path: str, **kwargs: dict): raise error return result + async def options(self, path: str, **kwargs: dict): + url = self.get_url(path, **kwargs) + kwargs["method"] = "OPTIONS" + request = get_tornado_request(url, **kwargs) + headers = self.get_headers(**kwargs) + request.headers.update(headers) + try: + result = await self._client.fetch(request) + except HTTPClientError as error: + raise error + return result + + async def patch(self, path: str, **kwargs: dict): + url = self.get_url(path, **kwargs) + kwargs["method"] = "PATCH" + request = get_tornado_request(url, **kwargs) + headers = self.get_headers(**kwargs) + request.headers.update(headers) + try: + result = await self._client.fetch(request) + except HTTPClientError as error: + raise error + return result + async def post(self, path: str, **kwargs: dict): - url = concat_url(self._bastion_address, path, **kwargs) + url = self.get_url(path, **kwargs) kwargs["method"] = "POST" request = get_tornado_request(url, **kwargs) headers = self.get_headers(**kwargs) @@ -117,3 +206,15 @@ async def post(self, path: str, **kwargs: dict): except HTTPClientError as error: raise error return result + + async def put(self, path: str, **kwargs: dict): + url = self.get_url(path, **kwargs) + kwargs["method"] = "PUT" + request = get_tornado_request(url, **kwargs) + headers = self.get_headers(**kwargs) + request.headers.update(headers) + try: + result = await self._client.fetch(request) + except HTTPClientError as error: + raise error + return result diff --git a/tests/fixtures/bastiontest/handlers.py b/tests/fixtures/bastiontest/handlers.py index b2fd03d..f332272 100644 --- a/tests/fixtures/bastiontest/handlers.py +++ b/tests/fixtures/bastiontest/handlers.py @@ -19,7 +19,7 @@ def head(self): class DeleteHandler(tornadoweb.TornadoHandler): def delete(self): - body = self.request.body + body = "da body" self.add_header("request-body", body) self.write("Delete method output") diff --git a/tests/transport_requests_test.py b/tests/transport_requests_test.py index bb99415..2fd5f0d 100644 --- a/tests/transport_requests_test.py +++ b/tests/transport_requests_test.py @@ -46,7 +46,7 @@ async def test_delete(self): expected_body = "da body" expected_content = b"Delete method output" try: - response = self.transport.delete("/delete", data="da body") + response = self.transport.delete("/delete") except Exception as e: raise e self.assertEqual(expected_body, response.headers.get("request-body")) diff --git a/tests/transport_tornado_test.py b/tests/transport_tornado_test.py index 19c3d01..aaf6be4 100644 --- a/tests/transport_tornado_test.py +++ b/tests/transport_tornado_test.py @@ -42,6 +42,17 @@ async def test_head(self): self.assertEqual(response.headers.get("user-agent"), self.transport.user_agent) + @gen_test + async def test_delete(self): + expected_body = "da body" + expected_content = b"Delete method output" + try: + response = await self.transport.delete("/delete") + except Exception as e: + raise e + self.assertEqual(expected_body, response.headers.get("request-body")) + self.assertEqual(expected_content, response.body) + @gen_test async def test_get(self): try: @@ -50,6 +61,28 @@ async def test_get(self): raise e self.assertEqual(response.body, b"Get method output") + @gen_test + async def test_options(self): + expected_body = "da body" + expected_content = b"Options method output" + try: + response = await self.transport.options("/options") + except Exception as e: + raise e + self.assertEqual(expected_body, response.headers.get("request-body")) + self.assertEqual(expected_content, response.body) + + @gen_test + async def test_patch(self): + expected_body = "da body" + expected_content = b"Patch method output" + try: + response = await self.transport.patch("/patch", body="da body") + except Exception as e: + raise e + self.assertEqual(expected_body, response.headers.get("request-body")) + self.assertEqual(expected_content, response.body) + @gen_test async def test_post(self): expected_body = "da body" @@ -59,3 +92,14 @@ async def test_post(self): raise e self.assertEqual(expected_body, response.headers.get("request-body")) self.assertEqual(response.body, b"Post method output") + + @gen_test + async def test_put(self): + expected_body = "da body" + expected_content = b"Put method output" + try: + response = await self.transport.put("/put", body="da body") + except Exception as e: + raise e + self.assertEqual(expected_body, response.headers.get("request-body")) + self.assertEqual(expected_content, response.body)