diff --git a/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml b/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml new file mode 100644 index 0000000000000..59e6f5baa27e0 --- /dev/null +++ b/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Gemini (or via Vertex) would return empty results." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 16f5b25c36f4c..34db97339d7ce 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -42,6 +42,25 @@ local function is_response_content(content) and content.candidates[1].content.parts[1].text end +local function is_tool_content(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].content + and content.candidates[1].content.parts + and #content.candidates[1].content.parts > 0 + and content.candidates[1].content.parts[1].functionCall +end + +local function is_function_call_message(message) + return message + and message.role + and message.role == "assistant" + and message.tool_calls + and type(message.tool_calls) == "table" + and #message.tool_calls > 0 +end + local function handle_stream_event(event_t, model_info, route_type) -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then @@ -83,10 +102,32 @@ local function handle_stream_event(event_t, model_info, route_type) end end +local function to_tools(in_tools) + if not in_tools then + return nil + end + + local out_tools + + for i, v in ipairs(in_tools) do + if v['function'] then + out_tools = out_tools or { + [1] = { + function_declarations = {} + } + } + + out_tools[1].function_declarations[i] = v['function'] + end + end + + return out_tools +end + local function to_gemini_chat_openai(request_table, model_info, route_type) - if request_table then -- try-catch type mechanism - local new_r = {} + local new_r = {} + if request_table then if request_table.messages and #request_table.messages > 0 then local system_prompt @@ -96,18 +137,60 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) if v.role and v.role == "system" then system_prompt = system_prompt or buffer.new() system_prompt:put(v.content or "") + + elseif v.role and v.role == "tool" then + -- handle tool execution output + table_insert(new_r.contents, { + role = "function", + parts = { + { + function_response = { + response = { + content = { + v.content, + }, + }, + name = "get_product_info", + }, + }, + }, + }) + + elseif is_function_call_message(v) then + -- treat specific 'assistant function call' tool execution input message + local function_calls = {} + for i, t in ipairs(v.tool_calls) do + function_calls[i] = { + function_call = { + name = t['function'].name, + }, + } + end + + table_insert(new_r.contents, { + role = "function", + parts = function_calls, + }) + else -- for any other role, just construct the chat history as 'parts.text' type new_r.contents = new_r.contents or {} + + local part = v.content + if type(v.content) == "string" then + part = { + text = v.content + } + end + table_insert(new_r.contents, { role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' parts = { - { - text = v.content or "" - }, + part, }, }) end + end -- This was only added in Gemini 1.5 @@ -127,42 +210,20 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) new_r.generationConfig = to_gemini_generation_config(request_table) - return new_r, "application/json", nil - end - - local new_r = {} - - if request_table.messages and #request_table.messages > 0 then - local system_prompt - - for i, v in ipairs(request_table.messages) do - - -- for 'system', we just concat them all into one Gemini instruction - if v.role and v.role == "system" then - system_prompt = system_prompt or buffer.new() - system_prompt:put(v.content or "") - else - -- for any other role, just construct the chat history as 'parts.text' type - new_r.contents = new_r.contents or {} - table_insert(new_r.contents, { - role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' - parts = { - { - text = v.content or "" - }, - }, - }) - end - end + -- handle function calling translation from OpenAI format + new_r.tools = request_table.tools and to_tools(request_table.tools) end - new_r.generationConfig = to_gemini_generation_config(request_table) + kong.log.warn(cjson.encode(new_r)) return new_r, "application/json", nil end local function from_gemini_chat_openai(response, model_info, route_type) - local response, err = cjson.decode(response) + local err + if response and (type(response) == "string") then + response, err = cjson.decode(response) + end if err then local err_client = "failed to decode response from Gemini" @@ -174,20 +235,38 @@ local function from_gemini_chat_openai(response, model_info, route_type) local messages = {} messages.choices = {} - if response.candidates - and #response.candidates > 0 - and is_response_content(response) then + if response.candidates and #response.candidates > 0 then + if is_response_content(response) then + messages.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response.candidates[1].content.parts[1].text, + }, + finish_reason = string_lower(response.candidates[1].finishReason), + } + messages.object = "chat.completion" + messages.model = model_info.name - messages.choices[1] = { - index = 0, - message = { - role = "assistant", - content = response.candidates[1].content.parts[1].text, - }, - finish_reason = string_lower(response.candidates[1].finishReason), - } - messages.object = "chat.completion" - messages.model = model_info.name + elseif is_tool_content(response) then + local function_call_responses = response.candidates[1].content.parts + for i, v in ipairs(function_call_responses) do + messages.choices[i] = { + index = 0, + message = { + role = "assistant", + tool_calls = { + { + ['function'] = { + name = v.functionCall.name, + arguments = cjson.encode(v.functionCall.args), + }, + }, + }, + }, + } + end + end -- process analytics if response.usageMetadata then @@ -206,7 +285,7 @@ local function from_gemini_chat_openai(response, model_info, route_type) ngx.log(ngx.ERR, err) return nil, err - else-- probably a server fault or other unexpected response + else -- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) return nil, err @@ -471,4 +550,12 @@ function _M.configure_request(conf, identity_interface) return true end + +if _G._TEST then + -- export locals for testing + _M._to_tools = to_tools + _M._from_gemini_chat_openai = from_gemini_chat_openai +end + + return _M diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index 17d0593756e2c..b2c54012bec2a 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -466,7 +466,7 @@ function _M:access(conf) local identity_interface = _KEYBASTION[conf] if identity_interface and identity_interface.error then - llm_state.set_response_transformer_skipped() + llm_state.disable_ai_proxy_response_transform() kong.log.err("error authenticating with cloud-provider, ", identity_interface.error) return bail(500, "LLM request failed before proxying") end diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index ce4b1c6242fd8..0655865b17384 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -80,6 +80,23 @@ local SAMPLE_OPENAI_TOOLS_REQUEST = { }, } +local SAMPLE_GEMINI_TOOLS_RESPONSE = { + candidates = { { + content = { + role = "model", + parts = { { + functionCall = { + name = "sql_execute", + args = { + product_name = "NewPhone" + } + } + } } + }, + finishReason = "STOP", + } }, +} + local SAMPLE_BEDROCK_TOOLS_RESPONSE = { metrics = { latencyMs = 3781 @@ -832,6 +849,71 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) end) + describe("gemini tools", function() + local gemini_driver + + setup(function() + _G._TEST = true + package.loaded["kong.llm.drivers.gemini"] = nil + gemini_driver = require("kong.llm.drivers.gemini") + end) + + teardown(function() + _G._TEST = nil + end) + + it("transforms openai tools to gemini tools GOOD", function() + local gemini_tools = gemini_driver._to_tools(SAMPLE_OPENAI_TOOLS_REQUEST.tools) + + assert.not_nil(gemini_tools) + assert.same(gemini_tools, { + { + function_declarations = { + { + description = "Check a product is in stock.", + name = "check_stock", + parameters = { + properties = { + product_name = { + type = "string" + } + }, + required = { + "product_name" + }, + type = "object" + } + } + } + } + }) + end) + + it("transforms openai tools to gemini tools NO_TOOLS", function() + local gemini_tools = gemini_driver._to_tools(SAMPLE_LLM_V1_CHAT) + + assert.is_nil(gemini_tools) + end) + + it("transforms openai tools to gemini tools NIL", function() + local gemini_tools = gemini_driver._to_tools(nil) + + assert.is_nil(gemini_tools) + end) + + it("transforms gemini tools to openai tools GOOD", function() + local openai_tools = gemini_driver._from_gemini_chat_openai(SAMPLE_GEMINI_TOOLS_RESPONSE, {}, "llm/v1/chat") + + assert.not_nil(openai_tools) + + openai_tools = cjson.decode(openai_tools) + assert.same(openai_tools.choices[1].message.tool_calls[1]['function'], { + name = "sql_execute", + arguments = "{\"product_name\":\"NewPhone\"}" + }) + end) + end) + describe("bedrock tools", function() local bedrock_driver