Skip to content

Commit

Permalink
fix(ai-proxy): (Gemini)(AG-154) fixed tools-functions calls coming ba…
Browse files Browse the repository at this point in the history
…ck empty
  • Loading branch information
tysoekong committed Oct 25, 2024
1 parent 963a09c commit f6ef879
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 74 deletions.
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/ai-gemini-fix-function-calling.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-proxy**: Fixed a bug where tools (function) calls to Gemini (or via Vertex) would return empty results."
type: bugfix
scope: Plugin
53 changes: 27 additions & 26 deletions kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,33 @@ local function to_tools(in_tools)
return out_tools
end

local function from_tool_call_response(tool_use)
local arguments
local function from_tool_call_response(content)
if not content then return nil end

if tool_use['input'] and next(tool_use['input']) then
arguments = cjson.encode(tool_use['input'])
local tools_used

for _, t in ipairs(content) do
if t.toolUse then
tools_used = tools_used or {}

local arguments
if t.toolUse['input'] and next(t.toolUse['input']) then
arguments = cjson.encode(t.toolUse['input'])
end

tools_used[#tools_used+1] = {
-- set explicit numbering to ensure ordering in later modifications
['function'] = {
arguments = arguments,
name = t.toolUse.name,
},
id = t.toolUse.toolUseId,
type = "function",
}
end
end

return {
-- set explicit numbering to ensure ordering in later modifications
[1] = {
['function'] = {
arguments = arguments,
name = tool_use.name,
},
id = tool_use.toolUseId,
type = "function",
},
}
return tools_used
end

local function handle_stream_event(event_t, model_info, route_type)
Expand Down Expand Up @@ -326,23 +335,15 @@ local function from_bedrock_chat_openai(response, model_info, route_type)
if response.output
and response.output.message
and response.output.message.content
and #response.output.message.content > 0
and response.output.message.content[1].text then
and #response.output.message.content > 0 then

local tool_use, err
if #response.output.message.content > 1 and response.output.message.content[2].toolUse then
tool_use, err = from_tool_call_response(response.output.message.content[2].toolUse)

if err then
return nil, fmt("unable to process function call response arguments: %s", err)
end
end
local tool_use, err = from_tool_call_response(response.output.message.content)

client_response.choices[1] = {
index = 0,
message = {
role = "assistant",
content = response.output.message.content[1].text,
content = response.output.message.content[1].text, -- may be nil
tool_calls = tool_use,
},
finish_reason = _OPENAI_STOP_REASON_MAPPING[response.stopReason] or "stop",
Expand Down
168 changes: 120 additions & 48 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ local function is_response_content(content)
and content.candidates[1].content.parts[1].text
end

local function is_tool_content(content)
return content
and content.candidates
and #content.candidates > 0
and content.candidates[1].content
and content.candidates[1].content.parts
and #content.candidates[1].content.parts > 0
and content.candidates[1].content.parts[1].functionCall
end

local function is_function_call_message(message)
return message
and message.role
and message.role == "assistant"
and message.tool_calls
and type(message.tool_calls) == "table"
and #message.tool_calls > 0
end

local function handle_stream_event(event_t, model_info, route_type)
-- discard empty frames, it should either be a random new line, or comment
if (not event_t.data) or (#event_t.data < 1) then
Expand Down Expand Up @@ -83,10 +102,28 @@ local function handle_stream_event(event_t, model_info, route_type)
end
end

local function to_tools(in_tools)
local out_tools

for i, v in ipairs(in_tools) do
if v['function'] then
out_tools = out_tools or {
[1] = {
function_declarations = {}
}
}

out_tools[1].function_declarations[i] = v['function']
end
end

return out_tools
end

local function to_gemini_chat_openai(request_table, model_info, route_type)
if request_table then -- try-catch type mechanism
local new_r = {}
local new_r = {}

if request_table then
if request_table.messages and #request_table.messages > 0 then
local system_prompt

Expand All @@ -96,18 +133,60 @@ local function to_gemini_chat_openai(request_table, model_info, route_type)
if v.role and v.role == "system" then
system_prompt = system_prompt or buffer.new()
system_prompt:put(v.content or "")

elseif v.role and v.role == "tool" then
-- handle tool execution output
table_insert(new_r.contents, {
role = "function",
parts = {
{
function_response = {
response = {
content = {
v.content,
},
},
name = "get_product_info",
},
},
},
})

elseif is_function_call_message(v) then
-- treat specific 'assistant function call' tool execution input message
local function_calls = {}
for i, t in ipairs(v.tool_calls) do
function_calls[i] = {
function_call = {
name = t['function'].name,
},
}
end

table_insert(new_r.contents, {
role = "function",
parts = function_calls,
})

else
-- for any other role, just construct the chat history as 'parts.text' type
new_r.contents = new_r.contents or {}

local part = v.content
if type(v.content) == "string" then
part = {
text = v.content
}
end

table_insert(new_r.contents, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
parts = {
{
text = v.content or ""
},
part,
},
})
end

end

-- This was only added in Gemini 1.5
Expand All @@ -127,36 +206,11 @@ local function to_gemini_chat_openai(request_table, model_info, route_type)

new_r.generationConfig = to_gemini_generation_config(request_table)

return new_r, "application/json", nil
end

local new_r = {}

if request_table.messages and #request_table.messages > 0 then
local system_prompt

for i, v in ipairs(request_table.messages) do

-- for 'system', we just concat them all into one Gemini instruction
if v.role and v.role == "system" then
system_prompt = system_prompt or buffer.new()
system_prompt:put(v.content or "")
else
-- for any other role, just construct the chat history as 'parts.text' type
new_r.contents = new_r.contents or {}
table_insert(new_r.contents, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
parts = {
{
text = v.content or ""
},
},
})
end
end
-- handle function calling translation from OpenAI format
new_r.tools = request_table.tools and to_tools(request_table.tools)
end

new_r.generationConfig = to_gemini_generation_config(request_table)
kong.log.warn(cjson.encode(new_r))

return new_r, "application/json", nil
end
Expand All @@ -174,20 +228,38 @@ local function from_gemini_chat_openai(response, model_info, route_type)
local messages = {}
messages.choices = {}

if response.candidates
and #response.candidates > 0
and is_response_content(response) then

messages.choices[1] = {
index = 0,
message = {
role = "assistant",
content = response.candidates[1].content.parts[1].text,
},
finish_reason = string_lower(response.candidates[1].finishReason),
}
messages.object = "chat.completion"
messages.model = model_info.name
if response.candidates and #response.candidates > 0 then
if is_response_content(response) then
messages.choices[1] = {
index = 0,
message = {
role = "assistant",
content = response.candidates[1].content.parts[1].text,
},
finish_reason = string_lower(response.candidates[1].finishReason),
}
messages.object = "chat.completion"
messages.model = model_info.name

elseif is_tool_content(response) then
local function_call_responses = response.candidates[1].content.parts
for i, v in ipairs(function_call_responses) do
messages.choices[i] = {
index = 0,
message = {
role = "assistant",
tool_calls = {
{
['function'] = {
name = v.functionCall.name,
arguments = cjson.encode(v.functionCall.args),
},
},
},
},
}
end
end

-- process analytics
if response.usageMetadata then
Expand All @@ -206,7 +278,7 @@ local function from_gemini_chat_openai(response, model_info, route_type)
ngx.log(ngx.ERR, err)
return nil, err

else-- probably a server fault or other unexpected response
else -- probably a server fault or other unexpected response
local err = "no generation candidates received from Gemini, or max_tokens too short"
ngx.log(ngx.ERR, err)
return nil, err
Expand Down

0 comments on commit f6ef879

Please sign in to comment.