Skip to content

Commit

Permalink
fix(ai-proxy): (Cohere-Anthropic)(AG-154): fix cohere and anthropic f…
Browse files Browse the repository at this point in the history
…unction calls coming back empty
  • Loading branch information
AntoineJac authored and tysoekong committed Nov 13, 2024
1 parent f8d379b commit 0142f6f
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-proxy**: Fixed a bug where tools (function) calls to Anthropic would return empty results."
type: bugfix
scope: Plugin
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/ai-cohere-fix-function-calling.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-proxy**: Fixed a bug where tools (function) calls to Cohere would return empty results."
type: bugfix
scope: Plugin
90 changes: 80 additions & 10 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,54 @@ local function kong_messages_to_claude_prompt(messages)
return buf:get()
end

local inject_tool_calls = function(tool_calls)
local tools
for _, n in ipairs(tool_calls) do
tools = tools or {}
table.insert(tools, {
type = "tool_use",
id = n.id,
name = n["function"].name,
input = cjson.decode(n["function"].arguments)
})
end

return tools
end

-- reuse the messages structure of prompt
-- extract messages and system from kong request
local function kong_messages_to_claude_messages(messages)
local msgs, system, n = {}, nil, 1

for _, v in ipairs(messages) do
if v.role ~= "assistant" and v.role ~= "user" then
if v.role ~= "assistant" and v.role ~= "user" and v.role ~= "tool" then
system = v.content

else
msgs[n] = v
if v.role == "assistant" and v.tool_calls then
msgs[n] = {
role = v.role,
content = inject_tool_calls(v.tool_calls),
}
elseif v.role == "tool" then
msgs[n] = {
role = "user",
content = {{
type = "tool_result",
tool_use_id = v.tool_call_id,
content = v.content
}},
}
else
msgs[n] = v
end
n = n + 1
end
end

return msgs, system
end


local function to_claude_prompt(req)
if req.prompt then
return kong_prompt_to_claude_prompt(req.prompt)
Expand All @@ -83,6 +112,21 @@ local function to_claude_messages(req)
return nil, nil, "request is missing .messages command"
end

local function to_tools(in_tools)
local out_tools = {}

for i, v in ipairs(in_tools) do
if v['function'] then
v['function'].input_schema = v['function'].parameters
v['function'].parameters = nil

table.insert(out_tools, v['function'])
end
end

return out_tools
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model)
local messages = {}
Expand All @@ -98,6 +142,10 @@ local transformers_to = {
messages.model = model.name or request_table.model
messages.stream = request_table.stream or false -- explicitly set this if nil

-- handle function calling translation from OpenAI format
messages.tools = request_table.tools and to_tools(request_table.tools)
messages.tool_choice = request_table.tool_choice

return messages, "application/json", nil
end,

Expand Down Expand Up @@ -243,16 +291,37 @@ local transformers_from = {
local function extract_text_from_content(content)
local buf = buffer.new()
for i, v in ipairs(content) do
if i ~= 1 then
buf:put("\n")
if v.text then
if i ~= 1 then
buf:put("\n")
end
buf:put(v.text)
end

buf:put(v.text)
end

return buf:tostring()
end

local function extract_tools_from_content(content)
local tools
for i, v in ipairs(content) do
if v.type == "tool_use" then
tools = tools or {}

table.insert(tools, {
id = v.id,
type = "function",
['function'] = {
name = v.name,
arguments = cjson.encode(v.input),
}
})
end
end

return tools
end

if response_table.content then
local usage = response_table.usage

Expand All @@ -275,13 +344,14 @@ local transformers_from = {
message = {
role = "assistant",
content = extract_text_from_content(response_table.content),
tool_calls = extract_tools_from_content(response_table.content)
},
finish_reason = response_table.stop_reason,
},
},
usage = usage,
model = response_table.model,
object = "chat.content",
object = "chat.completion",
}

return cjson.encode(res)
Expand Down Expand Up @@ -488,7 +558,7 @@ function _M.configure_request(conf)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
-- if auth_param_location is "body", it will have already been set in a pre-request hook
return true, nil
end

Expand Down
32 changes: 17 additions & 15 deletions kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type)
elseif v.role and v.role == "tool" then
local tool_execution_content, err = cjson.decode(v.content)
if err then
return nil, nil, "failed to decode function response arguments: " .. err
return nil, nil, "failed to decode function response arguments, not JSON format"
end

local content = {
Expand Down Expand Up @@ -261,28 +261,30 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type)
content = v.content

elseif v.tool_calls and (type(v.tool_calls) == "table") then
local inputs, err = cjson.decode(v.tool_calls[1]['function'].arguments)
if err then
return nil, nil, "failed to decode function response arguments from assistant: " .. err
end

content = {
{
toolUse = {
toolUseId = v.tool_calls[1].id,
name = v.tool_calls[1]['function'].name,
input = inputs,
for k, tool in ipairs(v.tool_calls) do
local inputs, err = cjson.decode(tool['function'].arguments)
if err then
return nil, nil, "failed to decode function response arguments from assistant's message, not JSON format"
end

content = {
{
toolUse = {
toolUseId = tool.id,
name = tool['function'].name,
input = inputs,
},
},
},
}
}

end

else
content = {
{
text = v.content or ""
},
}

end

-- for any other role, just construct the chat history as 'parts.text' type
Expand Down
38 changes: 37 additions & 1 deletion kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ local _M = {}
local cjson = require("cjson.safe")
local fmt = string.format
local ai_shared = require("kong.llm.drivers.shared")
local openai_driver = require("kong.llm.drivers.openai")
local socket_url = require "socket.url"
local table_new = require("table.new")
local string_gsub = string.gsub
Expand Down Expand Up @@ -260,6 +261,37 @@ local transformers_from = {
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
messages.usage = stats

elseif response_table.message then
-- this is a "co.chat"

messages.choices[1] = {
index = 0,
message = {
role = "assistant",
content = response_table.message.tool_plan or response_table.message.content,
tool_calls = response_table.message.tool_calls
},
finish_reason = response_table.finish_reason,
}
messages.object = "chat.completion"
messages.model = model_info.name
messages.id = response_table.id

local stats = {
completion_tokens = response_table.usage
and response_table.usage.billed_units
and response_table.usage.billed_units.output_tokens,

prompt_tokens = response_table.usage
and response_table.usage.billed_units
and response_table.usage.billed_units.input_tokens,

total_tokens = response_table.usage
and response_table.usage.billed_units
and (response_table.usage.billed_units.output_tokens + response_table.usage.billed_units.input_tokens),
}
messages.usage = stats

else -- probably a fault
return nil, "'text' or 'generations' missing from cohere response body"
Expand Down Expand Up @@ -357,6 +389,10 @@ end
function _M.to_format(request_table, model_info, route_type)
ngx.log(ngx.DEBUG, "converting from kong type to ", model_info.provider, "/", route_type)

if request_table.tools then
return openai_driver.to_format(request_table, model_info, route_type)
end

if route_type == "preserve" then
-- do nothing
return request_table, nil, nil
Expand Down Expand Up @@ -497,7 +533,7 @@ function _M.configure_request(conf)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
-- if auth_param_location is "body", it will have already been set in a pre-request hook
return true, nil
end

Expand Down
1 change: 1 addition & 0 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ local function to_gemini_chat_openai(request_table, model_info, route_type)

-- handle function calling translation from OpenAI format
new_r.tools = request_table.tools and to_tools(request_table.tools)
new_r.tool_config = request_table.tool_config
end

return new_r, "application/json", nil
Expand Down
6 changes: 5 additions & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ function _M.to_ollama(request_table, model)
input.stream = request_table.stream or false -- for future capability
input.model = model.name or request_table.name

-- handle function calling translation from Ollama format
input.tools = request_table.tools
input.tool_choice = request_table.tool_choice

if model.options then
input.options = {}

Expand Down Expand Up @@ -509,7 +513,7 @@ function _M.from_ollama(response_string, model_info, route_type)
output.object = "chat.completion"
output.choices = {
{
finish_reason = stop_reason,
finish_reason = response_table.finish_reason or stop_reason,
index = 0,
message = response_table.message,
}
Expand Down
8 changes: 4 additions & 4 deletions spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- check this is in the 'kong' response format
-- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "claude-2.1")
assert.equals(json.object, "chat.content")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1")

assert.is_table(json.choices)
Expand All @@ -597,7 +597,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- check this is in the 'kong' response format
-- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "claude-2.1")
assert.equals(json.object, "chat.content")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1")

assert.is_table(json.choices)
Expand Down Expand Up @@ -642,7 +642,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- check this is in the 'kong' response format
-- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "claude-2.1")
assert.equals(json.object, "chat.content")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1")

assert.is_table(json.choices)
Expand All @@ -669,7 +669,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- check this is in the 'kong' response format
-- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "claude-2.1")
assert.equals(json.object, "chat.content")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1")

assert.is_table(json.choices)
Expand Down

0 comments on commit 0142f6f

Please sign in to comment.