Skip to content

Commit

Permalink
inline strategy now uses adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
olimorris committed Mar 5, 2024
1 parent 82072d3 commit c843a16
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 220 deletions.
31 changes: 3 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,34 +95,9 @@ You only need to the call the `setup` function if you wish to change any of the

```lua
require("codecompanion").setup({
api_key = "OPENAI_API_KEY", -- Your API key
org_api_key = "OPENAI_ORG_KEY", -- Your organisation API key
base_url = "https://api.openai.com", -- The URL to use for the API requests
ai_settings = {
-- Default settings for the Completions API
-- See https://platform.openai.com/docs/api-reference/chat/create
chat = {
model = "gpt-4-0125-preview",
temperature = 1,
top_p = 1,
stop = nil,
max_tokens = nil,
presence_penalty = 0,
frequency_penalty = 0,
logit_bias = nil,
user = nil,
},
inline = {
model = "gpt-3.5-turbo-0125",
temperature = 1,
top_p = 1,
stop = nil,
max_tokens = nil,
presence_penalty = 0,
frequency_penalty = 0,
logit_bias = nil,
user = nil,
},
adapters = {
chat = require("codecompanion.adapters.openai"),
inline = require("codecompanion.adapters.openai"),
},
saved_chats = {
save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", -- Path to save chats to
Expand Down
6 changes: 5 additions & 1 deletion lua/codecompanion/adapter.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ function Adapter:get_default_settings()
return settings
end

---@param settings table
---@param settings? table
---@return CodeCompanion.Adapter
function Adapter:set_params(settings)
if not settings then
settings = self:get_default_settings()
end

for k, v in pairs(settings) do
local mapping = self.schema[k] and self.schema[k].mapping
if mapping then
Expand Down
35 changes: 24 additions & 11 deletions lua/codecompanion/adapters/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ local Adapter = require("codecompanion.adapter")
local adapter = {
name = "OpenAI",
url = "https://api.openai.com/v1/chat/completions",
raw = {
"--no-buffer",
"--silent",
},
headers = {
content_type = "application/json",
["Content-Type"] = "application/json",
-- FIX: Need a way to check if the key is set
Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY"),
},
Expand All @@ -35,24 +39,33 @@ local adapter = {
return formatted_data == "[DONE]"
end,

---Format the messages from the API
---@param data table
---@param messages table
---@param new_message table
---Output the data from the API ready for insertion into the chat buffer
---@param data table The streamed data from the API
---@param messages table A table of all of the messages in the chat buffer
---@param current_message table The current/latest message in the chat buffer
---@return table
format_messages = function(data, messages, new_message)
output_chat = function(data, messages, current_message)
local delta = data.choices[1].delta

if delta.role and delta.role ~= new_message.role then
new_message = { role = delta.role, content = "" }
table.insert(messages, new_message)
if delta.role and delta.role ~= current_message.role then
current_message = { role = delta.role, content = "" }
table.insert(messages, current_message)
end

-- Append the new message to the
if delta.content then
new_message.content = new_message.content .. delta.content
current_message.content = current_message.content .. delta.content
end

return new_message
return current_message
end,

---Output the data from the API ready for inlining into the current buffer
---@param data table The streamed data from the API
---@param context table Useful context about the buffer to inline to
---@return table
output_inline = function(data, context)
return data.choices[1].delta.content
end,
},
schema = {
Expand Down
74 changes: 23 additions & 51 deletions lua/codecompanion/client.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
local config = require("codecompanion.config")
local curl = require("plenary.curl")
local log = require("codecompanion.utils.log")
local schema = require("codecompanion.schema")
Expand Down Expand Up @@ -155,31 +154,40 @@ function Client:stream(adapter, payload, bufnr, cb)
start_request(bufnr, handler)
end

---Call the OpenAI API but block the main loop until the response is received
---@param url string
---Call the API and block until the response is received
---@param adapter CodeCompanion.Adapter
---@param payload table
---@param cb fun(err: nil|string, response: nil|table)
function Client:block_request(url, payload, cb)
function Client:call(adapter, payload, cb)
cb = log:wrap_cb(cb, "Response error: %s")

local cmd = {
"curl",
url,
"--silent",
"--no-buffer",
"-H",
"Content-Type: application/json",
"-H",
string.format("Authorization: Bearer %s", self.secret_key),
adapter.url,
}

if self.organization then
table.insert(cmd, "-H")
table.insert(cmd, string.format("OpenAI-Organization: %s", self.organization))
if adapter.raw then
for _, v in ipairs(adapter.raw) do
table.insert(cmd, v)
end
else
table.insert(cmd, "--no-buffer")
end

if adapter.headers then
for k, v in pairs(adapter.headers) do
table.insert(cmd, "-H")
table.insert(cmd, string.format("%s: %s", k, v))
end
end

table.insert(cmd, "-d")
table.insert(cmd, vim.json.encode(payload))
table.insert(
cmd,
vim.json.encode(vim.tbl_extend("keep", adapter.parameters, {
messages = payload,
}))
)
log:trace("Request payload: %s", cmd)

local result = vim.fn.system(cmd)
Expand All @@ -197,40 +205,4 @@ function Client:block_request(url, payload, cb)
end
end

---@class CodeCompanion.ChatMessage
---@field role "system"|"user"|"assistant"
---@field content string

---@class CodeCompanion.ChatSettings
---@field model string ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
---@field temperature nil|number Defaults to 1. What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.
---@field top_p nil|number Defaults to 1. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.
---@field n nil|integer Defaults to 1. How many chat completion choices to generate for each input message.
---@field stop nil|string|string[] Defaults to nil. Up to 4 sequences where the API will stop generating further tokens.
---@field max_tokens nil|integer Defaults to nil. The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.
---@field presence_penalty nil|number Defaults to 0. Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
---@field frequency_penalty nil|number Defaults to 0. Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
---@field logit_bias nil|table<integer, integer> Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs.
---@field user nil|string A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.

---@class CodeCompanion.ChatArgs : CodeCompanion.ChatSettings
---@field messages CodeCompanion.ChatMessage[] The messages to generate chat completions for, in the chat format.
---@field stream boolean? Whether to stream the chat output back to Neovim

---@param args CodeCompanion.ChatArgs
---@param cb fun(err: nil|string, response: nil|table)
---@return nil
function Client:chat(args, cb)
return self:block_request(config.options.base_url .. "/v1/chat/completions", args, cb)
end

---@class args CodeCompanion.InlineArgs
---@param bufnr integer
---@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true
---@return nil
function Client:inline(args, bufnr, cb)
args.stream = true
return self:stream(config.options.base_url .. "/v1/chat/completions", args, bufnr, cb)
end

return Client
16 changes: 0 additions & 16 deletions lua/codecompanion/config.lua
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
local M = {}

local defaults = {
api_key = "OPENAI_API_KEY",
org_api_key = "OPENAI_ORG_KEY",
base_url = "https://api.openai.com",
adapters = {
chat = require("codecompanion.adapters.openai"),
inline = require("codecompanion.adapters.openai"),
},
ai_settings = {
inline = {
model = "gpt-4-0125-preview",
temperature = 1,
top_p = 1,
stop = nil,
max_tokens = nil,
presence_penalty = 0,
frequency_penalty = 0,
logit_bias = nil,
user = nil,
},
},
saved_chats = {
save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats",
},
Expand Down
22 changes: 11 additions & 11 deletions lua/codecompanion/health.lua
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ function M.check()
end
end

for _, env in ipairs(M.env_vars) do
if env_available(env.name) then
ok(fmt("%s key found", env.name))
else
if env.optional then
warn(fmt("%s key not found", env.name))
else
error(fmt("%s key not found", env.name))
end
end
end
-- for _, env in ipairs(M.env_vars) do
-- if env_available(env.name) then
-- ok(fmt("%s key found", env.name))
-- else
-- if env.optional then
-- warn(fmt("%s key not found", env.name))
-- else
-- error(fmt("%s key not found", env.name))
-- end
-- end
-- end
end

return M
6 changes: 3 additions & 3 deletions lua/codecompanion/strategies/chat.lua
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,9 @@ function Chat:submit()
end
end

local new_message = messages[#messages]
local current_message = messages[#messages]

if new_message and new_message.role == "user" and new_message.content == "" then
if current_message and current_message.role == "user" and current_message.content == "" then
return finalize()
end

Expand All @@ -447,7 +447,7 @@ function Chat:submit()

if data then
log:trace("Chat data: %s", data)
new_message = adapter.callbacks.format_messages(data, messages, new_message)
current_message = adapter.callbacks.output_chat(data, messages, current_message)
render_buffer()
end

Expand Down
Loading

0 comments on commit c843a16

Please sign in to comment.