From ebadba7420a5f9b85829273b8c9cd2be56d9b074 Mon Sep 17 00:00:00 2001 From: yetone Date: Thu, 27 Feb 2025 15:21:00 +0800 Subject: [PATCH] fix: claude extended thinking (#1419) --- lua/avante/llm.lua | 22 ++++- lua/avante/providers/bedrock/claude.lua | 21 +++-- lua/avante/providers/claude.lua | 113 +++++++++++++++--------- lua/avante/sidebar.lua | 18 +++- lua/avante/types.lua | 3 +- 5 files changed, 126 insertions(+), 51 deletions(-) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index c9a79ab..2594a6d 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -201,8 +201,28 @@ function M._stream(opts) if stop_opts.reason == "rate_limit" then local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..." opts.on_chunk("\n*[" .. msg .. "]*\n") + local timer = vim.loop.new_timer() + if timer then + local retry_after = stop_opts.retry_after + local function countdown() + timer:start( + 1000, + 0, + vim.schedule_wrap(function() + if retry_after > 0 then retry_after = retry_after - 1 end + local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..." + opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n") + countdown() + end) + ) + end + countdown() + end Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" }) - vim.defer_fn(function() M._stream(opts) end, stop_opts.retry_after * 1000) + vim.defer_fn(function() + if timer then timer:stop() end + M._stream(opts) + end, stop_opts.retry_after * 1000) return end return opts.on_stop(stop_opts) diff --git a/lua/avante/providers/bedrock/claude.lua b/lua/avante/providers/bedrock/claude.lua index 531e655..d018780 100644 --- a/lua/avante/providers/bedrock/claude.lua +++ b/lua/avante/providers/bedrock/claude.lua @@ -39,11 +39,22 @@ function M.parse_messages(opts) role = "assistant", content = {}, } - if tool_history.tool_use.response_content then - msg.content[#msg.content + 1] = { - type = "text", - text = tool_history.tool_use.response_content, - } + if tool_history.tool_use.thinking_contents then + for _, thinking_content in ipairs(tool_history.tool_use.thinking_contents) do + msg.content[#msg.content + 1] = { + type = "thinking", + thinking = thinking_content.content, + signature = thinking_content.signature, + } + end + end + if tool_history.tool_use.response_contents then + for _, response_content in ipairs(tool_history.tool_use.response_contents) do + msg.content[#msg.content + 1] = { + type = "text", + text = response_content, + } + end end msg.content[#msg.content + 1] = { type = "tool_use", diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index a862fc3..5c0fcbe 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -89,11 +89,22 @@ function M.parse_messages(opts) role = "assistant", content = {}, } - if tool_history.tool_use.response_content then - msg.content[#msg.content + 1] = { - type = "text", - text = tool_history.tool_use.response_content, - } + if tool_history.tool_use.thinking_contents then + for _, thinking_content in ipairs(tool_history.tool_use.thinking_contents) do + msg.content[#msg.content + 1] = { + type = "thinking", + thinking = thinking_content.content, + signature = thinking_content.signature, + } + end + end + if tool_history.tool_use.response_contents then + for _, response_content in ipairs(tool_history.tool_use.response_contents) do + msg.content[#msg.content + 1] = { + type = "text", + text = response_content, + } + end end msg.content[#msg.content + 1] = { type = "tool_use", @@ -139,6 +150,7 @@ function M.parse_response(ctx, data_stream, event_state, opts) event_state = "content_block_stop" end end + if ctx.content_blocks == nil then ctx.content_blocks = {} end if event_state == "message_start" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end @@ -146,52 +158,39 @@ function M.parse_response(ctx, data_stream, event_state, opts) elseif event_state == "content_block_start" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end - if jsn.content_block.type == "tool_use" then - if not ctx.tool_use_list then ctx.tool_use_list = {} end - local tool_use = { - name = jsn.content_block.name, - id = jsn.content_block.id, - input_json = "", - response_content = nil, - } - table.insert(ctx.tool_use_list, tool_use) - elseif jsn.content_block.type == "text" then - ctx.response_content = "" - end + local content_block = jsn.content_block + content_block.stoppped = false + ctx.content_blocks[jsn.index + 1] = content_block + if content_block.type == "thinking" then opts.on_chunk("\n") end elseif event_state == "content_block_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end - if ctx.tool_use_list and jsn.delta.type == "input_json_delta" then - local tool_use = ctx.tool_use_list[#ctx.tool_use_list] - tool_use.input_json = tool_use.input_json .. jsn.delta.partial_json + local content_block = ctx.content_blocks[jsn.index + 1] + if jsn.delta.type == "input_json_delta" then + if not content_block.input_json then content_block.input_json = "" end + content_block.input_json = content_block.input_json .. jsn.delta.partial_json return - elseif ctx.response_content and jsn.delta.type == "text_delta" then - ctx.response_content = ctx.response_content .. jsn.delta.text - end - if jsn.delta.type == "thinking_delta" then - if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then - ctx.returned_think_start_tag = true - opts.on_chunk("\n") - end - ctx.last_think_content = jsn.delta.thinking + elseif jsn.delta.type == "thinking_delta" then + content_block.thinking = content_block.thinking .. jsn.delta.thinking opts.on_chunk(jsn.delta.thinking) elseif jsn.delta.type == "text_delta" then - if - ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) - then - ctx.returned_think_end_tag = true - if ctx.last_think_content and ctx.last_think_content ~= vim.NIL and ctx.last_think_content:sub(-1) ~= "\n" then - opts.on_chunk("\n\n\n") - else - opts.on_chunk("\n\n") - end - end + content_block.text = content_block.text .. jsn.delta.text opts.on_chunk(jsn.delta.text) + elseif jsn.delta.type == "signature_delta" then + if ctx.content_blocks[jsn.index + 1].signature == nil then ctx.content_blocks[jsn.index + 1].signature = "" end + ctx.content_blocks[jsn.index + 1].signature = ctx.content_blocks[jsn.index + 1].signature .. jsn.delta.signature end elseif event_state == "content_block_stop" then - if ctx.tool_use_list then - local tool_use = ctx.tool_use_list[#ctx.tool_use_list] - if tool_use.response_content == nil then tool_use.response_content = ctx.response_content end + local ok, jsn = pcall(vim.json.decode, data_stream) + if not ok then return end + local content_block = ctx.content_blocks[jsn.index + 1] + content_block.stoppped = true + if content_block.type == "thinking" then + if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then + opts.on_chunk("\n\n\n") + else + opts.on_chunk("\n\n") + end end elseif event_state == "message_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) @@ -199,10 +198,38 @@ function M.parse_response(ctx, data_stream, event_state, opts) if jsn.delta.stop_reason == "end_turn" then opts.on_stop({ reason = "complete", usage = jsn.usage }) elseif jsn.delta.stop_reason == "tool_use" then + ---@type AvanteLLMToolUse[] + local tool_use_list = vim + .iter(ctx.content_blocks) + :filter(function(content_block) return content_block.stoppped and content_block.type == "tool_use" end) + :map(function(content_block) + local response_contents = vim + .iter(ctx.content_blocks) + :filter(function(content_block_) return content_block_.stoppped and content_block_.type == "text" end) + :map(function(content_block_) return content_block_.text end) + :totable() + local thinking_contents = vim + .iter(ctx.content_blocks) + :filter(function(content_block_) return content_block_.stoppped and content_block_.type == "thinking" end) + :map( + function(content_block_) + return { content = content_block_.thinking, signature = content_block_.signature } + end + ) + :totable() + return { + name = content_block.name, + id = content_block.id, + input_json = content_block.input_json, + response_contents = response_contents, + thinking_contents = thinking_contents, + } + end) + :totable() opts.on_stop({ reason = "tool_use", usage = jsn.usage, - tool_use_list = ctx.tool_use_list, + tool_use_list = tool_use_list, }) end return diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 0952e93..c373f62 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2495,7 +2495,23 @@ function Sidebar:create_input_container(opts) local function on_chunk(chunk) self.is_generating = true - original_response = original_response .. chunk + local remove_line = [[\033[1A\033[K]] + if chunk:sub(1, #remove_line) == remove_line then + chunk = chunk:sub(#remove_line + 1) + local lines = vim.split(transformed_response, "\n") + local idx = #lines + while idx > 0 and lines[idx] == "" do + idx = idx - 1 + end + if idx == 1 then + lines = {} + else + lines = vim.list_slice(lines, 1, idx - 1) + end + transformed_response = table.concat(lines, "\n") + else + original_response = original_response .. chunk + end local selected_files = self.file_selector:get_selected_files_contents() diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 3e8ac75..fc214b4 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -226,7 +226,8 @@ vim.g.avante_login = vim.g.avante_login ---@field name string ---@field id string ---@field input_json string ----@field response_content? string +---@field response_contents? string[] +---@field thinking_contents? { content: string, signature: string }[] --- ---@class AvanteLLMStartCallbackOptions ---@field usage? AvanteLLMUsage