Skip to content

Commit

Permalink
Improve jsonlines dump to avoid ensuring ascii characters. (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
superhindupur authored Mar 14, 2024
1 parent 1eb600a commit a7535e7
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 4 deletions.
8 changes: 4 additions & 4 deletions code/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def stream_with_data(body, headers, endpoint):
if line:
lineJson = json.loads(line.lstrip(b"data:").decode("utf-8"))
if "error" in lineJson:
yield json.dumps(lineJson).replace("\n", "\\n") + "\n"
yield json.dumps(lineJson, ensure_ascii=False) + "\n"
response["id"] = lineJson["id"]
response["model"] = lineJson["model"]
response["created"] = lineJson["created"]
Expand All @@ -158,9 +158,9 @@ def stream_with_data(body, headers, endpoint):
"content"
] += deltaText

yield json.dumps(response).replace("\n", "\\n") + "\n"
yield json.dumps(response, ensure_ascii=False) + "\n"
except Exception as e:
yield json.dumps({"error": str(e)}).replace("\n", "\\n") + "\n"
yield json.dumps({"error": str(e)}, ensure_ascii=False) + "\n"


def conversation_with_data(request):
Expand All @@ -172,7 +172,7 @@ def conversation_with_data(request):
status_code = r.status_code
r = r.json()

return Response(json.dumps(r).replace("\n", "\\n"), status=status_code)
return Response(json.dumps(r, ensure_ascii=False), status=status_code)
else:
if request.method == "POST":
return Response(
Expand Down
107 changes: 107 additions & 0 deletions code/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,110 @@ def test_converstation_custom_allows_multiple_messages_from_user(
conversation_id=body["conversation_id"],
orchestrator=self.orchestrator_config,
)


class MockResponse:

def __init__(self, include_error=False):
self.include_error = include_error

def __enter__(self):
return self

def __exit__(self, *args):
return True

def iter_lines(self, chunk_size=512):
message = b'[{"delta": {"content": "A question\\n?", "end_turn": false, "role": "tool"}}]'
line = (
b'{"choices": [{"messages":'
+ message
+ b'}],"created": "response.created","id": "response.id","model": "some-model","object": "response.object"'
)

if self.include_error:
line += b',"error": "An error occurred\\n"'

line += b"}"
return [b"data:" + line]


class TestConversationAzureByod:
def setup_method(self):
self.body = {
"conversation_id": "123",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help?"},
{"role": "user", "content": "What is the meaning of life?"},
],
}

@patch("app.requests.Session")
@patch("app.env_helper")
def test_converstation_azure_byod_returns_correct_response_when_streaming_with_data(
self, env_helper_mock, get_requests_session_mock
):
# given
mock_session = get_requests_session_mock.return_value
response_mock = MockResponse()
mock_session.post = Mock(return_value=response_mock)
env_helper_mock.should_use_data.return_value = True

# when
response = app.test_client().post(
"/api/conversation/azure_byod",
headers={"content-type": "application/json"},
json=self.body,
)

# then
assert response.status_code == 200
assert (
response.data
== b'{"id": "response.id", "model": "some-model", "created": "response.created",'
+ b' "object": "response.object", "choices": [{"messages": [{"content": "A question\\n?", "end_turn": false, "role": "tool"}]}]}\n'
)

@patch("app.requests.Session")
@patch("app.env_helper")
def test_converstation_azure_byod_receives_error_from_search_when_streaming_with_data(
self, env_helper_mock, get_requests_session_mock
):
# given
mock_session = get_requests_session_mock.return_value
response_mock = MockResponse(include_error=True)
mock_session.post = Mock(return_value=response_mock)
env_helper_mock.should_use_data.return_value = True

# when
response = app.test_client().post(
"/api/conversation/azure_byod",
headers={"content-type": "application/json"},
json=self.body,
)

# then
assert response.status_code == 200
assert b'"error": "An error occurred\\n"' in response.data

@patch("app.requests.Session")
@patch("app.env_helper")
def test_converstation_azure_byod_throws_exception_when_streaming_with_data(
self, env_helper_mock, get_requests_session_mock
):
# given
mock_session = get_requests_session_mock.return_value
mock_session.post.side_effect = ValueError("Test exception")
env_helper_mock.should_use_data.return_value = True

# when
response = app.test_client().post(
"/api/conversation/azure_byod",
headers={"content-type": "application/json"},
json=self.body,
)

# then
assert response.status_code == 200
assert b'{"error": "Test exception"}\n' in response.data

0 comments on commit a7535e7

Please sign in to comment.