Skip to content

Commit

Permalink
feat(llm): add vectordb, embeddings and semantic caching libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
shaneutt authored Jun 18, 2024
1 parent 382f497 commit 434f7f4
Show file tree
Hide file tree
Showing 21 changed files with 1,949 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ mockserver
# kong
nginx_tmp/
kong*.yml
kong*.yaml
kong*.yaml.bak*

# luacov
luacov.*
Expand Down Expand Up @@ -49,3 +51,6 @@ bin/h2client
*.wasm
spec/fixtures/proxy_wasm_filters/build
spec/fixtures/proxy_wasm_filters/target

# python
__pycache__
11 changes: 11 additions & 0 deletions kong-3.8.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,17 @@ build = {
["kong.plugins.ai-response-transformer.handler"] = "kong/plugins/ai-response-transformer/handler.lua",
["kong.plugins.ai-response-transformer.schema"] = "kong/plugins/ai-response-transformer/schema.lua",

["kong.ai.typedefs"] = "kong/ai/typedefs.lua",
["kong.ai.embeddings"] = "kong/ai/embeddings/init.lua",
["kong.ai.embeddings.drivers.openai"] = "kong/ai/embeddings/drivers/openai.lua",
["kong.ai.embeddings.drivers.mistralai"] = "kong/ai/embeddings/drivers/mistralai.lua",
["kong.ai.semantic_cache"] = "kong/ai/semantic_cache/init.lua",
["kong.ai.semantic_cache.utils"] = "kong/ai/semantic_cache/utils.lua",
["kong.ai.semantic_cache.drivers.redis"] = "kong/ai/semantic_cache/drivers/redis.lua",
["kong.ai.vector_databases.drivers.redis.client"] = "kong/ai/vector_databases/drivers/redis/client.lua",
["kong.ai.vector_databases.drivers.redis.index"] = "kong/ai/vector_databases/drivers/redis/index.lua",
["kong.ai.vector_databases.drivers.redis.vectors"] = "kong/ai/vector_databases/drivers/redis/vectors.lua",

["kong.llm"] = "kong/llm/init.lua",
["kong.llm.schemas"] = "kong/llm/schemas/init.lua",
["kong.llm.drivers.shared"] = "kong/llm/drivers/shared.lua",
Expand Down
96 changes: 96 additions & 0 deletions kong/ai/embeddings/drivers/mistralai.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
-- This software is copyright Kong Inc. and its licensors.
-- Use of the software is subject to the agreement between your organization
-- and Kong Inc. If there is no such agreement, use is governed by and
-- subject to the terms of the Kong Master Software License Agreement found
-- at https://konghq.com/enterprisesoftwarelicense/.
-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ]

--
-- imports
--

local cjson = require("cjson.safe")
local http = require("resty.http")

local deep_copy = require("kong.tools.table").deep_copy
local gzip = require("kong.tools.gzip")

--
-- vars
--

local embeddings_url = "https://api.mistral.ai/v1/embeddings"

--
-- driver object
--

-- Driver is an interface for a mistralai embeddings driver.
local Driver = {}
Driver.__index = Driver

-- Constructs a new Driver
--
-- @param provided_embeddings_config embeddings driver configuration
-- @param dimensions the number of dimensions for generating embeddings
-- @return the Driver object
function Driver:new(provided_embeddings_config, dimensions)
local driver_config = deep_copy(provided_embeddings_config)
driver_config.dimensions = dimensions
return setmetatable(driver_config, Driver)
end

-- Generates the embeddings (vectors) for a given prompt
--
-- @param prompt the prompt to generate embeddings for
-- @return the API response containing the embeddings
-- @return nothing. throws an error if any
function Driver:generate(prompt)
local body, err = cjson.encode({
input = prompt,
model = self.model,
encoding_format = "float",
})
if err then
return nil, err
end

kong.log.debug("[mistralai] generating embeddings for prompt")
local httpc = http.new({
ssl_verify = true,
ssl_cafile = kong.configuration.lua_ssl_trusted_certificate_combined,
})
local res, err = httpc:request_uri(embeddings_url, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["Accept-Encoding"] = "gzip",
["Authorization"] = "Bearer " .. self.auth.token,
},
body = body,
})
if not res then
return nil, string.format("failed to generate embeddings (%s): %s", embeddings_url, err)
end
if res.status ~= 200 then
return nil, string.format("unexpected embeddings response (%s): %s", embeddings_url, res.status)
end

local inflated_body = gzip.inflate_gzip(res.body)
local embedding_response, err = cjson.decode(inflated_body)
if err then
return nil, err
end

if not embedding_response.data or #embedding_response.data == 0 then
return nil, "no embeddings found in response"
end

return embedding_response.data[1].embedding, nil
end

--
-- module
--

return Driver
99 changes: 99 additions & 0 deletions kong/ai/embeddings/drivers/openai.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
-- This software is copyright Kong Inc. and its licensors.
-- Use of the software is subject to the agreement between your organization
-- and Kong Inc. If there is no such agreement, use is governed by and
-- subject to the terms of the Kong Master Software License Agreement found
-- at https://konghq.com/enterprisesoftwarelicense/.
-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ]

--
-- imports
--

local cjson = require("cjson.safe")
local http = require("resty.http")

local deep_copy = require("kong.tools.table").deep_copy
local gzip = require("kong.tools.gzip")

--
-- vars
--

local embeddings_url = "https://api.openai.com/v1/embeddings"

--
-- driver object
--

-- Driver is an interface for a openai embeddings driver.
local Driver = {}
Driver.__index = Driver

-- Constructs a new Driver
--
-- @param provided_embeddings_config embeddings driver configuration
-- @param dimensions the number of dimensions for generating embeddings
-- @return the Driver object
function Driver:new(provided_embeddings_config, dimensions)
local driver_config = deep_copy(provided_embeddings_config)
driver_config.dimensions = dimensions
return setmetatable(driver_config, Driver)
end

-- Generates the embeddings (vectors) for a given prompt
--
-- @param prompt the prompt to generate embeddings for
-- @return the API response containing the embeddings
-- @return nothing. throws an error if any
function Driver:generate(prompt)
-- prepare prompt for embedding generation
local body, err = cjson.encode({
input = prompt,
dimensions = self.dimensions,
model = self.model,
})
if err then
return nil, err
end

kong.log.debug("[openai] generating embeddings for prompt")
local httpc = http.new({
ssl_verify = true,
ssl_cafile = kong.configuration.lua_ssl_trusted_certificate_combined,
})
local res, err = httpc:request_uri(embeddings_url, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["Accept-Encoding"] = "gzip", -- explicitly set because OpenAI likes to change this
["Authorization"] = self.auth.token,
},
body = body,
})
if not res then
return nil, string.format("failed to generate embeddings (%s): %s", embeddings_url, err)
end
if res.status ~= 200 then
return nil, string.format("unexpected embeddings response (%s): %s", embeddings_url, res.status)
end

-- decompress the embeddings response
local inflated_body = gzip.inflate_gzip(res.body)
local embedding_response, err = cjson.decode(inflated_body)
if err then
return nil, err
end

-- validate if there are embeddings in the response
if #embedding_response.data == 0 then
return nil, "no embeddings found in response"
end

return embedding_response.data[1].embedding, nil
end

--
-- module
--

return Driver
49 changes: 49 additions & 0 deletions kong/ai/embeddings/init.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
-- This software is copyright Kong Inc. and its licensors.
-- Use of the software is subject to the agreement between your organization
-- and Kong Inc. If there is no such agreement, use is governed by and
-- subject to the terms of the Kong Master Software License Agreement found
-- at https://konghq.com/enterprisesoftwarelicense/.
-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ]

--
-- private vars
--

local supported_embeddings = {
openai = "kong.ai.embeddings.drivers.openai",
mistralai = "kong.ai.embeddings.drivers.mistralai",
}

--
-- public functions
--

-- Initializes the appropriate embedding driver given its name.
--
-- @param embeddings_config the configuration for embeddings
-- @param dimensions the number of dimensions for generating embeddings
-- @return the driver module
-- @return nothing. throws an error if any
local function new(embeddings_config, dimensions)
local driver_name = embeddings_config.driver
if not driver_name then
return nil, "empty name provided for embeddings driver"
end

local driver_modname = supported_embeddings[driver_name]
if not driver_modname then
return nil, string.format("unsupported embeddings driver: %s", driver_name)
end

local driver_mod = require(driver_modname)
return driver_mod:new(embeddings_config, dimensions), nil
end

--
-- module
--

return {
-- functions
new = new
}
113 changes: 113 additions & 0 deletions kong/ai/semantic_cache/drivers/redis.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
-- This software is copyright Kong Inc. and its licensors.
-- Use of the software is subject to the agreement between your organization
-- and Kong Inc. If there is no such agreement, use is governed by and
-- subject to the terms of the Kong Master Software License Agreement found
-- at https://konghq.com/enterprisesoftwarelicense/.
-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ]

--
-- imports
--

local deep_copy = require("kong.tools.table").deep_copy

local index = require("kong.ai.vector_databases.drivers.redis.index")
local redis = require("kong.ai.vector_databases.drivers.redis.client")
local vectors = require("kong.ai.vector_databases.drivers.redis.vectors")
local utils = require("kong.ai.semantic_cache.utils")

---
--- private functions
---

-- Performs setup of the Redis database, including things like creating
-- indexes needed for vector search.
--
-- @param driver_config the configuration for the driver
-- @return boolean indicating success
-- @return nothing. throws an error if any
local function database_setup(driver_config)
kong.log.debug("[redis] creating index")
local index_name = utils.full_index_name(driver_config.index)
local prefix = driver_config.index
local succeded, err = index.create(
driver_config.red,
index_name,
prefix,
driver_config.dimensions,
driver_config.distance_metric
)
if err then
return false, err
end

if not succeded then
return false, "failed to create index"
end

return true, nil
end

---
--- driver object
---

-- Driver is an interface for a redis database.
local Driver = {}
Driver.__index = Driver

-- Constructs a new Driver
--
-- @param provided_driver_config the configuration for the driver
-- @return the Driver object
-- @return nothing. throws an error if any
function Driver:new(provided_driver_config)
local driver_config = deep_copy(provided_driver_config)

local red, err = redis.create(driver_config)
if err then
return false, err
end
driver_config.red = red

local _, err = database_setup(driver_config)
if err then
return nil, err
end

return setmetatable(driver_config, Driver), nil
end

-- Retrieves a cache entry for a given vector.
--
-- @param vector the vector to search
-- @param threshold the proximity threshold for results
-- @return the cache payload, if any
-- @return nothing. throws an error if any
function Driver:get_cache(vector, threshold)
if not threshold then
threshold = self.default_threshold
end

local index_name = utils.full_index_name(self.index)

return vectors.search(self.red, index_name, vector, threshold)
end

-- Insert a cache entry for a given vector and payload.
-- Generates a unique cache key is the format of <index>:<vector>.
--
-- @param vector the vector to search
-- @param payload the payload to be cached
-- @return boolean indicating success
-- @return nothing. throws an error if any
function Driver:set_cache(vector, payload)
local key = utils.cache_key(self.index)
return vectors.create(self.red, key, vector, payload)
end

--
-- module
--

return Driver
Loading

0 comments on commit 434f7f4

Please sign in to comment.