From 8038a8808129b05714d0bd878a2cb910b8eae7fd Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Sun, 17 Nov 2024 23:20:23 +0000 Subject: [PATCH] fix(ai-proxy): (Bedrock)(AG-166) properly map guardrails between request and response --- .../kong/ai-bedrock-fix-guardrails.yml | 3 ++ kong/llm/drivers/bedrock.lua | 9 +++++ spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 36 ++++++++++++++++++- 3 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 changelog/unreleased/kong/ai-bedrock-fix-guardrails.yml diff --git a/changelog/unreleased/kong/ai-bedrock-fix-guardrails.yml b/changelog/unreleased/kong/ai-bedrock-fix-guardrails.yml new file mode 100644 index 0000000000000..d29cd7bab36d0 --- /dev/null +++ b/changelog/unreleased/kong/ai-bedrock-fix-guardrails.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where Bedrock Guardrail config was ignored." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index e52faa97877ed..4fa0bc7858029 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -27,6 +27,7 @@ local _OPENAI_STOP_REASON_MAPPING = { ["max_tokens"] = "length", ["end_turn"] = "stop", ["tool_use"] = "tool_calls", + ["guardrail_intervened"] = "guardrail_intervened", } _M.bedrock_unsupported_system_role_patterns = { @@ -46,6 +47,10 @@ local function to_bedrock_generation_config(request_table) } end +local function to_bedrock_guardrail_config(guardrail_config) + return guardrail_config -- may be nil; this is handled +end + -- this is a placeholder and is archaic now, -- leave it in for backwards compatibility local function to_additional_request_fields(request_table) @@ -310,6 +315,7 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) end new_r.inferenceConfig = to_bedrock_generation_config(request_table) + new_r.guardrailConfig = to_bedrock_guardrail_config(request_table.guardrailConfig) -- backwards compatibility new_r.toolConfig = request_table.bedrock @@ -375,6 +381,8 @@ local function from_bedrock_chat_openai(response, model_info, route_type) } end + client_response.trace = response.trace -- may be nil, **do not** map to cjson.null + return cjson.encode(client_response) end @@ -598,6 +606,7 @@ end if _G._TEST then -- export locals for testing _M._to_tools = to_tools + _M._to_bedrock_chat_openai = to_bedrock_chat_openai _M._from_tool_call_response = from_tool_call_response end diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 55880d69ebee4..7c0814ad7c44f 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -37,6 +37,24 @@ local SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS = { another_extra_param = 0.5, } +local SAMPLE_LLM_V1_CHAT_WITH_GUARDRAILS = { + messages = { + [1] = { + role = "system", + content = "You are a mathematician." + }, + [2] = { + role = "assistant", + content = "What is 1 + 1?" + }, + }, + guardrailConfig = { + guardrailIdentifier = "yu5xwvfp4sud", + guardrailVersion = "1", + trace = "enabled", + }, +} + local SAMPLE_DOUBLE_FORMAT = { messages = { [1] = { @@ -976,6 +994,22 @@ describe(PLUGIN_NAME .. ": (unit)", function() arguments = "{\"areas\":[121,212,313]}" }) end) - end) + it("transforms guardrails into bedrock generation config", function() + local model_info = { + route_type = "llm/v1/chat", + name = "some-model", + provider = "bedrock", + } + local bedrock_guardrails = bedrock_driver._to_bedrock_chat_openai(SAMPLE_LLM_V1_CHAT_WITH_GUARDRAILS, model_info, "llm/v1/chat") + + assert.not_nil(bedrock_guardrails) + + assert.same(bedrock_guardrails.guardrailConfig, { + ['guardrailIdentifier'] = 'yu5xwvfp4sud', + ['guardrailVersion'] = '1', + ['trace'] = 'enabled', + }) + end) + end) end)