From d0e0fa1d58943d338928ad9961b70bb76fc8bdad Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 2 May 2024 15:04:49 +0100 Subject: [PATCH] fix(ai-proxy): fix tests --- kong/plugins/ai-proxy/handler.lua | 12 ++++++------ .../38-ai-proxy/02-openai_integration_spec.lua | 5 ++--- .../02-integration_spec.lua | 1 - .../02-integration_spec.lua | 1 - 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 59feb0c9431a..739c33f0667d 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -79,28 +79,28 @@ local function handle_streaming_frame(conf) if (not finished) and (is_gzip) then chunk = kong_utils.inflate_gzip(chunk) end - + local events = ai_shared.frame_to_events(chunk) - + for _, event in ipairs(events) do local formatted, _, metadata = ai_driver.from_format(event, conf.model, "stream/" .. conf.route_type) local event_t = nil local token_t = nil local err - + if formatted then -- only stream relevant frames back to the user if conf.logging and conf.logging.log_payloads and (formatted ~= "[DONE]") then -- append the "choice" to the buffer, for logging later. this actually works! if not event_t then event_t, err = cjson.decode(formatted) end - + if not err then if not token_t then token_t = get_token_text(event_t) end - + kong.ctx.plugin.ai_stream_log_buffer:put(token_t) end end @@ -112,7 +112,7 @@ local function handle_streaming_frame(conf) if not event_t then event_t, err = cjson.decode(formatted) end - + if not err then if not token_t then token_t = get_token_text(event_t) diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index b7a55183dca3..e9fb74c3114a 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -42,7 +42,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-3.5-turbo', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25, @@ -775,8 +774,8 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_number(log_message.response.size) -- test request bodies - assert.matches('"content": "What is 1 + 1?"', log_message.ai.payload.request, nil, true) - assert.matches('"role": "user"', log_message.ai.payload.request, nil, true) + assert.matches('"content": "What is 1 + 1?"', log_message.ai['ai-proxy'].payload.request, nil, true) + assert.matches('"role": "user"', log_message.ai['ai-proxy'].payload.request, nil, true) -- test response bodies assert.matches('"content": "The sum of 1 + 1 is 2.",', log_message.ai["ai-proxy"].payload.response, nil, true) diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index 2711f4aa393f..00b0391d7499 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -122,7 +122,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-4', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25, diff --git a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua index 13e4b558a3ef..800100c9a67c 100644 --- a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua @@ -179,7 +179,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-4', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25,