Skip to content

Commit

Permalink
feat: preserve custom headers and allow dynamic sessions
Browse files Browse the repository at this point in the history
Preserving custom headers & dynamic session instantiation for bypassing more secure CSRF defense systems (#26). These changes also make it possible to host the MITM site in a subdomain or subdirectory. This feature was required for a project of mine and more people may benefit from its inclusion in the library.
  • Loading branch information
regulad authored Nov 26, 2023
1 parent 2e92f00 commit 307bd06
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 197 deletions.
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ authors:
- family-names: $${author.last#1}
given-names: $${author.first#1}
keywords: [""]
date-released: $${today.zulu}
date-released: $${today.zulu}
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ COPY . /code
# Perhaps run a command:
# CMD authcaptureproxy --my --options --etc
# or expose a port:
# EXPOSE 443/tcp
# EXPOSE 443/tcp
109 changes: 89 additions & 20 deletions authcaptureproxy/auth_capture_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
hdrs,
web,
)
from multidict import CIMultiDict
from yarl import URL

from authcaptureproxy.const import SKIP_AUTO_HEADERS
Expand Down Expand Up @@ -44,17 +45,27 @@ class AuthCaptureProxy:
"""

def __init__(
self, proxy_url: URL, host_url: URL, session: Optional[httpx.AsyncClient] = None
self,
proxy_url: URL,
host_url: URL,
session: Optional[httpx.AsyncClient] = None,
session_factory: Optional[Callable[[], httpx.AsyncClient]] = None,
preserve_headers: bool = False,
) -> None:
"""Initialize proxy object.
Args:
proxy_url (URL): url for proxy location. e.g., http://192.168.1.1/. If there is any path, the path is considered part of the base url. If no explicit port is specified, a random port will be generated. If https is passed in, ssl_context must be provided at start_proxy() or the url will be downgraded to http.
host_url (URL): original url for login, e.g., http://amazon.com
session (httpx.AsyncClient): httpx client to make queries. Optional
session_factory (lambda: httpx.AsyncClient): factory to create the aforementioned httpx client if having one fixed session is insufficient.
preserve_headers (bool): Whether to preserve headers from the backend. Useful in circumventing CSRF protection. Defaults to False.
"""
self.session: httpx.AsyncClient = session if session else httpx.AsyncClient()
self._preserve_headers = preserve_headers
self.session_factory: Callable[[], httpx.AsyncClient] = session_factory or (
lambda: httpx.AsyncClient()
)
self.session: httpx.AsyncClient = session if session else self.session_factory()
self._proxy_url: URL = proxy_url
self._host_url: URL = host_url
self._port: int = proxy_url.explicit_port if proxy_url.explicit_port else 0 # type: ignore
Expand Down Expand Up @@ -163,7 +174,7 @@ async def reset_data(self) -> None:
"""
if self.session:
await self.session.aclose()
self.session = httpx.AsyncClient()
self.session = self.session_factory()
self.last_resp = None
self.init_query = {}
self.query = {}
Expand Down Expand Up @@ -220,13 +231,55 @@ def refresh_modifiers(self, site: Optional[URL] = None) -> None:
refreshed_modifers = get_nested_dict_keys(self.modifiers)
_LOGGER.debug("Refreshed %s modifiers: %s", len(refreshed_modifers), refreshed_modifers)

async def _build_response(
self, response: Optional[httpx.Response] = None, *args, **kwargs
) -> web.Response:
"""
Build a response.
"""
if "headers" not in kwargs and response is not None:
kwargs["headers"] = response.headers.copy() if self._preserve_headers else CIMultiDict()

if hdrs.CONTENT_TYPE in kwargs["headers"] and "content_type" in kwargs:
del kwargs["headers"][hdrs.CONTENT_TYPE]

if hdrs.CONTENT_LENGTH in kwargs["headers"]:
del kwargs["headers"][hdrs.CONTENT_LENGTH]

if hdrs.CONTENT_ENCODING in kwargs["headers"]:
del kwargs["headers"][hdrs.CONTENT_ENCODING]

if hdrs.CONTENT_TRANSFER_ENCODING in kwargs["headers"]:
del kwargs["headers"][hdrs.CONTENT_TRANSFER_ENCODING]

if hdrs.TRANSFER_ENCODING in kwargs["headers"]:
del kwargs["headers"][hdrs.TRANSFER_ENCODING]

if "x-connection-hash" in kwargs["headers"]:
del kwargs["headers"]["x-connection-hash"]

while hdrs.SET_COOKIE in kwargs["headers"]:
del kwargs["headers"][hdrs.SET_COOKIE]

# cache control

if hdrs.CACHE_CONTROL in kwargs["headers"]:
del kwargs["headers"][hdrs.CACHE_CONTROL]

kwargs["headers"][hdrs.CACHE_CONTROL] = "no-cache, no-store, must-revalidate"

return web.Response(*args, **kwargs)

async def all_handler(self, request: web.Request, **kwargs) -> web.Response:
"""Handle all requests.
This handler will exit on succesful test found in self.tests or if a /stop url is seen. This handler can be used with any aiohttp webserver and disabled after registered using self.all_handler_active.
Args
request (web.Request): The request to process
**kwargs: Additional keyword arguments
access_url (URL): The access url for the proxy. Defaults to self.access_url()
host_url (URL): The host url for the proxy. Defaults to self._host_url
Returns
web.Response: The webresponse to the browser
Expand All @@ -236,6 +289,15 @@ async def all_handler(self, request: web.Request, **kwargs) -> web.Response:
web.HTTPNotFound: Return 404 when all_handler is disabled
"""
if "access_url" in kwargs:
access_url = kwargs.pop("access_url")
else:
access_url = self.access_url()

if "host_url" in kwargs:
host_url = kwargs.pop("host_url")
else:
host_url = self._host_url

async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -> None:
"""Process multipart.
Expand Down Expand Up @@ -281,21 +343,21 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -
# if not self.session:
# self.session = httpx.AsyncClient()
method = request.method.lower()
_LOGGER.debug("Received %s: %s for %s", method, str(request.url), self._host_url)
_LOGGER.debug("Received %s: %s for %s", method, str(request.url), host_url)
resp: Optional[httpx.Response] = None
old_url: URL = (
self.access_url().with_host(request.url.host)
if request.url.host and request.url.host != self.access_url().host
else self.access_url()
access_url.with_host(request.url.host)
if request.url.host and request.url.host != access_url.host
else access_url
)
if request.scheme == "http" and self.access_url().scheme == "https":
if request.scheme == "http" and access_url.scheme == "https":
# detect reverse proxy downgrade
_LOGGER.debug("Detected http while should be https; switching to https")
site: str = str(
swap_url(
ignore_query=True,
old_url=old_url.with_scheme("https"),
new_url=self._host_url.with_path("/"),
new_url=host_url.with_path("/"),
url=URL(str(request.url)).with_scheme("https"),
),
)
Expand All @@ -304,7 +366,7 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -
swap_url(
ignore_query=True,
old_url=old_url,
new_url=self._host_url.with_path("/"),
new_url=host_url.with_path("/"),
url=URL(str(request.url)),
),
)
Expand All @@ -331,7 +393,7 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -
self.all_handler_active = False
if self.active:
asyncio.create_task(self.stop_proxy(3))
return web.Response(text="Proxy stopped.")
return await self._build_response(text="Proxy stopped.")
elif (
URL(str(request.url)).path
== re.sub(r"/+", "/", self._proxy_url.with_path(f"{self._proxy_url.path}/resume").path)
Expand All @@ -349,12 +411,12 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -
),
]:
# either base path or resume without anything to resume
site = str(URL(self._host_url))
site = str(URL(host_url))
if method == "get":
self.init_query = self.query.copy()
_LOGGER.debug(
"Starting auth capture proxy for %s",
self._host_url,
host_url,
)
headers = await self.modify_headers(URL(site), request)
skip_auto_headers: List[str] = headers.get(SKIP_AUTO_HEADERS, [])
Expand Down Expand Up @@ -390,11 +452,15 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -
site, headers=headers, follow_redirects=True
)
except ClientConnectionError as ex:
return web.Response(text=f"Error connecting to {site}; please retry: {ex}")
return await self._build_response(
text=f"Error connecting to {site}; please retry: {ex}"
)
except TooManyRedirects as ex:
return web.Response(text=f"Error connecting to {site}; too may redirects: {ex}")
return await self._build_response(
text=f"Error connecting to {site}; too may redirects: {ex}"
)
if resp is None:
return web.Response(text=f"Error connecting to {site}; please retry")
return await self._build_response(text=f"Error connecting to {site}; please retry")
self.last_resp = resp
print_resp(resp)
self.check_redirects()
Expand All @@ -413,7 +479,9 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -
raise web.HTTPFound(location=result)
elif isinstance(result, str):
_LOGGER.debug("Displaying page:\n%s", result)
return web.Response(text=result, content_type="text/html")
return await self._build_response(
resp, text=result, content_type="text/html"
)
else:
_LOGGER.warning("Proxy has no tests; please set.")
content_type = get_content_type(resp)
Expand Down Expand Up @@ -449,7 +517,8 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -
_LOGGER.warning("Modifier %s is not callable: %s", name, ex)
# _LOGGER.debug("Returning modified text:\n%s", text)
if modified:
return web.Response(
return await self._build_response(
resp,
text=text,
content_type=content_type,
)
Expand All @@ -461,7 +530,7 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -
else URL(str(request.url)).path,
content_type,
)
return web.Response(body=resp.content, content_type=content_type)
return await self._build_response(resp, body=resp.content, content_type=content_type)

async def start_proxy(
self, host: Optional[Text] = None, ssl_context: Optional[SSLContext] = None
Expand Down
2 changes: 1 addition & 1 deletion codemeta.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
"dateCreated": "2021-02-03",
"datePublished": "2021-02-03",
"programmingLanguage": "Python"
}
}
Loading

0 comments on commit 307bd06

Please sign in to comment.