diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index a0c5f0e..fb2ed72 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -139,6 +139,8 @@ M._stream = function(opts) ---@type AvanteCurlOutput local spec = Provider.parse_curl_args(Provider, code_opts) + local resp_ctx = {} + ---@param line string local function parse_stream_data(line) local event = line:match("^event: (.+)$") @@ -147,7 +149,7 @@ M._stream = function(opts) return end local data_match = line:match("^data: (.+)$") - if data_match then Provider.parse_response(data_match, current_event_state, handler_opts) end + if data_match then Provider.parse_response(resp_ctx, data_match, current_event_state, handler_opts) end end local function parse_response_without_stream(data) diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 469c93b..a59bb6d 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -77,7 +77,7 @@ M.parse_messages = function(opts) return messages end -M.parse_response = function(data_stream, event_state, opts) +M.parse_response = function(ctx, data_stream, event_state, opts) if event_state == nil then if data_stream:match('"content_block_delta"') then event_state = "content_block_delta" diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 11a1660..9e29e4e 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -64,7 +64,7 @@ M.parse_messages = function(opts) } end -M.parse_response = function(data_stream, _, opts) +M.parse_response = function(ctx, data_stream, _, opts) local ok, json = pcall(vim.json.decode, data_stream) if not ok then opts.on_complete(json) end if json.candidates then diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 343437b..375c458 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -37,7 +37,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@class ResponseParser ---@field on_chunk fun(chunk: string): any ---@field on_complete fun(err: string|nil): any ----@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil +---@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: ResponseParser): nil --- ---@class AvanteDefaultBaseProvider: table ---@field endpoint? string diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index b5d57e4..fe3eddd 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -26,7 +26,8 @@ local P = require("avante.providers") --- ---@class OpenAIMessage ---@field role? "user" | "system" | "assistant" ----@field content string +---@field content? string +---@field reasoning_content? string --- ---@class AvanteProviderFunctor local M = {} @@ -106,19 +107,30 @@ M.parse_messages = function(opts) return final_messages end -M.parse_response = function(data_stream, _, opts) +M.parse_response = function(ctx, data_stream, _, opts) if data_stream:match('"%[DONE%]":') then opts.on_complete(nil) return end if data_stream:match('"delta":') then ---@type OpenAIChatResponse - local json = vim.json.decode(data_stream) - if json.choices and json.choices[1] then - local choice = json.choices[1] + local jsn = vim.json.decode(data_stream) + Utils.debug("jsn", jsn) + if jsn.choices and jsn.choices[1] then + local choice = jsn.choices[1] if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then opts.on_complete(nil) + elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then + if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then + ctx.returned_think_start_tag = true + opts.on_chunk("\n") + end + opts.on_chunk(choice.delta.reasoning_content) elseif choice.delta.content then + if ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag then + ctx.returned_think_end_tag = true + opts.on_chunk("\n\n\n") + end if choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end end end