diff --git a/CITATION.cff b/CITATION.cff index 3422384..acee2f5 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -7,4 +7,4 @@ authors: - family-names: $${author.last#1} given-names: $${author.first#1} keywords: [""] -date-released: $${today.zulu} +date-released: $${today.zulu} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 7f8fa03..53137b6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,4 +42,4 @@ COPY . /code # Perhaps run a command: # CMD authcaptureproxy --my --options --etc # or expose a port: -# EXPOSE 443/tcp +# EXPOSE 443/tcp \ No newline at end of file diff --git a/authcaptureproxy/auth_capture_proxy.py b/authcaptureproxy/auth_capture_proxy.py index 6f941d0..367ec80 100644 --- a/authcaptureproxy/auth_capture_proxy.py +++ b/authcaptureproxy/auth_capture_proxy.py @@ -16,6 +16,7 @@ hdrs, web, ) +from multidict import CIMultiDict from yarl import URL from authcaptureproxy.const import SKIP_AUTO_HEADERS @@ -44,7 +45,12 @@ class AuthCaptureProxy: """ def __init__( - self, proxy_url: URL, host_url: URL, session: Optional[httpx.AsyncClient] = None + self, + proxy_url: URL, + host_url: URL, + session: Optional[httpx.AsyncClient] = None, + session_factory: Optional[Callable[[], httpx.AsyncClient]] = None, + preserve_headers: bool = False, ) -> None: """Initialize proxy object. @@ -52,9 +58,14 @@ def __init__( proxy_url (URL): url for proxy location. e.g., http://192.168.1.1/. If there is any path, the path is considered part of the base url. If no explicit port is specified, a random port will be generated. If https is passed in, ssl_context must be provided at start_proxy() or the url will be downgraded to http. host_url (URL): original url for login, e.g., http://amazon.com session (httpx.AsyncClient): httpx client to make queries. Optional - + session_factory (lambda: httpx.AsyncClient): factory to create the aforementioned httpx client if having one fixed session is insufficient. + preserve_headers (bool): Whether to preserve headers from the backend. Useful in circumventing CSRF protection. Defaults to False. """ - self.session: httpx.AsyncClient = session if session else httpx.AsyncClient() + self._preserve_headers = preserve_headers + self.session_factory: Callable[[], httpx.AsyncClient] = session_factory or ( + lambda: httpx.AsyncClient() + ) + self.session: httpx.AsyncClient = session if session else self.session_factory() self._proxy_url: URL = proxy_url self._host_url: URL = host_url self._port: int = proxy_url.explicit_port if proxy_url.explicit_port else 0 # type: ignore @@ -163,7 +174,7 @@ async def reset_data(self) -> None: """ if self.session: await self.session.aclose() - self.session = httpx.AsyncClient() + self.session = self.session_factory() self.last_resp = None self.init_query = {} self.query = {} @@ -220,6 +231,45 @@ def refresh_modifiers(self, site: Optional[URL] = None) -> None: refreshed_modifers = get_nested_dict_keys(self.modifiers) _LOGGER.debug("Refreshed %s modifiers: %s", len(refreshed_modifers), refreshed_modifers) + async def _build_response( + self, response: Optional[httpx.Response] = None, *args, **kwargs + ) -> web.Response: + """ + Build a response. + """ + if "headers" not in kwargs and response is not None: + kwargs["headers"] = response.headers.copy() if self._preserve_headers else CIMultiDict() + + if hdrs.CONTENT_TYPE in kwargs["headers"] and "content_type" in kwargs: + del kwargs["headers"][hdrs.CONTENT_TYPE] + + if hdrs.CONTENT_LENGTH in kwargs["headers"]: + del kwargs["headers"][hdrs.CONTENT_LENGTH] + + if hdrs.CONTENT_ENCODING in kwargs["headers"]: + del kwargs["headers"][hdrs.CONTENT_ENCODING] + + if hdrs.CONTENT_TRANSFER_ENCODING in kwargs["headers"]: + del kwargs["headers"][hdrs.CONTENT_TRANSFER_ENCODING] + + if hdrs.TRANSFER_ENCODING in kwargs["headers"]: + del kwargs["headers"][hdrs.TRANSFER_ENCODING] + + if "x-connection-hash" in kwargs["headers"]: + del kwargs["headers"]["x-connection-hash"] + + while hdrs.SET_COOKIE in kwargs["headers"]: + del kwargs["headers"][hdrs.SET_COOKIE] + + # cache control + + if hdrs.CACHE_CONTROL in kwargs["headers"]: + del kwargs["headers"][hdrs.CACHE_CONTROL] + + kwargs["headers"][hdrs.CACHE_CONTROL] = "no-cache, no-store, must-revalidate" + + return web.Response(*args, **kwargs) + async def all_handler(self, request: web.Request, **kwargs) -> web.Response: """Handle all requests. @@ -227,6 +277,9 @@ async def all_handler(self, request: web.Request, **kwargs) -> web.Response: Args request (web.Request): The request to process + **kwargs: Additional keyword arguments + access_url (URL): The access url for the proxy. Defaults to self.access_url() + host_url (URL): The host url for the proxy. Defaults to self._host_url Returns web.Response: The webresponse to the browser @@ -236,6 +289,15 @@ async def all_handler(self, request: web.Request, **kwargs) -> web.Response: web.HTTPNotFound: Return 404 when all_handler is disabled """ + if "access_url" in kwargs: + access_url = kwargs.pop("access_url") + else: + access_url = self.access_url() + + if "host_url" in kwargs: + host_url = kwargs.pop("host_url") + else: + host_url = self._host_url async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -> None: """Process multipart. @@ -281,21 +343,21 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - # if not self.session: # self.session = httpx.AsyncClient() method = request.method.lower() - _LOGGER.debug("Received %s: %s for %s", method, str(request.url), self._host_url) + _LOGGER.debug("Received %s: %s for %s", method, str(request.url), host_url) resp: Optional[httpx.Response] = None old_url: URL = ( - self.access_url().with_host(request.url.host) - if request.url.host and request.url.host != self.access_url().host - else self.access_url() + access_url.with_host(request.url.host) + if request.url.host and request.url.host != access_url.host + else access_url ) - if request.scheme == "http" and self.access_url().scheme == "https": + if request.scheme == "http" and access_url.scheme == "https": # detect reverse proxy downgrade _LOGGER.debug("Detected http while should be https; switching to https") site: str = str( swap_url( ignore_query=True, old_url=old_url.with_scheme("https"), - new_url=self._host_url.with_path("/"), + new_url=host_url.with_path("/"), url=URL(str(request.url)).with_scheme("https"), ), ) @@ -304,7 +366,7 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - swap_url( ignore_query=True, old_url=old_url, - new_url=self._host_url.with_path("/"), + new_url=host_url.with_path("/"), url=URL(str(request.url)), ), ) @@ -331,7 +393,7 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - self.all_handler_active = False if self.active: asyncio.create_task(self.stop_proxy(3)) - return web.Response(text="Proxy stopped.") + return await self._build_response(text="Proxy stopped.") elif ( URL(str(request.url)).path == re.sub(r"/+", "/", self._proxy_url.with_path(f"{self._proxy_url.path}/resume").path) @@ -349,12 +411,12 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - ), ]: # either base path or resume without anything to resume - site = str(URL(self._host_url)) + site = str(URL(host_url)) if method == "get": self.init_query = self.query.copy() _LOGGER.debug( "Starting auth capture proxy for %s", - self._host_url, + host_url, ) headers = await self.modify_headers(URL(site), request) skip_auto_headers: List[str] = headers.get(SKIP_AUTO_HEADERS, []) @@ -390,11 +452,15 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - site, headers=headers, follow_redirects=True ) except ClientConnectionError as ex: - return web.Response(text=f"Error connecting to {site}; please retry: {ex}") + return await self._build_response( + text=f"Error connecting to {site}; please retry: {ex}" + ) except TooManyRedirects as ex: - return web.Response(text=f"Error connecting to {site}; too may redirects: {ex}") + return await self._build_response( + text=f"Error connecting to {site}; too may redirects: {ex}" + ) if resp is None: - return web.Response(text=f"Error connecting to {site}; please retry") + return await self._build_response(text=f"Error connecting to {site}; please retry") self.last_resp = resp print_resp(resp) self.check_redirects() @@ -413,7 +479,9 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - raise web.HTTPFound(location=result) elif isinstance(result, str): _LOGGER.debug("Displaying page:\n%s", result) - return web.Response(text=result, content_type="text/html") + return await self._build_response( + resp, text=result, content_type="text/html" + ) else: _LOGGER.warning("Proxy has no tests; please set.") content_type = get_content_type(resp) @@ -449,7 +517,8 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - _LOGGER.warning("Modifier %s is not callable: %s", name, ex) # _LOGGER.debug("Returning modified text:\n%s", text) if modified: - return web.Response( + return await self._build_response( + resp, text=text, content_type=content_type, ) @@ -461,7 +530,7 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - else URL(str(request.url)).path, content_type, ) - return web.Response(body=resp.content, content_type=content_type) + return await self._build_response(resp, body=resp.content, content_type=content_type) async def start_proxy( self, host: Optional[Text] = None, ssl_context: Optional[SSLContext] = None diff --git a/codemeta.json b/codemeta.json index ee3a996..03a7965 100644 --- a/codemeta.json +++ b/codemeta.json @@ -28,4 +28,4 @@ "dateCreated": "2021-02-03", "datePublished": "2021-02-03", "programmingLanguage": "Python" -} +} \ No newline at end of file diff --git a/cov.xml b/cov.xml index 5a97ec9..c06395a 100644 --- a/cov.xml +++ b/cov.xml @@ -1,14 +1,14 @@ - + - /Users/alandtse/auth_capture_proxy/authcaptureproxy + authcaptureproxy - + - + @@ -42,16 +42,10 @@ - - - - - - - + @@ -65,140 +59,134 @@ - + - - - - - - - - - - - - + + + + + + + + + + + - - - - + + - - + - - + + - - + + + + - + - - - + + + - - - - - - + + + + + + - - - - - - - - + + + + + + + - - - - - - + + + + + + - + + + + - - - + + - - - - - - - + + + + + + + + + + + + + + + - + - - + - - - - + - - + - - - - - - - - - - - - - + + - - - + + + + + + + + - - @@ -206,107 +194,107 @@ - - + + + + + + - - + + + - - - - - - - - - + + + + + + + + + - + - - - - - - - - - - - - + - - - - - - - - - - + + - - - - + + + + + + - + + + + + + + + + + - + + - - - - + + + + + + @@ -317,65 +305,101 @@ - - + - - - - + + + + + - - - - - - - + + + + + + - + + + + + + + + + + - + + - - - - - + + + + + - - - - + + + + - - - - - - + + + + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml index ac791fe..f4ddef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,7 +225,7 @@ long_description = "tool.poetry.description" # "<> >,orcid:<>>" #] authors = ["Alan D. Tse "] -contributors = [] +contributors = ["Alan D. Tse ", "Parker Wahle "] # Turn this into a literal list if it is different than the authors maintainers = "tool.tyrannosaurus.sources.authors"