From af1a40917a738f74ab7e33a905b46d0bb643a87a Mon Sep 17 00:00:00 2001 From: qtnx <123870525+qtnx@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:50:44 +0700 Subject: [PATCH] feat: handle gpt o1-preview, o1-mini models --- lua/gp/dispatcher.lua | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/lua/gp/dispatcher.lua b/lua/gp/dispatcher.lua index b6f9432..26fc76a 100644 --- a/lua/gp/dispatcher.lua +++ b/lua/gp/dispatcher.lua @@ -165,7 +165,7 @@ D.prepare_payload = function(messages, model, provider) model.model = "gpt-4o-2024-05-13" end - return { + local output = { model = model.model, stream = true, messages = messages, @@ -173,6 +173,21 @@ D.prepare_payload = function(messages, model, provider) temperature = math.max(0, math.min(2, model.temperature or 1)), top_p = math.max(0, math.min(1, model.top_p or 1)), } + + if provider == "openai" and model.model:sub(1, 2) == "o1" then + for i = #messages, 1, -1 do + if messages[i].role == "system" then + table.remove(messages, i) + end + end + -- remove max_tokens, top_p, temperature for o1 models. https://platform.openai.com/docs/guides/reasoning/beta-limitations + output.max_tokens = nil + output.temperature = nil + output.top_p = nil + output.stream = false + end + + return output end -- gpt query @@ -249,6 +264,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback) end end + if content and type(content) == "string" then qt.response = qt.response .. content handler(qid, content) @@ -282,6 +298,19 @@ local query = function(buf, provider, payload, handler, on_exit, callback) if #buffer > 0 then process_lines(buffer) end + local raw_response = qt.raw_response + local content = qt.response + if qt.provider == 'openai' and content == "" and raw_response:match('choices') and raw_response:match("content") then + local response = vim.json.decode(raw_response) + if response.choices and response.choices[1] and response.choices[1].message and response.choices[1].message.content then + content = response.choices[1].message.content + end + if content and type(content) == "string" then + qt.response = qt.response .. content + handler(qid, content) + end + end + if qt.response == "" then logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response)) @@ -363,7 +392,8 @@ local query = function(buf, provider, payload, handler, on_exit, callback) } end - local temp_file = D.query_dir .. "/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json" + local temp_file = D.query_dir .. + "/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json" helpers.table_to_file(payload, temp_file) local curl_params = vim.deepcopy(D.config.curl_params or {})