From 4eaf25644bbab3f8790ff2cc36b33c23471d0f91 Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Tue, 19 Nov 2024 23:23:03 +0800 Subject: [PATCH] refactor(ai-prompt-template): migrate ai-prompt-template to new framework --- kong-3.9.0-0.rockspec | 9 +- .../filters/render-prompt-template.lua | 122 ++++++++++++++++++ kong/plugins/ai-prompt-template/handler.lua | 113 +--------------- .../02-integration_spec.lua | 4 +- 4 files changed, 135 insertions(+), 113 deletions(-) create mode 100644 kong/plugins/ai-prompt-template/filters/render-prompt-template.lua diff --git a/kong-3.9.0-0.rockspec b/kong-3.9.0-0.rockspec index 61374c496b83f..e0d0d9bcc1c69 100644 --- a/kong-3.9.0-0.rockspec +++ b/kong-3.9.0-0.rockspec @@ -644,14 +644,15 @@ build = { ["kong.llm.plugin.shared-filters.parse-sse-chunk"] = "kong/llm/plugin/shared-filters/parse-sse-chunk.lua", ["kong.llm.plugin.shared-filters.serialize-analytics"] = "kong/llm/plugin/shared-filters/serialize-analytics.lua", - ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", - ["kong.plugins.ai-prompt-decorator.filters.decorate-prompt"] = "kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua", - ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", - ["kong.plugins.ai-prompt-template.handler"] = "kong/plugins/ai-prompt-template/handler.lua", + ["kong.plugins.ai-prompt-template.filters.render-prompt-template"] = "kong/plugins/ai-prompt-template/filters/render-prompt-template.lua", ["kong.plugins.ai-prompt-template.schema"] = "kong/plugins/ai-prompt-template/schema.lua", ["kong.plugins.ai-prompt-template.templater"] = "kong/plugins/ai-prompt-template/templater.lua", + ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", + ["kong.plugins.ai-prompt-decorator.filters.decorate-prompt"] = "kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua", + ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", + ["kong.plugins.ai-prompt-guard.filters.guard-prompt"] = "kong/plugins/ai-prompt-guard/filters/guard-prompt.lua", ["kong.plugins.ai-prompt-guard.handler"] = "kong/plugins/ai-prompt-guard/handler.lua", ["kong.plugins.ai-prompt-guard.schema"] = "kong/plugins/ai-prompt-guard/schema.lua", diff --git a/kong/plugins/ai-prompt-template/filters/render-prompt-template.lua b/kong/plugins/ai-prompt-template/filters/render-prompt-template.lua new file mode 100644 index 0000000000000..b8f5eea570ab6 --- /dev/null +++ b/kong/plugins/ai-prompt-template/filters/render-prompt-template.lua @@ -0,0 +1,122 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +local ai_plugin_ctx = require("kong.llm.plugin.ctx") +local templater = require("kong.plugins.ai-prompt-template.templater") + +local ipairs = ipairs +local type = type + +local _M = { + NAME = "render-prompt-template", + STAGE = "REQ_TRANSFORMATION", + } + +local FILTER_OUTPUT_SCHEMA = { + transformed = "boolean", +} + +local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA) + + +local LOG_ENTRY_KEYS = { + REQUEST_BODY = "ai.payload.original_request", +} + + +local function bad_request(msg) + kong.log.debug(msg) + return kong.response.exit(400, { error = { message = msg } }) +end + + + +-- Checks if the passed in reference looks like a reference, and returns the template name. +-- Valid references start with '{template://' and end with '}'. +-- @tparam string reference reference to check +-- @treturn string the reference template name or nil if it's not a reference +local function extract_template_name(reference) + if type(reference) ~= "string" then + return nil + end + + if not (reference:sub(1, 12) == "{template://" and reference:sub(-1) == "}") then + return nil + end + + return reference:sub(13, -2) +end + + + +--- Find a template by name in the list of templates. +-- @tparam string reference_name the name of the template to find +-- @tparam table templates the list of templates to search +-- @treturn string the template if found, or nil + error message if not found +local function find_template(reference_name, templates) + for _, v in ipairs(templates) do + if v.name == reference_name then + return v, nil + end + end + + return nil, "could not find template name [" .. reference_name .. "]" +end + + + +function _M:run(conf) + if conf.log_original_request then + kong.log.set_serialize_value(LOG_ENTRY_KEYS.REQUEST_BODY, kong.request.get_raw_body(conf.max_request_body_size)) + end + + -- if plugin ordering was altered, receive the "decorated" request + local request_body_table = kong.request.get_body("application/json", nil, conf.max_request_body_size) + if type(request_body_table) ~= "table" then + return bad_request("this LLM route only supports application/json requests") + end + + local messages = request_body_table.messages + local prompt = request_body_table.prompt + + if messages and prompt then + return bad_request("cannot run 'messages' and 'prompt' templates at the same time") + end + + local reference = messages or prompt + if not reference then + return bad_request("only 'llm/v1/chat' and 'llm/v1/completions' formats are supported for templating") + end + + local template_name = extract_template_name(reference) + if not template_name then + if conf.allow_untemplated_requests then + return true -- not a reference, do nothing + end + + return bad_request("this LLM route only supports templated requests") + end + + local requested_template, err = find_template(template_name, conf.templates) + if not requested_template then + return bad_request(err) + end + + -- try to render the replacement request + local rendered_template, err = templater.render(requested_template, request_body_table.properties or {}) + if err then + return bad_request(err) + end + + kong.service.request.set_raw_body(rendered_template) + + set_ctx("transformed", true) + return true +end + + +return _M diff --git a/kong/plugins/ai-prompt-template/handler.lua b/kong/plugins/ai-prompt-template/handler.lua index 1167471700981..16173c3871d47 100644 --- a/kong/plugins/ai-prompt-template/handler.lua +++ b/kong/plugins/ai-prompt-template/handler.lua @@ -1,112 +1,11 @@ -local templater = require("kong.plugins.ai-prompt-template.templater") -local llm_state = require("kong.llm.state") -local ipairs = ipairs -local type = type +local ai_plugin_base = require("kong.llm.plugin.base") +local NAME = "ai-prompt-template" +local PRIORITY = 773 +local AIPlugin = ai_plugin_base.define(NAME, PRIORITY) -local AIPromptTemplateHandler = { - PRIORITY = 773, - VERSION = require("kong.meta").version, -} +AIPlugin:enable(AIPlugin.register_filter(require("kong.plugins." .. NAME .. ".filters.render-prompt-template"))) - - -local LOG_ENTRY_KEYS = { - REQUEST_BODY = "ai.payload.original_request", -} - - - -local function bad_request(msg) - kong.log.debug(msg) - return kong.response.exit(400, { error = { message = msg } }) -end - - - --- Checks if the passed in reference looks like a reference, and returns the template name. --- Valid references start with '{template://' and end with '}'. --- @tparam string reference reference to check --- @treturn string the reference template name or nil if it's not a reference -local function extract_template_name(reference) - if type(reference) ~= "string" then - return nil - end - - if not (reference:sub(1, 12) == "{template://" and reference:sub(-1) == "}") then - return nil - end - - return reference:sub(13, -2) -end - - - ---- Find a template by name in the list of templates. --- @tparam string reference_name the name of the template to find --- @tparam table templates the list of templates to search --- @treturn string the template if found, or nil + error message if not found -local function find_template(reference_name, templates) - for _, v in ipairs(templates) do - if v.name == reference_name then - return v, nil - end - end - - return nil, "could not find template name [" .. reference_name .. "]" -end - - - -function AIPromptTemplateHandler:access(conf) - kong.service.request.enable_buffering() - llm_state.set_prompt_templated() - - if conf.log_original_request then - kong.log.set_serialize_value(LOG_ENTRY_KEYS.REQUEST_BODY, kong.request.get_raw_body(conf.max_request_body_size)) - end - - local request = kong.request.get_body("application/json", nil, conf.max_request_body_size) - if type(request) ~= "table" then - return bad_request("this LLM route only supports application/json requests") - end - - local messages = request.messages - local prompt = request.prompt - - if messages and prompt then - return bad_request("cannot run 'messages' and 'prompt' templates at the same time") - end - - local reference = messages or prompt - if not reference then - return bad_request("only 'llm/v1/chat' and 'llm/v1/completions' formats are supported for templating") - end - - local template_name = extract_template_name(reference) - if not template_name then - if conf.allow_untemplated_requests then - return -- not a reference, do nothing - end - - return bad_request("this LLM route only supports templated requests") - end - - local requested_template, err = find_template(template_name, conf.templates) - if not requested_template then - return bad_request(err) - end - - -- try to render the replacement request - local rendered_template, err = templater.render(requested_template, request.properties or {}) - if err then - return bad_request(err) - end - - kong.service.request.set_raw_body(rendered_template) -end - - -return AIPromptTemplateHandler +return AIPlugin:as_kong_plugin() \ No newline at end of file diff --git a/spec/03-plugins/43-ai-prompt-template/02-integration_spec.lua b/spec/03-plugins/43-ai-prompt-template/02-integration_spec.lua index a834a520b615c..318879eac4833 100644 --- a/spec/03-plugins/43-ai-prompt-template/02-integration_spec.lua +++ b/spec/03-plugins/43-ai-prompt-template/02-integration_spec.lua @@ -5,7 +5,7 @@ local PLUGIN_NAME = "ai-prompt-template" -for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then +for _, strategy in helpers.all_strategies() do describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() local client @@ -426,4 +426,4 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then end) -end end +end