Skip to content

Commit

Permalink
add request discarded header during lyrebird protocol convert
Browse files Browse the repository at this point in the history
  • Loading branch information
noO0ob committed Dec 9, 2024
1 parent f3aea14 commit 04b7c9f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
26 changes: 26 additions & 0 deletions lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
self.origin_url = None
self.forward_url = None
self.init = False
self.discarded_data = dict()
self.protocol_parsers = [
self.protocol_read_from_path,
self.protocol_read_from_header,
Expand Down Expand Up @@ -43,6 +44,9 @@ def protocol_read_from_path(self, request: web.Request, lb_config):

self.netloc = url.netloc
self.request = request
self.discarded_data.update({
'Lyrebird_Protocol': 'from_path'
})

# Set init success
self.init = True
Expand Down Expand Up @@ -80,6 +84,9 @@ def protocol_read_from_header(self, request: web.Request, lb_config):
self.forward_url = f'http://127.0.0.1:{port}/mock/{origin_full_url}'
self.netloc = netloc
self.request = request
self.discarded_data.update({
'Lyrebird_Protocol': 'from_header'
})

# Set init success
self.init = True
Expand All @@ -98,6 +105,9 @@ def protocol_read_from_query(self, request: web.Request, lb_config):
query_list = str(request.url).split('proxy=')
origin_url = query_list[1] if len(query_list) > 1 else ''

if not origin_url:
return

unquote_origin_url = urlparse.unquote(origin_url)
url_obj = urlparse.urlparse(unquote_origin_url)

Expand All @@ -108,6 +118,16 @@ def protocol_read_from_query(self, request: web.Request, lb_config):
self.netloc = url_obj.netloc
self.request = request

self.discarded_data.update({
'Lyrebird_Protocol': 'from_query'
})
lyrebird_url_split = query_list[0].split('?')
lyrebird_url_query = lyrebird_url_split[1] if len(lyrebird_url_split)>1 else ''
if len(lyrebird_url_query) > 1:
self.discarded_data['Lyrebird-Protocol-Discarded-Query'] = lyrebird_url_query
if len(request.path) > 1:
self.discarded_data['Lyrebird-Protocol-Discarded-Path'] = request.path

# Set init success
self.init = True

Expand Down Expand Up @@ -147,6 +167,12 @@ def protocol_read_from_query_2(self, request: web.Request, lb_config):
self.netloc = url.netloc
self.request = request

self.discarded_data.update({
'Lyrebird_Protocol': 'from_query_2'
})
if len(request.path) > 1:
self.discarded_data['Lyrebird-Protocol-Discarded-Path'] = request.path

# Set init success
self.init = True

Expand Down
5 changes: 3 additions & 2 deletions lyrebird/mock/extra_mock_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ async def make_response_header(proxy_resp_headers: dict, context: LyrebirdProxyC
async def send_request(context: LyrebirdProxyContext, target_url):
async with client.ClientSession(auto_decompress=False) as session:
request: web.Request = context.request
headers = {k: v for k, v in request.headers.items() if k.lower() not in [
'cache-control', 'host', 'transfer-encoding']}
headers = {k: v for k, v in request.headers.items() if k.lower() not in ('cache-control', 'host', 'transfer-encoding')}
if 'Proxy-Raw-Headers' not in request.headers:
headers['Proxy-Raw-Headers'] = make_raw_headers_line(request)
if 'Lyrebird-Client-Address' not in request.headers:
headers['Lyrebird-Client-Address'] = request.remote
if context.discarded_data:
headers.update(context.discarded_data)
request_body = None
if request.body_exists:
request_body = request.content
Expand Down
8 changes: 8 additions & 0 deletions lyrebird/mock/handlers/handler_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def _parse_request(self):
# Read raw headers
if 'Proxy-Raw-Headers' in self.request.headers:
raw_headers = json.loads(self.request.headers['Proxy-Raw-Headers'])
if 'Lyrebird-Protocol' in self.request.headers:
raw_headers['Lyrebird-Protocol'] = self.request.headers['Lyrebird-Protocol']
if 'Lyrebird-Protocol-Discarded-Path' in self.request.headers:
raw_headers['Lyrebird-Protocol-Discarded-Path'] = self.request.headers['Lyrebird-Protocol-Discarded-Path']
if 'Lyrebird-Protocol-Discarded-Query' in self.request.headers:
raw_headers['Lyrebird-Protocol-Discarded-Query'] = self.request.headers['Lyrebird-Protocol-Discarded-Query']

# parse path
request_info = self._read_origin_request_info_from_url()
Expand Down Expand Up @@ -211,6 +217,8 @@ def get_request_headers(self):
continue
if name in unproxy_headers and unproxy_headers[name] in value:
continue
if 'lyrebird' in name.lower():
continue
headers[name] = value
return headers

Expand Down

0 comments on commit 04b7c9f

Please sign in to comment.