diff --git a/lyrebird/mitm/mitm_script.py b/lyrebird/mitm/mitm_script.py index e68fe779..5892a660 100644 --- a/lyrebird/mitm/mitm_script.py +++ b/lyrebird/mitm/mitm_script.py @@ -36,7 +36,7 @@ def to_mock_server(flow: http.HTTPFlow): flow.request.headers['Lyrebird-Client-Address'] = address flow.request.headers['Mitmproxy-Proxy'] = address flow.request.headers['Proxy-Raw-Headers'] = json.dumps({name: flow.request.headers[name] - for name in flow.request.headers}, ensure_ascii=False) + for name in flow.request.headers if name.lower() not in ('host', 'proxy-raw-headers')}, ensure_ascii=False) def request(flow: http.HTTPFlow): diff --git a/lyrebird/mock/extra_mock_server/server.py b/lyrebird/mock/extra_mock_server/server.py index ff2a8410..20962c63 100644 --- a/lyrebird/mock/extra_mock_server/server.py +++ b/lyrebird/mock/extra_mock_server/server.py @@ -7,7 +7,8 @@ import sys import json -from typing import List, Set, Optional +from typing import Set +from multidict import CIMultiDict from lyrebird.mock.extra_mock_server.lyrebird_proxy_protocol import LyrebirdProxyContext from lyrebird import log @@ -37,25 +38,25 @@ def make_raw_headers_line(request: web.Request): for k, v in request.raw_headers: raw_header_name = k.decode() raw_header_value = v.decode() - if raw_header_name.lower() in ['cache-control', 'host', 'transfer-encoding']: + if raw_header_name.lower() in ['cache-control', 'host', 'transfer-encoding', 'proxy-raw-headers']: continue raw_headers[raw_header_name] = raw_header_value return json.dumps(raw_headers, ensure_ascii=False) async def make_response_header(proxy_resp_headers: dict, context: LyrebirdProxyContext, data=None): - response_headers = {} + response_headers = CIMultiDict() for k, v in proxy_resp_headers.items(): if k.lower() == 'content-length': if data is not None: - response_headers[k] = str(len(data)) + response_headers.add(k, str(len(data))) elif k.lower() == 'host': - response_headers['Host'] = context.netloc + response_headers.add('Host', context.netloc) elif k.lower() == 'location': - response_headers['Host'] = context.netloc - response_headers[k] = v + response_headers.add('Host', context.netloc) + response_headers.add(k, v) else: - response_headers[k] = v + response_headers.add(k, v) return response_headers @@ -64,8 +65,10 @@ async def send_request(context: LyrebirdProxyContext, target_url): request: web.Request = context.request headers = {k: v for k, v in request.headers.items() if k.lower() not in [ 'cache-control', 'host', 'transfer-encoding']} - headers['Proxy-Raw-Headers'] = make_raw_headers_line(request) - headers['Lyrebird-Client-Address'] = request.remote + if 'Proxy-Raw-Headers' not in request.headers: + headers['Proxy-Raw-Headers'] = make_raw_headers_line(request) + if 'Lyrebird-Client-Address' not in request.headers: + headers['Lyrebird-Client-Address'] = request.remote request_body = None if request.body_exists: request_body = request.content