From c843a167f638ef190062ba9b0c82e5b5e2db600c Mon Sep 17 00:00:00 2001 From: Oli Morris Date: Tue, 5 Mar 2024 09:40:26 +0000 Subject: [PATCH] inline strategy now uses adapter --- README.md | 31 +--- lua/codecompanion/adapter.lua | 6 +- lua/codecompanion/adapters/openai.lua | 35 +++-- lua/codecompanion/client.lua | 74 +++------ lua/codecompanion/config.lua | 16 -- lua/codecompanion/health.lua | 22 +-- lua/codecompanion/strategies/chat.lua | 6 +- lua/codecompanion/strategies/inline.lua | 197 ++++++++++++------------ 8 files changed, 167 insertions(+), 220 deletions(-) diff --git a/README.md b/README.md index 4bd6aca8..2f17d684 100644 --- a/README.md +++ b/README.md @@ -95,34 +95,9 @@ You only need to the call the `setup` function if you wish to change any of the ```lua require("codecompanion").setup({ - api_key = "OPENAI_API_KEY", -- Your API key - org_api_key = "OPENAI_ORG_KEY", -- Your organisation API key - base_url = "https://api.openai.com", -- The URL to use for the API requests - ai_settings = { - -- Default settings for the Completions API - -- See https://platform.openai.com/docs/api-reference/chat/create - chat = { - model = "gpt-4-0125-preview", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, - inline = { - model = "gpt-3.5-turbo-0125", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, + adapters = { + chat = require("codecompanion.adapters.openai"), + inline = require("codecompanion.adapters.openai"), }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", -- Path to save chats to diff --git a/lua/codecompanion/adapter.lua b/lua/codecompanion/adapter.lua index 861c519b..266cfa94 100644 --- a/lua/codecompanion/adapter.lua +++ b/lua/codecompanion/adapter.lua @@ -36,9 +36,13 @@ function Adapter:get_default_settings() return settings end ----@param settings table +---@param settings? table ---@return CodeCompanion.Adapter function Adapter:set_params(settings) + if not settings then + settings = self:get_default_settings() + end + for k, v in pairs(settings) do local mapping = self.schema[k] and self.schema[k].mapping if mapping then diff --git a/lua/codecompanion/adapters/openai.lua b/lua/codecompanion/adapters/openai.lua index 9a11e9bf..cc0cba84 100644 --- a/lua/codecompanion/adapters/openai.lua +++ b/lua/codecompanion/adapters/openai.lua @@ -11,8 +11,12 @@ local Adapter = require("codecompanion.adapter") local adapter = { name = "OpenAI", url = "https://api.openai.com/v1/chat/completions", + raw = { + "--no-buffer", + "--silent", + }, headers = { - content_type = "application/json", + ["Content-Type"] = "application/json", -- FIX: Need a way to check if the key is set Authorization = "Bearer " .. os.getenv("OPENAI_API_KEY"), }, @@ -35,24 +39,33 @@ local adapter = { return formatted_data == "[DONE]" end, - ---Format the messages from the API - ---@param data table - ---@param messages table - ---@param new_message table + ---Output the data from the API ready for insertion into the chat buffer + ---@param data table The streamed data from the API + ---@param messages table A table of all of the messages in the chat buffer + ---@param current_message table The current/latest message in the chat buffer ---@return table - format_messages = function(data, messages, new_message) + output_chat = function(data, messages, current_message) local delta = data.choices[1].delta - if delta.role and delta.role ~= new_message.role then - new_message = { role = delta.role, content = "" } - table.insert(messages, new_message) + if delta.role and delta.role ~= current_message.role then + current_message = { role = delta.role, content = "" } + table.insert(messages, current_message) end + -- Append the new message to the if delta.content then - new_message.content = new_message.content .. delta.content + current_message.content = current_message.content .. delta.content end - return new_message + return current_message + end, + + ---Output the data from the API ready for inlining into the current buffer + ---@param data table The streamed data from the API + ---@param context table Useful context about the buffer to inline to + ---@return table + output_inline = function(data, context) + return data.choices[1].delta.content end, }, schema = { diff --git a/lua/codecompanion/client.lua b/lua/codecompanion/client.lua index 6823842b..18a38a8a 100644 --- a/lua/codecompanion/client.lua +++ b/lua/codecompanion/client.lua @@ -1,4 +1,3 @@ -local config = require("codecompanion.config") local curl = require("plenary.curl") local log = require("codecompanion.utils.log") local schema = require("codecompanion.schema") @@ -155,31 +154,40 @@ function Client:stream(adapter, payload, bufnr, cb) start_request(bufnr, handler) end ----Call the OpenAI API but block the main loop until the response is received ----@param url string +---Call the API and block until the response is received +---@param adapter CodeCompanion.Adapter ---@param payload table ---@param cb fun(err: nil|string, response: nil|table) -function Client:block_request(url, payload, cb) +function Client:call(adapter, payload, cb) cb = log:wrap_cb(cb, "Response error: %s") local cmd = { "curl", - url, - "--silent", - "--no-buffer", - "-H", - "Content-Type: application/json", - "-H", - string.format("Authorization: Bearer %s", self.secret_key), + adapter.url, } - if self.organization then - table.insert(cmd, "-H") - table.insert(cmd, string.format("OpenAI-Organization: %s", self.organization)) + if adapter.raw then + for _, v in ipairs(adapter.raw) do + table.insert(cmd, v) + end + else + table.insert(cmd, "--no-buffer") + end + + if adapter.headers then + for k, v in pairs(adapter.headers) do + table.insert(cmd, "-H") + table.insert(cmd, string.format("%s: %s", k, v)) + end end table.insert(cmd, "-d") - table.insert(cmd, vim.json.encode(payload)) + table.insert( + cmd, + vim.json.encode(vim.tbl_extend("keep", adapter.parameters, { + messages = payload, + })) + ) log:trace("Request payload: %s", cmd) local result = vim.fn.system(cmd) @@ -197,40 +205,4 @@ function Client:block_request(url, payload, cb) end end ----@class CodeCompanion.ChatMessage ----@field role "system"|"user"|"assistant" ----@field content string - ----@class CodeCompanion.ChatSettings ----@field model string ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. ----@field temperature nil|number Defaults to 1. What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. ----@field top_p nil|number Defaults to 1. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. ----@field n nil|integer Defaults to 1. How many chat completion choices to generate for each input message. ----@field stop nil|string|string[] Defaults to nil. Up to 4 sequences where the API will stop generating further tokens. ----@field max_tokens nil|integer Defaults to nil. The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. ----@field presence_penalty nil|number Defaults to 0. Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. ----@field frequency_penalty nil|number Defaults to 0. Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. ----@field logit_bias nil|table Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID) to an associated bias value from -100 to 100. Use https://platform.openai.com/tokenizer to find token IDs. ----@field user nil|string A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more. - ----@class CodeCompanion.ChatArgs : CodeCompanion.ChatSettings ----@field messages CodeCompanion.ChatMessage[] The messages to generate chat completions for, in the chat format. ----@field stream boolean? Whether to stream the chat output back to Neovim - ----@param args CodeCompanion.ChatArgs ----@param cb fun(err: nil|string, response: nil|table) ----@return nil -function Client:chat(args, cb) - return self:block_request(config.options.base_url .. "/v1/chat/completions", args, cb) -end - ----@class args CodeCompanion.InlineArgs ----@param bufnr integer ----@param cb fun(err: nil|string, chunk: nil|table, done: nil|boolean) Will be called multiple times until done is true ----@return nil -function Client:inline(args, bufnr, cb) - args.stream = true - return self:stream(config.options.base_url .. "/v1/chat/completions", args, bufnr, cb) -end - return Client diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index 5d92f3ff..870a7b89 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -1,26 +1,10 @@ local M = {} local defaults = { - api_key = "OPENAI_API_KEY", - org_api_key = "OPENAI_ORG_KEY", - base_url = "https://api.openai.com", adapters = { chat = require("codecompanion.adapters.openai"), inline = require("codecompanion.adapters.openai"), }, - ai_settings = { - inline = { - model = "gpt-4-0125-preview", - temperature = 1, - top_p = 1, - stop = nil, - max_tokens = nil, - presence_penalty = 0, - frequency_penalty = 0, - logit_bias = nil, - user = nil, - }, - }, saved_chats = { save_dir = vim.fn.stdpath("data") .. "/codecompanion/saved_chats", }, diff --git a/lua/codecompanion/health.lua b/lua/codecompanion/health.lua index 30bf7a43..5203496d 100644 --- a/lua/codecompanion/health.lua +++ b/lua/codecompanion/health.lua @@ -94,17 +94,17 @@ function M.check() end end - for _, env in ipairs(M.env_vars) do - if env_available(env.name) then - ok(fmt("%s key found", env.name)) - else - if env.optional then - warn(fmt("%s key not found", env.name)) - else - error(fmt("%s key not found", env.name)) - end - end - end + -- for _, env in ipairs(M.env_vars) do + -- if env_available(env.name) then + -- ok(fmt("%s key found", env.name)) + -- else + -- if env.optional then + -- warn(fmt("%s key not found", env.name)) + -- else + -- error(fmt("%s key not found", env.name)) + -- end + -- end + -- end end return M diff --git a/lua/codecompanion/strategies/chat.lua b/lua/codecompanion/strategies/chat.lua index ff1b1244..3223c908 100644 --- a/lua/codecompanion/strategies/chat.lua +++ b/lua/codecompanion/strategies/chat.lua @@ -431,9 +431,9 @@ function Chat:submit() end end - local new_message = messages[#messages] + local current_message = messages[#messages] - if new_message and new_message.role == "user" and new_message.content == "" then + if current_message and current_message.role == "user" and current_message.content == "" then return finalize() end @@ -447,7 +447,7 @@ function Chat:submit() if data then log:trace("Chat data: %s", data) - new_message = adapter.callbacks.format_messages(data, messages, new_message) + current_message = adapter.callbacks.output_chat(data, messages, current_message) render_buffer() end diff --git a/lua/codecompanion/strategies/inline.lua b/lua/codecompanion/strategies/inline.lua index 6529fb5f..1bc61088 100644 --- a/lua/codecompanion/strategies/inline.lua +++ b/lua/codecompanion/strategies/inline.lua @@ -1,4 +1,6 @@ +local client = require("codecompanion.client") local config = require("codecompanion.config") +local adapter = config.options.adapters.inline local log = require("codecompanion.utils.log") local ui = require("codecompanion.utils.ui") @@ -9,6 +11,14 @@ local function fire_autocmd(status) vim.api.nvim_exec_autocmds("User", { pattern = "CodeCompanionInline", data = { status = status } }) end +---When a user initiates an inline request, it will be possible to infer from +---their prompt, how the output should be placed in the buffer. For instance, if +---they reference the words "refactor" or "update" in their prompt, they will +---likely want the response to replace a visual selection they've made in the +---editor. However, if they include words such as "after" or "before", it may +---be insinuated that they wish the response to be placed after or before the +---cursor. In this function, we use the power of Generative AI to determine +---the user's intent and return a placement position. ---@param inline CodeCompanion.Inline ---@param prompt string ---@return string, boolean @@ -32,24 +42,17 @@ local function get_placement_position(inline, prompt) } local output - inline.client:chat( - vim.tbl_extend("keep", inline.settings, { - messages = messages, - }), - function(err, chunk) - if err then - return - end + client.new():call(adapter:set_params(), messages, function(err, data) + if err then + return + end - if chunk then - local content = chunk.choices[1].message.content - if content then - log:trace("Placement response: %s", content) - output = content - end - end + if data then + print(vim.inspect(data)) end - ) + end) + + log:trace("Placement output: %s", output) if output then local parts = vim.split(output, "|") @@ -131,17 +134,13 @@ local function get_cursor(winid) end ---@class CodeCompanion.Inline ----@field settings table ---@field context table ----@field client CodeCompanion.Client ---@field opts table ---@field prompts table local Inline = {} ---@class CodeCompanion.InlineArgs ----@field settings table ---@field context table ----@field client CodeCompanion.Client ---@field opts table ---@field pre_hook fun():number -- Assuming pre_hook returns a number for example ---@field prompts table @@ -162,9 +161,7 @@ function Inline.new(opts) end return setmetatable({ - settings = config.options.ai_settings.inline, context = opts.context, - client = opts.client, opts = opts.opts or {}, prompts = vim.deepcopy(opts.prompts), }, { __index = Inline }) @@ -176,54 +173,60 @@ function Inline:execute(user_input) local messages = get_action(self, user_input) - if not self.opts.placement and user_input then - local return_code - self.opts.placement, return_code = get_placement_position(self, user_input) - - if not return_code then - table.insert(messages, { - role = "user", - content = "Please do not return the code I have sent in the response", - }) - end - - log:debug("Setting the placement to: %s", self.opts.placement) - end - + -- if not self.opts.placement and user_input then + -- local return_code + -- self.opts.placement, return_code = get_placement_position(self, user_input) + -- + -- if not return_code then + -- table.insert(messages, { + -- role = "user", + -- content = "Please do not return the code I have sent in the response", + -- }) + -- end + -- + -- log:debug("Setting the placement to: %s", self.opts.placement) + -- end + + -- Assume the placement should be after the cursor + vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) + pos.line = self.context.end_line + 1 + pos.col = 0 + + --TODO: Workout how we can re-enable this -- Determine where to place the response in the buffer - if self.opts and self.opts.placement then - if self.opts.placement == "before" then - log:trace("Placing before selection") - vim.api.nvim_buf_set_lines( - self.context.bufnr, - self.context.start_line - 1, - self.context.start_line - 1, - false, - { "" } - ) - self.context.start_line = self.context.start_line + 1 - pos.line = self.context.start_line - 1 - pos.col = self.context.start_col - 1 - elseif self.opts.placement == "after" then - log:trace("Placing after selection") - vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) - pos.line = self.context.end_line + 1 - pos.col = 0 - elseif self.opts.placement == "replace" then - log:trace("Placing by overwriting selection") - overwrite_selection(self.context) - - pos.line, pos.col = get_cursor(self.context.winid) - elseif self.opts.placement == "new" then - log:trace("Placing in a new buffer") - self.context.bufnr = api.nvim_create_buf(true, false) - api.nvim_buf_set_option(self.context.bufnr, "filetype", self.context.filetype) - api.nvim_set_current_buf(self.context.bufnr) - - pos.line = 1 - pos.col = 0 - end - end + -- if self.opts and self.opts.placement then + -- if self.opts.placement == "before" then + -- log:trace("Placing before selection") + -- vim.api.nvim_buf_set_lines( + -- self.context.bufnr, + -- self.context.start_line - 1, + -- self.context.start_line - 1, + -- false, + -- { "" } + -- ) + -- self.context.start_line = self.context.start_line + 1 + -- pos.line = self.context.start_line - 1 + -- pos.col = self.context.start_col - 1 + -- elseif self.opts.placement == "after" then + -- log:trace("Placing after selection") + -- vim.api.nvim_buf_set_lines(self.context.bufnr, self.context.end_line, self.context.end_line, false, { "" }) + -- pos.line = self.context.end_line + 1 + -- pos.col = 0 + -- elseif self.opts.placement == "replace" then + -- log:trace("Placing by overwriting selection") + -- overwrite_selection(self.context) + -- + -- pos.line, pos.col = get_cursor(self.context.winid) + -- elseif self.opts.placement == "new" then + -- log:trace("Placing in a new buffer") + -- self.context.bufnr = api.nvim_create_buf(true, false) + -- api.nvim_buf_set_option(self.context.bufnr, "filetype", self.context.filetype) + -- api.nvim_set_current_buf(self.context.bufnr) + -- + -- pos.line = 1 + -- pos.col = 0 + -- end + -- end log:debug("Context for inline: %s", self.context) log:debug("Cursor position to use: %s", pos) @@ -268,43 +271,39 @@ function Inline:execute(user_input) fire_autocmd("started") local output = {} - self.client:stream_chat( - vim.tbl_extend("keep", self.settings, { - messages = messages, - }), - self.context.bufnr, - function(err, chunk, done) - if err then - return - end + client.new():stream(adapter:set_params(), messages, self.context.bufnr, function(err, data, done) + if err then + fire_autocmd("finished") + return + end - if chunk then - log:trace("Chat chunk: %s", chunk) - - local delta = chunk.choices[1].delta - if delta.content and not delta.role and delta.content ~= "```" and delta.content ~= self.context.filetype then - if self.context.buftype == "terminal" then - -- Don't stream to the terminal - table.insert(output, delta.content) - else - stream_buffer_text(delta.content) - if self.opts and self.opts.placement == "new" then - ui.buf_scroll_to_end(self.context.bufnr) - end + if data then + log:trace("Inline data: %s", data) + + local content = adapter.callbacks.output_inline(data, self.context) + + if self.context.buftype == "terminal" then + -- Don't stream to the terminal + table.insert(output, content) + else + if content then + stream_buffer_text(content) + if self.opts and self.opts.placement == "new" then + ui.buf_scroll_to_end(self.context.bufnr) end end end + end - if done then - api.nvim_buf_del_keymap(self.context.bufnr, "n", "q") - if self.context.buftype == "terminal" then - log:debug("Terminal output: %s", output) - api.nvim_put({ table.concat(output, "") }, "", false, true) - end - fire_autocmd("finished") + if done then + api.nvim_buf_del_keymap(self.context.bufnr, "n", "q") + if self.context.buftype == "terminal" then + log:debug("Terminal output: %s", output) + api.nvim_put({ table.concat(output, "") }, "", false, true) end + fire_autocmd("finished") end - ) + end) end ---@param user_input? string