Skip to content

Commit

Permalink
refactor(plugins): move shared ctx usage of ai plugins to use a prope…
Browse files Browse the repository at this point in the history
…r API

To make typo more obvious to be catched
  • Loading branch information
fffonion committed Aug 7, 2024
1 parent dd37fac commit 5a41d3a
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 43 deletions.
2 changes: 2 additions & 0 deletions kong-3.8.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,8 @@ build = {
["kong.llm.drivers.anthropic"] = "kong/llm/drivers/anthropic.lua",
["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua",
["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua",
["kong.llm.state"] = "kong/llm/state.lua",

["kong.llm.drivers.gemini"] = "kong/llm/drivers/gemini.lua",
["kong.llm.drivers.bedrock"] = "kong/llm/drivers/bedrock.lua",

Expand Down
23 changes: 16 additions & 7 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
local _M = {}

-- imports
<<<<<<< HEAD
local cjson = require("cjson.safe")
local http = require("resty.http")
local fmt = string.format
local os = os
local parse_url = require("socket.url").parse
local aws_stream = require("kong.tools.aws_stream")
=======
local cjson = require("cjson.safe")
local http = require("resty.http")
local fmt = string.format
local os = os
local parse_url = require("socket.url").parse
local llm_state = require("kong.llm.state")
>>>>>>> d9053432f9 (refactor(plugins): move shared ctx usage of ai plugins to use a proper API)
--

-- static
Expand Down Expand Up @@ -341,7 +350,7 @@ function _M.frame_to_events(frame, provider)
end -- if
end
end

return events
end

Expand Down Expand Up @@ -500,7 +509,7 @@ function _M.resolve_plugin_conf(kong_request, conf)
if #splitted ~= 2 then
return nil, "cannot parse expression for field '" .. v .. "'"
end

-- find the request parameter, with the configured name
prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2])
if err then
Expand All @@ -524,7 +533,7 @@ function _M.pre_request(conf, request_table)
local auth_param_name = conf.auth and conf.auth.param_name
local auth_param_value = conf.auth and conf.auth.param_value
local auth_param_location = conf.auth and conf.auth.param_location

if auth_param_name and auth_param_value and auth_param_location == "body" and request_table then
request_table[auth_param_name] = auth_param_value
end
Expand All @@ -547,7 +556,7 @@ function _M.pre_request(conf, request_table)
kong.log.warn("failed calculating cost for prompt tokens: ", err)
prompt_tokens = 0
end
kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens
llm_state.increase_prompt_tokens_count(prompt_tokens)
end

local start_time_key = "ai_request_start_time_" .. plugin_name
Expand Down Expand Up @@ -586,7 +595,7 @@ function _M.post_request(conf, response_object)
end

-- check if we already have analytics in this context
local request_analytics = kong.ctx.shared.analytics
local request_analytics = llm_state.get_request_analytics()

-- create a new structure if not
if not request_analytics then
Expand Down Expand Up @@ -657,7 +666,7 @@ function _M.post_request(conf, response_object)
[log_entry_keys.RESPONSE_BODY] = body_string,
}
request_analytics[plugin_name] = request_analytics_plugin
kong.ctx.shared.analytics = request_analytics
llm_state.set_request_analytics(request_analytics)

if conf.logging and conf.logging.log_statistics then
-- Log meta data
Expand All @@ -679,7 +688,7 @@ function _M.post_request(conf, response_object)
kong.log.warn("failed calculating cost for response tokens: ", err)
response_tokens = 0
end
kong.ctx.shared.ai_response_tokens = (kong.ctx.shared.ai_response_tokens or 0) + response_tokens
llm_state.increase_response_tokens_count(response_tokens)

return nil
end
Expand Down
46 changes: 21 additions & 25 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

local ai_shared = require("kong.llm.drivers.shared")
local llm = require("kong.llm")
local llm_state = require("kong.llm.state")
local cjson = require("cjson.safe")
local kong_utils = require("kong.tools.gzip")
local buffer = require "string.buffer"
Expand Down Expand Up @@ -265,9 +266,8 @@ function _M:header_filter(conf)
kong.ctx.shared.ai_request_body = nil

local kong_ctx_plugin = kong.ctx.plugin
local kong_ctx_shared = kong.ctx.shared

if kong_ctx_shared.skip_response_transformer then
if llm_state.is_response_transformer_skipped() then
return
end

Expand All @@ -282,7 +282,7 @@ function _M:header_filter(conf)
end

-- we use openai's streaming mode (SSE)
if kong_ctx_shared.ai_proxy_streaming_mode then
if llm_state.is_streaming_mode() then
-- we are going to send plaintext event-stream frames for ALL models
kong.response.set_header("Content-Type", "text/event-stream")
return
Expand All @@ -299,7 +299,7 @@ function _M:header_filter(conf)
-- if this is a 'streaming' request, we can't know the final
-- result of the response body, so we just proceed to body_filter
-- to translate each SSE event frame
if not kong_ctx_shared.ai_proxy_streaming_mode then
if not llm_state.is_streaming_mode() then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
response_body = kong_utils.inflate_gzip(response_body)
Expand Down Expand Up @@ -329,22 +329,18 @@ end

function _M:body_filter(conf)
local kong_ctx_plugin = kong.ctx.plugin
local kong_ctx_shared = kong.ctx.shared

-- if body_filter is called twice, then return
if kong_ctx_plugin.body_called and not kong_ctx_shared.ai_proxy_streaming_mode then
if kong_ctx_plugin.body_called and not llm_state.is_streaming_mode() then
return
end

local route_type = conf.route_type

if kong_ctx_shared.skip_response_transformer and (route_type ~= "preserve") then
local response_body
if llm_state.is_response_transformer_skipped() and (route_type ~= "preserve") then
local response_body = llm_state.get_parsed_response()

if kong_ctx_shared.parsed_response then
response_body = kong_ctx_shared.parsed_response

elseif kong.response.get_status() == 200 then
if not response_body and kong.response.get_status() == 200 then
response_body = kong.service.response.get_raw_body()
if not response_body then
kong.log.warn("issue when retrieve the response body for analytics in the body filter phase.",
Expand All @@ -355,6 +351,8 @@ function _M:body_filter(conf)
response_body = kong_utils.inflate_gzip(response_body)
end
end
else
kong.response.exit(500, "no response body found")
end

local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
Expand All @@ -368,13 +366,13 @@ function _M:body_filter(conf)
end
end

if not kong_ctx_shared.skip_response_transformer then
if not llm_state.is_response_transformer_skipped() then
if (kong.response.get_status() ~= 200) and (not kong_ctx_plugin.ai_parser_error) then
return
end

if route_type ~= "preserve" then
if kong_ctx_shared.ai_proxy_streaming_mode then
if llm_state.is_streaming_mode() then
handle_streaming_frame(conf)
else
-- all errors MUST be checked and returned in header_filter
Expand Down Expand Up @@ -406,37 +404,35 @@ end

function _M:access(conf)
local kong_ctx_plugin = kong.ctx.plugin
local kong_ctx_shared = kong.ctx.shared

-- store the route_type in ctx for use in response parsing
local route_type = conf.route_type

kong_ctx_plugin.operation = route_type

local request_table
local multipart = false

-- TODO: the access phase may be called mulitple times also in the balancer phase
-- Refactor this function a bit so that we don't mess them in the same function
local balancer_phase = ngx.get_phase() == "balancer"

-- we may have received a replacement / decorated request body from another AI plugin
if kong_ctx_shared.replacement_request then
local request_table = llm_state.get_replacement_response() -- not used
if request_table then
kong.log.debug("replacement request body received from another AI plugin")
request_table = kong_ctx_shared.replacement_request

else
-- first, calculate the coordinates of the request
local content_type = kong.request.get_header("Content-Type") or "application/json"

request_table = kong_ctx_shared.ai_request_body
request_table = llm_state.get_request_body_table()
if not request_table then
if balancer_phase then
error("Too late to read body", 2)
end

request_table = kong.request.get_body(content_type, nil, conf.max_request_body_size)
kong_ctx_shared.ai_request_body = request_table
llm_state.set_request_body_table(request_table)
end

if not request_table then
Expand Down Expand Up @@ -481,7 +477,7 @@ function _M:access(conf)
if not multipart then
local compatible, err = llm.is_compatible(request_table, route_type)
if not compatible then
kong_ctx_shared.skip_response_transformer = true
llm_state.set_response_transformer_skipped()
return bail(400, err)
end
end
Expand Down Expand Up @@ -511,7 +507,7 @@ function _M:access(conf)
end

-- specific actions need to skip later for this to work
kong_ctx_shared.ai_proxy_streaming_mode = true
llm_state.set_streaming_mode()

else
kong.service.request.enable_buffering()
Expand All @@ -531,7 +527,7 @@ function _M:access(conf)
-- transform the body to Kong-format for this provider/model
parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type)
if err then
kong_ctx_shared.skip_response_transformer = true
llm_state.set_response_transformer_skipped()
return bail(400, err)
end
end
Expand All @@ -549,7 +545,7 @@ function _M:access(conf)
-- get the provider's cached identity interface - nil may come back, which is fine
local identity_interface = _KEYBASTION[conf]
if identity_interface and identity_interface.error then
kong.ctx.shared.skip_response_transformer = true
llm_state.set_response_transformer_skipped()
kong.log.err("error authenticating with cloud-provider, ", identity_interface.error)
return bail(500, "LLM request failed before proxying")
end
Expand All @@ -558,7 +554,7 @@ function _M:access(conf)
local ok, err = ai_driver.configure_request(conf_m,
identity_interface and identity_interface.interface)
if not ok then
kong_ctx_shared.skip_response_transformer = true
llm_state.set_response_transformer_skipped()
kong.log.err("failed to configure request for AI service: ", err)
return bail(500)
end
Expand Down
95 changes: 95 additions & 0 deletions kong/llm/state.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
local _M = {}

function _M.disable_ai_proxy_response_transform()
kong.ctx.shared.llm_disable_ai_proxy_response_transform = true
end

function _M.should_disable_ai_proxy_response_transform()
return kong.ctx.shared.llm_disable_ai_proxy_response_transform == true
end

function _M.set_prompt_decorated()
kong.ctx.shared.llm_prompt_decorated = true
end

function _M.is_prompt_decorated()
return kong.ctx.shared.llm_prompt_decorated == true
end

function _M.set_prompt_guarded()
kong.ctx.shared.llm_prompt_guarded = true
end

function _M.is_prompt_guarded()
return kong.ctx.shared.llm_prompt_guarded == true
end

function _M.set_prompt_templated()
kong.ctx.shared.llm_prompt_templated = true
end

function _M.is_prompt_templated()
return kong.ctx.shared.llm_prompt_templated == true
end

function _M.set_streaming_mode()
kong.ctx.shared.llm_streaming_mode = true
end

function _M.is_streaming_mode()
return kong.ctx.shared.llm_streaming_mode == true
end

function _M.set_parsed_response(response)
kong.ctx.shared.llm_parsed_response = response
end

function _M.get_parsed_response()
return kong.ctx.shared.llm_parsed_response
end

function _M.set_request_body_table(body_t)
kong.ctx.shared.llm_request_body_t = body_t
end

function _M.get_request_body_table()
return kong.ctx.shared.llm_request_body_t
end

function _M.set_replacement_response(response)
kong.ctx.shared.llm_replacement_response = response
end

function _M.get_replacement_response()
return kong.ctx.shared.llm_replacement_response
end

function _M.set_request_analytics(tbl)
kong.ctx.shared.llm_request_analytics = tbl
end

function _M.get_request_analytics()
return kong.ctx.shared.llm_request_analytics
end

function _M.increase_prompt_tokens_count(by)
local count = (kong.ctx.shared.llm_prompt_tokens_count or 0) + by
kong.ctx.shared.llm_prompt_tokens_count = count
return count
end

function _M.get_prompt_tokens_count()
return kong.ctx.shared.llm_prompt_tokens_count
end

function _M.increase_response_tokens_count(by)
local count = (kong.ctx.shared.llm_response_tokens_count or 0) + by
kong.ctx.shared.llm_response_tokens_count = count
return count
end

function _M.get_response_tokens_count()
return kong.ctx.shared.llm_response_tokens_count
end

return _M
3 changes: 2 additions & 1 deletion kong/plugins/ai-prompt-decorator/handler.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local new_tab = require("table.new")
local llm_state = require("kong.llm.state")
local EMPTY = {}


Expand Down Expand Up @@ -52,7 +53,7 @@ end

function plugin:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.ai_prompt_decorated = true -- future use
llm_state.set_prompt_decorated() -- future use

-- if plugin ordering was altered, receive the "decorated" request
local request = kong.request.get_body("application/json", nil, conf.max_request_body_size)
Expand Down
3 changes: 2 additions & 1 deletion kong/plugins/ai-prompt-guard/handler.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local buffer = require("string.buffer")
local llm_state = require("kong.llm.state")
local ngx_re_find = ngx.re.find
local EMPTY = {}

Expand Down Expand Up @@ -116,7 +117,7 @@ end

function plugin:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.ai_prompt_guarded = true -- future use
llm_state.set_prompt_guarded() -- future use

-- if plugin ordering was altered, receive the "decorated" request
local request = kong.request.get_body("application/json", nil, conf.max_request_body_size)
Expand Down
Loading

0 comments on commit 5a41d3a

Please sign in to comment.