Skip to content

Commit

Permalink
feat(plugins): ai-transformer plugins (#12341)
Browse files Browse the repository at this point in the history
* feat(plugins): ai-transformer plugins
fix(ai-transformers): use correct http opts variables
fix(spec): ai-transformer plugin tests
fix(ai-transformer): PR comments

* Update kong/plugins/ai-response-transformer/schema.lua

Co-authored-by: Michael Martin <[email protected]>

* fix(azure-llm): missing api_version query param

* Update spec/03-plugins/38-ai-proxy/01-unit_spec.lua

Co-authored-by: Michael Martin <[email protected]>

---------

Co-authored-by: Michael Martin <[email protected]>
(cherry picked from commit 3ef9235)
  • Loading branch information
tysoekong authored and github-actions[bot] committed Jan 25, 2024
1 parent afae9d0 commit 35ae4ba
Show file tree
Hide file tree
Showing 34 changed files with 1,977 additions and 74 deletions.
8 changes: 8 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ plugins/ai-prompt-template:
- changed-files:
- any-glob-to-any-file: kong/plugins/ai-prompt-template/**/*

plugins/ai-request-transformer:
- changed-files:
- any-glob-to-any-file: ['kong/plugins/ai-request-transformer/**/*', 'kong/llm/**/*']

plugins/ai-response-transformer:
- changed-files:
- any-glob-to-any-file: ['kong/plugins/ai-response-transformer/**/*', 'kong/llm/**/*']

plugins/aws-lambda:
- changed-files:
- any-glob-to-any-file: kong/plugins/aws-lambda/**/*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: Introduced the new **AI Request Transformer** plugin that enables passing mid-flight consumer requests to an LLM for transformation or sanitization.
type: feature
scope: Plugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: Introduced the new **AI Response Transformer** plugin that enables passing mid-flight upstream responses to an LLM for transformation or sanitization.
type: feature
scope: Plugin
6 changes: 6 additions & 0 deletions kong-3.6.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,12 @@ build = {
["kong.plugins.ai-proxy.handler"] = "kong/plugins/ai-proxy/handler.lua",
["kong.plugins.ai-proxy.schema"] = "kong/plugins/ai-proxy/schema.lua",

["kong.plugins.ai-request-transformer.handler"] = "kong/plugins/ai-request-transformer/handler.lua",
["kong.plugins.ai-request-transformer.schema"] = "kong/plugins/ai-request-transformer/schema.lua",

["kong.plugins.ai-response-transformer.handler"] = "kong/plugins/ai-response-transformer/handler.lua",
["kong.plugins.ai-response-transformer.schema"] = "kong/plugins/ai-response-transformer/schema.lua",

["kong.llm"] = "kong/llm/init.lua",
["kong.llm.drivers.shared"] = "kong/llm/drivers/shared.lua",
["kong.llm.drivers.openai"] = "kong/llm/drivers/openai.lua",
Expand Down
2 changes: 2 additions & 0 deletions kong/constants.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ local plugins = {
"ai-proxy",
"ai-prompt-decorator",
"ai-prompt-template",
"ai-request-transformer",
"ai-response-transformer",
}

local plugin_map = {}
Expand Down
6 changes: 4 additions & 2 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
error("body must be table or string")
end

local url = fmt(
-- may be overridden
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
Expand Down Expand Up @@ -241,7 +243,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
14 changes: 9 additions & 5 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
end

-- azure has non-standard URL format
local url = fmt(
"%s%s",
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s?api-version=%s",
ai_shared.upstream_url_format[DRIVER_NAME]:format(conf.model.options.azure_instance, conf.model.options.azure_deployment_id),
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path,
conf.model.options.azure_api_version or "2023-05-15"
)

local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method
Expand Down Expand Up @@ -71,7 +73,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down Expand Up @@ -111,7 +113,9 @@ function _M.configure_request(conf)
end

local query_table = kong.request.get_query()
query_table["api-version"] = conf.model.options.azure_api_version

-- technically min supported version
query_table["api-version"] = conf.model.options and conf.model.options.azure_api_version or "2023-05-15"

if auth_param_name and auth_param_value and auth_param_location == "query" then
query_table[auth_param_name] = auth_param_value
Expand Down
53 changes: 4 additions & 49 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ local cjson = require("cjson.safe")
local fmt = string.format
local ai_shared = require("kong.llm.drivers.shared")
local socket_url = require "socket.url"
local http = require("resty.http")
local table_new = require("table.new")
--

Expand Down Expand Up @@ -290,52 +289,6 @@ function _M.to_format(request_table, model_info, route_type)
return response_object, content_type, nil
end

function _M.subrequest(body_table, route_type, auth)
local body_string, err = cjson.encode(body_table)
if err then
return nil, nil, "failed to parse body to json: " .. err
end

local httpc = http.new()

local request_url = fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][route_type].path
)

local headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
}

if auth and auth.header_name then
headers[auth.header_name] = auth.header_value
end

local res, err = httpc:request_uri(
request_url,
{
method = "POST",
body = body_string,
headers = headers,
})
if not res then
return nil, "request failed: " .. err
end

-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
local status = res.status
local body = res.body

if status ~= 200 then
return body, "status code not 200"
end

return body, res.status, nil
end

function _M.header_filter_hooks(body)
-- nothing to parse in header_filter phase
end
Expand Down Expand Up @@ -372,7 +325,9 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
return nil, nil, "body must be table or string"
end

local url = fmt(
-- may be overridden
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
Expand Down Expand Up @@ -403,7 +358,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
4 changes: 2 additions & 2 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function _M.to_format(request_table, model_info, route_type)
model_info
)
if err or (not ok) then
return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type)
return nil, nil, fmt("error transforming to %s://%s/%s", model_info.provider, route_type, model_info.options.llama2_format)
end

return response_object, content_type, nil
Expand Down Expand Up @@ -231,7 +231,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
6 changes: 3 additions & 3 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ function _M.subrequest(body, conf, http_opts, return_res_table)

local url = conf.model.options.upstream_url

local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method
local method = "POST"

local headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
["Content-Type"] = "application/json"
}

if conf.auth and conf.auth.header_name then
Expand All @@ -118,7 +118,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
local body = res.body

if status > 299 then
return body, res.status, "status code not 2xx"
return body, res.status, "status code " .. status
end

return body, res.status, nil
Expand Down
8 changes: 4 additions & 4 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,20 @@ function _M.pre_request(conf, request_table)
end

-- if enabled AND request type is compatible, capture the input for analytics
if conf.logging.log_payloads then
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body())
end

return true, nil
end

function _M.post_request(conf, response_string)
if conf.logging.log_payloads then
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_string)
end

-- analytics and logging
if conf.logging.log_statistics then
if conf.logging and conf.logging.log_statistics then
-- check if we already have analytics in this context
local request_analytics = kong.ctx.shared.analytics

Expand Down Expand Up @@ -253,7 +253,7 @@ function _M.http_request(url, body, method, headers, http_opts)
method = method,
body = body,
headers = headers,
ssl_verify = http_opts.https_verify or true,
ssl_verify = http_opts.https_verify,
})
if not res then
return nil, "request failed: " .. err
Expand Down
22 changes: 15 additions & 7 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex
local new_request_body = ai_response.choices
and #ai_response.choices > 0
and ai_response.choices[1]
and ai_response.choices[1].message
and ai_response.choices[1].message.content
if not new_request_body then
return nil, "no response choices received from upstream AI service"
Expand All @@ -327,16 +328,23 @@ function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex
return new_request_body
end

function _M:parse_json_instructions(body_string)
local instructions, err = cjson.decode(body_string)
if err then
return nil, nil, nil, err
function _M:parse_json_instructions(in_body)
local err
if type(in_body) == "string" then
in_body, err = cjson.decode(in_body)
if err then
return nil, nil, nil, err
end
end

if type(in_body) ~= "table" then
return nil, nil, nil, "input not table or string"
end

return
instructions.headers,
instructions.body or body_string,
instructions.status or 200
in_body.headers,
in_body.body or in_body,
in_body.status or 200
end

function _M:new(conf, http_opts)
Expand Down
74 changes: 74 additions & 0 deletions kong/plugins/ai-request-transformer/handler.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
local _M = {}

-- imports
local kong_meta = require "kong.meta"
local fmt = string.format
local llm = require("kong.llm")
--

_M.PRIORITY = 777
_M.VERSION = kong_meta.version

local function bad_request(msg)
kong.log.info(msg)
return kong.response.exit(400, { error = { message = msg } })
end

local function internal_server_error(msg)
kong.log.err(msg)
return kong.response.exit(500, { error = { message = msg } })
end

local function create_http_opts(conf)
local http_opts = {}

if conf.http_proxy_host then -- port WILL be set via schema constraint
http_opts.proxy_opts = http_opts.proxy_opts or {}
http_opts.proxy_opts.http_proxy = fmt("http://%s:%d", conf.http_proxy_host, conf.http_proxy_port)
end

if conf.https_proxy_host then
http_opts.proxy_opts = http_opts.proxy_opts or {}
http_opts.proxy_opts.https_proxy = fmt("http://%s:%d", conf.https_proxy_host, conf.https_proxy_port)
end

http_opts.http_timeout = conf.http_timeout
http_opts.https_verify = conf.https_verify

return http_opts
end

function _M:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.skip_response_transformer = true

-- first find the configured LLM interface and driver
local http_opts = create_http_opts(conf)
local ai_driver, err = llm:new(conf.llm, http_opts)

if not ai_driver then
return internal_server_error(err)
end

-- if asked, introspect the request before proxying
kong.log.debug("introspecting request with LLM")
local new_request_body, err = llm:ai_introspect_body(
kong.request.get_raw_body(),
conf.prompt,
http_opts,
conf.transformation_extract_pattern
)

if err then
return bad_request(err)
end

-- set the body for later plugins
kong.service.request.set_raw_body(new_request_body)

-- continue into other plugins including ai-response-transformer,
-- which may exit early with a sub-request
end


return _M
Loading

0 comments on commit 35ae4ba

Please sign in to comment.