From 9f6fdbeedcf051fd579789b1b064baef1074c39d Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Wed, 13 Nov 2024 15:18:39 +0100 Subject: [PATCH] fix(ai-proxy): (Cohere-Anthropic)(AG-154): fix cohere and anthropic function calls coming back empty --- .../ai-anthropic-fix-function-calling.yml | 3 + .../kong/ai-cohere-fix-function-calling.yml | 3 + kong/llm/drivers/anthropic.lua | 90 ++++++++++++++++--- kong/llm/drivers/bedrock.lua | 32 +++---- kong/llm/drivers/cohere.lua | 38 +++++++- kong/llm/drivers/gemini.lua | 1 + kong/llm/drivers/shared.lua | 6 +- .../03-anthropic_integration_spec.lua | 8 +- 8 files changed, 150 insertions(+), 31 deletions(-) create mode 100644 changelog/unreleased/kong/ai-anthropic-fix-function-calling.yml create mode 100644 changelog/unreleased/kong/ai-cohere-fix-function-calling.yml diff --git a/changelog/unreleased/kong/ai-anthropic-fix-function-calling.yml b/changelog/unreleased/kong/ai-anthropic-fix-function-calling.yml new file mode 100644 index 000000000000..41d2592f46d3 --- /dev/null +++ b/changelog/unreleased/kong/ai-anthropic-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Anthropic would return empty results." +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/ai-cohere-fix-function-calling.yml b/changelog/unreleased/kong/ai-cohere-fix-function-calling.yml new file mode 100644 index 000000000000..6e4885a2a43c --- /dev/null +++ b/changelog/unreleased/kong/ai-cohere-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Cohere would return empty results." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 508b62c4851a..d45f20e3ff3d 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -44,17 +44,47 @@ local function kong_messages_to_claude_prompt(messages) return buf:get() end +local inject_tool_calls = function(tool_calls) + local tools + for _, n in ipairs(tool_calls) do + tools = tools or {} + table.insert(tools, { + type = "tool_use", + id = n.id, + name = n["function"].name, + input = cjson.decode(n["function"].arguments) + }) + end + + return tools +end + -- reuse the messages structure of prompt -- extract messages and system from kong request local function kong_messages_to_claude_messages(messages) local msgs, system, n = {}, nil, 1 for _, v in ipairs(messages) do - if v.role ~= "assistant" and v.role ~= "user" then + if v.role ~= "assistant" and v.role ~= "user" and v.role ~= "tool" then system = v.content - else - msgs[n] = v + if v.role == "assistant" and v.tool_calls then + msgs[n] = { + role = v.role, + content = inject_tool_calls(v.tool_calls), + } + elseif v.role == "tool" then + msgs[n] = { + role = "user", + content = {{ + type = "tool_result", + tool_use_id = v.tool_call_id, + content = v.content + }}, + } + else + msgs[n] = v + end n = n + 1 end end @@ -62,7 +92,6 @@ local function kong_messages_to_claude_messages(messages) return msgs, system end - local function to_claude_prompt(req) if req.prompt then return kong_prompt_to_claude_prompt(req.prompt) @@ -83,6 +112,21 @@ local function to_claude_messages(req) return nil, nil, "request is missing .messages command" end +local function to_tools(in_tools) + local out_tools = {} + + for i, v in ipairs(in_tools) do + if v['function'] then + v['function'].input_schema = v['function'].parameters + v['function'].parameters = nil + + table.insert(out_tools, v['function']) + end + end + + return out_tools +end + local transformers_to = { ["llm/v1/chat"] = function(request_table, model) local messages = {} @@ -98,6 +142,10 @@ local transformers_to = { messages.model = model.name or request_table.model messages.stream = request_table.stream or false -- explicitly set this if nil + -- handle function calling translation from OpenAI format + messages.tools = request_table.tools and to_tools(request_table.tools) + messages.tool_choice = request_table.tool_choice + return messages, "application/json", nil end, @@ -243,16 +291,37 @@ local transformers_from = { local function extract_text_from_content(content) local buf = buffer.new() for i, v in ipairs(content) do - if i ~= 1 then - buf:put("\n") + if v.text then + if i ~= 1 then + buf:put("\n") + end + buf:put(v.text) end - - buf:put(v.text) end return buf:tostring() end + local function extract_tools_from_content(content) + local tools + for i, v in ipairs(content) do + if v.type == "tool_use" then + tools = tools or {} + + table.insert(tools, { + id = v.id, + type = "function", + ['function'] = { + name = v.name, + arguments = cjson.encode(v.input), + } + }) + end + end + + return tools + end + if response_table.content then local usage = response_table.usage @@ -275,13 +344,14 @@ local transformers_from = { message = { role = "assistant", content = extract_text_from_content(response_table.content), + tool_calls = extract_tools_from_content(response_table.content) }, finish_reason = response_table.stop_reason, }, }, usage = usage, model = response_table.model, - object = "chat.content", + object = "chat.completion", } return cjson.encode(res) @@ -488,7 +558,7 @@ function _M.configure_request(conf) end end - -- if auth_param_location is "form", it will have already been set in a pre-request hook + -- if auth_param_location is "body", it will have already been set in a pre-request hook return true, nil end diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index 09c5fea5c1e3..5806d7576929 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -232,7 +232,7 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) elseif v.role and v.role == "tool" then local tool_execution_content, err = cjson.decode(v.content) if err then - return nil, nil, "failed to decode function response arguments: " .. err + return nil, nil, "failed to decode function response arguments, not JSON format" end local content = { @@ -261,20 +261,23 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) content = v.content elseif v.tool_calls and (type(v.tool_calls) == "table") then - local inputs, err = cjson.decode(v.tool_calls[1]['function'].arguments) - if err then - return nil, nil, "failed to decode function response arguments from assistant: " .. err - end - - content = { - { - toolUse = { - toolUseId = v.tool_calls[1].id, - name = v.tool_calls[1]['function'].name, - input = inputs, + for k, tool in ipairs(v.tool_calls) do + local inputs, err = cjson.decode(tool['function'].arguments) + if err then + return nil, nil, "failed to decode function response arguments from assistant's message, not JSON format" + end + + content = { + { + toolUse = { + toolUseId = tool.id, + name = tool['function'].name, + input = inputs, + }, }, - }, - } + } + + end else content = { @@ -282,7 +285,6 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) text = v.content or "" }, } - end -- for any other role, just construct the chat history as 'parts.text' type diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 5f29a928bb0d..a51726f4b76b 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -4,6 +4,7 @@ local _M = {} local cjson = require("cjson.safe") local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") +local openai_driver = require("kong.llm.drivers.openai") local socket_url = require "socket.url" local table_new = require("table.new") local string_gsub = string.gsub @@ -260,6 +261,37 @@ local transformers_from = { and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats + + elseif response_table.message then + -- this is a "co.chat" + + messages.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response_table.message.tool_plan or response_table.message.content, + tool_calls = response_table.message.tool_calls + }, + finish_reason = response_table.finish_reason, + } + messages.object = "chat.completion" + messages.model = model_info.name + messages.id = response_table.id + + local stats = { + completion_tokens = response_table.usage + and response_table.usage.billed_units + and response_table.usage.billed_units.output_tokens, + + prompt_tokens = response_table.usage + and response_table.usage.billed_units + and response_table.usage.billed_units.input_tokens, + + total_tokens = response_table.usage + and response_table.usage.billed_units + and (response_table.usage.billed_units.output_tokens + response_table.usage.billed_units.input_tokens), + } + messages.usage = stats else -- probably a fault return nil, "'text' or 'generations' missing from cohere response body" @@ -357,6 +389,10 @@ end function _M.to_format(request_table, model_info, route_type) ngx.log(ngx.DEBUG, "converting from kong type to ", model_info.provider, "/", route_type) + if request_table.tools then + return openai_driver.to_format(request_table, model_info, route_type) + end + if route_type == "preserve" then -- do nothing return request_table, nil, nil @@ -497,7 +533,7 @@ function _M.configure_request(conf) end end - -- if auth_param_location is "form", it will have already been set in a pre-request hook + -- if auth_param_location is "body", it will have already been set in a pre-request hook return true, nil end diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index d9bafd2e083d..d0d0b26d507b 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -220,6 +220,7 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) -- handle function calling translation from OpenAI format new_r.tools = request_table.tools and to_tools(request_table.tools) + new_r.tool_config = request_table.tool_config end return new_r, "application/json", nil diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 829e10988185..a3b74ed81f56 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -447,6 +447,10 @@ function _M.to_ollama(request_table, model) input.stream = request_table.stream or false -- for future capability input.model = model.name or request_table.name + -- handle function calling translation from Ollama format + input.tools = request_table.tools + input.tool_choice = request_table.tool_choice + if model.options then input.options = {} @@ -509,7 +513,7 @@ function _M.from_ollama(response_string, model_info, route_type) output.object = "chat.completion" output.choices = { { - finish_reason = stop_reason, + finish_reason = response_table.finish_reason or stop_reason, index = 0, message = response_table.message, } diff --git a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua index 51d7e23ae311..43ab765aeead 100644 --- a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua @@ -570,7 +570,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- check this is in the 'kong' response format -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") assert.equals(json.model, "claude-2.1") - assert.equals(json.object, "chat.content") + assert.equals(json.object, "chat.completion") assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") assert.is_table(json.choices) @@ -597,7 +597,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- check this is in the 'kong' response format -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") assert.equals(json.model, "claude-2.1") - assert.equals(json.object, "chat.content") + assert.equals(json.object, "chat.completion") assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") assert.is_table(json.choices) @@ -642,7 +642,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- check this is in the 'kong' response format -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") assert.equals(json.model, "claude-2.1") - assert.equals(json.object, "chat.content") + assert.equals(json.object, "chat.completion") assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") assert.is_table(json.choices) @@ -669,7 +669,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- check this is in the 'kong' response format -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") assert.equals(json.model, "claude-2.1") - assert.equals(json.object, "chat.content") + assert.equals(json.object, "chat.completion") assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") assert.is_table(json.choices)