Skip to content

Commit

Permalink
fix(ai-proxy): (Gemini)(AG-154) fixed tools-functions calls coming ba…
Browse files Browse the repository at this point in the history
…ck empty
  • Loading branch information
tysoekong committed Nov 13, 2024
1 parent f2e9874 commit 7a580eb
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 49 deletions.
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/ai-gemini-fix-function-calling.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-proxy**: Fixed a bug where tools (function) calls to Gemini (or via Vertex) would return empty results."
type: bugfix
scope: Plugin
183 changes: 135 additions & 48 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ local function is_response_content(content)
and content.candidates[1].content.parts[1].text
end

local function is_tool_content(content)
return content
and content.candidates
and #content.candidates > 0
and content.candidates[1].content
and content.candidates[1].content.parts
and #content.candidates[1].content.parts > 0
and content.candidates[1].content.parts[1].functionCall
end

local function is_function_call_message(message)
return message
and message.role
and message.role == "assistant"
and message.tool_calls
and type(message.tool_calls) == "table"
and #message.tool_calls > 0
end

local function handle_stream_event(event_t, model_info, route_type)
-- discard empty frames, it should either be a random new line, or comment
if (not event_t.data) or (#event_t.data < 1) then
Expand Down Expand Up @@ -83,10 +102,32 @@ local function handle_stream_event(event_t, model_info, route_type)
end
end

local function to_tools(in_tools)
if not in_tools then
return nil
end

local out_tools

for i, v in ipairs(in_tools) do
if v['function'] then
out_tools = out_tools or {
[1] = {
function_declarations = {}
}
}

out_tools[1].function_declarations[i] = v['function']
end
end

return out_tools
end

local function to_gemini_chat_openai(request_table, model_info, route_type)
if request_table then -- try-catch type mechanism
local new_r = {}
local new_r = {}

if request_table then
if request_table.messages and #request_table.messages > 0 then
local system_prompt

Expand All @@ -96,18 +137,60 @@ local function to_gemini_chat_openai(request_table, model_info, route_type)
if v.role and v.role == "system" then
system_prompt = system_prompt or buffer.new()
system_prompt:put(v.content or "")

elseif v.role and v.role == "tool" then
-- handle tool execution output
table_insert(new_r.contents, {
role = "function",
parts = {
{
function_response = {
response = {
content = {
v.content,
},
},
name = "get_product_info",
},
},
},
})

elseif is_function_call_message(v) then
-- treat specific 'assistant function call' tool execution input message
local function_calls = {}
for i, t in ipairs(v.tool_calls) do
function_calls[i] = {
function_call = {
name = t['function'].name,
},
}
end

table_insert(new_r.contents, {
role = "function",
parts = function_calls,
})

else
-- for any other role, just construct the chat history as 'parts.text' type
new_r.contents = new_r.contents or {}

local part = v.content
if type(v.content) == "string" then
part = {
text = v.content
}
end

table_insert(new_r.contents, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
parts = {
{
text = v.content or ""
},
part,
},
})
end

end

-- This was only added in Gemini 1.5
Expand All @@ -127,42 +210,20 @@ local function to_gemini_chat_openai(request_table, model_info, route_type)

new_r.generationConfig = to_gemini_generation_config(request_table)

return new_r, "application/json", nil
end

local new_r = {}

if request_table.messages and #request_table.messages > 0 then
local system_prompt

for i, v in ipairs(request_table.messages) do

-- for 'system', we just concat them all into one Gemini instruction
if v.role and v.role == "system" then
system_prompt = system_prompt or buffer.new()
system_prompt:put(v.content or "")
else
-- for any other role, just construct the chat history as 'parts.text' type
new_r.contents = new_r.contents or {}
table_insert(new_r.contents, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
parts = {
{
text = v.content or ""
},
},
})
end
end
-- handle function calling translation from OpenAI format
new_r.tools = request_table.tools and to_tools(request_table.tools)
end

new_r.generationConfig = to_gemini_generation_config(request_table)
kong.log.warn(cjson.encode(new_r))

return new_r, "application/json", nil
end

local function from_gemini_chat_openai(response, model_info, route_type)
local response, err = cjson.decode(response)
local err
if response and (type(response) == "string") then
response, err = cjson.decode(response)
end

if err then
local err_client = "failed to decode response from Gemini"
Expand All @@ -174,20 +235,38 @@ local function from_gemini_chat_openai(response, model_info, route_type)
local messages = {}
messages.choices = {}

if response.candidates
and #response.candidates > 0
and is_response_content(response) then
if response.candidates and #response.candidates > 0 then
if is_response_content(response) then
messages.choices[1] = {
index = 0,
message = {
role = "assistant",
content = response.candidates[1].content.parts[1].text,
},
finish_reason = string_lower(response.candidates[1].finishReason),
}
messages.object = "chat.completion"
messages.model = model_info.name

messages.choices[1] = {
index = 0,
message = {
role = "assistant",
content = response.candidates[1].content.parts[1].text,
},
finish_reason = string_lower(response.candidates[1].finishReason),
}
messages.object = "chat.completion"
messages.model = model_info.name
elseif is_tool_content(response) then
local function_call_responses = response.candidates[1].content.parts
for i, v in ipairs(function_call_responses) do
messages.choices[i] = {
index = 0,
message = {
role = "assistant",
tool_calls = {
{
['function'] = {
name = v.functionCall.name,
arguments = cjson.encode(v.functionCall.args),
},
},
},
},
}
end
end

-- process analytics
if response.usageMetadata then
Expand All @@ -206,7 +285,7 @@ local function from_gemini_chat_openai(response, model_info, route_type)
ngx.log(ngx.ERR, err)
return nil, err

else-- probably a server fault or other unexpected response
else -- probably a server fault or other unexpected response
local err = "no generation candidates received from Gemini, or max_tokens too short"
ngx.log(ngx.ERR, err)
return nil, err
Expand Down Expand Up @@ -471,4 +550,12 @@ function _M.configure_request(conf, identity_interface)
return true
end


if _G._TEST then
-- export locals for testing
_M._to_tools = to_tools
_M._from_gemini_chat_openai = from_gemini_chat_openai
end


return _M
2 changes: 1 addition & 1 deletion kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ function _M:access(conf)
local identity_interface = _KEYBASTION[conf]

if identity_interface and identity_interface.error then
llm_state.set_response_transformer_skipped()
llm_state.disable_ai_proxy_response_transform()
kong.log.err("error authenticating with cloud-provider, ", identity_interface.error)
return bail(500, "LLM request failed before proxying")
end
Expand Down
82 changes: 82 additions & 0 deletions spec/03-plugins/38-ai-proxy/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,23 @@ local SAMPLE_OPENAI_TOOLS_REQUEST = {
},
}

local SAMPLE_GEMINI_TOOLS_RESPONSE = {
candidates = { {
content = {
role = "model",
parts = { {
functionCall = {
name = "sql_execute",
args = {
product_name = "NewPhone"
}
}
} }
},
finishReason = "STOP",
} },
}

local SAMPLE_BEDROCK_TOOLS_RESPONSE = {
metrics = {
latencyMs = 3781
Expand Down Expand Up @@ -832,6 +849,71 @@ describe(PLUGIN_NAME .. ": (unit)", function()
end)
end)

describe("gemini tools", function()
local gemini_driver

setup(function()
_G._TEST = true
package.loaded["kong.llm.drivers.gemini"] = nil
gemini_driver = require("kong.llm.drivers.gemini")
end)

teardown(function()
_G._TEST = nil
end)

it("transforms openai tools to gemini tools GOOD", function()
local gemini_tools = gemini_driver._to_tools(SAMPLE_OPENAI_TOOLS_REQUEST.tools)

assert.not_nil(gemini_tools)
assert.same(gemini_tools, {
{
function_declarations = {
{
description = "Check a product is in stock.",
name = "check_stock",
parameters = {
properties = {
product_name = {
type = "string"
}
},
required = {
"product_name"
},
type = "object"
}
}
}
}
})
end)

it("transforms openai tools to gemini tools NO_TOOLS", function()
local gemini_tools = gemini_driver._to_tools(SAMPLE_LLM_V1_CHAT)

assert.is_nil(gemini_tools)
end)

it("transforms openai tools to gemini tools NIL", function()
local gemini_tools = gemini_driver._to_tools(nil)

assert.is_nil(gemini_tools)
end)

it("transforms gemini tools to openai tools GOOD", function()
local openai_tools = gemini_driver._from_gemini_chat_openai(SAMPLE_GEMINI_TOOLS_RESPONSE, {}, "llm/v1/chat")

assert.not_nil(openai_tools)

openai_tools = cjson.decode(openai_tools)
assert.same(openai_tools.choices[1].message.tool_calls[1]['function'], {
name = "sql_execute",
arguments = "{\"product_name\":\"NewPhone\"}"
})
end)
end)

describe("bedrock tools", function()
local bedrock_driver

Expand Down

0 comments on commit 7a580eb

Please sign in to comment.