Skip to content

Commit

Permalink
fix(ai): llm model moved to share state to guarantee consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Sep 6, 2024
1 parent b861d1f commit 54cd557
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 9 deletions.
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/fix-ai-semantic-cache-model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: "Fixed an bug that AI semantic cache can't use request provided models"
type: bugfix
scope: Plugin

2 changes: 1 addition & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ function _M.post_request(conf, response_object)
-- Set the model, response, and provider names in the current try context
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = kong.ctx.plugin.llm_model_requested or conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = llm_state.get_request_model() or conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name

-- Set the llm latency meta, and time per token usage
Expand Down
4 changes: 2 additions & 2 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ function _M:header_filter(conf)
end

if ngx.var.http_kong_debug or conf.model_name_header then
local name = conf.model.provider .. "/" .. (kong.ctx.plugin.llm_model_requested or conf.model.name)
local name = conf.model.provider .. "/" .. (llm_state.get_request_model() or conf.model.name)
kong.response.set_header("X-Kong-LLM-Model", name)
end

Expand Down Expand Up @@ -386,7 +386,7 @@ function _M:access(conf)
return bail(400, "model parameter not found in request, nor in gateway configuration")
end

kong_ctx_plugin.llm_model_requested = conf_m.model.name
llm_state.set_request_model(conf_m.model.name)

-- check the incoming format is the same as the configured LLM format
local compatible, err = llm.is_compatible(request_table, route_type)
Expand Down
8 changes: 8 additions & 0 deletions kong/llm/state.lua
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,12 @@ function _M.get_metrics(key)
return (kong.ctx.shared.llm_metrics or {})[key]
end

function _M.set_request_model(model)
kong.ctx.shared.llm_model_requested = model
end

function _M.get_request_model()
return kong.ctx.shared.llm_model_requested or "NOT_SPECIFIED"
end

return _M
49 changes: 43 additions & 6 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
local client

lazy_setup(function()
local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME })
local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME, "ctx-checker-last" })

-- set up openai mock fixtures
local fixtures = {
Expand Down Expand Up @@ -544,16 +544,16 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
--

-- 200 chat good but no model set
local chat_good = assert(bp.routes:insert {
-- 200 chat good but no model set in plugin config
local chat_good_no_model = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/openai/llm/v1/chat/good-no-model-param" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = chat_good.id },
route = { id = chat_good_no_model.id },
config = {
route_type = "llm/v1/chat",
auth = {
Expand All @@ -572,11 +572,18 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
bp.plugins:insert {
name = "file-log",
route = { id = chat_good.id },
route = { id = chat_good_no_model.id },
config = {
path = "/dev/stdout",
},
}
bp.plugins:insert {
name = "ctx-checker-last",
route = { id = chat_good_no_model.id },
config = {
ctx_check_field = "kong.ctx.shared.llm_model_requested",
}
}
--

-- 200 completions good using post body key
Expand Down Expand Up @@ -764,7 +771,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- use the custom test template to create a local mock server
nginx_conf = "spec/fixtures/custom_nginx.template",
-- make sure our plugin gets loaded
plugins = "bundled," .. PLUGIN_NAME,
plugins = "bundled," .. PLUGIN_NAME .. ",ctx-checker-last",
-- write & load declarative config, only if 'strategy=off'
declarative_config = strategy == "off" and helpers.make_yaml_file() or nil,
}, nil, nil, fixtures))
Expand Down Expand Up @@ -835,6 +842,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same(first_expected, first_got)
assert.is_true(actual_llm_latency >= 0)
assert.same(actual_time_per_token, time_per_token)
assert.same(first_got.meta.request_model, "gpt-3.5-turbo")
end)

it("does not log statistics", function()
Expand Down Expand Up @@ -1110,6 +1118,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.is_truthy(json.error)
assert.equals(json.error.message, "request format not recognised")
end)

-- check that kong.ctx.shared.llm_model_requested is set
it("good request setting model from client body", function()
local r = client:get("/openai/llm/v1/chat/good-no-model-param", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "gpt-3.5-turbo-0613")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "openai/try-to-override-the-model")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)
end)

end)

describe("openai llm/v1/completions", function()
Expand Down

0 comments on commit 54cd557

Please sign in to comment.