From df970d5c5de3e8af04bb5f835aa9e9b2cbe6e4ff Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 29 Feb 2024 19:21:56 +0000 Subject: [PATCH 01/42] wip: first attempt at an adapter --- lua/codecompanion/adapter.lua | 33 ++++++ lua/spec/codecompanion/adapter_spec.lua | 150 ++++++++++++++++++++++++ 2 files changed, 183 insertions(+) create mode 100644 lua/codecompanion/adapter.lua create mode 100644 lua/spec/codecompanion/adapter_spec.lua diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua new file mode 100644 index 00000000..27999045 --- /dev/null +++ b/lua/codecompanion/adapter.lua @@ -0,0 +1,33 @@ +---@class CodeCompanion.Adapter +---@field data table +local Adapter = {} + +---@class CodeCompanion.AdapterArgs +---@field args table + +---@param args CodeCompanion.AdapterArgs +---@return CodeCompanion.Adapter +function Adapter.new(args) + return setmetatable({ data = args }, { __index = Adapter }) +end + +function Adapter:process(settings) + for k, v in pairs(self.data.payload) do + if type(v) == "string" then + -- Attempt to extract the key assuming the format is always `${key}` + local name, _ = v:find("%${.+}") + if name then + local key = v:sub(3, -2) -- Extract the key without `${` and `}` + if settings[key] ~= nil then + self.data.payload[k] = settings[key] + else + self.data.payload[k] = nil + end + end + end + end + + return self +end + +return Adapter diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua new file mode 100644 index 00000000..54c7e6df --- /dev/null +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -0,0 +1,150 @@ +local assert = require("luassert") + +local Adapter = require("codecompanion.adapter") + +local chat_buffer_settings = { + frequency_penalty = 0, + model = "gpt-4-0125-preview", + presence_penalty = 0, + temperature = 1, + top_p = 1, + stop = nil, + max_tokens = nil, + logit_bias = nil, + user = nil, +} + +local openai_adapter = { + url = "https://api.openai.com/v1/chat/completions", + headers = { + content_type = "application/json", + Authorization = "Bearer ", -- ignore the API key for now + }, + payload = { + stream = true, + model = "${model}", + temperature = "${temperature}", + top_p = "${top_p}", + stop = "${stop}", + max_tokens = "${max_tokens}", + presence_penalty = "${presence_penalty}", + frequency_penalty = "${frequency_penalty}", + logit_bias = "${logit_bias}", + user = "${user}", + }, + schema = { + model = { + order = 1, + type = "enum", + desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + default = "gpt-4-0125-preview", + choices = { + "gpt-4-1106-preview", + "gpt-4", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + }, + }, + temperature = { + order = 2, + type = "number", + optional = true, + default = 1, + desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 2, "Must be between 0 and 2" + end, + }, + top_p = { + order = 3, + type = "number", + optional = true, + default = 1, + desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", + validate = function(n) + return n >= 0 and n <= 1, "Must be between 0 and 1" + end, + }, + stop = { + order = 4, + type = "list", + optional = true, + default = nil, + subtype = { + type = "string", + }, + desc = "Up to 4 sequences where the API will stop generating further tokens.", + validate = function(l) + return #l >= 1 and #l <= 4, "Must have between 1 and 4 elements" + end, + }, + max_tokens = { + order = 5, + type = "integer", + optional = true, + default = nil, + desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", + validate = function(n) + return n > 0, "Must be greater than 0" + end, + }, + presence_penalty = { + order = 6, + type = "number", + optional = true, + default = 0, + desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", + validate = function(n) + return n >= -2 and n <= 2, "Must be between -2 and 2" + end, + }, + frequency_penalty = { + order = 7, + type = "number", + optional = true, + default = 0, + desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", + validate = function(n) + return n >= -2 and n <= 2, "Must be between -2 and 2" + end, + }, + logit_bias = { + order = 8, + type = "map", + optional = true, + default = nil, + desc = "Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs.", + subtype_key = { + type = "integer", + }, + subtype = { + type = "integer", + validate = function(n) + return n >= -100 and n <= 100, "Must be between -100 and 100" + end, + }, + }, + user = { + order = 9, + type = "string", + optional = true, + default = nil, + desc = "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.", + validate = function(u) + return u:len() < 100, "Cannot be longer than 100 characters" + end, + }, + }, +} + +describe("Adapter", function() + it("can form a payload consisting of a chat buffer's settings", function() + local adapter = vim.deepcopy(openai_adapter) + local result = Adapter.new(adapter):process(chat_buffer_settings) + + -- Remove the stream key from the payload as this isn't handled via the settings in the chat buffer + result.data.payload.stream = nil + + assert.are.same(chat_buffer_settings, result.data.payload) + end) +end) From a234364a3e6cef8e1f1ecc206212698cb13f7203 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 29 Feb 2024 19:24:46 +0000 Subject: [PATCH 02/42] add luadoc blocks --- lua/codecompanion/adapter.lua | 22 +++++++++++++++------- lua/spec/codecompanion/adapter_spec.lua | 4 ++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index 27999045..c4700430 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -1,27 +1,35 @@ ---@class CodeCompanion.Adapter ----@field data table +---@field url string +---@field header table +---@field payload table +---@field schema table local Adapter = {} ---@class CodeCompanion.AdapterArgs ----@field args table +---@field url string +---@field header table +---@field payload table +---@field schema table ----@param args CodeCompanion.AdapterArgs +---@param args table ---@return CodeCompanion.Adapter function Adapter.new(args) - return setmetatable({ data = args }, { __index = Adapter }) + return setmetatable(args, { __index = Adapter }) end +---@param settings table +---@return CodeCompanion.Adapter function Adapter:process(settings) - for k, v in pairs(self.data.payload) do + for k, v in pairs(self.payload) do if type(v) == "string" then -- Attempt to extract the key assuming the format is always `${key}` local name, _ = v:find("%${.+}") if name then local key = v:sub(3, -2) -- Extract the key without `${` and `}` if settings[key] ~= nil then - self.data.payload[k] = settings[key] + self.payload[k] = settings[key] else - self.data.payload[k] = nil + self.payload[k] = nil end end end diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua index 54c7e6df..b13217db 100644 --- a/lua/spec/codecompanion/adapter_spec.lua +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -143,8 +143,8 @@ describe("Adapter", function() local result = Adapter.new(adapter):process(chat_buffer_settings) -- Remove the stream key from the payload as this isn't handled via the settings in the chat buffer - result.data.payload.stream = nil + result.payload.stream = nil - assert.are.same(chat_buffer_settings, result.data.payload) + assert.are.same(chat_buffer_settings, result.payload) end) end) From 5019e6986aa9aa9e3c232f7edd48d5a51aa450eb Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 29 Feb 2024 22:55:47 +0000 Subject: [PATCH 03/42] refactor adapters --- lua/codecompanion/adapter.lua | 23 ++-- lua/codecompanion/adapters/openai.lua | 134 +++++++++++++++++++++++ lua/spec/codecompanion/adapter_spec.lua | 135 ++---------------------- 3 files changed, 147 insertions(+), 145 deletions(-) create mode 100644 lua/codecompanion/adapters/openai.lua diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index c4700430..50e3587e 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -1,14 +1,16 @@ ---@class CodeCompanion.Adapter +---@field name string ---@field url string ---@field header table ----@field payload table +---@field parameters table ---@field schema table local Adapter = {} ---@class CodeCompanion.AdapterArgs +---@field name string ---@field url string ---@field header table ----@field payload table +---@field parameters table ---@field schema table ---@param args table @@ -19,20 +21,9 @@ end ---@param settings table ---@return CodeCompanion.Adapter -function Adapter:process(settings) - for k, v in pairs(self.payload) do - if type(v) == "string" then - -- Attempt to extract the key assuming the format is always `${key}` - local name, _ = v:find("%${.+}") - if name then - local key = v:sub(3, -2) -- Extract the key without `${` and `}` - if settings[key] ~= nil then - self.payload[k] = settings[key] - else - self.payload[k] = nil - end - end - end +function Adapter:set_params(settings) + for k, v in pairs(settings) do + self.parameters[k] = v end return self diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua new file mode 100644 index 00000000..7177ec9d --- /dev/null +++ b/lua/codecompanion/adapters/openai.lua @@ -0,0 +1,134 @@ +local Adapter = require("codecompanion.adapter") + +---@class CodeCompanion.Adapter +---@field name string +---@field url string +---@field headers table +---@field parameters table +---@field schema table + +local adapter = { + name = "OpenAI", + url = "https://api.openai.com/v1/chat/completions", + headers = { + content_type = "application/json", + Authorization = "Bearer ", -- ignore the API key for now + }, + parameters = { + stream = true, + }, + schema = { + model = { + order = 1, + mapping = "parameters", + type = "enum", + desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + default = "gpt-4-0125-preview", + choices = { + "gpt-4-1106-preview", + "gpt-4", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + }, + }, + temperature = { + order = 2, + mapping = "parameters", + type = "number", + optional = true, + default = 1, + desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 2, "Must be between 0 and 2" + end, + }, + top_p = { + order = 3, + mapping = "parameters", + type = "number", + optional = true, + default = 1, + desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", + validate = function(n) + return n >= 0 and n <= 1, "Must be between 0 and 1" + end, + }, + stop = { + order = 4, + mapping = "parameters", + type = "list", + optional = true, + default = nil, + subtype = { + type = "string", + }, + desc = "Up to 4 sequences where the API will stop generating further tokens.", + validate = function(l) + return #l >= 1 and #l <= 4, "Must have between 1 and 4 elements" + end, + }, + max_tokens = { + order = 5, + mapping = "parameters", + type = "integer", + optional = true, + default = nil, + desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", + validate = function(n) + return n > 0, "Must be greater than 0" + end, + }, + presence_penalty = { + order = 6, + mapping = "parameters", + type = "number", + optional = true, + default = 0, + desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", + validate = function(n) + return n >= -2 and n <= 2, "Must be between -2 and 2" + end, + }, + frequency_penalty = { + order = 7, + mapping = "parameters", + type = "number", + optional = true, + default = 0, + desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", + validate = function(n) + return n >= -2 and n <= 2, "Must be between -2 and 2" + end, + }, + logit_bias = { + order = 8, + mapping = "parameters", + type = "map", + optional = true, + default = nil, + desc = "Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs.", + subtype_key = { + type = "integer", + }, + subtype = { + type = "integer", + validate = function(n) + return n >= -100 and n <= 100, "Must be between -100 and 100" + end, + }, + }, + user = { + order = 9, + mapping = "parameters", + type = "string", + optional = true, + default = nil, + desc = "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.", + validate = function(u) + return u:len() < 100, "Cannot be longer than 100 characters" + end, + }, + }, +} + +return Adapter.new(adapter) diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua index b13217db..e1d7a1ad 100644 --- a/lua/spec/codecompanion/adapter_spec.lua +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -14,137 +14,14 @@ local chat_buffer_settings = { user = nil, } -local openai_adapter = { - url = "https://api.openai.com/v1/chat/completions", - headers = { - content_type = "application/json", - Authorization = "Bearer ", -- ignore the API key for now - }, - payload = { - stream = true, - model = "${model}", - temperature = "${temperature}", - top_p = "${top_p}", - stop = "${stop}", - max_tokens = "${max_tokens}", - presence_penalty = "${presence_penalty}", - frequency_penalty = "${frequency_penalty}", - logit_bias = "${logit_bias}", - user = "${user}", - }, - schema = { - model = { - order = 1, - type = "enum", - desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - default = "gpt-4-0125-preview", - choices = { - "gpt-4-1106-preview", - "gpt-4", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo", - }, - }, - temperature = { - order = 2, - type = "number", - optional = true, - default = 1, - desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", - validate = function(n) - return n >= 0 and n <= 2, "Must be between 0 and 2" - end, - }, - top_p = { - order = 3, - type = "number", - optional = true, - default = 1, - desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", - validate = function(n) - return n >= 0 and n <= 1, "Must be between 0 and 1" - end, - }, - stop = { - order = 4, - type = "list", - optional = true, - default = nil, - subtype = { - type = "string", - }, - desc = "Up to 4 sequences where the API will stop generating further tokens.", - validate = function(l) - return #l >= 1 and #l <= 4, "Must have between 1 and 4 elements" - end, - }, - max_tokens = { - order = 5, - type = "integer", - optional = true, - default = nil, - desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", - validate = function(n) - return n > 0, "Must be greater than 0" - end, - }, - presence_penalty = { - order = 6, - type = "number", - optional = true, - default = 0, - desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", - validate = function(n) - return n >= -2 and n <= 2, "Must be between -2 and 2" - end, - }, - frequency_penalty = { - order = 7, - type = "number", - optional = true, - default = 0, - desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", - validate = function(n) - return n >= -2 and n <= 2, "Must be between -2 and 2" - end, - }, - logit_bias = { - order = 8, - type = "map", - optional = true, - default = nil, - desc = "Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs.", - subtype_key = { - type = "integer", - }, - subtype = { - type = "integer", - validate = function(n) - return n >= -100 and n <= 100, "Must be between -100 and 100" - end, - }, - }, - user = { - order = 9, - type = "string", - optional = true, - default = nil, - desc = "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.", - validate = function(u) - return u:len() < 100, "Cannot be longer than 100 characters" - end, - }, - }, -} - describe("Adapter", function() - it("can form a payload consisting of a chat buffer's settings", function() - local adapter = vim.deepcopy(openai_adapter) - local result = Adapter.new(adapter):process(chat_buffer_settings) + it("can receive parameters from a chat buffer's settings", function() + local adapter = require("codecompanion.adapters.openai") + local result = adapter:set_params(chat_buffer_settings) - -- Remove the stream key from the payload as this isn't handled via the settings in the chat buffer - result.payload.stream = nil + -- The `stream` parameter is not present in the chat buffer's settings, so remove it to get the tests to pass + result.parameters.stream = nil - assert.are.same(chat_buffer_settings, result.payload) + assert.are.same(chat_buffer_settings, result.parameters) end) end) From d8b0b1d2d935cbb99dfbbaf85038a773f061ede0 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 29 Feb 2024 23:09:32 +0000 Subject: [PATCH 04/42] chore: clean up spec --- lua/spec/codecompanion/adapter_spec.lua | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua index e1d7a1ad..fc1c44b5 100644 --- a/lua/spec/codecompanion/adapter_spec.lua +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -1,7 +1,5 @@ local assert = require("luassert") -local Adapter = require("codecompanion.adapter") - local chat_buffer_settings = { frequency_penalty = 0, model = "gpt-4-0125-preview", @@ -19,7 +17,7 @@ describe("Adapter", function() local adapter = require("codecompanion.adapters.openai") local result = adapter:set_params(chat_buffer_settings) - -- The `stream` parameter is not present in the chat buffer's settings, so remove it to get the tests to pass + -- Ignore this for now result.parameters.stream = nil assert.are.same(chat_buffer_settings, result.parameters) From 4cf82835fff3e5bb712c28b9b9f9188e87e6c452 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 29 Feb 2024 23:09:39 +0000 Subject: [PATCH 05/42] wip: note on schema --- lua/codecompanion/adapter.lua | 1 + 1 file changed, 1 insertion(+) diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index 50e3587e..c71dc36a 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -22,6 +22,7 @@ end ---@param settings table ---@return CodeCompanion.Adapter function Adapter:set_params(settings) + -- TODO: Need to take into account the schema's "mapping" field for k, v in pairs(settings) do self.parameters[k] = v end From fbba3a176ae91a9978df390cf6e445b725126786 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 29 Feb 2024 23:16:07 +0000 Subject: [PATCH 06/42] chore: slight word tweak for spec --- lua/spec/codecompanion/adapter_spec.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua index fc1c44b5..ee54a090 100644 --- a/lua/spec/codecompanion/adapter_spec.lua +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -13,7 +13,7 @@ local chat_buffer_settings = { } describe("Adapter", function() - it("can receive parameters from a chat buffer's settings", function() + it("can form parameters from a chat buffer's settings", function() local adapter = require("codecompanion.adapters.openai") local result = adapter:set_params(chat_buffer_settings) From d36dbde287b2be0c88b4084f67f831ceb1a82637 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Fri, 1 Mar 2024 10:28:11 +0000 Subject: [PATCH 07/42] refactor: strategy folder is now strategies --- RECIPES.md | 2 +- lua/codecompanion/actions.lua | 4 ++-- lua/codecompanion/init.lua | 6 +++--- lua/codecompanion/keymaps.lua | 4 ++-- lua/codecompanion/{strategy => strategies}/chat.lua | 4 +++- lua/codecompanion/{strategy => strategies}/inline.lua | 0 lua/codecompanion/{strategy => strategies}/saved_chats.lua | 2 +- lua/codecompanion/strategy.lua | 4 ++-- 8 files changed, 14 insertions(+), 12 deletions(-) rename lua/codecompanion/{strategy => strategies}/chat.lua (99%) rename lua/codecompanion/{strategy => strategies}/inline.lua (100%) rename lua/codecompanion/{strategy => strategies}/saved_chats.lua (98%) diff --git a/RECIPES.md b/RECIPES.md index cc2517c2..d9ca4a11 100644 --- a/RECIPES.md +++ b/RECIPES.md @@ -298,7 +298,7 @@ And to determine the visibility of actions in the palette itself: strategy = "saved_chats", description = "Load your previously saved chats", condition = function() - local saved_chats = require("codecompanion.strategy.saved_chats") + local saved_chats = require("codecompanion.strategies.saved_chats") return saved_chats:has_chats() end, picker = { diff --git a/lua/codecompanion/actions.lua b/lua/codecompanion/actions.lua index 74cffa58..0f482ad2 100644 --- a/lua/codecompanion/actions.lua +++ b/lua/codecompanion/actions.lua @@ -505,14 +505,14 @@ M.static.actions = { strategy = "saved_chats", description = "Load your previously saved chats", condition = function() - local saved_chats = require("codecompanion.strategy.saved_chats") + local saved_chats = require("codecompanion.strategies.saved_chats") return saved_chats:has_chats() end, picker = { prompt = "Load chats", items = function() local client = require("codecompanion").get_client() - local saved_chats = require("codecompanion.strategy.saved_chats") + local saved_chats = require("codecompanion.strategies.saved_chats") local items = saved_chats:list({ sort = true }) local chats = {} diff --git a/lua/codecompanion/init.lua b/lua/codecompanion/init.lua index 414a0b27..34aa615c 100644 --- a/lua/codecompanion/init.lua +++ b/lua/codecompanion/init.lua @@ -31,7 +31,7 @@ end ---@param bufnr nil|integer ---@return nil|CodeCompanion.Chat M.buf_get_chat = function(bufnr) - return require("codecompanion.strategy.chat").buf_get_chat(bufnr) + return require("codecompanion.strategies.chat").buf_get_chat(bufnr) end ---@param args table @@ -44,7 +44,7 @@ M.inline = function(args) local context = util.get_context(vim.api.nvim_get_current_buf(), args) - return require("codecompanion.strategy.inline") + return require("codecompanion.strategies.inline") .new({ context = context, client = client, @@ -71,7 +71,7 @@ M.chat = function(args) local context = util.get_context(vim.api.nvim_get_current_buf(), args) - local chat = require("codecompanion.strategy.chat").new({ + local chat = require("codecompanion.strategies.chat").new({ client = client, context = context, }) diff --git a/lua/codecompanion/keymaps.lua b/lua/codecompanion/keymaps.lua index 67868b8b..3b5b258a 100644 --- a/lua/codecompanion/keymaps.lua +++ b/lua/codecompanion/keymaps.lua @@ -33,8 +33,8 @@ M.cancel_request = { M.save_chat = { desc = "Save the current chat", callback = function(args) - local chat = require("codecompanion.strategy.chat") - local saved_chat = require("codecompanion.strategy.saved_chats").new({}) + local chat = require("codecompanion.strategies.chat") + local saved_chat = require("codecompanion.strategies.saved_chats").new({}) if args.saved_chat then saved_chat.filename = args.saved_chat diff --git a/lua/codecompanion/strategy/chat.lua b/lua/codecompanion/strategies/chat.lua similarity index 99% rename from lua/codecompanion/strategy/chat.lua rename to lua/codecompanion/strategies/chat.lua index 9935a2cb..9e381558 100644 --- a/lua/codecompanion/strategy/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -316,7 +316,9 @@ local function chat_autocmds(bufnr, args) _G.codecompanion_chats[bufnr] = nil - _G.codecompanion_jobs[request.data.buf].handler:shutdown() + if _G.codecompanion_jobs[request.data.buf] then + _G.codecompanion_jobs[request.data.buf].handler:shutdown() + end vim.api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = "finished" } }) vim.api.nvim_buf_delete(request.data.buf, { force = true }) end, diff --git a/lua/codecompanion/strategy/inline.lua b/lua/codecompanion/strategies/inline.lua similarity index 100% rename from lua/codecompanion/strategy/inline.lua rename to lua/codecompanion/strategies/inline.lua diff --git a/lua/codecompanion/strategy/saved_chats.lua b/lua/codecompanion/strategies/saved_chats.lua similarity index 98% rename from lua/codecompanion/strategy/saved_chats.lua rename to lua/codecompanion/strategies/saved_chats.lua index 018ed44a..2850cba5 100644 --- a/lua/codecompanion/strategy/saved_chats.lua +++ b/lua/codecompanion/strategies/saved_chats.lua @@ -1,4 +1,4 @@ -local Chat = require("codecompanion.strategy.chat") +local Chat = require("codecompanion.strategies.chat") local config = require("codecompanion.config") local log = require("codecompanion.utils.log") diff --git a/lua/codecompanion/strategy.lua b/lua/codecompanion/strategy.lua index 6bf09c31..defd7e66 100644 --- a/lua/codecompanion/strategy.lua +++ b/lua/codecompanion/strategy.lua @@ -76,7 +76,7 @@ function Strategy:chat() }) end - return require("codecompanion.strategy.chat").new({ + return require("codecompanion.strategies.chat").new({ client = self.client, type = self.selected.type, messages = messages, @@ -99,7 +99,7 @@ function Strategy:chat() end function Strategy:inline() - return require("codecompanion.strategy.inline") + return require("codecompanion.strategies.inline") .new({ context = self.context, client = self.client, From e012ebef2c610be195469cd3a308bae62bd6358e Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Fri, 1 Mar 2024 17:18:34 +0000 Subject: [PATCH 08/42] can form parameters based on the schema definition --- lua/codecompanion/adapter.lua | 28 +++++++++- lua/codecompanion/adapters/openai.lua | 8 ++- lua/spec/codecompanion/adapter_spec.lua | 69 +++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 4 deletions(-) diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index c71dc36a..0389ff0b 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -22,9 +22,33 @@ end ---@param settings table ---@return CodeCompanion.Adapter function Adapter:set_params(settings) - -- TODO: Need to take into account the schema's "mapping" field for k, v in pairs(settings) do - self.parameters[k] = v + local mapping = self.schema[k] and self.schema[k].mapping + if mapping then + local segments = {} + for segment in string.gmatch(mapping, "[^.]+") do + table.insert(segments, segment) + end + + local current = self + for i = 1, #segments - 1 do + if not current[segments[i]] then + current[segments[i]] = {} + end + current = current[segments[i]] + end + + -- Before setting the value, ensure the target exists or initialize it. + local target = segments[#segments] + if not current[target] then + current[target] = {} + end + + -- Ensure 'target' is not nil and 'k' can be assigned to the final segment. + if target then + current[target][k] = v + end + end end return self diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 7177ec9d..3542e7cd 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -8,15 +8,19 @@ local Adapter = require("codecompanion.adapter") ---@field schema table local adapter = { - name = "OpenAI", + opts = { + name = "OpenAI", + stream = true, -- Need this to determine if we use the plenary.curl stream functionality + }, url = "https://api.openai.com/v1/chat/completions", headers = { content_type = "application/json", - Authorization = "Bearer ", -- ignore the API key for now + Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY"), }, parameters = { stream = true, }, + -- TODO: Need to map roles/messages based on Tree-sitter parsing of the chat buffer schema = { model = { order = 1, diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua index ee54a090..3baac16c 100644 --- a/lua/spec/codecompanion/adapter_spec.lua +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -1,5 +1,56 @@ local assert = require("luassert") +local test_adapter = { + opts = { + name = "TestAdapter", + stream = true, + }, + url = "https://api.openai.com/v1/chat/completions", + headers = { + content_type = "application/json", + }, + parameters = { + stream = true, + }, + schema = { + model = { + order = 1, + mapping = "parameters.data", + type = "enum", + desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + default = "gpt-4-0125-preview", + choices = { + "gpt-4-1106-preview", + "gpt-4", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + }, + }, + temperature = { + order = 2, + mapping = "parameters.options", + type = "number", + optional = true, + default = 1, + desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 2, "Must be between 0 and 2" + end, + }, + top_p = { + order = 3, + mapping = "parameters.options", + type = "number", + optional = true, + default = 1, + desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", + validate = function(n) + return n >= 0 and n <= 1, "Must be between 0 and 1" + end, + }, + }, +} + local chat_buffer_settings = { frequency_penalty = 0, model = "gpt-4-0125-preview", @@ -22,4 +73,22 @@ describe("Adapter", function() assert.are.same(chat_buffer_settings, result.parameters) end) + + it("can nest parameters based on an adapter's schema", function() + local adapter = require("codecompanion.adapter").new(test_adapter) + local result = adapter:set_params(chat_buffer_settings) + + local expected = { + stream = true, + data = { + model = "gpt-4-0125-preview", + }, + options = { + temperature = 1, + top_p = 1, + }, + } + + assert.are.same(expected, result.parameters) + end) end) From 6621185f391046d61e0ed9341db98559998ad6c1 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Fri, 1 Mar 2024 17:19:16 +0000 Subject: [PATCH 09/42] switch to using the schema from the adapter --- lua/codecompanion/config.lua | 3 +++ lua/codecompanion/strategies/chat.lua | 13 +++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index 3d5cafe5..d5381bb0 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -4,6 +4,9 @@ local defaults = { api_key = "OPENAI_API_KEY", org_api_key = "OPENAI_ORG_KEY", base_url = "https://api.openai.com", + adapters = { + chat = require("codecompanion.adapters.openai"), + }, ai_settings = { chat = { model = "gpt-4-0125-preview", diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index 9e381558..f51da578 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -101,9 +101,8 @@ end local function render_messages(bufnr, settings, messages, context) local lines = {} if config.options.display.chat.show_settings then - -- Put the settings at the top of the buffer lines = { "---" } - local keys = schema.get_ordered_keys(schema.static.chat_settings) + local keys = schema.get_ordered_keys(config.options.adapters.chat.schema) for _, key in ipairs(keys) do table.insert(lines, string.format("%s: %s", key, yaml.encode(settings[key]))) end @@ -199,7 +198,7 @@ local function chat_autocmds(bufnr, args) buffer = bufnr, callback = function() local settings = parse_settings(bufnr) - local errors = schema.validate(schema.static.chat_settings, settings) + local errors = schema.validate(config.options.adapters.chat.schema, settings) local node = settings.__ts_node local items = {} if errors and node then @@ -333,6 +332,7 @@ local Chat = {} ---@class CodeCompanion.ChatArgs ---@field client CodeCompanion.Client +---@field adapter CodeCompanion.Adapter ---@field context table ---@field messages nil|CodeCompanion.ChatMessage[] ---@field show_buffer nil|boolean @@ -364,11 +364,12 @@ function Chat.new(args) watch_cursor() chat_autocmds(bufnr, args) - local settings = args.settings or schema.get_default(schema.static.chat_settings, args.settings) + local settings = args.settings or schema.get_default(config.options.adapters.chat.schema, args.settings) local self = setmetatable({ bufnr = bufnr, client = args.client, + adapter = args.adapter, context = args.context, saved_chat = args.saved_chat, settings = settings, @@ -500,7 +501,7 @@ function Chat:on_cursor_moved() vim.diagnostic.set(config.INFO_NS, self.bufnr, {}) return end - local key_schema = schema.static.chat_settings[key_name] + local key_schema = config.options.adapters.chat.schema[key_name] if key_schema and key_schema.desc then local lnum, col, end_lnum, end_col = node:range() @@ -528,7 +529,7 @@ function Chat:complete(request, callback) return end - local key_schema = schema.static.chat_settings[key_name] + local key_schema = config.options.adapters.chat.schema[key_name] if key_schema.type == "enum" then for _, choice in ipairs(key_schema.choices) do table.insert(items, { From e3fd4db26273b86094a77e4d76003d979442ebe2 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Fri, 1 Mar 2024 18:16:25 +0000 Subject: [PATCH 10/42] clean up client --- lua/codecompanion/adapters/openai.lua | 6 +- lua/codecompanion/client.lua | 12 ++- lua/codecompanion/schema.lua | 118 ------------------------- lua/spec/codecompanion/client_spec.lua | 6 +- 4 files changed, 15 insertions(+), 127 deletions(-) diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 3542e7cd..8ea22d84 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -8,9 +8,9 @@ local Adapter = require("codecompanion.adapter") ---@field schema table local adapter = { - opts = { - name = "OpenAI", - stream = true, -- Need this to determine if we use the plenary.curl stream functionality + name = "OpenAI", + client = { + stream = true, }, url = "https://api.openai.com/v1/chat/completions", headers = { diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 53646e54..ef4ea343 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -1,4 +1,5 @@ local config = require("codecompanion.config") +local curl = require("plenary.curl") local log = require("codecompanion.utils.log") local schema = require("codecompanion.schema") @@ -72,10 +73,19 @@ local function parse_response(code, stdout) end ---@class CodeCompanion.Client +---@field static table ---@field secret_key string ---@field organization nil|string ---@field settings nil|table local Client = {} +Client.static = {} + +Client.static.settings = { + request = { default = curl.post }, + encode = { default = vim.json.encode }, + decode = { default = vim.json.decode }, + schedule = { default = vim.schedule }, +} ---@class CodeCompanion.ClientArgs ---@field secret_key string @@ -88,7 +98,7 @@ function Client.new(args) return setmetatable({ secret_key = args.secret_key, organization = args.organization, - settings = args.settings or schema.get_default(schema.static.client_settings, args.settings), + settings = args.settings or schema.get_default(Client.static.settings, args.settings), }, { __index = Client }) end diff --git a/lua/codecompanion/schema.lua b/lua/codecompanion/schema.lua index a13743bf..3202457d 100644 --- a/lua/codecompanion/schema.lua +++ b/lua/codecompanion/schema.lua @@ -1,6 +1,3 @@ -local config = require("codecompanion.config") -local curl = require("plenary.curl") - local M = {} M.get_default = function(schema, defaults) @@ -120,119 +117,4 @@ M.get_ordered_keys = function(schema) return keys end -M.static = {} - -local model_choices = { - "gpt-4-1106-preview", - "gpt-4", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo", -} - -M.static.chat_settings = { - model = { - order = 1, - type = "enum", - desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - default = config.options.ai_settings.chat.model, - choices = model_choices, - }, - temperature = { - order = 2, - type = "number", - optional = true, - default = config.options.ai_settings.chat.temperature, - desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", - validate = function(n) - return n >= 0 and n <= 2, "Must be between 0 and 2" - end, - }, - top_p = { - order = 3, - type = "number", - optional = true, - default = config.options.ai_settings.chat.top_p, - desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", - validate = function(n) - return n >= 0 and n <= 1, "Must be between 0 and 1" - end, - }, - stop = { - order = 4, - type = "list", - optional = true, - default = config.options.ai_settings.chat.stop, - subtype = { - type = "string", - }, - desc = "Up to 4 sequences where the API will stop generating further tokens.", - validate = function(l) - return #l >= 1 and #l <= 4, "Must have between 1 and 4 elements" - end, - }, - max_tokens = { - order = 5, - type = "integer", - optional = true, - default = config.options.ai_settings.chat.max_tokens, - desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", - validate = function(n) - return n > 0, "Must be greater than 0" - end, - }, - presence_penalty = { - order = 6, - type = "number", - optional = true, - default = config.options.ai_settings.chat.presence_penalty, - desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", - validate = function(n) - return n >= -2 and n <= 2, "Must be between -2 and 2" - end, - }, - frequency_penalty = { - order = 7, - type = "number", - optional = true, - default = config.options.ai_settings.chat.frequency_penalty, - desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", - validate = function(n) - return n >= -2 and n <= 2, "Must be between -2 and 2" - end, - }, - logit_bias = { - order = 8, - type = "map", - optional = true, - default = config.options.ai_settings.chat.logit_bias, - desc = "Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs.", - subtype_key = { - type = "integer", - }, - subtype = { - type = "integer", - validate = function(n) - return n >= -100 and n <= 100, "Must be between -100 and 100" - end, - }, - }, - user = { - order = 9, - type = "string", - optional = true, - default = config.options.ai_settings.chat.user, - desc = "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.", - validate = function(u) - return u:len() < 100, "Cannot be longer than 100 characters" - end, - }, -} - -M.static.client_settings = { - request = { default = curl.post }, - encode = { default = vim.json.encode }, - decode = { default = vim.json.decode }, - schedule = { default = vim.schedule }, -} - return M diff --git a/lua/spec/codecompanion/client_spec.lua b/lua/spec/codecompanion/client_spec.lua index 7973ef79..deac9688 100644 --- a/lua/spec/codecompanion/client_spec.lua +++ b/lua/spec/codecompanion/client_spec.lua @@ -2,18 +2,15 @@ local assert = require("luassert") local codecompanion = require("codecompanion") local stub = require("luassert.stub") -local schema local Client describe("Client", function() before_each(function() codecompanion.setup() - schema = require("codecompanion.schema") Client = require("codecompanion.client") -- Now that setup has been called, we can require the client end) after_each(function() - schema.static.client_settings = nil _G.codecompanion_jobs = nil end) @@ -26,7 +23,7 @@ describe("Client", function() -- Mock globals _G.codecompanion_jobs = {} - schema.static.client_settings = { + Client.static.settings = { request = { default = mock_request }, encode = { default = mock_encode }, decode = { default = mock_decode }, @@ -43,6 +40,5 @@ describe("Client", function() client:stream_chat({}, 0, cb) assert.stub(mock_request).was_called() - -- assert.stub(cb).was_called() end) end) From cf9ffffd7f6a6f77588546056fa285b913d30e31 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Fri, 1 Mar 2024 18:28:05 +0000 Subject: [PATCH 11/42] fix adapter --- lua/codecompanion/adapters/openai.lua | 3 ++- lua/codecompanion/strategies/chat.lua | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 8ea22d84..3ed0f445 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -15,7 +15,8 @@ local adapter = { url = "https://api.openai.com/v1/chat/completions", headers = { content_type = "application/json", - Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY"), + -- FIX: Need a way to check if the key is set + Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY") or nil, }, parameters = { stream = true, diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index f51da578..f9b6ab64 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -369,7 +369,6 @@ function Chat.new(args) local self = setmetatable({ bufnr = bufnr, client = args.client, - adapter = args.adapter, context = args.context, saved_chat = args.saved_chat, settings = settings, @@ -441,6 +440,7 @@ function Chat:submit() end self.client:stream_chat( + config.options.adapters.chat, vim.tbl_extend("keep", settings, { messages = messages, }), From e93fe4d80b6b7595fdd6e19fd5d492f0d9cedf01 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Sat, 2 Mar 2024 10:03:11 +0000 Subject: [PATCH 12/42] client now uses adapter from config --- lua/codecompanion/adapter.lua | 2 ++ lua/codecompanion/adapters/openai.lua | 1 + lua/codecompanion/client.lua | 25 ++++++++++--------------- lua/codecompanion/strategies/chat.lua | 8 +++----- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index 0389ff0b..9a5597fc 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -1,6 +1,7 @@ ---@class CodeCompanion.Adapter ---@field name string ---@field url string +---@field raw table ---@field header table ---@field parameters table ---@field schema table @@ -9,6 +10,7 @@ local Adapter = {} ---@class CodeCompanion.AdapterArgs ---@field name string ---@field url string +---@field raw table ---@field header table ---@field parameters table ---@field schema table diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 3ed0f445..2282b7a9 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -13,6 +13,7 @@ local adapter = { stream = true, }, url = "https://api.openai.com/v1/chat/completions", + raw = { "--no-buffer" }, headers = { content_type = "application/json", -- FIX: Need a way to check if the key is set diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index ef4ea343..5a6c7a44 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -144,19 +144,23 @@ function Client:block_request(url, payload, cb) end end ----@param url string +---@param adapter CodeCompanion.Adapter ---@param payload table ---@param bufnr number ---@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ---@return nil -function Client:stream_request(url, payload, bufnr, cb) +function Client:stream_request(adapter, payload, bufnr, cb) cb = log:wrap_cb(cb, "Response error: %s") + log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, adapter.headers, adapter.parameters }) + local handler = self.settings.request({ - url = url, - raw = { "--no-buffer" }, - headers = headers(self), - body = self.settings.encode(payload), + url = adapter.url, + raw = adapter.raw, + headers = adapter.headers, + body = self.settings.encode(vim.tbl_extend("keep", adapter.parameters, { + messages = payload, + })), stream = function(_, chunk) chunk = chunk:sub(7) @@ -233,15 +237,6 @@ function Client:chat(args, cb) return self:block_request(config.options.base_url .. "/v1/chat/completions", args, cb) end ----@param args CodeCompanion.ChatArgs ----@param bufnr integer ----@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ----@return nil -function Client:stream_chat(args, bufnr, cb) - args.stream = true - return self:stream_request(config.options.base_url .. "/v1/chat/completions", args, bufnr, cb) -end - ---@class args CodeCompanion.InlineArgs ---@param bufnr integer ---@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index f9b6ab64..a950e2f0 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -439,11 +439,9 @@ function Chat:submit() return finalize() end - self.client:stream_chat( - config.options.adapters.chat, - vim.tbl_extend("keep", settings, { - messages = messages, - }), + self.client:stream_request( + config.options.adapters.chat:set_params(settings), + messages, self.bufnr, function(err, chunk, done) if err then From 135829f52ab107c2b10457d503fe385de0a996a9 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Sun, 3 Mar 2024 14:18:31 +0000 Subject: [PATCH 13/42] clean up client calls throughout the plugin --- lua/codecompanion/adapters/openai.lua | 4 - lua/codecompanion/client.lua | 112 ++++++++++++------------ lua/codecompanion/init.lua | 42 --------- lua/codecompanion/strategies/chat.lua | 61 ++++++------- lua/codecompanion/strategy.lua | 5 -- lua/spec/codecompanion/adapter_spec.lua | 5 +- lua/spec/codecompanion/client_spec.lua | 16 +++- 7 files changed, 99 insertions(+), 146 deletions(-) diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 2282b7a9..f5dcdcd8 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -9,11 +9,7 @@ local Adapter = require("codecompanion.adapter") local adapter = { name = "OpenAI", - client = { - stream = true, - }, url = "https://api.openai.com/v1/chat/completions", - raw = { "--no-buffer" }, headers = { content_type = "application/json", -- FIX: Need a way to check if the key is set diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 5a6c7a44..dbee0efa 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -76,11 +76,11 @@ end ---@field static table ---@field secret_key string ---@field organization nil|string ----@field settings nil|table +---@field opts nil|table local Client = {} Client.static = {} -Client.static.settings = { +Client.static.opts = { request = { default = curl.post }, encode = { default = vim.json.encode }, decode = { default = vim.json.decode }, @@ -90,60 +90,18 @@ Client.static.settings = { ---@class CodeCompanion.ClientArgs ---@field secret_key string ---@field organization nil|string ----@field settings nil|table +---@field opts nil|table ----@param args CodeCompanion.ClientArgs +---@param args? CodeCompanion.ClientArgs ---@return CodeCompanion.Client function Client.new(args) + args = args or {} + return setmetatable({ - secret_key = args.secret_key, - organization = args.organization, - settings = args.settings or schema.get_default(Client.static.settings, args.settings), + opts = args.opts or schema.get_default(Client.static.opts, args.opts), }, { __index = Client }) end ----Call the OpenAI API but block the main loop until the response is received ----@param url string ----@param payload table ----@param cb fun(err: nil|string, response: nil|table) -function Client:block_request(url, payload, cb) - cb = log:wrap_cb(cb, "Response error: %s") - - local cmd = { - "curl", - url, - "--silent", - "--no-buffer", - "-H", - "Content-Type: application/json", - "-H", - string.format("Authorization: Bearer %s", self.secret_key), - } - - if self.organization then - table.insert(cmd, "-H") - table.insert(cmd, string.format("OpenAI-Organization: %s", self.organization)) - end - - table.insert(cmd, "-d") - table.insert(cmd, vim.json.encode(payload)) - log:trace("Request payload: %s", cmd) - - local result = vim.fn.system(cmd) - - if vim.v.shell_error ~= 0 then - log:error("Error calling curl: %s", result) - return cb("Error executing curl", nil) - else - local err, data = parse_response(0, result) - if err then - return cb(err, nil) - else - return cb(nil, data) - end - end -end - ---@param adapter CodeCompanion.Adapter ---@param payload table ---@param bufnr number @@ -154,11 +112,11 @@ function Client:stream_request(adapter, payload, bufnr, cb) log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, adapter.headers, adapter.parameters }) - local handler = self.settings.request({ + local handler = self.opts.request({ url = adapter.url, - raw = adapter.raw, - headers = adapter.headers, - body = self.settings.encode(vim.tbl_extend("keep", adapter.parameters, { + raw = adapter.raw or { "--no-buffer" }, + headers = adapter.headers or {}, + body = self.opts.encode(vim.tbl_extend("keep", adapter.parameters, { messages = payload, })), stream = function(_, chunk) @@ -166,18 +124,18 @@ function Client:stream_request(adapter, payload, bufnr, cb) if chunk ~= "" then if chunk == "[DONE]" then - self.settings.schedule(function() + self.opts.schedule(function() close_request(bufnr) return cb(nil, nil, true) end) else - self.settings.schedule(function() + self.opts.schedule(function() if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then close_request(bufnr, { shutdown = true }) return cb(nil, nil, true) end - local ok, data = pcall(self.settings.decode, chunk, { luanil = { object = true } }) + local ok, data = pcall(self.opts.decode, chunk, { luanil = { object = true } }) if not ok then log:error("Error malformed json: %s", data) @@ -210,6 +168,48 @@ function Client:stream_request(adapter, payload, bufnr, cb) start_request(bufnr, handler) end +---Call the OpenAI API but block the main loop until the response is received +---@param url string +---@param payload table +---@param cb fun(err: nil|string, response: nil|table) +function Client:block_request(url, payload, cb) + cb = log:wrap_cb(cb, "Response error: %s") + + local cmd = { + "curl", + url, + "--silent", + "--no-buffer", + "-H", + "Content-Type: application/json", + "-H", + string.format("Authorization: Bearer %s", self.secret_key), + } + + if self.organization then + table.insert(cmd, "-H") + table.insert(cmd, string.format("OpenAI-Organization: %s", self.organization)) + end + + table.insert(cmd, "-d") + table.insert(cmd, vim.json.encode(payload)) + log:trace("Request payload: %s", cmd) + + local result = vim.fn.system(cmd) + + if vim.v.shell_error ~= 0 then + log:error("Error calling curl: %s", result) + return cb("Error executing curl", nil) + else + local err, data = parse_response(0, result) + if err then + return cb(err, nil) + else + return cb(nil, data) + end + end +end + ---@class CodeCompanion.ChatMessage ---@field role "system"|"user"|"assistant" ---@field content string diff --git a/lua/codecompanion/init.lua b/lua/codecompanion/init.lua index 34aa615c..371dc5d6 100644 --- a/lua/codecompanion/init.lua +++ b/lua/codecompanion/init.lua @@ -4,30 +4,6 @@ local util = require("codecompanion.utils.util") local M = {} -local _client ----@return nil|CodeCompanion.Client -function M.get_client() - if not _client then - local secret_key = os.getenv(config.options.api_key) - if not secret_key then - vim.notify( - string.format("[CodeCompanion.nvim]\nCould not find env variable: %s", config.options.api_key), - vim.log.levels.ERROR - ) - return nil - end - - local Client = require("codecompanion.client") - - _client = Client.new({ - secret_key = secret_key, - organization = os.getenv(config.options.org_api_key), - }) - end - - return _client -end - ---@param bufnr nil|integer ---@return nil|CodeCompanion.Chat M.buf_get_chat = function(bufnr) @@ -37,17 +13,11 @@ end ---@param args table ---@return nil|CodeCompanion.Inline M.inline = function(args) - local client = M.get_client() - if not client then - return - end - local context = util.get_context(vim.api.nvim_get_current_buf(), args) return require("codecompanion.strategies.inline") .new({ context = context, - client = client, prompts = { { role = "system", @@ -64,15 +34,9 @@ end ---@param args? table M.chat = function(args) - local client = M.get_client() - if not client then - return - end - local context = util.get_context(vim.api.nvim_get_current_buf(), args) local chat = require("codecompanion.strategies.chat").new({ - client = client, context = context, }) @@ -118,11 +82,6 @@ end local _cached_actions = {} M.actions = function(args) - local client = M.get_client() - if not client then - return - end - local actions = require("codecompanion.actions") local context = util.get_context(vim.api.nvim_get_current_buf(), args) @@ -164,7 +123,6 @@ M.actions = function(args) else local Strategy = require("codecompanion.strategy") return Strategy.new({ - client = client, context = context, selected = item, }):start(item.strategy) diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index a950e2f0..95439707 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -1,3 +1,4 @@ +local client = require("codecompanion.client") local config = require("codecompanion.config") local log = require("codecompanion.utils.log") local schema = require("codecompanion.schema") @@ -318,20 +319,18 @@ local function chat_autocmds(bufnr, args) if _G.codecompanion_jobs[request.data.buf] then _G.codecompanion_jobs[request.data.buf].handler:shutdown() end - vim.api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = "finished" } }) - vim.api.nvim_buf_delete(request.data.buf, { force = true }) + api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = "finished" } }) + api.nvim_buf_delete(request.data.buf, { force = true }) end, }) end ---@class CodeCompanion.Chat ----@field client CodeCompanion.Client ---@field bufnr integer ---@field settings CodeCompanion.ChatSettings local Chat = {} ---@class CodeCompanion.ChatArgs ----@field client CodeCompanion.Client ---@field adapter CodeCompanion.Adapter ---@field context table ---@field messages nil|CodeCompanion.ChatMessage[] @@ -368,7 +367,6 @@ function Chat.new(args) local self = setmetatable({ bufnr = bufnr, - client = args.client, context = args.context, saved_chat = args.saved_chat, settings = settings, @@ -439,40 +437,37 @@ function Chat:submit() return finalize() end - self.client:stream_request( - config.options.adapters.chat:set_params(settings), - messages, - self.bufnr, - function(err, chunk, done) - if err then - log:error("Error: %s", err) - vim.notify("Error: " .. err, vim.log.levels.ERROR) - return finalize() - end - - if chunk then - log:trace("Chat chunk: %s", chunk) - local delta = chunk.choices[1].delta - if delta.role and delta.role ~= new_message.role then - new_message = { role = delta.role, content = "" } - table.insert(messages, new_message) - end + local adapter = config.options.adapters.chat - if delta.content then - new_message.content = new_message.content .. delta.content - end + client.new():stream_request(adapter:set_params(settings), messages, self.bufnr, function(err, chunk, done) + if err then + log:error("Error: %s", err) + vim.notify("Error: " .. err, vim.log.levels.ERROR) + return finalize() + end - render_buffer() + if chunk then + log:trace("Chat chunk: %s", chunk) + local delta = chunk.choices[1].delta + if delta.role and delta.role ~= new_message.role then + new_message = { role = delta.role, content = "" } + table.insert(messages, new_message) end - if done then - table.insert(messages, { role = "user", content = "" }) - render_buffer() - display_tokens(self.bufnr) - finalize() + if delta.content then + new_message.content = new_message.content .. delta.content end + + render_buffer() end - ) + + if done then + table.insert(messages, { role = "user", content = "" }) + render_buffer() + display_tokens(self.bufnr) + finalize() + end + end) end ---@param opts nil|table diff --git a/lua/codecompanion/strategy.lua b/lua/codecompanion/strategy.lua index defd7e66..0d8f8ace 100644 --- a/lua/codecompanion/strategy.lua +++ b/lua/codecompanion/strategy.lua @@ -28,13 +28,11 @@ local function modal_prompts(prompts, context) end ---@class CodeCompanion.Strategy ----@field client CodeCompanion.Client ---@field context table ---@field selected table local Strategy = {} ---@class CodeCompanion.StrategyArgs ----@field client CodeCompanion.Client ---@field context table ---@field selected table @@ -43,7 +41,6 @@ local Strategy = {} function Strategy.new(args) log:trace("Context: %s", args.context) return setmetatable({ - client = args.client, context = args.context, selected = args.selected, }, { __index = Strategy }) @@ -77,7 +74,6 @@ function Strategy:chat() end return require("codecompanion.strategies.chat").new({ - client = self.client, type = self.selected.type, messages = messages, show_buffer = true, @@ -102,7 +98,6 @@ function Strategy:inline() return require("codecompanion.strategies.inline") .new({ context = self.context, - client = self.client, opts = self.selected.opts, pre_hook = self.selected.pre_hook, prompts = self.selected.prompts, diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua index 3baac16c..fe98e925 100644 --- a/lua/spec/codecompanion/adapter_spec.lua +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -1,10 +1,7 @@ local assert = require("luassert") local test_adapter = { - opts = { - name = "TestAdapter", - stream = true, - }, + name = "TestAdapter", url = "https://api.openai.com/v1/chat/completions", headers = { content_type = "application/json", diff --git a/lua/spec/codecompanion/client_spec.lua b/lua/spec/codecompanion/client_spec.lua index deac9688..f4ad58d1 100644 --- a/lua/spec/codecompanion/client_spec.lua +++ b/lua/spec/codecompanion/client_spec.lua @@ -4,6 +4,18 @@ local stub = require("luassert.stub") local Client +local adapter = { + name = "TestAdapter", + url = "https://api.openai.com/v1/chat/completions", + headers = { + content_type = "application/json", + }, + parameters = { + stream = true, + }, + schema = {}, +} + describe("Client", function() before_each(function() codecompanion.setup() @@ -23,7 +35,7 @@ describe("Client", function() -- Mock globals _G.codecompanion_jobs = {} - Client.static.settings = { + Client.static.opts = { request = { default = mock_request }, encode = { default = mock_encode }, decode = { default = mock_decode }, @@ -37,7 +49,7 @@ describe("Client", function() local cb = stub.new() - client:stream_chat({}, 0, cb) + client:stream_request(adapter, {}, 0, cb) assert.stub(mock_request).was_called() end) From f9938b3bd65a53781c6a4424a87f5d055c5ee550 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Sun, 3 Mar 2024 20:56:46 +0000 Subject: [PATCH 14/42] start adding callbacks to adapters --- lua/codecompanion/adapter.lua | 6 ++- lua/codecompanion/adapters/openai.lua | 17 +++++++ lua/codecompanion/client.lua | 64 ++++++++++----------------- lua/codecompanion/schema.lua | 1 + lua/codecompanion/strategies/chat.lua | 10 ++--- 5 files changed, 50 insertions(+), 48 deletions(-) diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index 9a5597fc..c6edf18a 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -1,18 +1,20 @@ ---@class CodeCompanion.Adapter ---@field name string ---@field url string ----@field raw table +---@field raw? table ---@field header table ---@field parameters table +---@field callbacks table ---@field schema table local Adapter = {} ---@class CodeCompanion.AdapterArgs ---@field name string ---@field url string ----@field raw table +---@field raw? table ---@field header table ---@field parameters table +---@field callbacks table ---@field schema table ---@param args table diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index f5dcdcd8..1b852206 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -3,6 +3,7 @@ local Adapter = require("codecompanion.adapter") ---@class CodeCompanion.Adapter ---@field name string ---@field url string +---@field raw? table ---@field headers table ---@field parameters table ---@field schema table @@ -18,6 +19,22 @@ local adapter = { parameters = { stream = true, }, + callbacks = { + ---Format any data before it's consumed by the other callbacks + ---@param data string + ---@return string + format_data = function(data) + -- Remove the "data: " prefix + return data:sub(7) + end, + + ---Has the streaming completed? + ---@param formatted_data string The table from the format_data callback + ---@return boolean + is_complete = function(formatted_data) + return formatted_data == "[DONE]" + end, + }, -- TODO: Need to map roles/messages based on Tree-sitter parsing of the chat buffer schema = { model = { diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index dbee0efa..7f71ebb5 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -84,7 +84,7 @@ Client.static.opts = { request = { default = curl.post }, encode = { default = vim.json.encode }, decode = { default = vim.json.decode }, - schedule = { default = vim.schedule }, + schedule = { default = vim.schedule_wrap }, } ---@class CodeCompanion.ClientArgs @@ -107,7 +107,7 @@ end ---@param bufnr number ---@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ---@return nil -function Client:stream_request(adapter, payload, bufnr, cb) +function Client:stream(adapter, payload, bufnr, cb) cb = log:wrap_cb(cb, "Response error: %s") log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, adapter.headers, adapter.parameters }) @@ -119,48 +119,30 @@ function Client:stream_request(adapter, payload, bufnr, cb) body = self.opts.encode(vim.tbl_extend("keep", adapter.parameters, { messages = payload, })), - stream = function(_, chunk) - chunk = chunk:sub(7) - - if chunk ~= "" then - if chunk == "[DONE]" then - self.opts.schedule(function() - close_request(bufnr) - return cb(nil, nil, true) - end) - else - self.opts.schedule(function() - if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then - close_request(bufnr, { shutdown = true }) - return cb(nil, nil, true) - end - - local ok, data = pcall(self.opts.decode, chunk, { luanil = { object = true } }) - - if not ok then - log:error("Error malformed json: %s", data) - close_request(bufnr) - return cb(string.format("Error malformed json: %s", data)) - end - - if data.choices[1].finish_reason then - log:debug("Finish Reason: %s", data.choices[1].finish_reason) - end - - if data.choices[1].finish_reason == "length" then - log:debug("Token limit reached") - close_request(bufnr) - return cb("[CodeCompanion.nvim]\nThe token limit for the current chat has been reached") - end - - cb(nil, data) - end) + stream = self.opts.schedule(function(_, data) + if type(adapter.callbacks.format_data) == "function" then + data = adapter.callbacks.format_data(data) + end + + if adapter.callbacks.is_complete(data) then + close_request(bufnr) + return cb(nil, nil, true) + end + + if data ~= "" then + local ok, json = pcall(self.opts.decode, data, { luanil = { object = true } }) + + if not ok then + close_request(bufnr) + return cb(string.format("Error malformed json: %s", json)) end + + cb(nil, json) end - end, + end), on_error = function(err, _, _) - close_request(bufnr) log:error("Error: %s", err) + close_request(bufnr) end, }) @@ -243,7 +225,7 @@ end ---@return nil function Client:inline(args, bufnr, cb) args.stream = true - return self:stream_request(config.options.base_url .. "/v1/chat/completions", args, bufnr, cb) + return self:stream(config.options.base_url .. "/v1/chat/completions", args, bufnr, cb) end return Client diff --git a/lua/codecompanion/schema.lua b/lua/codecompanion/schema.lua index 3202457d..76093775 100644 --- a/lua/codecompanion/schema.lua +++ b/lua/codecompanion/schema.lua @@ -16,6 +16,7 @@ end ---@class CodeCompanion.SchemaParam ---@field type "string"|"number"|"integer"|"boolean"|"enum"|"list"|"map" +---@field mapping string ---@field order nil|integer ---@field optional nil|boolean ---@field choices nil|table diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index 95439707..bda1c3b6 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -439,16 +439,15 @@ function Chat:submit() local adapter = config.options.adapters.chat - client.new():stream_request(adapter:set_params(settings), messages, self.bufnr, function(err, chunk, done) + client.new():stream(adapter:set_params(settings), messages, self.bufnr, function(err, data, done) if err then - log:error("Error: %s", err) vim.notify("Error: " .. err, vim.log.levels.ERROR) return finalize() end - if chunk then - log:trace("Chat chunk: %s", chunk) - local delta = chunk.choices[1].delta + if data then + log:trace("Chat data: %s", data) + local delta = data.choices[1].delta if delta.role and delta.role ~= new_message.role then new_message = { role = delta.role, content = "" } table.insert(messages, new_message) @@ -462,6 +461,7 @@ function Chat:submit() end if done then + log:debug("Chat is done") table.insert(messages, { role = "user", content = "" }) render_buffer() display_tokens(self.bufnr) From 71b935a708800030bc0ab15fcfe90b98bc2b7580 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Sun, 3 Mar 2024 20:57:13 +0000 Subject: [PATCH 15/42] fix test --- lua/spec/codecompanion/client_spec.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lua/spec/codecompanion/client_spec.lua b/lua/spec/codecompanion/client_spec.lua index f4ad58d1..b0cb78c6 100644 --- a/lua/spec/codecompanion/client_spec.lua +++ b/lua/spec/codecompanion/client_spec.lua @@ -49,7 +49,7 @@ describe("Client", function() local cb = stub.new() - client:stream_request(adapter, {}, 0, cb) + client:stream(adapter, {}, 0, cb) assert.stub(mock_request).was_called() end) From 83ae51d1ccffe31dcd8d3b30bec4e72dedf00ec5 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Sun, 3 Mar 2024 20:57:45 +0000 Subject: [PATCH 16/42] make test more explict --- lua/spec/codecompanion/client_spec.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lua/spec/codecompanion/client_spec.lua b/lua/spec/codecompanion/client_spec.lua index b0cb78c6..c0905307 100644 --- a/lua/spec/codecompanion/client_spec.lua +++ b/lua/spec/codecompanion/client_spec.lua @@ -51,6 +51,6 @@ describe("Client", function() client:stream(adapter, {}, 0, cb) - assert.stub(mock_request).was_called() + assert.stub(mock_request).was_called(1) end) end) From 982ecf735493b8a8f44179b229f466e2b2714ffd Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Sun, 3 Mar 2024 20:58:37 +0000 Subject: [PATCH 17/42] clean up test --- lua/spec/codecompanion/client_spec.lua | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lua/spec/codecompanion/client_spec.lua b/lua/spec/codecompanion/client_spec.lua index c0905307..5830ab0d 100644 --- a/lua/spec/codecompanion/client_spec.lua +++ b/lua/spec/codecompanion/client_spec.lua @@ -42,14 +42,9 @@ describe("Client", function() schedule = { default = mock_schedule }, } - local client = Client.new({ - secret_key = "fake_key", - organization = "fake_org", - }) - local cb = stub.new() - client:stream(adapter, {}, 0, cb) + Client.new():stream(adapter, {}, 0, cb) assert.stub(mock_request).was_called(1) end) From 1d58fe9961f653bcd184185c189500870a13f475 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Sun, 3 Mar 2024 22:11:36 +0000 Subject: [PATCH 18/42] chat buffer now fully moved to openai adapter --- lua/codecompanion/adapter.lua | 13 +++++++++++++ lua/codecompanion/adapters/openai.lua | 21 ++++++++++++++++++++- lua/codecompanion/client.lua | 9 +++++++-- lua/codecompanion/config.lua | 12 +----------- lua/codecompanion/strategies/chat.lua | 17 ++++------------- 5 files changed, 45 insertions(+), 27 deletions(-) diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index c6edf18a..861c519b 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -23,6 +23,19 @@ function Adapter.new(args) return setmetatable(args, { __index = Adapter }) end +---@return table +function Adapter:get_default_settings() + local settings = {} + + for key, value in pairs(self.schema) do + if value.default ~= nil then + settings[key] = value.default + end + end + + return settings +end + ---@param settings table ---@return CodeCompanion.Adapter function Adapter:set_params(settings) diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 1b852206..fb019e87 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -34,8 +34,27 @@ local adapter = { is_complete = function(formatted_data) return formatted_data == "[DONE]" end, + + ---Format the messages from the API + ---@param data table + ---@param messages table + ---@param new_message table + ---@return table + format_messages = function(data, messages, new_message) + local delta = data.choices[1].delta + + if delta.role and delta.role ~= new_message.role then + new_message = { role = delta.role, content = "" } + table.insert(messages, new_message) + end + + if delta.content then + new_message.content = new_message.content .. delta.content + end + + return new_message + end, }, - -- TODO: Need to map roles/messages based on Tree-sitter parsing of the chat buffer schema = { model = { order = 1, diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 7f71ebb5..6823842b 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -120,7 +120,12 @@ function Client:stream(adapter, payload, bufnr, cb) messages = payload, })), stream = self.opts.schedule(function(_, data) - if type(adapter.callbacks.format_data) == "function" then + if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then + close_request(bufnr, { shutdown = true }) + return cb(nil, nil, true) + end + + if data and type(adapter.callbacks.format_data) == "function" then data = adapter.callbacks.format_data(data) end @@ -129,7 +134,7 @@ function Client:stream(adapter, payload, bufnr, cb) return cb(nil, nil, true) end - if data ~= "" then + if data and data ~= "" then local ok, json = pcall(self.opts.decode, data, { luanil = { object = true } }) if not ok then diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index d5381bb0..5d92f3ff 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -6,19 +6,9 @@ local defaults = { base_url = "https://api.openai.com", adapters = { chat = require("codecompanion.adapters.openai"), + inline = require("codecompanion.adapters.openai"), }, ai_settings = { - chat = { - model = "gpt-4-0125-preview", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, inline = { model = "gpt-4-0125-preview", temperature = 1, diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index bda1c3b6..ff1b1244 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -34,7 +34,7 @@ local function parse_settings(bufnr) end if not config.options.display.chat.show_settings then - config_settings[bufnr] = vim.deepcopy(config.options.ai_settings.chat) + config_settings[bufnr] = config.options.adapters.chat:get_default_settings() log:debug("Using the settings from the user's config: %s", config_settings[bufnr]) return config_settings[bufnr] @@ -447,25 +447,16 @@ function Chat:submit() if data then log:trace("Chat data: %s", data) - local delta = data.choices[1].delta - if delta.role and delta.role ~= new_message.role then - new_message = { role = delta.role, content = "" } - table.insert(messages, new_message) - end - - if delta.content then - new_message.content = new_message.content .. delta.content - end - + new_message = adapter.callbacks.format_messages(data, messages, new_message) render_buffer() end if done then - log:debug("Chat is done") + log:trace("Chat streaming is done") table.insert(messages, { role = "user", content = "" }) render_buffer() display_tokens(self.bufnr) - finalize() + return finalize() end end) end From 82072d317915ab20051da244cd01b9809fe9a14b Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Sun, 3 Mar 2024 22:18:55 +0000 Subject: [PATCH 19/42] fix env var in tests --- .github/workflows/ci.yml | 2 ++ lua/codecompanion/adapters/openai.lua | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 285924ed..c1a22039 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,6 +42,8 @@ jobs: ln -s $(pwd) ~/.local/share/nvim/site/pack/vendor/start - name: Run tests + env: + OPENAI_API_KEY: abc-123 run: | export PATH="${PWD}/_neovim/bin:${PATH}" export VIM="${PWD}/_neovim/share/nvim/runtime" diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index fb019e87..9a11e9bf 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -14,7 +14,7 @@ local adapter = { headers = { content_type = "application/json", -- FIX: Need a way to check if the key is set - Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY") or nil, + Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY"), }, parameters = { stream = true, From c843a167f638ef190062ba9b0c82e5b5e2db600c Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 09:40:26 +0000 Subject: [PATCH 20/42] inline strategy now uses adapter --- README.md | 31 +--- lua/codecompanion/adapter.lua | 6 +- lua/codecompanion/adapters/openai.lua | 35 +++-- lua/codecompanion/client.lua | 74 +++------ lua/codecompanion/config.lua | 16 -- lua/codecompanion/health.lua | 22 +-- lua/codecompanion/strategies/chat.lua | 6 +- lua/codecompanion/strategies/inline.lua | 197 ++++++++++++------------ 8 files changed, 167 insertions(+), 220 deletions(-) diff --git a/README.md b/README.md index 4bd6aca8..2f17d684 100644 --- a/README.md +++ b/README.md @@ -95,34 +95,9 @@ You only need to the call the `setup` function if you wish to change any of the ```lua require("codecompanion").setup({ - api_key = "OPENAI_API_KEY", -- Your API key - org_api_key = "OPENAI_ORG_KEY", -- Your organisation API key - base_url = "https://api.openai.com", -- The URL to use for the API requests - ai_settings = { - -- Default settings for the Completions API - -- See https://platform.openai.com/docs/api-reference/chat/create - chat = { - model = "gpt-4-0125-preview", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, - inline = { - model = "gpt-3.5-turbo-0125", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, + adapters = { + chat = require("codecompanion.adapters.openai"), + inline = require("codecompanion.adapters.openai"), }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", -- Path to save chats to diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index 861c519b..266cfa94 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -36,9 +36,13 @@ function Adapter:get_default_settings() return settings end ----@param settings table +---@param settings? table ---@return CodeCompanion.Adapter function Adapter:set_params(settings) + if not settings then + settings = self:get_default_settings() + end + for k, v in pairs(settings) do local mapping = self.schema[k] and self.schema[k].mapping if mapping then diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 9a11e9bf..cc0cba84 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -11,8 +11,12 @@ local Adapter = require("codecompanion.adapter") local adapter = { name = "OpenAI", url = "https://api.openai.com/v1/chat/completions", + raw = { + "--no-buffer", + "--silent", + }, headers = { - content_type = "application/json", + ["Content-Type"] = "application/json", -- FIX: Need a way to check if the key is set Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY"), }, @@ -35,24 +39,33 @@ local adapter = { return formatted_data == "[DONE]" end, - ---Format the messages from the API - ---@param data table - ---@param messages table - ---@param new_message table + ---Output the data from the API ready for insertion into the chat buffer + ---@param data table The streamed data from the API + ---@param messages table A table of all of the messages in the chat buffer + ---@param current_message table The current/latest message in the chat buffer ---@return table - format_messages = function(data, messages, new_message) + output_chat = function(data, messages, current_message) local delta = data.choices[1].delta - if delta.role and delta.role ~= new_message.role then - new_message = { role = delta.role, content = "" } - table.insert(messages, new_message) + if delta.role and delta.role ~= current_message.role then + current_message = { role = delta.role, content = "" } + table.insert(messages, current_message) end + -- Append the new message to the if delta.content then - new_message.content = new_message.content .. delta.content + current_message.content = current_message.content .. delta.content end - return new_message + return current_message + end, + + ---Output the data from the API ready for inlining into the current buffer + ---@param data table The streamed data from the API + ---@param context table Useful context about the buffer to inline to + ---@return table + output_inline = function(data, context) + return data.choices[1].delta.content end, }, schema = { diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 6823842b..18a38a8a 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -1,4 +1,3 @@ -local config = require("codecompanion.config") local curl = require("plenary.curl") local log = require("codecompanion.utils.log") local schema = require("codecompanion.schema") @@ -155,31 +154,40 @@ function Client:stream(adapter, payload, bufnr, cb) start_request(bufnr, handler) end ----Call the OpenAI API but block the main loop until the response is received ----@param url string +---Call the API and block until the response is received +---@param adapter CodeCompanion.Adapter ---@param payload table ---@param cb fun(err: nil|string, response: nil|table) -function Client:block_request(url, payload, cb) +function Client:call(adapter, payload, cb) cb = log:wrap_cb(cb, "Response error: %s") local cmd = { "curl", - url, - "--silent", - "--no-buffer", - "-H", - "Content-Type: application/json", - "-H", - string.format("Authorization: Bearer %s", self.secret_key), + adapter.url, } - if self.organization then - table.insert(cmd, "-H") - table.insert(cmd, string.format("OpenAI-Organization: %s", self.organization)) + if adapter.raw then + for _, v in ipairs(adapter.raw) do + table.insert(cmd, v) + end + else + table.insert(cmd, "--no-buffer") + end + + if adapter.headers then + for k, v in pairs(adapter.headers) do + table.insert(cmd, "-H") + table.insert(cmd, string.format("%s: %s", k, v)) + end end table.insert(cmd, "-d") - table.insert(cmd, vim.json.encode(payload)) + table.insert( + cmd, + vim.json.encode(vim.tbl_extend("keep", adapter.parameters, { + messages = payload, + })) + ) log:trace("Request payload: %s", cmd) local result = vim.fn.system(cmd) @@ -197,40 +205,4 @@ function Client:block_request(url, payload, cb) end end ----@class CodeCompanion.ChatMessage ----@field role "system"|"user"|"assistant" ----@field content string - ----@class CodeCompanion.ChatSettings ----@field model string ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. ----@field temperature nil|number Defaults to 1. What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. ----@field top_p nil|number Defaults to 1. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. ----@field n nil|integer Defaults to 1. How many chat completion choices to generate for each input message. ----@field stop nil|string|string[] Defaults to nil. Up to 4 sequences where the API will stop generating further tokens. ----@field max_tokens nil|integer Defaults to nil. The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. ----@field presence_penalty nil|number Defaults to 0. Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. ----@field frequency_penalty nil|number Defaults to 0. Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. ----@field logit_bias nil|table Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs. ----@field user nil|string A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more. - ----@class CodeCompanion.ChatArgs : CodeCompanion.ChatSettings ----@field messages CodeCompanion.ChatMessage[] The messages to generate chat completions for, in the chat format. ----@field stream boolean? Whether to stream the chat output back to Neovim - ----@param args CodeCompanion.ChatArgs ----@param cb fun(err: nil|string, response: nil|table) ----@return nil -function Client:chat(args, cb) - return self:block_request(config.options.base_url .. "/v1/chat/completions", args, cb) -end - ----@class args CodeCompanion.InlineArgs ----@param bufnr integer ----@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ----@return nil -function Client:inline(args, bufnr, cb) - args.stream = true - return self:stream(config.options.base_url .. "/v1/chat/completions", args, bufnr, cb) -end - return Client diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index 5d92f3ff..870a7b89 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -1,26 +1,10 @@ local M = {} local defaults = { - api_key = "OPENAI_API_KEY", - org_api_key = "OPENAI_ORG_KEY", - base_url = "https://api.openai.com", adapters = { chat = require("codecompanion.adapters.openai"), inline = require("codecompanion.adapters.openai"), }, - ai_settings = { - inline = { - model = "gpt-4-0125-preview", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, - }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", }, diff --git a/lua/codecompanion/health.lua b/lua/codecompanion/health.lua index 30bf7a43..5203496d 100644 --- a/lua/codecompanion/health.lua +++ b/lua/codecompanion/health.lua @@ -94,17 +94,17 @@ function M.check() end end - for _, env in ipairs(M.env_vars) do - if env_available(env.name) then - ok(fmt("%s key found", env.name)) - else - if env.optional then - warn(fmt("%s key not found", env.name)) - else - error(fmt("%s key not found", env.name)) - end - end - end + -- for _, env in ipairs(M.env_vars) do + -- if env_available(env.name) then + -- ok(fmt("%s key found", env.name)) + -- else + -- if env.optional then + -- warn(fmt("%s key not found", env.name)) + -- else + -- error(fmt("%s key not found", env.name)) + -- end + -- end + -- end end return M diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index ff1b1244..3223c908 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -431,9 +431,9 @@ function Chat:submit() end end - local new_message = messages[#messages] + local current_message = messages[#messages] - if new_message and new_message.role == "user" and new_message.content == "" then + if current_message and current_message.role == "user" and current_message.content == "" then return finalize() end @@ -447,7 +447,7 @@ function Chat:submit() if data then log:trace("Chat data: %s", data) - new_message = adapter.callbacks.format_messages(data, messages, new_message) + current_message = adapter.callbacks.output_chat(data, messages, current_message) render_buffer() end diff --git a/lua/codecompanion/strategies/inline.lua b/lua/codecompanion/strategies/inline.lua index 6529fb5f..1bc61088 100644 --- a/lua/codecompanion/strategies/inline.lua +++ b/lua/codecompanion/strategies/inline.lua @@ -1,4 +1,6 @@ +local client = require("codecompanion.client") local config = require("codecompanion.config") +local adapter = config.options.adapters.inline local log = require("codecompanion.utils.log") local ui = require("codecompanion.utils.ui") @@ -9,6 +11,14 @@ local function fire_autocmd(status) vim.api.nvim_exec_autocmds("User", { pattern = "CodeCompanionInline", data = { status = status } }) end +---When a user initiates an inline request, it will be possible to infer from +---their prompt, how the output should be placed in the buffer. For instance, if +---they reference the words "refactor" or "update" in their prompt, they will +---likely want the response to replace a visual selection they've made in the +---editor. However, if they include words such as "after" or "before", it may +---be insinuated that they wish the response to be placed after or before the +---cursor. In this function, we use the power of Generative AI to determine +---the user's intent and return a placement position. ---@param inline CodeCompanion.Inline ---@param prompt string ---@return string, boolean @@ -32,24 +42,17 @@ local function get_placement_position(inline, prompt) } local output - inline.client:chat( - vim.tbl_extend("keep", inline.settings, { - messages = messages, - }), - function(err, chunk) - if err then - return - end + client.new():call(adapter:set_params(), messages, function(err, data) + if err then + return + end - if chunk then - local content = chunk.choices[1].message.content - if content then - log:trace("Placement response: %s", content) - output = content - end - end + if data then + print(vim.inspect(data)) end - ) + end) + + log:trace("Placement output: %s", output) if output then local parts = vim.split(output, "|") @@ -131,17 +134,13 @@ local function get_cursor(winid) end ---@class CodeCompanion.Inline ----@field settings table ---@field context table ----@field client CodeCompanion.Client ---@field opts table ---@field prompts table local Inline = {} ---@class CodeCompanion.InlineArgs ----@field settings table ---@field context table ----@field client CodeCompanion.Client ---@field opts table ---@field pre_hook fun():number -- Assuming pre_hook returns a number for example ---@field prompts table @@ -162,9 +161,7 @@ function Inline.new(opts) end return setmetatable({ - settings = config.options.ai_settings.inline, context = opts.context, - client = opts.client, opts = opts.opts or {}, prompts = vim.deepcopy(opts.prompts), }, { __index = Inline }) @@ -176,54 +173,60 @@ function Inline:execute(user_input) local messages = get_action(self, user_input) - if not self.opts.placement and user_input then - local return_code - self.opts.placement, return_code = get_placement_position(self, user_input) - - if not return_code then - table.insert(messages, { - role = "user", - content = "Please do not return the code I have sent in the response", - }) - end - - log:debug("Setting the placement to: %s", self.opts.placement) - end - + -- if not self.opts.placement and user_input then + -- local return_code + -- self.opts.placement, return_code = get_placement_position(self, user_input) + -- + -- if not return_code then + -- table.insert(messages, { + -- role = "user", + -- content = "Please do not return the code I have sent in the response", + -- }) + -- end + -- + -- log:debug("Setting the placement to: %s", self.opts.placement) + -- end + + -- Assume the placement should be after the cursor + vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) + pos.line = self.context.end_line + 1 + pos.col = 0 + + --TODO: Workout how we can re-enable this -- Determine where to place the response in the buffer - if self.opts and self.opts.placement then - if self.opts.placement == "before" then - log:trace("Placing before selection") - vim.api.nvim_buf_set_lines( - self.context.bufnr, - self.context.start_line - 1, - self.context.start_line - 1, - false, - { "" } - ) - self.context.start_line = self.context.start_line + 1 - pos.line = self.context.start_line - 1 - pos.col = self.context.start_col - 1 - elseif self.opts.placement == "after" then - log:trace("Placing after selection") - vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) - pos.line = self.context.end_line + 1 - pos.col = 0 - elseif self.opts.placement == "replace" then - log:trace("Placing by overwriting selection") - overwrite_selection(self.context) - - pos.line, pos.col = get_cursor(self.context.winid) - elseif self.opts.placement == "new" then - log:trace("Placing in a new buffer") - self.context.bufnr = api.nvim_create_buf(true, false) - api.nvim_buf_set_option(self.context.bufnr, "filetype", self.context.filetype) - api.nvim_set_current_buf(self.context.bufnr) - - pos.line = 1 - pos.col = 0 - end - end + -- if self.opts and self.opts.placement then + -- if self.opts.placement == "before" then + -- log:trace("Placing before selection") + -- vim.api.nvim_buf_set_lines( + -- self.context.bufnr, + -- self.context.start_line - 1, + -- self.context.start_line - 1, + -- false, + -- { "" } + -- ) + -- self.context.start_line = self.context.start_line + 1 + -- pos.line = self.context.start_line - 1 + -- pos.col = self.context.start_col - 1 + -- elseif self.opts.placement == "after" then + -- log:trace("Placing after selection") + -- vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) + -- pos.line = self.context.end_line + 1 + -- pos.col = 0 + -- elseif self.opts.placement == "replace" then + -- log:trace("Placing by overwriting selection") + -- overwrite_selection(self.context) + -- + -- pos.line, pos.col = get_cursor(self.context.winid) + -- elseif self.opts.placement == "new" then + -- log:trace("Placing in a new buffer") + -- self.context.bufnr = api.nvim_create_buf(true, false) + -- api.nvim_buf_set_option(self.context.bufnr, "filetype", self.context.filetype) + -- api.nvim_set_current_buf(self.context.bufnr) + -- + -- pos.line = 1 + -- pos.col = 0 + -- end + -- end log:debug("Context for inline: %s", self.context) log:debug("Cursor position to use: %s", pos) @@ -268,43 +271,39 @@ function Inline:execute(user_input) fire_autocmd("started") local output = {} - self.client:stream_chat( - vim.tbl_extend("keep", self.settings, { - messages = messages, - }), - self.context.bufnr, - function(err, chunk, done) - if err then - return - end + client.new():stream(adapter:set_params(), messages, self.context.bufnr, function(err, data, done) + if err then + fire_autocmd("finished") + return + end - if chunk then - log:trace("Chat chunk: %s", chunk) - - local delta = chunk.choices[1].delta - if delta.content and not delta.role and delta.content ~= "```" and delta.content ~= self.context.filetype then - if self.context.buftype == "terminal" then - -- Don't stream to the terminal - table.insert(output, delta.content) - else - stream_buffer_text(delta.content) - if self.opts and self.opts.placement == "new" then - ui.buf_scroll_to_end(self.context.bufnr) - end + if data then + log:trace("Inline data: %s", data) + + local content = adapter.callbacks.output_inline(data, self.context) + + if self.context.buftype == "terminal" then + -- Don't stream to the terminal + table.insert(output, content) + else + if content then + stream_buffer_text(content) + if self.opts and self.opts.placement == "new" then + ui.buf_scroll_to_end(self.context.bufnr) end end end + end - if done then - api.nvim_buf_del_keymap(self.context.bufnr, "n", "q") - if self.context.buftype == "terminal" then - log:debug("Terminal output: %s", output) - api.nvim_put({ table.concat(output, "") }, "", false, true) - end - fire_autocmd("finished") + if done then + api.nvim_buf_del_keymap(self.context.bufnr, "n", "q") + if self.context.buftype == "terminal" then + log:debug("Terminal output: %s", output) + api.nvim_put({ table.concat(output, "") }, "", false, true) end + fire_autocmd("finished") end - ) + end) end ---@param user_input? string From 6cc48022c39dacedb2571a42a28b4fcd497dd1ab Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 13:01:35 +0000 Subject: [PATCH 21/42] env vars in headers can be swapped in --- lua/codecompanion/adapter.lua | 23 +++++++++++++++++++++++ lua/codecompanion/adapters/openai.lua | 6 ++++-- lua/codecompanion/client.lua | 23 ++++++----------------- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index 266cfa94..69b6e585 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -1,6 +1,9 @@ +local log = require("codecompanion.utils.log") + ---@class CodeCompanion.Adapter ---@field name string ---@field url string +---@field env? table ---@field raw? table ---@field header table ---@field parameters table @@ -11,6 +14,7 @@ local Adapter = {} ---@class CodeCompanion.AdapterArgs ---@field name string ---@field url string +---@field env? table ---@field raw? table ---@field header table ---@field parameters table @@ -75,4 +79,23 @@ function Adapter:set_params(settings) return self end +---@return CodeCompanion.Adapter +function Adapter:replace_header_vars() + for k, v in pairs(self.headers) do + self.headers[k] = v:gsub("${(.-)}", function(var) + local env_var = os.getenv(self.env[var]) + if not env_var then + log:error("Error: Could not find env var: %s", self.env[var]) + return vim.notify( + string.format("[CodeCompanion.nvim]\nCould not find env var: %s", self.env[var]), + vim.log.levels.ERROR + ) + end + return env_var + end) + end + + return self +end + return Adapter diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index cc0cba84..5ab7ab56 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -11,14 +11,16 @@ local Adapter = require("codecompanion.adapter") local adapter = { name = "OpenAI", url = "https://api.openai.com/v1/chat/completions", + env = { + openai_api_key = "OPENAI_API_KEY", + }, raw = { "--no-buffer", "--silent", }, headers = { ["Content-Type"] = "application/json", - -- FIX: Need a way to check if the key is set - Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY"), + Authorization = "Bearer ${openai_api_key}", }, parameters = { stream = true, diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 18a38a8a..994b5de1 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -33,20 +33,6 @@ local function close_request(bufnr, opts) fire_autocmd("finished") end ----@param client CodeCompanion.Client ----@return table -local function headers(client) - local group = { - content_type = "application/json", - Authorization = "Bearer " .. client.secret_key, - OpenAI_Organization = client.organization, - } - - log:debug("Request Headers: %s", group) - - return group -end - ---@param code integer ---@param stdout string ---@return nil|string @@ -102,22 +88,25 @@ function Client.new(args) end ---@param adapter CodeCompanion.Adapter ----@param payload table +---@param payload table the messages to send to the API ---@param bufnr number ---@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ---@return nil function Client:stream(adapter, payload, bufnr, cb) cb = log:wrap_cb(cb, "Response error: %s") - log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, adapter.headers, adapter.parameters }) + --TODO: Check for any errors env variables + local headers = adapter:replace_header_vars().headers + + log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, headers, adapter.parameters }) local handler = self.opts.request({ url = adapter.url, raw = adapter.raw or { "--no-buffer" }, - headers = adapter.headers or {}, body = self.opts.encode(vim.tbl_extend("keep", adapter.parameters, { messages = payload, })), + headers = headers, stream = self.opts.schedule(function(_, data) if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then close_request(bufnr, { shutdown = true }) From e46a0655e510d6a5955eeaca19f134c415da3144 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 13:01:57 +0000 Subject: [PATCH 22/42] adapter call to format input messages to api --- lua/codecompanion/adapters/openai.lua | 18 ++++++++++++------ lua/codecompanion/client.lua | 4 +--- lua/codecompanion/strategies/chat.lua | 12 ++++++------ 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 5ab7ab56..6f670746 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -26,6 +26,13 @@ local adapter = { stream = true, }, callbacks = { + ---Set the format of the role and content for the messages from the chat buffer + ---@param messages table Format is: { { role = "user", content = "Your prompt here" } } + ---@return table + form_messages = function(messages) + return { messages = messages } + end, + ---Format any data before it's consumed by the other callbacks ---@param data string ---@return string @@ -35,14 +42,14 @@ local adapter = { end, ---Has the streaming completed? - ---@param formatted_data string The table from the format_data callback + ---@param data string The data from the format_data callback ---@return boolean - is_complete = function(formatted_data) - return formatted_data == "[DONE]" + is_complete = function(data) + return data == "[DONE]" end, ---Output the data from the API ready for insertion into the chat buffer - ---@param data table The streamed data from the API + ---@param data table The streamed data from the API, also formatted by the format_data callback ---@param messages table A table of all of the messages in the chat buffer ---@param current_message table The current/latest message in the chat buffer ---@return table @@ -54,7 +61,6 @@ local adapter = { table.insert(messages, current_message) end - -- Append the new message to the if delta.content then current_message.content = current_message.content .. delta.content end @@ -63,7 +69,7 @@ local adapter = { end, ---Output the data from the API ready for inlining into the current buffer - ---@param data table The streamed data from the API + ---@param data table The streamed data from the API, also formatted by the format_data callback ---@param context table Useful context about the buffer to inline to ---@return table output_inline = function(data, context) diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 994b5de1..cc73fa1c 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -103,10 +103,8 @@ function Client:stream(adapter, payload, bufnr, cb) local handler = self.opts.request({ url = adapter.url, raw = adapter.raw or { "--no-buffer" }, - body = self.opts.encode(vim.tbl_extend("keep", adapter.parameters, { - messages = payload, - })), headers = headers, + body = self.opts.encode(vim.tbl_extend("keep", adapter.parameters, adapter.callbacks.form_messages(payload))), stream = self.opts.schedule(function(_, data) if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then close_request(bufnr, { shutdown = true }) diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index 3223c908..7743711c 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -58,7 +58,7 @@ end ---@param bufnr integer ---@return table ----@return CodeCompanion.ChatMessage[] +---@return table local function parse_messages_buffer(bufnr) local ret = {} @@ -96,8 +96,8 @@ local function parse_messages_buffer(bufnr) end ---@param bufnr integer ----@param settings CodeCompanion.ChatSettings ----@param messages CodeCompanion.ChatMessage[] +---@param settings table +---@param messages table ---@param context table local function render_messages(bufnr, settings, messages, context) local lines = {} @@ -327,16 +327,16 @@ end ---@class CodeCompanion.Chat ---@field bufnr integer ----@field settings CodeCompanion.ChatSettings +---@field settings table local Chat = {} ---@class CodeCompanion.ChatArgs ---@field adapter CodeCompanion.Adapter ---@field context table ----@field messages nil|CodeCompanion.ChatMessage[] +---@field messages nil|table ---@field show_buffer nil|boolean ---@field auto_submit nil|boolean ----@field settings nil|CodeCompanion.ChatSettings +---@field settings nil|table ---@field type nil|string ---@field saved_chat nil|string From 1c8bab4bab649ced13786a4b207db644318a9485 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 13:20:30 +0000 Subject: [PATCH 23/42] fix tests --- lua/codecompanion/adapter.lua | 30 +++++++++++++++----------- lua/spec/codecompanion/client_spec.lua | 7 ++++++ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index 69b6e585..143b4971 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -81,18 +81,24 @@ end ---@return CodeCompanion.Adapter function Adapter:replace_header_vars() - for k, v in pairs(self.headers) do - self.headers[k] = v:gsub("${(.-)}", function(var) - local env_var = os.getenv(self.env[var]) - if not env_var then - log:error("Error: Could not find env var: %s", self.env[var]) - return vim.notify( - string.format("[CodeCompanion.nvim]\nCould not find env var: %s", self.env[var]), - vim.log.levels.ERROR - ) - end - return env_var - end) + if self.headers then + for k, v in pairs(self.headers) do + self.headers[k] = v:gsub("${(.-)}", function(var) + local env_var = self.env[var] + + if env_var then + env_var = os.getenv(env_var) + if not env_var then + log:error("Error: Could not find env var: %s", self.env[var]) + return vim.notify( + string.format("[CodeCompanion.nvim]\nCould not find env var: %s", self.env[var]), + vim.log.levels.ERROR + ) + end + return env_var + end + end) + end end return self diff --git a/lua/spec/codecompanion/client_spec.lua b/lua/spec/codecompanion/client_spec.lua index 5830ab0d..f4808279 100644 --- a/lua/spec/codecompanion/client_spec.lua +++ b/lua/spec/codecompanion/client_spec.lua @@ -13,6 +13,11 @@ local adapter = { parameters = { stream = true, }, + callbacks = { + form_messages = function() + return {} + end, + }, schema = {}, } @@ -44,6 +49,8 @@ describe("Client", function() local cb = stub.new() + adapter = require("codecompanion.adapter").new(adapter) + Client.new():stream(adapter, {}, 0, cb) assert.stub(mock_request).was_called(1) From 1f150b3f205462a9d92472baaf2501fd349094ef Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 14:11:51 +0000 Subject: [PATCH 24/42] add ollama support --- lua/codecompanion/adapters/ollama.lua | 84 +++++++++++++++++++++++++++ lua/codecompanion/client.lua | 6 +- lua/codecompanion/config.lua | 4 +- 3 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 lua/codecompanion/adapters/ollama.lua diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua new file mode 100644 index 00000000..48355c6a --- /dev/null +++ b/lua/codecompanion/adapters/ollama.lua @@ -0,0 +1,84 @@ +local Adapter = require("codecompanion.adapter") + +---@class CodeCompanion.Adapter +---@field name string +---@field url string +---@field raw? table +---@field headers table +---@field parameters table +---@field schema table + +local adapter = { + name = "Ollama", + url = "http://localhost:11434/api/chat", + raw = { + "--no-buffer", + }, + callbacks = { + ---Set the format of the role and content for the messages from the chat buffer + ---@param messages table Format is: { { role = "user", content = "Your prompt here" } } + ---@return table + form_messages = function(messages) + return { messages = messages } + end, + + ---Format any data before it's consumed by the other callbacks + ---@param data table + ---@return table + format_data = function(data) + return data + end, + + ---Has the streaming completed? + ---@param data table The data from the format_data callback + ---@return boolean + is_complete = function(data) + return data.done == true + end, + + ---Output the data from the API ready for insertion into the chat buffer + ---@param data table The streamed data from the API, also formatted by the format_data callback + ---@param messages table A table of all of the messages in the chat buffer + ---@param current_message table The current/latest message in the chat buffer + ---@return table + output_chat = function(data, messages, current_message) + local delta = data.message + + if delta.role and delta.role ~= current_message.role then + current_message = { role = delta.role, content = "" } + table.insert(messages, current_message) + end + + if delta.content then + current_message.content = current_message.content .. delta.content + end + + return current_message + end, + + ---Output the data from the API ready for inlining into the current buffer + ---@param data table The streamed data from the API, also formatted by the format_data callback + ---@param context table Useful context about the buffer to inline to + ---@return table + output_inline = function(data, context) + return data.message.content + end, + }, + schema = { + model = { + order = 1, + mapping = "parameters", + type = "enum", + desc = "ID of the model to use.", + default = "llama2", + choices = { + "llama2", + "mistral", + "dolphin-phi", + "phi", + }, + }, + }, +} + +return Adapter.new(adapter) diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index cc73fa1c..615a11d3 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -97,14 +97,16 @@ function Client:stream(adapter, payload, bufnr, cb) --TODO: Check for any errors env variables local headers = adapter:replace_header_vars().headers + local body = + self.opts.encode(vim.tbl_extend("keep", adapter.parameters or {}, adapter.callbacks.form_messages(payload))) - log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, headers, adapter.parameters }) + log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, headers, body }) local handler = self.opts.request({ url = adapter.url, raw = adapter.raw or { "--no-buffer" }, headers = headers, - body = self.opts.encode(vim.tbl_extend("keep", adapter.parameters, adapter.callbacks.form_messages(payload))), + body = body, stream = self.opts.schedule(function(_, data) if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then close_request(bufnr, { shutdown = true }) diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index 870a7b89..2231934a 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -2,8 +2,8 @@ local M = {} local defaults = { adapters = { - chat = require("codecompanion.adapters.openai"), - inline = require("codecompanion.adapters.openai"), + chat = require("codecompanion.adapters.ollama"), + inline = require("codecompanion.adapters.ollama"), }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", From 9775fda97cf3ddad833fa45be5e1f12a9c1b951b Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 14:19:01 +0000 Subject: [PATCH 25/42] fix ollama `is_done` method --- lua/codecompanion/adapters/ollama.lua | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua index 48355c6a..33726cdd 100644 --- a/lua/codecompanion/adapters/ollama.lua +++ b/lua/codecompanion/adapters/ollama.lua @@ -33,7 +33,8 @@ local adapter = { ---@param data table The data from the format_data callback ---@return boolean is_complete = function(data) - return data.done == true + data = vim.fn.json_decode(data) + return data.done end, ---Output the data from the API ready for insertion into the chat buffer From eb000d7c2d92beaa500c457a77976b4cea776b0d Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 18:53:21 +0000 Subject: [PATCH 26/42] fix displaying decimal places in settings --- lua/codecompanion/utils/yaml.lua | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lua/codecompanion/utils/yaml.lua b/lua/codecompanion/utils/yaml.lua index 9f0022c8..b8775909 100644 --- a/lua/codecompanion/utils/yaml.lua +++ b/lua/codecompanion/utils/yaml.lua @@ -7,7 +7,11 @@ M.encode = function(data) if data == nil then return "null" elseif dt == "number" then - return string.format("%d", data) + if data % 1 == 0 then + return string.format("%d", data) + else + return string.format("%.1f", data) + end elseif dt == "boolean" then return string.format("%s", data) elseif dt == "string" then From 051058bfef96c3230840cc887265e7f73fd456c2 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 18:54:47 +0000 Subject: [PATCH 27/42] tweak adapters --- lua/codecompanion/adapters/ollama.lua | 26 +++++++++++++++++--------- lua/codecompanion/adapters/openai.lua | 13 ++++++------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua index 33726cdd..5fa5ca13 100644 --- a/lua/codecompanion/adapters/ollama.lua +++ b/lua/codecompanion/adapters/ollama.lua @@ -11,9 +11,6 @@ local Adapter = require("codecompanion.adapter") local adapter = { name = "Ollama", url = "http://localhost:11434/api/chat", - raw = { - "--no-buffer", - }, callbacks = { ---Set the format of the role and content for the messages from the chat buffer ---@param messages table Format is: { { role = "user", content = "Your prompt here" } } @@ -38,12 +35,12 @@ local adapter = { end, ---Output the data from the API ready for insertion into the chat buffer - ---@param data table The streamed data from the API, also formatted by the format_data callback + ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback ---@param messages table A table of all of the messages in the chat buffer ---@param current_message table The current/latest message in the chat buffer ---@return table - output_chat = function(data, messages, current_message) - local delta = data.message + output_chat = function(json_data, messages, current_message) + local delta = json_data.message if delta.role and delta.role ~= current_message.role then current_message = { role = delta.role, content = "" } @@ -58,11 +55,11 @@ local adapter = { end, ---Output the data from the API ready for inlining into the current buffer - ---@param data table The streamed data from the API, also formatted by the format_data callback + ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback ---@param context table Useful context about the buffer to inline to ---@return table - output_inline = function(data, context) - return data.message.content + output_inline = function(json_data, context) + return json_data.message.content end, }, schema = { @@ -79,6 +76,17 @@ local adapter = { "phi", }, }, + temperature = { + order = 2, + mapping = "parameters.options", + type = "number", + optional = true, + default = 0.8, + desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 2, "Must be between 0 and 2" + end, + }, }, } diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 6f670746..a499cc9b 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -7,7 +7,6 @@ local Adapter = require("codecompanion.adapter") ---@field headers table ---@field parameters table ---@field schema table - local adapter = { name = "OpenAI", url = "https://api.openai.com/v1/chat/completions", @@ -49,12 +48,12 @@ local adapter = { end, ---Output the data from the API ready for insertion into the chat buffer - ---@param data table The streamed data from the API, also formatted by the format_data callback + ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback ---@param messages table A table of all of the messages in the chat buffer ---@param current_message table The current/latest message in the chat buffer ---@return table - output_chat = function(data, messages, current_message) - local delta = data.choices[1].delta + output_chat = function(json_data, messages, current_message) + local delta = json_data.choices[1].delta if delta.role and delta.role ~= current_message.role then current_message = { role = delta.role, content = "" } @@ -69,11 +68,11 @@ local adapter = { end, ---Output the data from the API ready for inlining into the current buffer - ---@param data table The streamed data from the API, also formatted by the format_data callback + ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback ---@param context table Useful context about the buffer to inline to ---@return table - output_inline = function(data, context) - return data.choices[1].delta.content + output_inline = function(json_data, context) + return json_data.choices[1].delta.content end, }, schema = { From eb70e27f974c5a0b121e102d40ac439deb5357fd Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 18:56:52 +0000 Subject: [PATCH 28/42] update README.md --- README.md | 54 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 2f17d684..5517f26c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@

-CodeCompanion.nvim +CodeCompanion.nvim

CodeCompanion.nvim

@@ -14,7 +14,8 @@

-Use the OpenAI APIs directly in Neovim. Use it to chat, author and advise you on your code. +Use the power of Generative AI in Neovim. Use it to chat, author and advise you on your code.
+Currently supports OpenAI and Ollama.

> [!IMPORTANT] @@ -29,9 +30,10 @@ Use the > [!WARNING] -> For some users, the sending of any code to OpenAI may not be an option. In those instances, you can set `send_code = false` in your config. +> Depending on your chosen adapter, you may need to setup environment variables within your shell. See the adapters section below for specific information. + +### Adapters + +The plugin uses adapters to bridge between Generative AI services and the plugin itself. More information can be found in the [ADAPTERS](ADAPTERS.md) guide. Below are the configuration requirements for each adapter: + +- **OpenAI** - Set the `OPENAI_API_KEY` variable within your shell +- **Ollama** - None + +#### Modifying Adapters + +It may be neccessary for you to modify in-built adapters. This can be done by calling the `setup` method: + +```lua +require("codecompanion").setup({ + adapters = { + chat = require("codecompanion.adapters.openai").setup({ + + }), + }, +}) +``` ### Edgy.nvim Configuration @@ -219,13 +241,13 @@ The Action Palette, opened via `:CodeCompanionActions`, contains all of the acti

chat buffer

-The chat buffer is where you can converse with the OpenAI APIs, directly from Neovim. It behaves as a regular markdown buffer with some clever additions. When the buffer is written (or "saved"), autocmds trigger the sending of its content to OpenAI, in the form of prompts. These prompts are segmented by H1 headers: `user` and `assistant` (see OpenAI's [Chat Completions API](https://platform.openai.com/docs/guides/text-generation/chat-completions-api) for more on this). When a response is received, it is then streamed back into the buffer. The result is that you experience the feel of conversing with ChatGPT from within Neovim. +The chat buffer is where you can converse with the Generative AI service, directly from Neovim. It behaves as a regular markdown buffer with some clever additions. When the buffer is written (or "saved"), autocmds trigger the sending of its content to the Generative AI service, in the form of prompts. These prompts are segmented by H1 headers: `user` and `assistant`. When a response is received, it is then streamed back into the buffer. The result is that you experience the feel of conversing with ChatGPT from within Neovim. #### Keymaps When in the chat buffer, there are number of keymaps available to you (which can be changed in the config): -- `` - Save the buffer and trigger a response from the OpenAI API +- `` - Save the buffer and trigger a response from the Generative AI service - `` - Close the buffer - `q` - Cancel the stream from the API - `gc` - Clear the buffer's contents @@ -240,7 +262,7 @@ Chat buffers are not saved to disk by default, but can be by pressing `gs` in th #### Settings -If `display.chat.show_settings` is set to `true`, at the very top of the chat buffer will be the OpenAI parameters which can be changed to tweak the response back to you. This enables fine-tuning and parameter tweaking throughout the chat. You can find more detail about them by moving the cursor over them or referring to the [OpenAI Chat Completions reference guide](https://platform.openai.com/docs/api-reference/chat). +If `display.chat.show_settings` is set to `true`, at the very top of the chat buffer will be the adapter parameters which can be changed to tweak the response back to you. This enables fine-tuning and parameter tweaking throughout the chat. You can find more detail about them by moving the cursor over them. ### Inline Code @@ -259,7 +281,7 @@ You can use the plugin to create inline code directly into a Neovim buffer. This > [!NOTE] > The command can detect if you've made a visual selection and send any code as context to the API alongside the filetype of the buffer. -One of the challenges with inline editing is determining how the API's response should be handled in the buffer. If you've prompted the API to _"create a table of 5 fruits"_ then you may wish for the response to be placed after the cursor in the buffer. However, if you asked the API to _"refactor this function"_ then you'd expect the response to overwrite a visual selection. If this _placement_ isn't specified then the plugin will use OpenAI itself to determine if the response should follow any of the placements below: +One of the challenges with inline editing is determining how the API's response should be handled in the buffer. If you've prompted the API to _"create a table of 5 fruits"_ then you may wish for the response to be placed after the cursor in the buffer. However, if you asked the API to _"refactor this function"_ then you'd expect the response to overwrite a visual selection. If this _placement_ isn't specified then the plugin will use Generative AI itself to determine if the response should follow any of the placements below: - _after_ - after the visual selection - _before_ - before the visual selection @@ -271,11 +293,11 @@ As a final example, specifying a prompt like _"create a test for this code in a ### In-Built Actions -The plugin comes with a number of [in-built actions](https://github.com/olimorris/codecompanion.nvim/blob/main/lua/codecompanion/actions.lua) which aim to improve your Neovim workflow. Actions make use of either a _chat_ or an _inline_ strategy, which are essentially bridges between Neovim and OpenAI. The chat strategy opens up a chat buffer whilst an inline strategy will write output from OpenAI into the Neovim buffer. +The plugin comes with a number of [in-built actions](https://github.com/olimorris/codecompanion.nvim/blob/main/lua/codecompanion/actions.lua) which aim to improve your Neovim workflow. Actions make use of either a _chat_ or an _inline_ strategy. The chat strategy opens up a chat buffer whilst an inline strategy will write output from the Generative AI service into the Neovim buffer. #### Chat and Chat as -Both of these actions utilise the `chat` strategy. The `Chat` action opens up a fresh chat buffer. The `Chat as` action allows for persona based context to be set in the chat buffer allowing for better and more detailed responses from OpenAI. +Both of these actions utilise the `chat` strategy. The `Chat` action opens up a fresh chat buffer. The `Chat as` action allows for persona based context to be set in the chat buffer allowing for better and more detailed responses from the Generative AI service. > [!TIP] > Both of these actions allow for visually selected code to be sent to the chat buffer as code blocks. @@ -303,7 +325,7 @@ As the name suggests, this action provides advice on a visual selection of code #### LSP assistant -Taken from the fantastic [Wtf.nvim](https://github.com/piersolenski/wtf.nvim) plugin, this action provides advice on how to correct any LSP diagnostics which are present on the visually selected lines. Again, the `send_code = false` value can be set in your config to prevent the code itself being sent to OpenAI. +Taken from the fantastic [Wtf.nvim](https://github.com/piersolenski/wtf.nvim) plugin, this action provides advice on how to correct any LSP diagnostics which are present on the visually selected lines. Again, the `send_code = false` value can be set in your config to prevent the code itself being sent to the Generative AI service. ## :rainbow: Helpers @@ -338,10 +360,10 @@ vim.api.nvim_create_autocmd({ "User" }, { ### Heirline.nvim -If you're using the fantastic [Heirline.nvim](https://github.com/rebelot/heirline.nvim) plugin, consider the following snippet to display an icon in the statusline whilst CodeCompanion is conversing with OpenAI: +If you're using the fantastic [Heirline.nvim](https://github.com/rebelot/heirline.nvim) plugin, consider the following snippet to display an icon in the statusline whilst CodeCompanion is conversing with the Generative AI service: ```lua -local OpenAI = { +local CodeCompanion = { static = { processing = false, }, From 3b1d5daeeee2cd07256d08c0934253142b6c2a15 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Wed, 6 Mar 2024 16:47:50 +0000 Subject: [PATCH 29/42] feat: add anthropic adapter --- lua/codecompanion/adapters/anthropic.lua | 135 +++++++++++++++++++++++ lua/codecompanion/adapters/ollama.lua | 14 +++ lua/codecompanion/adapters/openai.lua | 14 +++ lua/codecompanion/client.lua | 40 ++++--- lua/codecompanion/strategies/chat.lua | 6 +- 5 files changed, 194 insertions(+), 15 deletions(-) create mode 100644 lua/codecompanion/adapters/anthropic.lua diff --git a/lua/codecompanion/adapters/anthropic.lua b/lua/codecompanion/adapters/anthropic.lua new file mode 100644 index 00000000..014d3598 --- /dev/null +++ b/lua/codecompanion/adapters/anthropic.lua @@ -0,0 +1,135 @@ +local Adapter = require("codecompanion.adapter") + +---@class CodeCompanion.Adapter +---@field name string +---@field url string +---@field raw? table +---@field headers table +---@field parameters table +---@field schema table +local adapter = { + name = "Anthropic", + url = "https://api.anthropic.com/v1/messages", + env = { + anthropic_api_key = "ANTHROPIC_API_KEY", + }, + headers = { + ["anthropic-version"] = "2023-06-01", + -- ["anthropic-beta"] = "messages-2023-12-15", + ["content-type"] = "application/json", + ["x-api-key"] = "${anthropic_api_key}", + }, + parameters = { + stream = true, + }, + callbacks = { + ---Set the format of the role and content for the messages from the chat buffer + ---@param messages table Format is: { { role = "user", content = "Your prompt here" } } + ---@return table + form_messages = function(messages) + return { messages = messages } + end, + + ---Event based responses sometimes include data that shouldn't be processed + ---@param data table + ---@return boolean + should_skip = function(data) + if type(data) == "string" then + return string.sub(data, 1, 6) == "event:" + end + return false + end, + + ---Format any data before it's consumed by the other callbacks + ---@param data string + ---@return string + format_data = function(data) + return data:sub(6) + end, + + ---Handle any errors from the API + ---@param data string + ---@return boolean + should_handle_errors = function(data) + if type(data) == "string" then + return string.sub(data, 1, 12) == "event: error" + end + return false + end, + + ---Has the streaming completed? + ---@param data string The data from the format_data callback + ---@return boolean + is_complete = function(data) + local ok, data = pcall(vim.fn.json_decode, data) + if ok and data.type then + return data.type == "message_stop" + end + return false + end, + + ---Output the data from the API ready for insertion into the chat buffer + ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback + ---@param messages table A table of all of the messages in the chat buffer + ---@param current_message table The current/latest message in the chat buffer + ---@return table + output_chat = function(json_data, messages, current_message) + if json_data.type == "message_start" then + current_message = { role = json_data.message.role, content = "" } + table.insert(messages, current_message) + end + + if json_data.type == "content_block_delta" then + current_message.content = current_message.content .. json_data.delta.text + end + + return current_message + end, + + ---Output the data from the API ready for inlining into the current buffer + ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback + ---@param context table Useful context about the buffer to inline to + ---@return table + output_inline = function(json_data, context) + return json_data.choices[1].delta.content + end, + }, + schema = { + model = { + order = 1, + mapping = "parameters", + type = "enum", + desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + default = "claude-3-opus-20240229", + choices = { + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-2.1", + }, + }, + max_tokens = { + order = 2, + mapping = "parameters", + type = "number", + optional = true, + default = 1024, + desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", + validate = function(n) + return n > 0, "Must be greater than 0" + end, + }, + temperature = { + order = 3, + mapping = "parameters", + type = "number", + optional = true, + default = 0, + desc = "What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 1, "Must be between 0 and 1" + end, + }, + }, +} + +return Adapter.new(adapter) diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua index 5fa5ca13..da1f4499 100644 --- a/lua/codecompanion/adapters/ollama.lua +++ b/lua/codecompanion/adapters/ollama.lua @@ -19,6 +19,13 @@ local adapter = { return { messages = messages } end, + ---Event based responses sometimes include data that shouldn't be processed + ---@param data table + ---@return boolean + should_skip = function(data) + return false + end, + ---Format any data before it's consumed by the other callbacks ---@param data table ---@return table @@ -26,6 +33,13 @@ local adapter = { return data end, + ---Handle any errors from the API + ---@param data string + ---@return boolean + should_handle_errors = function(data) + return false + end, + ---Has the streaming completed? ---@param data table The data from the format_data callback ---@return boolean diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index a499cc9b..8129e1df 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -32,6 +32,13 @@ local adapter = { return { messages = messages } end, + ---Event based responses sometimes include data that shouldn't be processed + ---@param data table + ---@return boolean + should_skip = function(data) + return false + end, + ---Format any data before it's consumed by the other callbacks ---@param data string ---@return string @@ -40,6 +47,13 @@ local adapter = { return data:sub(7) end, + ---Handle any errors from the API + ---@param data string + ---@return boolean + should_handle_errors = function(data) + return false + end, + ---Has the streaming completed? ---@param data string The data from the format_data callback ---@return boolean diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 615a11d3..a8bab991 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -104,33 +104,47 @@ function Client:stream(adapter, payload, bufnr, cb) local handler = self.opts.request({ url = adapter.url, + timeout = 1000, raw = adapter.raw or { "--no-buffer" }, headers = headers, body = body, stream = self.opts.schedule(function(_, data) + log:trace("Chat data: %s", data) + -- log:trace("----- For Adapter test creation -----\nRequest: %s\n ---------- // END ----------", data) + if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then close_request(bufnr, { shutdown = true }) return cb(nil, nil, true) end - if data and type(adapter.callbacks.format_data) == "function" then - data = adapter.callbacks.format_data(data) - end + if not adapter.callbacks.should_skip(data) then + if data and type(adapter.callbacks.format_data) == "function" then + data = adapter.callbacks.format_data(data) + end - if adapter.callbacks.is_complete(data) then - close_request(bufnr) - return cb(nil, nil, true) - end + if adapter.callbacks.is_complete(data) then + log:trace("Chat completed") + close_request(bufnr) + return cb(nil, nil, true) + end + + if data and data ~= "" then + local ok, json = pcall(self.opts.decode, data, { luanil = { object = true } }) - if data and data ~= "" then - local ok, json = pcall(self.opts.decode, data, { luanil = { object = true } }) + if not ok then + close_request(bufnr) + log:error("Decoding error: %s", json) + log:error("Data trace: %s", data) + return cb(string.format("Error decoding data: %s", json)) + end - if not ok then + cb(nil, json) + end + else + if adapter.callbacks.should_handle_errors(data) then close_request(bufnr) - return cb(string.format("Error malformed json: %s", json)) + return cb(string.format("There was an error from API: %s: ", data)) end - - cb(nil, json) end end), on_error = function(err, _, _) diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index 7743711c..d0b7f68c 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -437,6 +437,9 @@ function Chat:submit() return finalize() end + -- log:trace("----- For Adapter test creation -----\nMessages: %s\n ---------- // END ----------", messages) + log:trace("Settings: %s", settings) + local adapter = config.options.adapters.chat client.new():stream(adapter:set_params(settings), messages, self.bufnr, function(err, data, done) @@ -446,13 +449,12 @@ function Chat:submit() end if data then - log:trace("Chat data: %s", data) current_message = adapter.callbacks.output_chat(data, messages, current_message) + -- log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", current_message) render_buffer() end if done then - log:trace("Chat streaming is done") table.insert(messages, { role = "user", content = "" }) render_buffer() display_tokens(self.bufnr) From a47b569e3fa5f0cdda1e4467c4b5031e0fa94568 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Wed, 6 Mar 2024 16:50:40 +0000 Subject: [PATCH 30/42] refactor name of error callback --- lua/codecompanion/adapters/anthropic.lua | 2 +- lua/codecompanion/adapters/ollama.lua | 2 +- lua/codecompanion/adapters/openai.lua | 2 +- lua/codecompanion/client.lua | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lua/codecompanion/adapters/anthropic.lua b/lua/codecompanion/adapters/anthropic.lua index 014d3598..3c97e01d 100644 --- a/lua/codecompanion/adapters/anthropic.lua +++ b/lua/codecompanion/adapters/anthropic.lua @@ -50,7 +50,7 @@ local adapter = { ---Handle any errors from the API ---@param data string ---@return boolean - should_handle_errors = function(data) + has_error = function(data) if type(data) == "string" then return string.sub(data, 1, 12) == "event: error" end diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua index da1f4499..ea26881a 100644 --- a/lua/codecompanion/adapters/ollama.lua +++ b/lua/codecompanion/adapters/ollama.lua @@ -36,7 +36,7 @@ local adapter = { ---Handle any errors from the API ---@param data string ---@return boolean - should_handle_errors = function(data) + has_error = function(data) return false end, diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 8129e1df..0cad751f 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -50,7 +50,7 @@ local adapter = { ---Handle any errors from the API ---@param data string ---@return boolean - should_handle_errors = function(data) + has_error = function(data) return false end, diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index a8bab991..788a7c51 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -141,7 +141,7 @@ function Client:stream(adapter, payload, bufnr, cb) cb(nil, json) end else - if adapter.callbacks.should_handle_errors(data) then + if adapter.callbacks.has_error(data) then close_request(bufnr) return cb(string.format("There was an error from API: %s: ", data)) end From cd1db089f1c00c6f091b6f61e4f6ebec8852119d Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Wed, 6 Mar 2024 17:55:34 +0000 Subject: [PATCH 31/42] better handling of errors --- lua/codecompanion/adapters/anthropic.lua | 21 ++++++++++++--------- lua/codecompanion/adapters/ollama.lua | 6 +++--- lua/codecompanion/adapters/openai.lua | 5 +++-- lua/codecompanion/client.lua | 13 +++++++++++-- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/lua/codecompanion/adapters/anthropic.lua b/lua/codecompanion/adapters/anthropic.lua index 3c97e01d..13438e4d 100644 --- a/lua/codecompanion/adapters/anthropic.lua +++ b/lua/codecompanion/adapters/anthropic.lua @@ -6,6 +6,7 @@ local Adapter = require("codecompanion.adapter") ---@field raw? table ---@field headers table ---@field parameters table +---@field callbacks table ---@field schema table local adapter = { name = "Anthropic", @@ -30,7 +31,7 @@ local adapter = { return { messages = messages } end, - ---Event based responses sometimes include data that shouldn't be processed + ---Does this streamed data need to be skipped? ---@param data table ---@return boolean should_skip = function(data) @@ -47,21 +48,20 @@ local adapter = { return data:sub(6) end, - ---Handle any errors from the API + ---Does the data contain an error? ---@param data string ---@return boolean has_error = function(data) - if type(data) == "string" then - return string.sub(data, 1, 12) == "event: error" - end - return false + local msg = "event: error" + return string.sub(data, 1, string.len(msg)) == msg end, ---Has the streaming completed? ---@param data string The data from the format_data callback ---@return boolean is_complete = function(data) - local ok, data = pcall(vim.fn.json_decode, data) + local ok + ok, data = pcall(vim.fn.json_decode, data) if ok and data.type then return data.type == "message_stop" end @@ -89,9 +89,12 @@ local adapter = { ---Output the data from the API ready for inlining into the current buffer ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback ---@param context table Useful context about the buffer to inline to - ---@return table + ---@return table|nil output_inline = function(json_data, context) - return json_data.choices[1].delta.content + if json_data.type == "content_block_delta" then + return json_data.delta.text + end + return nil end, }, schema = { diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua index ea26881a..29fa9049 100644 --- a/lua/codecompanion/adapters/ollama.lua +++ b/lua/codecompanion/adapters/ollama.lua @@ -6,8 +6,8 @@ local Adapter = require("codecompanion.adapter") ---@field raw? table ---@field headers table ---@field parameters table +---@field callbacks table ---@field schema table - local adapter = { name = "Ollama", url = "http://localhost:11434/api/chat", @@ -19,7 +19,7 @@ local adapter = { return { messages = messages } end, - ---Event based responses sometimes include data that shouldn't be processed + ---Does this streamed data need to be skipped? ---@param data table ---@return boolean should_skip = function(data) @@ -33,7 +33,7 @@ local adapter = { return data end, - ---Handle any errors from the API + ---Does the data contain an error? ---@param data string ---@return boolean has_error = function(data) diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 0cad751f..e15dcbbc 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -6,6 +6,7 @@ local Adapter = require("codecompanion.adapter") ---@field raw? table ---@field headers table ---@field parameters table +---@field callbacks table ---@field schema table local adapter = { name = "OpenAI", @@ -32,7 +33,7 @@ local adapter = { return { messages = messages } end, - ---Event based responses sometimes include data that shouldn't be processed + ---Does this streamed data need to be skipped? ---@param data table ---@return boolean should_skip = function(data) @@ -47,7 +48,7 @@ local adapter = { return data:sub(7) end, - ---Handle any errors from the API + ---Does the data contain an error? ---@param data string ---@return boolean has_error = function(data) diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 788a7c51..cd9fb65e 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -102,6 +102,12 @@ function Client:stream(adapter, payload, bufnr, cb) log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, headers, body }) + local function handle_error(data) + log:error("Error: %s", data) + close_request(bufnr) + return cb(string.format("There was an error from API: %s: ", data)) + end + local handler = self.opts.request({ url = adapter.url, timeout = 1000, @@ -118,6 +124,10 @@ function Client:stream(adapter, payload, bufnr, cb) end if not adapter.callbacks.should_skip(data) then + if adapter.callbacks.has_error(data) then + return handle_error(data) + end + if data and type(adapter.callbacks.format_data) == "function" then data = adapter.callbacks.format_data(data) end @@ -142,8 +152,7 @@ function Client:stream(adapter, payload, bufnr, cb) end else if adapter.callbacks.has_error(data) then - close_request(bufnr) - return cb(string.format("There was an error from API: %s: ", data)) + return handle_error(data) end end end), From d7386b45b55ab82887a4966943a95d7d8458d118 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Wed, 6 Mar 2024 18:16:43 +0000 Subject: [PATCH 32/42] start adding adapter tests --- lua/spec/codecompanion/adapters/helpers.lua | 16 +++++ .../codecompanion/adapters/openai_spec.lua | 59 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 lua/spec/codecompanion/adapters/helpers.lua create mode 100644 lua/spec/codecompanion/adapters/openai_spec.lua diff --git a/lua/spec/codecompanion/adapters/helpers.lua b/lua/spec/codecompanion/adapters/helpers.lua new file mode 100644 index 00000000..9122cec5 --- /dev/null +++ b/lua/spec/codecompanion/adapters/helpers.lua @@ -0,0 +1,16 @@ +local M = {} + +function M.chat_buffer_output(stream_response, adapter, messages) + local output = {} + + for _, data in ipairs(stream_response) do + data = adapter.callbacks.format_data(data.request) + data = vim.json.decode(data, { luanil = { object = true } }) + + output = adapter.callbacks.output_chat(data, messages, output) + end + + return output +end + +return M diff --git a/lua/spec/codecompanion/adapters/openai_spec.lua b/lua/spec/codecompanion/adapters/openai_spec.lua new file mode 100644 index 00000000..ed478a14 --- /dev/null +++ b/lua/spec/codecompanion/adapters/openai_spec.lua @@ -0,0 +1,59 @@ +local adapter = require("codecompanion.adapters.openai") +local assert = require("luassert") +local helpers = require("spec.codecompanion.adapters.helpers") + +--------------------------------------------------- OUTPUT FROM THE CHAT BUFFER +local messages = { { + content = "Explain Ruby in two words", + role = "user", +} } + +local stream_response = { + { + request = 'data: {"id":"chatcmpl-8zlFGE8bEaXPG43tedyauJkw1EiMQ","object":"chat.completion.chunk","created":1709730310,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}', + output = { + content = "", + role = "assistant", + }, + }, + { + request = 'data: {"id":"chatcmpl-8zlFGE8bEaXPG43tedyauJkw1EiMQ","object":"chat.completion.chunk","created":1709730310,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":"Programming"},"logprobs":null,"finish_reason":null}]}', + output = { + content = "Programming", + role = "assistant", + }, + }, + { + request = 'data: {"id":"chatcmpl-8zlFGE8bEaXPG43tedyauJkw1EiMQ","object":"chat.completion.chunk","created":1709730310,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":" language"},"logprobs":null,"finish_reason":null}]}', + output = { + content = "Programming language", + role = "assistant", + }, + }, +} + +local done_response = "data: [DONE]" +------------------------------------------------------------------------ // END + +describe("OpenAI adapter", function() + it("can form messages to be sent to the API", function() + assert.are.same({ messages = messages }, adapter.callbacks.form_messages(messages)) + end) + + it("can format the data from the API", function() + assert.are.same("[DONE]", adapter.callbacks.format_data(done_response)) + end) + + it("can check if the streaming is complete", function() + local data = adapter.callbacks.format_data(done_response) + + assert.is_true(adapter.callbacks.is_complete(data)) + end) + + it("can output streamed data into a format for the chat buffer", function() + assert.are.same( + stream_response[#stream_response].output, + helpers.chat_buffer_output(stream_response, adapter, messages) + ) + end) +end) From 8f25246bfc1937ad66baf48425c7390722d43889 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Wed, 6 Mar 2024 19:03:48 +0000 Subject: [PATCH 33/42] add ollama adapter test --- .../codecompanion/adapters/ollama_spec.lua | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 lua/spec/codecompanion/adapters/ollama_spec.lua diff --git a/lua/spec/codecompanion/adapters/ollama_spec.lua b/lua/spec/codecompanion/adapters/ollama_spec.lua new file mode 100644 index 00000000..b069d4a6 --- /dev/null +++ b/lua/spec/codecompanion/adapters/ollama_spec.lua @@ -0,0 +1,84 @@ +local adapter = require("codecompanion.adapters.ollama") +local assert = require("luassert") +local helpers = require("spec.codecompanion.adapters.helpers") + +--------------------------------------------------- OUTPUT FROM THE CHAT BUFFER +local messages = { { + content = "Explain Ruby in two words", + role = "user", +} } + +local stream_response = { + { + request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.715665Z","message":{"role":"assistant","content":"\n"},"done":false}]], + output = { + content = "\n", + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.745213Z","message":{"role":"assistant","content":"\""},"done":false}]], + output = { + content = '\n"', + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.77473Z","message":{"role":"assistant","content":"E"},"done":false}]], + output = { + content = '\n"E', + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.803753Z","message":{"role":"assistant","content":"asy"},"done":false}]], + output = { + content = '\n"Easy', + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.833925Z","message":{"role":"assistant","content":" Program"},"done":false}]], + output = { + content = '\n"Easy Program', + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.862917Z","message":{"role":"assistant","content":"ming"},"done":false}]], + output = { + content = '\n"Easy Programming', + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.892319Z","message":{"role":"assistant","content":"\""},"done":false}]], + output = { + content = '\n"Easy Programming"', + role = "assistant", + }, + }, +} + +local done_response = + [[{"model":"llama2","created_at":"2024-03-06T18:35:15.921631Z","message":{"role":"assistant","content":""},"done":true,"total_duration":6035327208,"load_duration":5654490167,"prompt_eval_count":26,"prompt_eval_duration":173338000,"eval_count":8,"eval_duration":205986000}]] +------------------------------------------------------------------------ // END + +describe("Ollama adapter", function() + it("can form messages to be sent to the API", function() + assert.are.same({ messages = messages }, adapter.callbacks.form_messages(messages)) + end) + + it("can check if the streaming is complete", function() + local data = adapter.callbacks.format_data(done_response) + + assert.is_true(adapter.callbacks.is_complete(data)) + end) + + it("can output streamed data into a format for the chat buffer", function() + assert.are.same( + stream_response[#stream_response].output, + helpers.chat_buffer_output(stream_response, adapter, messages) + ) + end) +end) From c9647c5c974e932c266c8f139b98fffbfd3f845d Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 08:08:15 +0000 Subject: [PATCH 34/42] allow adapters to be customised from the config --- README.md | 35 ++++++++++++++++++------ lua/codecompanion/adapters/anthropic.lua | 10 ++----- lua/codecompanion/adapters/init.lua | 25 +++++++++++++++++ lua/codecompanion/adapters/ollama.lua | 6 +--- lua/codecompanion/adapters/openai.lua | 10 ++----- lua/codecompanion/config.lua | 4 +-- 6 files changed, 60 insertions(+), 30 deletions(-) create mode 100644 lua/codecompanion/adapters/init.lua diff --git a/README.md b/README.md index 5517f26c..f5b65dc5 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Currently supports OpenAI and Ollama. - The `curl` library installed - Neovim 0.9.2 or greater -- _(Optional)_ An API key from your chosen Generative AI service +- _(Optional)_ An API key to be set in your shell for your chosen Generative AI service ## :package: Installation @@ -97,8 +97,8 @@ You only need to the call the `setup` function if you wish to change any of the ```lua require("codecompanion").setup({ adapters = { - chat = require("codecompanion.adapters.openai"), - inline = require("codecompanion.adapters.openai"), + chat = require("codecompanion.adapters").use("openai"), + inline = require("codecompanion.adapters").use("openai"), }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", -- Path to save chats to @@ -158,25 +158,42 @@ require("codecompanion").setup({ ### Adapters -The plugin uses adapters to bridge between Generative AI services and the plugin itself. More information can be found in the [ADAPTERS](ADAPTERS.md) guide. Below are the configuration requirements for each adapter: +The plugin uses adapters to bridge between Generative AI services and the plugin. Currently the plugin supports: -- **OpenAI** - Set the `OPENAI_API_KEY` variable within your shell -- **Ollama** - None +- Anthropic (`anthropic`) - Requires `ANTHROPIC_API_KEY` to be set in your shell +- Ollama (`ollama`) +- OpenAI (`openai`) - Requires `OPENAI_API_KEY` to be set in your shell + +You can specify an adapter for each of the strategies in the plugin: + +```lua +require("codecompanion").setup({ + adapters = { + chat = require("codecompanion.adapters").use("openai"), + inline = require("codecompanion.adapters").use("openai"), + }, +}) +``` #### Modifying Adapters -It may be neccessary for you to modify in-built adapters. This can be done by calling the `setup` method: +It may be necessary to modify certain parameters of an adapter. In the example below, we're changing the name of the API key that the OpenAI adapter uses by passing in a table to the `use` method: ```lua require("codecompanion").setup({ adapters = { - chat = require("codecompanion.adapters.openai").setup({ - + chat = require("codecompanion.adapters").use("openai", { + env = { + api_key = "DIFFERENT_OPENAI_KEY", + }, }), }, }) ``` +> [!TIP] +> To create your own adapter please refer to the [ADAPTERS](ADAPTERS.md) guide + ### Edgy.nvim Configuration The author recommends pairing with [edgy.nvim](https://github.com/folke/edgy.nvim) for a Co-Pilot Chat-like experience: diff --git a/lua/codecompanion/adapters/anthropic.lua b/lua/codecompanion/adapters/anthropic.lua index 13438e4d..f5d7ecc9 100644 --- a/lua/codecompanion/adapters/anthropic.lua +++ b/lua/codecompanion/adapters/anthropic.lua @@ -1,5 +1,3 @@ -local Adapter = require("codecompanion.adapter") - ---@class CodeCompanion.Adapter ---@field name string ---@field url string @@ -8,17 +6,17 @@ local Adapter = require("codecompanion.adapter") ---@field parameters table ---@field callbacks table ---@field schema table -local adapter = { +return { name = "Anthropic", url = "https://api.anthropic.com/v1/messages", env = { - anthropic_api_key = "ANTHROPIC_API_KEY", + api_key = "ANTHROPIC_API_KEY", }, headers = { ["anthropic-version"] = "2023-06-01", -- ["anthropic-beta"] = "messages-2023-12-15", ["content-type"] = "application/json", - ["x-api-key"] = "${anthropic_api_key}", + ["x-api-key"] = "${api_key}", }, parameters = { stream = true, @@ -134,5 +132,3 @@ local adapter = { }, }, } - -return Adapter.new(adapter) diff --git a/lua/codecompanion/adapters/init.lua b/lua/codecompanion/adapters/init.lua new file mode 100644 index 00000000..6ae2c8d1 --- /dev/null +++ b/lua/codecompanion/adapters/init.lua @@ -0,0 +1,25 @@ +local Adapter = require("codecompanion.adapter") + +local M = {} + +---@param adapter table +---@param opts table +---@return table +local function setup(adapter, opts) + return vim.tbl_deep_extend("force", {}, adapter, opts or {}) +end + +---@param adapter string +---@param opts? table +---@return CodeCompanion.Adapter +function M.use(adapter, opts) + local adapter_config = require("codecompanion.adapters." .. adapter) + + if opts then + adapter_config = setup(adapter_config, opts) + end + + return Adapter.new(adapter_config) +end + +return M diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua index 29fa9049..34ebbf87 100644 --- a/lua/codecompanion/adapters/ollama.lua +++ b/lua/codecompanion/adapters/ollama.lua @@ -1,5 +1,3 @@ -local Adapter = require("codecompanion.adapter") - ---@class CodeCompanion.Adapter ---@field name string ---@field url string @@ -8,7 +6,7 @@ local Adapter = require("codecompanion.adapter") ---@field parameters table ---@field callbacks table ---@field schema table -local adapter = { +return { name = "Ollama", url = "http://localhost:11434/api/chat", callbacks = { @@ -103,5 +101,3 @@ local adapter = { }, }, } - -return Adapter.new(adapter) diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index e15dcbbc..fbf4537b 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -1,5 +1,3 @@ -local Adapter = require("codecompanion.adapter") - ---@class CodeCompanion.Adapter ---@field name string ---@field url string @@ -8,11 +6,11 @@ local Adapter = require("codecompanion.adapter") ---@field parameters table ---@field callbacks table ---@field schema table -local adapter = { +return { name = "OpenAI", url = "https://api.openai.com/v1/chat/completions", env = { - openai_api_key = "OPENAI_API_KEY", + api_key = "OPENAI_API_KEY", }, raw = { "--no-buffer", @@ -20,7 +18,7 @@ local adapter = { }, headers = { ["Content-Type"] = "application/json", - Authorization = "Bearer ${openai_api_key}", + Authorization = "Bearer ${api_key}", }, parameters = { stream = true, @@ -203,5 +201,3 @@ local adapter = { }, }, } - -return Adapter.new(adapter) diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index 2231934a..e09fa2c5 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -2,8 +2,8 @@ local M = {} local defaults = { adapters = { - chat = require("codecompanion.adapters.ollama"), - inline = require("codecompanion.adapters.ollama"), + chat = require("codecompanion.adapters").use("ollama"), + inline = require("codecompanion.adapters").use("ollama"), }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", From f9214e881098051dba7b46f7c8c6d680258776dc Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 08:19:24 +0000 Subject: [PATCH 35/42] fix tests --- lua/codecompanion/adapters/init.lua | 15 ++++++++++++--- lua/spec/codecompanion/adapter_spec.lua | 6 +++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/lua/codecompanion/adapters/init.lua b/lua/codecompanion/adapters/init.lua index 6ae2c8d1..fbe0f5d7 100644 --- a/lua/codecompanion/adapters/init.lua +++ b/lua/codecompanion/adapters/init.lua @@ -9,11 +9,20 @@ local function setup(adapter, opts) return vim.tbl_deep_extend("force", {}, adapter, opts or {}) end ----@param adapter string +---@param adapter string|table ---@param opts? table ----@return CodeCompanion.Adapter +---@return CodeCompanion.Adapter|nil function M.use(adapter, opts) - local adapter_config = require("codecompanion.adapters." .. adapter) + local adapter_config + + if type(adapter) == "string" then + adapter_config = require("codecompanion.adapters." .. adapter) + elseif type(adapter) == "table" then + adapter_config = adapter + else + error("Adapter must be a string or a table") + return + end if opts then adapter_config = setup(adapter_config, opts) diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua index fe98e925..3a93ae7e 100644 --- a/lua/spec/codecompanion/adapter_spec.lua +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -2,7 +2,7 @@ local assert = require("luassert") local test_adapter = { name = "TestAdapter", - url = "https://api.openai.com/v1/chat/completions", + url = "https://api.testgenai.com/v1/chat/completions", headers = { content_type = "application/json", }, @@ -62,7 +62,7 @@ local chat_buffer_settings = { describe("Adapter", function() it("can form parameters from a chat buffer's settings", function() - local adapter = require("codecompanion.adapters.openai") + local adapter = require("codecompanion.adapters").use("openai") local result = adapter:set_params(chat_buffer_settings) -- Ignore this for now @@ -72,7 +72,7 @@ describe("Adapter", function() end) it("can nest parameters based on an adapter's schema", function() - local adapter = require("codecompanion.adapter").new(test_adapter) + local adapter = require("codecompanion.adapters").use(test_adapter) local result = adapter:set_params(chat_buffer_settings) local expected = { From 1447bca99611f2a5749da0be993c7049a680ecd7 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 20:43:41 +0000 Subject: [PATCH 36/42] make chat buffer more efficient and streamline adapters --- lua/codecompanion/adapters/ollama.lua | 70 +++++---- lua/codecompanion/adapters/openai.lua | 72 +++++---- lua/codecompanion/client.lua | 142 ++++-------------- lua/codecompanion/strategies/chat.lua | 78 +++++++--- lua/spec/codecompanion/adapters/helpers.lua | 9 +- .../codecompanion/adapters/ollama_spec.lua | 35 ++--- .../codecompanion/adapters/openai_spec.lua | 23 +-- 7 files changed, 179 insertions(+), 250 deletions(-) diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua index 34ebbf87..5ae2d0f1 100644 --- a/lua/codecompanion/adapters/ollama.lua +++ b/lua/codecompanion/adapters/ollama.lua @@ -1,3 +1,5 @@ +local log = require("codecompanion.utils.log") + ---@class CodeCompanion.Adapter ---@field name string ---@field url string @@ -17,53 +19,49 @@ return { return { messages = messages } end, - ---Does this streamed data need to be skipped? - ---@param data table - ---@return boolean - should_skip = function(data) - return false - end, - - ---Format any data before it's consumed by the other callbacks - ---@param data table - ---@return table - format_data = function(data) - return data - end, - - ---Does the data contain an error? - ---@param data string - ---@return boolean - has_error = function(data) - return false - end, - ---Has the streaming completed? ---@param data table The data from the format_data callback ---@return boolean is_complete = function(data) - data = vim.fn.json_decode(data) - return data.done + if data then + data = vim.fn.json_decode(data) + return data.done + end + return false end, ---Output the data from the API ready for insertion into the chat buffer - ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback - ---@param messages table A table of all of the messages in the chat buffer - ---@param current_message table The current/latest message in the chat buffer - ---@return table - output_chat = function(json_data, messages, current_message) - local delta = json_data.message + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback + ---@return table|nil + chat_output = function(data) + local output = {} - if delta.role and delta.role ~= current_message.role then - current_message = { role = delta.role, content = "" } - table.insert(messages, current_message) - end + if data and data ~= "" then + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return { + status = "error", + output = string.format("Error malformed json: %s", json), + } + end + + local message = json.message + + if message.content then + output.content = message.content + output.role = message.role or nil + end + + log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) - if delta.content then - current_message.content = current_message.content .. delta.content + return { + status = "success", + output = output, + } end - return current_message + return nil end, ---Output the data from the API ready for inlining into the current buffer diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index fbf4537b..15ed3a46 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -1,3 +1,5 @@ +local log = require("codecompanion.utils.log") + ---@class CodeCompanion.Adapter ---@field name string ---@field url string @@ -31,53 +33,49 @@ return { return { messages = messages } end, - ---Does this streamed data need to be skipped? - ---@param data table + ---Has the streaming completed? + ---@param data string The streamed data from the API ---@return boolean - should_skip = function(data) + is_complete = function(data) + if data then + data = data:sub(7) + return data == "[DONE]" + end return false end, - ---Format any data before it's consumed by the other callbacks - ---@param data string - ---@return string - format_data = function(data) - -- Remove the "data: " prefix - return data:sub(7) - end, + ---Output the data from the API ready for insertion into the chat buffer + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback + ---@return table|nil + chat_output = function(data) + local output = {} - ---Does the data contain an error? - ---@param data string - ---@return boolean - has_error = function(data) - return false - end, + if data and data ~= "" then + data = data:sub(7) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) - ---Has the streaming completed? - ---@param data string The data from the format_data callback - ---@return boolean - is_complete = function(data) - return data == "[DONE]" - end, + if not ok then + return { + status = "error", + output = string.format("Error malformed json: %s", json), + } + end - ---Output the data from the API ready for insertion into the chat buffer - ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback - ---@param messages table A table of all of the messages in the chat buffer - ---@param current_message table The current/latest message in the chat buffer - ---@return table - output_chat = function(json_data, messages, current_message) - local delta = json_data.choices[1].delta + local delta = json.choices[1].delta - if delta.role and delta.role ~= current_message.role then - current_message = { role = delta.role, content = "" } - table.insert(messages, current_message) - end + if delta.content then + output.content = delta.content + output.role = delta.role or nil + end - if delta.content then - current_message.content = current_message.content .. delta.content - end + log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) - return current_message + return { + status = "success", + output = output, + } + end + return nil end, ---Output the data from the API ready for inlining into the current buffer diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index cd9fb65e..b2bacc7a 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -2,11 +2,13 @@ local curl = require("plenary.curl") local log = require("codecompanion.utils.log") local schema = require("codecompanion.schema") +local api = vim.api + _G.codecompanion_jobs = {} ---@param status string local function fire_autocmd(status) - vim.api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = status } }) + api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = status } }) end ---@param bufnr? number @@ -23,7 +25,7 @@ end ---@param bufnr? number ---@param opts? table -local function close_request(bufnr, opts) +local function stop_request(bufnr, opts) if bufnr then if opts and opts.shutdown then _G.codecompanion_jobs[bufnr].handler:shutdown() @@ -33,30 +35,6 @@ local function close_request(bufnr, opts) fire_autocmd("finished") end ----@param code integer ----@param stdout string ----@return nil|string ----@return nil|any -local function parse_response(code, stdout) - if code ~= 0 then - log:error("Error: %s", stdout) - return string.format("Error: %s", stdout) - end - - local ok, data = pcall(vim.json.decode, stdout, { luanil = { object = true } }) - if not ok then - log:error("Error malformed json: %s", data) - return string.format("Error malformed json: %s", data) - end - - if data.error then - log:error("API Error: %s", data.error.message) - return string.format("API Error: %s", data.error.message) - end - - return nil, data -end - ---@class CodeCompanion.Client ---@field static table ---@field secret_key string @@ -102,11 +80,17 @@ function Client:stream(adapter, payload, bufnr, cb) log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, headers, body }) - local function handle_error(data) - log:error("Error: %s", data) - close_request(bufnr) - return cb(string.format("There was an error from API: %s: ", data)) - end + local stop_request_cmd = api.nvim_create_autocmd("User", { + desc = "Stop the current request", + pattern = "CodeCompanionRequest", + callback = function(request) + if request.data.buf ~= bufnr or request.data.action ~= "stop_request" then + return + end + + return stop_request(bufnr, { shutdown = true }) + end, + }) local handler = self.opts.request({ url = adapter.url, @@ -116,49 +100,26 @@ function Client:stream(adapter, payload, bufnr, cb) body = body, stream = self.opts.schedule(function(_, data) log:trace("Chat data: %s", data) - -- log:trace("----- For Adapter test creation -----\nRequest: %s\n ---------- // END ----------", data) + log:trace("----- For Adapter test creation -----\nRequest: %s\n ---------- // END ----------", data) if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then - close_request(bufnr, { shutdown = true }) + stop_request(bufnr, { shutdown = true }) return cb(nil, nil, true) end - if not adapter.callbacks.should_skip(data) then - if adapter.callbacks.has_error(data) then - return handle_error(data) - end - - if data and type(adapter.callbacks.format_data) == "function" then - data = adapter.callbacks.format_data(data) - end - - if adapter.callbacks.is_complete(data) then - log:trace("Chat completed") - close_request(bufnr) - return cb(nil, nil, true) - end - - if data and data ~= "" then - local ok, json = pcall(self.opts.decode, data, { luanil = { object = true } }) - - if not ok then - close_request(bufnr) - log:error("Decoding error: %s", json) - log:error("Data trace: %s", data) - return cb(string.format("Error decoding data: %s", json)) - end - - cb(nil, json) - end - else - if adapter.callbacks.has_error(data) then - return handle_error(data) - end + if adapter.callbacks.is_complete(data) then + log:trace("Chat completed") + stop_request(bufnr) + api.nvim_del_autocmd(stop_request_cmd) + return cb(nil, nil, true) end + + cb(nil, data) end), on_error = function(err, _, _) log:error("Error: %s", err) - close_request(bufnr) + stop_request(bufnr) + api.nvim_del_autocmd(stop_request_cmd) end, }) @@ -166,55 +127,4 @@ function Client:stream(adapter, payload, bufnr, cb) start_request(bufnr, handler) end ----Call the API and block until the response is received ----@param adapter CodeCompanion.Adapter ----@param payload table ----@param cb fun(err: nil|string, response: nil|table) -function Client:call(adapter, payload, cb) - cb = log:wrap_cb(cb, "Response error: %s") - - local cmd = { - "curl", - adapter.url, - } - - if adapter.raw then - for _, v in ipairs(adapter.raw) do - table.insert(cmd, v) - end - else - table.insert(cmd, "--no-buffer") - end - - if adapter.headers then - for k, v in pairs(adapter.headers) do - table.insert(cmd, "-H") - table.insert(cmd, string.format("%s: %s", k, v)) - end - end - - table.insert(cmd, "-d") - table.insert( - cmd, - vim.json.encode(vim.tbl_extend("keep", adapter.parameters, { - messages = payload, - })) - ) - log:trace("Request payload: %s", cmd) - - local result = vim.fn.system(cmd) - - if vim.v.shell_error ~= 0 then - log:error("Error calling curl: %s", result) - return cb("Error executing curl", nil) - else - local err, data = parse_response(0, result) - if err then - return cb(err, nil) - else - return cb(nil, data) - end - end -end - return Client diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index d0b7f68c..6f4623fe 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -112,6 +112,13 @@ local function render_messages(bufnr, settings, messages, context) table.insert(lines, "") end + -- Start with the user heading + if #messages == 0 then + table.insert(lines, "# user") + table.insert(lines, "") + table.insert(lines, "") + end + -- Put the messages in the buffer for i, message in ipairs(messages) do if i > 1 then @@ -379,7 +386,10 @@ function Chat.new(args) keys.set_keymaps(config.options.keymaps, bufnr, self) render_messages(bufnr, settings, args.messages or {}, args.context or {}) - display_tokens(bufnr) + + if args.saved_chat then + display_tokens(bufnr) + end if config.options.display.chat.type == "float" then winid = ui.open_float(bufnr, { @@ -417,17 +427,41 @@ function Chat:submit() vim.bo[self.bufnr].modifiable = true end - local function render_buffer() - local line_count = api.nvim_buf_line_count(self.bufnr) + local role = "" + local function render_new_messages(data) + local total_lines = api.nvim_buf_line_count(self.bufnr) local current_line = api.nvim_win_get_cursor(0)[1] - local cursor_moved = current_line == line_count + local cursor_moved = current_line == total_lines - render_messages(self.bufnr, settings, messages, {}) + local lines = {} + if data.role and data.role ~= role then + role = data.role + table.insert(lines, "") + table.insert(lines, "") + table.insert(lines, string.format("# %s", data.role)) + table.insert(lines, "") + end + + if data.content then + for _, text in ipairs(vim.split(data.content, "\n", { plain = true, trimempty = false })) do + table.insert(lines, text) + end - if cursor_moved and ui.buf_is_active(self.bufnr) then - ui.buf_scroll_to_end(self.bufnr) - elseif not ui.buf_is_active(self.bufnr) then - ui.buf_scroll_to_end(self.bufnr) + local modifiable = vim.bo[self.bufnr].modifiable + vim.bo[self.bufnr].modifiable = true + + local last_line = api.nvim_buf_get_lines(self.bufnr, total_lines - 1, total_lines, false)[1] + local last_col = last_line and #last_line or 0 + api.nvim_buf_set_text(self.bufnr, total_lines - 1, last_col, total_lines - 1, last_col, lines) + + vim.bo[self.bufnr].modified = false + vim.bo[self.bufnr].modifiable = modifiable + + if cursor_moved and ui.buf_is_active(self.bufnr) then + ui.buf_scroll_to_end(self.bufnr) + elseif not ui.buf_is_active(self.bufnr) then + ui.buf_scroll_to_end(self.bufnr) + end end end @@ -438,7 +472,7 @@ function Chat:submit() end -- log:trace("----- For Adapter test creation -----\nMessages: %s\n ---------- // END ----------", messages) - log:trace("Settings: %s", settings) + -- log:trace("Settings: %s", settings) local adapter = config.options.adapters.chat @@ -448,18 +482,26 @@ function Chat:submit() return finalize() end - if data then - current_message = adapter.callbacks.output_chat(data, messages, current_message) - -- log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", current_message) - render_buffer() - end - if done then - table.insert(messages, { role = "user", content = "" }) - render_buffer() + render_new_messages({ role = "user", content = "" }) display_tokens(self.bufnr) return finalize() end + + if data then + local result = adapter.callbacks.chat_output(data) + + if result and result.status == "success" then + render_new_messages(result.output) + elseif result and result.status == "error" then + vim.api.nvim_exec_autocmds( + "User", + { pattern = "CodeCompanionRequest", data = { buf = self.bufnr, action = "stop_request" } } + ) + vim.notify("Error: " .. result.output, vim.log.levels.ERROR) + return finalize() + end + end end) end diff --git a/lua/spec/codecompanion/adapters/helpers.lua b/lua/spec/codecompanion/adapters/helpers.lua index 9122cec5..7cb1e967 100644 --- a/lua/spec/codecompanion/adapters/helpers.lua +++ b/lua/spec/codecompanion/adapters/helpers.lua @@ -1,16 +1,13 @@ local M = {} -function M.chat_buffer_output(stream_response, adapter, messages) +function M.chat_buffer_output(stream_response, adapter) local output = {} for _, data in ipairs(stream_response) do - data = adapter.callbacks.format_data(data.request) - data = vim.json.decode(data, { luanil = { object = true } }) - - output = adapter.callbacks.output_chat(data, messages, output) + output = adapter.callbacks.chat_output(data.request) end - return output + return output.output end return M diff --git a/lua/spec/codecompanion/adapters/ollama_spec.lua b/lua/spec/codecompanion/adapters/ollama_spec.lua index b069d4a6..a856bf7b 100644 --- a/lua/spec/codecompanion/adapters/ollama_spec.lua +++ b/lua/spec/codecompanion/adapters/ollama_spec.lua @@ -10,51 +10,51 @@ local messages = { { local stream_response = { { - request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.715665Z","message":{"role":"assistant","content":"\n"},"done":false}]], + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.622386Z","message":{"role":"assistant","content":"\n"},"done":false}]], output = { content = "\n", role = "assistant", }, }, { - request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.745213Z","message":{"role":"assistant","content":"\""},"done":false}]], + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.652682Z","message":{"role":"assistant","content":"\""},"done":false}]], output = { - content = '\n"', + content = '"', role = "assistant", }, }, { - request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.77473Z","message":{"role":"assistant","content":"E"},"done":false}]], + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.681756Z","message":{"role":"assistant","content":"Be"},"done":false}]], output = { - content = '\n"E', + content = "Be", role = "assistant", }, }, { - request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.803753Z","message":{"role":"assistant","content":"asy"},"done":false}]], + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.710758Z","message":{"role":"assistant","content":"aut"},"done":false}]], output = { - content = '\n"Easy', + content = "aut", role = "assistant", }, }, { - request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.833925Z","message":{"role":"assistant","content":" Program"},"done":false}]], + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.739508Z","message":{"role":"assistant","content":"iful"},"done":false}]], output = { - content = '\n"Easy Program', + content = "iful", role = "assistant", }, }, { - request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.862917Z","message":{"role":"assistant","content":"ming"},"done":false}]], + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.770345Z","message":{"role":"assistant","content":" Language"},"done":false}]], output = { - content = '\n"Easy Programming', + content = " Language", role = "assistant", }, }, { - request = [[{"model":"llama2","created_at":"2024-03-06T18:35:15.892319Z","message":{"role":"assistant","content":"\""},"done":false}]], + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.7994Z","message":{"role":"assistant","content":"\""},"done":false}]], output = { - content = '\n"Easy Programming"', + content = '"', role = "assistant", }, }, @@ -70,15 +70,10 @@ describe("Ollama adapter", function() end) it("can check if the streaming is complete", function() - local data = adapter.callbacks.format_data(done_response) - - assert.is_true(adapter.callbacks.is_complete(data)) + assert.is_true(adapter.callbacks.is_complete(done_response)) end) it("can output streamed data into a format for the chat buffer", function() - assert.are.same( - stream_response[#stream_response].output, - helpers.chat_buffer_output(stream_response, adapter, messages) - ) + assert.are.same(stream_response[#stream_response].output, helpers.chat_buffer_output(stream_response, adapter)) end) end) diff --git a/lua/spec/codecompanion/adapters/openai_spec.lua b/lua/spec/codecompanion/adapters/openai_spec.lua index ed478a14..266cbf1e 100644 --- a/lua/spec/codecompanion/adapters/openai_spec.lua +++ b/lua/spec/codecompanion/adapters/openai_spec.lua @@ -10,24 +10,22 @@ local messages = { { local stream_response = { { - request = 'data: {"id":"chatcmpl-8zlFGE8bEaXPG43tedyauJkw1EiMQ","object":"chat.completion.chunk","created":1709730310,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}', + request = [[data: {"id":"chatcmpl-90DdmqMKOKpqFemxX0OhTVdH042gu","object":"chat.completion.chunk","created":1709839462,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}]], output = { content = "", role = "assistant", }, }, { - request = 'data: {"id":"chatcmpl-8zlFGE8bEaXPG43tedyauJkw1EiMQ","object":"chat.completion.chunk","created":1709730310,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":"Programming"},"logprobs":null,"finish_reason":null}]}', + request = [[data: {"id":"chatcmpl-90DdmqMKOKpqFemxX0OhTVdH042gu","object":"chat.completion.chunk","created":1709839462,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":"Programming"},"logprobs":null,"finish_reason":null}]}]], output = { content = "Programming", - role = "assistant", }, }, { - request = 'data: {"id":"chatcmpl-8zlFGE8bEaXPG43tedyauJkw1EiMQ","object":"chat.completion.chunk","created":1709730310,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":" language"},"logprobs":null,"finish_reason":null}]}', + request = [[data: {"id":"chatcmpl-90DdmqMKOKpqFemxX0OhTVdH042gu","object":"chat.completion.chunk","created":1709839462,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":" language"},"logprobs":null,"finish_reason":null}]}]], output = { - content = "Programming language", - role = "assistant", + content = " language", }, }, } @@ -40,20 +38,11 @@ describe("OpenAI adapter", function() assert.are.same({ messages = messages }, adapter.callbacks.form_messages(messages)) end) - it("can format the data from the API", function() - assert.are.same("[DONE]", adapter.callbacks.format_data(done_response)) - end) - it("can check if the streaming is complete", function() - local data = adapter.callbacks.format_data(done_response) - - assert.is_true(adapter.callbacks.is_complete(data)) + assert.is_true(adapter.callbacks.is_complete(done_response)) end) it("can output streamed data into a format for the chat buffer", function() - assert.are.same( - stream_response[#stream_response].output, - helpers.chat_buffer_output(stream_response, adapter, messages) - ) + assert.are.same(stream_response[#stream_response].output, helpers.chat_buffer_output(stream_response, adapter)) end) end) From 66773c424d31ff9d35482ac7541cd80608a281cf Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 20:43:46 +0000 Subject: [PATCH 37/42] update README.md --- README.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index f5b65dc5..8362eb7e 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,8 @@

-Use the power of Generative AI in Neovim. Use it to chat, author and advise you on your code.
-Currently supports OpenAI and Ollama. +Use the power of generative AI in Neovim. Use it to chat, author and advise you on your code.

+Supports Anthropic, Ollama and OpenAI.

> [!IMPORTANT] @@ -30,10 +30,10 @@ Currently supports OpenAI and Ollama. ## :sparkles: Features - :speech_balloon: A Copilot Chat experience from within Neovim -- :electric_plug: Adapter support for many Generative AI services such as OpenAI and Ollama +- :electric_plug: Adapter support for many generative AI services - :rocket: Inline code creation and modification - :sparkles: Built in actions for specific language prompts LSP error fixes and code advice -- :building_construction: Create your own custom actions for Neovim which hook into Generative AI APIs +- :building_construction: Create your own custom actions for Neovim which hook into generative AI APIs - :floppy_disk: Save and restore your chats - :muscle: Async execution for improved performance @@ -52,7 +52,7 @@ Currently supports OpenAI and Ollama. - The `curl` library installed - Neovim 0.9.2 or greater -- _(Optional)_ An API key to be set in your shell for your chosen Generative AI service +- _(Optional)_ An API key to be set in your shell for your chosen generative AI service ## :package: Installation @@ -145,7 +145,7 @@ require("codecompanion").setup({ ["["] = "keymaps.previous", -- Move to the previous header in the chat }, log_level = "ERROR", -- TRACE|DEBUG|ERROR - send_code = true, -- Send code context to the Generative AI service? Disable to prevent leaking code outside of Neovim + send_code = true, -- Send code context to the generative AI service? Disable to prevent leaking code outside of Neovim silence_notifications = false, -- Silence notifications for actions like saving saving chats? use_default_actions = true, -- Use the default actions in the action palette? }) @@ -158,7 +158,7 @@ require("codecompanion").setup({ ### Adapters -The plugin uses adapters to bridge between Generative AI services and the plugin. Currently the plugin supports: +The plugin uses adapters to bridge between generative AI services and the plugin. Currently the plugin supports: - Anthropic (`anthropic`) - Requires `ANTHROPIC_API_KEY` to be set in your shell - Ollama (`ollama`) @@ -258,13 +258,13 @@ The Action Palette, opened via `:CodeCompanionActions`, contains all of the acti

chat buffer

-The chat buffer is where you can converse with the Generative AI service, directly from Neovim. It behaves as a regular markdown buffer with some clever additions. When the buffer is written (or "saved"), autocmds trigger the sending of its content to the Generative AI service, in the form of prompts. These prompts are segmented by H1 headers: `user` and `assistant`. When a response is received, it is then streamed back into the buffer. The result is that you experience the feel of conversing with ChatGPT from within Neovim. +The chat buffer is where you can converse with the generative AI service, directly from Neovim. It behaves as a regular markdown buffer with some clever additions. When the buffer is written (or "saved"), autocmds trigger the sending of its content to the generative AI service, in the form of prompts. These prompts are segmented by H1 headers: `user` and `assistant`. When a response is received, it is then streamed back into the buffer. The result is that you experience the feel of conversing with ChatGPT from within Neovim. #### Keymaps When in the chat buffer, there are number of keymaps available to you (which can be changed in the config): -- `` - Save the buffer and trigger a response from the Generative AI service +- `` - Save the buffer and trigger a response from the generative AI service - `` - Close the buffer - `q` - Cancel the stream from the API - `gc` - Clear the buffer's contents @@ -298,7 +298,7 @@ You can use the plugin to create inline code directly into a Neovim buffer. This > [!NOTE] > The command can detect if you've made a visual selection and send any code as context to the API alongside the filetype of the buffer. -One of the challenges with inline editing is determining how the API's response should be handled in the buffer. If you've prompted the API to _"create a table of 5 fruits"_ then you may wish for the response to be placed after the cursor in the buffer. However, if you asked the API to _"refactor this function"_ then you'd expect the response to overwrite a visual selection. If this _placement_ isn't specified then the plugin will use Generative AI itself to determine if the response should follow any of the placements below: +One of the challenges with inline editing is determining how the API's response should be handled in the buffer. If you've prompted the API to _"create a table of 5 fruits"_ then you may wish for the response to be placed after the cursor in the buffer. However, if you asked the API to _"refactor this function"_ then you'd expect the response to overwrite a visual selection. If this _placement_ isn't specified then the plugin will use generative AI itself to determine if the response should follow any of the placements below: - _after_ - after the visual selection - _before_ - before the visual selection @@ -310,11 +310,11 @@ As a final example, specifying a prompt like _"create a test for this code in a ### In-Built Actions -The plugin comes with a number of [in-built actions](https://github.com/olimorris/codecompanion.nvim/blob/main/lua/codecompanion/actions.lua) which aim to improve your Neovim workflow. Actions make use of either a _chat_ or an _inline_ strategy. The chat strategy opens up a chat buffer whilst an inline strategy will write output from the Generative AI service into the Neovim buffer. +The plugin comes with a number of [in-built actions](https://github.com/olimorris/codecompanion.nvim/blob/main/lua/codecompanion/actions.lua) which aim to improve your Neovim workflow. Actions make use of either a _chat_ or an _inline_ strategy. The chat strategy opens up a chat buffer whilst an inline strategy will write output from the generative AI service into the Neovim buffer. #### Chat and Chat as -Both of these actions utilise the `chat` strategy. The `Chat` action opens up a fresh chat buffer. The `Chat as` action allows for persona based context to be set in the chat buffer allowing for better and more detailed responses from the Generative AI service. +Both of these actions utilise the `chat` strategy. The `Chat` action opens up a fresh chat buffer. The `Chat as` action allows for persona based context to be set in the chat buffer allowing for better and more detailed responses from the generative AI service. > [!TIP] > Both of these actions allow for visually selected code to be sent to the chat buffer as code blocks. @@ -342,7 +342,7 @@ As the name suggests, this action provides advice on a visual selection of code #### LSP assistant -Taken from the fantastic [Wtf.nvim](https://github.com/piersolenski/wtf.nvim) plugin, this action provides advice on how to correct any LSP diagnostics which are present on the visually selected lines. Again, the `send_code = false` value can be set in your config to prevent the code itself being sent to the Generative AI service. +Taken from the fantastic [Wtf.nvim](https://github.com/piersolenski/wtf.nvim) plugin, this action provides advice on how to correct any LSP diagnostics which are present on the visually selected lines. Again, the `send_code = false` value can be set in your config to prevent the code itself being sent to the generative AI service. ## :rainbow: Helpers @@ -377,7 +377,7 @@ vim.api.nvim_create_autocmd({ "User" }, { ### Heirline.nvim -If you're using the fantastic [Heirline.nvim](https://github.com/rebelot/heirline.nvim) plugin, consider the following snippet to display an icon in the statusline whilst CodeCompanion is conversing with the Generative AI service: +If you're using the fantastic [Heirline.nvim](https://github.com/rebelot/heirline.nvim) plugin, consider the following snippet to display an icon in the statusline whilst CodeCompanion is conversing with the generative AI service: ```lua local CodeCompanion = { From 90f9d0ef63ecb794d0cd1834eafeff00428004ce Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 21:14:03 +0000 Subject: [PATCH 38/42] fix inline for ollama and openai --- lua/codecompanion/adapters/ollama.lua | 16 +++++++++++----- lua/codecompanion/adapters/openai.lua | 18 +++++++++++++----- lua/codecompanion/client.lua | 2 +- lua/codecompanion/strategies/inline.lua | 2 +- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua index 5ae2d0f1..9f2c9689 100644 --- a/lua/codecompanion/adapters/ollama.lua +++ b/lua/codecompanion/adapters/ollama.lua @@ -53,7 +53,7 @@ return { output.role = message.role or nil end - log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) + -- log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) return { status = "success", @@ -65,11 +65,17 @@ return { end, ---Output the data from the API ready for inlining into the current buffer - ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback ---@param context table Useful context about the buffer to inline to - ---@return table - output_inline = function(json_data, context) - return json_data.message.content + ---@return table|nil + inline_output = function(data, context) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return + end + + return json.message.content end, }, schema = { diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 15ed3a46..f6f9113e 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -68,22 +68,30 @@ return { output.role = delta.role or nil end - log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) + -- log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) return { status = "success", output = output, } end + return nil end, ---Output the data from the API ready for inlining into the current buffer - ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback ---@param context table Useful context about the buffer to inline to - ---@return table - output_inline = function(json_data, context) - return json_data.choices[1].delta.content + ---@return table|nil + inline_output = function(data, context) + data = data:sub(7) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return + end + + return json.choices[1].delta.content end, }, schema = { diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index b2bacc7a..fe99ef93 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -100,7 +100,7 @@ function Client:stream(adapter, payload, bufnr, cb) body = body, stream = self.opts.schedule(function(_, data) log:trace("Chat data: %s", data) - log:trace("----- For Adapter test creation -----\nRequest: %s\n ---------- // END ----------", data) + -- log:trace("----- For Adapter test creation -----\nRequest: %s\n ---------- // END ----------", data) if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then stop_request(bufnr, { shutdown = true }) diff --git a/lua/codecompanion/strategies/inline.lua b/lua/codecompanion/strategies/inline.lua index 1bc61088..1bc48588 100644 --- a/lua/codecompanion/strategies/inline.lua +++ b/lua/codecompanion/strategies/inline.lua @@ -280,7 +280,7 @@ function Inline:execute(user_input) if data then log:trace("Inline data: %s", data) - local content = adapter.callbacks.output_inline(data, self.context) + local content = adapter.callbacks.inline_output(data, self.context) if self.context.buftype == "terminal" then -- Don't stream to the terminal From d82bccfd632598cc59454df9ad6e3661af8a4b37 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 22:08:44 +0000 Subject: [PATCH 39/42] fix anthropic adapter --- lua/codecompanion/adapters/anthropic.lua | 108 +++++++++++++---------- lua/codecompanion/client.lua | 5 +- 2 files changed, 63 insertions(+), 50 deletions(-) diff --git a/lua/codecompanion/adapters/anthropic.lua b/lua/codecompanion/adapters/anthropic.lua index f5d7ecc9..49ef32e9 100644 --- a/lua/codecompanion/adapters/anthropic.lua +++ b/lua/codecompanion/adapters/anthropic.lua @@ -1,3 +1,5 @@ +local log = require("codecompanion.utils.log") + ---@class CodeCompanion.Adapter ---@field name string ---@field url string @@ -29,70 +31,84 @@ return { return { messages = messages } end, - ---Does this streamed data need to be skipped? - ---@param data table - ---@return boolean - should_skip = function(data) - if type(data) == "string" then - return string.sub(data, 1, 6) == "event:" - end - return false - end, - - ---Format any data before it's consumed by the other callbacks - ---@param data string - ---@return string - format_data = function(data) - return data:sub(6) - end, - - ---Does the data contain an error? - ---@param data string - ---@return boolean - has_error = function(data) - local msg = "event: error" - return string.sub(data, 1, string.len(msg)) == msg - end, - ---Has the streaming completed? ---@param data string The data from the format_data callback ---@return boolean is_complete = function(data) - local ok - ok, data = pcall(vim.fn.json_decode, data) - if ok and data.type then - return data.type == "message_stop" + if data then + data = data:sub(6) + + local ok + ok, data = pcall(vim.fn.json_decode, data) + if ok and data.type then + return data.type == "message_stop" + end + if ok and data.delta.stop_reason then + return data.delta.stop_reason == "end_turn" + end end return false end, ---Output the data from the API ready for insertion into the chat buffer - ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback - ---@param messages table A table of all of the messages in the chat buffer - ---@param current_message table The current/latest message in the chat buffer - ---@return table - output_chat = function(json_data, messages, current_message) - if json_data.type == "message_start" then - current_message = { role = json_data.message.role, content = "" } - table.insert(messages, current_message) - end + ---@param data string The streamed JSON data from the API, also formatted by the format_data callback + ---@return table|nil + chat_output = function(data) + local output = {} - if json_data.type == "content_block_delta" then - current_message.content = current_message.content .. json_data.delta.text + -- Skip the event messages + if type(data) == "string" and string.sub(data, 1, 6) == "event:" then + return end - return current_message + if data and data ~= "" then + data = data:sub(6) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return { + status = "error", + output = string.format("Error malformed json: %s", json), + } + end + + if json.type == "message_start" then + output.role = json.message.role + output.content = "" + end + + if json.type == "content_block_delta" then + output.role = nil + output.content = json.delta.text + end + + -- log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) + + return { + status = "success", + output = output, + } + end end, ---Output the data from the API ready for inlining into the current buffer - ---@param json_data table The streamed JSON data from the API, also formatted by the format_data callback + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback ---@param context table Useful context about the buffer to inline to ---@return table|nil - output_inline = function(json_data, context) - if json_data.type == "content_block_delta" then - return json_data.delta.text + inline_output = function(data, context) + data = data:sub(6) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return end - return nil + + log:trace("INLINE JSON: %s", json) + if json.type == "content_block_delta" then + return json.delta.text + end + + return end, }, schema = { diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index fe99ef93..856a34cb 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -78,8 +78,6 @@ function Client:stream(adapter, payload, bufnr, cb) local body = self.opts.encode(vim.tbl_extend("keep", adapter.parameters or {}, adapter.callbacks.form_messages(payload))) - log:debug("Adapter: %s", { adapter.name, adapter.url, adapter.raw, headers, body }) - local stop_request_cmd = api.nvim_create_autocmd("User", { desc = "Stop the current request", pattern = "CodeCompanionRequest", @@ -94,12 +92,11 @@ function Client:stream(adapter, payload, bufnr, cb) local handler = self.opts.request({ url = adapter.url, - timeout = 1000, raw = adapter.raw or { "--no-buffer" }, headers = headers, body = body, stream = self.opts.schedule(function(_, data) - log:trace("Chat data: %s", data) + -- log:trace("Chat data: %s", data) -- log:trace("----- For Adapter test creation -----\nRequest: %s\n ---------- // END ----------", data) if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then From 6c9003174f52ada31701e87736ecd51b6a7ffbd1 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 22:31:57 +0000 Subject: [PATCH 40/42] tweaks --- lua/codecompanion/client.lua | 3 +- lua/codecompanion/config.lua | 4 +- lua/codecompanion/strategies/inline.lua | 50 ------------------------- 3 files changed, 3 insertions(+), 54 deletions(-) diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 856a34cb..f6788157 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -46,7 +46,6 @@ Client.static = {} Client.static.opts = { request = { default = curl.post }, encode = { default = vim.json.encode }, - decode = { default = vim.json.decode }, schedule = { default = vim.schedule_wrap }, } @@ -96,7 +95,7 @@ function Client:stream(adapter, payload, bufnr, cb) headers = headers, body = body, stream = self.opts.schedule(function(_, data) - -- log:trace("Chat data: %s", data) + log:trace("Chat data: %s", data) -- log:trace("----- For Adapter test creation -----\nRequest: %s\n ---------- // END ----------", data) if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index e09fa2c5..e59690e2 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -2,8 +2,8 @@ local M = {} local defaults = { adapters = { - chat = require("codecompanion.adapters").use("ollama"), - inline = require("codecompanion.adapters").use("ollama"), + chat = require("codecompanion.adapters").use("openai"), + inline = require("codecompanion.adapters").use("openai"), }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", diff --git a/lua/codecompanion/strategies/inline.lua b/lua/codecompanion/strategies/inline.lua index 1bc48588..b4e921d9 100644 --- a/lua/codecompanion/strategies/inline.lua +++ b/lua/codecompanion/strategies/inline.lua @@ -173,61 +173,11 @@ function Inline:execute(user_input) local messages = get_action(self, user_input) - -- if not self.opts.placement and user_input then - -- local return_code - -- self.opts.placement, return_code = get_placement_position(self, user_input) - -- - -- if not return_code then - -- table.insert(messages, { - -- role = "user", - -- content = "Please do not return the code I have sent in the response", - -- }) - -- end - -- - -- log:debug("Setting the placement to: %s", self.opts.placement) - -- end - -- Assume the placement should be after the cursor vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) pos.line = self.context.end_line + 1 pos.col = 0 - --TODO: Workout how we can re-enable this - -- Determine where to place the response in the buffer - -- if self.opts and self.opts.placement then - -- if self.opts.placement == "before" then - -- log:trace("Placing before selection") - -- vim.api.nvim_buf_set_lines( - -- self.context.bufnr, - -- self.context.start_line - 1, - -- self.context.start_line - 1, - -- false, - -- { "" } - -- ) - -- self.context.start_line = self.context.start_line + 1 - -- pos.line = self.context.start_line - 1 - -- pos.col = self.context.start_col - 1 - -- elseif self.opts.placement == "after" then - -- log:trace("Placing after selection") - -- vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) - -- pos.line = self.context.end_line + 1 - -- pos.col = 0 - -- elseif self.opts.placement == "replace" then - -- log:trace("Placing by overwriting selection") - -- overwrite_selection(self.context) - -- - -- pos.line, pos.col = get_cursor(self.context.winid) - -- elseif self.opts.placement == "new" then - -- log:trace("Placing in a new buffer") - -- self.context.bufnr = api.nvim_create_buf(true, false) - -- api.nvim_buf_set_option(self.context.bufnr, "filetype", self.context.filetype) - -- api.nvim_set_current_buf(self.context.bufnr) - -- - -- pos.line = 1 - -- pos.col = 0 - -- end - -- end - log:debug("Context for inline: %s", self.context) log:debug("Cursor position to use: %s", pos) From 580a5d6ad39dafad56592b4e2737dfb91541c03a Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 22:32:07 +0000 Subject: [PATCH 41/42] add adapters.md --- ADAPTERS.md | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 ADAPTERS.md diff --git a/ADAPTERS.md b/ADAPTERS.md new file mode 100644 index 00000000..fab8f569 --- /dev/null +++ b/ADAPTERS.md @@ -0,0 +1,8 @@ +# Adapters + +The purpose of this guide is to showcase how you can extend the functionality of CodeCompanion by adding your own actions to the _Action Palette_. + + +## Testing your adapters + +- Two commented out lines in client.lua and the adapter itself From e26c3e08d78fa8cad1a3a539a319a8d91f69c512 Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Thu, 7 Mar 2024 22:32:11 +0000 Subject: [PATCH 42/42] add anthropic spec --- .../codecompanion/adapters/anthropic_spec.lua | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 lua/spec/codecompanion/adapters/anthropic_spec.lua diff --git a/lua/spec/codecompanion/adapters/anthropic_spec.lua b/lua/spec/codecompanion/adapters/anthropic_spec.lua new file mode 100644 index 00000000..6bdb56e9 --- /dev/null +++ b/lua/spec/codecompanion/adapters/anthropic_spec.lua @@ -0,0 +1,60 @@ +local adapter = require("codecompanion.adapters.anthropic") +local assert = require("luassert") +local helpers = require("spec.codecompanion.adapters.helpers") + +--------------------------------------------------- OUTPUT FROM THE CHAT BUFFER +local messages = { { + content = "Explain Ruby in two words", + role = "user", +} } + +local stream_response = { + { + request = [[data: {"type":"message_start","message":{"id":"msg_01Ngmyfn49udNhWaojMVKiR6","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":13,"output_tokens":1}}}]], + output = { + content = "", + role = "assistant", + }, + }, + { + request = [[data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Dynamic"}}]], + output = { + content = "Dynamic", + }, + }, + { + request = [[data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":","}}]], + output = { + content = ",", + }, + }, + { + request = [[data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" elegant"}}]], + output = { + content = " elegant", + }, + }, + { + request = [[data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."}}]], + output = { + content = ".", + }, + }, +} + +local done_response = [[data: {"type":"message_stop"}]] +------------------------------------------------------------------------ // END + +describe("Anthropic adapter", function() + it("can form messages to be sent to the API", function() + assert.are.same({ messages = messages }, adapter.callbacks.form_messages(messages)) + end) + + it("can check if the streaming is complete", function() + assert.is_true(adapter.callbacks.is_complete(done_response)) + end) + + it("can output streamed data into a format for the chat buffer", function() + assert.are.same(stream_response[#stream_response].output, helpers.chat_buffer_output(stream_response, adapter)) + end) +end)