From cd610db060d11dc5b6eb870267b503dacd9019e6 Mon Sep 17 00:00:00 2001 From: Oli Date: Thu, 7 Mar 2024 22:33:30 +0000 Subject: [PATCH] feat: add support for anthropic and ollama via adapters --- .github/workflows/ci.yml | 2 + ADAPTERS.md | 8 + README.md | 102 ++++---- RECIPES.md | 2 +- lua/codecompanion/actions.lua | 4 +- lua/codecompanion/adapter.lua | 107 ++++++++ lua/codecompanion/adapters/anthropic.lua | 150 +++++++++++ lua/codecompanion/adapters/init.lua | 34 +++ lua/codecompanion/adapters/ollama.lua | 107 ++++++++ lua/codecompanion/adapters/openai.lua | 209 +++++++++++++++ lua/codecompanion/client.lua | 242 +++++------------- lua/codecompanion/config.lua | 29 +-- lua/codecompanion/health.lua | 22 +- lua/codecompanion/init.lua | 48 +--- lua/codecompanion/keymaps.lua | 4 +- lua/codecompanion/schema.lua | 119 +-------- .../{strategy => strategies}/chat.lua | 149 ++++++----- .../{strategy => strategies}/inline.lua | 149 ++++------- .../{strategy => strategies}/saved_chats.lua | 2 +- lua/codecompanion/strategy.lua | 9 +- lua/codecompanion/utils/yaml.lua | 6 +- lua/spec/codecompanion/adapter_spec.lua | 91 +++++++ .../codecompanion/adapters/anthropic_spec.lua | 60 +++++ lua/spec/codecompanion/adapters/helpers.lua | 13 + .../codecompanion/adapters/ollama_spec.lua | 79 ++++++ .../codecompanion/adapters/openai_spec.lua | 48 ++++ lua/spec/codecompanion/client_spec.lua | 34 ++- 27 files changed, 1220 insertions(+), 609 deletions(-) create mode 100644 ADAPTERS.md create mode 100644 lua/codecompanion/adapter.lua create mode 100644 lua/codecompanion/adapters/anthropic.lua create mode 100644 lua/codecompanion/adapters/init.lua create mode 100644 lua/codecompanion/adapters/ollama.lua create mode 100644 lua/codecompanion/adapters/openai.lua rename lua/codecompanion/{strategy => strategies}/chat.lua (78%) rename lua/codecompanion/{strategy => strategies}/inline.lua (67%) rename lua/codecompanion/{strategy => strategies}/saved_chats.lua (98%) create mode 100644 lua/spec/codecompanion/adapter_spec.lua create mode 100644 lua/spec/codecompanion/adapters/anthropic_spec.lua create mode 100644 lua/spec/codecompanion/adapters/helpers.lua create mode 100644 lua/spec/codecompanion/adapters/ollama_spec.lua create mode 100644 lua/spec/codecompanion/adapters/openai_spec.lua diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 285924ed..c1a22039 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,6 +42,8 @@ jobs: ln -s $(pwd) ~/.local/share/nvim/site/pack/vendor/start - name: Run tests + env: + OPENAI_API_KEY: abc-123 run: | export PATH="${PWD}/_neovim/bin:${PATH}" export VIM="${PWD}/_neovim/share/nvim/runtime" diff --git a/ADAPTERS.md b/ADAPTERS.md new file mode 100644 index 00000000..fab8f569 --- /dev/null +++ b/ADAPTERS.md @@ -0,0 +1,8 @@ +# Adapters + +The purpose of this guide is to showcase how you can extend the functionality of CodeCompanion by adding your own actions to the _Action Palette_. + + +## Testing your adapters + +- Two commented out lines in client.lua and the adapter itself diff --git a/README.md b/README.md index 4bd6aca8..8362eb7e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@

-CodeCompanion.nvim +CodeCompanion.nvim

CodeCompanion.nvim

@@ -14,7 +14,8 @@

-Use the OpenAI APIs directly in Neovim. Use it to chat, author and advise you on your code. +Use the power of generative AI in Neovim. Use it to chat, author and advise you on your code.

+Supports Anthropic, Ollama and OpenAI.

> [!IMPORTANT] @@ -29,9 +30,10 @@ Use the > [!WARNING] -> For some users, the sending of any code to OpenAI may not be an option. In those instances, you can set `send_code = false` in your config. +> Depending on your chosen adapter, you may need to setup environment variables within your shell. See the adapters section below for specific information. + +### Adapters + +The plugin uses adapters to bridge between generative AI services and the plugin. Currently the plugin supports: + +- Anthropic (`anthropic`) - Requires `ANTHROPIC_API_KEY` to be set in your shell +- Ollama (`ollama`) +- OpenAI (`openai`) - Requires `OPENAI_API_KEY` to be set in your shell + +You can specify an adapter for each of the strategies in the plugin: + +```lua +require("codecompanion").setup({ + adapters = { + chat = require("codecompanion.adapters").use("openai"), + inline = require("codecompanion.adapters").use("openai"), + }, +}) +``` + +#### Modifying Adapters + +It may be necessary to modify certain parameters of an adapter. In the example below, we're changing the name of the API key that the OpenAI adapter uses by passing in a table to the `use` method: + +```lua +require("codecompanion").setup({ + adapters = { + chat = require("codecompanion.adapters").use("openai", { + env = { + api_key = "DIFFERENT_OPENAI_KEY", + }, + }), + }, +}) +``` + +> [!TIP] +> To create your own adapter please refer to the [ADAPTERS](ADAPTERS.md) guide ### Edgy.nvim Configuration @@ -244,13 +258,13 @@ The Action Palette, opened via `:CodeCompanionActions`, contains all of the acti

chat buffer

-The chat buffer is where you can converse with the OpenAI APIs, directly from Neovim. It behaves as a regular markdown buffer with some clever additions. When the buffer is written (or "saved"), autocmds trigger the sending of its content to OpenAI, in the form of prompts. These prompts are segmented by H1 headers: `user` and `assistant` (see OpenAI's [Chat Completions API](https://platform.openai.com/docs/guides/text-generation/chat-completions-api) for more on this). When a response is received, it is then streamed back into the buffer. The result is that you experience the feel of conversing with ChatGPT from within Neovim. +The chat buffer is where you can converse with the generative AI service, directly from Neovim. It behaves as a regular markdown buffer with some clever additions. When the buffer is written (or "saved"), autocmds trigger the sending of its content to the generative AI service, in the form of prompts. These prompts are segmented by H1 headers: `user` and `assistant`. When a response is received, it is then streamed back into the buffer. The result is that you experience the feel of conversing with ChatGPT from within Neovim. #### Keymaps When in the chat buffer, there are number of keymaps available to you (which can be changed in the config): -- `` - Save the buffer and trigger a response from the OpenAI API +- `` - Save the buffer and trigger a response from the generative AI service - `` - Close the buffer - `q` - Cancel the stream from the API - `gc` - Clear the buffer's contents @@ -265,7 +279,7 @@ Chat buffers are not saved to disk by default, but can be by pressing `gs` in th #### Settings -If `display.chat.show_settings` is set to `true`, at the very top of the chat buffer will be the OpenAI parameters which can be changed to tweak the response back to you. This enables fine-tuning and parameter tweaking throughout the chat. You can find more detail about them by moving the cursor over them or referring to the [OpenAI Chat Completions reference guide](https://platform.openai.com/docs/api-reference/chat). +If `display.chat.show_settings` is set to `true`, at the very top of the chat buffer will be the adapter parameters which can be changed to tweak the response back to you. This enables fine-tuning and parameter tweaking throughout the chat. You can find more detail about them by moving the cursor over them. ### Inline Code @@ -284,7 +298,7 @@ You can use the plugin to create inline code directly into a Neovim buffer. This > [!NOTE] > The command can detect if you've made a visual selection and send any code as context to the API alongside the filetype of the buffer. -One of the challenges with inline editing is determining how the API's response should be handled in the buffer. If you've prompted the API to _"create a table of 5 fruits"_ then you may wish for the response to be placed after the cursor in the buffer. However, if you asked the API to _"refactor this function"_ then you'd expect the response to overwrite a visual selection. If this _placement_ isn't specified then the plugin will use OpenAI itself to determine if the response should follow any of the placements below: +One of the challenges with inline editing is determining how the API's response should be handled in the buffer. If you've prompted the API to _"create a table of 5 fruits"_ then you may wish for the response to be placed after the cursor in the buffer. However, if you asked the API to _"refactor this function"_ then you'd expect the response to overwrite a visual selection. If this _placement_ isn't specified then the plugin will use generative AI itself to determine if the response should follow any of the placements below: - _after_ - after the visual selection - _before_ - before the visual selection @@ -296,11 +310,11 @@ As a final example, specifying a prompt like _"create a test for this code in a ### In-Built Actions -The plugin comes with a number of [in-built actions](https://github.com/olimorris/codecompanion.nvim/blob/main/lua/codecompanion/actions.lua) which aim to improve your Neovim workflow. Actions make use of either a _chat_ or an _inline_ strategy, which are essentially bridges between Neovim and OpenAI. The chat strategy opens up a chat buffer whilst an inline strategy will write output from OpenAI into the Neovim buffer. +The plugin comes with a number of [in-built actions](https://github.com/olimorris/codecompanion.nvim/blob/main/lua/codecompanion/actions.lua) which aim to improve your Neovim workflow. Actions make use of either a _chat_ or an _inline_ strategy. The chat strategy opens up a chat buffer whilst an inline strategy will write output from the generative AI service into the Neovim buffer. #### Chat and Chat as -Both of these actions utilise the `chat` strategy. The `Chat` action opens up a fresh chat buffer. The `Chat as` action allows for persona based context to be set in the chat buffer allowing for better and more detailed responses from OpenAI. +Both of these actions utilise the `chat` strategy. The `Chat` action opens up a fresh chat buffer. The `Chat as` action allows for persona based context to be set in the chat buffer allowing for better and more detailed responses from the generative AI service. > [!TIP] > Both of these actions allow for visually selected code to be sent to the chat buffer as code blocks. @@ -328,7 +342,7 @@ As the name suggests, this action provides advice on a visual selection of code #### LSP assistant -Taken from the fantastic [Wtf.nvim](https://github.com/piersolenski/wtf.nvim) plugin, this action provides advice on how to correct any LSP diagnostics which are present on the visually selected lines. Again, the `send_code = false` value can be set in your config to prevent the code itself being sent to OpenAI. +Taken from the fantastic [Wtf.nvim](https://github.com/piersolenski/wtf.nvim) plugin, this action provides advice on how to correct any LSP diagnostics which are present on the visually selected lines. Again, the `send_code = false` value can be set in your config to prevent the code itself being sent to the generative AI service. ## :rainbow: Helpers @@ -363,10 +377,10 @@ vim.api.nvim_create_autocmd({ "User" }, { ### Heirline.nvim -If you're using the fantastic [Heirline.nvim](https://github.com/rebelot/heirline.nvim) plugin, consider the following snippet to display an icon in the statusline whilst CodeCompanion is conversing with OpenAI: +If you're using the fantastic [Heirline.nvim](https://github.com/rebelot/heirline.nvim) plugin, consider the following snippet to display an icon in the statusline whilst CodeCompanion is conversing with the generative AI service: ```lua -local OpenAI = { +local CodeCompanion = { static = { processing = false, }, diff --git a/RECIPES.md b/RECIPES.md index cc2517c2..d9ca4a11 100644 --- a/RECIPES.md +++ b/RECIPES.md @@ -298,7 +298,7 @@ And to determine the visibility of actions in the palette itself: strategy = "saved_chats", description = "Load your previously saved chats", condition = function() - local saved_chats = require("codecompanion.strategy.saved_chats") + local saved_chats = require("codecompanion.strategies.saved_chats") return saved_chats:has_chats() end, picker = { diff --git a/lua/codecompanion/actions.lua b/lua/codecompanion/actions.lua index 74cffa58..0f482ad2 100644 --- a/lua/codecompanion/actions.lua +++ b/lua/codecompanion/actions.lua @@ -505,14 +505,14 @@ M.static.actions = { strategy = "saved_chats", description = "Load your previously saved chats", condition = function() - local saved_chats = require("codecompanion.strategy.saved_chats") + local saved_chats = require("codecompanion.strategies.saved_chats") return saved_chats:has_chats() end, picker = { prompt = "Load chats", items = function() local client = require("codecompanion").get_client() - local saved_chats = require("codecompanion.strategy.saved_chats") + local saved_chats = require("codecompanion.strategies.saved_chats") local items = saved_chats:list({ sort = true }) local chats = {} diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua new file mode 100644 index 00000000..143b4971 --- /dev/null +++ b/lua/codecompanion/adapter.lua @@ -0,0 +1,107 @@ +local log = require("codecompanion.utils.log") + +---@class CodeCompanion.Adapter +---@field name string +---@field url string +---@field env? table +---@field raw? table +---@field header table +---@field parameters table +---@field callbacks table +---@field schema table +local Adapter = {} + +---@class CodeCompanion.AdapterArgs +---@field name string +---@field url string +---@field env? table +---@field raw? table +---@field header table +---@field parameters table +---@field callbacks table +---@field schema table + +---@param args table +---@return CodeCompanion.Adapter +function Adapter.new(args) + return setmetatable(args, { __index = Adapter }) +end + +---@return table +function Adapter:get_default_settings() + local settings = {} + + for key, value in pairs(self.schema) do + if value.default ~= nil then + settings[key] = value.default + end + end + + return settings +end + +---@param settings? table +---@return CodeCompanion.Adapter +function Adapter:set_params(settings) + if not settings then + settings = self:get_default_settings() + end + + for k, v in pairs(settings) do + local mapping = self.schema[k] and self.schema[k].mapping + if mapping then + local segments = {} + for segment in string.gmatch(mapping, "[^.]+") do + table.insert(segments, segment) + end + + local current = self + for i = 1, #segments - 1 do + if not current[segments[i]] then + current[segments[i]] = {} + end + current = current[segments[i]] + end + + -- Before setting the value, ensure the target exists or initialize it. + local target = segments[#segments] + if not current[target] then + current[target] = {} + end + + -- Ensure 'target' is not nil and 'k' can be assigned to the final segment. + if target then + current[target][k] = v + end + end + end + + return self +end + +---@return CodeCompanion.Adapter +function Adapter:replace_header_vars() + if self.headers then + for k, v in pairs(self.headers) do + self.headers[k] = v:gsub("${(.-)}", function(var) + local env_var = self.env[var] + + if env_var then + env_var = os.getenv(env_var) + if not env_var then + log:error("Error: Could not find env var: %s", self.env[var]) + return vim.notify( + string.format("[CodeCompanion.nvim]\nCould not find env var: %s", self.env[var]), + vim.log.levels.ERROR + ) + end + return env_var + end + end) + end + end + + return self +end + +return Adapter diff --git a/lua/codecompanion/adapters/anthropic.lua b/lua/codecompanion/adapters/anthropic.lua new file mode 100644 index 00000000..49ef32e9 --- /dev/null +++ b/lua/codecompanion/adapters/anthropic.lua @@ -0,0 +1,150 @@ +local log = require("codecompanion.utils.log") + +---@class CodeCompanion.Adapter +---@field name string +---@field url string +---@field raw? table +---@field headers table +---@field parameters table +---@field callbacks table +---@field schema table +return { + name = "Anthropic", + url = "https://api.anthropic.com/v1/messages", + env = { + api_key = "ANTHROPIC_API_KEY", + }, + headers = { + ["anthropic-version"] = "2023-06-01", + -- ["anthropic-beta"] = "messages-2023-12-15", + ["content-type"] = "application/json", + ["x-api-key"] = "${api_key}", + }, + parameters = { + stream = true, + }, + callbacks = { + ---Set the format of the role and content for the messages from the chat buffer + ---@param messages table Format is: { { role = "user", content = "Your prompt here" } } + ---@return table + form_messages = function(messages) + return { messages = messages } + end, + + ---Has the streaming completed? + ---@param data string The data from the format_data callback + ---@return boolean + is_complete = function(data) + if data then + data = data:sub(6) + + local ok + ok, data = pcall(vim.fn.json_decode, data) + if ok and data.type then + return data.type == "message_stop" + end + if ok and data.delta.stop_reason then + return data.delta.stop_reason == "end_turn" + end + end + return false + end, + + ---Output the data from the API ready for insertion into the chat buffer + ---@param data string The streamed JSON data from the API, also formatted by the format_data callback + ---@return table|nil + chat_output = function(data) + local output = {} + + -- Skip the event messages + if type(data) == "string" and string.sub(data, 1, 6) == "event:" then + return + end + + if data and data ~= "" then + data = data:sub(6) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return { + status = "error", + output = string.format("Error malformed json: %s", json), + } + end + + if json.type == "message_start" then + output.role = json.message.role + output.content = "" + end + + if json.type == "content_block_delta" then + output.role = nil + output.content = json.delta.text + end + + -- log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) + + return { + status = "success", + output = output, + } + end + end, + + ---Output the data from the API ready for inlining into the current buffer + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback + ---@param context table Useful context about the buffer to inline to + ---@return table|nil + inline_output = function(data, context) + data = data:sub(6) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return + end + + log:trace("INLINE JSON: %s", json) + if json.type == "content_block_delta" then + return json.delta.text + end + + return + end, + }, + schema = { + model = { + order = 1, + mapping = "parameters", + type = "enum", + desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + default = "claude-3-opus-20240229", + choices = { + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-2.1", + }, + }, + max_tokens = { + order = 2, + mapping = "parameters", + type = "number", + optional = true, + default = 1024, + desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", + validate = function(n) + return n > 0, "Must be greater than 0" + end, + }, + temperature = { + order = 3, + mapping = "parameters", + type = "number", + optional = true, + default = 0, + desc = "What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 1, "Must be between 0 and 1" + end, + }, + }, +} diff --git a/lua/codecompanion/adapters/init.lua b/lua/codecompanion/adapters/init.lua new file mode 100644 index 00000000..fbe0f5d7 --- /dev/null +++ b/lua/codecompanion/adapters/init.lua @@ -0,0 +1,34 @@ +local Adapter = require("codecompanion.adapter") + +local M = {} + +---@param adapter table +---@param opts table +---@return table +local function setup(adapter, opts) + return vim.tbl_deep_extend("force", {}, adapter, opts or {}) +end + +---@param adapter string|table +---@param opts? table +---@return CodeCompanion.Adapter|nil +function M.use(adapter, opts) + local adapter_config + + if type(adapter) == "string" then + adapter_config = require("codecompanion.adapters." .. adapter) + elseif type(adapter) == "table" then + adapter_config = adapter + else + error("Adapter must be a string or a table") + return + end + + if opts then + adapter_config = setup(adapter_config, opts) + end + + return Adapter.new(adapter_config) +end + +return M diff --git a/lua/codecompanion/adapters/ollama.lua b/lua/codecompanion/adapters/ollama.lua new file mode 100644 index 00000000..9f2c9689 --- /dev/null +++ b/lua/codecompanion/adapters/ollama.lua @@ -0,0 +1,107 @@ +local log = require("codecompanion.utils.log") + +---@class CodeCompanion.Adapter +---@field name string +---@field url string +---@field raw? table +---@field headers table +---@field parameters table +---@field callbacks table +---@field schema table +return { + name = "Ollama", + url = "http://localhost:11434/api/chat", + callbacks = { + ---Set the format of the role and content for the messages from the chat buffer + ---@param messages table Format is: { { role = "user", content = "Your prompt here" } } + ---@return table + form_messages = function(messages) + return { messages = messages } + end, + + ---Has the streaming completed? + ---@param data table The data from the format_data callback + ---@return boolean + is_complete = function(data) + if data then + data = vim.fn.json_decode(data) + return data.done + end + return false + end, + + ---Output the data from the API ready for insertion into the chat buffer + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback + ---@return table|nil + chat_output = function(data) + local output = {} + + if data and data ~= "" then + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return { + status = "error", + output = string.format("Error malformed json: %s", json), + } + end + + local message = json.message + + if message.content then + output.content = message.content + output.role = message.role or nil + end + + -- log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) + + return { + status = "success", + output = output, + } + end + + return nil + end, + + ---Output the data from the API ready for inlining into the current buffer + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback + ---@param context table Useful context about the buffer to inline to + ---@return table|nil + inline_output = function(data, context) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return + end + + return json.message.content + end, + }, + schema = { + model = { + order = 1, + mapping = "parameters", + type = "enum", + desc = "ID of the model to use.", + default = "llama2", + choices = { + "llama2", + "mistral", + "dolphin-phi", + "phi", + }, + }, + temperature = { + order = 2, + mapping = "parameters.options", + type = "number", + optional = true, + default = 0.8, + desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 2, "Must be between 0 and 2" + end, + }, + }, +} diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua new file mode 100644 index 00000000..f6f9113e --- /dev/null +++ b/lua/codecompanion/adapters/openai.lua @@ -0,0 +1,209 @@ +local log = require("codecompanion.utils.log") + +---@class CodeCompanion.Adapter +---@field name string +---@field url string +---@field raw? table +---@field headers table +---@field parameters table +---@field callbacks table +---@field schema table +return { + name = "OpenAI", + url = "https://api.openai.com/v1/chat/completions", + env = { + api_key = "OPENAI_API_KEY", + }, + raw = { + "--no-buffer", + "--silent", + }, + headers = { + ["Content-Type"] = "application/json", + Authorization = "Bearer ${api_key}", + }, + parameters = { + stream = true, + }, + callbacks = { + ---Set the format of the role and content for the messages from the chat buffer + ---@param messages table Format is: { { role = "user", content = "Your prompt here" } } + ---@return table + form_messages = function(messages) + return { messages = messages } + end, + + ---Has the streaming completed? + ---@param data string The streamed data from the API + ---@return boolean + is_complete = function(data) + if data then + data = data:sub(7) + return data == "[DONE]" + end + return false + end, + + ---Output the data from the API ready for insertion into the chat buffer + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback + ---@return table|nil + chat_output = function(data) + local output = {} + + if data and data ~= "" then + data = data:sub(7) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return { + status = "error", + output = string.format("Error malformed json: %s", json), + } + end + + local delta = json.choices[1].delta + + if delta.content then + output.content = delta.content + output.role = delta.role or nil + end + + -- log:trace("----- For Adapter test creation -----\nOutput: %s\n ---------- // END ----------", output) + + return { + status = "success", + output = output, + } + end + + return nil + end, + + ---Output the data from the API ready for inlining into the current buffer + ---@param data table The streamed JSON data from the API, also formatted by the format_data callback + ---@param context table Useful context about the buffer to inline to + ---@return table|nil + inline_output = function(data, context) + data = data:sub(7) + local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } }) + + if not ok then + return + end + + return json.choices[1].delta.content + end, + }, + schema = { + model = { + order = 1, + mapping = "parameters", + type = "enum", + desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + default = "gpt-4-0125-preview", + choices = { + "gpt-4-1106-preview", + "gpt-4", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + }, + }, + temperature = { + order = 2, + mapping = "parameters", + type = "number", + optional = true, + default = 1, + desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 2, "Must be between 0 and 2" + end, + }, + top_p = { + order = 3, + mapping = "parameters", + type = "number", + optional = true, + default = 1, + desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", + validate = function(n) + return n >= 0 and n <= 1, "Must be between 0 and 1" + end, + }, + stop = { + order = 4, + mapping = "parameters", + type = "list", + optional = true, + default = nil, + subtype = { + type = "string", + }, + desc = "Up to 4 sequences where the API will stop generating further tokens.", + validate = function(l) + return #l >= 1 and #l <= 4, "Must have between 1 and 4 elements" + end, + }, + max_tokens = { + order = 5, + mapping = "parameters", + type = "integer", + optional = true, + default = nil, + desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", + validate = function(n) + return n > 0, "Must be greater than 0" + end, + }, + presence_penalty = { + order = 6, + mapping = "parameters", + type = "number", + optional = true, + default = 0, + desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", + validate = function(n) + return n >= -2 and n <= 2, "Must be between -2 and 2" + end, + }, + frequency_penalty = { + order = 7, + mapping = "parameters", + type = "number", + optional = true, + default = 0, + desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", + validate = function(n) + return n >= -2 and n <= 2, "Must be between -2 and 2" + end, + }, + logit_bias = { + order = 8, + mapping = "parameters", + type = "map", + optional = true, + default = nil, + desc = "Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs.", + subtype_key = { + type = "integer", + }, + subtype = { + type = "integer", + validate = function(n) + return n >= -100 and n <= 100, "Must be between -100 and 100" + end, + }, + }, + user = { + order = 9, + mapping = "parameters", + type = "string", + optional = true, + default = nil, + desc = "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.", + validate = function(u) + return u:len() < 100, "Cannot be longer than 100 characters" + end, + }, + }, +} diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 53646e54..f6788157 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -1,12 +1,14 @@ -local config = require("codecompanion.config") +local curl = require("plenary.curl") local log = require("codecompanion.utils.log") local schema = require("codecompanion.schema") +local api = vim.api + _G.codecompanion_jobs = {} ---@param status string local function fire_autocmd(status) - vim.api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = status } }) + api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = status } }) end ---@param bufnr? number @@ -23,7 +25,7 @@ end ---@param bufnr? number ---@param opts? table -local function close_request(bufnr, opts) +local function stop_request(bufnr, opts) if bufnr then if opts and opts.shutdown then _G.codecompanion_jobs[bufnr].handler:shutdown() @@ -33,162 +35,87 @@ local function close_request(bufnr, opts) fire_autocmd("finished") end ----@param client CodeCompanion.Client ----@return table -local function headers(client) - local group = { - content_type = "application/json", - Authorization = "Bearer " .. client.secret_key, - OpenAI_Organization = client.organization, - } - - log:debug("Request Headers: %s", group) - - return group -end - ----@param code integer ----@param stdout string ----@return nil|string ----@return nil|any -local function parse_response(code, stdout) - if code ~= 0 then - log:error("Error: %s", stdout) - return string.format("Error: %s", stdout) - end - - local ok, data = pcall(vim.json.decode, stdout, { luanil = { object = true } }) - if not ok then - log:error("Error malformed json: %s", data) - return string.format("Error malformed json: %s", data) - end - - if data.error then - log:error("API Error: %s", data.error.message) - return string.format("API Error: %s", data.error.message) - end - - return nil, data -end - ---@class CodeCompanion.Client +---@field static table ---@field secret_key string ---@field organization nil|string ----@field settings nil|table +---@field opts nil|table local Client = {} +Client.static = {} + +Client.static.opts = { + request = { default = curl.post }, + encode = { default = vim.json.encode }, + schedule = { default = vim.schedule_wrap }, +} ---@class CodeCompanion.ClientArgs ---@field secret_key string ---@field organization nil|string ----@field settings nil|table +---@field opts nil|table ----@param args CodeCompanion.ClientArgs +---@param args? CodeCompanion.ClientArgs ---@return CodeCompanion.Client function Client.new(args) + args = args or {} + return setmetatable({ - secret_key = args.secret_key, - organization = args.organization, - settings = args.settings or schema.get_default(schema.static.client_settings, args.settings), + opts = args.opts or schema.get_default(Client.static.opts, args.opts), }, { __index = Client }) end ----Call the OpenAI API but block the main loop until the response is received ----@param url string ----@param payload table ----@param cb fun(err: nil|string, response: nil|table) -function Client:block_request(url, payload, cb) - cb = log:wrap_cb(cb, "Response error: %s") - - local cmd = { - "curl", - url, - "--silent", - "--no-buffer", - "-H", - "Content-Type: application/json", - "-H", - string.format("Authorization: Bearer %s", self.secret_key), - } - - if self.organization then - table.insert(cmd, "-H") - table.insert(cmd, string.format("OpenAI-Organization: %s", self.organization)) - end - - table.insert(cmd, "-d") - table.insert(cmd, vim.json.encode(payload)) - log:trace("Request payload: %s", cmd) - - local result = vim.fn.system(cmd) - - if vim.v.shell_error ~= 0 then - log:error("Error calling curl: %s", result) - return cb("Error executing curl", nil) - else - local err, data = parse_response(0, result) - if err then - return cb(err, nil) - else - return cb(nil, data) - end - end -end - ----@param url string ----@param payload table +---@param adapter CodeCompanion.Adapter +---@param payload table the messages to send to the API ---@param bufnr number ---@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ---@return nil -function Client:stream_request(url, payload, bufnr, cb) +function Client:stream(adapter, payload, bufnr, cb) cb = log:wrap_cb(cb, "Response error: %s") - local handler = self.settings.request({ - url = url, - raw = { "--no-buffer" }, - headers = headers(self), - body = self.settings.encode(payload), - stream = function(_, chunk) - chunk = chunk:sub(7) - - if chunk ~= "" then - if chunk == "[DONE]" then - self.settings.schedule(function() - close_request(bufnr) - return cb(nil, nil, true) - end) - else - self.settings.schedule(function() - if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then - close_request(bufnr, { shutdown = true }) - return cb(nil, nil, true) - end - - local ok, data = pcall(self.settings.decode, chunk, { luanil = { object = true } }) - - if not ok then - log:error("Error malformed json: %s", data) - close_request(bufnr) - return cb(string.format("Error malformed json: %s", data)) - end - - if data.choices[1].finish_reason then - log:debug("Finish Reason: %s", data.choices[1].finish_reason) - end - - if data.choices[1].finish_reason == "length" then - log:debug("Token limit reached") - close_request(bufnr) - return cb("[CodeCompanion.nvim]\nThe token limit for the current chat has been reached") - end - - cb(nil, data) - end) - end + --TODO: Check for any errors env variables + local headers = adapter:replace_header_vars().headers + local body = + self.opts.encode(vim.tbl_extend("keep", adapter.parameters or {}, adapter.callbacks.form_messages(payload))) + + local stop_request_cmd = api.nvim_create_autocmd("User", { + desc = "Stop the current request", + pattern = "CodeCompanionRequest", + callback = function(request) + if request.data.buf ~= bufnr or request.data.action ~= "stop_request" then + return end + + return stop_request(bufnr, { shutdown = true }) end, + }) + + local handler = self.opts.request({ + url = adapter.url, + raw = adapter.raw or { "--no-buffer" }, + headers = headers, + body = body, + stream = self.opts.schedule(function(_, data) + log:trace("Chat data: %s", data) + -- log:trace("----- For Adapter test creation -----\nRequest: %s\n ---------- // END ----------", data) + + if _G.codecompanion_jobs[bufnr] and _G.codecompanion_jobs[bufnr].status == "stopping" then + stop_request(bufnr, { shutdown = true }) + return cb(nil, nil, true) + end + + if adapter.callbacks.is_complete(data) then + log:trace("Chat completed") + stop_request(bufnr) + api.nvim_del_autocmd(stop_request_cmd) + return cb(nil, nil, true) + end + + cb(nil, data) + end), on_error = function(err, _, _) - close_request(bufnr) log:error("Error: %s", err) + stop_request(bufnr) + api.nvim_del_autocmd(stop_request_cmd) end, }) @@ -196,49 +123,4 @@ function Client:stream_request(url, payload, bufnr, cb) start_request(bufnr, handler) end ----@class CodeCompanion.ChatMessage ----@field role "system"|"user"|"assistant" ----@field content string - ----@class CodeCompanion.ChatSettings ----@field model string ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. ----@field temperature nil|number Defaults to 1. What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. ----@field top_p nil|number Defaults to 1. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. ----@field n nil|integer Defaults to 1. How many chat completion choices to generate for each input message. ----@field stop nil|string|string[] Defaults to nil. Up to 4 sequences where the API will stop generating further tokens. ----@field max_tokens nil|integer Defaults to nil. The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. ----@field presence_penalty nil|number Defaults to 0. Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. ----@field frequency_penalty nil|number Defaults to 0. Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. ----@field logit_bias nil|table Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs. ----@field user nil|string A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more. - ----@class CodeCompanion.ChatArgs : CodeCompanion.ChatSettings ----@field messages CodeCompanion.ChatMessage[] The messages to generate chat completions for, in the chat format. ----@field stream boolean? Whether to stream the chat output back to Neovim - ----@param args CodeCompanion.ChatArgs ----@param cb fun(err: nil|string, response: nil|table) ----@return nil -function Client:chat(args, cb) - return self:block_request(config.options.base_url .. "/v1/chat/completions", args, cb) -end - ----@param args CodeCompanion.ChatArgs ----@param bufnr integer ----@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ----@return nil -function Client:stream_chat(args, bufnr, cb) - args.stream = true - return self:stream_request(config.options.base_url .. "/v1/chat/completions", args, bufnr, cb) -end - ----@class args CodeCompanion.InlineArgs ----@param bufnr integer ----@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ----@return nil -function Client:inline(args, bufnr, cb) - args.stream = true - return self:stream_request(config.options.base_url .. "/v1/chat/completions", args, bufnr, cb) -end - return Client diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index 3d5cafe5..e59690e2 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -1,32 +1,9 @@ local M = {} local defaults = { - api_key = "OPENAI_API_KEY", - org_api_key = "OPENAI_ORG_KEY", - base_url = "https://api.openai.com", - ai_settings = { - chat = { - model = "gpt-4-0125-preview", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, - inline = { - model = "gpt-4-0125-preview", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, + adapters = { + chat = require("codecompanion.adapters").use("openai"), + inline = require("codecompanion.adapters").use("openai"), }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", diff --git a/lua/codecompanion/health.lua b/lua/codecompanion/health.lua index 30bf7a43..5203496d 100644 --- a/lua/codecompanion/health.lua +++ b/lua/codecompanion/health.lua @@ -94,17 +94,17 @@ function M.check() end end - for _, env in ipairs(M.env_vars) do - if env_available(env.name) then - ok(fmt("%s key found", env.name)) - else - if env.optional then - warn(fmt("%s key not found", env.name)) - else - error(fmt("%s key not found", env.name)) - end - end - end + -- for _, env in ipairs(M.env_vars) do + -- if env_available(env.name) then + -- ok(fmt("%s key found", env.name)) + -- else + -- if env.optional then + -- warn(fmt("%s key not found", env.name)) + -- else + -- error(fmt("%s key not found", env.name)) + -- end + -- end + -- end end return M diff --git a/lua/codecompanion/init.lua b/lua/codecompanion/init.lua index 414a0b27..371dc5d6 100644 --- a/lua/codecompanion/init.lua +++ b/lua/codecompanion/init.lua @@ -4,50 +4,20 @@ local util = require("codecompanion.utils.util") local M = {} -local _client ----@return nil|CodeCompanion.Client -function M.get_client() - if not _client then - local secret_key = os.getenv(config.options.api_key) - if not secret_key then - vim.notify( - string.format("[CodeCompanion.nvim]\nCould not find env variable: %s", config.options.api_key), - vim.log.levels.ERROR - ) - return nil - end - - local Client = require("codecompanion.client") - - _client = Client.new({ - secret_key = secret_key, - organization = os.getenv(config.options.org_api_key), - }) - end - - return _client -end - ---@param bufnr nil|integer ---@return nil|CodeCompanion.Chat M.buf_get_chat = function(bufnr) - return require("codecompanion.strategy.chat").buf_get_chat(bufnr) + return require("codecompanion.strategies.chat").buf_get_chat(bufnr) end ---@param args table ---@return nil|CodeCompanion.Inline M.inline = function(args) - local client = M.get_client() - if not client then - return - end - local context = util.get_context(vim.api.nvim_get_current_buf(), args) - return require("codecompanion.strategy.inline") + return require("codecompanion.strategies.inline") .new({ context = context, - client = client, prompts = { { role = "system", @@ -64,15 +34,9 @@ end ---@param args? table M.chat = function(args) - local client = M.get_client() - if not client then - return - end - local context = util.get_context(vim.api.nvim_get_current_buf(), args) - local chat = require("codecompanion.strategy.chat").new({ - client = client, + local chat = require("codecompanion.strategies.chat").new({ context = context, }) @@ -118,11 +82,6 @@ end local _cached_actions = {} M.actions = function(args) - local client = M.get_client() - if not client then - return - end - local actions = require("codecompanion.actions") local context = util.get_context(vim.api.nvim_get_current_buf(), args) @@ -164,7 +123,6 @@ M.actions = function(args) else local Strategy = require("codecompanion.strategy") return Strategy.new({ - client = client, context = context, selected = item, }):start(item.strategy) diff --git a/lua/codecompanion/keymaps.lua b/lua/codecompanion/keymaps.lua index 67868b8b..3b5b258a 100644 --- a/lua/codecompanion/keymaps.lua +++ b/lua/codecompanion/keymaps.lua @@ -33,8 +33,8 @@ M.cancel_request = { M.save_chat = { desc = "Save the current chat", callback = function(args) - local chat = require("codecompanion.strategy.chat") - local saved_chat = require("codecompanion.strategy.saved_chats").new({}) + local chat = require("codecompanion.strategies.chat") + local saved_chat = require("codecompanion.strategies.saved_chats").new({}) if args.saved_chat then saved_chat.filename = args.saved_chat diff --git a/lua/codecompanion/schema.lua b/lua/codecompanion/schema.lua index a13743bf..76093775 100644 --- a/lua/codecompanion/schema.lua +++ b/lua/codecompanion/schema.lua @@ -1,6 +1,3 @@ -local config = require("codecompanion.config") -local curl = require("plenary.curl") - local M = {} M.get_default = function(schema, defaults) @@ -19,6 +16,7 @@ end ---@class CodeCompanion.SchemaParam ---@field type "string"|"number"|"integer"|"boolean"|"enum"|"list"|"map" +---@field mapping string ---@field order nil|integer ---@field optional nil|boolean ---@field choices nil|table @@ -120,119 +118,4 @@ M.get_ordered_keys = function(schema) return keys end -M.static = {} - -local model_choices = { - "gpt-4-1106-preview", - "gpt-4", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo", -} - -M.static.chat_settings = { - model = { - order = 1, - type = "enum", - desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - default = config.options.ai_settings.chat.model, - choices = model_choices, - }, - temperature = { - order = 2, - type = "number", - optional = true, - default = config.options.ai_settings.chat.temperature, - desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", - validate = function(n) - return n >= 0 and n <= 2, "Must be between 0 and 2" - end, - }, - top_p = { - order = 3, - type = "number", - optional = true, - default = config.options.ai_settings.chat.top_p, - desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", - validate = function(n) - return n >= 0 and n <= 1, "Must be between 0 and 1" - end, - }, - stop = { - order = 4, - type = "list", - optional = true, - default = config.options.ai_settings.chat.stop, - subtype = { - type = "string", - }, - desc = "Up to 4 sequences where the API will stop generating further tokens.", - validate = function(l) - return #l >= 1 and #l <= 4, "Must have between 1 and 4 elements" - end, - }, - max_tokens = { - order = 5, - type = "integer", - optional = true, - default = config.options.ai_settings.chat.max_tokens, - desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", - validate = function(n) - return n > 0, "Must be greater than 0" - end, - }, - presence_penalty = { - order = 6, - type = "number", - optional = true, - default = config.options.ai_settings.chat.presence_penalty, - desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", - validate = function(n) - return n >= -2 and n <= 2, "Must be between -2 and 2" - end, - }, - frequency_penalty = { - order = 7, - type = "number", - optional = true, - default = config.options.ai_settings.chat.frequency_penalty, - desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", - validate = function(n) - return n >= -2 and n <= 2, "Must be between -2 and 2" - end, - }, - logit_bias = { - order = 8, - type = "map", - optional = true, - default = config.options.ai_settings.chat.logit_bias, - desc = "Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs.", - subtype_key = { - type = "integer", - }, - subtype = { - type = "integer", - validate = function(n) - return n >= -100 and n <= 100, "Must be between -100 and 100" - end, - }, - }, - user = { - order = 9, - type = "string", - optional = true, - default = config.options.ai_settings.chat.user, - desc = "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.", - validate = function(u) - return u:len() < 100, "Cannot be longer than 100 characters" - end, - }, -} - -M.static.client_settings = { - request = { default = curl.post }, - encode = { default = vim.json.encode }, - decode = { default = vim.json.decode }, - schedule = { default = vim.schedule }, -} - return M diff --git a/lua/codecompanion/strategy/chat.lua b/lua/codecompanion/strategies/chat.lua similarity index 78% rename from lua/codecompanion/strategy/chat.lua rename to lua/codecompanion/strategies/chat.lua index 9935a2cb..6f4623fe 100644 --- a/lua/codecompanion/strategy/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -1,3 +1,4 @@ +local client = require("codecompanion.client") local config = require("codecompanion.config") local log = require("codecompanion.utils.log") local schema = require("codecompanion.schema") @@ -33,7 +34,7 @@ local function parse_settings(bufnr) end if not config.options.display.chat.show_settings then - config_settings[bufnr] = vim.deepcopy(config.options.ai_settings.chat) + config_settings[bufnr] = config.options.adapters.chat:get_default_settings() log:debug("Using the settings from the user's config: %s", config_settings[bufnr]) return config_settings[bufnr] @@ -57,7 +58,7 @@ end ---@param bufnr integer ---@return table ----@return CodeCompanion.ChatMessage[] +---@return table local function parse_messages_buffer(bufnr) local ret = {} @@ -95,15 +96,14 @@ local function parse_messages_buffer(bufnr) end ---@param bufnr integer ----@param settings CodeCompanion.ChatSettings ----@param messages CodeCompanion.ChatMessage[] +---@param settings table +---@param messages table ---@param context table local function render_messages(bufnr, settings, messages, context) local lines = {} if config.options.display.chat.show_settings then - -- Put the settings at the top of the buffer lines = { "---" } - local keys = schema.get_ordered_keys(schema.static.chat_settings) + local keys = schema.get_ordered_keys(config.options.adapters.chat.schema) for _, key in ipairs(keys) do table.insert(lines, string.format("%s: %s", key, yaml.encode(settings[key]))) end @@ -112,6 +112,13 @@ local function render_messages(bufnr, settings, messages, context) table.insert(lines, "") end + -- Start with the user heading + if #messages == 0 then + table.insert(lines, "# user") + table.insert(lines, "") + table.insert(lines, "") + end + -- Put the messages in the buffer for i, message in ipairs(messages) do if i > 1 then @@ -199,7 +206,7 @@ local function chat_autocmds(bufnr, args) buffer = bufnr, callback = function() local settings = parse_settings(bufnr) - local errors = schema.validate(schema.static.chat_settings, settings) + local errors = schema.validate(config.options.adapters.chat.schema, settings) local node = settings.__ts_node local items = {} if errors and node then @@ -316,26 +323,27 @@ local function chat_autocmds(bufnr, args) _G.codecompanion_chats[bufnr] = nil - _G.codecompanion_jobs[request.data.buf].handler:shutdown() - vim.api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = "finished" } }) - vim.api.nvim_buf_delete(request.data.buf, { force = true }) + if _G.codecompanion_jobs[request.data.buf] then + _G.codecompanion_jobs[request.data.buf].handler:shutdown() + end + api.nvim_exec_autocmds("User", { pattern = "CodeCompanionRequest", data = { status = "finished" } }) + api.nvim_buf_delete(request.data.buf, { force = true }) end, }) end ---@class CodeCompanion.Chat ----@field client CodeCompanion.Client ---@field bufnr integer ----@field settings CodeCompanion.ChatSettings +---@field settings table local Chat = {} ---@class CodeCompanion.ChatArgs ----@field client CodeCompanion.Client +---@field adapter CodeCompanion.Adapter ---@field context table ----@field messages nil|CodeCompanion.ChatMessage[] +---@field messages nil|table ---@field show_buffer nil|boolean ---@field auto_submit nil|boolean ----@field settings nil|CodeCompanion.ChatSettings +---@field settings nil|table ---@field type nil|string ---@field saved_chat nil|string @@ -362,11 +370,10 @@ function Chat.new(args) watch_cursor() chat_autocmds(bufnr, args) - local settings = args.settings or schema.get_default(schema.static.chat_settings, args.settings) + local settings = args.settings or schema.get_default(config.options.adapters.chat.schema, args.settings) local self = setmetatable({ bufnr = bufnr, - client = args.client, context = args.context, saved_chat = args.saved_chat, settings = settings, @@ -379,7 +386,10 @@ function Chat.new(args) keys.set_keymaps(config.options.keymaps, bufnr, self) render_messages(bufnr, settings, args.messages or {}, args.context or {}) - display_tokens(bufnr) + + if args.saved_chat then + display_tokens(bufnr) + end if config.options.display.chat.type == "float" then winid = ui.open_float(bufnr, { @@ -417,61 +427,82 @@ function Chat:submit() vim.bo[self.bufnr].modifiable = true end - local function render_buffer() - local line_count = api.nvim_buf_line_count(self.bufnr) + local role = "" + local function render_new_messages(data) + local total_lines = api.nvim_buf_line_count(self.bufnr) local current_line = api.nvim_win_get_cursor(0)[1] - local cursor_moved = current_line == line_count + local cursor_moved = current_line == total_lines - render_messages(self.bufnr, settings, messages, {}) + local lines = {} + if data.role and data.role ~= role then + role = data.role + table.insert(lines, "") + table.insert(lines, "") + table.insert(lines, string.format("# %s", data.role)) + table.insert(lines, "") + end + + if data.content then + for _, text in ipairs(vim.split(data.content, "\n", { plain = true, trimempty = false })) do + table.insert(lines, text) + end - if cursor_moved and ui.buf_is_active(self.bufnr) then - ui.buf_scroll_to_end(self.bufnr) - elseif not ui.buf_is_active(self.bufnr) then - ui.buf_scroll_to_end(self.bufnr) + local modifiable = vim.bo[self.bufnr].modifiable + vim.bo[self.bufnr].modifiable = true + + local last_line = api.nvim_buf_get_lines(self.bufnr, total_lines - 1, total_lines, false)[1] + local last_col = last_line and #last_line or 0 + api.nvim_buf_set_text(self.bufnr, total_lines - 1, last_col, total_lines - 1, last_col, lines) + + vim.bo[self.bufnr].modified = false + vim.bo[self.bufnr].modifiable = modifiable + + if cursor_moved and ui.buf_is_active(self.bufnr) then + ui.buf_scroll_to_end(self.bufnr) + elseif not ui.buf_is_active(self.bufnr) then + ui.buf_scroll_to_end(self.bufnr) + end end end - local new_message = messages[#messages] + local current_message = messages[#messages] - if new_message and new_message.role == "user" and new_message.content == "" then + if current_message and current_message.role == "user" and current_message.content == "" then return finalize() end - self.client:stream_chat( - vim.tbl_extend("keep", settings, { - messages = messages, - }), - self.bufnr, - function(err, chunk, done) - if err then - log:error("Error: %s", err) - vim.notify("Error: " .. err, vim.log.levels.ERROR) - return finalize() - end + -- log:trace("----- For Adapter test creation -----\nMessages: %s\n ---------- // END ----------", messages) + -- log:trace("Settings: %s", settings) - if chunk then - log:trace("Chat chunk: %s", chunk) - local delta = chunk.choices[1].delta - if delta.role and delta.role ~= new_message.role then - new_message = { role = delta.role, content = "" } - table.insert(messages, new_message) - end + local adapter = config.options.adapters.chat - if delta.content then - new_message.content = new_message.content .. delta.content - end + client.new():stream(adapter:set_params(settings), messages, self.bufnr, function(err, data, done) + if err then + vim.notify("Error: " .. err, vim.log.levels.ERROR) + return finalize() + end - render_buffer() - end + if done then + render_new_messages({ role = "user", content = "" }) + display_tokens(self.bufnr) + return finalize() + end - if done then - table.insert(messages, { role = "user", content = "" }) - render_buffer() - display_tokens(self.bufnr) - finalize() + if data then + local result = adapter.callbacks.chat_output(data) + + if result and result.status == "success" then + render_new_messages(result.output) + elseif result and result.status == "error" then + vim.api.nvim_exec_autocmds( + "User", + { pattern = "CodeCompanionRequest", data = { buf = self.bufnr, action = "stop_request" } } + ) + vim.notify("Error: " .. result.output, vim.log.levels.ERROR) + return finalize() end end - ) + end) end ---@param opts nil|table @@ -498,7 +529,7 @@ function Chat:on_cursor_moved() vim.diagnostic.set(config.INFO_NS, self.bufnr, {}) return end - local key_schema = schema.static.chat_settings[key_name] + local key_schema = config.options.adapters.chat.schema[key_name] if key_schema and key_schema.desc then local lnum, col, end_lnum, end_col = node:range() @@ -526,7 +557,7 @@ function Chat:complete(request, callback) return end - local key_schema = schema.static.chat_settings[key_name] + local key_schema = config.options.adapters.chat.schema[key_name] if key_schema.type == "enum" then for _, choice in ipairs(key_schema.choices) do table.insert(items, { diff --git a/lua/codecompanion/strategy/inline.lua b/lua/codecompanion/strategies/inline.lua similarity index 67% rename from lua/codecompanion/strategy/inline.lua rename to lua/codecompanion/strategies/inline.lua index 6529fb5f..b4e921d9 100644 --- a/lua/codecompanion/strategy/inline.lua +++ b/lua/codecompanion/strategies/inline.lua @@ -1,4 +1,6 @@ +local client = require("codecompanion.client") local config = require("codecompanion.config") +local adapter = config.options.adapters.inline local log = require("codecompanion.utils.log") local ui = require("codecompanion.utils.ui") @@ -9,6 +11,14 @@ local function fire_autocmd(status) vim.api.nvim_exec_autocmds("User", { pattern = "CodeCompanionInline", data = { status = status } }) end +---When a user initiates an inline request, it will be possible to infer from +---their prompt, how the output should be placed in the buffer. For instance, if +---they reference the words "refactor" or "update" in their prompt, they will +---likely want the response to replace a visual selection they've made in the +---editor. However, if they include words such as "after" or "before", it may +---be insinuated that they wish the response to be placed after or before the +---cursor. In this function, we use the power of Generative AI to determine +---the user's intent and return a placement position. ---@param inline CodeCompanion.Inline ---@param prompt string ---@return string, boolean @@ -32,24 +42,17 @@ local function get_placement_position(inline, prompt) } local output - inline.client:chat( - vim.tbl_extend("keep", inline.settings, { - messages = messages, - }), - function(err, chunk) - if err then - return - end + client.new():call(adapter:set_params(), messages, function(err, data) + if err then + return + end - if chunk then - local content = chunk.choices[1].message.content - if content then - log:trace("Placement response: %s", content) - output = content - end - end + if data then + print(vim.inspect(data)) end - ) + end) + + log:trace("Placement output: %s", output) if output then local parts = vim.split(output, "|") @@ -131,17 +134,13 @@ local function get_cursor(winid) end ---@class CodeCompanion.Inline ----@field settings table ---@field context table ----@field client CodeCompanion.Client ---@field opts table ---@field prompts table local Inline = {} ---@class CodeCompanion.InlineArgs ----@field settings table ---@field context table ----@field client CodeCompanion.Client ---@field opts table ---@field pre_hook fun():number -- Assuming pre_hook returns a number for example ---@field prompts table @@ -162,9 +161,7 @@ function Inline.new(opts) end return setmetatable({ - settings = config.options.ai_settings.inline, context = opts.context, - client = opts.client, opts = opts.opts or {}, prompts = vim.deepcopy(opts.prompts), }, { __index = Inline }) @@ -176,54 +173,10 @@ function Inline:execute(user_input) local messages = get_action(self, user_input) - if not self.opts.placement and user_input then - local return_code - self.opts.placement, return_code = get_placement_position(self, user_input) - - if not return_code then - table.insert(messages, { - role = "user", - content = "Please do not return the code I have sent in the response", - }) - end - - log:debug("Setting the placement to: %s", self.opts.placement) - end - - -- Determine where to place the response in the buffer - if self.opts and self.opts.placement then - if self.opts.placement == "before" then - log:trace("Placing before selection") - vim.api.nvim_buf_set_lines( - self.context.bufnr, - self.context.start_line - 1, - self.context.start_line - 1, - false, - { "" } - ) - self.context.start_line = self.context.start_line + 1 - pos.line = self.context.start_line - 1 - pos.col = self.context.start_col - 1 - elseif self.opts.placement == "after" then - log:trace("Placing after selection") - vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) - pos.line = self.context.end_line + 1 - pos.col = 0 - elseif self.opts.placement == "replace" then - log:trace("Placing by overwriting selection") - overwrite_selection(self.context) - - pos.line, pos.col = get_cursor(self.context.winid) - elseif self.opts.placement == "new" then - log:trace("Placing in a new buffer") - self.context.bufnr = api.nvim_create_buf(true, false) - api.nvim_buf_set_option(self.context.bufnr, "filetype", self.context.filetype) - api.nvim_set_current_buf(self.context.bufnr) - - pos.line = 1 - pos.col = 0 - end - end + -- Assume the placement should be after the cursor + vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) + pos.line = self.context.end_line + 1 + pos.col = 0 log:debug("Context for inline: %s", self.context) log:debug("Cursor position to use: %s", pos) @@ -268,43 +221,39 @@ function Inline:execute(user_input) fire_autocmd("started") local output = {} - self.client:stream_chat( - vim.tbl_extend("keep", self.settings, { - messages = messages, - }), - self.context.bufnr, - function(err, chunk, done) - if err then - return - end + client.new():stream(adapter:set_params(), messages, self.context.bufnr, function(err, data, done) + if err then + fire_autocmd("finished") + return + end - if chunk then - log:trace("Chat chunk: %s", chunk) - - local delta = chunk.choices[1].delta - if delta.content and not delta.role and delta.content ~= "```" and delta.content ~= self.context.filetype then - if self.context.buftype == "terminal" then - -- Don't stream to the terminal - table.insert(output, delta.content) - else - stream_buffer_text(delta.content) - if self.opts and self.opts.placement == "new" then - ui.buf_scroll_to_end(self.context.bufnr) - end + if data then + log:trace("Inline data: %s", data) + + local content = adapter.callbacks.inline_output(data, self.context) + + if self.context.buftype == "terminal" then + -- Don't stream to the terminal + table.insert(output, content) + else + if content then + stream_buffer_text(content) + if self.opts and self.opts.placement == "new" then + ui.buf_scroll_to_end(self.context.bufnr) end end end + end - if done then - api.nvim_buf_del_keymap(self.context.bufnr, "n", "q") - if self.context.buftype == "terminal" then - log:debug("Terminal output: %s", output) - api.nvim_put({ table.concat(output, "") }, "", false, true) - end - fire_autocmd("finished") + if done then + api.nvim_buf_del_keymap(self.context.bufnr, "n", "q") + if self.context.buftype == "terminal" then + log:debug("Terminal output: %s", output) + api.nvim_put({ table.concat(output, "") }, "", false, true) end + fire_autocmd("finished") end - ) + end) end ---@param user_input? string diff --git a/lua/codecompanion/strategy/saved_chats.lua b/lua/codecompanion/strategies/saved_chats.lua similarity index 98% rename from lua/codecompanion/strategy/saved_chats.lua rename to lua/codecompanion/strategies/saved_chats.lua index 018ed44a..2850cba5 100644 --- a/lua/codecompanion/strategy/saved_chats.lua +++ b/lua/codecompanion/strategies/saved_chats.lua @@ -1,4 +1,4 @@ -local Chat = require("codecompanion.strategy.chat") +local Chat = require("codecompanion.strategies.chat") local config = require("codecompanion.config") local log = require("codecompanion.utils.log") diff --git a/lua/codecompanion/strategy.lua b/lua/codecompanion/strategy.lua index 6bf09c31..0d8f8ace 100644 --- a/lua/codecompanion/strategy.lua +++ b/lua/codecompanion/strategy.lua @@ -28,13 +28,11 @@ local function modal_prompts(prompts, context) end ---@class CodeCompanion.Strategy ----@field client CodeCompanion.Client ---@field context table ---@field selected table local Strategy = {} ---@class CodeCompanion.StrategyArgs ----@field client CodeCompanion.Client ---@field context table ---@field selected table @@ -43,7 +41,6 @@ local Strategy = {} function Strategy.new(args) log:trace("Context: %s", args.context) return setmetatable({ - client = args.client, context = args.context, selected = args.selected, }, { __index = Strategy }) @@ -76,8 +73,7 @@ function Strategy:chat() }) end - return require("codecompanion.strategy.chat").new({ - client = self.client, + return require("codecompanion.strategies.chat").new({ type = self.selected.type, messages = messages, show_buffer = true, @@ -99,10 +95,9 @@ function Strategy:chat() end function Strategy:inline() - return require("codecompanion.strategy.inline") + return require("codecompanion.strategies.inline") .new({ context = self.context, - client = self.client, opts = self.selected.opts, pre_hook = self.selected.pre_hook, prompts = self.selected.prompts, diff --git a/lua/codecompanion/utils/yaml.lua b/lua/codecompanion/utils/yaml.lua index 9f0022c8..b8775909 100644 --- a/lua/codecompanion/utils/yaml.lua +++ b/lua/codecompanion/utils/yaml.lua @@ -7,7 +7,11 @@ M.encode = function(data) if data == nil then return "null" elseif dt == "number" then - return string.format("%d", data) + if data % 1 == 0 then + return string.format("%d", data) + else + return string.format("%.1f", data) + end elseif dt == "boolean" then return string.format("%s", data) elseif dt == "string" then diff --git a/lua/spec/codecompanion/adapter_spec.lua b/lua/spec/codecompanion/adapter_spec.lua new file mode 100644 index 00000000..3a93ae7e --- /dev/null +++ b/lua/spec/codecompanion/adapter_spec.lua @@ -0,0 +1,91 @@ +local assert = require("luassert") + +local test_adapter = { + name = "TestAdapter", + url = "https://api.testgenai.com/v1/chat/completions", + headers = { + content_type = "application/json", + }, + parameters = { + stream = true, + }, + schema = { + model = { + order = 1, + mapping = "parameters.data", + type = "enum", + desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + default = "gpt-4-0125-preview", + choices = { + "gpt-4-1106-preview", + "gpt-4", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + }, + }, + temperature = { + order = 2, + mapping = "parameters.options", + type = "number", + optional = true, + default = 1, + desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", + validate = function(n) + return n >= 0 and n <= 2, "Must be between 0 and 2" + end, + }, + top_p = { + order = 3, + mapping = "parameters.options", + type = "number", + optional = true, + default = 1, + desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", + validate = function(n) + return n >= 0 and n <= 1, "Must be between 0 and 1" + end, + }, + }, +} + +local chat_buffer_settings = { + frequency_penalty = 0, + model = "gpt-4-0125-preview", + presence_penalty = 0, + temperature = 1, + top_p = 1, + stop = nil, + max_tokens = nil, + logit_bias = nil, + user = nil, +} + +describe("Adapter", function() + it("can form parameters from a chat buffer's settings", function() + local adapter = require("codecompanion.adapters").use("openai") + local result = adapter:set_params(chat_buffer_settings) + + -- Ignore this for now + result.parameters.stream = nil + + assert.are.same(chat_buffer_settings, result.parameters) + end) + + it("can nest parameters based on an adapter's schema", function() + local adapter = require("codecompanion.adapters").use(test_adapter) + local result = adapter:set_params(chat_buffer_settings) + + local expected = { + stream = true, + data = { + model = "gpt-4-0125-preview", + }, + options = { + temperature = 1, + top_p = 1, + }, + } + + assert.are.same(expected, result.parameters) + end) +end) diff --git a/lua/spec/codecompanion/adapters/anthropic_spec.lua b/lua/spec/codecompanion/adapters/anthropic_spec.lua new file mode 100644 index 00000000..6bdb56e9 --- /dev/null +++ b/lua/spec/codecompanion/adapters/anthropic_spec.lua @@ -0,0 +1,60 @@ +local adapter = require("codecompanion.adapters.anthropic") +local assert = require("luassert") +local helpers = require("spec.codecompanion.adapters.helpers") + +--------------------------------------------------- OUTPUT FROM THE CHAT BUFFER +local messages = { { + content = "Explain Ruby in two words", + role = "user", +} } + +local stream_response = { + { + request = [[data: {"type":"message_start","message":{"id":"msg_01Ngmyfn49udNhWaojMVKiR6","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":13,"output_tokens":1}}}]], + output = { + content = "", + role = "assistant", + }, + }, + { + request = [[data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Dynamic"}}]], + output = { + content = "Dynamic", + }, + }, + { + request = [[data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":","}}]], + output = { + content = ",", + }, + }, + { + request = [[data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" elegant"}}]], + output = { + content = " elegant", + }, + }, + { + request = [[data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."}}]], + output = { + content = ".", + }, + }, +} + +local done_response = [[data: {"type":"message_stop"}]] +------------------------------------------------------------------------ // END + +describe("Anthropic adapter", function() + it("can form messages to be sent to the API", function() + assert.are.same({ messages = messages }, adapter.callbacks.form_messages(messages)) + end) + + it("can check if the streaming is complete", function() + assert.is_true(adapter.callbacks.is_complete(done_response)) + end) + + it("can output streamed data into a format for the chat buffer", function() + assert.are.same(stream_response[#stream_response].output, helpers.chat_buffer_output(stream_response, adapter)) + end) +end) diff --git a/lua/spec/codecompanion/adapters/helpers.lua b/lua/spec/codecompanion/adapters/helpers.lua new file mode 100644 index 00000000..7cb1e967 --- /dev/null +++ b/lua/spec/codecompanion/adapters/helpers.lua @@ -0,0 +1,13 @@ +local M = {} + +function M.chat_buffer_output(stream_response, adapter) + local output = {} + + for _, data in ipairs(stream_response) do + output = adapter.callbacks.chat_output(data.request) + end + + return output.output +end + +return M diff --git a/lua/spec/codecompanion/adapters/ollama_spec.lua b/lua/spec/codecompanion/adapters/ollama_spec.lua new file mode 100644 index 00000000..a856bf7b --- /dev/null +++ b/lua/spec/codecompanion/adapters/ollama_spec.lua @@ -0,0 +1,79 @@ +local adapter = require("codecompanion.adapters.ollama") +local assert = require("luassert") +local helpers = require("spec.codecompanion.adapters.helpers") + +--------------------------------------------------- OUTPUT FROM THE CHAT BUFFER +local messages = { { + content = "Explain Ruby in two words", + role = "user", +} } + +local stream_response = { + { + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.622386Z","message":{"role":"assistant","content":"\n"},"done":false}]], + output = { + content = "\n", + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.652682Z","message":{"role":"assistant","content":"\""},"done":false}]], + output = { + content = '"', + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.681756Z","message":{"role":"assistant","content":"Be"},"done":false}]], + output = { + content = "Be", + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.710758Z","message":{"role":"assistant","content":"aut"},"done":false}]], + output = { + content = "aut", + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.739508Z","message":{"role":"assistant","content":"iful"},"done":false}]], + output = { + content = "iful", + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.770345Z","message":{"role":"assistant","content":" Language"},"done":false}]], + output = { + content = " Language", + role = "assistant", + }, + }, + { + request = [[{"model":"llama2","created_at":"2024-03-07T20:02:30.7994Z","message":{"role":"assistant","content":"\""},"done":false}]], + output = { + content = '"', + role = "assistant", + }, + }, +} + +local done_response = + [[{"model":"llama2","created_at":"2024-03-06T18:35:15.921631Z","message":{"role":"assistant","content":""},"done":true,"total_duration":6035327208,"load_duration":5654490167,"prompt_eval_count":26,"prompt_eval_duration":173338000,"eval_count":8,"eval_duration":205986000}]] +------------------------------------------------------------------------ // END + +describe("Ollama adapter", function() + it("can form messages to be sent to the API", function() + assert.are.same({ messages = messages }, adapter.callbacks.form_messages(messages)) + end) + + it("can check if the streaming is complete", function() + assert.is_true(adapter.callbacks.is_complete(done_response)) + end) + + it("can output streamed data into a format for the chat buffer", function() + assert.are.same(stream_response[#stream_response].output, helpers.chat_buffer_output(stream_response, adapter)) + end) +end) diff --git a/lua/spec/codecompanion/adapters/openai_spec.lua b/lua/spec/codecompanion/adapters/openai_spec.lua new file mode 100644 index 00000000..266cbf1e --- /dev/null +++ b/lua/spec/codecompanion/adapters/openai_spec.lua @@ -0,0 +1,48 @@ +local adapter = require("codecompanion.adapters.openai") +local assert = require("luassert") +local helpers = require("spec.codecompanion.adapters.helpers") + +--------------------------------------------------- OUTPUT FROM THE CHAT BUFFER +local messages = { { + content = "Explain Ruby in two words", + role = "user", +} } + +local stream_response = { + { + request = [[data: {"id":"chatcmpl-90DdmqMKOKpqFemxX0OhTVdH042gu","object":"chat.completion.chunk","created":1709839462,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}]], + output = { + content = "", + role = "assistant", + }, + }, + { + request = [[data: {"id":"chatcmpl-90DdmqMKOKpqFemxX0OhTVdH042gu","object":"chat.completion.chunk","created":1709839462,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":"Programming"},"logprobs":null,"finish_reason":null}]}]], + output = { + content = "Programming", + }, + }, + { + request = [[data: {"id":"chatcmpl-90DdmqMKOKpqFemxX0OhTVdH042gu","object":"chat.completion.chunk","created":1709839462,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":" language"},"logprobs":null,"finish_reason":null}]}]], + output = { + content = " language", + }, + }, +} + +local done_response = "data: [DONE]" +------------------------------------------------------------------------ // END + +describe("OpenAI adapter", function() + it("can form messages to be sent to the API", function() + assert.are.same({ messages = messages }, adapter.callbacks.form_messages(messages)) + end) + + it("can check if the streaming is complete", function() + assert.is_true(adapter.callbacks.is_complete(done_response)) + end) + + it("can output streamed data into a format for the chat buffer", function() + assert.are.same(stream_response[#stream_response].output, helpers.chat_buffer_output(stream_response, adapter)) + end) +end) diff --git a/lua/spec/codecompanion/client_spec.lua b/lua/spec/codecompanion/client_spec.lua index 7973ef79..f4808279 100644 --- a/lua/spec/codecompanion/client_spec.lua +++ b/lua/spec/codecompanion/client_spec.lua @@ -2,18 +2,32 @@ local assert = require("luassert") local codecompanion = require("codecompanion") local stub = require("luassert.stub") -local schema local Client +local adapter = { + name = "TestAdapter", + url = "https://api.openai.com/v1/chat/completions", + headers = { + content_type = "application/json", + }, + parameters = { + stream = true, + }, + callbacks = { + form_messages = function() + return {} + end, + }, + schema = {}, +} + describe("Client", function() before_each(function() codecompanion.setup() - schema = require("codecompanion.schema") Client = require("codecompanion.client") -- Now that setup has been called, we can require the client end) after_each(function() - schema.static.client_settings = nil _G.codecompanion_jobs = nil end) @@ -26,23 +40,19 @@ describe("Client", function() -- Mock globals _G.codecompanion_jobs = {} - schema.static.client_settings = { + Client.static.opts = { request = { default = mock_request }, encode = { default = mock_encode }, decode = { default = mock_decode }, schedule = { default = mock_schedule }, } - local client = Client.new({ - secret_key = "fake_key", - organization = "fake_org", - }) - local cb = stub.new() - client:stream_chat({}, 0, cb) + adapter = require("codecompanion.adapter").new(adapter) + + Client.new():stream(adapter, {}, 0, cb) - assert.stub(mock_request).was_called() - -- assert.stub(cb).was_called() + assert.stub(mock_request).was_called(1) end) end)