diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index f171b24..730385b 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -114,6 +114,10 @@ M.stream = function(opts) if data_match then Provider.parse_response(data_match, current_event_state, handler_opts) end end + local function parse_response_without_stream(data) + Provider.parse_response_without_stream(data, current_event_state, handler_opts) + end + local completed = false local active_job @@ -170,6 +174,14 @@ M.stream = function(opts) end end) end + + -- If stream is not enabled, then handle the response here + if spec.body.stream == false and result.status == 200 then + vim.schedule(function() + completed = true + parse_response_without_stream(result.body) + end) + end end, }) diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 52e62b1..06a5f7b 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -9,7 +9,7 @@ local P = require("avante.providers") ---@field created integer ---@field model string ---@field system_fingerprint string ----@field choices? OpenAIResponseChoice[] +---@field choices? OpenAIResponseChoice[] | OpenAIResponseChoiceComplete[] ---@field usage {prompt_tokens: integer, completion_tokens: integer, total_tokens: integer} --- ---@class OpenAIResponseChoice @@ -18,6 +18,12 @@ local P = require("avante.providers") ---@field logprobs? integer ---@field finish_reason? "stop" | "length" --- +---@class OpenAIResponseChoiceComplete +---@field message OpenAIMessage +---@field finish_reason "stop" | "length" +---@field index integer +---@field logprobs integer +--- ---@class OpenAIMessage ---@field role? "user" | "system" | "assistant" ---@field content string @@ -50,10 +56,22 @@ M.parse_message = function(opts) end) end - return { - { role = "system", content = opts.system_prompt }, - { role = "user", content = user_content }, - } + local messages = {} + local provider = P[Config.provider] + local base, _ = P.parse_config(provider) + + -- NOTE: Handle the case where the selected model is the `o1` model + -- "o1" models are "smart" enough to understand user prompt as a system prompt in this context + if base.model and string.find(base.model, "o1") then + table.insert(messages, { role = "user", content = opts.system_prompt }) + else + table.insert(messages, { role = "system", content = opts.system_prompt }) + end + + -- User message after the prompt + table.insert(messages, { role = "user", content = user_content }) + + return messages end M.parse_response = function(data_stream, _, opts) @@ -75,6 +93,18 @@ M.parse_response = function(data_stream, _, opts) end end +M.parse_response_without_stream = function(data, _, opts) + ---@type OpenAIChatResponse + local json = vim.json.decode(data) + if json.choices and json.choices[1] then + local choice = json.choices[1] + if choice.message and choice.message.content then + opts.on_chunk(choice.message.content) + vim.schedule(function() opts.on_complete(nil) end) + end + end +end + M.parse_curl_args = function(provider, code_opts) local base, body_opts = P.parse_config(provider) @@ -83,6 +113,14 @@ M.parse_curl_args = function(provider, code_opts) } if not P.env.is_local("openai") then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end + -- NOTE: When using "o1" set the supported parameters only + local stream = true + if base.model and string.find(base.model, "o1") then + stream = false + body_opts.max_tokens = nil + body_opts.temperature = 1 + end + return { url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions", proxy = base.proxy, @@ -91,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts) body = vim.tbl_deep_extend("force", { model = base.model, messages = M.parse_message(code_opts), - stream = true, + stream = stream, }, body_opts), } end