diff --git a/changelog/unreleased/kong/ai-bedrock-fix-function-calling.yml b/changelog/unreleased/kong/ai-bedrock-fix-function-calling.yml new file mode 100644 index 000000000000..622e0532f1da --- /dev/null +++ b/changelog/unreleased/kong/ai-bedrock-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Bedrock would return empty results." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index 7ff646586d71..fa7d695d7f29 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -20,6 +20,13 @@ local _OPENAI_ROLE_MAPPING = { ["system"] = "assistant", ["user"] = "user", ["assistant"] = "assistant", + ["tool"] = "user", +} + +local _OPENAI_STOP_REASON_MAPPING = { + ["max_tokens"] = "length", + ["end_turn"] = "stop", + ["tool_use"] = "tool_calls", } _M.bedrock_unsupported_system_role_patterns = { @@ -51,6 +58,48 @@ local function to_tool_config(request_table) } end +local function to_tools(in_tools) + local out_tools + + for i, v in ipairs(in_tools) do + if v['function'] then + out_tools = out_tools or {} + + out_tools[i] = { + toolSpec = { + name = v['function'].name, + description = v['function'].description, + inputSchema = { + json = v['function'].parameters, + }, + }, + } + end + end + + return out_tools +end + +local function from_tool_call_response(tool_use) + local arguments + + if tool_use['input'] and next(tool_use['input']) then + arguments = cjson.encode(tool_use['input']) + end + + return { + -- set explicit numbering to ensure ordering in later modifications + [1] = { + ['function'] = { + arguments = arguments, + name = tool_use.name, + }, + id = tool_use.toolUseId, + type = "function", + }, + } +end + local function handle_stream_event(event_t, model_info, route_type) local new_event, metadata @@ -113,7 +162,7 @@ local function handle_stream_event(event_t, model_info, route_type) [1] = { delta = {}, index = 0, - finish_reason = body.stopReason, + finish_reason = _OPENAI_STOP_REASON_MAPPING[body.stopReason] or "stop", logprobs = cjson.null, }, }, @@ -144,7 +193,7 @@ local function handle_stream_event(event_t, model_info, route_type) end local function to_bedrock_chat_openai(request_table, model_info, route_type) - if not request_table then -- try-catch type mechanism + if not request_table then local err = "empty request table received for transformation" ngx.log(ngx.ERR, "[bedrock] ", err) return nil, nil, err @@ -164,16 +213,60 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) if v.role and v.role == "system" then system_prompts[#system_prompts+1] = { text = v.content } + elseif v.role and v.role == "tool" then + local tool_execution_content, err = cjson.decode(v.content) + if err then + return nil, nil, "failed to decode function response arguments: " .. err + end + + local content = { + { + toolResult = { + toolUseId = v.tool_call_id, + content = { + { + json = tool_execution_content, + }, + }, + status = v.status, + }, + }, + } + + new_r.messages = new_r.messages or {} + table_insert(new_r.messages, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + content = content, + }) + else local content if type(v.content) == "table" then content = v.content + + elseif v.tool_calls and (type(v.tool_calls) == "table") then + local inputs, err = cjson.decode(v.tool_calls[1]['function'].arguments) + if err then + return nil, nil, "failed to decode function response arguments from assistant: " .. err + end + + content = { + { + toolUse = { + toolUseId = v.tool_calls[1].id, + name = v.tool_calls[1]['function'].name, + input = inputs, + }, + }, + } + else content = { { text = v.content or "" }, } + end -- for any other role, just construct the chat history as 'parts.text' type @@ -199,9 +292,18 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) new_r.inferenceConfig = to_bedrock_generation_config(request_table) + -- backwards compatibility new_r.toolConfig = request_table.bedrock and request_table.bedrock.toolConfig and to_tool_config(request_table) + + if request_table.tools + and type(request_table.tools) == "table" + and #request_table.tools > 0 then + + new_r.toolConfig = new_r.toolConfig or {} + new_r.toolConfig.tools = to_tools(request_table.tools) + end new_r.additionalModelRequestFields = request_table.bedrock and request_table.bedrock.additionalModelRequestFields @@ -219,7 +321,6 @@ local function from_bedrock_chat_openai(response, model_info, route_type) return nil, err_client end - -- messages/choices table is only 1 size, so don't need to static allocate local client_response = {} client_response.choices = {} @@ -229,13 +330,23 @@ local function from_bedrock_chat_openai(response, model_info, route_type) and #response.output.message.content > 0 and response.output.message.content[1].text then - client_response.choices[1] = { + local tool_use, err + if #response.output.message.content > 1 and response.output.message.content[2].toolUse then + tool_use, err = from_tool_call_response(response.output.message.content[2].toolUse) + + if err then + return nil, fmt("unable to process function call response arguments: %s", err) + end + end + + client_response.choices[1] = { index = 0, message = { role = "assistant", content = response.output.message.content[1].text, + tool_calls = tool_use, }, - finish_reason = string_lower(response.stopReason), + finish_reason = _OPENAI_STOP_REASON_MAPPING[response.stopReason] or "stop", } client_response.object = "chat.completion" client_response.model = model_info.name @@ -294,7 +405,7 @@ function _M.to_format(request_table, model_info, route_type) -- do nothing return request_table, nil, nil end - + if not transformers_to[route_type] then return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) end