Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ai): llm model moved to share state to guarantee consistency #13627

Merged
merged 2 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/fix-ai-semantic-cache-model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: "Fixed an bug that AI semantic cache can't use request provided models"
type: bugfix
scope: Plugin

2 changes: 1 addition & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ function _M.post_request(conf, response_object)
-- Set the model, response, and provider names in the current try context
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = kong.ctx.plugin.llm_model_requested or conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = llm_state.get_request_model()
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name

-- Set the llm latency meta, and time per token usage
Expand Down
4 changes: 2 additions & 2 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ function _M:header_filter(conf)
end

if ngx.var.http_kong_debug or conf.model_name_header then
local name = conf.model.provider .. "/" .. (kong.ctx.plugin.llm_model_requested or conf.model.name)
local name = conf.model.provider .. "/" .. (llm_state.get_request_model())
kong.response.set_header("X-Kong-LLM-Model", name)
end

Expand Down Expand Up @@ -386,7 +386,7 @@ function _M:access(conf)
return bail(400, "model parameter not found in request, nor in gateway configuration")
end

kong_ctx_plugin.llm_model_requested = conf_m.model.name
llm_state.set_request_model(conf_m.model.name)

-- check the incoming format is the same as the configured LLM format
local compatible, err = llm.is_compatible(request_table, route_type)
Expand Down
8 changes: 8 additions & 0 deletions kong/llm/state.lua
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,12 @@ function _M.get_metrics(key)
return (kong.ctx.shared.llm_metrics or {})[key]
end

function _M.set_request_model(model)
kong.ctx.shared.llm_model_requested = model
end

function _M.get_request_model()
return kong.ctx.shared.llm_model_requested or "NOT_SPECIFIED"
end

return _M
1 change: 1 addition & 0 deletions kong/plugins/ai-request-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ local function create_http_opts(conf)
end

function _M:access(conf)
llm_state.set_request_model(conf.llm.model and conf.llm.model.name)
local kong_ctx_shared = kong.ctx.shared

kong.service.request.enable_buffering()
Expand Down
1 change: 1 addition & 0 deletions kong/plugins/ai-response-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ end


function _M:access(conf)
llm_state.set_request_model(conf.llm.model and conf.llm.model.name)
local kong_ctx_shared = kong.ctx.shared

kong.service.request.enable_buffering()
Expand Down
68 changes: 61 additions & 7 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
local client

lazy_setup(function()
local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME })
local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME, "ctx-checker-last", "ctx-checker" })

-- set up openai mock fixtures
local fixtures = {
Expand Down Expand Up @@ -274,6 +274,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
path = FILE_LOG_PATH_STATS_ONLY,
},
}
bp.plugins:insert {
name = "ctx-checker-last",
route = { id = chat_good.id },
config = {
ctx_kind = "kong.ctx.shared",
ctx_check_field = "llm_model_requested",
ctx_check_value = "gpt-3.5-turbo",
}
}

-- 200 chat good with one option
local chat_good_no_allow_override = assert(bp.routes:insert {
Expand Down Expand Up @@ -544,16 +553,16 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
--

-- 200 chat good but no model set
local chat_good = assert(bp.routes:insert {
-- 200 chat good but no model set in plugin config
local chat_good_no_model = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/openai/llm/v1/chat/good-no-model-param" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = chat_good.id },
route = { id = chat_good_no_model.id },
config = {
route_type = "llm/v1/chat",
auth = {
Expand All @@ -572,11 +581,20 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
bp.plugins:insert {
name = "file-log",
route = { id = chat_good.id },
route = { id = chat_good_no_model.id },
config = {
path = "/dev/stdout",
},
}
bp.plugins:insert {
name = "ctx-checker-last",
route = { id = chat_good_no_model.id },
config = {
ctx_kind = "kong.ctx.shared",
ctx_check_field = "llm_model_requested",
ctx_check_value = "try-to-override-the-model",
}
}
--

-- 200 completions good using post body key
Expand Down Expand Up @@ -755,7 +773,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
},
}
--


-- start kong
assert(helpers.start_kong({
Expand All @@ -764,7 +782,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- use the custom test template to create a local mock server
nginx_conf = "spec/fixtures/custom_nginx.template",
-- make sure our plugin gets loaded
plugins = "bundled," .. PLUGIN_NAME,
plugins = "bundled,ctx-checker-last,ctx-checker," .. PLUGIN_NAME,
-- write & load declarative config, only if 'strategy=off'
declarative_config = strategy == "off" and helpers.make_yaml_file() or nil,
}, nil, nil, fixtures))
Expand Down Expand Up @@ -835,6 +853,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same(first_expected, first_got)
assert.is_true(actual_llm_latency >= 0)
assert.same(actual_time_per_token, time_per_token)
assert.same(first_got.meta.request_model, "gpt-3.5-turbo")
end)

it("does not log statistics", function()
Expand Down Expand Up @@ -1030,6 +1049,9 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)

-- from ctx-checker-last plugin
assert.equals(r.headers["ctx-checker-last-llm-model-requested"], "gpt-3.5-turbo")
end)

it("good request, parses model of cjson.null", function()
Expand Down Expand Up @@ -1110,6 +1132,38 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.is_truthy(json.error)
assert.equals(json.error.message, "request format not recognised")
end)

-- check that kong.ctx.shared.llm_model_requested is set
it("good request setting model from client body", function()
local r = client:get("/openai/llm/v1/chat/good-no-model-param", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "gpt-3.5-turbo-0613")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "openai/try-to-override-the-model")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)

-- from ctx-checker-last plugin
assert.equals(r.headers["ctx-checker-last-llm-model-requested"], "try-to-override-the-model")
end)

end)

describe("openai llm/v1/completions", function()
Expand Down
Loading