From c5199ff0c386bc8091273c8a99d6760eb382f102 Mon Sep 17 00:00:00 2001 From: Stephen Brown Date: Mon, 25 Nov 2024 06:54:03 +0000 Subject: [PATCH] feat(llm): add huggingface provider (#13484) AG-70 --- .../kong/feat-add-huggingface-llm-driver.yml | 6 + kong-3.9.0-0.rockspec | 1 + kong/llm/drivers/huggingface.lua | 328 +++++++++++++ kong/llm/drivers/shared.lua | 12 +- kong/llm/schemas/init.lua | 17 +- .../10-huggingface_integration_spec.lua | 451 ++++++++++++++++++ .../llm-v1-chat/requests/good.json | 13 + .../llm-v1-chat/responses/bad_request.json | 4 + .../responses/bad_response_model_load.json | 4 + .../responses/bad_response_timeout.json | 3 + .../llm-v1-chat/responses/good.json | 23 + .../llm-v1-chat/responses/unauthorized.json | 3 + .../llm-v1-completions/requests/good.json | 3 + .../responses/bad_request.json | 3 + .../llm-v1-completions/responses/good.json | 5 + .../responses/unauthorized.json | 3 + 16 files changed, 877 insertions(+), 2 deletions(-) create mode 100644 changelog/unreleased/kong/feat-add-huggingface-llm-driver.yml create mode 100644 kong/llm/drivers/huggingface.lua create mode 100644 spec/03-plugins/38-ai-proxy/10-huggingface_integration_spec.lua create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_request.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_model_load.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_timeout.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/good.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/unauthorized.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-completions/requests/good.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/bad_request.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/good.json create mode 100644 spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/unauthorized.json diff --git a/changelog/unreleased/kong/feat-add-huggingface-llm-driver.yml b/changelog/unreleased/kong/feat-add-huggingface-llm-driver.yml new file mode 100644 index 000000000000..e5dca183563c --- /dev/null +++ b/changelog/unreleased/kong/feat-add-huggingface-llm-driver.yml @@ -0,0 +1,6 @@ +message: | + Addded a new LLM driver for interfacing with the Hugging Face inference API. + The driver supports both serverless and dedicated LLM instances hosted by + Hugging Face for conversational and text generation tasks. +type: feature +scope: Core diff --git a/kong-3.9.0-0.rockspec b/kong-3.9.0-0.rockspec index 3a7ddf689b58..f8150208195f 100644 --- a/kong-3.9.0-0.rockspec +++ b/kong-3.9.0-0.rockspec @@ -633,6 +633,7 @@ build = { ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", ["kong.llm.drivers.gemini"] = "kong/llm/drivers/gemini.lua", ["kong.llm.drivers.bedrock"] = "kong/llm/drivers/bedrock.lua", + ["kong.llm.drivers.huggingface"] = "kong/llm/drivers/huggingface.lua", ["kong.llm.plugin.base"] = "kong/llm/plugin/base.lua", diff --git a/kong/llm/drivers/huggingface.lua b/kong/llm/drivers/huggingface.lua new file mode 100644 index 000000000000..88b0a2ca2d7b --- /dev/null +++ b/kong/llm/drivers/huggingface.lua @@ -0,0 +1,328 @@ +local _M = {} + +-- imports +local cjson = require("cjson.safe") +local fmt = string.format +local ai_shared = require("kong.llm.drivers.shared") +local socket_url = require("socket.url") +-- + +local DRIVER_NAME = "huggingface" + +function _M.pre_request(conf, body) + return true, nil +end + +local function from_huggingface(response_string, model_info, route_type) + local response_table, err = cjson.decode(response_string) + if not response_table then + ngx.log(ngx.ERR, "Failed to decode JSON response from HuggingFace API: ", err) + return nil, "Failed to decode response" + end + + if response_table.error or response_table.message then + local error_msg = response_table.error or response_table.message + ngx.log(ngx.ERR, "Error from HuggingFace API: ", error_msg) + return nil, "API error: " .. error_msg + end + + local transformed_response = { + model = model_info.name, + object = response_table.object or route_type, + choices = {}, + usage = {}, + } + + -- Chat reports usage, generation does not + transformed_response.usage = response_table.usage or {} + + response_table.generated_text = response_table[1] and response_table[1].generated_text or nil + if response_table.generated_text then + table.insert(transformed_response.choices, { + message = { content = response_table.generated_text }, + index = 0, + finish_reason = "complete", + }) + elseif response_table.choices then + for i, choice in ipairs(response_table.choices) do + local content = choice.message and choice.message.content or "" + table.insert(transformed_response.choices, { + message = { content = content }, + index = i - 1, + finish_reason = "complete", + }) + end + else + ngx.log(ngx.ERR, "Unexpected response format from Hugging Face API") + return nil, "Invalid response format" + end + + local result_string, err = cjson.encode(transformed_response) + if not result_string then + ngx.log(ngx.ERR, "Failed to encode transformed response: ", err) + return nil, "Failed to encode response" + end + return result_string, nil +end + +local function set_huggingface_options(model_info) + local use_cache = false + local wait_for_model = false + + if model_info and model_info.options and model_info.options.huggingface then + use_cache = model_info.options.huggingface.use_cache or false + wait_for_model = model_info.options.huggingface.wait_for_model or false + end + + return { + use_cache = use_cache, + wait_for_model = wait_for_model, + } +end + +local function set_default_parameters(request_table) + local parameters = request_table.parameters or {} + if parameters.top_k == nil then + parameters.top_k = request_table.top_k + end + if parameters.top_p == nil then + parameters.top_p = request_table.top_p + end + if parameters.temperature == nil then + parameters.temperature = request_table.temperature + end + if parameters.max_tokens == nil then + if request_table.messages then + -- conversational model use the max_lenght param + -- https://huggingface.co/docs/api-inference/en/detailed_parameters?code=curl#conversational-task + parameters.max_lenght = request_table.max_tokens + else + parameters.max_new_tokens = request_table.max_tokens + end + end + request_table.top_k = nil + request_table.top_p = nil + request_table.temperature = nil + request_table.max_tokens = nil + + return parameters +end + +local function to_huggingface(task, request_table, model_info) + local parameters = set_default_parameters(request_table) + local options = set_huggingface_options(model_info) + if task == "llm/v1/completions" then + request_table.inputs = request_table.prompt + request_table.prompt = nil + end + request_table.options = options + request_table.parameters = parameters + request_table.model = model_info.name or request_table.model + + return request_table, "application/json", nil +end + +local function safe_access(tbl, ...) + local value = tbl + for _, key in ipairs({ ... }) do + value = value and value[key] + if not value then + return nil + end + end + return value +end + +local function handle_huggingface_stream(event_t, model_info, route_type) + -- discard empty frames, it should either be a random new line, or comment + if (not event_t.data) or (#event_t.data < 1) then + return + end + local event, err = cjson.decode(event_t.data) + + if err then + ngx.log(ngx.WARN, "failed to decode stream event frame from Hugging Face: " .. err) + return nil, "failed to decode stream event frame from Hugging Face", nil + end + + local new_event + if route_type == "stream/llm/v1/chat" then + local content = safe_access(event, "choices", 1, "delta", "content") or "" + new_event = { + choices = { + [1] = { + delta = { + content = content, + role = "assistant", + }, + index = 0, + }, + }, + model = event.model or model_info.name, + object = "chat.completion.chunk", + } + else + local text = safe_access(event, "token", "text") or "" + new_event = { + choices = { + [1] = { + text = text, + index = 0, + }, + }, + model = model_info.name, + object = "text_completion", + } + end + return cjson.encode(new_event), nil, nil +end + +local transformers_from = { + ["llm/v1/chat"] = from_huggingface, + ["llm/v1/completions"] = from_huggingface, + ["stream/llm/v1/chat"] = handle_huggingface_stream, + ["stream/llm/v1/completions"] = handle_huggingface_stream, +} + +function _M.from_format(response_string, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from ", model_info.provider, "://", route_type, " type to kong") + + -- MUST return a string, set as the response body + if not transformers_from[route_type] then + return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) + end + + local ok, response_string, err, metadata = + pcall(transformers_from[route_type], response_string, model_info, route_type) + if not ok or err then + return nil, + fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, err or "unexpected_error") + end + + return response_string, nil, metadata +end + +local transformers_to = { + ["llm/v1/chat"] = to_huggingface, + ["llm/v1/completions"] = to_huggingface, +} + +function _M.to_format(request_table, model_info, route_type) + if not transformers_to[route_type] then + return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) + end + + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + + local ok, response_object, content_type, err = + pcall(transformers_to[route_type], route_type, request_table, model_info) + if err or not ok then + return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type) + end + + return response_object, content_type, nil +end + +local function build_url(base_url, route_type) + return (route_type == "llm/v1/completions") and base_url or (base_url .. "/v1/chat/completions") +end + +local function huggingface_endpoint(conf) + local parsed_url + + local base_url + if conf.model.options and conf.model.options.upstream_url then + base_url = conf.model.options.upstream_url + elseif conf.model.name then + base_url = fmt(ai_shared.upstream_url_format[DRIVER_NAME], conf.model.name) + else + return nil + end + + local url = build_url(base_url, conf.route_type) + parsed_url = socket_url.parse(url) + + return parsed_url +end + +function _M.configure_request(conf) + local parsed_url = huggingface_endpoint(conf) + if not parsed_url then + return kong.response.exit(400, "Could not parse the Hugging Face model endponit") + end + if parsed_url.path then + kong.service.request.set_path(parsed_url.path) + end + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, tonumber(parsed_url.port) or 443) + + local auth_header_name = conf.auth and conf.auth.header_name + local auth_header_value = conf.auth and conf.auth.header_value + + if auth_header_name and auth_header_value then + kong.service.request.set_header(auth_header_name, auth_header_value) + end + return true, nil +end + +function _M.post_request(conf) + -- Clear any response headers if needed + if ai_shared.clear_response_headers[DRIVER_NAME] then + for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do + kong.response.clear_header(v) + end + end +end + +function _M.subrequest(body, conf, http_opts, return_res_table) + -- Encode the request body as JSON + local body_string, err = cjson.encode(body) + if not body_string then + return nil, nil, "Failed to encode body to JSON: " .. (err or "unknown error") + end + + -- Construct the Hugging Face API URL + local url = huggingface_endpoint(conf) + if not url then + return nil, nil, "Could not parse the Hugging Face model endpoint" + end + local url_string = url.scheme .. "://" .. url.host .. (url.path or "") + + local headers = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", + } + + if conf.auth and conf.auth.header_name then + headers[conf.auth.header_name] = conf.auth.header_value + end + + local method = "POST" + + local res, err, httpc = ai_shared.http_request(url_string, body_string, method, headers, http_opts, return_res_table) + + -- Handle the response + if not res then + return nil, nil, "Request to Hugging Face API failed: " .. (err or "unknown error") + end + + -- Check if the response should be returned as a table + if return_res_table then + return { + status = res.status, + headers = res.headers, + body = res.body, + }, + res.status, + nil, + httpc + else + if res.status >= 200 and res.status < 300 then + return res.body, res.status, nil + else + return res.body, res.status, "Hugging Face API returned status " .. res.status + end + end +end + +return _M diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 55057fceb984..19c17537287c 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -94,7 +94,8 @@ _M.upstream_url_format = { gemini = "https://generativelanguage.googleapis.com", gemini_vertex = "https://%s", bedrock = "https://bedrock-runtime.%s.amazonaws.com", - mistral = "https://api.mistral.ai:443" + mistral = "https://api.mistral.ai:443", + huggingface = "https://api-inference.huggingface.co/models/%s", } _M.operation_map = { @@ -147,6 +148,15 @@ _M.operation_map = { gemini_vertex = { ["llm/v1/chat"] = { path = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + }, + }, + huggingface = { + ["llm/v1/completions"] = { + path = "/models/%s", + method = "POST", + }, + ["llm/v1/chat"] = { + path = "/models/%s", method = "POST", }, }, diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index 985e7dbe0f6b..127350756c7f 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -37,6 +37,20 @@ local gemini_options_schema = { }, } +local huggingface_options_schema = { + type = "record", + required = false, + fields = { + { use_cache = { + type = "boolean", + description = "Use the cache layer on the inference API", + required = false }}, + { wait_for_model = { + type = "boolean", + description = "Wait for the model if it is not ready", + required = false }}, + }, +} local auth_schema = { type = "record", @@ -179,6 +193,7 @@ local model_options_schema = { required = false }}, { gemini = gemini_options_schema }, { bedrock = bedrock_options_schema }, + { huggingface = huggingface_options_schema}, } } @@ -192,7 +207,7 @@ local model_schema = { type = "string", description = "AI provider request format - Kong translates " .. "requests to and from the specified backend compatible formats.", required = true, - one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini", "bedrock" }}}, + one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini", "bedrock", "huggingface" }}}, { name = { type = "string", description = "Model name to execute.", diff --git a/spec/03-plugins/38-ai-proxy/10-huggingface_integration_spec.lua b/spec/03-plugins/38-ai-proxy/10-huggingface_integration_spec.lua new file mode 100644 index 000000000000..5cec28bd3f2b --- /dev/null +++ b/spec/03-plugins/38-ai-proxy/10-huggingface_integration_spec.lua @@ -0,0 +1,451 @@ +local helpers = require("spec.helpers") +local cjson = require("cjson") +local pl_file = require("pl.file") + +local PLUGIN_NAME = "ai-proxy" +local MOCK_PORT = helpers.get_available_port() + +for _, strategy in helpers.all_strategies() do + if strategy ~= "cassandra" then + describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() + local client + + lazy_setup(function() + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) + + -- set up huggingface mock fixtures + local fixtures = { + http_mock = {}, + } + + fixtures.http_mock.huggingface = [[ + server { + server_name huggingface; + listen ]] .. MOCK_PORT .. [[; + + default_type 'application/json'; + + location = "/v1/chat/completions" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["authorization"] + if token == "Bearer huggingface-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (body.messages == ngx.null) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_request.json")) + else + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/good.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/unauthorized.json")) + end + } + } + + # completions is on the root of huggingface models + location = "/" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["authorization"] + if token == "Bearer huggingface-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (body.prompt == ngx.null) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/bad_request.json")) + else + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/good.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/unauthorized.json")) + end + } + } + location = "/model-loading/v1/chat/completions" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 503 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_model_load.json")) + } + } + location = "/model-timeout/v1/chat/completions" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 504 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_timeout.json")) + } + } + } + ]] + + local empty_service = assert(bp.services:insert({ + name = "empty_service", + host = "localhost", --helpers.mock_upstream_host, + port = 8080, --MOCK_PORT, + path = "/", + })) + + -- 200 chat good with one option + local chat_good = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/huggingface/llm/v1/chat/good" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + route = { id = chat_good.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer huggingface-key", + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.2", + provider = "huggingface", + options = { + max_tokens = 256, + temperature = 1.0, + huggingface = { + use_cache = false, + wait_for_model = true, + }, + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT, + }, + }, + }, + }) + local completions_good = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/huggingface/llm/v1/completions/good" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + route = { id = completions_good.id }, + config = { + route_type = "llm/v1/completions", + auth = { + header_name = "Authorization", + header_value = "Bearer huggingface-key", + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.2", + provider = "huggingface", + options = { + max_tokens = 256, + temperature = 1.0, + huggingface = { + use_cache = false, + wait_for_model = true, + }, + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT, + }, + }, + }, + }) + -- 401 unauthorized + local chat_401 = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/huggingface/llm/v1/chat/unauthorized" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + route = { id = chat_401.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "wrong-key", + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.2", + provider = "huggingface", + options = { + max_tokens = 256, + temperature = 1.0, + huggingface = { + use_cache = false, + wait_for_model = true, + }, + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT, + }, + }, + }, + }) + -- 401 unauthorized + local completions_401 = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/huggingface/llm/v1/completions/unauthorized" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + route = { id = completions_401.id }, + config = { + route_type = "llm/v1/completions", + auth = { + header_name = "api-key", + header_value = "wrong-key", + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.2", + provider = "huggingface", + options = { + max_tokens = 256, + temperature = 1.0, + huggingface = { + use_cache = false, + wait_for_model = true, + }, + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT, + }, + }, + }, + }) + -- 503 Service Temporarily Unavailable + local chat_503 = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/huggingface/llm/v1/chat/bad-response/model-loading" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + route = { id = chat_503.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "huggingface-key", + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.2", + provider = "huggingface", + options = { + max_tokens = 256, + temperature = 1.0, + huggingface = { + use_cache = false, + wait_for_model = false, + }, + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT.."/model-loading", + }, + }, + }, + }) + -- 503 Service Timeout + local chat_503_to = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/huggingface/llm/v1/chat/bad-response/model-timeout" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + route = { id = chat_503_to.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "huggingface-key", + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.2", + provider = "huggingface", + options = { + max_tokens = 256, + temperature = 1.0, + huggingface = { + use_cache = false, + wait_for_model = false, + }, + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT.."/model-timeout", + }, + }, + }, + }) + + -- start kong + assert(helpers.start_kong({ + -- set the strategy + database = strategy, + -- use the custom test template to create a local mock server + nginx_conf = "spec/fixtures/custom_nginx.template", + -- make sure our plugin gets loaded + plugins = "bundled," .. PLUGIN_NAME, + -- write & load declarative config, only if 'strategy=off' + declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, + }, nil, nil, fixtures)) + end) + + lazy_teardown(function() + helpers.stop_kong() + end) + + before_each(function() + client = helpers.proxy_client() + end) + + after_each(function() + if client then + client:close() + end + end) + + describe("huggingface llm/v1/chat", function() + it("good request", function() + local r = client:get("/huggingface/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json"), + }) + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.2") + assert.equals(json.object, "chat.completion") + + assert.is_table(json.choices) + --print("json: ", inspect(json)) + assert.is_string(json.choices[1].message.content) + assert.same( + " The sum of 1 + 1 is 2. This is a basic arithmetic operation and the answer is always the same: adding one to one results in having two in total.", + json.choices[1].message.content + ) + end) + end) + describe("huggingface llm/v1/completions", function() + it("good request", function() + local r = client:get("/huggingface/llm/v1/completions/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("mistralai/Mistral-7B-Instruct-v0.2", json.model) + assert.equals("llm/v1/completions", json.object) + + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("I am a language model AI created by Mistral AI", json.choices[1].message.content) + end) + end) + describe("huggingface no auth", function() + it("unauthorized request chat", function() + local r = client:get("/huggingface/llm/v1/chat/unauthorized", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(401, r) + local json = cjson.decode(body) + assert.equals(json.error, "Authorization header is correct, but the token seems invalid") + end) + it("unauthorized request completions", function() + local r = client:get("/huggingface/llm/v1/completions/unauthorized", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-completions/requests/good.json"), + }) + + local body = assert.res_status(401, r) + local json = cjson.decode(body) + assert.equals(json.error, "Authorization header is correct, but the token seems invalid") + end) + end) + describe("huggingface bad request", function() + it("bad chat request", function() + local r = client:get("/huggingface/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = { messages = ngx.null }, + }) + + local body = assert.res_status(400, r) + local json = cjson.decode(body) + assert.equals(json.error.message, "request body doesn't contain valid prompts") + end) + it("bad completions request", function() + local r = client:get("/huggingface/llm/v1/completions/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = { prompt = ngx.null }, + }) + + local body = assert.res_status(400, r) + local json = cjson.decode(body) + assert.equals(json.error.message, "request body doesn't contain valid prompts") + end) + end) + describe("huggingface bad response", function() + it("bad chat response", function() + local r = client:get("/huggingface/llm/v1/chat/bad-response/model-loading", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(503, r) + local json = cjson.decode(body) + assert.equals(json.error, "Model mistralai/Mistral-7B-Instruct-v0.2 is currently loading") + end) + it("bad completions request", function() + local r = client:get("/huggingface/llm/v1/chat/bad-response/model-timeout", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json"), + }) + local body = assert.res_status(504, r) + local json = cjson.decode(body) + assert.equals(json.error, "Model mistralai/Mistral-7B-Instruct-v0.2 time out") + end) + end) + end) + end +end diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json new file mode 100644 index 000000000000..542b64f8065a --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json @@ -0,0 +1,13 @@ +{ + "messages":[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is 1 + 1?" + } + ], + "stream": false +} diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_request.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_request.json new file mode 100644 index 000000000000..9e4f4b5fd29e --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_request.json @@ -0,0 +1,4 @@ +{ + "error": "Template error: undefined value (in :1)", + "error_type": "template_error" +} diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_model_load.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_model_load.json new file mode 100644 index 000000000000..4de86f6ccccb --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_model_load.json @@ -0,0 +1,4 @@ +{ + "error": "Model mistralai/Mistral-7B-Instruct-v0.2 is currently loading", + "estimated_time": 305.6863708496094 +} diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_timeout.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_timeout.json new file mode 100644 index 000000000000..4ce2a2cc65ca --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/bad_response_timeout.json @@ -0,0 +1,3 @@ +{ + "error": "Model mistralai/Mistral-7B-Instruct-v0.2 time out" +} diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/good.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/good.json new file mode 100644 index 000000000000..2c4c6981d976 --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/good.json @@ -0,0 +1,23 @@ +{ + "object": "chat.completion", + "id": "", + "created": 1722866733, + "model": "mistralai/Mistral-7B-Instruct-v0.2", + "system_fingerprint": "2.1.1-sha-4dfdb48", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": " The sum of 1 + 1 is 2. This is a basic arithmetic operation and the answer is always the same: adding one to one results in having two in total." + }, + "logprobs": null, + "finish_reason": "eos_token" + } + ], + "usage": { + "prompt_tokens": 26, + "completion_tokens": 40, + "total_tokens": 66 + } +} diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/unauthorized.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/unauthorized.json new file mode 100644 index 000000000000..8a50264ccf93 --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-chat/responses/unauthorized.json @@ -0,0 +1,3 @@ +{ + "error": "Authorization header is correct, but the token seems invalid" +} diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/requests/good.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/requests/good.json new file mode 100644 index 000000000000..d66bba88c77c --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/requests/good.json @@ -0,0 +1,3 @@ +{ + "prompt": "What are you?" +} diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/bad_request.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/bad_request.json new file mode 100644 index 000000000000..c8e6e72fd9fd --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/bad_request.json @@ -0,0 +1,3 @@ +{ + "message": "Failed to deserialize the JSON body into the target type: missing field `inputs` at line 9 column 7" +} diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/good.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/good.json new file mode 100644 index 000000000000..d145445ec6e1 --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/good.json @@ -0,0 +1,5 @@ +[ + { + "generated_text": "I am a language model AI created by Mistral AI" + } +] diff --git a/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/unauthorized.json b/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/unauthorized.json new file mode 100644 index 000000000000..8a50264ccf93 --- /dev/null +++ b/spec/fixtures/ai-proxy/huggingface/llm-v1-completions/responses/unauthorized.json @@ -0,0 +1,3 @@ +{ + "error": "Authorization header is correct, but the token seems invalid" +}