diff --git a/lua/avante/api.lua b/lua/avante/api.lua index 4cfc054..355ab27 100644 --- a/lua/avante/api.lua +++ b/lua/avante/api.lua @@ -244,6 +244,7 @@ function M.select_history() Path.history.save_latest_filename(buf, filename) local sidebar = require("avante").get() sidebar:update_content_with_history() + sidebar:create_todos_container() vim.schedule(function() sidebar:focus_input() end) end) end) diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 0412505..57d42a1 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -272,9 +272,10 @@ M._defaults = { endpoint = "https://api.anthropic.com", model = "claude-sonnet-4-20250514", timeout = 30000, -- Timeout in milliseconds + context_window = 200000, extra_request_body = { temperature = 0.75, - max_tokens = 20480, + max_tokens = 64000, }, }, ---@type AvanteSupportedProvider diff --git a/lua/avante/history_message.lua b/lua/avante/history_message.lua index 4d1e6d0..afb0cb4 100644 --- a/lua/avante/history_message.lua +++ b/lua/avante/history_message.lua @@ -5,7 +5,7 @@ local M = {} M.__index = M ---@param message AvanteLLMMessage ----@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, session_id?: string} +---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, turn_id?: string} ---@return avante.HistoryMessage function M:new(message, opts) opts = opts or {} @@ -23,7 +23,7 @@ function M:new(message, opts) if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end if opts.is_dummy ~= nil then obj.is_dummy = opts.is_dummy end - obj.session_id = opts.session_id + obj.turn_id = opts.turn_id return obj end diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index cf14b56..5b70da4 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -416,13 +416,24 @@ function M.curl(opts) local prompt_opts = opts.prompt_opts local handler_opts = opts.handler_opts + local orig_on_stop = handler_opts.on_stop + local stopped = false + ---@param stop_opts AvanteLLMStopCallbackOptions + handler_opts.on_stop = function(stop_opts) + if stop_opts and not stop_opts.streaming_tool_use then + if stopped then return end + stopped = true + end + if orig_on_stop then return orig_on_stop(stop_opts) end + end + ---@type AvanteCurlOutput local spec = provider:parse_curl_args(prompt_opts) ---@type string local current_event_state = nil - local resp_ctx = {} - resp_ctx.session_id = Utils.uuid() + local turn_ctx = {} + turn_ctx.turn_id = Utils.uuid() local response_body = "" ---@param line string @@ -435,7 +446,7 @@ function M.curl(opts) local data_match = line:match("^data:%s*(.+)$") if data_match then response_body = "" - provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) + provider:parse_response(turn_ctx, data_match, current_event_state, handler_opts) else response_body = response_body .. line local ok, jsn = pcall(vim.json.decode, response_body) @@ -443,7 +454,7 @@ function M.curl(opts) if jsn.error then handler_opts.on_stop({ reason = "error", error = jsn.error }) else - provider:parse_response(resp_ctx, response_body, current_event_state, handler_opts) + provider:parse_response(turn_ctx, response_body, current_event_state, handler_opts) end response_body = "" end @@ -509,7 +520,7 @@ function M.curl(opts) end vim.schedule(function() if provider.parse_stream_data ~= nil then - provider:parse_stream_data(resp_ctx, data, handler_opts) + provider:parse_stream_data(turn_ctx, data, handler_opts) else parse_stream_data(data) end @@ -843,6 +854,7 @@ function M._stream(opts) end end if stop_opts.reason == "tool_use" then + opts.session_ctx.user_reminder_count = 0 return handle_next_tool_use(uncalled_tool_uses, 1, {}, stop_opts.streaming_tool_use) end if stop_opts.reason == "rate_limit" then diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index 546c47d..996ec8c 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -133,6 +133,8 @@ function M.func(opts, on_log, on_complete, session_ctx) local is_streaming = opts.streaming or false + if is_streaming then return end + session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {} local current_timestamp = os.time() if is_streaming then diff --git a/lua/avante/llm_tools/write_to_file.lua b/lua/avante/llm_tools/write_to_file.lua index af0a288..fbaf3dc 100644 --- a/lua/avante/llm_tools/write_to_file.lua +++ b/lua/avante/llm_tools/write_to_file.lua @@ -58,7 +58,10 @@ M.returns = { --- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view. ---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }> function M.func(opts, on_log, on_complete, session_ctx) - if opts.the_content ~= nil then opts.content = opts.the_content end + if opts.the_content ~= nil then + opts.content = opts.the_content + opts.the_content = nil + end if not on_complete then return false, "on_complete not provided" end local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index df04054..1b3f69b 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -166,7 +166,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) content = content_block.text, }, { state = "generating", - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) content_block.uuid = msg.uuid if opts.on_messages_add then opts.on_messages_add({ msg }) end @@ -185,7 +185,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, }, { state = "generating", - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) content_block.uuid = msg.uuid opts.on_messages_add({ msg }) @@ -205,11 +205,11 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, }, { state = "generating", - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) content_block.uuid = msg.uuid opts.on_messages_add({ msg }) - opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) + -- opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end elseif event_state == "content_block_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) @@ -234,7 +234,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generating", uuid = content_block.uuid, - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end elseif jsn.delta.type == "text_delta" then @@ -246,7 +246,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generating", uuid = content_block.uuid, - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end elseif jsn.delta.type == "signature_delta" then @@ -265,7 +265,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generated", uuid = content_block.uuid, - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end end @@ -284,7 +284,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generated", uuid = content_block.uuid, - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end end @@ -308,7 +308,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generated", uuid = content_block.uuid, - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end end @@ -317,6 +317,8 @@ function M:parse_response(ctx, data_stream, event_state, opts) if not ok then return end if jsn.delta.stop_reason == "end_turn" then opts.on_stop({ reason = "complete", usage = jsn.usage }) + elseif jsn.delta.stop_reason == "max_tokens" then + opts.on_stop({ reason = "max_tokens", usage = jsn.usage }) elseif jsn.delta.stop_reason == "tool_use" then opts.on_stop({ reason = "tool_use", diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 13fbcf5..9799d63 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -243,7 +243,7 @@ function M:parse_response(ctx, data_stream, _, opts) if not ctx.function_call_id then ctx.function_call_id = 0 end ctx.function_call_id = ctx.function_call_id + 1 local tool_use = { - id = ctx.session_id .. "-" .. tostring(ctx.function_call_id), + id = ctx.turn_id .. "-" .. tostring(ctx.function_call_id), name = part.functionCall.name, input_json = vim.json.encode(part.functionCall.args), } diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index b79dcac..d4b7c64 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -264,7 +264,6 @@ function M:add_text_message(ctx, text, state, opts) local cleaned_xml_content = table.concat(cleaned_xml_lines, "\n") local stream_parser = XMLParser.createStreamParser() stream_parser:addData(cleaned_xml_content) - local has_tool_use = false local xml = stream_parser:getAllElements() if xml then local new_content_list = {} @@ -318,7 +317,7 @@ function M:add_text_message(ctx, text, state, opts) }, { state = state, uuid = msg_uuid, - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) msgs[#msgs + 1] = msg_ ctx.tool_use_list = ctx.tool_use_list or {} @@ -327,14 +326,13 @@ function M:add_text_message(ctx, text, state, opts) name = item._name, input_json = input, } - has_tool_use = true end if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end ::continue:: end end if opts.on_messages_add then opts.on_messages_add(msgs) end - if has_tool_use and state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end + -- if has_tool_use and state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end end function M:add_thinking_message(ctx, text, state, opts) @@ -352,7 +350,7 @@ function M:add_thinking_message(ctx, text, state, opts) }, { state = state, uuid = ctx.reasonging_content_uuid, - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) ctx.reasonging_content_uuid = msg.uuid if opts.on_messages_add then opts.on_messages_add({ msg }) end @@ -373,38 +371,25 @@ function M:add_tool_use_message(ctx, tool_use, state, opts) }, { state = state, uuid = tool_use.uuid, - session_id = ctx.session_id, + turn_id = ctx.turn_id, }) tool_use.uuid = msg.uuid tool_use.state = state if opts.on_messages_add then opts.on_messages_add({ msg }) end - if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end + -- if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end end function M:parse_response(ctx, data_stream, _, opts) - local orig_on_stop = opts.on_stop - local stopped = false - ---@param stop_opts AvanteLLMStopCallbackOptions - opts.on_stop = function(stop_opts) - if stop_opts and not stop_opts.streaming_tool_use then - if stopped then return end - stopped = true - end - return orig_on_stop(stop_opts) - end - if data_stream:match('"%[DONE%]":') then + if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then self:finish_pending_messages(ctx, opts) if ctx.tool_use_list and #ctx.tool_use_list > 0 then + ctx.tool_use_list = {} opts.on_stop({ reason = "tool_use" }) else opts.on_stop({ reason = "complete" }) end return end - if data_stream == "[DONE]" then - opts.on_stop({ reason = "complete" }) - return - end local jsn = vim.json.decode(data_stream) ---@cast jsn AvanteOpenAIChatResponse if not jsn.choices then return end @@ -453,7 +438,7 @@ function M:parse_response(ctx, data_stream, _, opts) else local tool_use = ctx.tool_use_list[tool_call.index + 1] tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments - self:add_tool_use_message(ctx, tool_use, "generating", opts) + -- self:add_tool_use_message(ctx, tool_use, "generating", opts) end end elseif delta.content then diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index b48d5aa..3a49446 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2003,7 +2003,6 @@ end function Sidebar:add_chat_history(messages, options) options = options or {} messages = vim.islist(messages) and messages or { messages } - self:reload_chat_history() local is_first_user = true local history_messages = {} for _, message in ipairs(messages) do @@ -2191,44 +2190,6 @@ function Sidebar:get_history_messages_for_api(opts) history_messages0 = picked_messages end - local picked_messages = {} - local max_tool_use_count = 15 - local tool_use_count = 0 - for idx = #history_messages0, 1, -1 do - local msg = history_messages0[idx] - if tool_use_count > max_tool_use_count then - if Utils.is_tool_result_message(msg) then - local tool_use_message = Utils.get_tool_use_message(msg, history_messages0) - if tool_use_message then - local msg_content = {} - table.insert( - msg_content, - string.format( - "Tool use %s(%s)", - tool_use_message.message.content[1].name, - vim.json.encode(tool_use_message.message.content[1].input) - ) - ) - table.insert(msg_content, string.format("Result: %s", msg.message.content[1].content)) - table.insert( - picked_messages, - 1, - HistoryMessage:new({ role = "user", content = msg_content }, { is_dummy = true }) - ) - end - elseif Utils.is_tool_use_message(history_messages0[idx]) then - tool_use_count = tool_use_count + 1 - goto continue - else - table.insert(picked_messages, 1, msg) - end - else - if Utils.is_tool_use_message(history_messages0[idx]) then tool_use_count = tool_use_count + 1 end - table.insert(picked_messages, 1, msg) - end - ::continue:: - end - local tool_id_to_tool_name = {} local tool_id_to_path = {} local tool_id_to_start_line = {} @@ -2241,6 +2202,7 @@ function Sidebar:get_history_messages_for_api(opts) for idx, message in ipairs(history_messages0) do if Utils.is_tool_result_message(message) then local tool_use_message = Utils.get_tool_use_message(message, history_messages0) + local is_edit_func_call, _, _, path = Utils.is_edit_func_call_message(tool_use_message) local tool_result = message.message.content[1] @@ -2264,8 +2226,8 @@ function Sidebar:get_history_messages_for_api(opts) end for idx, message in ipairs(history_messages0) do - if Utils.is_tool_use_message(message) and failed_edit_tool_ids[message.message.content[1].id] then - goto continue + if Utils.is_tool_use_message(message) then + if failed_edit_tool_ids[message.message.content[1].id] then goto continue end end table.insert(history_messages, message) if Utils.is_tool_result_message(message) then @@ -2422,6 +2384,62 @@ function Sidebar:get_history_messages_for_api(opts) end end + local picked_messages = {} + local max_tool_use_count = 10 + local tool_use_count = 0 + for idx = #history_messages, 1, -1 do + local msg = history_messages[idx] + if tool_use_count > max_tool_use_count then + if Utils.is_tool_result_message(msg) then + local tool_use_message = Utils.get_tool_use_message(msg, history_messages) + if tool_use_message then + table.insert( + picked_messages, + 1, + HistoryMessage:new({ + role = "user", + content = { + { + type = "text", + text = string.format( + "Tool use [%s] is successful: %s", + tool_use_message.message.content[1].name, + tostring(not msg.message.content[1].is_error) + ), + }, + }, + }, { is_dummy = true }) + ) + local msg_content = {} + table.insert(msg_content, { + type = "text", + text = string.format( + "Tool use %s(%s)", + tool_use_message.message.content[1].name, + vim.json.encode(tool_use_message.message.content[1].input) + ), + }) + table.insert( + picked_messages, + 1, + HistoryMessage:new({ role = "assistant", content = msg_content }, { is_dummy = true }) + ) + end + elseif Utils.is_tool_use_message(msg) then + tool_use_count = tool_use_count + 1 + goto continue + else + table.insert(picked_messages, 1, msg) + end + else + if Utils.is_tool_use_message(msg) then tool_use_count = tool_use_count + 1 end + table.insert(picked_messages, 1, msg) + end + ::continue:: + end + + history_messages = picked_messages + local final_history_messages = {} for _, msg in ipairs(history_messages) do local tool_result_message @@ -2704,9 +2722,6 @@ function Sidebar:create_input_container() Path.history.save(self.code.bufnr, self.chat_history) end - local history_messages = Utils.get_history_messages(self.chat_history) - local is_first_request = #history_messages == 0 - if request and request ~= "" then self:add_history_messages({ HistoryMessage:new({ diff --git a/lua/avante/types.lua b/lua/avante/types.lua index d94308d..db90670 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -110,7 +110,7 @@ vim.g.avante_login = vim.g.avante_login ---@field is_dummy boolean | nil ---@field is_compacted boolean | nil ---@field is_deleted boolean | nil ----@field session_id string | nil +---@field turn_id string | nil --- ---@class AvanteLLMToolResult ---@field tool_name string @@ -278,7 +278,7 @@ vim.g.avante_login = vim.g.avante_login ---@field usage? AvanteLLMUsage --- ---@class AvanteLLMStopCallbackOptions ----@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled" +---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled" | "max_tokens" ---@field error? string | table ---@field usage? AvanteLLMUsage ---@field retry_after? integer diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index d92b793..95987e1 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -993,10 +993,14 @@ function M.open_buffer(path, set_current_buf) local abs_path = M.join_paths(M.get_project_root(), path) - local bufnr = vim.fn.bufnr(abs_path, true) - vim.fn.bufload(bufnr) - - if set_current_buf then vim.api.nvim_set_current_buf(bufnr) end + local bufnr + if set_current_buf then + vim.cmd("noautocmd edit " .. abs_path) + bufnr = vim.api.nvim_get_current_buf() + else + bufnr = vim.fn.bufnr(abs_path, true) + pcall(vim.fn.bufload, bufnr) + end vim.cmd("filetype detect") @@ -1480,10 +1484,6 @@ function M.is_edit_func_call_tool_use(tool_use) local is_str_replace_editor_func_call = false local is_str_replace_based_edit_tool_func_call = false local path = nil - if tool_use.name == "write_to_file" then - is_replace_func_call = true - path = tool_use.input.path - end if tool_use.name == "replace_in_file" then is_replace_func_call = true path = tool_use.input.path @@ -1711,10 +1711,17 @@ end ---@param history_messages avante.HistoryMessage[] ---@return AvantePartialLLMToolUse[] function M.get_uncalled_tool_uses(history_messages) - local partial_tool_use_list = {} ---@type AvantePartialLLMToolUse[] + local last_turn_id = nil + if #history_messages > 0 then last_turn_id = history_messages[#history_messages].turn_id end + local uncalled_tool_use_list = {} ---@type AvantePartialLLMToolUse[] local tool_result_seen = {} for idx = #history_messages, 1, -1 do local message = history_messages[idx] + if last_turn_id then + if message.turn_id ~= last_turn_id then break end + else + if not M.is_tool_use_message(message) and not M.is_tool_result_message(message) then break end + end local content = message.message.content if type(content) ~= "table" or #content == 0 then goto continue end local is_break = false @@ -1727,7 +1734,7 @@ function M.get_uncalled_tool_uses(history_messages) input = item.input, state = message.state, } - table.insert(partial_tool_use_list, 1, partial_tool_use) + table.insert(uncalled_tool_use_list, 1, partial_tool_use) else is_break = true break @@ -1738,7 +1745,7 @@ function M.get_uncalled_tool_uses(history_messages) if is_break then break end ::continue:: end - return partial_tool_use_list + return uncalled_tool_use_list end function M.call_once(func)