diff --git a/lyrebird/mock/extra_mock_server/server.py b/lyrebird/mock/extra_mock_server/server.py index 994e1f44..8c3bdc0b 100644 --- a/lyrebird/mock/extra_mock_server/server.py +++ b/lyrebird/mock/extra_mock_server/server.py @@ -43,6 +43,22 @@ def make_raw_headers_line(request: web.Request): return json.dumps(raw_headers, ensure_ascii=False) +async def make_response_header(proxy_resp_headers: dict, context: LyrebirdProxyContext, data=None): + response_headers = {} + for k, v in proxy_resp_headers.items(): + if k.lower() == 'content-length': + if data is not None: + response_headers[k] = str(len(data)) + elif k.lower() == 'host': + response_headers['Host'] = context.netloc + elif k.lower() == 'location': + response_headers['Host'] = context.netloc + response_headers[k] = v + else: + response_headers[k] = v + return response_headers + + async def send_request(context: LyrebirdProxyContext, target_url): async with client.ClientSession(auto_decompress=False) as session: request: web.Request = context.request @@ -64,24 +80,19 @@ async def send_request(context: LyrebirdProxyContext, target_url): ) as _resp: proxy_resp_status = _resp.status proxy_resp_headers = _resp.headers - # TODO support stream response - proxy_resp_data = await _resp.read() - - response_headers = {} - for k, v in proxy_resp_headers.items(): - if k.lower() in ['transfer-encoding']: - continue - elif k.lower() == 'content-length': - response_headers[k] = str(len(proxy_resp_data)) - elif k.lower() == 'host': - response_headers['Host'] = context.netloc - elif k.lower() == 'location': - response_headers['Host'] = context.netloc - response_headers[k] = v - else: - response_headers[k] = v - - resp = web.Response(status=proxy_resp_status, body=proxy_resp_data, headers=response_headers) + if 'Transfer-Encoding' in proxy_resp_headers and proxy_resp_headers.get('Transfer-Encoding') == 'chunked': + response_headers = await make_response_header(proxy_resp_headers, context) + resp = web.StreamResponse(status=proxy_resp_status, headers=response_headers) + await resp.prepare(request) + async for data in _resp.content.iter_any(): + await resp.write(data) + await resp.write_eof() + logger.info(f'Stream Request finished: {proxy_resp_status} {context.origin_url}') + else: + proxy_resp_data = await _resp.read() + response_headers = await make_response_header(proxy_resp_headers, context, proxy_resp_data) + resp = web.Response(status=proxy_resp_status, body=proxy_resp_data, headers=response_headers) + logger.info(f'Bytes Response finished: {proxy_resp_status} {context.origin_url}') return resp