From 18eccd665a4f7fdddb824b314e7305e42d98029d Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 16 Oct 2024 16:01:26 +0100 Subject: [PATCH] fix(ai-proxy): (Gemini)(AG-154) fixed tools-functions calls coming back empty --- .../kong/ai-gemini-fix-function-calling.yml | 3 + kong/llm/drivers/gemini.lua | 168 +++++++++++++----- 2 files changed, 123 insertions(+), 48 deletions(-) create mode 100644 changelog/unreleased/kong/ai-gemini-fix-function-calling.yml diff --git a/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml b/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml new file mode 100644 index 0000000000000..59e6f5baa27e0 --- /dev/null +++ b/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Gemini (or via Vertex) would return empty results." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 16f5b25c36f4c..bfab0b15743e0 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -42,6 +42,25 @@ local function is_response_content(content) and content.candidates[1].content.parts[1].text end +local function is_tool_content(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].content + and content.candidates[1].content.parts + and #content.candidates[1].content.parts > 0 + and content.candidates[1].content.parts[1].functionCall +end + +local function is_function_call_message(message) + return message + and message.role + and message.role == "assistant" + and message.tool_calls + and type(message.tool_calls) == "table" + and #message.tool_calls > 0 +end + local function handle_stream_event(event_t, model_info, route_type) -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then @@ -83,10 +102,28 @@ local function handle_stream_event(event_t, model_info, route_type) end end +local function to_tools(in_tools) + local out_tools + + for i, v in ipairs(in_tools) do + if v['function'] then + out_tools = out_tools or { + [1] = { + function_declarations = {} + } + } + + out_tools[1].function_declarations[i] = v['function'] + end + end + + return out_tools +end + local function to_gemini_chat_openai(request_table, model_info, route_type) - if request_table then -- try-catch type mechanism - local new_r = {} + local new_r = {} + if request_table then if request_table.messages and #request_table.messages > 0 then local system_prompt @@ -96,18 +133,60 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) if v.role and v.role == "system" then system_prompt = system_prompt or buffer.new() system_prompt:put(v.content or "") + + elseif v.role and v.role == "tool" then + -- handle tool execution output + table_insert(new_r.contents, { + role = "function", + parts = { + { + function_response = { + response = { + content = { + v.content, + }, + }, + name = "get_product_info", + }, + }, + }, + }) + + elseif is_function_call_message(v) then + -- treat specific 'assistant function call' tool execution input message + local function_calls = {} + for i, t in ipairs(v.tool_calls) do + function_calls[i] = { + function_call = { + name = t['function'].name, + }, + } + end + + table_insert(new_r.contents, { + role = "function", + parts = function_calls, + }) + else -- for any other role, just construct the chat history as 'parts.text' type new_r.contents = new_r.contents or {} + + local part = v.content + if type(v.content) == "string" then + part = { + text = v.content + } + end + table_insert(new_r.contents, { role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' parts = { - { - text = v.content or "" - }, + part, }, }) end + end -- This was only added in Gemini 1.5 @@ -127,36 +206,11 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) new_r.generationConfig = to_gemini_generation_config(request_table) - return new_r, "application/json", nil - end - - local new_r = {} - - if request_table.messages and #request_table.messages > 0 then - local system_prompt - - for i, v in ipairs(request_table.messages) do - - -- for 'system', we just concat them all into one Gemini instruction - if v.role and v.role == "system" then - system_prompt = system_prompt or buffer.new() - system_prompt:put(v.content or "") - else - -- for any other role, just construct the chat history as 'parts.text' type - new_r.contents = new_r.contents or {} - table_insert(new_r.contents, { - role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' - parts = { - { - text = v.content or "" - }, - }, - }) - end - end + -- handle function calling translation from OpenAI format + new_r.tools = request_table.tools and to_tools(request_table.tools) end - new_r.generationConfig = to_gemini_generation_config(request_table) + kong.log.warn(cjson.encode(new_r)) return new_r, "application/json", nil end @@ -174,20 +228,38 @@ local function from_gemini_chat_openai(response, model_info, route_type) local messages = {} messages.choices = {} - if response.candidates - and #response.candidates > 0 - and is_response_content(response) then - - messages.choices[1] = { - index = 0, - message = { - role = "assistant", - content = response.candidates[1].content.parts[1].text, - }, - finish_reason = string_lower(response.candidates[1].finishReason), - } - messages.object = "chat.completion" - messages.model = model_info.name + if response.candidates and #response.candidates > 0 then + if is_response_content(response) then + messages.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response.candidates[1].content.parts[1].text, + }, + finish_reason = string_lower(response.candidates[1].finishReason), + } + messages.object = "chat.completion" + messages.model = model_info.name + + elseif is_tool_content(response) then + local function_call_responses = response.candidates[1].content.parts + for i, v in ipairs(function_call_responses) do + messages.choices[i] = { + index = 0, + message = { + role = "assistant", + tool_calls = { + { + ['function'] = { + name = v.functionCall.name, + arguments = cjson.encode(v.functionCall.args), + }, + }, + }, + }, + } + end + end -- process analytics if response.usageMetadata then @@ -206,7 +278,7 @@ local function from_gemini_chat_openai(response, model_info, route_type) ngx.log(ngx.ERR, err) return nil, err - else-- probably a server fault or other unexpected response + else -- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) return nil, err