Skip to content

Commit

Permalink
fix(ai-proxy-bedrock): add toolConfig and additional fields to Bedroc…
Browse files Browse the repository at this point in the history
…k; fix bedrock converse system prompts
  • Loading branch information
tysoekong committed Jul 8, 2024
1 parent 2dd382b commit f5c477c
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM kong:3.6.1
FROM kong:3.7.1

USER root

Expand Down
37 changes: 31 additions & 6 deletions kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ local function to_bedrock_generation_config(request_table)
}
end

local function to_additional_request_fields(request_table)
return {
request_table.bedrock.additionalModelRequestFields
}
end

local function to_tool_config(request_table)
return {
request_table.bedrock.toolConfig
}
end

local function handle_stream_event(event_t, model_info, route_type)
local new_event, metadata

Expand Down Expand Up @@ -139,13 +151,12 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type)
or "bedrock-2023-05-31"

if request_table.messages and #request_table.messages > 0 then
local system_prompt
local system_prompts = {}

for i, v in ipairs(request_table.messages) do
-- for 'system', we just concat them all into one Gemini instruction
if v.role and v.role == "system" then
system_prompt = system_prompt or buffer.new()
system_prompt:put(v.content or "")
system_prompts[#system_prompts+1] = { text = v.content }

else
-- for any other role, just construct the chat history as 'parts.text' type
Expand All @@ -161,14 +172,28 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type)
end
end

-- only works for
if system_prompt then
new_r.system = system_prompt:get()
-- only works for some models
if #system_prompts > 0 then
for _, p in ipairs(ai_shared.bedrock_unsupported_system_role_patterns) do
if model_info.name:find(p) then
return nil, nil, "system prompts are unsupported for model '" .. model_info.name
end
end

new_r.system = system_prompts
end
end

new_r.inferenceConfig = to_bedrock_generation_config(request_table)

new_r.toolConfig = request_table.bedrock
and request_table.bedrock.toolConfig
and to_tool_config(request_table)

new_r.additionalModelRequestFields = request_table.bedrock
and request_table.bedrock.additionalModelRequestFields
and to_additional_request_fields(request_table)

return new_r, "application/json", nil
end

Expand Down
8 changes: 8 additions & 0 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ _M.streaming_has_token_counts = {
["bedrock"] = true,
}

_M.bedrock_unsupported_system_role_patterns = {
"amazon.titan.-.*",
"cohere.command.-text.-.*",
"cohere.command.-light.-text.-.*",
"mistral.mistral.-7b.-instruct.-.*",
"mistral.mixtral.-8x7b.-instruct.-.*",
}

_M.upstream_url_format = {
openai = fmt("%s://api.openai.com:%s", (openai_override and "http") or "https", (openai_override) or "443"),
anthropic = "https://api.anthropic.com:443",
Expand Down
61 changes: 46 additions & 15 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,51 @@ do
function LLM:ai_introspect_body(request, system_prompt, http_opts, response_regex_match)
local err, _

-- set up the request
local ai_request = {
messages = {
[1] = {
role = "system",
content = system_prompt,
-- set up the LLM request for transformation instructions
local ai_request

-- mistral, cohere, titan (with Bedrock) don't support system commands
if self.driver == "bedrock" then
for _, p in ipairs(ai_shared.bedrock_unsupported_system_role_patterns) do
if request.model:find(p) then
ai_request = {
messages = {
[1] = {
role = "user",
content = system_prompt,
},
[2] = {
role = "assistant",
content = "What is the message?",
},
[3] = {
role = "user",
content = request,
}
},
stream = false,
}
break
end
end
end

-- not Bedrock, or didn't match banned pattern - continue as normal
if not ai_request then
ai_request = {
messages = {
[1] = {
role = "system",
content = system_prompt,
},
[2] = {
role = "user",
content = request,
}
},
[2] = {
role = "user",
content = request,
}
},
stream = false,
}
stream = false,
}
end

-- convert it to the specified driver format
ai_request, _, err = self.driver.to_format(ai_request, self.conf.model, "llm/v1/chat")
Expand Down Expand Up @@ -204,8 +235,8 @@ do
}
setmetatable(self, LLM)

local provider = (self.conf.model or {}).provider or "NONE_SET"
local driver_module = "kong.llm.drivers." .. provider
self.provider = (self.conf.model or {}).provider or "NONE_SET"
local driver_module = "kong.llm.drivers." .. self.provider
local ok
ok, self.driver = pcall(require, driver_module)
if not ok then
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/schemas/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ local model_schema = {
type = "string", description = "AI provider request format - Kong translates "
.. "requests to and from the specified backend compatible formats.",
required = true,
one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini" }}},
one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini", "bedrock" }}},
{ name = {
type = "string",
description = "Model name to execute.",
Expand Down

0 comments on commit f5c477c

Please sign in to comment.