From 434f7f4c7e185809ca61210881b66c8cb79eaaed Mon Sep 17 00:00:00 2001 From: Shane Utt Date: Tue, 18 Jun 2024 07:09:16 -0400 Subject: [PATCH] feat(llm): add vectordb, embeddings and semantic caching libraries --- .gitignore | 5 + kong-3.8.0-0.rockspec | 11 + kong/ai/embeddings/drivers/mistralai.lua | 96 ++++++ kong/ai/embeddings/drivers/openai.lua | 99 ++++++ kong/ai/embeddings/init.lua | 49 +++ kong/ai/semantic_cache/drivers/redis.lua | 113 +++++++ kong/ai/semantic_cache/init.lua | 47 +++ kong/ai/semantic_cache/utils.lua | 53 ++++ kong/ai/typedefs.lua | 156 ++++++++++ .../vector_databases/drivers/redis/client.lua | 78 +++++ .../vector_databases/drivers/redis/index.lua | 65 ++++ .../drivers/redis/vectors.lua | 147 +++++++++ .../30-ai/01-embeddings/01-openai_spec.lua | 47 +++ .../30-ai/01-embeddings/02-mistralai_spec.lua | 47 +++ .../02-vector_databases/01-redis_spec.lua | 188 ++++++++++++ .../30-ai/03-semantic_cache/01-utils_spec.lua | 45 +++ .../30-ai/03-semantic_cache/02-redis_spec.lua | 193 ++++++++++++ spec/helpers/ai/embeddings_mock.lua | 93 ++++++ spec/helpers/ai/mistralai_mock.lua | 64 ++++ spec/helpers/ai/openai_mock.lua | 64 ++++ spec/helpers/ai/redis_mock.lua | 289 ++++++++++++++++++ 21 files changed, 1949 insertions(+) create mode 100644 kong/ai/embeddings/drivers/mistralai.lua create mode 100644 kong/ai/embeddings/drivers/openai.lua create mode 100644 kong/ai/embeddings/init.lua create mode 100644 kong/ai/semantic_cache/drivers/redis.lua create mode 100644 kong/ai/semantic_cache/init.lua create mode 100644 kong/ai/semantic_cache/utils.lua create mode 100644 kong/ai/typedefs.lua create mode 100644 kong/ai/vector_databases/drivers/redis/client.lua create mode 100644 kong/ai/vector_databases/drivers/redis/index.lua create mode 100644 kong/ai/vector_databases/drivers/redis/vectors.lua create mode 100644 spec/01-unit/30-ai/01-embeddings/01-openai_spec.lua create mode 100644 spec/01-unit/30-ai/01-embeddings/02-mistralai_spec.lua create mode 100644 spec/01-unit/30-ai/02-vector_databases/01-redis_spec.lua create mode 100644 spec/01-unit/30-ai/03-semantic_cache/01-utils_spec.lua create mode 100644 spec/01-unit/30-ai/03-semantic_cache/02-redis_spec.lua create mode 100644 spec/helpers/ai/embeddings_mock.lua create mode 100644 spec/helpers/ai/mistralai_mock.lua create mode 100644 spec/helpers/ai/openai_mock.lua create mode 100644 spec/helpers/ai/redis_mock.lua diff --git a/.gitignore b/.gitignore index 0918890cb713..26691f1b371d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,8 @@ mockserver # kong nginx_tmp/ kong*.yml +kong*.yaml +kong*.yaml.bak* # luacov luacov.* @@ -49,3 +51,6 @@ bin/h2client *.wasm spec/fixtures/proxy_wasm_filters/build spec/fixtures/proxy_wasm_filters/target + +# python +__pycache__ diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index d79ccd7c1025..7a63c78eb05b 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -882,6 +882,17 @@ build = { ["kong.plugins.ai-response-transformer.handler"] = "kong/plugins/ai-response-transformer/handler.lua", ["kong.plugins.ai-response-transformer.schema"] = "kong/plugins/ai-response-transformer/schema.lua", + ["kong.ai.typedefs"] = "kong/ai/typedefs.lua", + ["kong.ai.embeddings"] = "kong/ai/embeddings/init.lua", + ["kong.ai.embeddings.drivers.openai"] = "kong/ai/embeddings/drivers/openai.lua", + ["kong.ai.embeddings.drivers.mistralai"] = "kong/ai/embeddings/drivers/mistralai.lua", + ["kong.ai.semantic_cache"] = "kong/ai/semantic_cache/init.lua", + ["kong.ai.semantic_cache.utils"] = "kong/ai/semantic_cache/utils.lua", + ["kong.ai.semantic_cache.drivers.redis"] = "kong/ai/semantic_cache/drivers/redis.lua", + ["kong.ai.vector_databases.drivers.redis.client"] = "kong/ai/vector_databases/drivers/redis/client.lua", + ["kong.ai.vector_databases.drivers.redis.index"] = "kong/ai/vector_databases/drivers/redis/index.lua", + ["kong.ai.vector_databases.drivers.redis.vectors"] = "kong/ai/vector_databases/drivers/redis/vectors.lua", + ["kong.llm"] = "kong/llm/init.lua", ["kong.llm.schemas"] = "kong/llm/schemas/init.lua", ["kong.llm.drivers.shared"] = "kong/llm/drivers/shared.lua", diff --git a/kong/ai/embeddings/drivers/mistralai.lua b/kong/ai/embeddings/drivers/mistralai.lua new file mode 100644 index 000000000000..dd7f1c7fc451 --- /dev/null +++ b/kong/ai/embeddings/drivers/mistralai.lua @@ -0,0 +1,96 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local cjson = require("cjson.safe") +local http = require("resty.http") + +local deep_copy = require("kong.tools.table").deep_copy +local gzip = require("kong.tools.gzip") + +-- +-- vars +-- + +local embeddings_url = "https://api.mistral.ai/v1/embeddings" + +-- +-- driver object +-- + +-- Driver is an interface for a mistralai embeddings driver. +local Driver = {} +Driver.__index = Driver + +-- Constructs a new Driver +-- +-- @param provided_embeddings_config embeddings driver configuration +-- @param dimensions the number of dimensions for generating embeddings +-- @return the Driver object +function Driver:new(provided_embeddings_config, dimensions) + local driver_config = deep_copy(provided_embeddings_config) + driver_config.dimensions = dimensions + return setmetatable(driver_config, Driver) +end + +-- Generates the embeddings (vectors) for a given prompt +-- +-- @param prompt the prompt to generate embeddings for +-- @return the API response containing the embeddings +-- @return nothing. throws an error if any +function Driver:generate(prompt) + local body, err = cjson.encode({ + input = prompt, + model = self.model, + encoding_format = "float", + }) + if err then + return nil, err + end + + kong.log.debug("[mistralai] generating embeddings for prompt") + local httpc = http.new({ + ssl_verify = true, + ssl_cafile = kong.configuration.lua_ssl_trusted_certificate_combined, + }) + local res, err = httpc:request_uri(embeddings_url, { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + ["Accept-Encoding"] = "gzip", + ["Authorization"] = "Bearer " .. self.auth.token, + }, + body = body, + }) + if not res then + return nil, string.format("failed to generate embeddings (%s): %s", embeddings_url, err) + end + if res.status ~= 200 then + return nil, string.format("unexpected embeddings response (%s): %s", embeddings_url, res.status) + end + + local inflated_body = gzip.inflate_gzip(res.body) + local embedding_response, err = cjson.decode(inflated_body) + if err then + return nil, err + end + + if not embedding_response.data or #embedding_response.data == 0 then + return nil, "no embeddings found in response" + end + + return embedding_response.data[1].embedding, nil +end + +-- +-- module +-- + +return Driver diff --git a/kong/ai/embeddings/drivers/openai.lua b/kong/ai/embeddings/drivers/openai.lua new file mode 100644 index 000000000000..e08a27700f5f --- /dev/null +++ b/kong/ai/embeddings/drivers/openai.lua @@ -0,0 +1,99 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local cjson = require("cjson.safe") +local http = require("resty.http") + +local deep_copy = require("kong.tools.table").deep_copy +local gzip = require("kong.tools.gzip") + +-- +-- vars +-- + +local embeddings_url = "https://api.openai.com/v1/embeddings" + +-- +-- driver object +-- + +-- Driver is an interface for a openai embeddings driver. +local Driver = {} +Driver.__index = Driver + +-- Constructs a new Driver +-- +-- @param provided_embeddings_config embeddings driver configuration +-- @param dimensions the number of dimensions for generating embeddings +-- @return the Driver object +function Driver:new(provided_embeddings_config, dimensions) + local driver_config = deep_copy(provided_embeddings_config) + driver_config.dimensions = dimensions + return setmetatable(driver_config, Driver) +end + +-- Generates the embeddings (vectors) for a given prompt +-- +-- @param prompt the prompt to generate embeddings for +-- @return the API response containing the embeddings +-- @return nothing. throws an error if any +function Driver:generate(prompt) + -- prepare prompt for embedding generation + local body, err = cjson.encode({ + input = prompt, + dimensions = self.dimensions, + model = self.model, + }) + if err then + return nil, err + end + + kong.log.debug("[openai] generating embeddings for prompt") + local httpc = http.new({ + ssl_verify = true, + ssl_cafile = kong.configuration.lua_ssl_trusted_certificate_combined, + }) + local res, err = httpc:request_uri(embeddings_url, { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + ["Accept-Encoding"] = "gzip", -- explicitly set because OpenAI likes to change this + ["Authorization"] = self.auth.token, + }, + body = body, + }) + if not res then + return nil, string.format("failed to generate embeddings (%s): %s", embeddings_url, err) + end + if res.status ~= 200 then + return nil, string.format("unexpected embeddings response (%s): %s", embeddings_url, res.status) + end + + -- decompress the embeddings response + local inflated_body = gzip.inflate_gzip(res.body) + local embedding_response, err = cjson.decode(inflated_body) + if err then + return nil, err + end + + -- validate if there are embeddings in the response + if #embedding_response.data == 0 then + return nil, "no embeddings found in response" + end + + return embedding_response.data[1].embedding, nil +end + +-- +-- module +-- + +return Driver diff --git a/kong/ai/embeddings/init.lua b/kong/ai/embeddings/init.lua new file mode 100644 index 000000000000..d42bfbdc7f53 --- /dev/null +++ b/kong/ai/embeddings/init.lua @@ -0,0 +1,49 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- private vars +-- + +local supported_embeddings = { + openai = "kong.ai.embeddings.drivers.openai", + mistralai = "kong.ai.embeddings.drivers.mistralai", +} + +-- +-- public functions +-- + +-- Initializes the appropriate embedding driver given its name. +-- +-- @param embeddings_config the configuration for embeddings +-- @param dimensions the number of dimensions for generating embeddings +-- @return the driver module +-- @return nothing. throws an error if any +local function new(embeddings_config, dimensions) + local driver_name = embeddings_config.driver + if not driver_name then + return nil, "empty name provided for embeddings driver" + end + + local driver_modname = supported_embeddings[driver_name] + if not driver_modname then + return nil, string.format("unsupported embeddings driver: %s", driver_name) + end + + local driver_mod = require(driver_modname) + return driver_mod:new(embeddings_config, dimensions), nil +end + +-- +-- module +-- + +return { + -- functions + new = new +} diff --git a/kong/ai/semantic_cache/drivers/redis.lua b/kong/ai/semantic_cache/drivers/redis.lua new file mode 100644 index 000000000000..c9043450e450 --- /dev/null +++ b/kong/ai/semantic_cache/drivers/redis.lua @@ -0,0 +1,113 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local deep_copy = require("kong.tools.table").deep_copy + +local index = require("kong.ai.vector_databases.drivers.redis.index") +local redis = require("kong.ai.vector_databases.drivers.redis.client") +local vectors = require("kong.ai.vector_databases.drivers.redis.vectors") +local utils = require("kong.ai.semantic_cache.utils") + +--- +--- private functions +--- + +-- Performs setup of the Redis database, including things like creating +-- indexes needed for vector search. +-- +-- @param driver_config the configuration for the driver +-- @return boolean indicating success +-- @return nothing. throws an error if any +local function database_setup(driver_config) + kong.log.debug("[redis] creating index") + local index_name = utils.full_index_name(driver_config.index) + local prefix = driver_config.index + local succeded, err = index.create( + driver_config.red, + index_name, + prefix, + driver_config.dimensions, + driver_config.distance_metric + ) + if err then + return false, err + end + + if not succeded then + return false, "failed to create index" + end + + return true, nil +end + +--- +--- driver object +--- + +-- Driver is an interface for a redis database. +local Driver = {} +Driver.__index = Driver + +-- Constructs a new Driver +-- +-- @param provided_driver_config the configuration for the driver +-- @return the Driver object +-- @return nothing. throws an error if any +function Driver:new(provided_driver_config) + local driver_config = deep_copy(provided_driver_config) + + local red, err = redis.create(driver_config) + if err then + return false, err + end + driver_config.red = red + + local _, err = database_setup(driver_config) + if err then + return nil, err + end + + return setmetatable(driver_config, Driver), nil +end + +-- Retrieves a cache entry for a given vector. +-- +-- @param vector the vector to search +-- @param threshold the proximity threshold for results +-- @return the cache payload, if any +-- @return nothing. throws an error if any +function Driver:get_cache(vector, threshold) + if not threshold then + threshold = self.default_threshold + end + + local index_name = utils.full_index_name(self.index) + + return vectors.search(self.red, index_name, vector, threshold) +end + +-- Insert a cache entry for a given vector and payload. +-- Generates a unique cache key is the format of :. +-- +-- @param vector the vector to search +-- @param payload the payload to be cached +-- @return boolean indicating success +-- @return nothing. throws an error if any +function Driver:set_cache(vector, payload) + local key = utils.cache_key(self.index) + return vectors.create(self.red, key, vector, payload) +end + +-- +-- module +-- + +return Driver diff --git a/kong/ai/semantic_cache/init.lua b/kong/ai/semantic_cache/init.lua new file mode 100644 index 000000000000..02ac8e43b2b4 --- /dev/null +++ b/kong/ai/semantic_cache/init.lua @@ -0,0 +1,47 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- private vars +-- + +local supported_vector_databases = { + redis = "kong.ai.semantic_cache.drivers.redis", +} + +-- +-- public functions +-- + +-- Initializes the appropriate vector database driver given its name. +-- +-- @param vectordb_config the configuration for the vector database driver +-- @return the driver module +-- @return nothing. throws an error if any +local function new(vectordb_config) + local driver_name = vectordb_config and vectordb_config.driver + if not driver_name then + return nil, "empty name provided for vector database driver" + end + + local driver_modname = supported_vector_databases[driver_name] + if not driver_modname then + return nil, string.format("unsupported vector database driver: %s", driver_name) + end + + local driver_mod = require(driver_modname) + return driver_mod:new(vectordb_config) +end + +-- +-- module +-- + +return { + -- functions + new = new +} diff --git a/kong/ai/semantic_cache/utils.lua b/kong/ai/semantic_cache/utils.lua new file mode 100644 index 000000000000..d6d80c2072d1 --- /dev/null +++ b/kong/ai/semantic_cache/utils.lua @@ -0,0 +1,53 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local utils = require("kong.tools.utils") + +-- +-- public functions +-- + +-- Given a simple name generates the full and opinionated name that we use as +-- a standard for all indexes managed by this driver. +-- +-- @param index the name of the index +-- @return the full index name +local function full_index_name(index) + return "idx:" .. index .. "_vss" +end + +-- Returns a cache key for a given index. This is our opinioned way to store +-- semantic caching keys in the cache and to make them unique. +-- +-- e.g. "kong_aigateway_semantic_cache:609594e6-9dee-410a-a9ea-a87745da8160" +-- +-- The UUID can be provided for formatting an existing known key. +-- +-- @param index the name of the index +-- @param uuid (optional) a known UUID to format the key +-- @return the unique cache key +local function cache_key(index, uuid) + if not uuid then + return index .. ":" .. utils.uuid() + end + + return index .. ":" .. uuid +end + +-- +-- module +-- + +return { + -- functions + full_index_name = full_index_name, + cache_key = cache_key, +} diff --git a/kong/ai/typedefs.lua b/kong/ai/typedefs.lua new file mode 100644 index 000000000000..c29031534f1d --- /dev/null +++ b/kong/ai/typedefs.lua @@ -0,0 +1,156 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- typedefs +-- + +-- the authentication configuration for the vector database. +local auth = { + type = "record", + required = false, + fields = { + { + password = { + type = "string", + description = "authentication password", + required = false, + }, + }, + { + token = { + type = "string", + description = "authentication token", + required = false, + }, + }, + }, +} + +-- the configuration for embeddings, which are the vector representations of +-- inference prompts. +local embeddings = { + type = "record", + required = true, + fields = { + { auth = auth }, + { + driver = { + type = "string", + description = "which driver to use for embeddings", + required = true, + one_of = { + "mistralai", + "openai", + }, + }, + }, + { + model = { + type = "string", + description = "which AI model to use for generating embeddings", + required = true, + one_of = { + -- openai + "text-embedding-3-large", + "text-embedding-3-small", + -- mistralai + "mistral-embed", + }, + }, + }, + }, +} + +-- the TLS configuration for the vector database +local tls = { + type = "record", + required = false, + fields = { + { + ssl = { + type = "boolean", + description = "require TLS communication", + required = false, + default = true, + }, + }, + { + ssl_verify = { + type = "boolean", + description = "verify SSL certificates during TLS", + required = false, + default = true, + }, + }, + } +} + +-- the Vector Database configuration +local vectordb = { + type = "record", + required = true, + fields = { + { auth = auth }, + { tls = tls }, + { + driver = { + type = "string", + description = "which vector database driver to use", + required = true, + one_of = { "redis" }, + }, + }, + { + url = { + type = "string", + description = "the URL endpoint to reach the vector database", + required = true, + }, + }, + { + index = { + type = "string", + description = "the name of the index by which vectors can be searched (relevant for redis)", + required = false, + default = "kong_aigateway", + }, + }, + { + dimensions = { + type = "integer", + description = "the desired dimensionality for the vectors", + required = true, + }, + }, + { + default_threshold = { + type = "number", + description = "the default similarity threshold for accepting semantic search results (float)", + required = true, + }, + }, + { + distance_metric = { + type = "string", + description = "the distance metric to use for vector searches", + required = true, + one_of = { "COSINE", "EUCLIDEAN" }, + }, + }, + }, +} + +-- +-- module +-- + +return { + -- typedefs + embeddings = embeddings, + vectordb = vectordb, +} diff --git a/kong/ai/vector_databases/drivers/redis/client.lua b/kong/ai/vector_databases/drivers/redis/client.lua new file mode 100644 index 000000000000..b235a8a276ac --- /dev/null +++ b/kong/ai/vector_databases/drivers/redis/client.lua @@ -0,0 +1,78 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local redis = require("resty.redis") +local urls = require("socket.url") + +-- +-- private vars +-- + +local REDIS_TIMEOUT = 1000 + +-- +-- public functions +-- + +-- Initialize a Redis client and verify connectivity. +-- +-- @param opts an options table including things like the URL, auth, and TLS configuration +-- @return the Redis client +-- @return nothing. throws an error if any +local function create(opts) + local url = opts.url + if not url then + return nil, "missing URL" + end + + local url, err = urls.parse(url) + if err then + return nil, err + end + + local red = redis:new() + red:set_timeouts(REDIS_TIMEOUT, REDIS_TIMEOUT, REDIS_TIMEOUT) + local redis_options = { + ssl = true, + ssl_verify = true, + } + if opts.tls then + redis_options.ssl = opts.tls.ssl + redis_options.ssl_verify = opts.tls.ssl_verify + end + + local _, err = red:connect(url.host, tonumber(url.port), redis_options) + if err then + return red, err + end + + if opts.auth and opts.auth.password then + local ok, err = red:auth(opts.auth.password) + if err then + return red, err + end + if not ok then + return red, "failed to authenticate" + end + end + + local _, err = red:ping() + return red, err +end + +-- +-- module +-- + +return { + -- functions + create = create +} diff --git a/kong/ai/vector_databases/drivers/redis/index.lua b/kong/ai/vector_databases/drivers/redis/index.lua new file mode 100644 index 000000000000..ba945407986f --- /dev/null +++ b/kong/ai/vector_databases/drivers/redis/index.lua @@ -0,0 +1,65 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- public functions +-- + +-- Creates a new opinionated index in Redis for vector search, unless it already exists. +-- +-- This will specifically create an index on a field called $.vector which +-- will need to be present in any cache entry searched via this index. +-- +-- @param red the initialized Redis client +-- @param index the name of the index to create +-- @param prefix the prefix to use for the index +-- @param dimensions the number of dimensions in the vector +-- @param metric the distance metric to use for vector search +-- @return boolean indicating success +-- @return nothing. throws an error if any +local function create(red, index, prefix, dimensions, metric) + local res, err = red["FT.CREATE"](red, + index, "ON", "JSON", + "PREFIX", "1", prefix .. ":", "SCORE", "1.0", + "SCHEMA", "$.vector", "AS", "vector", + "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", + "DIM", dimensions, + "DISTANCE_METRIC", metric + ) + if err and err ~= "Index already exists" then + return false, err + end + + kong.log.debug("[redis] index " .. (res and "created" or "already exists")) + return true, nil +end + +-- Deletes an index in Redis for vector search. +-- +-- @param red the initialized Redis client +-- @param index the name of the index to delete +-- @return boolean indicating success +-- @return nothing. throws an error if any +local function delete(red, index) + kong.log.debug("[redis] deleting index") + local _, err = red["FT.DROPINDEX"](red, index) + if err then + return false, err + end + + return true, nil +end + +-- +-- module +-- + +return { + -- functions + create = create, + delete = delete, +} diff --git a/kong/ai/vector_databases/drivers/redis/vectors.lua b/kong/ai/vector_databases/drivers/redis/vectors.lua new file mode 100644 index 000000000000..f0dc26714869 --- /dev/null +++ b/kong/ai/vector_databases/drivers/redis/vectors.lua @@ -0,0 +1,147 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local cjson = require("cjson.safe") +local ffi = require("ffi") + +-- +-- private functions +-- + +-- Converts a given vector into a byte string. +-- +-- It is currently required by Redis that vectors sent in with FT.SEARCH need +-- to be in a byte string format. We have to use their commands interface +-- directly (since Lua client support for Redis is limited at the time of +-- writing). They do this in their Python client by storing the vector as a +-- numpy array with float32 precision and then converting it to a byte string, +-- e.g.: +-- +-- vector = [0.1, 0.2, 0.3] +-- array = numpy.array(vector, dtype=numpy.float32) +-- bytes = array.tobytes() +-- +-- This function produces equivalent output, and is a bit of a hack. Ideally in +-- the future a higher level vector search API will be available in Redis so +-- we don't have to do this. +-- +-- @param vector the vector to encode to bytes +-- @return the byte string representation of the vector +local function convert_vector_to_bytes(vector) + local float_array = ffi.new("float[?]", #vector, unpack(vector)) + return ffi.string(float_array, ffi.sizeof(float_array)) +end + +-- Sets a cache entry in Redis. +-- +-- @param red the initialized Redis client +-- @param key the cache key to set +-- @param payload the cache payload to set, as a table +-- @return boolean indicating success +-- @return nothing. throws an error if any +local function json_set(red, key, payload) + local json_payload, err = cjson.encode(payload) + if err then + return err + end + return red["JSON.SET"](red, key, "$", json_payload) +end + +-- +-- public functions +-- + +-- Inserts a cache payload into Redis with an associated vector. +-- +-- @param red the initialized Redis client +-- @param key the cache key to use +-- @param vector the vector to associate with the cache +-- @param payload the cache payload to insert +-- @return boolean indicating success +-- @return nothing. throws an error if any +local function create(red, key, vector, payload) + local decoded_payload, err = cjson.decode(payload) + if err then + return false, err + end + decoded_payload.vector = vector -- inserting the vector into the payload is required by redis + + local _, err = json_set(red, key, decoded_payload) + if err then + return false, err + end + + return true, nil +end + +-- Performs a vector search on the Redis cache. +-- +-- @param red the initialized Redis client +-- @param index the name of the index to search +-- @param vector the vector to search +-- @param threshold the proximity threshold for results +-- @return the search results, if any +-- @return an error message, if any +local function search(red, index, vector, threshold) + kong.log.debug("[redis] performing vector search with threshold ", threshold) + local res, err = red["FT.SEARCH"](red, index, + "@vector:[VECTOR_RANGE $range $query_vector]=>{$YIELD_DISTANCE_AS: vector_score}", + "SORTBY", "vector_score", "DIALECT", "2", "LIMIT", "0", "4", "PARAMS", "4", "query_vector", + convert_vector_to_bytes(vector), + "range", threshold + ) + if err then + return nil, err + end + + -- Redis will return nothing when there are no keys in the prefix + if #res == 0 then + return + end + + -- Redis will return a 0 when keys were found in the index prefix, but none matched + if res[1] == 0 then + return + end + + local nested_table = res[3] + if not nested_table then + return nil, "unexpected search response: no value found in result set" + end + + local json_payload = nested_table[4] + if not json_payload then + return nil, "unexpected search response: no JSON payload found in result set" + end + + local decoded_payload, err = cjson.decode(json_payload) + if err then + return nil, err + end + + -- redis requires that the vector be stored in the cache, but we don't want to return that to the user. + -- we might consider later whether we would store the cache payload nested and adjacent and use another + -- mechanism in the search to retrieve it without the vector. + decoded_payload.vector = nil + + kong.log.debug("[redis] result found with score ", nested_table[2]) + return decoded_payload, nil +end + +-- +-- module +-- + +return { + -- functions + create = create, + search = search, +} diff --git a/spec/01-unit/30-ai/01-embeddings/01-openai_spec.lua b/spec/01-unit/30-ai/01-embeddings/01-openai_spec.lua new file mode 100644 index 000000000000..e3cd81a4e452 --- /dev/null +++ b/spec/01-unit/30-ai/01-embeddings/01-openai_spec.lua @@ -0,0 +1,47 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local openai_mock = require("spec.helpers.ai.openai_mock") + +local known_text_embeddings = require("spec.helpers.ai.embeddings_mock").known_text_embeddings + +-- +-- test setup +-- + +-- initialize kong.global (so logging works, e.t.c.) +local kong_global = require "kong.global" +_G.kong = kong_global.new() +kong_global.init_pdk(kong, nil) + +-- +-- tests +-- + +describe("[openai]", function() + describe("embeddings:", function() + it("can generate embeddings", function() + openai_mock.setup(finally) + local embeddings, err = require("kong.ai.embeddings").new({ + driver = "openai", + model = "text-embedding-3-small", + auth = { token = "fake" }, + }, 4) + assert.is_nil(err) + + for prompt, embedding in pairs(known_text_embeddings) do + local found_embedding, err = embeddings:generate(prompt) + assert.is_nil(err) + assert.are.same(embedding, found_embedding) + end + end) + end) +end) diff --git a/spec/01-unit/30-ai/01-embeddings/02-mistralai_spec.lua b/spec/01-unit/30-ai/01-embeddings/02-mistralai_spec.lua new file mode 100644 index 000000000000..5e8d25b7dfbd --- /dev/null +++ b/spec/01-unit/30-ai/01-embeddings/02-mistralai_spec.lua @@ -0,0 +1,47 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local mistralai_mock = require("spec.helpers.ai.mistralai_mock") + +local known_text_embeddings = require("spec.helpers.ai.embeddings_mock").known_text_embeddings + +-- +-- test setup +-- + +-- initialize kong.global (so logging works, e.t.c.) +local kong_global = require "kong.global" +_G.kong = kong_global.new() +kong_global.init_pdk(kong, nil) + +-- +-- tests +-- + +describe("[mistralai]", function() + describe("embeddings:", function() + it("can generate embeddings", function() + mistralai_mock.setup(finally) + local embeddings, err = require("kong.ai.embeddings").new({ + driver = "mistralai", + model = "mistral-embed", + auth = { token = "fake" }, + }, 4) + assert.is_nil(err) + + for prompt, embedding in pairs(known_text_embeddings) do + local found_embedding, err = embeddings:generate(prompt) + assert.is_nil(err) + assert.are.same(embedding, found_embedding) + end + end) + end) +end) diff --git a/spec/01-unit/30-ai/02-vector_databases/01-redis_spec.lua b/spec/01-unit/30-ai/02-vector_databases/01-redis_spec.lua new file mode 100644 index 000000000000..d4cb554ab03a --- /dev/null +++ b/spec/01-unit/30-ai/02-vector_databases/01-redis_spec.lua @@ -0,0 +1,188 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local cjson = require("cjson") + +local redis_mock = require("spec.helpers.ai.redis_mock") + +-- +-- test setup +-- + +-- initialize kong.global (so logging works, e.t.c.) +local kong_global = require "kong.global" +_G.kong = kong_global.new() +kong_global.init_pdk(kong, nil) + +-- +-- test data +-- + +local fake_redis_url = "redis://localhost:6379" +local default_distance_metric = "EUCLIDEAN" +local default_threshold = 0.3 +local test_prefix = "test_prefix" +local test_indexes = { + "test_index1", + "test_index2", + "test_index3", +} +local test_vectors = { + { 1.0, 1.1, -1.1, 3.4 }, + { 1.1, 1.2, -1.1, -0.5 }, + { 5.6, -5.5, -1.6, -0.2 }, +} +local test_vectors_for_search = { + { 1.1, 1.2, -1.0, 3.4 }, -- is in close proximity to test_vectors[1] (threshold 0.3 will hit) + { 100.6, 88.4, -20.5, -5.5 }, -- no close proximity (threshold 0.3 will miss, but 500.0 will hit) + { 99.6, 42.0, -10.5, -128.9 }, -- no close proximity (threshold 0.3 will miss, but 500.0 will hit) +} +local test_payloads = { + [[{"message":"test_payload1"}]], + [[{"message":"test_payload2"}]], + [[{"message":"test_payload3"}]], +} + +-- +-- tests +-- + +describe("[redis vectordb]", function() + describe("client:", function() + it("initializes", function() + redis_mock.setup(finally) + local red, err = require("kong.ai.vector_databases.drivers.redis.client").create({ url = fake_redis_url }) + assert.is_nil(err) + + assert.not_nil(red.indexes) + assert.not_nil(red.cache) + assert.equal(0, #red.indexes) + assert.equal(0, #red.cache) + assert.equal(0, red.key_count) + end) + + it("fails to initialize if the server connection can't be made", function() + redis_mock.setup(finally) + local client = require("kong.ai.vector_databases.drivers.redis.client") + local redis = require("resty.redis") + local err_msg = "connection refused" + redis.forced_failure(err_msg) + + local _, err = client.create({ url = fake_redis_url }) + assert.equal(err_msg, err) + + redis.forced_failure(nil) + end) + end) + + describe("indexes:", function() + it("can manage indexes", function() + redis_mock.setup(finally) + local indexes = require("kong.ai.vector_databases.drivers.redis.index") + local red, err = require("kong.ai.vector_databases.drivers.redis.client").create({ url = fake_redis_url }) + assert.is_nil(err) + + -- creating indexes + for i = 1, #test_indexes do + local succeeded, err = indexes.create(red, test_indexes[i], test_prefix, #test_vectors[1], + default_distance_metric) + assert.is_nil(err) + assert.is_true(succeeded) + end + + -- it should not fail for duplicate indexes + for i = 1, #test_indexes do + local succeeded, err = indexes.create(red, test_indexes[i], test_prefix, #test_vectors[1], + default_distance_metric) + assert.is_nil(err) + assert.is_true(succeeded) + end + + -- deleting indexes + for i = 1, #test_indexes do + local succeeded, err = indexes.delete(red, test_indexes[i]) + assert.is_nil(err) + assert.is_true(succeeded) + end + + -- can't delete non-existent indexes + for i = 1, #test_indexes do + local succeeded, err = indexes.delete(red, test_indexes[i]) + assert.equal("Index not found", err) + assert.is_false(succeeded) + end + end) + end) + + describe("vectors:", function() + it("can manage vectors", function() + redis_mock.setup(finally) + local indexes = require("kong.ai.vector_databases.drivers.redis.index") + local vectors = require("kong.ai.vector_databases.drivers.redis.vectors") + local red, err = require("kong.ai.vector_databases.drivers.redis.client").create({ url = fake_redis_url }) + assert.is_nil(err) + + -- create vectors + for i = 1, #test_indexes do + local succeeded, err = vectors.create(red, test_indexes[i], test_vectors[i], test_payloads[i]) + assert.is_nil(err) + assert.is_true(succeeded) + end + + -- disallow duplicates + for i = 1, #test_indexes do + local succeeded, err = vectors.create(red, test_indexes[i], test_vectors[i], test_payloads[i]) + assert.equal("Already exists", err) + assert.is_false(succeeded) + end + + -- fails on non-existent indexes + local results, err = vectors.search(red, "non_existent_index", test_vectors[1], default_threshold) + assert.is_nil(results) + assert.equal("Index not found", err) + + -- search for vectors that have immediate matches + local succeeded, err = indexes.create(red, test_indexes[1], test_prefix, #test_vectors[1], default_distance_metric) + assert.is_nil(err) + assert.is_true(succeeded) + for i = 1, #test_vectors do + local results, err = vectors.search(red, test_indexes[1], test_vectors[i], default_threshold) + assert.is_nil(err) + assert.equal(default_threshold, red.last_threshold_received) + assert.is_not_nil(results) + assert.equal(test_payloads[i], cjson.encode(results)) + end + + -- search for vectors in close proximity + local vector_known_to_have_another_close_vector = test_vectors_for_search[1] + local results, err = vectors.search(red, test_indexes[1], vector_known_to_have_another_close_vector, + default_threshold) + assert.is_nil(err) + assert.is_not_nil(results) + assert.equal(test_payloads[1], cjson.encode(results)) + + -- cache miss when there are no vectors in close proximity + for i = 2, 3 do + local results, err = vectors.search(red, test_indexes[1], test_vectors_for_search[i], default_threshold) + assert.is_nil(err) + assert.is_nil(results) + end + + -- cache hit for distant vectors if you crank up the threshold + local crazy_threshold = 500.0 + for i = 2, 3 do + local results, err = vectors.search(red, test_indexes[1], test_vectors_for_search[i], crazy_threshold) + assert.is_nil(err) + assert.is_not_nil(results) + end + end) + end) +end) diff --git a/spec/01-unit/30-ai/03-semantic_cache/01-utils_spec.lua b/spec/01-unit/30-ai/03-semantic_cache/01-utils_spec.lua new file mode 100644 index 000000000000..8ca1e8dd06f8 --- /dev/null +++ b/spec/01-unit/30-ai/03-semantic_cache/01-utils_spec.lua @@ -0,0 +1,45 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local utils = require("kong.ai.semantic_cache.utils") + +-- +-- private vars +-- + +local uuid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" +local test_uuid = "ac82f632-1475-449a-a670-8e36a3df2014" + +-- +-- tests +-- + +describe("[utils]", function() + describe("generators:", function() + it("generates full index names", function() + assert.is.equal("idx:test_index1_vss", utils.full_index_name("test_index1")) + assert.is.equal("idx:test_index2_vss", utils.full_index_name("test_index2")) + assert.is.equal("idx:test_index3_vss", utils.full_index_name("test_index3")) + end) + + it("generates cache keys", function() + assert.is.truthy(ngx.re.find(utils.cache_key("test_index1"), "test_index1:" .. uuid_regexp .. "$")) + assert.is.truthy(ngx.re.find(utils.cache_key("test_index2"), "test_index2:" .. uuid_regexp .. "$")) + assert.is.truthy(ngx.re.find(utils.cache_key("test_index3"), "test_index3:" .. uuid_regexp .. "$")) + end) + + it("can be optionally given an existing uuid", function() + assert.equal("test_index1:" .. test_uuid, utils.cache_key("test_index1", test_uuid)) + assert.equal("test_index2:" .. test_uuid, utils.cache_key("test_index2", test_uuid)) + assert.equal("test_index3:" .. test_uuid, utils.cache_key("test_index3", test_uuid)) + end) + end) +end) diff --git a/spec/01-unit/30-ai/03-semantic_cache/02-redis_spec.lua b/spec/01-unit/30-ai/03-semantic_cache/02-redis_spec.lua new file mode 100644 index 000000000000..b0d07569a77a --- /dev/null +++ b/spec/01-unit/30-ai/03-semantic_cache/02-redis_spec.lua @@ -0,0 +1,193 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local cjson = require("cjson") + +local redis_mock = require("spec.helpers.ai.redis_mock") + +local redis_vectordb_utils = require("kong.ai.semantic_cache.utils") + +-- +-- test setup +-- + +-- initialize kong.global (so logging works, e.t.c.) +local kong_global = require "kong.global" +_G.kong = kong_global.new() +kong_global.init_pdk(kong, nil) + +-- +-- test data +-- + +local driver_name = "redis" +local fake_redis_url = "redis://localhost:6379" +local default_distance_metric = "EUCLIDEAN" +local default_threshold = 0.3 +local test_indexes = { + "test_index1", + "test_index2", +} +local test_vectors = { + { 1.0, 1.1, -1.1, 3.4 }, + { 1.1, 1.2, -1.1, -0.5 }, + { 5.6, -5.5, -1.6, -0.2 }, +} +local test_vectors_for_search = { + { 1.1, 1.2, -1.0, 3.4 }, -- is in close proximity to test_vectors[1] (threshold 0.3 will hit) + { 100.6, 88.4, -20.5, -5.5 }, -- no close proximity (threshold 0.3 will miss, but 500.0 will hit) + { 99.6, 42.0, -10.5, -128.9 }, -- no close proximity (threshold 0.3 will miss, but 500.0 will hit) +} +local test_payloads = { + [[{"message":"test_payload1"}]], + [[{"message":"test_payload2"}]], + [[{"message":"test_payload3"}]], +} + +-- +-- tests +-- + +describe("[redis semantic cache]", function() + describe("driver:", function() + it("initializes", function() + redis_mock.setup(finally) + local driver, err = require("kong.ai.semantic_cache").new({ + driver = driver_name, + url = fake_redis_url, + index = test_indexes[1], + dimensions = #test_vectors[1], + distance_metric = default_distance_metric, + default_threshold = default_threshold, + }) + assert.is_nil(err) + + -- check driver initialization + assert.is_not_nil(driver) + assert.equal(driver_name, driver.driver) + assert.equal(fake_redis_url, driver.url) + assert.equal(test_indexes[1], driver.index) + assert.equal(#test_vectors[1], driver.dimensions) + assert.equal(default_distance_metric, driver.distance_metric) + assert.equal(default_distance_metric, driver.red.indexes[redis_vectordb_utils.full_index_name(test_indexes[1])]) + end) + + it("fails to initialze a driver without a valid index", function() + redis_mock.setup(finally) + local driver, err = require("kong.ai.semantic_cache").new({ + driver = driver_name, + url = fake_redis_url, + index = "", -- invalid index + dimensions = #test_vectors[1], + distance_metric = default_distance_metric, + default_threshold = default_threshold, + }) + + -- driver should fail to initialize + assert.equal("Invalid index name", err) + assert.is_nil(driver) + end) + + + + it("can manage cache", function() + redis_mock.setup(finally) + local driver, err = require("kong.ai.semantic_cache").new({ + driver = driver_name, + url = fake_redis_url, + index = test_indexes[1], + dimensions = #test_vectors[1], + distance_metric = default_distance_metric, + default_threshold = default_threshold, + }) + assert.is_nil(err) + assert.is_not_nil(driver) + + -- insert several cache entries + for i = 1, #test_vectors do + local succeeded, err = driver:set_cache(test_vectors[i], test_payloads[i]) + assert.is_nil(err) + assert.is_true(succeeded) + end + assert.equal(3, driver.red.key_count) + + -- should tolerate redundant cache entries + for i = 1, #test_vectors do + local succeeded, err = driver:set_cache(test_vectors[i], test_payloads[i]) + assert.is_nil(err) + assert.is_true(succeeded) + end + assert.equal(6, driver.red.key_count) + + -- should use the default threshold for cache searches when not otherwise prompted + assert.equal(0.0, driver.red.last_threshold_received) + for i = 1, #test_vectors do + local results, err = driver:get_cache(test_vectors[i]) + assert.is_nil(err) + assert.equal(default_threshold, driver.red.last_threshold_received) + assert.is_not_nil(results) + assert.equal(test_payloads[i], cjson.encode(results)) + end + + -- allows a threshold override + local threshold = 0.1 + assert.equal(default_threshold, driver.red.last_threshold_received) + for i = 1, #test_vectors do + local results, err = driver:get_cache(test_vectors[i], threshold) + assert.is_nil(err) + assert.equal(threshold, driver.red.last_threshold_received) + assert.is_not_nil(results) + assert.equal(test_payloads[i], cjson.encode(results)) + end + + -- can search and find a close proximity vector when there's no direct match + local vector_known_to_have_another_close_vector = test_vectors_for_search[1] + local results, err = driver:get_cache(vector_known_to_have_another_close_vector) + assert.is_nil(err) + assert.is_not_nil(results) + assert.equal(test_payloads[1], cjson.encode(results)) + + -- will receive a cache miss for vector searches where no other vectors are even close + for i = 2, 3 do + local results, err = driver:get_cache(test_vectors_for_search[i]) + assert.is_nil(err) + assert.is_nil(results) + end + + -- will receive a cache hit for very distant vectors if you crank up the threshold + local crazy_threshold = 500.0 + for i = 2, 3 do + local results, err = driver:get_cache(test_vectors_for_search[i], crazy_threshold) + assert.is_nil(err) + assert.is_not_nil(results) + end + + -- will return an error if there are connection issues + local redis = require("resty.redis") + local err_msg = "connection refused" + redis.forced_failure(err_msg) + + for i = 1, #test_vectors do + local succeeded, err = driver:set_cache(test_vectors[i], test_payloads[i]) + assert.equal(err_msg, err) + assert.is_false(succeeded) + end + + for i = 1, #test_vectors do + local results, err = driver:get_cache(test_vectors[i]) + assert.is_nil(results) + assert.equal(err_msg, err) + end + + redis.forced_failure(nil) + end) + end) +end) diff --git a/spec/helpers/ai/embeddings_mock.lua b/spec/helpers/ai/embeddings_mock.lua new file mode 100644 index 000000000000..1118177055fd --- /dev/null +++ b/spec/helpers/ai/embeddings_mock.lua @@ -0,0 +1,93 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local cjson = require("cjson") +local gzip = require("kong.tools.gzip") + +-- +-- public vars +-- + +-- some previously generated text embeddings for mocking, using OpenAI's +-- text-embedding-3-small model and 4 dimensions. +local known_text_embeddings = { + ["dog"] = { 0.56267416, -0.20551957, -0.047182854, 0.79933304 }, + ["cat"] = { 0.4653789, -0.42677408, -0.29335415, 0.717795 }, + ["capacitor"] = { 0.350534, -0.025470039, -0.9204002, -0.17129119 }, + ["smell"] = { 0.23342973, -0.08322083, -0.8492907, -0.46614397 }, + ["Non-Perturbative Quantum Field Theory and Resurgence in Supersymmetric Gauge Theories"] = { + -0.6826024, -0.08655233, -0.72073454, -0.084287055, + }, + ["taco"] = { -0.4407651, -0.85174876, -0.27901474, -0.048999753 }, +} + +-- +-- public functions +-- + +local function mock_embeddings(opts) + if opts.method ~= "POST" then + return nil, "Only POST method is supported" + end + + if opts.headers["Content-Type"] ~= "application/json" then + return nil, "Only application/json content type is supported" + end + + if opts.headers["Accept-Encoding"] ~= "gzip" then + return nil, "Only gzip encoding is supported" + end + + if not opts.headers["Authorization"] then + return nil, "Authorization header is required" + end + + local request_body = cjson.decode(opts.body) + + if not request_body.dimensions then + request_body.dimensions = 4 + end + if request_body.dimensions ~= 4 then + return nil, "Only 4 dimensions are supported" + end + + local prompt = request_body.input + local embedding = known_text_embeddings[prompt] + if not embedding then + return nil, "Invalid prompt" + end + + local response_body = { + data = { + { embedding = embedding }, + }, + } + + local encoded_response_body = cjson.encode(response_body) + local gzipped_response_body = gzip.deflate_gzip(encoded_response_body) + + return { + status = 200, + body = gzipped_response_body, + } +end + +-- +-- module +-- + +return { + -- vars + known_text_embeddings = known_text_embeddings, + + -- functions + mock_embeddings = mock_embeddings, +} diff --git a/spec/helpers/ai/mistralai_mock.lua b/spec/helpers/ai/mistralai_mock.lua new file mode 100644 index 000000000000..724c8739574d --- /dev/null +++ b/spec/helpers/ai/mistralai_mock.lua @@ -0,0 +1,64 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local mocker = require("spec.fixtures.mocker") + +local mock_embeddings = require("spec.helpers.ai.embeddings_mock").mock_embeddings + +-- +-- private vars +-- + +local api = "https://api.mistral.ai" +local embeddings_url = api .. "/v1/embeddings" + +-- +-- private functions +-- + +local mock_request_router = function(_self, url, opts) + if not string.find("^" .. url, api) then + return nil, "what are you doing?" + end + + if url == embeddings_url then + return mock_embeddings(opts) + end + + return nil, "URL " .. url .. " is not supported by mocking" +end + +-- +-- public functions +-- + +local function setup(finally) + mocker.setup(finally, { + modules = { + { "resty.http", { + new = function() + return { + request_uri = mock_request_router, + } + end, + } }, + } + }) +end + +-- +-- module +-- + +return { + -- functions + setup = setup, +} diff --git a/spec/helpers/ai/openai_mock.lua b/spec/helpers/ai/openai_mock.lua new file mode 100644 index 000000000000..1366c9d6f73a --- /dev/null +++ b/spec/helpers/ai/openai_mock.lua @@ -0,0 +1,64 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local mocker = require("spec.fixtures.mocker") + +local mock_embeddings = require("spec.helpers.ai.embeddings_mock").mock_embeddings + +-- +-- private vars +-- + +local api = "https://api.openai.com" +local embeddings_url = api .. "/v1/embeddings" + +-- +-- private functions +-- + +local mock_request_router = function(_self, url, opts) + if not string.find("^" .. url, api) then + return nil, "what are you doing?" + end + + if url == embeddings_url then + return mock_embeddings(opts) + end + + return nil, "URL " .. url .. " is not supported by mocking" +end + +-- +-- public functions +-- + +local function setup(finally) + mocker.setup(finally, { + modules = { + { "resty.http", { + new = function() + return { + request_uri = mock_request_router, + } + end, + } }, + } + }) +end + +-- +-- module +-- + +return { + -- functions + setup = setup, +} diff --git a/spec/helpers/ai/redis_mock.lua b/spec/helpers/ai/redis_mock.lua new file mode 100644 index 000000000000..464bc930dee7 --- /dev/null +++ b/spec/helpers/ai/redis_mock.lua @@ -0,0 +1,289 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +-- +-- imports +-- + +local cjson = require("cjson.safe") +local ffi = require("ffi") + +local mocker = require("spec.fixtures.mocker") + +-- +-- private vars +-- + +-- the error message to force on the next Redis call +local forced_error_msg = nil + +-- +-- private functions +-- + +-- the default precision to round to during conversion +local default_precision = 1e-6 + +-- Redis requires a vector to be converted to a byte string, this function reverses +-- that process so that we can compare vectors. +-- +-- @param bytes the byte string to convert +-- @param precision the precision to round to (optional) +-- @return the vector +local function convert_bytes_to_vector(bytes, precision) + precision = precision or default_precision + local float_size = ffi.sizeof("float") + local num_floats = #bytes / float_size + local float_array = ffi.cast("float*", bytes) + local vector = {} + for i = 0, num_floats - 1 do + local value = float_array[i] + value = math.floor(value / precision + 0.5) * precision -- round to precision + table.insert(vector, value) + end + return vector +end + +-- Searches for the cosine distance between two vectors, and compares it +-- against a threshold. +-- +-- @param v1 the first vector +-- @param v2 the second vector +-- @param threshold the threshold to compare against +-- @return true if the vectors are within the threshold, false otherwise +-- @return the distance between the vectors +local function cosine_distance(v1, v2, threshold) + local dot_product = 0.0 + local magnitude_v1 = 0.0 + local magnitude_v2 = 0.0 + + for i = 1, #v1 do + dot_product = dot_product + v1[i] * v2[i] + magnitude_v1 = magnitude_v1 + v1[i] ^ 2 + magnitude_v2 = magnitude_v2 + v2[i] ^ 2 + end + + magnitude_v1 = math.sqrt(magnitude_v1) + magnitude_v2 = math.sqrt(magnitude_v2) + + local cosine_similarity = dot_product / (magnitude_v1 * magnitude_v2) + local cosine_distance = 1 - cosine_similarity + + return cosine_distance <= threshold, cosine_distance +end + +-- Searches for the euclidean distance between two vectors, and compares it +-- against a threshold. +-- +-- @param v1 the first vector +-- @param v2 the second vector +-- @param threshold the threshold to compare against +-- @return true if the vectors are within the threshold, false otherwise +-- @return the distance between the vectors +local function euclidean_distance(v1, v2, threshold) + local distance = 0.0 + for i = 1, #v1 do + distance = distance + (v1[i] - v2[i]) ^ 2 + end + + distance = math.sqrt(distance) + + return distance <= threshold, distance +end + +-- +-- public functions +-- + +local function setup(finally) + mocker.setup(finally, { + modules = { + { "resty.redis", { + new = function() + return { + -- function mocks + set_timeouts = function() end, + connect = function() + if forced_error_msg then + return false, forced_error_msg + end + end, + auth = function() + if forced_error_msg then + return false, forced_error_msg + end + end, + ping = function() + if forced_error_msg then + return false, forced_error_msg + end + end, + + -- raw command mocks + ["FT.CREATE"] = function(red, index, ...) + if forced_error_msg then + return false, forced_error_msg + end + + if not index or index == "idx:_vss" then + return false, "Invalid index name" + end + + -- gather the distance metric + local args = { ... } + local distance_metric = args[#args] + if distance_metric ~= "EUCLIDEAN" and distance_metric ~= "COSINE" then + return false, "Invalid distance metric" + end + + red.indexes[index] = distance_metric + return true, nil + end, + ["FT.DROPINDEX"] = function(red, index, ...) + if forced_error_msg then + return false, forced_error_msg + end + + if not red.indexes[index] then + return false, "Index not found" + end + + red.indexes[index] = nil + return true, nil + end, + ["FT.SEARCH"] = function(red, index, ...) + if forced_error_msg then + return nil, forced_error_msg + end + + -- verify whether the index for the search is valid, + -- and determine whether the index was configured + -- with euclidean or cosine distance + local distance_metric = red.indexes[index] + if not distance_metric then + return nil, "Index not found" + end + + -- determine the threshold, and record + local num_args = select("#", ...) + local threshold = select(num_args, ...) + red.last_threshold_received = threshold + + -- determine the vector + local vector_bytes = select(num_args - 2, ...) + local search_vector = convert_bytes_to_vector(vector_bytes) + + -- The caller can override the response with mock_next_search to set this next_response_key + -- and that will force a specific payload to be returned, if desired. + local payload = red.cache[red.next_response_key] + if payload then + -- reset the override + red.next_response_key = nil + + -- the structure Redis would respond with, but we only care about the proximity and payload + return { {}, {}, { {}, "1.0", {}, payload } } + end + + -- if the payload wasn't forced with an override, we'll do a vector search. + -- we won't try to fully emulate Redis' vector search but we can do a simple + -- distance comparison to emulate it. + local payloads = {} + for _key, value in pairs(red.cache) do + local decoded_payload, err = cjson.decode(value) + if err then + return nil, err + end + + -- check the proximity of the found vector + local found_vector = decoded_payload.vector + local proximity_match, distance + if distance_metric == "COSINE" then + proximity_match, distance = cosine_distance(search_vector, found_vector, threshold) + elseif distance_metric == "EUCLIDEAN" then + proximity_match, distance = euclidean_distance(search_vector, found_vector, threshold) + end + if proximity_match then + table.insert(payloads, { {}, tostring(distance), {}, value }) + end + end + + -- sort the payloads by distance + table.sort(payloads, function(a, b) + return tonumber(a[2]) < tonumber(b[2]) + end) + + -- if no payloads were found, just return an empty table to emulate cache miss + if #payloads < 1 then + return {} + end + + -- the structure Redis would respond with, but we only care about the proximity and payload + local res = { {}, {} } -- filler response information from Redis we don't use + for i = 1, #payloads do + table.insert(res, payloads[i]) + end + return res, nil + end, + ["JSON.GET"] = function(red, key) + if forced_error_msg then + return nil, forced_error_msg + end + + return red.cache[key], nil + end, + ["JSON.SET"] = function(red, key, _path, payload) -- currently, path is not used because we only set cache at root + if forced_error_msg then + return false, forced_error_msg + end + + if red.cache[key] ~= nil then + return false, "Already exists" + end + + red.key_count = red.key_count + 1 + red.cache[key] = payload + + return true, nil + end, + ["JSON.DEL"] = function(red, key, path) + if forced_error_msg then + return false, forced_error_msg + end + + red.key_count = red.key_count - 1 + red.cache[key] = nil + + return true, nil + end, + + -- internal tracking + indexes = {}, + key_count = 0, + cache = {}, + next_response_key = nil, + last_threshold_received = 0.0, + } + end, + mock_next_search = function(red, key) + red.next_response_key = key + end, + forced_failure = function(err_msg) + forced_error_msg = err_msg + end, + } }, + } + }) +end + +-- +-- module +-- + +return { + -- functions + setup = setup, +}