diff --git a/.github/labeler.yml b/.github/labeler.yml index 38a50436f354..d40e0799a351 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -102,6 +102,14 @@ plugins/ai-prompt-template: - changed-files: - any-glob-to-any-file: kong/plugins/ai-prompt-template/**/* +plugins/ai-request-transformer: +- changed-files: + - any-glob-to-any-file: ['kong/plugins/ai-request-transformer/**/*', 'kong/llm/**/*'] + +plugins/ai-response-transformer: +- changed-files: + - any-glob-to-any-file: ['kong/plugins/ai-response-transformer/**/*', 'kong/llm/**/*'] + plugins/aws-lambda: - changed-files: - any-glob-to-any-file: kong/plugins/aws-lambda/**/* diff --git a/changelog/unreleased/kong/add-ai-request-transformer-plugin.yml b/changelog/unreleased/kong/add-ai-request-transformer-plugin.yml new file mode 100644 index 000000000000..2a54c5d548df --- /dev/null +++ b/changelog/unreleased/kong/add-ai-request-transformer-plugin.yml @@ -0,0 +1,3 @@ +message: Introduced the new **AI Request Transformer** plugin that enables passing mid-flight consumer requests to an LLM for transformation or sanitization. +type: feature +scope: Plugin diff --git a/changelog/unreleased/kong/add-ai-response-transformer-plugin.yml b/changelog/unreleased/kong/add-ai-response-transformer-plugin.yml new file mode 100644 index 000000000000..0b7f5742de42 --- /dev/null +++ b/changelog/unreleased/kong/add-ai-response-transformer-plugin.yml @@ -0,0 +1,3 @@ +message: Introduced the new **AI Response Transformer** plugin that enables passing mid-flight upstream responses to an LLM for transformation or sanitization. +type: feature +scope: Plugin diff --git a/kong-3.6.0-0.rockspec b/kong-3.6.0-0.rockspec index 8bfc5c08b164..c06a24019e35 100644 --- a/kong-3.6.0-0.rockspec +++ b/kong-3.6.0-0.rockspec @@ -565,6 +565,12 @@ build = { ["kong.plugins.ai-proxy.handler"] = "kong/plugins/ai-proxy/handler.lua", ["kong.plugins.ai-proxy.schema"] = "kong/plugins/ai-proxy/schema.lua", + ["kong.plugins.ai-request-transformer.handler"] = "kong/plugins/ai-request-transformer/handler.lua", + ["kong.plugins.ai-request-transformer.schema"] = "kong/plugins/ai-request-transformer/schema.lua", + + ["kong.plugins.ai-response-transformer.handler"] = "kong/plugins/ai-response-transformer/handler.lua", + ["kong.plugins.ai-response-transformer.schema"] = "kong/plugins/ai-response-transformer/schema.lua", + ["kong.llm"] = "kong/llm/init.lua", ["kong.llm.drivers.shared"] = "kong/llm/drivers/shared.lua", ["kong.llm.drivers.openai"] = "kong/llm/drivers/openai.lua", diff --git a/kong/constants.lua b/kong/constants.lua index 8dedd3145553..251637350167 100644 --- a/kong/constants.lua +++ b/kong/constants.lua @@ -39,6 +39,8 @@ local plugins = { "ai-proxy", "ai-prompt-decorator", "ai-prompt-template", + "ai-request-transformer", + "ai-response-transformer", } local plugin_map = {} diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 668e035d5715..811eb638722a 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -209,7 +209,9 @@ function _M.subrequest(body, conf, http_opts, return_res_table) error("body must be table or string") end - local url = fmt( + -- may be overridden + local url = (conf.model.options and conf.model.options.upstream_url) + or fmt( "%s%s", ai_shared.upstream_url_format[DRIVER_NAME], ai_shared.operation_map[DRIVER_NAME][conf.route_type].path @@ -241,7 +243,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table) local body = res.body if status > 299 then - return body, res.status, "status code not 2xx" + return body, res.status, "status code " .. status end return body, res.status, nil diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index 684dce7afab7..9207b20a54d7 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -40,10 +40,12 @@ function _M.subrequest(body, conf, http_opts, return_res_table) end -- azure has non-standard URL format - local url = fmt( - "%s%s", + local url = (conf.model.options and conf.model.options.upstream_url) + or fmt( + "%s%s?api-version=%s", ai_shared.upstream_url_format[DRIVER_NAME]:format(conf.model.options.azure_instance, conf.model.options.azure_deployment_id), - ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, + conf.model.options.azure_api_version or "2023-05-15" ) local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method @@ -71,7 +73,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table) local body = res.body if status > 299 then - return body, res.status, "status code not 2xx" + return body, res.status, "status code " .. status end return body, res.status, nil @@ -111,7 +113,9 @@ function _M.configure_request(conf) end local query_table = kong.request.get_query() - query_table["api-version"] = conf.model.options.azure_api_version + + -- technically min supported version + query_table["api-version"] = conf.model.options and conf.model.options.azure_api_version or "2023-05-15" if auth_param_name and auth_param_value and auth_param_location == "query" then query_table[auth_param_name] = auth_param_value diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 87b8a87d309d..46bde9bc3e1a 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -5,7 +5,6 @@ local cjson = require("cjson.safe") local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require "socket.url" -local http = require("resty.http") local table_new = require("table.new") -- @@ -290,52 +289,6 @@ function _M.to_format(request_table, model_info, route_type) return response_object, content_type, nil end -function _M.subrequest(body_table, route_type, auth) - local body_string, err = cjson.encode(body_table) - if err then - return nil, nil, "failed to parse body to json: " .. err - end - - local httpc = http.new() - - local request_url = fmt( - "%s%s", - ai_shared.upstream_url_format[DRIVER_NAME], - ai_shared.operation_map[DRIVER_NAME][route_type].path - ) - - local headers = { - ["Accept"] = "application/json", - ["Content-Type"] = "application/json", - } - - if auth and auth.header_name then - headers[auth.header_name] = auth.header_value - end - - local res, err = httpc:request_uri( - request_url, - { - method = "POST", - body = body_string, - headers = headers, - }) - if not res then - return nil, "request failed: " .. err - end - - -- At this point, the entire request / response is complete and the connection - -- will be closed or back on the connection pool. - local status = res.status - local body = res.body - - if status ~= 200 then - return body, "status code not 200" - end - - return body, res.status, nil -end - function _M.header_filter_hooks(body) -- nothing to parse in header_filter phase end @@ -372,7 +325,9 @@ function _M.subrequest(body, conf, http_opts, return_res_table) return nil, nil, "body must be table or string" end - local url = fmt( + -- may be overridden + local url = (conf.model.options and conf.model.options.upstream_url) + or fmt( "%s%s", ai_shared.upstream_url_format[DRIVER_NAME], ai_shared.operation_map[DRIVER_NAME][conf.route_type].path @@ -403,7 +358,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table) local body = res.body if status > 299 then - return body, res.status, "status code not 2xx" + return body, res.status, "status code " .. status end return body, res.status, nil diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index d4da6d7be0f8..7e965e2c5530 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -183,7 +183,7 @@ function _M.to_format(request_table, model_info, route_type) model_info ) if err or (not ok) then - return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type) + return nil, nil, fmt("error transforming to %s://%s/%s", model_info.provider, route_type, model_info.options.llama2_format) end return response_object, content_type, nil @@ -231,7 +231,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table) local body = res.body if status > 299 then - return body, res.status, "status code not 2xx" + return body, res.status, "status code " .. status end return body, res.status, nil diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index ba7dd94d1e24..84f961782955 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -93,11 +93,11 @@ function _M.subrequest(body, conf, http_opts, return_res_table) local url = conf.model.options.upstream_url - local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method + local method = "POST" local headers = { ["Accept"] = "application/json", - ["Content-Type"] = "application/json", + ["Content-Type"] = "application/json" } if conf.auth and conf.auth.header_name then @@ -118,7 +118,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table) local body = res.body if status > 299 then - return body, res.status, "status code not 2xx" + return body, res.status, "status code " .. status end return body, res.status, nil diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 8983c46a7b00..5d6120552367 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -169,7 +169,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table) local body = res.body if status > 299 then - return body, res.status, "status code not 2xx" + return body, res.status, "status code " .. status end return body, res.status, nil diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index ab244d9fda2d..dcc996c80857 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -178,7 +178,7 @@ function _M.pre_request(conf, request_table) end -- if enabled AND request type is compatible, capture the input for analytics - if conf.logging.log_payloads then + if conf.logging and conf.logging.log_payloads then kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body()) end @@ -186,12 +186,12 @@ function _M.pre_request(conf, request_table) end function _M.post_request(conf, response_string) - if conf.logging.log_payloads then + if conf.logging and conf.logging.log_payloads then kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_string) end -- analytics and logging - if conf.logging.log_statistics then + if conf.logging and conf.logging.log_statistics then -- check if we already have analytics in this context local request_analytics = kong.ctx.shared.analytics @@ -253,7 +253,7 @@ function _M.http_request(url, body, method, headers, http_opts) method = method, body = body, headers = headers, - ssl_verify = http_opts.https_verify or true, + ssl_verify = http_opts.https_verify, }) if not res then return nil, "request failed: " .. err diff --git a/kong/llm/init.lua b/kong/llm/init.lua index c5c73ae8bdb3..489af760ccea 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -301,6 +301,7 @@ function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex local new_request_body = ai_response.choices and #ai_response.choices > 0 and ai_response.choices[1] + and ai_response.choices[1].message and ai_response.choices[1].message.content if not new_request_body then return nil, "no response choices received from upstream AI service" @@ -327,16 +328,23 @@ function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex return new_request_body end -function _M:parse_json_instructions(body_string) - local instructions, err = cjson.decode(body_string) - if err then - return nil, nil, nil, err +function _M:parse_json_instructions(in_body) + local err + if type(in_body) == "string" then + in_body, err = cjson.decode(in_body) + if err then + return nil, nil, nil, err + end + end + + if type(in_body) ~= "table" then + return nil, nil, nil, "input not table or string" end return - instructions.headers, - instructions.body or body_string, - instructions.status or 200 + in_body.headers, + in_body.body or in_body, + in_body.status or 200 end function _M:new(conf, http_opts) diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua new file mode 100644 index 000000000000..7efd0e0c72ef --- /dev/null +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -0,0 +1,74 @@ +local _M = {} + +-- imports +local kong_meta = require "kong.meta" +local fmt = string.format +local llm = require("kong.llm") +-- + +_M.PRIORITY = 777 +_M.VERSION = kong_meta.version + +local function bad_request(msg) + kong.log.info(msg) + return kong.response.exit(400, { error = { message = msg } }) +end + +local function internal_server_error(msg) + kong.log.err(msg) + return kong.response.exit(500, { error = { message = msg } }) +end + +local function create_http_opts(conf) + local http_opts = {} + + if conf.http_proxy_host then -- port WILL be set via schema constraint + http_opts.proxy_opts = http_opts.proxy_opts or {} + http_opts.proxy_opts.http_proxy = fmt("http://%s:%d", conf.http_proxy_host, conf.http_proxy_port) + end + + if conf.https_proxy_host then + http_opts.proxy_opts = http_opts.proxy_opts or {} + http_opts.proxy_opts.https_proxy = fmt("http://%s:%d", conf.https_proxy_host, conf.https_proxy_port) + end + + http_opts.http_timeout = conf.http_timeout + http_opts.https_verify = conf.https_verify + + return http_opts +end + +function _M:access(conf) + kong.service.request.enable_buffering() + kong.ctx.shared.skip_response_transformer = true + + -- first find the configured LLM interface and driver + local http_opts = create_http_opts(conf) + local ai_driver, err = llm:new(conf.llm, http_opts) + + if not ai_driver then + return internal_server_error(err) + end + + -- if asked, introspect the request before proxying + kong.log.debug("introspecting request with LLM") + local new_request_body, err = llm:ai_introspect_body( + kong.request.get_raw_body(), + conf.prompt, + http_opts, + conf.transformation_extract_pattern + ) + + if err then + return bad_request(err) + end + + -- set the body for later plugins + kong.service.request.set_raw_body(new_request_body) + + -- continue into other plugins including ai-response-transformer, + -- which may exit early with a sub-request +end + + +return _M diff --git a/kong/plugins/ai-request-transformer/schema.lua b/kong/plugins/ai-request-transformer/schema.lua new file mode 100644 index 000000000000..c7ce498ba68e --- /dev/null +++ b/kong/plugins/ai-request-transformer/schema.lua @@ -0,0 +1,68 @@ +local typedefs = require("kong.db.schema.typedefs") +local llm = require("kong.llm") + + + +return { + name = "ai-request-transformer", + fields = { + { protocols = typedefs.protocols_http }, + { consumer = typedefs.no_consumer }, + { config = { + type = "record", + fields = { + { prompt = { + description = "Use this prompt to tune the LLM system/assistant message for the incoming " + .. "proxy request (from the client), and what you are expecting in return.", + type = "string", + required = true, + }}, + { transformation_extract_pattern = { + description = "Defines the regular expression that must match to indicate a successful AI transformation " + .. "at the request phase. The first match will be set as the outgoing body. " + .. "If the AI service's response doesn't match this pattern, it is marked as a failure.", + type = "string", + required = false, + }}, + { http_timeout = { + description = "Timeout in milliseconds for the AI upstream service.", + type = "integer", + required = true, + default = 60000, + }}, + { https_verify = { + description = "Verify the TLS certificate of the AI upstream service.", + type = "boolean", + required = true, + default = true, + }}, + + -- from forward-proxy + { http_proxy_host = typedefs.host }, + { http_proxy_port = typedefs.port }, + { https_proxy_host = typedefs.host }, + { https_proxy_port = typedefs.port }, + + { llm = llm.config_schema }, + }, + }}, + + }, + entity_checks = { + { + conditional = { + if_field = "config.llm.route_type", + if_match = { + not_one_of = { + "llm/v1/chat", + } + }, + then_field = "config.llm.route_type", + then_match = { eq = "llm/v1/chat" }, + then_err = "'config.llm.route_type' must be 'llm/v1/chat' for AI transformer plugins", + }, + }, + { mutually_required = { "config.http_proxy_host", "config.http_proxy_port" } }, + { mutually_required = { "config.https_proxy_host", "config.https_proxy_port" } }, + }, +} diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua new file mode 100644 index 000000000000..b5cde6fc0daa --- /dev/null +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -0,0 +1,165 @@ +local _M = {} + +-- imports +local kong_meta = require "kong.meta" +local http = require("resty.http") +local fmt = string.format +local kong_utils = require("kong.tools.utils") +local llm = require("kong.llm") +-- + +_M.PRIORITY = 769 +_M.VERSION = kong_meta.version + +local function bad_request(msg) + kong.log.info(msg) + return kong.response.exit(400, { error = { message = msg } }) +end + +local function internal_server_error(msg) + kong.log.err(msg) + return kong.response.exit(500, { error = { message = msg } }) +end + +local function subrequest(httpc, request_body, http_opts) + httpc:set_timeouts(http_opts.http_timeout or 60000) + + local upstream_uri = ngx.var.upstream_uri + if ngx.var.is_args == "?" or string.sub(ngx.var.request_uri, -1) == "?" then + ngx.var.upstream_uri = upstream_uri .. "?" .. (ngx.var.args or "") + end + + local ok, err = httpc:connect { + scheme = ngx.var.upstream_scheme, + host = ngx.ctx.balancer_data.host, + port = ngx.ctx.balancer_data.port, + proxy_opts = http_opts.proxy_opts, + ssl_verify = http_opts.https_verify, + ssl_server_name = ngx.ctx.balancer_data.host, + } + + if not ok then + return nil, "failed to connect to upstream: " .. err + end + + local headers = kong.request.get_headers() + headers["transfer-encoding"] = nil -- transfer-encoding is hop-by-hop, strip + -- it out + headers["content-length"] = nil -- clear content-length - it will be set + -- later on by resty-http (if not found); + -- further, if we leave it here it will + -- cause issues if the value varies (if may + -- happen, e.g., due to a different transfer + -- encoding being used subsequently) + + if ngx.var.upstream_host == "" then + headers["host"] = nil + else + headers["host"] = ngx.var.upstream_host + end + + local res, err = httpc:request({ + method = kong.request.get_method(), + path = ngx.var.upstream_uri, + headers = headers, + body = request_body, + }) + + if not res then + return nil, "subrequest failed: " .. err + end + + return res +end + +local function create_http_opts(conf) + local http_opts = {} + + if conf.http_proxy_host then -- port WILL be set via schema constraint + http_opts.proxy_opts = http_opts.proxy_opts or {} + http_opts.proxy_opts.http_proxy = fmt("http://%s:%d", conf.http_proxy_host, conf.http_proxy_port) + end + + if conf.https_proxy_host then + http_opts.proxy_opts = http_opts.proxy_opts or {} + http_opts.proxy_opts.https_proxy = fmt("http://%s:%d", conf.https_proxy_host, conf.https_proxy_port) + end + + http_opts.http_timeout = conf.http_timeout + http_opts.https_verify = conf.https_verify + + return http_opts +end + +function _M:access(conf) + kong.service.request.enable_buffering() + kong.ctx.shared.skip_response_transformer = true + + -- first find the configured LLM interface and driver + local http_opts = create_http_opts(conf) + local ai_driver, err = llm:new(conf.llm, http_opts) + + if not ai_driver then + return internal_server_error(err) + end + + kong.log.debug("intercepting plugin flow with one-shot request") + local httpc = http.new() + local res, err = subrequest(httpc, kong.request.get_raw_body(), http_opts) + if err then + return internal_server_error(err) + end + + local res_body = res:read_body() + local is_gzip = res.headers["Content-Encoding"] == "gzip" + if is_gzip then + res_body = kong_utils.inflate_gzip(res_body) + end + + -- if asked, introspect the request before proxying + kong.log.debug("introspecting response with LLM") + + local new_response_body, err = llm:ai_introspect_body( + res_body, + conf.prompt, + http_opts, + conf.transformation_extract_pattern + ) + + if err then + return bad_request(err) + end + + if res.headers then + res.headers["content-length"] = nil + res.headers["content-encoding"] = nil + res.headers["transfer-encoding"] = nil + end + + local headers, body, status + if conf.parse_llm_response_json_instructions then + headers, body, status, err = llm:parse_json_instructions(new_response_body) + if err then + return internal_server_error("failed to parse JSON response instructions from AI backend: " .. err) + end + + if headers then + for k, v in pairs(headers) do + res.headers[k] = v -- override e.g. ['content-type'] + end + end + + headers = res.headers + else + + headers = res.headers -- headers from upstream + body = new_response_body -- replacement body from AI + status = res.status -- status from upstream + end + + return kong.response.exit(status, body, headers) + +end + + +return _M diff --git a/kong/plugins/ai-response-transformer/schema.lua b/kong/plugins/ai-response-transformer/schema.lua new file mode 100644 index 000000000000..c4eb6fe25ac1 --- /dev/null +++ b/kong/plugins/ai-response-transformer/schema.lua @@ -0,0 +1,76 @@ +local typedefs = require("kong.db.schema.typedefs") +local llm = require("kong.llm") + + + +return { + name = "ai-response-transformer", + fields = { + { protocols = typedefs.protocols_http }, + { consumer = typedefs.no_consumer }, + { config = { + type = "record", + fields = { + { prompt = { + description = "Use this prompt to tune the LLM system/assistant message for the returning " + .. "proxy response (from the upstream), adn what response format you are expecting.", + type = "string", + required = true, + }}, + { transformation_extract_pattern = { + description = "Defines the regular expression that must match to indicate a successful AI transformation " + .. "at the response phase. The first match will be set as the returning body. " + .. "If the AI service's response doesn't match this pattern, a failure is returned to the client.", + type = "string", + required = false, + }}, + { parse_llm_response_json_instructions = { + description = "Set true to read specific response format from the LLM, " + .. "and accordingly set the status code / body / headers that proxy back to the client. " + .. "You need to engineer your LLM prompt to return the correct format, " + .. "see plugin docs 'Overview' page for usage instructions.", + type = "boolean", + required = true, + default = false, + }}, + { http_timeout = { + description = "Timeout in milliseconds for the AI upstream service.", + type = "integer", + required = true, + default = 60000, + }}, + { https_verify = { + description = "Verify the TLS certificate of the AI upstream service.", + type = "boolean", + required = true, + default = true, + }}, + + -- from forward-proxy + { http_proxy_host = typedefs.host }, + { http_proxy_port = typedefs.port }, + { https_proxy_host = typedefs.host }, + { https_proxy_port = typedefs.port }, + + { llm = llm.config_schema }, + }, + }}, + }, + entity_checks = { + { + conditional = { + if_field = "config.llm.route_type", + if_match = { + not_one_of = { + "llm/v1/chat", + } + }, + then_field = "config.llm.route_type", + then_match = { eq = "llm/v1/chat" }, + then_err = "'config.llm.route_type' must be 'llm/v1/chat' for AI transformer plugins", + }, + }, + { mutually_required = { "config.http_proxy_host", "config.http_proxy_port" } }, + { mutually_required = { "config.https_proxy_host", "config.https_proxy_port" } }, + }, +} diff --git a/spec/01-unit/12-plugins_order_spec.lua b/spec/01-unit/12-plugins_order_spec.lua index 2f24d6348678..8189d05e9925 100644 --- a/spec/01-unit/12-plugins_order_spec.lua +++ b/spec/01-unit/12-plugins_order_spec.lua @@ -72,9 +72,11 @@ describe("Plugins", function() "response-ratelimiting", "request-transformer", "response-transformer", + "ai-request-transformer", "ai-prompt-template", "ai-prompt-decorator", "ai-proxy", + "ai-response-transformer", "aws-lambda", "azure-functions", "proxy-cache", diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index dc5b59a53400..61f9cb5da270 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -118,7 +118,7 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - llama2_format = "raw", + llama2_format = "ollama", }, }, ["llm/v1/completions"] = { @@ -199,6 +199,13 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.same("request matches multiple LLM request formats", err) end) + it("double-format message is denied", function() + local compatible, err = llm.is_compatible(SAMPLE_DOUBLE_FORMAT, "llm/v1/completions") + + assert.is_falsy(compatible) + assert.same("request matches multiple LLM request formats", err) + end) + for i, j in pairs(FORMATS) do describe(i .. " format tests", function() diff --git a/spec/03-plugins/39-ai-request-transformer/00-config_spec.lua b/spec/03-plugins/39-ai-request-transformer/00-config_spec.lua new file mode 100644 index 000000000000..bf5e3ae3b42a --- /dev/null +++ b/spec/03-plugins/39-ai-request-transformer/00-config_spec.lua @@ -0,0 +1,120 @@ +local PLUGIN_NAME = "ai-request-transformer" + + +-- helper function to validate data against a schema +local validate do + local validate_entity = require("spec.helpers").validate_plugin_config_schema + local plugin_schema = require("kong.plugins."..PLUGIN_NAME..".schema") + + function validate(data) + return validate_entity(data, plugin_schema) + end +end + +describe(PLUGIN_NAME .. ": (schema)", function() + it("must be 'llm/v1/chat' route type", function() + local config = { + llm = { + route_type = "llm/v1/completions", + auth = { + header_name = "Authorization", + header_value = "Bearer token", + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://kong" + }, + }, + }, + } + + local ok, err = validate(config) + + assert.not_nil(err) + + assert.same({ + ["@entity"] = { + [1] = "'config.llm.route_type' must be 'llm/v1/chat' for AI transformer plugins" + }, + config = { + llm = { + route_type = "value must be llm/v1/chat", + }, + prompt = "required field missing", + }}, err) + assert.is_falsy(ok) + end) + + it("requires 'https_proxy_host' and 'https_proxy_port' to be set together", function() + local config = { + prompt = "anything", + https_proxy_host = "kong.local", + llm = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer token", + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://kong" + }, + }, + }, + } + + local ok, err = validate(config) + + assert.not_nil(err) + + assert.same({ + ["@entity"] = { + [1] = "all or none of these fields must be set: 'config.https_proxy_host', 'config.https_proxy_port'" + }}, err) + assert.is_falsy(ok) + end) + + it("requires 'http_proxy_host' and 'http_proxy_port' to be set together", function() + local config = { + prompt = "anything", + http_proxy_host = "kong.local", + llm = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer token", + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://kong" + }, + }, + }, + } + + local ok, err = validate(config) + + assert.not_nil(err) + + assert.same({ + ["@entity"] = { + [1] = "all or none of these fields must be set: 'config.http_proxy_host', 'config.http_proxy_port'" + }}, err) + assert.is_falsy(ok) + end) +end) diff --git a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua new file mode 100644 index 000000000000..5f4bd4cdc5db --- /dev/null +++ b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua @@ -0,0 +1,307 @@ +local llm_class = require("kong.llm") +local helpers = require "spec.helpers" +local cjson = require "cjson" + +local MOCK_PORT = 62349 +local PLUGIN_NAME = "ai-request-transformer" + +local FORMATS = { + openai = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/openai" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, + }, + cohere = { + route_type = "llm/v1/chat", + model = { + name = "command", + provider = "cohere", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/cohere" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer cohere-key", + }, + }, + authropic = { + route_type = "llm/v1/chat", + model = { + name = "claude-2", + provider = "anthropic", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/anthropic" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer anthropic-key", + }, + }, + azure = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "azure", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/azure" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer azure-key", + }, + }, + llama2 = { + route_type = "llm/v1/chat", + model = { + name = "llama2", + provider = "llama2", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/llama2", + llama2_format = "raw", + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer llama2-key", + }, + }, + mistral = { + route_type = "llm/v1/chat", + model = { + name = "mistral", + provider = "mistral", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/mistral", + mistral_format = "ollama", + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer mistral-key", + }, + }, +} + +local OPENAI_NOT_JSON = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/not-json" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + +local REQUEST_BODY = [[ + { + "persons": [ + { + "name": "Kong A", + "age": 31 + }, + { + "name": "Kong B", + "age": 42 + } + ] + } +]] + +local EXPECTED_RESULT = { + persons = { + [1] = { + age = 62, + name = "Kong A" + }, + [2] = { + age = 84, + name = "Kong B" + }, + } +} + +local SYSTEM_PROMPT = "You are a mathematician. " + .. "Multiply all numbers in my JSON request, by 2. Return me the JSON output only" + + +local client + + +for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then + + describe(PLUGIN_NAME .. ": (unit)", function() + + lazy_setup(function() + -- set up provider fixtures + local fixtures = { + http_mock = {}, + } + + fixtures.http_mock.openai = [[ + server { + server_name llm; + listen ]]..MOCK_PORT..[[; + + default_type 'application/json'; + + location ~/chat/(?[a-z0-9]+) { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local token = ngx.req.get_headers()["authorization"] + local token_query = ngx.req.get_uri_args()["apikey"] + + if token == "Bearer " .. ngx.var.provider .. "-key" or token_query == "$1-key" or body.apikey == "$1-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (body.messages == ngx.null) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/llm-v1-chat/responses/bad_request.json")) + else + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/request-transformer/response-in-json.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/llm-v1-chat/responses/unauthorized.json")) + end + } + } + + location ~/not-json { + content_by_lua_block { + local pl_file = require "pl.file" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/request-transformer/response-not-json.json")) + } + } + } + ]] + + -- start kong + assert(helpers.start_kong({ + -- set the strategy + database = strategy, + -- use the custom test template to create a local mock server + nginx_conf = "spec/fixtures/custom_nginx.template", + -- make sure our plugin gets loaded + plugins = "bundled," .. PLUGIN_NAME, + -- write & load declarative config, only if 'strategy=off' + declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, + }, nil, nil, fixtures)) + end) + + lazy_teardown(function() + helpers.stop_kong(nil, true) + end) + + before_each(function() + client = helpers.proxy_client() + end) + + after_each(function() + if client then client:close() end + end) + + for name, format_options in pairs(FORMATS) do + + describe(name .. " transformer tests, exact json response", function() + + it("transforms request based on LLM instructions", function() + local llm = llm_class:new(format_options, {}) + assert.truthy(llm) + + local result, err = llm:ai_introspect_body( + REQUEST_BODY, -- request body + SYSTEM_PROMPT, -- conf.prompt + {}, -- http opts + nil -- transformation extraction pattern + ) + + assert.is_nil(err) + + result, err = cjson.decode(result) + assert.is_nil(err) + + assert.same(EXPECTED_RESULT, result) + end) + end) + + + end + + describe("openai transformer tests, pattern matchers", function() + it("transforms request based on LLM instructions, with json extraction pattern", function() + local llm = llm_class:new(OPENAI_NOT_JSON, {}) + assert.truthy(llm) + + local result, err = llm:ai_introspect_body( + REQUEST_BODY, -- request body + SYSTEM_PROMPT, -- conf.prompt + {}, -- http opts + "\\{((.|\n)*)\\}" -- transformation extraction pattern (loose json) + ) + + assert.is_nil(err) + + result, err = cjson.decode(result) + assert.is_nil(err) + + assert.same(EXPECTED_RESULT, result) + end) + + it("transforms request based on LLM instructions, but fails to match pattern", function() + local llm = llm_class:new(OPENAI_NOT_JSON, {}) + assert.truthy(llm) + + local result, err = llm:ai_introspect_body( + REQUEST_BODY, -- request body + SYSTEM_PROMPT, -- conf.prompt + {}, -- http opts + "\\#*\\=" -- transformation extraction pattern (loose json) + ) + + assert.is_nil(result) + assert.is_not_nil(err) + assert.same("AI response did not match specified regular expression", err) + end) + end) + end) +end end diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua new file mode 100644 index 000000000000..1d0ff2a00ba7 --- /dev/null +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -0,0 +1,253 @@ +local helpers = require "spec.helpers" +local cjson = require "cjson" + +local MOCK_PORT = 62349 +local PLUGIN_NAME = "ai-request-transformer" + +local OPENAI_FLAT_RESPONSE = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/flat" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + +local OPENAI_BAD_REQUEST = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/badrequest" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + +local OPENAI_INTERNAL_SERVER_ERROR = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/internalservererror" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + + +local REQUEST_BODY = [[ + { + "persons": [ + { + "name": "Kong A", + "age": 31 + }, + { + "name": "Kong B", + "age": 42 + } + ] + } +]] + +local EXPECTED_RESULT_FLAT = { + persons = { + [1] = { + age = 62, + name = "Kong A" + }, + [2] = { + age = 84, + name = "Kong B" + }, + } +} + +local SYSTEM_PROMPT = "You are a mathematician. " + .. "Multiply all numbers in my JSON request, by 2." + + +local client + +for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then + describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() + + lazy_setup(function() + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) + + -- set up provider fixtures + local fixtures = { + http_mock = {}, + } + + fixtures.http_mock.openai = [[ + server { + server_name llm; + listen ]]..MOCK_PORT..[[; + + default_type 'application/json'; + + location ~/flat { + content_by_lua_block { + local pl_file = require "pl.file" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/request-transformer/response-in-json.json")) + } + } + + location = "/badrequest" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) + } + } + + location = "/internalservererror" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 500 + ngx.header["content-type"] = "text/html" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) + } + } + } + ]] + + -- echo server via 'openai' LLM + local without_response_instructions = assert(bp.routes:insert { + paths = { "/echo-flat" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = without_response_instructions.id }, + config = { + prompt = SYSTEM_PROMPT, + llm = OPENAI_FLAT_RESPONSE, + }, + } + + local bad_request = assert(bp.routes:insert { + paths = { "/echo-bad-request" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = bad_request.id }, + config = { + prompt = SYSTEM_PROMPT, + llm = OPENAI_BAD_REQUEST, + }, + } + + local internal_server_error = assert(bp.routes:insert { + paths = { "/echo-internal-server-error" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = internal_server_error.id }, + config = { + prompt = SYSTEM_PROMPT, + llm = OPENAI_INTERNAL_SERVER_ERROR, + }, + } + -- + + -- start kong + assert(helpers.start_kong({ + -- set the strategy + database = strategy, + -- use the custom test template to create a local mock server + nginx_conf = "spec/fixtures/custom_nginx.template", + -- make sure our plugin gets loaded + plugins = "bundled," .. PLUGIN_NAME, + -- write & load declarative config, only if 'strategy=off' + declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, + }, nil, nil, fixtures)) + end) + + lazy_teardown(function() + helpers.stop_kong(nil, true) + end) + + before_each(function() + client = helpers.proxy_client() + end) + + after_each(function() + if client then client:close() end + end) + + describe("openai response transformer integration", function() + it("transforms properly from LLM", function() + local r = client:get("/echo-flat", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(200 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.same(EXPECTED_RESULT_FLAT, body_table.post_data.params) + end) + + it("bad request from LLM", function() + local r = client:get("/echo-bad-request", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(400 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table) + end) + + it("internal server error from LLM", function() + local r = client:get("/echo-internal-server-error", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(400 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.same({ error = { message = "failed to introspect request with AI service: status code 500" }}, body_table) + end) + end) + end) +end +end diff --git a/spec/03-plugins/40-ai-response-transformer/00-config_spec.lua b/spec/03-plugins/40-ai-response-transformer/00-config_spec.lua new file mode 100644 index 000000000000..bf5e3ae3b42a --- /dev/null +++ b/spec/03-plugins/40-ai-response-transformer/00-config_spec.lua @@ -0,0 +1,120 @@ +local PLUGIN_NAME = "ai-request-transformer" + + +-- helper function to validate data against a schema +local validate do + local validate_entity = require("spec.helpers").validate_plugin_config_schema + local plugin_schema = require("kong.plugins."..PLUGIN_NAME..".schema") + + function validate(data) + return validate_entity(data, plugin_schema) + end +end + +describe(PLUGIN_NAME .. ": (schema)", function() + it("must be 'llm/v1/chat' route type", function() + local config = { + llm = { + route_type = "llm/v1/completions", + auth = { + header_name = "Authorization", + header_value = "Bearer token", + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://kong" + }, + }, + }, + } + + local ok, err = validate(config) + + assert.not_nil(err) + + assert.same({ + ["@entity"] = { + [1] = "'config.llm.route_type' must be 'llm/v1/chat' for AI transformer plugins" + }, + config = { + llm = { + route_type = "value must be llm/v1/chat", + }, + prompt = "required field missing", + }}, err) + assert.is_falsy(ok) + end) + + it("requires 'https_proxy_host' and 'https_proxy_port' to be set together", function() + local config = { + prompt = "anything", + https_proxy_host = "kong.local", + llm = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer token", + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://kong" + }, + }, + }, + } + + local ok, err = validate(config) + + assert.not_nil(err) + + assert.same({ + ["@entity"] = { + [1] = "all or none of these fields must be set: 'config.https_proxy_host', 'config.https_proxy_port'" + }}, err) + assert.is_falsy(ok) + end) + + it("requires 'http_proxy_host' and 'http_proxy_port' to be set together", function() + local config = { + prompt = "anything", + http_proxy_host = "kong.local", + llm = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer token", + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://kong" + }, + }, + }, + } + + local ok, err = validate(config) + + assert.not_nil(err) + + assert.same({ + ["@entity"] = { + [1] = "all or none of these fields must be set: 'config.http_proxy_host', 'config.http_proxy_port'" + }}, err) + assert.is_falsy(ok) + end) +end) diff --git a/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua b/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua new file mode 100644 index 000000000000..c13f9dc27eda --- /dev/null +++ b/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua @@ -0,0 +1,152 @@ +local llm_class = require("kong.llm") +local helpers = require "spec.helpers" +local cjson = require "cjson" + +local MOCK_PORT = 62349 +local PLUGIN_NAME = "ai-response-transformer" + +local OPENAI_INSTRUCTIONAL_RESPONSE = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/instructions" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + +local REQUEST_BODY = [[ + { + "persons": [ + { + "name": "Kong A", + "age": 31 + }, + { + "name": "Kong B", + "age": 42 + } + ] + } +]] + +local EXPECTED_RESULT = { + body = [[ + + Kong A + 62 + + + Kong B + 84 + +]], + status = 209, + headers = { + ["content-type"] = "application/xml", + }, +} + +local SYSTEM_PROMPT = "You are a mathematician. " + .. "Multiply all numbers in my JSON request, by 2. Return me this message: " + .. "{\"status\": 400, \"headers: {\"content-type\": \"application/xml\"}, \"body\": \"OUTPUT\"} " + .. "where 'OUTPUT' is the result but transformed into XML format." + + +local client + + +for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then + + describe(PLUGIN_NAME .. ": (unit)", function() + + lazy_setup(function() + -- set up provider fixtures + local fixtures = { + http_mock = {}, + } + + fixtures.http_mock.openai = [[ + server { + server_name llm; + listen ]]..MOCK_PORT..[[; + + default_type 'application/json'; + + location ~/instructions { + content_by_lua_block { + local pl_file = require "pl.file" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/request-transformer/response-with-instructions.json")) + } + } + } + ]] + + -- start kong + assert(helpers.start_kong({ + -- set the strategy + database = strategy, + -- use the custom test template to create a local mock server + nginx_conf = "spec/fixtures/custom_nginx.template", + -- make sure our plugin gets loaded + plugins = "bundled," .. PLUGIN_NAME, + -- write & load declarative config, only if 'strategy=off' + declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, + }, nil, nil, fixtures)) + end) + + lazy_teardown(function() + helpers.stop_kong(nil, true) + end) + + before_each(function() + client = helpers.proxy_client() + end) + + after_each(function() + if client then client:close() end + end) + + describe("openai transformer tests, specific response", function() + it("transforms request based on LLM instructions, with response transformation instructions format", function() + local llm = llm_class:new(OPENAI_INSTRUCTIONAL_RESPONSE, {}) + assert.truthy(llm) + + local result, err = llm:ai_introspect_body( + REQUEST_BODY, -- request body + SYSTEM_PROMPT, -- conf.prompt + {}, -- http opts + nil -- transformation extraction pattern (loose json) + ) + + assert.is_nil(err) + + local table_result, err = cjson.decode(result) + assert.is_nil(err) + assert.same(EXPECTED_RESULT, table_result) + + -- parse in response string format + local headers, body, status, err = llm:parse_json_instructions(result) + assert.is_nil(err) + assert.same({ ["content-type"] = "application/xml"}, headers) + assert.same(209, status) + assert.same(EXPECTED_RESULT.body, body) + + -- parse in response table format + headers, body, status, err = llm:parse_json_instructions(table_result) + assert.is_nil(err) + assert.same({ ["content-type"] = "application/xml"}, headers) + assert.same(209, status) + assert.same(EXPECTED_RESULT.body, body) + end) + + end) + end) +end end diff --git a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua new file mode 100644 index 000000000000..9f724629da95 --- /dev/null +++ b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua @@ -0,0 +1,411 @@ +local helpers = require "spec.helpers" +local cjson = require "cjson" + +local MOCK_PORT = 62349 +local PLUGIN_NAME = "ai-response-transformer" + +local OPENAI_INSTRUCTIONAL_RESPONSE = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/instructions" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + +local OPENAI_FLAT_RESPONSE = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/flat" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + +local OPENAI_BAD_INSTRUCTIONS = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/badinstructions" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + +local OPENAI_BAD_REQUEST = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/badrequest" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + +local OPENAI_INTERNAL_SERVER_ERROR = { + route_type = "llm/v1/chat", + model = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/internalservererror" + }, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, +} + + +local REQUEST_BODY = [[ + { + "persons": [ + { + "name": "Kong A", + "age": 31 + }, + { + "name": "Kong B", + "age": 42 + } + ] + } +]] + +local EXPECTED_RESULT_FLAT = { + persons = { + [1] = { + age = 62, + name = "Kong A" + }, + [2] = { + age = 84, + name = "Kong B" + }, + } +} + +local EXPECTED_BAD_INSTRUCTIONS_ERROR = { + error = { + message = "failed to parse JSON response instructions from AI backend: Expected value but found invalid token at character 1" + } +} + +local EXPECTED_RESULT = { + body = [[ + + Kong A + 62 + + + Kong B + 84 + +]], + status = 209, + headers = { + ["content-type"] = "application/xml", + }, +} + +local SYSTEM_PROMPT = "You are a mathematician. " + .. "Multiply all numbers in my JSON request, by 2. Return me this message: " + .. "{\"status\": 400, \"headers: {\"content-type\": \"application/xml\"}, \"body\": \"OUTPUT\"} " + .. "where 'OUTPUT' is the result but transformed into XML format." + + +local client + +for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then + describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() + + lazy_setup(function() + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) + + -- set up provider fixtures + local fixtures = { + http_mock = {}, + } + + fixtures.http_mock.openai = [[ + server { + server_name llm; + listen ]]..MOCK_PORT..[[; + + default_type 'application/json'; + + location ~/instructions { + content_by_lua_block { + local pl_file = require "pl.file" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/request-transformer/response-with-instructions.json")) + } + } + + location ~/flat { + content_by_lua_block { + local pl_file = require "pl.file" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/request-transformer/response-in-json.json")) + } + } + + location ~/badinstructions { + content_by_lua_block { + local pl_file = require "pl.file" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/request-transformer/response-with-bad-instructions.json")) + } + } + + location = "/badrequest" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) + } + } + + location = "/internalservererror" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 500 + ngx.header["content-type"] = "text/html" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) + } + } + } + ]] + + -- echo server via 'openai' LLM + local with_response_instructions = assert(bp.routes:insert { + paths = { "/echo-parse-instructions" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = with_response_instructions.id }, + config = { + prompt = SYSTEM_PROMPT, + parse_llm_response_json_instructions = true, + llm = OPENAI_INSTRUCTIONAL_RESPONSE, + }, + } + + local without_response_instructions = assert(bp.routes:insert { + paths = { "/echo-flat" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = without_response_instructions.id }, + config = { + prompt = SYSTEM_PROMPT, + parse_llm_response_json_instructions = false, + llm = OPENAI_FLAT_RESPONSE, + }, + } + + local bad_instructions = assert(bp.routes:insert { + paths = { "/echo-bad-instructions" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = bad_instructions.id }, + config = { + prompt = SYSTEM_PROMPT, + parse_llm_response_json_instructions = true, + llm = OPENAI_BAD_INSTRUCTIONS, + }, + } + + local bad_instructions_parse_out = assert(bp.routes:insert { + paths = { "/echo-bad-instructions-parse-out" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = bad_instructions_parse_out.id }, + config = { + prompt = SYSTEM_PROMPT, + parse_llm_response_json_instructions = true, + llm = OPENAI_BAD_INSTRUCTIONS, + transformation_extract_pattern = "\\{((.|\n)*)\\}", + }, + } + + local bad_request = assert(bp.routes:insert { + paths = { "/echo-bad-request" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = bad_request.id }, + config = { + prompt = SYSTEM_PROMPT, + parse_llm_response_json_instructions = false, + llm = OPENAI_BAD_REQUEST, + }, + } + + local internal_server_error = assert(bp.routes:insert { + paths = { "/echo-internal-server-error" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = internal_server_error.id }, + config = { + prompt = SYSTEM_PROMPT, + parse_llm_response_json_instructions = false, + llm = OPENAI_INTERNAL_SERVER_ERROR, + }, + } + -- + + -- start kong + assert(helpers.start_kong({ + -- set the strategy + database = strategy, + -- use the custom test template to create a local mock server + nginx_conf = "spec/fixtures/custom_nginx.template", + -- make sure our plugin gets loaded + plugins = "bundled," .. PLUGIN_NAME, + -- write & load declarative config, only if 'strategy=off' + declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, + }, nil, nil, fixtures)) + end) + + lazy_teardown(function() + helpers.stop_kong(nil, true) + end) + + before_each(function() + client = helpers.proxy_client() + end) + + after_each(function() + if client then client:close() end + end) + + describe("openai response transformer integration", function() + it("transforms request based on LLM instructions, with response transformation instructions format", function() + local r = client:get("/echo-parse-instructions", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(209 , r) + assert.same(EXPECTED_RESULT.body, body) + assert.same(r.headers["content-type"], "application/xml") + end) + + it("transforms request based on LLM instructions, without response instructions", function() + local r = client:get("/echo-flat", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(200 , r) + local body_table, err = cjson.decode(body) + assert.is_nil(err) + assert.same(EXPECTED_RESULT_FLAT, body_table) + end) + + it("fails properly when json instructions are bad", function() + local r = client:get("/echo-bad-instructions", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(500 , r) + local body_table, err = cjson.decode(body) + assert.is_nil(err) + assert.same(EXPECTED_BAD_INSTRUCTIONS_ERROR, body_table) + end) + + it("succeeds extracting json instructions when bad", function() + local r = client:get("/echo-bad-instructions-parse-out", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(209 , r) + assert.same(EXPECTED_RESULT.body, body) + assert.same(r.headers["content-type"], "application/xml") + end) + + it("bad request from LLM", function() + local r = client:get("/echo-bad-request", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(400 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table) + end) + + it("internal server error from LLM", function() + local r = client:get("/echo-internal-server-error", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(400 , r) + local body_table, err = cjson.decode(body) + + assert.is_nil(err) + assert.same({ error = { message = "failed to introspect request with AI service: status code 500" }}, body_table) + end) + end) + end) +end +end diff --git a/spec/fixtures/ai-proxy/anthropic/request-transformer/response-in-json.json b/spec/fixtures/ai-proxy/anthropic/request-transformer/response-in-json.json new file mode 100644 index 000000000000..cca0d6e595b1 --- /dev/null +++ b/spec/fixtures/ai-proxy/anthropic/request-transformer/response-in-json.json @@ -0,0 +1,5 @@ +{ + "completion": "{\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }\n", + "stop_reason": "stop_sequence", + "model": "claude-2" +} diff --git a/spec/fixtures/ai-proxy/azure/request-transformer/response-in-json.json b/spec/fixtures/ai-proxy/azure/request-transformer/response-in-json.json new file mode 100644 index 000000000000..cc8f792cb387 --- /dev/null +++ b/spec/fixtures/ai-proxy/azure/request-transformer/response-in-json.json @@ -0,0 +1,22 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": " {\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }\n", + "role": "assistant" + } + } + ], + "created": 1701947430, + "id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2", + "model": "gpt-3.5-turbo-0613", + "object": "chat.completion", + "system_fingerprint": null, + "usage": { + "completion_tokens": 12, + "prompt_tokens": 25, + "total_tokens": 37 + } +} diff --git a/spec/fixtures/ai-proxy/cohere/request-transformer/response-in-json.json b/spec/fixtures/ai-proxy/cohere/request-transformer/response-in-json.json new file mode 100644 index 000000000000..beda83d6264d --- /dev/null +++ b/spec/fixtures/ai-proxy/cohere/request-transformer/response-in-json.json @@ -0,0 +1,19 @@ +{ + "text": "{\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }\n", + "generation_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", + "token_count": { + "billed_tokens": 339, + "prompt_tokens": 102, + "response_tokens": 258, + "total_tokens": 360 + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 81, + "output_tokens": 258 + } + } + } diff --git a/spec/fixtures/ai-proxy/llama2/request-transformer/response-in-json.json b/spec/fixtures/ai-proxy/llama2/request-transformer/response-in-json.json new file mode 100644 index 000000000000..7a433236f2de --- /dev/null +++ b/spec/fixtures/ai-proxy/llama2/request-transformer/response-in-json.json @@ -0,0 +1,7 @@ +{ + "data": [ + { + "generated_text": "[INST]\nWhat is Sans? ?\n[/INST]\n\n{\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }\n" + } + ] +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/mistral/request-transformer/response-in-json.json b/spec/fixtures/ai-proxy/mistral/request-transformer/response-in-json.json new file mode 100644 index 000000000000..754883eb0bd7 --- /dev/null +++ b/spec/fixtures/ai-proxy/mistral/request-transformer/response-in-json.json @@ -0,0 +1,16 @@ +{ + "model": "mistral-tiny", + "created_at": "2024-01-15T08:13:38.876196Z", + "message": { + "role": "assistant", + "content": " {\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }\n" + }, + "done": true, + "total_duration": 4062418334, + "load_duration": 1229365792, + "prompt_eval_count": 26, + "prompt_eval_duration": 167969000, + "eval_count": 100, + "eval_duration": 2658646000 + } + \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/openai/request-transformer/response-in-json.json b/spec/fixtures/ai-proxy/openai/request-transformer/response-in-json.json new file mode 100644 index 000000000000..cc8f792cb387 --- /dev/null +++ b/spec/fixtures/ai-proxy/openai/request-transformer/response-in-json.json @@ -0,0 +1,22 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": " {\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }\n", + "role": "assistant" + } + } + ], + "created": 1701947430, + "id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2", + "model": "gpt-3.5-turbo-0613", + "object": "chat.completion", + "system_fingerprint": null, + "usage": { + "completion_tokens": 12, + "prompt_tokens": 25, + "total_tokens": 37 + } +} diff --git a/spec/fixtures/ai-proxy/openai/request-transformer/response-not-json.json b/spec/fixtures/ai-proxy/openai/request-transformer/response-not-json.json new file mode 100644 index 000000000000..35c96e723c57 --- /dev/null +++ b/spec/fixtures/ai-proxy/openai/request-transformer/response-not-json.json @@ -0,0 +1,22 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Sure! Here is your JSON: {\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }.\n Can I do anything else for you?", + "role": "assistant" + } + } + ], + "created": 1701947430, + "id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2", + "model": "gpt-3.5-turbo-0613", + "object": "chat.completion", + "system_fingerprint": null, + "usage": { + "completion_tokens": 12, + "prompt_tokens": 25, + "total_tokens": 37 + } +} diff --git a/spec/fixtures/ai-proxy/openai/request-transformer/response-with-bad-instructions.json b/spec/fixtures/ai-proxy/openai/request-transformer/response-with-bad-instructions.json new file mode 100644 index 000000000000..b2f1083419b5 --- /dev/null +++ b/spec/fixtures/ai-proxy/openai/request-transformer/response-with-bad-instructions.json @@ -0,0 +1,22 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Sure! Here's your response: {\n \"status\": 209,\n \"headers\": {\n \"content-type\": \"application/xml\"\n },\n \"body\": \"\n \n Kong A\n 62\n \n \n Kong B\n 84\n \n\"\n}.\nCan I help with anything else?", + "role": "assistant" + } + } + ], + "created": 1701947430, + "id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2", + "model": "gpt-3.5-turbo-0613", + "object": "chat.completion", + "system_fingerprint": null, + "usage": { + "completion_tokens": 12, + "prompt_tokens": 25, + "total_tokens": 37 + } +} diff --git a/spec/fixtures/ai-proxy/openai/request-transformer/response-with-instructions.json b/spec/fixtures/ai-proxy/openai/request-transformer/response-with-instructions.json new file mode 100644 index 000000000000..29445e6afbdc --- /dev/null +++ b/spec/fixtures/ai-proxy/openai/request-transformer/response-with-instructions.json @@ -0,0 +1,22 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "{\n \"status\": 209,\n \"headers\": {\n \"content-type\": \"application/xml\"\n },\n \"body\": \"\n \n Kong A\n 62\n \n \n Kong B\n 84\n \n\"\n}\n", + "role": "assistant" + } + } + ], + "created": 1701947430, + "id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2", + "model": "gpt-3.5-turbo-0613", + "object": "chat.completion", + "system_fingerprint": null, + "usage": { + "completion_tokens": 12, + "prompt_tokens": 25, + "total_tokens": 37 + } +}