diff --git a/lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py b/lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py index 8f27389a..4f20ab98 100644 --- a/lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py +++ b/lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py @@ -16,6 +16,7 @@ def __init__(self): self.origin_url = None self.forward_url = None self.init = False + self.discarded_data = dict() self.protocol_parsers = [ self.protocol_read_from_path, self.protocol_read_from_header, @@ -43,6 +44,9 @@ def protocol_read_from_path(self, request: web.Request, lb_config): self.netloc = url.netloc self.request = request + self.discarded_data.update({ + 'Lyrebird_Protocol': 'from_path' + }) # Set init success self.init = True @@ -80,6 +84,9 @@ def protocol_read_from_header(self, request: web.Request, lb_config): self.forward_url = f'http://127.0.0.1:{port}/mock/{origin_full_url}' self.netloc = netloc self.request = request + self.discarded_data.update({ + 'Lyrebird_Protocol': 'from_header' + }) # Set init success self.init = True @@ -98,6 +105,9 @@ def protocol_read_from_query(self, request: web.Request, lb_config): query_list = str(request.url).split('proxy=') origin_url = query_list[1] if len(query_list) > 1 else '' + if not origin_url: + return + unquote_origin_url = urlparse.unquote(origin_url) url_obj = urlparse.urlparse(unquote_origin_url) @@ -108,6 +118,16 @@ def protocol_read_from_query(self, request: web.Request, lb_config): self.netloc = url_obj.netloc self.request = request + self.discarded_data.update({ + 'Lyrebird_Protocol': 'from_query' + }) + lyrebird_url_split = query_list[0].split('?') + lyrebird_url_query = lyrebird_url_split[1] if len(lyrebird_url_split)>1 else '' + if len(lyrebird_url_query) > 1: + self.discarded_data['Lyrebird-Protocol-Discarded-Query'] = lyrebird_url_query + if len(request.path) > 1: + self.discarded_data['Lyrebird-Protocol-Discarded-Path'] = request.path + # Set init success self.init = True @@ -147,6 +167,12 @@ def protocol_read_from_query_2(self, request: web.Request, lb_config): self.netloc = url.netloc self.request = request + self.discarded_data.update({ + 'Lyrebird_Protocol': 'from_query_2' + }) + if len(request.path) > 1: + self.discarded_data['Lyrebird-Protocol-Discarded-Path'] = request.path + # Set init success self.init = True diff --git a/lyrebird/mock/extra_mock_server/server.py b/lyrebird/mock/extra_mock_server/server.py index 20962c63..f1a78538 100644 --- a/lyrebird/mock/extra_mock_server/server.py +++ b/lyrebird/mock/extra_mock_server/server.py @@ -63,12 +63,13 @@ async def make_response_header(proxy_resp_headers: dict, context: LyrebirdProxyC async def send_request(context: LyrebirdProxyContext, target_url): async with client.ClientSession(auto_decompress=False) as session: request: web.Request = context.request - headers = {k: v for k, v in request.headers.items() if k.lower() not in [ - 'cache-control', 'host', 'transfer-encoding']} + headers = {k: v for k, v in request.headers.items() if k.lower() not in ('cache-control', 'host', 'transfer-encoding')} if 'Proxy-Raw-Headers' not in request.headers: headers['Proxy-Raw-Headers'] = make_raw_headers_line(request) if 'Lyrebird-Client-Address' not in request.headers: headers['Lyrebird-Client-Address'] = request.remote + if context.discarded_data: + headers.update(context.discarded_data) request_body = None if request.body_exists: request_body = request.content diff --git a/lyrebird/mock/handlers/handler_context.py b/lyrebird/mock/handlers/handler_context.py index 850e317e..4766ced7 100644 --- a/lyrebird/mock/handlers/handler_context.py +++ b/lyrebird/mock/handlers/handler_context.py @@ -67,6 +67,12 @@ def _parse_request(self): # Read raw headers if 'Proxy-Raw-Headers' in self.request.headers: raw_headers = json.loads(self.request.headers['Proxy-Raw-Headers']) + if 'Lyrebird-Protocol' in self.request.headers: + raw_headers['Lyrebird-Protocol'] = self.request.headers['Lyrebird-Protocol'] + if 'Lyrebird-Protocol-Discarded-Path' in self.request.headers: + raw_headers['Lyrebird-Protocol-Discarded-Path'] = self.request.headers['Lyrebird-Protocol-Discarded-Path'] + if 'Lyrebird-Protocol-Discarded-Query' in self.request.headers: + raw_headers['Lyrebird-Protocol-Discarded-Query'] = self.request.headers['Lyrebird-Protocol-Discarded-Query'] # parse path request_info = self._read_origin_request_info_from_url() @@ -211,6 +217,8 @@ def get_request_headers(self): continue if name in unproxy_headers and unproxy_headers[name] in value: continue + if 'lyrebird' in name.lower(): + continue headers[name] = value return headers