From f10b8383e323f08551942fa66dbabae6c34f0350 Mon Sep 17 00:00:00 2001 From: yetone Date: Wed, 30 Apr 2025 03:07:18 +0800 Subject: [PATCH] refactor: history messages (#1934) --- README.md | 25 +- lua/avante/api.lua | 20 +- lua/avante/config.lua | 4 +- lua/avante/highlights.lua | 6 + lua/avante/history_message.lua | 28 + lua/avante/history_selector.lua | 5 +- lua/avante/llm.lua | 381 ++++-- lua/avante/llm_tools/create.lua | 2 +- lua/avante/llm_tools/dispatch_agent.lua | 19 +- lua/avante/llm_tools/init.lua | 6 +- lua/avante/llm_tools/insert.lua | 2 +- lua/avante/llm_tools/replace_in_file.lua | 481 +++++++ lua/avante/llm_tools/str_replace.lua | 21 +- lua/avante/llm_tools/undo_edit.lua | 2 +- lua/avante/llm_tools/view.lua | 11 - lua/avante/path.lua | 20 +- lua/avante/providers/bedrock.lua | 1 + lua/avante/providers/claude.lua | 247 ++-- lua/avante/providers/copilot.lua | 6 +- lua/avante/providers/gemini.lua | 101 +- lua/avante/providers/init.lua | 17 +- lua/avante/providers/ollama.lua | 8 +- lua/avante/providers/openai.lua | 235 ++-- lua/avante/sidebar.lua | 1168 +++++------------ lua/avante/templates/agentic.avanterules | 10 + .../claude-text-editor-tool.avanterules | 6 - .../templates/cursor-applying.avanterules | 4 - .../templates/cursor-planning.avanterules | 46 - ...lanning.avanterules => legacy.avanterules} | 0 lua/avante/types.lua | 57 +- .../ui/selector/providers/telescope.lua | 2 +- lua/avante/utils/history.lua | 87 -- lua/avante/utils/init.lua | 98 +- lua/avante/utils/streaming_json_parser.lua | 18 +- plugin/avante.lua | 8 +- tests/utils/streaming_json_parser_spec.lua | 9 +- 36 files changed, 1699 insertions(+), 1462 deletions(-) create mode 100644 lua/avante/history_message.lua create mode 100644 lua/avante/llm_tools/replace_in_file.lua create mode 100644 lua/avante/templates/agentic.avanterules delete mode 100644 lua/avante/templates/claude-text-editor-tool.avanterules delete mode 100644 lua/avante/templates/cursor-applying.avanterules delete mode 100644 lua/avante/templates/cursor-planning.avanterules rename lua/avante/templates/{planning.avanterules => legacy.avanterules} (100%) delete mode 100644 lua/avante/utils/history.lua diff --git a/README.md b/README.md index 0207305..4faa917 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,8 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_ { ---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | string provider = "claude", -- The provider used in Aider mode or in the planning phase of Cursor Planning Mode + ---@alias Mode "agentic" | "legacy" + mode = "agentic", -- The default mode for interaction. "agentic" uses tools to automatically generate code, "legacy" uses the old planning method to generate code. -- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive, -- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048 -- Of course, you can reduce the request frequency by increasing `suggestion.debounce`. @@ -340,8 +342,6 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_ support_paste_from_clipboard = false, minimize_diff = true, -- Whether to remove unchanged lines when applying a code block enable_token_counting = true, -- Whether to enable token counting. Default to true. - enable_cursor_planning_mode = false, -- Whether to enable Cursor Planning Mode. Default to false. - enable_claude_text_editor_tool_mode = false, -- Whether to enable Claude Text Editor Tool Mode. }, mappings = { --- @class AvanteConflictMappings @@ -734,12 +734,6 @@ Avante provides a set of default providers, but users can also create their own For more information, see [Custom Providers](https://github.com/yetone/avante.nvim/wiki/Custom-providers) -## Cursor planning mode - -Because avante.nvim has always used Aider’s method for planning applying, but its prompts are very picky with models and require ones like claude-3.5-sonnet or gpt-4o to work properly. - -Therefore, I have adopted Cursor’s method to implement planning applying. For details on the implementation, please refer to [cursor-planning-mode.md](./cursor-planning-mode.md) - ## RAG Service Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. By default, it is not enabled. You can enable it this way: @@ -890,21 +884,6 @@ Avante allows you to define custom tools that can be used by the AI during code Now you can integrate MCP functionality for Avante through `mcphub.nvim`. For detailed documentation, please refer to [mcphub.nvim](https://github.com/ravitemer/mcphub.nvim#avante-integration) -## Claude Text Editor Tool Mode - -Avante leverages [Claude Text Editor Tool](https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool) to provide a more elegant code editing experience. You can now enable this feature by setting `enable_claude_text_editor_tool_mode` to `true` in the `behaviour` configuration: - -```lua -{ - behaviour = { - enable_claude_text_editor_tool_mode = true, - }, -} -``` - -> [!NOTE] -> To enable **Claude Text Editor Tool Mode**, you must use the `claude-3-5-sonnet-*` or `claude-3-7-sonnet-*` model with the `claude` provider! This feature is not supported by any other models! - ## Custom prompts By default, `avante.nvim` provides three different modes to interact with: `planning`, `editing`, and `suggesting`, followed with three different prompts per mode. diff --git a/lua/avante/api.lua b/lua/avante/api.lua index 70d5089..6711668 100644 --- a/lua/avante/api.lua +++ b/lua/avante/api.lua @@ -226,16 +226,16 @@ end function M.select_model() require("avante.model_selector").open() end function M.select_history() - require("avante.history_selector").open(vim.api.nvim_get_current_buf(), function(filename) - local Path = require("avante.path") - Path.history.save_latest_filename(vim.api.nvim_get_current_buf(), filename) - local sidebar = require("avante").get() - if not sidebar then - require("avante.api").ask() - sidebar = require("avante").get() - end - sidebar:update_content_with_history() - if not sidebar:is_open() then sidebar:open({}) end + local buf = vim.api.nvim_get_current_buf() + require("avante.history_selector").open(buf, function(filename) + vim.api.nvim_buf_call(buf, function() + if not require("avante").is_sidebar_open() then require("avante").open_sidebar({}) end + local Path = require("avante.path") + Path.history.save_latest_filename(buf, filename) + local sidebar = require("avante").get() + sidebar:update_content_with_history() + vim.schedule(function() sidebar:focus_input() end) + end) end) end diff --git a/lua/avante/config.lua b/lua/avante/config.lua index a7c7be7..33d4b1c 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -19,6 +19,8 @@ local M = {} ---@class avante.Config M._defaults = { debug = false, + ---@alias avante.Mode "agentic" | "legacy" + mode = "agentic", ---@alias avante.ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | string provider = "claude", -- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive, @@ -380,8 +382,6 @@ M._defaults = { support_paste_from_clipboard = false, minimize_diff = true, enable_token_counting = true, - enable_cursor_planning_mode = false, - enable_claude_text_editor_tool_mode = false, use_cwd_as_project_root = false, auto_focus_on_diff_view = false, }, diff --git a/lua/avante/highlights.lua b/lua/avante/highlights.lua index f0ccdbd..99e382f 100644 --- a/lua/avante/highlights.lua +++ b/lua/avante/highlights.lua @@ -40,6 +40,12 @@ local Highlights = { AVANTE_SIDEBAR_NORMAL = { name = "AvanteSidebarNormal", link = "NormalFloat" }, AVANTE_COMMENT_FG = { name = "AvanteCommentFg", fg_link = "Comment" }, AVANTE_REVERSED_NORMAL = { name = "AvanteReversedNormal", fg_link_bg = "Normal", bg_link_fg = "Normal" }, + AVANTE_STATE_SPINNER_GENERATING = { name = "AvanteStateSpinnerGenerating", fg = "#1e222a", bg = "#ab9df2" }, + AVANTE_STATE_SPINNER_TOOL_CALLING = { name = "AvanteStateSpinnerToolCalling", fg = "#1e222a", bg = "#56b6c2" }, + AVANTE_STATE_SPINNER_FAILED = { name = "AvanteStateSpinnerFailed", fg = "#1e222a", bg = "#e06c75" }, + AVANTE_STATE_SPINNER_SUCCEEDED = { name = "AvanteStateSpinnerSucceeded", fg = "#1e222a", bg = "#98c379" }, + AVANTE_STATE_SPINNER_SEARCHING = { name = "AvanteStateSpinnerSearching", fg = "#1e222a", bg = "#c678dd" }, + AVANTE_STATE_SPINNER_THINKING = { name = "AvanteStateSpinnerThinking", fg = "#1e222a", bg = "#c678dd" }, } Highlights.conflict = { diff --git a/lua/avante/history_message.lua b/lua/avante/history_message.lua new file mode 100644 index 0000000..9c630c2 --- /dev/null +++ b/lua/avante/history_message.lua @@ -0,0 +1,28 @@ +local Utils = require("avante.utils") + +---@class avante.HistoryMessage +local M = {} +M.__index = M + +---@param message AvanteLLMMessage +---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean} +---@return avante.HistoryMessage +function M:new(message, opts) + opts = opts or {} + local obj = setmetatable({}, M) + obj.message = message + obj.uuid = opts.uuid or Utils.uuid() + obj.state = opts.state or "generated" + obj.timestamp = Utils.get_timestamp() + obj.is_user_submission = false + obj.visible = true + if opts.is_user_submission ~= nil then obj.is_user_submission = opts.is_user_submission end + if opts.visible ~= nil then obj.visible = opts.visible end + if opts.displayed_content ~= nil then obj.displayed_content = opts.displayed_content end + if opts.selected_filepaths ~= nil then obj.selected_filepaths = opts.selected_filepaths end + if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end + if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end + return obj +end + +return M diff --git a/lua/avante/history_selector.lua b/lua/avante/history_selector.lua index 9272e6a..b3d693a 100644 --- a/lua/avante/history_selector.lua +++ b/lua/avante/history_selector.lua @@ -9,8 +9,9 @@ local M = {} ---@param history avante.ChatHistory ---@return table? local function to_selector_item(history) - local timestamp = #history.entries > 0 and history.entries[#history.entries].timestamp or history.timestamp - local name = history.title .. " - " .. timestamp .. " (" .. #history.entries .. ")" + local messages = Utils.get_history_messages(history) + local timestamp = #messages > 0 and messages[#messages].timestamp or history.timestamp + local name = history.title .. " - " .. timestamp .. " (" .. #messages .. ")" name = name:gsub("\n", "\\n") return { name = name, diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 5eb744d..c8be271 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -10,6 +10,7 @@ local Path = require("avante.path") local Providers = require("avante.providers") local LLMToolHelpers = require("avante.llm_tools.helpers") local LLMTools = require("avante.llm_tools") +local HistoryMessage = require("avante.history_message") ---@class avante.LLM local M = {} @@ -26,7 +27,7 @@ function M.summarize_chat_thread_title(content, cb) local system_prompt = [[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]] local response_content = "" - local provider = Providers[Config.memory_summary_provider or Config.provider] + local provider = Providers.get_memory_summary_provider() M.curl({ provider = provider, prompt_opts = { @@ -58,73 +59,49 @@ function M.summarize_chat_thread_title(content, cb) }) end ----@param messages AvanteLLMMessage[] ----@return AvanteLLMMessage[] -local function filter_out_tool_use_messages(messages) - local filtered_messages = {} - for _, message in ipairs(messages) do - local content = message.content - if type(content) == "table" then - local new_content = {} - for _, item in ipairs(content) do - if item.type == "tool_use" or item.type == "tool_result" then goto continue end - table.insert(new_content, item) - ::continue:: - end - content = new_content - end - if type(content) == "table" then - if #content > 0 then table.insert(filtered_messages, { role = message.role, content = content }) end - else - table.insert(filtered_messages, { role = message.role, content = content }) - end - end - return filtered_messages -end - ---@param bufnr integer ---@param history avante.ChatHistory ----@param entries? avante.ChatHistoryEntry[] +---@param history_messages avante.HistoryMessage[] ---@param cb fun(memory: avante.ChatMemory | nil): nil -function M.summarize_memory(bufnr, history, entries, cb) - local system_prompt = [[You are a helpful AI assistant tasked with summarizing conversations.]] - if not entries then entries = Utils.history.filter_active_entries(history.entries) end - if #entries == 0 then - cb(nil) - return - end - if history.memory then - entries = vim - .iter(entries) - :filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end) - :totable() - end - if #entries == 0 then - cb(history.memory) - return - end - local history_messages = Utils.history.entries_to_llm_messages(entries) - history_messages = filter_out_tool_use_messages(history_messages) - history_messages = vim.list_slice(history_messages, 1, 4) +function M.summarize_memory(bufnr, history, history_messages, cb) + local system_prompt = + [[You are an expert coding assistant. Your goal is to generate a concise, structured summary of the conversation below that captures all essential information needed to continue development after context replacement. Include tasks performed, code areas modified or reviewed, key decisions or assumptions, test results or errors, and outstanding tasks or next steps.]] if #history_messages == 0 then cb(history.memory) return end - Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content) - local user_prompt = - [[Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.]] + local latest_timestamp = history_messages[#history_messages].timestamp + local latest_message_uuid = history_messages[#history_messages].uuid + local conversation_items = vim + .iter(history_messages) + :filter(function(msg) + if msg.just_for_display then return false end + if msg.message.role ~= "assistant" and msg.message.role ~= "user" then return false end + local content = msg.message.content + if type(content) == "table" and content[1].type == "tool_result" then return false end + if type(content) == "table" and content[1].type == "tool_use" then return false end + return true + end) + :map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg) end) + :totable() + local conversation_text = table.concat(conversation_items, "\n") + local user_prompt = "Here is the conversation so far:\n" + .. conversation_text + .. "\n\nPlease summarize this conversation, covering:\n1. Tasks performed and outcomes\n2. Code files, modules, or functions modified or examined\n3. Important decisions or assumptions made\n4. Errors encountered and test or build results\n5. Remaining tasks, open questions, or next steps\nProvide the summary in a clear, concise format." if history.memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content end - table.insert(history_messages, { - role = "user", - content = user_prompt, - }) + local messages = { + { + role = "user", + content = user_prompt, + }, + } local response_content = "" - local provider = Providers[Config.memory_summary_provider or Config.provider] + local provider = Providers.get_memory_summary_provider() M.curl({ provider = provider, prompt_opts = { system_prompt = system_prompt, - messages = history_messages, + messages = messages, }, handler_opts = { on_start = function(_) end, @@ -141,11 +118,14 @@ function M.summarize_memory(bufnr, history, entries, cb) response_content = Utils.trim_think_content(response_content) local memory = { content = response_content, - last_summarized_timestamp = entries[#entries].timestamp, + last_summarized_timestamp = latest_timestamp, + last_message_uuid = latest_message_uuid, } history.memory = memory Path.history.save(bufnr, history) cb(memory) + else + cb(history.memory) end end, }, @@ -156,7 +136,7 @@ end ---@return AvantePromptOptions function M.generate_prompts(opts) local provider = opts.provider or Providers[Config.provider] - local mode = opts.mode or "planning" + local mode = opts.mode or Config.mode ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor local _, request_body = Providers.parse_config(provider) local max_tokens = request_body.max_tokens or 4096 @@ -166,22 +146,58 @@ function M.generate_prompts(opts) if opts.prompt_opts and opts.prompt_opts.image_paths then image_paths = vim.list_extend(image_paths, opts.prompt_opts.image_paths) end - local instructions = opts.instructions - if instructions and instructions:match("image: ") then - local lines = vim.split(opts.instructions, "\n") - for i, line in ipairs(lines) do - if line:match("^image: ") then - local image_path = line:gsub("^image: ", "") - table.insert(image_paths, image_path) - table.remove(lines, i) - end - end - instructions = table.concat(lines, "\n") - end local project_root = Utils.root.get() Path.prompts.initialize(Path.prompts.get_templates_dir(project_root)) + local tool_id_to_tool_name = {} + local tool_id_to_path = {} + local viewed_files = {} + if opts.history_messages then + for _, message in ipairs(opts.history_messages) do + local content = message.message.content + if type(content) ~= "table" then goto continue end + for _, item in ipairs(content) do + if type(item) ~= "table" then goto continue1 end + if item.type ~= "tool_use" then goto continue1 end + local tool_name = item.name + if tool_name ~= "view" then goto continue1 end + local path = item.input.path + tool_id_to_tool_name[item.id] = tool_name + if path then + local uniform_path = Utils.uniform_path(path) + tool_id_to_path[item.id] = uniform_path + viewed_files[uniform_path] = item.id + end + ::continue1:: + end + ::continue:: + end + for _, message in ipairs(opts.history_messages) do + local content = message.message.content + if type(content) == "table" then + for _, item in ipairs(content) do + if type(item) ~= "table" then goto continue end + if item.type ~= "tool_result" then goto continue end + local tool_name = tool_id_to_tool_name[item.tool_use_id] + if tool_name ~= "view" then goto continue end + local path = tool_id_to_path[item.tool_use_id] + local latest_tool_id = viewed_files[path] + if not latest_tool_id then goto continue end + if latest_tool_id ~= item.tool_use_id then + item.content = string.format("The file %s has been updated. Please use the latest view tool result!", path) + else + local lines, error = Utils.read_file_from_buf_or_disk(path) + if error ~= nil then Utils.error("error reading file: " .. error) end + lines = lines or {} + item.content = table.concat(lines, "\n") + end + ::continue:: + end + end + end + end + local system_info = Utils.get_system_info() local selected_files = opts.selected_files or {} @@ -200,6 +216,8 @@ function M.generate_prompts(opts) end end + selected_files = vim.iter(selected_files):filter(function(file) return viewed_files[file.path] == nil end):totable() + local template_opts = { ask = opts.ask, -- TODO: add mode without ask instruction code_lang = opts.code_lang, @@ -229,36 +247,42 @@ function M.generate_prompts(opts) end ---@type AvanteLLMMessage[] - local messages = {} + local context_messages = {} if opts.prompt_opts and opts.prompt_opts.messages then - messages = vim.list_extend(messages, opts.prompt_opts.messages) + context_messages = vim.list_extend(context_messages, opts.prompt_opts.messages) end if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then local project_context = Path.prompts.render_file("_project.avanterules", template_opts) - if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end + if project_context ~= "" then + table.insert(context_messages, { role = "user", content = project_context, visible = false, is_context = true }) + end end if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts) - if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end + if diagnostics ~= "" then + table.insert(context_messages, { role = "user", content = diagnostics, visible = false, is_context = true }) + end end if #selected_files > 0 or opts.selected_code ~= nil then local code_context = Path.prompts.render_file("_context.avanterules", template_opts) - if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end + if code_context ~= "" then + table.insert(context_messages, { role = "user", content = code_context, visible = false, is_context = true }) + end end if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then local memory = Path.prompts.render_file("_memory.avanterules", template_opts) - if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end + if memory ~= "" then + table.insert(context_messages, { role = "user", content = memory, visible = false, is_context = true }) + end end - if instructions then table.insert(messages, { role = "user", content = instructions }) end - local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt) - for _, message in ipairs(messages) do + for _, message in ipairs(context_messages) do remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content) end @@ -267,47 +291,49 @@ function M.generate_prompts(opts) dropped_history_messages = vim.list_extend(dropped_history_messages, opts.prompt_opts.dropped_history_messages) end + local final_history_messages = {} if opts.history_messages then if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end -- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user" local history_messages = {} for i = #opts.history_messages, 1, -1 do local message = opts.history_messages[i] - if Config.history.carried_entry_count ~= nil then - if #history_messages > Config.history.carried_entry_count then break end - table.insert(history_messages, message) + local tokens = Utils.tokens.calculate_tokens(message.message.content) + remaining_tokens = remaining_tokens - tokens + if remaining_tokens > 0 then + table.insert(history_messages, 1, message) else - local tokens = Utils.tokens.calculate_tokens(message.content) - remaining_tokens = remaining_tokens - tokens - if remaining_tokens > 0 then - table.insert(history_messages, message) - else - break - end + break end end + if #history_messages == 0 then + history_messages = vim.list_slice(opts.history_messages, #opts.history_messages - 1, #opts.history_messages) + end + dropped_history_messages = vim.list_slice(opts.history_messages, 1, #opts.history_messages - #history_messages) -- prepend the history messages to the messages table - vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end) - if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end + vim.iter(history_messages):each(function(msg) table.insert(final_history_messages, msg) end) end - if opts.mode == "cursor-applying" then - local user_prompt = [[ -Merge all changes from the snippet into the below. -- Preserve the code's structure, order, comments, and indentation exactly. -- Output only the updated code, enclosed within and tags. -- Do not include any additional text, explanations, placeholders, ellipses, or code fences. + -- Utils.debug("opts.history_messages", opts.history_messages) + -- Utils.debug("final_history_messages", final_history_messages) -]] - user_prompt = user_prompt .. string.format("\n%s\n\n", opts.original_code) - for _, snippet in ipairs(opts.update_snippets) do - user_prompt = user_prompt .. string.format("\n%s\n\n", snippet) - end - user_prompt = user_prompt .. "Provide the complete updated code." - table.insert(messages, { role = "user", content = user_prompt }) + ---@type AvanteLLMMessage[] + local messages = vim.deepcopy(context_messages) + for _, msg in ipairs(final_history_messages) do + local message = msg.message + table.insert(messages, message) + end + + messages = vim + .iter(messages) + :filter(function(msg) return type(msg.content) ~= "string" or msg.content ~= "" end) + :totable() + + if opts.instructions ~= nil and opts.instructions ~= "" then + messages = vim.list_extend(messages, { { role = "user", content = opts.instructions } }) end opts.session_ctx = opts.session_ctx or {} @@ -318,19 +344,12 @@ Merge all changes from the snippet into the below. if opts.tools then tools = vim.list_extend(tools, opts.tools) end if opts.prompt_opts and opts.prompt_opts.tools then tools = vim.list_extend(tools, opts.prompt_opts.tools) end - local tool_histories = {} - if opts.tool_histories then tool_histories = vim.list_extend(tool_histories, opts.tool_histories) end - if opts.prompt_opts and opts.prompt_opts.tool_histories then - tool_histories = vim.list_extend(tool_histories, opts.prompt_opts.tool_histories) - end - ---@type AvantePromptOptions return { system_prompt = system_prompt, messages = messages, image_paths = image_paths, tools = tools, - tool_histories = tool_histories, dropped_history_messages = dropped_history_messages, } end @@ -372,7 +391,9 @@ function M.curl(opts) ---@type string local current_event_state = nil local resp_ctx = {} + resp_ctx.session_id = Utils.uuid() + local response_body = "" ---@param line string local function parse_stream_data(line) local event = line:match("^event:%s*(.+)$") @@ -381,7 +402,16 @@ function M.curl(opts) return end local data_match = line:match("^data:%s*(.+)$") - if data_match then provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) end + if data_match then + provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) + else + response_body = response_body .. line + local ok, jsn = pcall(vim.json.decode, response_body) + if ok then + response_body = "" + if jsn.error then handler_opts.on_stop({ reason = "error", error = jsn.error }) end + end + end end local function parse_response_without_stream(data) @@ -394,10 +424,13 @@ function M.curl(opts) local temp_file = fn.tempname() local curl_body_file = temp_file .. "-request-body.json" + local resp_body_file = temp_file .. "-response-body.json" local json_content = vim.json.encode(spec.body) fn.writefile(vim.split(json_content, "\n"), curl_body_file) - Utils.debug("curl body file:", curl_body_file) + Utils.debug("curl request body file:", curl_body_file) + + Utils.debug("curl response body file:", resp_body_file) local headers_file = temp_file .. "-headers.txt" @@ -407,6 +440,7 @@ function M.curl(opts) if Config.debug then return end vim.schedule(function() fn.delete(curl_body_file) + pcall(fn.delete, resp_body_file) fn.delete(headers_file) end) end @@ -431,6 +465,15 @@ function M.curl(opts) return end if not data then return end + if Config.debug then + if type(data) == "string" then + local file = io.open(resp_body_file, "a") + if file then + file:write(data .. "\n") + file:close() + end + end + end vim.schedule(function() if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then if provider.parse_response ~= nil then @@ -495,6 +538,17 @@ function M.curl(opts) Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" }) end if result.status == 429 then + Utils.debug("result", result) + if result.body then + local ok, jsn = pcall(vim.json.decode, result.body) + if ok then + if jsn.error and jsn.error.message then + handler_opts.on_stop({ reason = "error", error = jsn.error.message }) + return + end + end + end + Utils.debug("result", result) local retry_after = 10 if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after }) @@ -585,17 +639,34 @@ function M._stream(opts) ---@type AvanteHandlerOptions local handler_opts = { - on_partial_tool_use = opts.on_partial_tool_use, + on_messages_add = opts.on_messages_add, + on_state_change = opts.on_state_change, on_start = opts.on_start, on_chunk = opts.on_chunk, on_stop = function(stop_opts) ---@param tool_use_list AvanteLLMToolUse[] ---@param tool_use_index integer - ---@param tool_histories AvanteLLMToolHistory[] - local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories) + ---@param tool_results AvanteLLMToolResult[] + local function handle_next_tool_use(tool_use_list, tool_use_index, tool_results) if tool_use_index > #tool_use_list then + ---@type avante.HistoryMessage[] + local messages = {} + for _, tool_result in ipairs(tool_results) do + messages[#messages + 1] = HistoryMessage:new({ + role = "user", + content = { + { + type = "tool_result", + tool_use_id = tool_result.tool_use_id, + content = tool_result.content, + is_error = tool_result.is_error, + }, + }, + }) + end + opts.on_messages_add(messages) local new_opts = vim.tbl_deep_extend("force", opts, { - tool_histories = tool_histories, + history_messages = opts.get_history_messages(), }) if provider.get_rate_limit_sleep_time then local sleep_time = provider:get_rate_limit_sleep_time(resp_headers) @@ -616,7 +687,7 @@ function M._stream(opts) if error == LLMToolHelpers.CANCEL_TOKEN then Utils.debug("Tool execution was cancelled by user") opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n") - return opts.on_stop({ reason = "cancelled", tool_histories = tool_histories }) + return opts.on_stop({ reason = "cancelled" }) end local tool_result = { @@ -624,8 +695,8 @@ function M._stream(opts) content = error ~= nil and error or result, is_error = error ~= nil, } - table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use }) - return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories) + table.insert(tool_results, tool_result) + return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_results) end -- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil local result, error = LLMTools.process_tool_use( @@ -638,20 +709,53 @@ function M._stream(opts) if result ~= nil or error ~= nil then return handle_tool_result(result, error) end end if stop_opts.reason == "cancelled" then - opts.on_chunk("\n*[Request cancelled by user.]*\n") - return opts.on_stop({ reason = "cancelled", tool_histories = opts.tool_histories }) + if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user.]*\n") end + if opts.on_messages_add then + local message = HistoryMessage:new({ + role = "user", + content = "[Request cancelled by user.]", + }) + opts.on_messages_add({ message }) + end + return opts.on_stop({ reason = "cancelled" }) end - if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then - local old_tool_histories = vim.deepcopy(opts.tool_histories) or {} + if stop_opts.reason == "tool_use" then + local tool_use_list = {} ---@type AvanteLLMToolUse[] + local tool_result_seen = {} + local history_messages = opts.get_history_messages and opts.get_history_messages() or {} + for idx = #history_messages, 1, -1 do + local message = history_messages[idx] + local content = message.message.content + if type(content) ~= "table" or #content == 0 then goto continue end + if content[1].type == "tool_use" then + if not tool_result_seen[content[1].id] then + table.insert(tool_use_list, 1, content[1]) + else + break + end + end + if content[1].type == "tool_result" then tool_result_seen[content[1].tool_use_id] = true end + ::continue:: + end local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[] - for _, tool_use in vim.spairs(stop_opts.tool_use_list) do + for _, tool_use in vim.spairs(tool_use_list) do table.insert(sorted_tool_use_list, tool_use) end - return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories) + return handle_next_tool_use(sorted_tool_use_list, 1, {}) end if stop_opts.reason == "rate_limit" then - local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..." - opts.on_chunk("\n*[" .. msg .. "]*\n") + local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*" + if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end + local message + if opts.on_messages_add then + message = HistoryMessage:new({ + role = "assistant", + content = "\n\n" .. msg_content, + }, { + just_for_display = true, + }) + opts.on_messages_add({ message }) + end local timer = vim.loop.new_timer() if timer then local retry_after = stop_opts.retry_after @@ -661,8 +765,12 @@ function M._stream(opts) 0, vim.schedule_wrap(function() if retry_after > 0 then retry_after = retry_after - 1 end - local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..." - opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n") + local msg_content_ = "*[Rate limit reached. Retrying in " .. retry_after .. " seconds ...]*" + if opts.on_chunk then opts.on_chunk([[\033[1A\033[K]] .. "\n" .. msg_content_ .. "\n") end + if opts.on_messages_add and message then + message.message.content = "\n\n" .. msg_content_ + opts.on_messages_add({ message }) + end countdown() end) ) @@ -676,7 +784,6 @@ function M._stream(opts) end, stop_opts.retry_after * 1000) return end - stop_opts.tool_histories = opts.tool_histories return opts.on_stop(stop_opts) end, } @@ -697,6 +804,8 @@ local function _merge_response(first_response, second_response, opts) prompt = prompt .. "\n" + if opts.instructions == nil then opts.instructions = "" end + -- append this reference prompt to the prompt_opts messages at last opts.instructions = opts.instructions .. prompt @@ -802,20 +911,12 @@ function M.stream(opts) return original_on_stop(stop_opts) end) end - if opts.on_partial_tool_use ~= nil then - local original_on_partial_tool_use = opts.on_partial_tool_use - opts.on_partial_tool_use = vim.schedule_wrap(function(tool_use) - if is_completed then return end - return original_on_partial_tool_use(tool_use) - end) - end local valid_dual_boost_modes = { - planning = true, - ["cursor-planning"] = true, + legacy = true, } - opts.mode = opts.mode or "planning" + opts.mode = opts.mode or Config.mode if Config.dual_boost.enabled and valid_dual_boost_modes[opts.mode] then M._dual_boost_stream( diff --git a/lua/avante/llm_tools/create.lua b/lua/avante/llm_tools/create.lua index 4653268..4c912ae 100644 --- a/lua/avante/llm_tools/create.lua +++ b/lua/avante/llm_tools/create.lua @@ -10,7 +10,7 @@ M.name = "create" M.description = "The create tool allows you to create a new file with specified content." -function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end +function M.enabled() return require("avante.config").mode == "agentic" end ---@type AvanteLLMToolParam M.param = { diff --git a/lua/avante/llm_tools/dispatch_agent.lua b/lua/avante/llm_tools/dispatch_agent.lua index bd14900..70ffb7d 100644 --- a/lua/avante/llm_tools/dispatch_agent.lua +++ b/lua/avante/llm_tools/dispatch_agent.lua @@ -71,7 +71,7 @@ function M.func(opts, on_log, on_complete, session_ctx) if not on_complete then return false, "on_complete not provided" end local prompt = opts.prompt local tools = get_available_tools() - local start_time = os.date("%Y-%m-%d %H:%M:%S") + local start_time = Utils.get_timestamp() if on_log then on_log("prompt: " .. prompt) end @@ -84,6 +84,8 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub messages = messages or {} table.insert(messages, { role = "user", content = prompt }) + local tool_use_messages = {} + local total_tokens = 0 local final_response = "" Llm._stream({ @@ -93,6 +95,15 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub on_tool_log = function(tool_id, tool_name, log, state) if on_log then on_log(string.format("[%s] %s", tool_name, log)) end end, + on_messages_add = function(msgs) + msgs = vim.is_list(msgs) and msgs or { msgs } + for _, msg in ipairs(msgs) do + local content = msg.message.content + if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then + tool_use_messages[msg.uuid] = true + end + end + end, session_ctx = session_ctx, prompt_opts = { system_prompt = system_prompt, @@ -111,9 +122,9 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub on_complete(err, nil) return end - local end_time = os.date("%Y-%m-%d %H:%M:%S") - local elapsed_time = Utils.datetime_diff(tostring(start_time), tostring(end_time)) - local tool_use_count = stop_opts.tool_histories and #stop_opts.tool_histories or 0 + local end_time = Utils.get_timestamp() + local elapsed_time = Utils.datetime_diff(start_time, end_time) + local tool_use_count = vim.tbl_count(tool_use_messages) local summary = "Done (" .. (tool_use_count <= 1 and "1 tool use" or tool_use_count .. " tool uses") .. " · " diff --git a/lua/avante/llm_tools/init.lua b/lua/avante/llm_tools/init.lua index e833697..f3b9bec 100644 --- a/lua/avante/llm_tools/init.lua +++ b/lua/avante/llm_tools/init.lua @@ -598,6 +598,7 @@ end ---@type AvanteLLMTool[] M._tools = { + require("avante.llm_tools.replace_in_file"), require("avante.llm_tools.dispatch_agent"), require("avante.llm_tools.glob"), { @@ -1104,7 +1105,7 @@ M._tools = { ---@return string | nil result ---@return string | nil error function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx) - Utils.debug("use tool", tool_use.name, tool_use.input_json) + -- Utils.debug("use tool", tool_use.name, tool_use.input_json) -- Check if execution is already cancelled if Helpers.is_cancelled then @@ -1125,8 +1126,7 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx) if tool == nil then return nil, "This tool is not provided: " .. tool_use.name end func = tool.func or M[tool.name] end - local ok, input_json = pcall(vim.json.decode, tool_use.input_json) - if not ok then return nil, "Failed to decode tool input json: " .. vim.inspect(input_json) end + local input_json = tool_use.input if not func then return nil, "Tool not found: " .. tool_use.name end if on_log then on_log(tool_use.id, tool_use.name, "running tool", "running") end diff --git a/lua/avante/llm_tools/insert.lua b/lua/avante/llm_tools/insert.lua index 5d599fd..60fc3c2 100644 --- a/lua/avante/llm_tools/insert.lua +++ b/lua/avante/llm_tools/insert.lua @@ -10,7 +10,7 @@ M.name = "insert" M.description = "The insert tool allows you to insert text at a specific location in a file." -function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end +function M.enabled() return require("avante.config").mode == "agentic" end ---@type AvanteLLMToolParam M.param = { diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua new file mode 100644 index 0000000..f15bb4b --- /dev/null +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -0,0 +1,481 @@ +local Base = require("avante.llm_tools.base") +local Helpers = require("avante.llm_tools.helpers") +local Utils = require("avante.utils") +local Highlights = require("avante.highlights") +local Config = require("avante.config") + +local PRIORITY = (vim.hl or vim.highlight).priorities.user +local NAMESPACE = vim.api.nvim_create_namespace("avante-diff") +local KEYBINDING_NAMESPACE = vim.api.nvim_create_namespace("avante-diff-keybinding") + +---@class AvanteLLMTool +local M = setmetatable({}, Base) + +M.name = "replace_in_file" + +M.description = + "Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file." + +---@type AvanteLLMToolParam +M.param = { + type = "table", + fields = { + { + name = "path", + description = "The path to the file in the current project scope", + type = "string", + }, + { + name = "diff", + description = [[ +One or more SEARCH/REPLACE blocks following this exact format: + \`\`\` + <<<<<<< SEARCH + [exact content to find] + ======= + [new content to replace with] + >>>>>>> REPLACE + \`\`\` + Critical rules: + 1. SEARCH content must match the associated file section to find EXACTLY: + * Match character-for-character including whitespace, indentation, line endings + * Include all comments, docstrings, etc. + 2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence. + * Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes. + * Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change. + * When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file. + 3. Keep SEARCH/REPLACE blocks concise: + * Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file. + * Include just the changing lines, and a few surrounding lines if needed for uniqueness. + * Do not include long runs of unchanging lines in SEARCH/REPLACE blocks. + * Each line must be complete. Never truncate lines mid-way through as this can cause matching failures. + 4. Special operations: + * To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location) + * To delete code: Use empty REPLACE section + ]], + type = "string", + }, + }, +} + +---@type AvanteLLMToolReturn[] +M.returns = { + { + name = "success", + description = "True if the replacement was successful, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the replacement failed", + type = "string", + optional = true, + }, +} + +---@type AvanteLLMToolFunc<{ path: string, diff: string }> +function M.func(opts, on_log, on_complete, session_ctx) + if not opts.path or not opts.diff then return false, "path and diff are required" end + if on_log then on_log("path: " .. opts.path) end + local abs_path = Helpers.get_abs_path(opts.path) + if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + + local diff_lines = vim.split(opts.diff, "\n") + local is_searching = false + local is_replacing = false + local current_search = {} + local current_replace = {} + local rough_diff_blocks = {} + + for _, line in ipairs(diff_lines) do + if line:match("^%s*<<<<<<< SEARCH") then + is_searching = true + is_replacing = false + current_search = {} + elseif line:match("^%s*=======") and is_searching then + is_searching = false + is_replacing = true + current_replace = {} + elseif line:match("^%s*>>>>>>> REPLACE") and is_replacing then + is_replacing = false + table.insert( + rough_diff_blocks, + { search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") } + ) + elseif is_searching then + table.insert(current_search, line) + elseif is_replacing then + table.insert(current_replace, line) + end + end + + if #rough_diff_blocks == 0 then return false, "No diff blocks found" end + + local bufnr, err = Helpers.get_bufnr(abs_path) + if err then return false, err end + local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local sidebar = require("avante").get() + if not sidebar then return false, "Avante sidebar not found" end + + local function parse_rough_diff_block(rough_diff_block, current_lines) + local old_lines = vim.split(rough_diff_block.search, "\n") + local new_lines = vim.split(rough_diff_block.replace, "\n") + local start_line, end_line + for i = 1, #current_lines - #old_lines + 1 do + local match = true + for j = 1, #old_lines do + if Utils.remove_indentation(current_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then + match = false + break + end + end + if match then + start_line = i + end_line = i + #old_lines - 1 + break + end + end + if start_line == nil or end_line == nil then + return "Failed to find the old string:\n" .. rough_diff_block.search + end + local old_str = rough_diff_block.search + local new_str = rough_diff_block.replace + local original_indentation = Utils.get_indentation(current_lines[start_line]) + if original_indentation ~= Utils.get_indentation(old_lines[1]) then + old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines) + new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines) + old_str = table.concat(old_lines, "\n") + new_str = table.concat(new_lines, "\n") + end + rough_diff_block.old_lines = old_lines + rough_diff_block.new_lines = new_lines + rough_diff_block.search = old_str + rough_diff_block.replace = new_str + rough_diff_block.start_line = start_line + rough_diff_block.end_line = end_line + return nil + end + + local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_) + local res = {} + local base_line_ = 0 + for _, rough_diff_block in ipairs(rough_diff_blocks_) do + ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields + local patch = vim.diff(rough_diff_block.search, rough_diff_block.replace, { ---@type integer[][] + algorithm = "histogram", + result_type = "indices", + ctxlen = vim.o.scrolloff, + }) + for _, hunk in ipairs(patch) do + local start_a, count_a, start_b, count_b = unpack(hunk) + local diff_block = {} + if count_a > 0 then + diff_block.old_lines = vim.list_slice(rough_diff_block.old_lines, start_a, start_a + count_a - 1) + else + diff_block.old_lines = {} + end + if count_b > 0 then + diff_block.new_lines = vim.list_slice(rough_diff_block.new_lines, start_b, start_b + count_b - 1) + else + diff_block.new_lines = {} + end + if count_a > 0 then + diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a - 1 + else + diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a + end + diff_block.end_line = base_line_ + rough_diff_block.start_line + start_a + math.max(count_a, 1) - 2 + diff_block.search = table.concat(diff_block.old_lines, "\n") + diff_block.replace = table.concat(diff_block.new_lines, "\n") + table.insert(res, diff_block) + end + + local distance = 0 + for _, hunk in ipairs(patch) do + local _, count_a, _, count_b = unpack(hunk) + distance = distance + count_b - count_a + end + + local old_distance = #rough_diff_block.new_lines - #rough_diff_block.old_lines + + base_line_ = base_line_ + distance - old_distance + end + return res + end + + for _, rough_diff_block in ipairs(rough_diff_blocks) do + local error = parse_rough_diff_block(rough_diff_block, original_lines) + if error then + on_complete(false, error) + return + end + end + + local diff_blocks = rough_diff_blocks_to_diff_blocks(rough_diff_blocks) + + table.sort(diff_blocks, function(a, b) return a.start_line < b.start_line end) + + local base_line = 0 + for _, diff_block in ipairs(diff_blocks) do + diff_block.new_start_line = diff_block.start_line + base_line + diff_block.new_end_line = diff_block.new_start_line + #diff_block.new_lines - 1 + base_line = base_line + #diff_block.new_lines - #diff_block.old_lines + end + + local function remove_diff_block(removed_idx, use_new_lines) + local new_diff_blocks = {} + local distance = 0 + for idx, diff_block in ipairs(diff_blocks) do + if idx == removed_idx then + if not use_new_lines then distance = #diff_block.old_lines - #diff_block.new_lines end + goto continue + end + if idx > removed_idx then + diff_block.new_start_line = diff_block.new_start_line + distance + diff_block.new_end_line = diff_block.new_end_line + distance + end + table.insert(new_diff_blocks, diff_block) + ::continue:: + end + + diff_blocks = new_diff_blocks + end + + local function get_current_diff_block() + local winid = Utils.get_winid(bufnr) + local cursor_line = Utils.get_cursor_pos(winid) + for idx, diff_block in ipairs(diff_blocks) do + if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then + return diff_block, idx + end + end + return nil, nil + end + + local function get_prev_diff_block() + local winid = Utils.get_winid(bufnr) + local cursor_line = Utils.get_cursor_pos(winid) + local distance = nil + local idx = nil + for i, diff_block in ipairs(diff_blocks) do + if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then + local new_i = i - 1 + if new_i < 1 then return diff_blocks[#diff_blocks] end + return diff_blocks[new_i] + end + if diff_block.new_start_line < cursor_line then + local distance_ = cursor_line - diff_block.new_start_line + if distance == nil or distance_ < distance then + distance = distance_ + idx = i + end + end + end + if idx ~= nil then return diff_blocks[idx] end + if #diff_blocks > 0 then return diff_blocks[#diff_blocks] end + return nil + end + + local function get_next_diff_block() + local winid = Utils.get_winid(bufnr) + local cursor_line = Utils.get_cursor_pos(winid) + local distance = nil + local idx = nil + for i, diff_block in ipairs(diff_blocks) do + if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then + local new_i = i + 1 + if new_i > #diff_blocks then return diff_blocks[1] end + return diff_blocks[new_i] + end + if diff_block.new_start_line > cursor_line then + local distance_ = diff_block.new_start_line - cursor_line + if distance == nil or distance_ < distance then + distance = distance_ + idx = i + end + end + end + if idx ~= nil then return diff_blocks[idx] end + if #diff_blocks > 0 then return diff_blocks[1] end + return nil + end + + local show_keybinding_hint_extmark_id = nil + local function register_cursor_move_events() + local function show_keybinding_hint(lnum) + if show_keybinding_hint_extmark_id then + vim.api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id) + end + + local hint = string.format( + "[<%s>: OURS, <%s>: THEIRS, <%s>: PREV, <%s>: NEXT]", + Config.mappings.diff.ours, + Config.mappings.diff.theirs, + Config.mappings.diff.prev, + Config.mappings.diff.next + ) + + show_keybinding_hint_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, KEYBINDING_NAMESPACE, lnum - 1, -1, { + hl_group = "AvanteInlineHint", + virt_text = { { hint, "AvanteInlineHint" } }, + virt_text_pos = "right_align", + priority = PRIORITY, + }) + end + + vim.api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI", "WinLeave" }, { + buffer = bufnr, + callback = function(event) + local diff_block = get_current_diff_block() + if (event.event == "CursorMoved" or event.event == "CursorMovedI") and diff_block then + show_keybinding_hint(diff_block.new_start_line) + else + vim.api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1) + end + end, + }) + end + + local function register_keybinding_events() + vim.keymap.set({ "n", "v" }, Config.mappings.diff.ours, function() + if vim.api.nvim_get_current_buf() ~= bufnr then return end + local diff_block, idx = get_current_diff_block() + if not diff_block then return end + pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.delete_extmark_id) + pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.incoming_extmark_id) + vim.api.nvim_buf_set_lines( + bufnr, + diff_block.new_start_line - 1, + diff_block.new_end_line, + false, + diff_block.old_lines + ) + diff_block.incoming_extmark_id = nil + diff_block.delete_extmark_id = nil + remove_diff_block(idx, false) + local next_diff_block = get_next_diff_block() + if next_diff_block then + local winnr = Utils.get_winid(bufnr) + vim.api.nvim_win_set_cursor(winnr, { next_diff_block.new_start_line, 0 }) + vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) + end + end) + + vim.keymap.set({ "n", "v" }, Config.mappings.diff.theirs, function() + if vim.api.nvim_get_current_buf() ~= bufnr then return end + local diff_block, idx = get_current_diff_block() + if not diff_block then return end + pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.incoming_extmark_id) + pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.delete_extmark_id) + diff_block.incoming_extmark_id = nil + diff_block.delete_extmark_id = nil + remove_diff_block(idx, true) + local next_diff_block = get_next_diff_block() + if next_diff_block then + local winnr = Utils.get_winid(bufnr) + vim.api.nvim_win_set_cursor(winnr, { next_diff_block.new_start_line, 0 }) + vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) + end + end) + + vim.keymap.set({ "n", "v" }, Config.mappings.diff.next, function() + if vim.api.nvim_get_current_buf() ~= bufnr then return end + local diff_block = get_next_diff_block() + if not diff_block then return end + local winnr = Utils.get_winid(bufnr) + vim.api.nvim_win_set_cursor(winnr, { diff_block.new_start_line, 0 }) + vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) + end) + + vim.keymap.set({ "n", "v" }, Config.mappings.diff.prev, function() + if vim.api.nvim_get_current_buf() ~= bufnr then return end + local diff_block = get_prev_diff_block() + if not diff_block then return end + local winnr = Utils.get_winid(bufnr) + vim.api.nvim_win_set_cursor(winnr, { diff_block.new_start_line, 0 }) + vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) + end) + end + + local function unregister_keybinding_events() + pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.ours) + pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.theirs) + pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.next) + pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.prev) + pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.ours) + pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.theirs) + pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.next) + pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.prev) + end + + local augroup = vim.api.nvim_create_augroup("avante_replace_in_file", { clear = true }) + local function clear() + if bufnr and not vim.api.nvim_buf_is_valid(bufnr) then return end + vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) + vim.api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1) + unregister_keybinding_events() + pcall(vim.api.nvim_del_augroup_by_id, augroup) + end + + local function insert_diff_blocks_new_lines() + local base_line_ = 0 + for _, diff_block in ipairs(diff_blocks) do + local start_line = diff_block.start_line + base_line_ + local end_line = diff_block.end_line + base_line_ + base_line_ = base_line_ + #diff_block.new_lines - #diff_block.old_lines + vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, diff_block.new_lines) + end + end + + local function highlight_diff_blocks() + vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) + local base_line_ = 0 + local max_col = vim.o.columns + for _, diff_block in ipairs(diff_blocks) do + local start_line = diff_block.start_line + base_line_ + base_line_ = base_line_ + #diff_block.new_lines - #diff_block.old_lines + local deleted_virt_lines = vim + .iter(diff_block.old_lines) + :map(function(line) + --- append spaces to the end of the line + local line_ = line .. string.rep(" ", max_col - #line) + return { { line_, Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH } } + end) + :totable() + local extmark_line = math.max(0, start_line - 2) + local delete_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, { + virt_lines = deleted_virt_lines, + hl_eol = true, + hl_mode = "combine", + }) + local end_row = start_line + #diff_block.new_lines - 1 + local incoming_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, { + hl_group = Highlights.INCOMING, + hl_eol = true, + hl_mode = "combine", + end_row = end_row, + }) + diff_block.delete_extmark_id = delete_extmark_id + diff_block.incoming_extmark_id = incoming_extmark_id + end + end + + insert_diff_blocks_new_lines() + highlight_diff_blocks() + register_cursor_move_events() + register_keybinding_events() + + Helpers.confirm("Are you sure you want to apply this modification?", function(ok, reason) + clear() + if not ok then + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, original_lines) + on_complete(false, "User declined, reason: " .. (reason or "unknown")) + return + end + vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end) + if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end + on_complete(true, nil) + end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx) +end + +return M diff --git a/lua/avante/llm_tools/str_replace.lua b/lua/avante/llm_tools/str_replace.lua index e593590..8a9a92b 100644 --- a/lua/avante/llm_tools/str_replace.lua +++ b/lua/avante/llm_tools/str_replace.lua @@ -13,7 +13,7 @@ M.name = "str_replace" M.description = "The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits." -function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end +function M.enabled() return false end ---@type AvanteLLMToolParam M.param = { @@ -65,8 +65,8 @@ function M.func(opts, on_log, on_complete, session_ctx) if not file then return false, "file not found: " .. abs_path end if opts.old_str == nil then return false, "old_str not provided" end if opts.new_str == nil then return false, "new_str not provided" end - Utils.debug("old_str", opts.old_str) - Utils.debug("new_str", opts.new_str) + -- Utils.debug("old_str", opts.old_str) + -- Utils.debug("new_str", opts.new_str) local bufnr, err = Helpers.get_bufnr(abs_path) if err then return false, err end local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) @@ -77,7 +77,7 @@ function M.func(opts, on_log, on_complete, session_ctx) for i = 1, #lines - #old_lines + 1 do local match = true for j = 1, #old_lines do - if lines[i + j - 1] ~= old_lines[j] then + if Utils.remove_indentation(lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then match = false break end @@ -89,11 +89,20 @@ function M.func(opts, on_log, on_complete, session_ctx) end end if start_line == nil or end_line == nil then - on_complete(false, "Failed to find the old string: " .. opts.old_str) + on_complete(false, "Failed to find the old string:\n" .. opts.old_str) return end + local old_str = opts.old_str + local new_str = opts.new_str + local original_indentation = Utils.get_indentation(lines[start_line]) + if original_indentation ~= Utils.get_indentation(old_lines[1]) then + old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines) + new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines) + old_str = table.concat(old_lines, "\n") + new_str = table.concat(new_lines, "\n") + end ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields - local patch = vim.diff(opts.old_str, opts.new_str, { ---@type integer[][] + local patch = vim.diff(old_str, new_str, { ---@type integer[][] algorithm = "histogram", result_type = "indices", ctxlen = vim.o.scrolloff, diff --git a/lua/avante/llm_tools/undo_edit.lua b/lua/avante/llm_tools/undo_edit.lua index 75ce5b6..919b862 100644 --- a/lua/avante/llm_tools/undo_edit.lua +++ b/lua/avante/llm_tools/undo_edit.lua @@ -10,7 +10,7 @@ M.name = "undo_edit" M.description = "The undo_edit tool allows you to revert the last edit made to a file." -function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end +function M.enabled() return require("avante.config").mode == "agentic" end ---@type AvanteLLMToolParam M.param = { diff --git a/lua/avante/llm_tools/view.lua b/lua/avante/llm_tools/view.lua index 6d0f909..55b52bd 100644 --- a/lua/avante/llm_tools/view.lua +++ b/lua/avante/llm_tools/view.lua @@ -76,17 +76,6 @@ M.returns = { function M.func(opts, on_log, on_complete, session_ctx) if not on_complete then return false, "on_complete not provided" end if on_log then on_log("path: " .. opts.path) end - if Helpers.already_in_context(opts.path) then - on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to read it again?") - return - end - if session_ctx then - if Helpers.already_viewed(opts.path, session_ctx) then - on_complete(nil, "Ooooops! You have already viewed this file! Why you are trying to read it again?") - return - end - Helpers.mark_as_viewed(opts.path, session_ctx) - end local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end diff --git a/lua/avante/path.lua b/lua/avante/path.lua index 06da432..4e1d053 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -16,9 +16,9 @@ local function generate_project_dirname_in_storage(bufnr) buf = bufnr, }) -- Replace path separators with double underscores - local path_with_separators = fn.substitute(project_root, "/", "__", "g") + local path_with_separators = string.gsub(project_root, "/", "__") -- Replace other non-alphanumeric characters with single underscores - local dirname = fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") + local dirname = string.gsub(path_with_separators, "[^A-Za-z0-9._]", "_") return tostring(Path:new("projects"):joinpath(dirname)) end @@ -34,6 +34,7 @@ function History.get_history_dir(bufnr) return history_dir end +---@return avante.ChatHistory[] function History.list(bufnr) local history_dir = History.get_history_dir(bufnr) local files = vim.fn.glob(tostring(history_dir:joinpath("*.json")), true, true) @@ -53,8 +54,10 @@ function History.list(bufnr) table.sort(res, function(a, b) if a.filename == latest_filename then return true end if b.filename == latest_filename then return false end - local timestamp_a = #a.entries > 0 and a.entries[#a.entries].timestamp or a.timestamp - local timestamp_b = #b.entries > 0 and b.entries[#b.entries].timestamp or b.timestamp + local a_messages = Utils.get_history_messages(a) + local b_messages = Utils.get_history_messages(b) + local timestamp_a = #a_messages > 0 and a_messages[#a_messages].timestamp or a.timestamp + local timestamp_b = #b_messages > 0 and b_messages[#b_messages].timestamp or b.timestamp return timestamp_a > timestamp_b end) return res @@ -117,8 +120,8 @@ function History.new(bufnr) ---@type avante.ChatHistory local history = { title = "untitled", - timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")), - entries = {}, + timestamp = Utils.get_timestamp(), + messages = {}, filename = filepath_to_filename(filepath), } return history @@ -169,11 +172,10 @@ function Prompt.get_builtin_prompts_filepath(mode) return string.format("%s.avan local _templates_lib = nil Prompt.custom_modes = { - planning = true, + agentic = true, + legacy = true, editing = true, suggesting = true, - ["cursor-planning"] = true, - ["cursor-applying"] = true, } Prompt.custom_prompts_contents = {} diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index 31e4f4c..f820fc2 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -58,6 +58,7 @@ function M:parse_stream_data(ctx, data, opts) end function M:parse_response_without_stream(data, event_state, opts) + if opts.on_chunk == nil then return end local bedrock_match = data:gmatch("exception(%b{})") opts.on_chunk("\n**Exception caught**\n\n") for bedrock_data_match in bedrock_match do diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 7c64e6f..87c8845 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -2,7 +2,7 @@ local Utils = require("avante.utils") local Clipboard = require("avante.clipboard") local P = require("avante.providers") local Config = require("avante.config") -local StreamingJsonParser = require("avante.utils.streaming_json_parser") +local HistoryMessage = require("avante.history_message") ---@class AvanteProviderFunctor local M = {} @@ -139,63 +139,6 @@ function M:parse_messages(opts) messages[#messages].content = message_content end - if opts.tool_histories then - for _, tool_history in ipairs(opts.tool_histories) do - if tool_history.tool_use then - local msg = { - role = "assistant", - content = {}, - } - if tool_history.tool_use.thinking_blocks then - for _, thinking_block in ipairs(tool_history.tool_use.thinking_blocks) do - msg.content[#msg.content + 1] = { - type = "thinking", - thinking = thinking_block.thinking, - signature = thinking_block.signature, - } - end - end - if tool_history.tool_use.redacted_thinking_blocks then - for _, redacted_thinking_block in ipairs(tool_history.tool_use.redacted_thinking_blocks) do - msg.content[#msg.content + 1] = { - type = "redacted_thinking", - data = redacted_thinking_block.data, - } - end - end - if tool_history.tool_use.response_contents then - for _, response_content in ipairs(tool_history.tool_use.response_contents) do - msg.content[#msg.content + 1] = { - type = "text", - text = response_content, - } - end - end - msg.content[#msg.content + 1] = { - type = "tool_use", - id = tool_history.tool_use.id, - name = tool_history.tool_use.name, - input = vim.json.decode(tool_history.tool_use.input_json), - } - messages[#messages + 1] = msg - end - - if tool_history.tool_result then - messages[#messages + 1] = { - role = "user", - content = { - { - type = "tool_result", - tool_use_id = tool_history.tool_result.tool_use_id, - content = tool_history.tool_result.content, - is_error = tool_history.tool_result.is_error, - }, - }, - } - end - end - end - return messages end @@ -226,14 +169,51 @@ function M:parse_response(ctx, data_stream, event_state, opts) local content_block = jsn.content_block content_block.stoppped = false ctx.content_blocks[jsn.index + 1] = content_block - if content_block.type == "thinking" then opts.on_chunk("\n") end - if content_block.type == "tool_use" and opts.on_partial_tool_use then - opts.on_partial_tool_use({ - name = content_block.name, - id = content_block.id, - partial_json = {}, + if content_block.type == "text" then + local msg = HistoryMessage:new({ + role = "assistant", + content = content_block.text, + }, { state = "generating", }) + content_block.uuid = msg.uuid + if opts.on_messages_add then opts.on_messages_add({ msg }) end + end + if content_block.type == "thinking" then + if opts.on_chunk then opts.on_chunk("\n") end + if opts.on_messages_add then + local msg = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "thinking", + thinking = content_block.thinking, + signature = content_block.signature, + }, + }, + }, { + state = "generating", + }) + content_block.uuid = msg.uuid + opts.on_messages_add({ msg }) + end + end + if content_block.type == "tool_use" and opts.on_messages_add then + local msg = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "tool_use", + name = content_block.name, + id = content_block.id, + input = {}, + }, + }, + }, { + state = "generating", + }) + content_block.uuid = msg.uuid + opts.on_messages_add({ msg }) end elseif event_state == "content_block_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) @@ -242,23 +222,35 @@ function M:parse_response(ctx, data_stream, event_state, opts) if jsn.delta.type == "input_json_delta" then if not content_block.input_json then content_block.input_json = "" end content_block.input_json = content_block.input_json .. jsn.delta.partial_json - if opts.on_partial_tool_use then - local streaming_json_parser = StreamingJsonParser:new() - local partial_json = streaming_json_parser:parse(content_block.input_json) - opts.on_partial_tool_use({ - name = content_block.name, - id = content_block.id, - partial_json = partial_json or {}, - state = "generating", - }) - end return elseif jsn.delta.type == "thinking_delta" then content_block.thinking = content_block.thinking .. jsn.delta.thinking - opts.on_chunk(jsn.delta.thinking) + if opts.on_chunk then opts.on_chunk(jsn.delta.thinking) end + local msg = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "thinking", + thinking = content_block.thinking, + signature = content_block.signature, + }, + }, + }, { + state = "generating", + uuid = content_block.uuid, + }) + if opts.on_messages_add then opts.on_messages_add({ msg }) end elseif jsn.delta.type == "text_delta" then content_block.text = content_block.text .. jsn.delta.text - opts.on_chunk(jsn.delta.text) + if opts.on_chunk then opts.on_chunk(jsn.delta.text) end + local msg = HistoryMessage:new({ + role = "assistant", + content = content_block.text, + }, { + state = "generating", + uuid = content_block.uuid, + }) + if opts.on_messages_add then opts.on_messages_add({ msg }) end elseif jsn.delta.type == "signature_delta" then if ctx.content_blocks[jsn.index + 1].signature == nil then ctx.content_blocks[jsn.index + 1].signature = "" end ctx.content_blocks[jsn.index + 1].signature = ctx.content_blocks[jsn.index + 1].signature .. jsn.delta.signature @@ -268,12 +260,56 @@ function M:parse_response(ctx, data_stream, event_state, opts) if not ok then return end local content_block = ctx.content_blocks[jsn.index + 1] content_block.stoppped = true + if content_block.type == "text" then + local msg = HistoryMessage:new({ + role = "assistant", + content = content_block.text, + }, { + state = "generated", + uuid = content_block.uuid, + }) + if opts.on_messages_add then opts.on_messages_add({ msg }) end + end + if content_block.type == "tool_use" then + local complete_json = vim.json.decode(content_block.input_json) + local msg = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "tool_use", + name = content_block.name, + id = content_block.id, + input = complete_json or {}, + }, + }, + }, { + state = "generated", + uuid = content_block.uuid, + }) + if opts.on_messages_add then opts.on_messages_add({ msg }) end + end if content_block.type == "thinking" then - if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then - opts.on_chunk("\n\n\n") - else - opts.on_chunk("\n\n") + if opts.on_chunk then + if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then + opts.on_chunk("\n\n\n") + else + opts.on_chunk("\n\n") + end end + local msg = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "thinking", + thinking = content_block.thinking, + signature = content_block.signature, + }, + }, + }, { + state = "generated", + uuid = content_block.uuid, + }) + if opts.on_messages_add then opts.on_messages_add({ msg }) end end elseif event_state == "message_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) @@ -281,49 +317,20 @@ function M:parse_response(ctx, data_stream, event_state, opts) if jsn.delta.stop_reason == "end_turn" then opts.on_stop({ reason = "complete", usage = jsn.usage }) elseif jsn.delta.stop_reason == "tool_use" then - ---@type AvanteLLMToolUse[] - local tool_use_list = vim - .iter(ctx.content_blocks) - :filter(function(content_block) return content_block.stoppped and content_block.type == "tool_use" end) - :map(function(content_block) - local response_contents = vim - .iter(ctx.content_blocks) - :filter(function(content_block_) return content_block_.stoppped and content_block_.type == "text" end) - :map(function(content_block_) return content_block_.text end) - :totable() - local thinking_blocks = vim - .iter(ctx.content_blocks) - :filter(function(content_block_) return content_block_.stoppped and content_block_.type == "thinking" end) - :map(function(content_block_) - ---@type AvanteLLMThinkingBlock - return { thinking = content_block_.thinking, signature = content_block_.signature } - end) - :totable() - local redacted_thinking_blocks = vim - .iter(ctx.content_blocks) - :filter( - function(content_block_) return content_block_.stoppped and content_block_.type == "redacted_thinking" end - ) - :map(function(content_block_) - ---@type AvanteLLMRedactedThinkingBlock - return { data = content_block_.data } - end) - :totable() - ---@type AvanteLLMToolUse - return { - name = content_block.name, + local tool_use_list = {} + for _, content_block in ipairs(ctx.content_blocks) do + if content_block.type == "tool_use" then + table.insert(tool_use_list, { id = content_block.id, + name = content_block.name, input_json = content_block.input_json, - response_contents = response_contents, - thinking_blocks = thinking_blocks, - redacted_thinking_blocks = redacted_thinking_blocks, - } - end) - :totable() + }) + end + end opts.on_stop({ reason = "tool_use", + -- tool_use_list = tool_use_list, usage = jsn.usage, - tool_use_list = tool_use_list, }) end return @@ -351,7 +358,7 @@ function M:parse_curl_args(prompt_opts) local tools = {} if not disable_tools and prompt_opts.tools then for _, tool in ipairs(prompt_opts.tools) do - if Config.behaviour.enable_claude_text_editor_tool_mode then + if Config.mode == "agentic" then if tool.name == "create_file" then goto continue end if tool.name == "view" then goto continue end if tool.name == "str_replace" then goto continue end @@ -364,7 +371,7 @@ function M:parse_curl_args(prompt_opts) end end - if prompt_opts.tools and #prompt_opts.tools > 0 and Config.behaviour.enable_claude_text_editor_tool_mode then + if prompt_opts.tools and #prompt_opts.tools > 0 and Config.mode == "agentic" then if provider_conf.model:match("claude%-3%-7%-sonnet") then table.insert(tools, { type = "text_editor_20250124", diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 5f56dc0..caf8f91 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -211,11 +211,7 @@ M.role_map = { function M:is_disable_stream() return false end -M.parse_messages = OpenAI.parse_messages - -M.parse_response = OpenAI.parse_response - -M.is_reasoning_model = OpenAI.is_reasoning_model +setmetatable(M, { __index = OpenAI }) function M:parse_curl_args(prompt_opts) -- refresh token synchronously, only if it has expired diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 4718426..67c2aed 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -1,6 +1,7 @@ local Utils = require("avante.utils") -local P = require("avante.providers") +local Providers = require("avante.providers") local Clipboard = require("avante.clipboard") +local OpenAI = require("avante.providers").openai ---@class AvanteProviderFunctor local M = {} @@ -10,14 +11,32 @@ M.role_map = { user = "user", assistant = "model", } --- M.tokenizer_id = "google/gemma-2b" function M:is_disable_stream() return false end +---@param tool AvanteLLMTool +function M:transform_to_function_declaration(tool) + local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields) + local parameters = nil + if not vim.tbl_isempty(input_schema_properties) then + parameters = { + type = "object", + properties = input_schema_properties, + required = required, + } + end + return { + name = tool.name, + description = tool.get_description and tool.get_description() or tool.description, + parameters = parameters, + } +end + function M:parse_messages(opts) local contents = {} local prev_role = nil + local tool_id_to_name = {} vim.iter(opts.messages):each(function(message) local role = message.role if role == prev_role then @@ -54,9 +73,27 @@ function M:parse_messages(opts) }, }) elseif type(item) == "table" and item.type == "tool_use" then - table.insert(parts, { text = item.name }) + tool_id_to_name[item.id] = item.name + role = "model" + table.insert(parts, { + functionCall = { + name = item.name, + args = item.input, + }, + }) elseif type(item) == "table" and item.type == "tool_result" then - table.insert(parts, { text = item.content }) + role = "function" + local ok, content = pcall(vim.json.decode, item.content) + if not ok then content = item.content end + table.insert(parts, { + functionResponse = { + name = tool_id_to_name[item.tool_use_id], + response = { + name = tool_id_to_name[item.tool_use_id], + content = content, + }, + }, + }) elseif type(item) == "table" and item.type == "thinking" then table.insert(parts, { text = item.thinking }) elseif type(item) == "table" and item.type == "redacted_thinking" then @@ -96,22 +133,43 @@ end function M:parse_response(ctx, data_stream, _, opts) local ok, json = pcall(vim.json.decode, data_stream) if not ok then opts.on_stop({ reason = "error", error = json }) end - if json.candidates then - if #json.candidates > 0 then - if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then - opts.on_chunk(json.candidates[1].content.parts[1].text) - opts.on_stop({ reason = "complete" }) - else - opts.on_chunk(json.candidates[1].content.parts[1].text) + if json.candidates and #json.candidates > 0 then + local candidate = json.candidates[1] + ---@type AvanteLLMToolUse[] + local tool_use_list = {} + for _, part in ipairs(candidate.content.parts) do + if part.text then + if opts.on_chunk then opts.on_chunk(part.text) end + OpenAI:add_text_message(ctx, part.text, "generating", opts) + elseif part.functionCall then + if not ctx.function_call_id then ctx.function_call_id = 0 end + ctx.function_call_id = ctx.function_call_id + 1 + local tool_use = { + id = ctx.session_id .. "-" .. tostring(ctx.function_call_id), + name = part.functionCall.name, + input_json = vim.json.encode(part.functionCall.args), + } + table.insert(tool_use_list, tool_use) + OpenAI:add_tool_use_message(tool_use, "generated", opts) end - else - opts.on_stop({ reason = "complete" }) end + if candidate.finishReason and candidate.finishReason == "STOP" then + OpenAI:finish_pending_messages(ctx, opts) + if #tool_use_list > 0 then + opts.on_stop({ reason = "tool_use", tool_use_list = tool_use_list }) + else + opts.on_stop({ reason = "complete" }) + end + end + else + OpenAI:finish_pending_messages(ctx, opts) + opts.on_stop({ reason = "complete" }) end end function M:parse_curl_args(prompt_opts) - local provider_conf, request_body = P.parse_config(self) + local provider_conf, request_body = Providers.parse_config(self) + local disable_tools = provider_conf.disable_tools or false request_body = vim.tbl_deep_extend("force", request_body, { generationConfig = { @@ -125,6 +183,21 @@ function M:parse_curl_args(prompt_opts) local api_key = self.parse_api_key() if api_key == nil then error("Cannot get the gemini api key!") end + local function_declarations = {} + if not disable_tools and prompt_opts.tools then + for _, tool in ipairs(prompt_opts.tools) do + table.insert(function_declarations, self:transform_to_function_declaration(tool)) + end + end + + if #function_declarations > 0 then + request_body.tools = { + { + functionDeclarations = function_declarations, + }, + } + end + return { url = Utils.url_join( provider_conf.endpoint, diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 01c3c4f..55791a6 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -215,14 +215,6 @@ function M.setup() E.setup({ provider = auto_suggestions_provider }) end - if Config.behaviour.enable_cursor_planning_mode then - local cursor_applying_provider_name = Config.cursor_applying_provider or Config.provider - local cursor_applying_provider = M[cursor_applying_provider_name] - if cursor_applying_provider and cursor_applying_provider ~= provider then - E.setup({ provider = cursor_applying_provider }) - end - end - if Config.memory_summary_provider then local memory_summary_provider = M[Config.memory_summary_provider] if memory_summary_provider and memory_summary_provider ~= provider then @@ -277,4 +269,13 @@ function M.get_config(provider_name) return type(cur) == "function" and cur() or cur end +function M.get_memory_summary_provider() + local provider_name = Config.memory_summary_provider + if provider_name == nil then + if M.openai.is_env_set() then provider_name = "openai-gpt-4o-mini" end + end + if provider_name == nil then provider_name = Config.provider end + return M[provider_name] +end + return M diff --git a/lua/avante/providers/ollama.lua b/lua/avante/providers/ollama.lua index 5a84b50..950aec3 100644 --- a/lua/avante/providers/ollama.lua +++ b/lua/avante/providers/ollama.lua @@ -16,7 +16,7 @@ M.is_reasoning_model = P.openai.is_reasoning_model function M:is_disable_stream() return false end -function M:parse_stream_data(ctx, data, handler_opts) +function M:parse_stream_data(ctx, data, opts) local ok, json_data = pcall(vim.json.decode, data) if not ok or not json_data then -- Add debug logging @@ -26,11 +26,13 @@ function M:parse_stream_data(ctx, data, handler_opts) if json_data.message and json_data.message.content then local content = json_data.message.content - if content and content ~= "" then handler_opts.on_chunk(content) end + P.openai:add_text_message(ctx, content, "generating", opts) + if content and content ~= "" and opts.on_chunk then opts.on_chunk(content) end end if json_data.done then - handler_opts.on_stop({ reason = "complete" }) + P.openai:finish_pending_messages(ctx, opts) + opts.on_stop({ reason = "complete" }) return end end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 7351f13..42f207e 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -2,7 +2,7 @@ local Utils = require("avante.utils") local Config = require("avante.config") local Clipboard = require("avante.clipboard") local Providers = require("avante.providers") -local StreamingJsonParser = require("avante.utils.streaming_json_parser") +local HistoryMessage = require("avante.history_message") ---@class AvanteProviderFunctor local M = {} @@ -164,117 +164,154 @@ function M:parse_messages(opts) table.insert(final_messages, message) end) - if opts.tool_histories then - for _, tool_history in ipairs(opts.tool_histories) do - table.insert(final_messages, { - role = self.role_map["assistant"], - tool_calls = { - { - id = tool_history.tool_use.id, - type = "function", - ["function"] = { - name = tool_history.tool_use.name, - arguments = tool_history.tool_use.input_json, - }, - }, - }, - }) - local result_content = tool_history.tool_result.content or "" - table.insert(final_messages, { - role = "tool", - tool_call_id = tool_history.tool_result.tool_use_id, - content = tool_history.tool_result.is_error and "Error: " .. result_content or result_content, - }) + return final_messages +end + +function M:finish_pending_messages(ctx, opts) + if ctx.content ~= nil and ctx.content ~= "" then self:add_text_message(ctx, "", "generated", opts) end + if ctx.tool_use_list then + for _, tool_use in ipairs(ctx.tool_use_list) do + if tool_use.state == "generating" then self:add_tool_use_message(tool_use, "generated", opts) end end end +end - return final_messages +function M:add_text_message(ctx, text, state, opts) + if ctx.content == nil then ctx.content = "" end + ctx.content = ctx.content .. text + local msg = HistoryMessage:new({ + role = "assistant", + content = ctx.content, + }, { + state = state, + uuid = ctx.content_uuid, + }) + ctx.content_uuid = msg.uuid + if opts.on_messages_add then opts.on_messages_add({ msg }) end +end + +function M:add_thinking_message(ctx, text, state, opts) + if ctx.reasonging_content == nil then ctx.reasonging_content = "" end + ctx.reasonging_content = ctx.reasonging_content .. text + local msg = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "thinking", + thinking = ctx.reasonging_content, + signature = "", + }, + }, + }, { + state = state, + uuid = ctx.reasonging_content_uuid, + }) + ctx.reasonging_content_uuid = msg.uuid + if opts.on_messages_add then opts.on_messages_add({ msg }) end +end + +function M:add_tool_use_message(tool_use, state, opts) + local jsn = nil + if state == "generated" then jsn = vim.json.decode(tool_use.input_json) end + local msg = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "tool_use", + name = tool_use.name, + id = tool_use.id, + input = jsn or {}, + }, + }, + }, { + state = state, + uuid = tool_use.uuid, + }) + tool_use.uuid = msg.uuid + tool_use.state = state + if opts.on_messages_add then opts.on_messages_add({ msg }) end end function M:parse_response(ctx, data_stream, _, opts) if data_stream:match('"%[DONE%]":') then + self:finish_pending_messages(ctx, opts) opts.on_stop({ reason = "complete" }) return end - if data_stream:match('"delta":') then - ---@type AvanteOpenAIChatResponse - local jsn = vim.json.decode(data_stream) - if jsn.choices and jsn.choices[1] then - local choice = jsn.choices[1] - if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then - if choice.delta.content and choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end - opts.on_stop({ reason = "complete" }) - elseif choice.finish_reason == "tool_calls" then - opts.on_stop({ - reason = "tool_use", - usage = jsn.usage, - tool_use_list = ctx.tool_use_list, - }) - elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then - if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then - ctx.returned_think_start_tag = true - opts.on_chunk("\n") + if not data_stream:match('"delta":') then return end + ---@type AvanteOpenAIChatResponse + local jsn = vim.json.decode(data_stream) + if not jsn.choices or not jsn.choices[1] then return end + local choice = jsn.choices[1] + if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then + if choice.delta.content and choice.delta.content ~= vim.NIL then + self:add_text_message(ctx, choice.delta.content, "generated", opts) + opts.on_chunk(choice.delta.content) + end + self:finish_pending_messages(ctx, opts) + opts.on_stop({ reason = "complete" }) + elseif choice.finish_reason == "tool_calls" then + self:finish_pending_messages(ctx, opts) + opts.on_stop({ + reason = "tool_use", + -- tool_use_list = ctx.tool_use_list, + usage = jsn.usage, + }) + elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then + if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then + ctx.returned_think_start_tag = true + if opts.on_chunk then opts.on_chunk("\n") end + end + ctx.last_think_content = choice.delta.reasoning_content + self:add_thinking_message(ctx, choice.delta.reasoning_content, "generating", opts) + if opts.on_chunk then opts.on_chunk(choice.delta.reasoning_content) end + elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then + if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then + ctx.returned_think_start_tag = true + if opts.on_chunk then opts.on_chunk("\n") end + end + ctx.last_think_content = choice.delta.reasoning + self:add_thinking_message(ctx, choice.delta.reasoning, "generating", opts) + if opts.on_chunk then opts.on_chunk(choice.delta.reasoning) end + elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then + for _, tool_call in ipairs(choice.delta.tool_calls) do + if not ctx.tool_use_list then ctx.tool_use_list = {} end + if not ctx.tool_use_list[tool_call.index + 1] then + if tool_call.index > 0 and ctx.tool_use_list[tool_call.index] then + local prev_tool_use = ctx.tool_use_list[tool_call.index] + self:add_tool_use_message(prev_tool_use, "generated", opts) end - ctx.last_think_content = choice.delta.reasoning_content - opts.on_chunk(choice.delta.reasoning_content) - elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then - if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then - ctx.returned_think_start_tag = true - opts.on_chunk("\n") - end - ctx.last_think_content = choice.delta.reasoning - opts.on_chunk(choice.delta.reasoning) - elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then - for _, tool_call in ipairs(choice.delta.tool_calls) do - if not ctx.tool_use_list then ctx.tool_use_list = {} end - if not ctx.tool_use_list[tool_call.index + 1] then - local tool_use = { - name = tool_call["function"].name, - id = tool_call.id, - input_json = "", - } - ctx.tool_use_list[tool_call.index + 1] = tool_use - if opts.on_partial_tool_use then - opts.on_partial_tool_use({ - name = tool_call["function"].name, - id = tool_call.id, - partial_json = {}, - state = "generating", - }) - end - else - local tool_use = ctx.tool_use_list[tool_call.index + 1] - tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments - if opts.on_partial_tool_use then - local parser = StreamingJsonParser:new() - local partial_json = parser:parse(tool_use.input_json) - opts.on_partial_tool_use({ - name = tool_call["function"].name, - id = tool_call.id, - partial_json = partial_json or {}, - state = "generating", - }) - end - end - end - elseif choice.delta.content then - if - ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) - then - ctx.returned_think_end_tag = true - if - ctx.last_think_content - and ctx.last_think_content ~= vim.NIL - and ctx.last_think_content:sub(-1) ~= "\n" - then - opts.on_chunk("\n\n") - else - opts.on_chunk("\n") - end - end - if choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end + local tool_use = { + name = tool_call["function"].name, + id = tool_call.id, + input_json = "", + } + ctx.tool_use_list[tool_call.index + 1] = tool_use + self:add_tool_use_message(tool_use, "generating", opts) + else + local tool_use = ctx.tool_use_list[tool_call.index + 1] + tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments + self:add_tool_use_message(tool_use, "generating", opts) end end + elseif choice.delta.content then + if + ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) + then + ctx.returned_think_end_tag = true + if opts.on_chunk then + if ctx.last_think_content and ctx.last_think_content ~= vim.NIL and ctx.last_think_content:sub(-1) ~= "\n" then + opts.on_chunk("\n\n") + else + opts.on_chunk("\n") + end + end + self:add_thinking_message(ctx, "", "generated", opts) + end + if choice.delta.content ~= vim.NIL then + if opts.on_chunk then opts.on_chunk(choice.delta.content) end + self:add_text_message(ctx, choice.delta.content, "generating", opts) + end end end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index c91ad35..77e7baf 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -15,6 +15,7 @@ local Highlights = require("avante.highlights") local RepoMap = require("avante.repo_map") local FileSelector = require("avante.file_selector") local LLMTools = require("avante.llm_tools") +local HistoryMessage = require("avante.history_message") local RESULT_BUF_NAME = "AVANTE_RESULT" local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated" @@ -46,6 +47,13 @@ Sidebar.__index = Sidebar ---@field input_container NuiSplit | nil ---@field file_selector FileSelector ---@field chat_history avante.ChatHistory | nil +---@field current_state avante.GenerateState | nil +---@field state_timer table | nil +---@field state_spinner_chars string[] +---@field state_spinner_idx integer +---@field state_ns_id integer +---@field state_extmark_id integer | nil +---@field scroll boolean ---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage() function Sidebar:new(id) @@ -65,6 +73,13 @@ function Sidebar:new(id) file_selector = FileSelector:new(id), is_generating = false, chat_history = nil, + current_state = nil, + state_timer = nil, + state_spinner_chars = { "·", "✢", "✳", "∗", "✻", "✽" }, + state_spinner_idx = 1, + state_ns_id = api.nvim_create_namespace("avante_generate_state"), + state_extmark_id = nil, + scroll = true, }, Sidebar) end @@ -99,6 +114,7 @@ function Sidebar:reset() self.selected_code_container = nil self.selected_files_container = nil self.input_container = nil + self.scroll = true end ---@class SidebarOpenOptions: AskOptions @@ -156,6 +172,8 @@ function Sidebar:setup_colors() end function Sidebar:set_code_winhl() + if not self.code.winid or not api.nvim_win_is_valid(self.code.winid) then return end + if Utils.should_hidden_border(self.code.winid, self.winids.result_container) then Utils.debug("setting winhl") local old_winhl = vim.wo[self.code.winid].winhl @@ -247,11 +265,10 @@ end ---@field last_think_tag_start_line integer ---@field last_think_tag_end_line integer ----@param selected_files {path: string, content: string, file_type: string | nil}[] ---@param result_content string ---@param prev_filepath string ---@return AvanteReplacementResult -local function transform_result_content(selected_files, result_content, prev_filepath) +local function transform_result_content(result_content, prev_filepath) local transformed_lines = {} local result_lines = vim.split(result_content, "\n") @@ -338,42 +355,11 @@ local function transform_result_content(selected_files, result_content, prev_fil local end_line = 0 local match_filetype = nil local filepath = current_filepath or prev_filepath or "" - ---@type {path: string, content: string, file_type: string | nil} | nil - local the_matched_file = nil - for _, file in ipairs(selected_files) do - if Utils.is_same_file(file.path, filepath) then - the_matched_file = file - break - end - end - if not the_matched_file then - if not PPath:new(filepath):exists() then - the_matched_file = { - filepath = filepath, - content = "", - file_type = nil, - } - else - if not PPath:new(filepath):is_file() then - Utils.warn("Not a file: " .. filepath) - goto continue - end - local lines = Utils.read_file_from_buf_or_disk(filepath) - if lines == nil then - Utils.warn("Failed to read file: " .. filepath) - goto continue - end - local content = table.concat(lines, "\n") - the_matched_file = { - filepath = filepath, - content = content, - file_type = nil, - } - end - end + if filepath == "" then goto continue end - local file_content = vim.split(the_matched_file.content, "\n") + local file_content = Utils.read_file_from_buf_or_disk(filepath) or {} + local file_type = Utils.get_filetype(filepath) if start_line ~= 0 or end_line ~= 0 then break end for j = 1, #file_content - (search_end - search_start) + 1 do local match = true @@ -388,7 +374,7 @@ local function transform_result_content(selected_files, result_content, prev_fil if match then start_line = j end_line = j + (search_end - search_start) - 1 - match_filetype = the_matched_file.file_type + match_filetype = file_type break end end @@ -479,74 +465,6 @@ local function transform_result_content(selected_files, result_content, prev_fil } end -local spinner_chars = { - "⡀", - "⠄", - "⠂", - "⠁", - "⠈", - "⠐", - "⠠", - "⢀", - "⣀", - "⢄", - "⢂", - "⢁", - "⢈", - "⢐", - "⢠", - "⣠", - "⢤", - "⢢", - "⢡", - "⢨", - "⢰", - "⣰", - "⢴", - "⢲", - "⢱", - "⢸", - "⣸", - "⢼", - "⢺", - "⢹", - "⣹", - "⢽", - "⢻", - "⣻", - "⢿", - "⣿", - "⣶", - "⣤", - "⣀", -} -local spinner_index = 1 - -local function get_searching_hint() - spinner_index = (spinner_index % #spinner_chars) + 1 - local spinner = spinner_chars[spinner_index] - return "\n" .. spinner .. " Searching..." -end - -local thinking_spinner_chars = { - Utils.icon("🤯", "?"), - Utils.icon("🙄", "¿"), -} -local thinking_spinner_index = 1 - -local function get_thinking_spinner() - thinking_spinner_index = thinking_spinner_index + 1 - if thinking_spinner_index > #thinking_spinner_chars then thinking_spinner_index = 1 end - local spinner = thinking_spinner_chars[thinking_spinner_index] - return "\n\n" .. spinner .. " Thinking..." -end - -local function get_display_content_suffix(replacement) - if replacement.is_searching then return get_searching_hint() end - if replacement.is_thinking then return get_thinking_spinner() end - return "" -end - ---@param replacement AvanteReplacementResult ---@return string local function generate_display_content(replacement) @@ -944,14 +862,8 @@ local function parse_codeblocks(buf, current_filepath, current_filetype) start_line, _ = node:start() elseif node:type() == "fenced_code_block_delimiter" and start_line ~= nil and node:start() >= start_line then local end_line, _ = node:start() - if Config.behaviour.enable_cursor_planning_mode then - local filepath = obtain_filepath_from_codeblock(lines, start_line) - if not filepath and lang == current_filetype then filepath = current_filepath end - valid = filepath ~= nil - else - valid = lines[start_line - 1]:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)") - ~= nil - end + valid = lines[start_line - 1]:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)") + ~= nil if valid then table.insert(codeblocks, { start_line = start_line, end_line = end_line + 1, lang = lang }) end end end @@ -1042,16 +954,10 @@ end ---@param current_cursor boolean function Sidebar:apply(current_cursor) local buf_path = api.nvim_buf_get_name(self.code.bufnr) - local current_filepath = Utils.file.is_in_cwd(buf_path) and Utils.relative_path(buf_path) or buf_path - local current_filetype = Utils.get_filetype(current_filepath) local response, response_start_line = self:get_content_between_separators() - local all_snippets_map = Config.behaviour.enable_cursor_planning_mode - and extract_cursor_planning_code_snippets_map(response, current_filepath, current_filetype) - or extract_code_snippets_map(response) - if not Config.behaviour.enable_cursor_planning_mode then - all_snippets_map = ensure_snippets_no_overlap(all_snippets_map) - end + local all_snippets_map = extract_code_snippets_map(response) + all_snippets_map = ensure_snippets_no_overlap(all_snippets_map) local selected_snippets_map = {} if current_cursor then if self.result_container and self.result_container.winid then @@ -1072,342 +978,6 @@ function Sidebar:apply(current_cursor) selected_snippets_map = all_snippets_map end - if Config.behaviour.enable_cursor_planning_mode then - for filepath, snippets in pairs(selected_snippets_map) do - local original_code_lines = Utils.read_file_from_buf_or_disk(filepath) - if not original_code_lines then - Utils.error("Failed to read file: " .. filepath) - return - end - local formated_snippets = vim.iter(snippets):map(function(snippet) return snippet.content end):totable() - local original_code = table.concat(original_code_lines, "\n") - local resp_content = "" - local filetype = Utils.get_filetype(filepath) - local cursor_applying_provider_name = Config.cursor_applying_provider or Config.provider - Utils.debug(string.format("Use %s for cursor applying", cursor_applying_provider_name)) - local cursor_applying_provider = Provider[cursor_applying_provider_name] - if not cursor_applying_provider then - Utils.error("Failed to find cursor_applying_provider provider: " .. cursor_applying_provider_name, { - once = true, - title = "Avante", - }) - end - if self.code.winid ~= nil and api.nvim_win_is_valid(self.code.winid) then - api.nvim_set_current_win(self.code.winid) - end - local bufnr = Utils.get_or_create_buffer_with_filepath(filepath) - local path_ = PPath:new(filepath) - path_:parent():mkdir({ parents = true, exists_ok = true }) - - local ns_id = api.nvim_create_namespace("avante_live_diff") - - local function clear_highlights() api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) end - - -- Create loading indicator float window - local loading_buf = nil - local loading_win = nil - local spinner_frames = { "⣾", "⣽", "⣻", "⢿", "⡿", "⣟", "⣯", "⣷" } - local spinner_idx = 1 - local loading_timer = nil - - local function update_loading_indicator() - if not loading_win or not loading_buf or not api.nvim_win_is_valid(loading_win) then return end - spinner_idx = (spinner_idx % #spinner_frames) + 1 - local text = spinner_frames[spinner_idx] .. " Applying changes..." - api.nvim_buf_set_lines(loading_buf, 0, -1, false, { text }) - end - - local function create_loading_window() - local winid = self.input_container.winid - local win_height = api.nvim_win_get_height(winid) - local win_width = api.nvim_win_get_width(winid) - - -- Calculate position for center of window - local width = 30 - local height = 1 - local row = win_height - height - 1 - local col = win_width - width - - local opts = { - relative = "win", - win = winid, - width = width, - height = height, - row = row, - col = col, - anchor = "NW", - style = "minimal", - border = "none", - focusable = false, - zindex = 101, - } - - loading_buf = api.nvim_create_buf(false, true) - loading_win = api.nvim_open_win(loading_buf, false, opts) - - -- Start timer to update spinner - loading_timer = vim.loop.new_timer() - if loading_timer then loading_timer:start(0, 100, vim.schedule_wrap(update_loading_indicator)) end - end - - local function close_loading_window() - if loading_timer then - loading_timer:stop() - loading_timer:close() - loading_timer = nil - end - if loading_win and api.nvim_win_is_valid(loading_win) then - api.nvim_win_close(loading_win, true) - loading_win = nil - end - - if loading_buf then - api.nvim_buf_delete(loading_buf, { force = true }) - loading_buf = nil - end - end - - clear_highlights() - create_loading_window() - - local last_processed_line = 0 - local last_orig_diff_end_line = 1 - local last_resp_diff_end_line = 1 - local cleaned = false - local prev_patch = {} - - local function get_stable_patch(patch) - local new_patch = {} - for _, hunk in ipairs(patch) do - local start_a, count_a, start_b, count_b = unpack(hunk) - start_a = start_a + last_orig_diff_end_line - 1 - start_b = start_b + last_resp_diff_end_line - 1 - local has = vim.iter(prev_patch):find(function(hunk_) - local start_a_, count_a_, start_b_, count_b_ = unpack(hunk_) - return start_a == start_a_ and start_b == start_b_ and count_a == count_a_ and count_b == count_b_ - end) - if has ~= nil then table.insert(new_patch, hunk) end - end - return new_patch - end - - local extmark_id_map = {} - local virt_lines_map = {} - - Llm.stream({ - ask = true, - provider = cursor_applying_provider, - code_lang = filetype, - mode = "cursor-applying", - original_code = original_code, - update_snippets = formated_snippets, - on_start = function(_) end, - on_chunk = function(chunk) - if not chunk then return end - - resp_content = resp_content .. chunk - - if not cleaned then - resp_content = resp_content:gsub("\n*", ""):gsub("\n*", "") - resp_content = resp_content:gsub(".*```%w+\n", ""):gsub("\n```\n.*", "") - end - - local resp_lines = vim.split(resp_content, "\n") - - local complete_lines_count = #resp_lines - 1 - if complete_lines_count > 2 then cleaned = true end - - if complete_lines_count <= last_processed_line then return end - - local original_lines_to_process = - vim.list_slice(original_code_lines, last_orig_diff_end_line, complete_lines_count) - local resp_lines_to_process = vim.list_slice(resp_lines, last_resp_diff_end_line, complete_lines_count) - - local resp_lines_content = table.concat(resp_lines_to_process, "\n") - local original_lines_content = table.concat(original_lines_to_process, "\n") - - ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields - local patch = vim.diff(original_lines_content, resp_lines_content, { ---@type integer[][] - algorithm = "histogram", - result_type = "indices", - ctxlen = vim.o.scrolloff, - }) - - local stable_patch = get_stable_patch(patch) - - for _, hunk in ipairs(stable_patch) do - local start_a, count_a, start_b, count_b = unpack(hunk) - - start_a = last_orig_diff_end_line + start_a - 1 - - if count_a > 0 then - api.nvim_buf_set_extmark(bufnr, ns_id, start_a - 1, 0, { - hl_group = Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH, - hl_eol = true, - hl_mode = "combine", - end_row = start_a + count_a - 1, - }) - end - - if count_b == 0 then goto continue end - - local new_lines = vim.list_slice(resp_lines_to_process, start_b, start_b + count_b - 1) - local max_col = vim.o.columns - local virt_lines = vim - .iter(new_lines) - :map(function(line) - --- append spaces to the end of the line - local line_ = line .. string.rep(" ", max_col - #line) - return { { line_, Highlights.INCOMING } } - end) - :totable() - local extmark_line - if count_a > 0 then - extmark_line = math.max(0, start_a + count_a - 2) - else - extmark_line = math.max(0, start_a + count_a - 1) - end - local old_extmark_id = extmark_id_map[extmark_line] - if old_extmark_id ~= nil then - local old_virt_lines = virt_lines_map[old_extmark_id] or {} - virt_lines = vim.list_extend(old_virt_lines, virt_lines) - api.nvim_buf_del_extmark(bufnr, ns_id, old_extmark_id) - end - local extmark_id = api.nvim_buf_set_extmark(bufnr, ns_id, extmark_line, 0, { - virt_lines = virt_lines, - hl_eol = true, - hl_mode = "combine", - }) - extmark_id_map[extmark_line] = extmark_id - virt_lines_map[extmark_id] = virt_lines - ::continue:: - end - - prev_patch = vim - .iter(patch) - :map(function(hunk) - local start_a, count_a, start_b, count_b = unpack(hunk) - return { last_orig_diff_end_line + start_a - 1, count_a, last_resp_diff_end_line + start_b - 1, count_b } - end) - :totable() - - if #stable_patch > 0 then - local start_a, count_a, start_b, count_b = unpack(stable_patch[#stable_patch]) - last_orig_diff_end_line = last_orig_diff_end_line + start_a + math.max(count_a, 1) - 1 - last_resp_diff_end_line = last_resp_diff_end_line + start_b + math.max(count_b, 1) - 1 - end - - if #patch == 0 then - last_orig_diff_end_line = complete_lines_count + 1 - last_resp_diff_end_line = complete_lines_count + 1 - end - - last_processed_line = complete_lines_count - - local winid = Utils.get_winid(bufnr) - - if winid == nil then return end - - --- goto window winid - api.nvim_set_current_win(winid) - --- goto the last line - pcall(function() api.nvim_win_set_cursor(winid, { complete_lines_count, 0 }) end) - vim.cmd("normal! zz") - end, - on_stop = function(stop_opts) - clear_highlights() - close_loading_window() - - if stop_opts.error ~= nil then - Utils.error(string.format("applying failed: %s", vim.inspect(stop_opts.error))) - return - end - - resp_content = resp_content:gsub("\n*", ""):gsub("\n*", "") - - resp_content = resp_content:gsub(".*```%w+\n", ""):gsub("\n```\n.*", ""):gsub("\n```$", "") - - local resp_lines = vim.split(resp_content, "\n") - - if #resp_lines > 0 and resp_lines[#resp_lines] == "" then - resp_lines = vim.list_slice(resp_lines, 0, #resp_lines - 1) - resp_content = table.concat(resp_lines, "\n") - end - - if require("avante.config").debug then - local resp_content_file = fn.tempname() .. ".txt" - fn.writefile(vim.split(resp_content, "\n"), resp_content_file) - Utils.debug("cursor applying response content written to: " .. resp_content_file) - end - - if resp_content == original_code then return end - - ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields - local patch = vim.diff(original_code, resp_content, { ---@type integer[][] - algorithm = "histogram", - result_type = "indices", - ctxlen = vim.o.scrolloff, - }) - - local new_lines = {} - local prev_start_a = 1 - for _, hunk in ipairs(patch) do - local start_a, count_a, start_b, count_b = unpack(hunk) - if count_a > 0 then - vim.list_extend(new_lines, vim.list_slice(original_code_lines, prev_start_a, start_a - 1)) - else - vim.list_extend(new_lines, vim.list_slice(original_code_lines, prev_start_a, start_a)) - end - prev_start_a = start_a + count_a - if count_a == 0 then prev_start_a = prev_start_a + 1 end - table.insert(new_lines, "<<<<<<< HEAD") - if count_a > 0 then - vim.list_extend(new_lines, vim.list_slice(original_code_lines, start_a, start_a + count_a - 1)) - end - table.insert(new_lines, "=======") - if count_b > 0 then - vim.list_extend(new_lines, vim.list_slice(resp_lines, start_b, start_b + count_b - 1)) - end - table.insert(new_lines, ">>>>>>> Snippet") - end - - local remaining_lines = vim.list_slice(original_code_lines, prev_start_a, #original_code_lines) - new_lines = vim.list_extend(new_lines, remaining_lines) - - api.nvim_buf_set_lines(bufnr, 0, -1, false, new_lines) - - local function process(winid) - api.nvim_set_current_win(winid) - vim.cmd("noautocmd stopinsert") - Diff.add_visited_buffer(bufnr) - Diff.process(bufnr) - api.nvim_win_set_cursor(winid, { 1, 0 }) - vim.defer_fn(function() - Diff.find_next(Config.windows.ask.focus_on_apply) - vim.cmd("normal! zz") - end, 100) - end - - local winid = Utils.get_winid(bufnr) - if winid then - process(winid) - else - api.nvim_create_autocmd("BufWinEnter", { - group = self.augroup, - buffer = bufnr, - once = true, - callback = function() - local winid_ = Utils.get_winid(bufnr) - if winid_ then process(winid_) end - end, - }) - end - end, - }) - end - return - end - vim.defer_fn(function() api.nvim_set_current_win(self.code.winid) for filepath, snippets in pairs(selected_snippets_map) do @@ -1815,28 +1385,30 @@ function Sidebar:on_mount(opts) end, }) - local buf_path = api.nvim_buf_get_name(self.code.bufnr) - local current_filepath = Utils.file.is_in_cwd(buf_path) and Utils.relative_path(buf_path) or buf_path - local current_filetype = Utils.get_filetype(current_filepath) + if self.code.bufnr and api.nvim_buf_is_valid(self.code.bufnr) then + local buf_path = api.nvim_buf_get_name(self.code.bufnr) + local current_filepath = Utils.file.is_in_cwd(buf_path) and Utils.relative_path(buf_path) or buf_path + local current_filetype = Utils.get_filetype(current_filepath) - api.nvim_create_autocmd({ "BufEnter", "BufWritePost" }, { - group = self.augroup, - buffer = self.result_container.bufnr, - callback = function(ev) - codeblocks = parse_codeblocks(ev.buf, current_filepath, current_filetype) - self:bind_sidebar_keys(codeblocks) - end, - }) + api.nvim_create_autocmd({ "BufEnter", "BufWritePost" }, { + group = self.augroup, + buffer = self.result_container.bufnr, + callback = function(ev) + codeblocks = parse_codeblocks(ev.buf, current_filepath, current_filetype) + self:bind_sidebar_keys(codeblocks) + end, + }) - api.nvim_create_autocmd("User", { - group = self.augroup, - pattern = VIEW_BUFFER_UPDATED_PATTERN, - callback = function() - if not Utils.is_valid_container(self.result_container) then return end - codeblocks = parse_codeblocks(self.result_container.bufnr, current_filepath, current_filetype) - self:bind_sidebar_keys(codeblocks) - end, - }) + api.nvim_create_autocmd("User", { + group = self.augroup, + pattern = VIEW_BUFFER_UPDATED_PATTERN, + callback = function() + if not Utils.is_valid_container(self.result_container) then return end + codeblocks = parse_codeblocks(self.result_container.bufnr, current_filepath, current_filetype) + self:bind_sidebar_keys(codeblocks) + end, + }) + end api.nvim_create_autocmd("BufLeave", { group = self.augroup, @@ -1848,8 +1420,6 @@ function Sidebar:on_mount(opts) self:render_input(opts.ask) self:render_selected_code() - local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr }) - if self.selected_code_container ~= nil then local selected_code_buf = self.selected_code_container.bufnr if selected_code_buf ~= nil then @@ -1859,7 +1429,10 @@ function Sidebar:on_mount(opts) api.nvim_buf_set_lines(selected_code_buf, 0, -1, false, lines) Utils.lock_buf(selected_code_buf) end - api.nvim_set_option_value("filetype", filetype, { buf = selected_code_buf }) + if self.code.bufnr and api.nvim_buf_is_valid(self.code.bufnr) then + local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr }) + api.nvim_set_option_value("filetype", filetype, { buf = selected_code_buf }) + end end end @@ -2013,81 +1586,25 @@ function Sidebar:is_sidebar_winid(winid) return false end -local function delete_last_n_chars(bufnr, n) - bufnr = bufnr or api.nvim_get_current_buf() - - local line_count = api.nvim_buf_line_count(bufnr) - - while n > 0 and line_count > 0 do - local last_line = api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] - - local total_chars_in_line = #last_line + 1 - - if total_chars_in_line > n then - local chars_to_keep = total_chars_in_line - n - 1 - 1 - local new_last_line = last_line:sub(1, chars_to_keep) - if new_last_line == "" then - api.nvim_buf_set_lines(bufnr, line_count - 1, line_count, false, {}) - line_count = line_count - 1 - else - api.nvim_buf_set_lines(bufnr, line_count - 1, line_count, false, { new_last_line }) - end - n = 0 - else - n = n - total_chars_in_line - api.nvim_buf_set_lines(bufnr, line_count - 1, line_count, false, {}) - line_count = line_count - 1 - end - end -end - ---@param content string concatenated content of the buffer ----@param opts? {focus?: boolean, scroll?: boolean, backspace?: integer, ignore_history?: boolean, callback?: fun(): nil} whether to focus the result view +---@param opts? {focus?: boolean, scroll?: boolean, backspace?: integer, callback?: fun(): nil} whether to focus the result view function Sidebar:update_content(content, opts) if not self.result_container or not self.result_container.bufnr then return end - opts = vim.tbl_deep_extend("force", { focus = false, scroll = true, stream = false, callback = nil }, opts or {}) - if not opts.ignore_history then - local chat_history = Path.history.load(self.code.bufnr) - content = self.render_history_content(chat_history) .. "-------\n\n" .. content - end - if opts.stream then - local function scroll_to_bottom() - local last_line = api.nvim_buf_line_count(self.result_container.bufnr) - - local current_lines = Utils.get_buf_lines(last_line - 1, last_line, self.result_container.bufnr) - - if #current_lines > 0 then - local last_line_content = current_lines[1] - local last_col = #last_line_content - xpcall( - function() api.nvim_win_set_cursor(self.result_container.winid, { last_line, last_col }) end, - function(err) return err end - ) - end - end - - vim.schedule(function() - if not Utils.is_valid_container(self.result_container) then return end - Utils.unlock_buf(self.result_container.bufnr) - if opts.backspace ~= nil and opts.backspace > 0 then - delete_last_n_chars(self.result_container.bufnr, opts.backspace) - end - scroll_to_bottom() - local lines = vim.split(content, "\n") - api.nvim_buf_call(self.result_container.bufnr, function() api.nvim_put(lines, "c", true, true) end) - Utils.lock_buf(self.result_container.bufnr) - api.nvim_set_option_value("filetype", "Avante", { buf = self.result_container.bufnr }) - if opts.scroll then scroll_to_bottom() end - if opts.callback ~= nil then opts.callback() end - end) - else - vim.defer_fn(function() + opts = vim.tbl_deep_extend("force", { focus = false, scroll = self.scroll, callback = nil }, opts or {}) + local history_content = self.render_history_content(self.chat_history) + local contents = { history_content, content } + contents = vim.iter(contents):filter(function(item) return item ~= nil and item ~= "" end):totable() + content = table.concat(contents, "\n\n") + vim.defer_fn(function() + self:clear_state() + local f = function() if not Utils.is_valid_container(self.result_container) then return end local lines = vim.split(content, "\n") Utils.unlock_buf(self.result_container.bufnr) Utils.update_buffer_content(self.result_container.bufnr, lines) Utils.lock_buf(self.result_container.bufnr) api.nvim_set_option_value("filetype", "Avante", { buf = self.result_container.bufnr }) + vim.schedule(function() vim.cmd("redraw") end) if opts.focus and not self:is_focused_on_result() then --- set cursor to bottom of result view xpcall(function() api.nvim_set_current_win(self.result_container.winid) end, function(err) return err end) @@ -2096,14 +1613,13 @@ function Sidebar:update_content(content, opts) if opts.scroll then Utils.buf_scroll_to_end(self.result_container.bufnr) end if opts.callback ~= nil then opts.callback() end - end, 0) - end + end + f() + self:render_state() + end, 0) return self end --- Function to get current timestamp -local function get_timestamp() return os.date("%Y-%m-%d %H:%M:%S") end - ---@param timestamp string|osdate ---@param provider string ---@param model string @@ -2132,7 +1648,7 @@ local function render_chat_record_prefix(timestamp, provider, model, request, se .. "\n```" end - return res .. "\n\n> " .. request:gsub("\n", "\n> "):gsub("([%w-_]+)%b[]", "`%0`") .. "\n\n" + return res .. "\n\n> " .. request:gsub("\n", "\n> "):gsub("([%w-_]+)%b[]", "`%0`") end local function calculate_config_window_position() @@ -2157,58 +1673,71 @@ function Sidebar:get_layout() return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal" end +---@param message avante.HistoryMessage +---@param ctx table +---@return string | nil +local function render_message(message, ctx) + if message.visible == false then return nil end + local text = Utils.message_to_text(message) + if text == "" then return nil end + if message.is_user_submission then + ctx.selected_filepaths = message.selected_filepaths + local prefix = render_chat_record_prefix( + message.timestamp, + message.provider, + message.model, + text, + message.selected_filepaths, + message.selected_code + ) + return prefix + end + if message.message.role == "user" then + local lines = vim.split(text, "\n") + lines = vim.iter(lines):map(function(line) return "> " .. line end):totable() + text = table.concat(lines, "\n") + return text + end + if message.message.role == "assistant" then + local transformed = transform_result_content(text, ctx.prev_filepath) + ctx.prev_filepath = transformed.current_filepath + local displayed_content = generate_display_content(transformed) + return displayed_content + end + return "" +end + ---@param history avante.ChatHistory ---@return string function Sidebar.render_history_content(history) - local added_breakline = false - local content = "" - for idx, entry in ipairs(history.entries) do - if entry.visible == false then goto continue end - if entry.reset_memory then - content = content .. "***MEMORY RESET***\n\n" - if idx < #history.entries and not added_breakline then - added_breakline = true - content = content .. "-------\n\n" - end - goto continue + local history_messages = Utils.get_history_messages(history) + local ctx = {} + local group = {} + for _, message in ipairs(history_messages) do + local text = render_message(message, ctx) + if text == nil then goto continue end + if message.is_user_submission then table.insert(group, {}) end + local last_item = group[#group] + if last_item == nil then + table.insert(group, {}) + last_item = group[#group] end - local selected_filepaths = entry.selected_filepaths - if not selected_filepaths and entry.selected_file ~= nil then - selected_filepaths = { entry.selected_file.filepath } - end - if entry.request and entry.request ~= "" then - if idx ~= 1 and not added_breakline then - added_breakline = true - content = content .. "-------\n\n" - end - local prefix = render_chat_record_prefix( - entry.timestamp, - entry.provider, - entry.model, - entry.request or "", - selected_filepaths or {}, - entry.selected_code - ) - content = content .. prefix - end - if entry.response and entry.response ~= "" then - content = content .. entry.response .. "\n\n" - if idx < #history.entries then - added_breakline = true - content = content .. "-------\n\n" - end - else - added_breakline = false + if message.message.role == "assistant" and not message.just_for_display and text:sub(1, 2) ~= "\n\n" then + text = "\n\n" .. text end + table.insert(last_item, text) ::continue:: end - return content + local pieces = {} + for _, item in ipairs(group) do + table.insert(pieces, table.concat(item, "")) + end + return table.concat(pieces, "\n\n" .. RESP_SEPARATOR .. "\n\n") .. "\n\n" end function Sidebar:update_content_with_history() self:reload_chat_history() - local content = self.render_history_content(self.chat_history) - self:update_content(content, { ignore_history = true }) + self:update_content("") end ---@return string, integer @@ -2249,11 +1778,11 @@ end function Sidebar:clear_history(args, cb) local chat_history = Path.history.load(self.code.bufnr) if next(chat_history) ~= nil then - chat_history.entries = {} + chat_history.messages = {} Path.history.save(self.code.bufnr, chat_history) self:update_content( "Chat history cleared", - { ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end } + { focus = false, scroll = false, callback = function() self:focus_input() end } ) if cb then cb(args) end else @@ -2264,24 +1793,115 @@ function Sidebar:clear_history(args, cb) end end +function Sidebar:clear_state() + if self.state_extmark_id then + pcall(api.nvim_buf_del_extmark, self.result_container.bufnr, self.state_ns_id, self.state_extmark_id) + end + self.state_extmark_id = nil + self.state_spinner_idx = 1 + if self.state_timer then self.state_timer:stop() end +end + +function Sidebar:render_state() + if not Utils.is_valid_container(self.result_container) then return end + if not self.current_state then return end + local lines = vim.api.nvim_buf_get_lines(self.result_container.bufnr, 0, -1, false) + if self.state_extmark_id then + api.nvim_buf_del_extmark(self.result_container.bufnr, self.state_ns_id, self.state_extmark_id) + end + local spinner_char = self.state_spinner_chars[self.state_spinner_idx] + self.state_spinner_idx = (self.state_spinner_idx % #self.state_spinner_chars) + 1 + local hl = "AvanteStateSpinnerGenerating" + if self.current_state == "tool calling" then hl = "AvanteStateSpinnerToolCalling" end + if self.current_state == "failed" then hl = "AvanteStateSpinnerFailed" end + if self.current_state == "succeeded" then hl = "AvanteStateSpinnerSucceeded" end + if self.current_state == "searching" then hl = "AvanteStateSpinnerSearching" end + if self.current_state == "thinking" then hl = "AvanteStateSpinnerThinking" end + if self.current_state ~= "generating" and self.current_state ~= "tool calling" then spinner_char = "" end + local virt_line + if spinner_char == "" then + virt_line = " " .. self.current_state .. " " + else + virt_line = " " .. spinner_char .. " " .. self.current_state .. " " + end + + local win_width = api.nvim_win_get_width(self.result_container.winid) + local padding = math.floor((win_width - vim.fn.strdisplaywidth(virt_line)) / 2) + local centered_virt_lines = { + { { string.rep(" ", padding) }, { virt_line, hl } }, + } + + self.state_extmark_id = api.nvim_buf_set_extmark(self.result_container.bufnr, self.state_ns_id, #lines - 2, 0, { + virt_lines = centered_virt_lines, + hl_eol = true, + hl_mode = "combine", + }) + self.state_timer = vim.defer_fn(function() self:render_state() end, 160) +end + function Sidebar:new_chat(args, cb) local history = Path.history.new(self.code.bufnr) Path.history.save(self.code.bufnr, history) self:reload_chat_history() - self:update_content( - "New chat", - { ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end } - ) + self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end }) if cb then cb(args) end end +---@param messages avante.HistoryMessage | avante.HistoryMessage[] +function Sidebar:add_history_messages(messages) + local history_messages = Utils.get_history_messages(self.chat_history) + messages = vim.islist(messages) and messages or { messages } + for _, message in ipairs(messages) do + if message.is_user_submission then + message.provider = Config.provider + message.model = Config.get_provider_config(Config.provider).model + end + local idx = nil + for idx_, message_ in ipairs(history_messages) do + if message_.uuid == message.uuid then + idx = idx_ + break + end + end + if idx ~= nil then + history_messages[idx] = message + else + table.insert(history_messages, message) + end + end + self.chat_history.messages = history_messages + Path.history.save(self.code.bufnr, self.chat_history) + if self.chat_history.title == "untitled" and #messages > 0 then + Llm.summarize_chat_thread_title(messages[1].message.content, function(title) + self:reload_chat_history() + if title then self.chat_history.title = title end + Path.history.save(self.code.bufnr, self.chat_history) + end) + end + local last_message = messages[#messages] + if last_message then + local content = last_message.message.content + if type(content) == "table" and content[1].type == "tool_use" then + self.current_state = "tool calling" + elseif type(content) == "table" and content[1].type == "thinking" then + self.current_state = "thinking" + elseif type(content) == "table" and content[1].type == "redacted_thinking" then + self.current_state = "thinking" + else + self.current_state = "generating" + end + end + self:update_content("") +end + ---@param messages AvanteLLMMessage | AvanteLLMMessage[] ---@param options {visible?: boolean} function Sidebar:add_chat_history(messages, options) options = options or {} - local timestamp = get_timestamp() messages = vim.islist(messages) and messages or { messages } self:reload_chat_history() + local is_first_user = true + local history_messages = {} for _, message in ipairs(messages) do local content = message.content if message.role == "system" and type(content) == "string" then @@ -2289,61 +1909,22 @@ function Sidebar:add_chat_history(messages, options) self.chat_history.system_prompt = content goto continue end - table.insert(self.chat_history.entries, { - timestamp = timestamp, - provider = Config.provider, - model = Config.get_provider_config(Config.provider).model, - request = message.role == "user" and message.content or "", - response = message.role == "assistant" and message.content or "", - original_response = message.role == "assistant" and message.content or "", - selected_filepaths = nil, - selected_code = nil, - reset_memory = false, - visible = options.visible, - }) + local history_message = HistoryMessage:new(message) + if message.role == "user" and is_first_user then + is_first_user = false + history_message.is_user_submission = true + history_message.provider = Config.provider + history_message.model = Config.get_provider_config(Config.provider).model + end + table.insert(history_messages, history_message) ::continue:: end - Path.history.save(self.code.bufnr, self.chat_history) - if options.visible then self:update_content_with_history() end - if self.chat_history.title == "untitled" and #messages > 0 then - Llm.summarize_chat_thread_title(messages[1].content, function(title) - self:reload_chat_history() - if title then self.chat_history.title = title end - Path.history.save(self.code.bufnr, self.chat_history) - end) - end -end - -function Sidebar:reset_memory(args, cb) - local chat_history = Path.history.load(self.code.bufnr) - if next(chat_history) ~= nil then - table.insert(chat_history, { - timestamp = get_timestamp(), - provider = Config.provider, - model = Config.get_provider_config(Config.provider).model, - request = "", - response = "", - original_response = "", - selected_file = nil, - selected_code = nil, - reset_memory = true, - }) - Path.history.save(self.code.bufnr, chat_history) - self:reload_chat_history() - local history_content = self.render_history_content(chat_history) - self:update_content(history_content, { - focus = false, - scroll = true, - callback = function() self:focus_input() end, - }) - if cb then cb(args) end - else - self:reload_chat_history() - self:update_content( - "Chat history is already empty", - { focus = false, scroll = false, callback = function() self:focus_input() end } - ) + if options.visible ~= nil then + for _, history_message in ipairs(history_messages) do + history_message.visible = options.visible + end end + self:add_history_messages(history_messages) end function Sidebar:create_selected_code_container() @@ -2382,8 +1963,6 @@ function Sidebar:create_selected_code_container() end end -local generating_text = "**Generating response ...**\n" - local hint_window = nil function Sidebar:reload_chat_history() @@ -2391,6 +1970,22 @@ function Sidebar:reload_chat_history() self.chat_history = Path.history.load(self.code.bufnr) end +---@return avante.HistoryMessage[] +function Sidebar:get_history_messages_for_api() + local history_messages = Utils.get_history_messages(self.chat_history) + self.chat_history.messages = history_messages + + if self.chat_history.memory then + history_messages = {} + for i = #self.chat_history.messages, 1, -1 do + local message = self.chat_history.messages[i] + if message.uuid == self.chat_history.memory.last_message_uuid then break end + table.insert(history_messages, 1, message) + end + end + return vim.iter(history_messages):filter(function(message) return not message.just_for_display end):totable() +end + ---@param opts AskOptions function Sidebar:create_input_container(opts) if self.input_container then self.input_container:unmount() end @@ -2400,9 +1995,8 @@ function Sidebar:create_input_container(opts) if self.chat_history == nil then self:reload_chat_history() end ---@param request string - ---@param summarize_memory boolean ---@param cb fun(opts: AvanteGeneratePromptsOptions): nil - local function get_generate_prompts_options(request, summarize_memory, cb) + local function get_generate_prompts_options(request, cb) local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr }) local file_ext = nil @@ -2434,16 +2028,7 @@ function Sidebar:create_input_container(opts) end end - local entries = Utils.history.filter_active_entries(self.chat_history.entries) - - if self.chat_history.memory then - entries = vim - .iter(entries) - :filter(function(entry) return entry.timestamp > self.chat_history.memory.last_summarized_timestamp end) - :totable() - end - - local history_messages = Utils.history.entries_to_llm_messages(entries) + local history_messages = self:get_history_messages_for_api() local tools = vim.deepcopy(LLMTools.get_tools(request, history_messages)) table.insert(tools, { @@ -2476,11 +2061,6 @@ function Sidebar:create_input_container(opts) returns = {}, }) - local mode = "planning" - if Config.behaviour.enable_cursor_planning_mode then mode = "cursor-planning" end - - if Config.behaviour.enable_claude_text_editor_tool_mode then mode = "claude-text-editor-tool" end - local selected_filepaths = self.file_selector.selected_filepaths or {} ---@type AvanteGeneratePromptsOptions @@ -2493,35 +2073,30 @@ function Sidebar:create_input_container(opts) history_messages = history_messages, code_lang = filetype, selected_code = selected_code, - instructions = request, - mode = mode, + -- instructions = request, tools = tools, } if self.chat_history.system_prompt then prompts_opts.prompt_opts = { system_prompt = self.chat_history.system_prompt, - messages = {}, + messages = history_messages, } end if self.chat_history.memory then prompts_opts.memory = self.chat_history.memory.content end - if not summarize_memory or #history_messages < 8 then - cb(prompts_opts) - return - end - - prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5) - - Llm.summarize_memory(self.code.bufnr, self.chat_history, nil, function(memory) - if memory then prompts_opts.memory = memory.content end - cb(prompts_opts) - end) + cb(prompts_opts) end ---@param request string local function handle_submit(request) + if self.is_generating then + self:add_history_messages({ + HistoryMessage:new({ role = "user", content = request }), + }) + return + end if request:match("@codebase") and not vim.fn.expand("%:e") then self:update_content("Please open a file first before using @codebase", { focus = false, scroll = false }) return @@ -2557,7 +2132,7 @@ function Sidebar:create_input_container(opts) local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model or "default" - local timestamp = get_timestamp() + local timestamp = Utils.get_timestamp() local selected_filepaths = self.file_selector:get_selected_filepaths() @@ -2578,32 +2153,23 @@ function Sidebar:create_input_container(opts) --- prevent the cursor from jumping to the bottom of the --- buffer at the beginning self:update_content("", { focus = true, scroll = false }) - self:update_content(content_prefix .. generating_text) - - local original_response = "" - local waiting_for_breakline = false - local transformed_response = "" - local displayed_response = "" - local current_path = "" - - local is_first_chunk = true - local scroll = true + self:update_content(content_prefix) ---stop scroll when user presses j/k keys local function on_j() - scroll = false + self.scroll = false ---perform scroll vim.cmd("normal! j") end local function on_k() - scroll = false + self.scroll = false ---perform scroll vim.cmd("normal! k") end local function on_G() - scroll = true + self.scroll = true ---perform scroll vim.cmd("normal! G") end @@ -2615,82 +2181,42 @@ function Sidebar:create_input_container(opts) ---@type AvanteLLMStartCallback local function on_start(_) end - ---@type AvanteLLMChunkCallback - local function on_chunk(chunk) - self.is_generating = true + ---@param messages avante.HistoryMessage[] + local function on_messages_add(messages) self:add_history_messages(messages) end - local remove_line = [[\033[1A\033[K]] - if chunk:sub(1, #remove_line) == remove_line then - chunk = chunk:sub(#remove_line + 1) - local lines = vim.split(transformed_response, "\n") - local idx = #lines - while idx > 0 and lines[idx] == "" do - idx = idx - 1 - end - if idx == 1 then - lines = {} - else - lines = vim.list_slice(lines, 1, idx - 1) - end - transformed_response = table.concat(lines, "\n") - else - original_response = original_response .. chunk - end - - local selected_files = self.file_selector:get_selected_files_contents() - - local transformed_response_ - if waiting_for_breakline and chunk and chunk:sub(1, 1) ~= "\n" then - transformed_response_ = transformed_response .. "\n" .. chunk - else - transformed_response_ = transformed_response .. chunk - end - - local transformed = transform_result_content(selected_files, transformed_response_, current_path) - waiting_for_breakline = transformed.waiting_for_breakline - transformed_response = transformed.content - if transformed.current_filepath and transformed.current_filepath ~= "" then - current_path = transformed.current_filepath - end - local cur_displayed_response = generate_display_content(transformed) - if is_first_chunk then - is_first_chunk = false - self:update_content(content_prefix .. chunk, { scroll = scroll }) - displayed_response = cur_displayed_response - return - end - local suffix = get_display_content_suffix(transformed) - self:update_content(content_prefix .. cur_displayed_response .. suffix, { scroll = scroll }) - vim.schedule(function() vim.cmd("redraw") end) - displayed_response = cur_displayed_response + ---@param state avante.GenerateState + local function on_state_change(state) + self:clear_state() + self.current_state = state + self:render_state() end - local tool_use_log_history = {} + local save_history = Utils.debounce(function() Path.history.save(self.code.bufnr, self.chat_history) end, 3000) ---@param tool_id string ---@param tool_name string ---@param log string ---@param state AvanteLLMToolUseState local function on_tool_log(tool_id, tool_name, log, state) - if state == "generating" then - if tool_use_log_history[tool_id] then return end - tool_use_log_history[tool_id] = true + if state == "generating" then on_state_change("tool calling") end + local tool_message = vim.iter(self.chat_history.messages):find(function(message) + if message.message.role ~= "assistant" then return false end + local content = message.message.content + if type(content) ~= "table" then return false end + if content[1].type ~= "tool_use" then return false end + if content[1].id ~= tool_id then return false end + return true + end) + if not tool_message then + Utils.debug("tool_message not found", tool_id, tool_name) + return end - if transformed_response:sub(-1) ~= "\n" then transformed_response = transformed_response .. "\n" end - transformed_response = transformed_response .. "[" .. tool_name .. "]: " .. log .. "\n" - local breakline = "" - if displayed_response:sub(-1) ~= "\n" then breakline = "\n" end - displayed_response = displayed_response .. breakline .. "[" .. tool_name .. "]: " .. log .. "\n" - self:update_content(content_prefix .. displayed_response, { - scroll = scroll, - }) - end - - ---@param tool_use AvantePartialLLMToolUse - local function on_partial_tool_use(tool_use) - if not tool_use.name then return end - if not tool_use.id then return end - on_tool_log(tool_use.id, tool_use.name, "calling...", tool_use.state) + local tool_use_logs = tool_message.tool_use_logs or {} + local content = string.format("[%s]: %s", tool_name, log) + table.insert(tool_use_logs, content) + tool_message.tool_use_logs = tool_use_logs + save_history() + self:update_content("") end ---@type AvanteLLMStopCallback @@ -2705,22 +2231,25 @@ function Sidebar:create_input_container(opts) end) if stop_opts.error ~= nil then - self:update_content( - content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(stop_opts.error), - { scroll = scroll } - ) + local msg_content = stop_opts.error + if type(msg_content) ~= "string" then msg_content = vim.inspect(msg_content) end + self:add_history_messages({ + HistoryMessage:new({ + role = "assistant", + content = "\n\nError: " .. msg_content, + }, { + just_for_display = true, + }), + }) + on_state_change("failed") return end - self:update_content( - content_prefix - .. displayed_response - .. "\n\n**Generation complete!** Please review the code suggestions above.\n", - { - scroll = scroll, - callback = function() api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN }) end, - } - ) + on_state_change("succeeded") + + self:update_content("", { + callback = function() api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN }) end, + }) vim.defer_fn(function() if Utils.is_valid_container(self.result_container, true) and Config.behaviour.jump_result_buffer_on_finish then @@ -2729,19 +2258,6 @@ function Sidebar:create_input_container(opts) if Config.behaviour.auto_apply_diff_after_generation then self:apply(false) end end, 0) - -- Save chat history - self.chat_history.entries = self.chat_history.entries or {} - table.insert(self.chat_history.entries, { - timestamp = timestamp, - provider = Config.provider, - model = model, - request = request, - response = displayed_response, - original_response = original_response, - selected_filepaths = selected_filepaths, - selected_code = selected_code, - tool_histories = stop_opts.tool_histories, - }) if self.chat_history.title == "untitled" then Llm.summarize_chat_thread_title(request, function(title) if title then self.chat_history.title = title end @@ -2752,40 +2268,46 @@ function Sidebar:create_input_container(opts) end end - get_generate_prompts_options(request, true, function(generate_prompts_options) + if request and request ~= "" then + self:add_history_messages({ + HistoryMessage:new({ + role = "user", + content = request, + }, { + is_user_submission = true, + selected_filepaths = selected_filepaths, + selected_code = selected_code, + }), + }) + end + + get_generate_prompts_options(request, function(generate_prompts_options) ---@type AvanteLLMStreamOptions ---@diagnostic disable-next-line: assign-type-mismatch local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { on_start = on_start, - on_chunk = on_chunk, on_stop = on_stop, on_tool_log = on_tool_log, - on_partial_tool_use = on_partial_tool_use, + on_messages_add = on_messages_add, + on_state_change = on_state_change, + get_history_messages = function() return self:get_history_messages_for_api() end, session_ctx = {}, }) + ---@param dropped_history_messages avante.HistoryMessage[] local function on_memory_summarize(dropped_history_messages) - local entries = Utils.history.filter_active_entries(self.chat_history.entries) - - if self.chat_history.memory then - entries = vim - .iter(entries) - :filter(function(entry) return entry.timestamp > self.chat_history.memory.last_summarized_timestamp end) - :totable() - end - - entries = vim.list_slice(entries, 1, #dropped_history_messages) - - Llm.summarize_memory(self.code.bufnr, self.chat_history, entries, function(memory) + Llm.summarize_memory(self.code.bufnr, self.chat_history, dropped_history_messages, function(memory) if memory then stream_options.memory = memory.content end - stream_options.history_messages = - vim.list_slice(stream_options.history_messages, #dropped_history_messages + 1) + stream_options.history_messages = self:get_history_messages_for_api() + -- Utils.debug("dropping history messages", dropped_history_messages) + -- Utils.debug("history messages", stream_options.history_messages) Llm.stream(stream_options) end) end stream_options.on_memory_summarize = on_memory_summarize + on_state_change("generating") Llm.stream(stream_options) end) end @@ -2935,8 +2457,8 @@ function Sidebar:create_input_container(opts) if Config.behaviour.enable_token_counting then local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n") - get_generate_prompts_options(input_value, false, function(generate_prompts_options) - local tokens = Llm.calculate_tokens(generate_prompts_options) + get_generate_prompts_options(input_value, function(generate_prompts_options) + local tokens = Llm.calculate_tokens(generate_prompts_options) + Utils.tokens.calculate_tokens(input_value) hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text show() end) @@ -3058,8 +2580,8 @@ function Sidebar:get_selected_files_size() local selected_files_max_lines_count = 10 - local selected_files = self.file_selector:get_selected_filepaths() - local selected_files_size = #selected_files + local selected_filepaths = self.file_selector:get_selected_filepaths() + local selected_files_size = #selected_filepaths selected_files_size = math.min(selected_files_size, selected_files_max_lines_count) return selected_files_size @@ -3129,16 +2651,18 @@ function Sidebar:render(opts) self:update_content_with_history() - -- reset states when buffer is closed - api.nvim_buf_attach(self.code.bufnr, false, { - on_detach = function(_, _) - vim.schedule(function() - local bufnr = api.nvim_win_get_buf(self.code.winid) - self.code.bufnr = bufnr - self:reload_chat_history() - end) - end, - }) + if self.code.bufnr and api.nvim_buf_is_valid(self.code.bufnr) then + -- reset states when buffer is closed + api.nvim_buf_attach(self.code.bufnr, false, { + on_detach = function(_, _) + vim.schedule(function() + local bufnr = api.nvim_win_get_buf(self.code.winid) + self.code.bufnr = bufnr + self:reload_chat_history() + end) + end, + }) + end self:create_selected_code_container() diff --git a/lua/avante/templates/agentic.avanterules b/lua/avante/templates/agentic.avanterules new file mode 100644 index 0000000..0290953 --- /dev/null +++ b/lua/avante/templates/agentic.avanterules @@ -0,0 +1,10 @@ +{% extends "base.avanterules" %} +{% block extra_prompt %} +Always reply to the user in the same language they are using. + +Don't just provide code suggestions, use the `replace_in_file` tool to help users fulfill their needs. + +After the tool call is complete, please do not output the entire file content. + +Before calling the tool, be sure to explain the reason for calling the tool. +{% endblock %} diff --git a/lua/avante/templates/claude-text-editor-tool.avanterules b/lua/avante/templates/claude-text-editor-tool.avanterules deleted file mode 100644 index a95c245..0000000 --- a/lua/avante/templates/claude-text-editor-tool.avanterules +++ /dev/null @@ -1,6 +0,0 @@ -{% extends "base.avanterules" %} -{% block extra_prompt %} -Always reply to the user in the same language they are using. - -Don't just provide code suggestions, use the `str_replace` tool to help users fulfill their needs. -{% endblock %} diff --git a/lua/avante/templates/cursor-applying.avanterules b/lua/avante/templates/cursor-applying.avanterules deleted file mode 100644 index 64be9b6..0000000 --- a/lua/avante/templates/cursor-applying.avanterules +++ /dev/null @@ -1,4 +0,0 @@ -You are a coding assistant that helps merge code updates, ensuring every modification is fully integrated. - -{% block custom_prompt %} -{% endblock %} diff --git a/lua/avante/templates/cursor-planning.avanterules b/lua/avante/templates/cursor-planning.avanterules deleted file mode 100644 index e610499..0000000 --- a/lua/avante/templates/cursor-planning.avanterules +++ /dev/null @@ -1,46 +0,0 @@ -{% extends "base.avanterules" %} -{%- if ask %} -{% block extra_prompt %} -You are an intelligent programmer, powered by {{ model_name }}. You are happy to help answer any questions that the user has (usually they will be about coding). - -1. When the user is asking for edits to their code, please output a simplified version of the code block that highlights the changes necessary and adds comments to indicate where unchanged code has been skipped. For example: -```language:path/to/file -// ... existing code ... -{% raw -%} -{{ edit_1 }} -{%- endraw %} -// ... existing code ... -{% raw -%} -{{ edit_2 }} -{%- endraw %} -// ... existing code ... -``` -The user can see the entire file, so they prefer to only read the updates to the code. Often this will mean that the start/end of the file will be skipped, but that's okay! Rewrite the entire file only if specifically requested. Always provide a brief explanation of the updates, unless the user specifically requests only the code. - -These edit codeblocks are also read by a less intelligent language model, colloquially called the apply model, to update the file. To help specify the edit to the apply model, you will be very careful when generating the codeblock to not introduce ambiguity. You will specify all unchanged regions (code and comments) of the file with "// … existing code …" comment markers. This will ensure the apply model will not delete existing unchanged code or comments when editing the file. You will not mention the apply model. - -2. Do not lie or make up facts. - -3. If a user messages you in a foreign language, please respond in that language. - -4. Format your response in markdown. - -5. When writing out new code blocks, please specify the language ID after the initial backticks, like so: -```python -{% raw -%} -{{ code }} -{%- endraw %} -``` - -6. When writing out code blocks for an existing file, please also specify the file path after the initial backticks and restate the method / class your codeblock belongs to, like so: -```language:some/other/file -function AIChatHistory() { - ... - {% raw -%} - {{ code }} - {%- endraw %} - ... -} -``` -{% endblock %} -{%- endif %} diff --git a/lua/avante/templates/planning.avanterules b/lua/avante/templates/legacy.avanterules similarity index 100% rename from lua/avante/templates/planning.avanterules rename to lua/avante/templates/legacy.avanterules diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 45c4b91..dd1b5ce 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -75,15 +75,32 @@ vim.g.avante_login = vim.g.avante_login ---@field on_start AvanteLLMStartCallback ---@field on_chunk AvanteLLMChunkCallback ---@field on_stop AvanteLLMStopCallback ----@field on_partial_tool_use? fun(tool_use: AvantePartialLLMToolUse): nil +---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil +---@field on_state_change? fun(state: avante.GenerateState): nil --- ---@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string } ---- + ---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string ---- + ---@class AvanteLLMMessage ---@field role "user" | "assistant" ---@field content AvanteLLMMessageContent + +---@class avante.HistoryMessage +---@field message AvanteLLMMessage +---@field timestamp string +---@field state avante.HistoryMessageState +---@field uuid string | nil +---@field displayed_content string | nil +---@field visible boolean | nil +---@field is_context boolean | nil +---@field is_user_submission boolean | nil +---@field provider string | nil +---@field model string | nil +---@field selected_code AvanteSelectedCode | nil +---@field selected_filepaths string[] | nil +---@field tool_use_logs string[] | nil +---@field just_for_display boolean | nil --- ---@class AvanteLLMToolResult ---@field tool_name string @@ -96,8 +113,7 @@ vim.g.avante_login = vim.g.avante_login ---@field messages AvanteLLMMessage[] ---@field image_paths? string[] ---@field tools? AvanteLLMTool[] ----@field tool_histories? AvanteLLMToolHistory[] ----@field dropped_history_messages? AvanteLLMMessage[] +---@field dropped_history_messages? avante.HistoryMessage[] --- ---@class AvanteGeminiMessage ---@field role "user" @@ -236,19 +252,18 @@ vim.g.avante_login = vim.g.avante_login ---@class AvanteLLMRedactedThinkingBlock ---@field data string --- +---@alias avante.HistoryMessageState "generating" | "generated" +--- ---@class AvantePartialLLMToolUse ---@field name string ---@field id string ---@field partial_json table ----@field state "generating" | "generated" +---@field state avante.HistoryMessageState --- ---@class AvanteLLMToolUse ---@field name string ---@field id string ----@field input_json string ----@field response_contents? string[] ----@field thinking_blocks? AvanteLLMThinkingBlock[] ----@field redacted_thinking_blocks? AvanteLLMRedactedThinkingBlock[] +---@field input any --- ---@class AvanteLLMStartCallbackOptions ---@field usage? AvanteLLMUsage @@ -257,10 +272,8 @@ vim.g.avante_login = vim.g.avante_login ---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled" ---@field error? string | table ---@field usage? AvanteLLMUsage ----@field tool_use_list? AvanteLLMToolUse[] ---@field retry_after? integer ---@field headers? table ----@field tool_histories? AvanteLLMToolHistory[] --- ---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil @@ -303,7 +316,7 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_response AvanteResponseParser ---@field build_bedrock_payload AvanteBedrockPayloadBuilder --- ----@alias AvanteLlmMode "planning" | "editing" | "suggesting" | "cursor-planning" | "cursor-applying" | "claude-text-editor-tool" +---@alias AvanteLlmMode avante.Mode | "editing" | "suggesting" --- ---@class AvanteSelectedCode ---@field path string @@ -324,7 +337,7 @@ vim.g.avante_login = vim.g.avante_login ---@field selected_files AvanteSelectedFile[] | nil ---@field selected_filepaths string[] | nil ---@field diagnostics string | nil ----@field history_messages AvanteLLMMessage[] | nil +---@field history_messages avante.HistoryMessage[] | nil ---@field memory string | nil --- ---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions @@ -332,7 +345,6 @@ vim.g.avante_login = vim.g.avante_login ---@field mode? AvanteLlmMode ---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil ---@field tools? AvanteLLMTool[] ----@field tool_histories? AvanteLLMToolHistory[] ---@field original_code? string ---@field update_snippets? string[] ---@field prompt_opts? AvantePromptOptions @@ -342,9 +354,10 @@ vim.g.avante_login = vim.g.avante_login ---@field tool_result? AvanteLLMToolResult ---@field tool_use? AvanteLLMToolUse --- ----@alias AvanteLLMMemorySummarizeCallback fun(dropped_history_messages: AvanteLLMMessage[]): nil +---@alias AvanteLLMMemorySummarizeCallback fun(dropped_history_messages: avante.HistoryMessage[]): nil --- ---@alias AvanteLLMToolUseState "generating" | "generated" | "running" | "succeeded" | "failed" +---@alias avante.GenerateState "generating" | "tool calling" | "failed" | "succeeded" | "cancelled" | "searching" | "thinking" --- ---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions ---@field on_start AvanteLLMStartCallback @@ -352,7 +365,9 @@ vim.g.avante_login = vim.g.avante_login ---@field on_stop AvanteLLMStopCallback ---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback ---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil ----@field on_partial_tool_use? fun(tool_use: AvantePartialLLMToolUse): nil +---@field get_history_messages? fun(): avante.HistoryMessage[] +---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil +---@field on_state_change? fun(state: avante.GenerateState): nil --- ---@alias AvanteLLMToolFunc fun( --- input: T, @@ -400,15 +415,14 @@ vim.g.avante_login = vim.g.avante_login ---@field original_response string ---@field selected_file {filepath: string}? ---@field selected_code AvanteSelectedCode | nil ----@field reset_memory boolean? ---@field selected_filepaths string[] | nil ---@field visible boolean? ----@field tool_histories? AvanteLLMToolHistory[] --- ---@class avante.ChatHistory ---@field title string ---@field timestamp string ----@field entries avante.ChatHistoryEntry[] +---@field messages avante.HistoryMessage[] | nil +---@field entries avante.ChatHistoryEntry[] | nil ---@field memory avante.ChatMemory | nil ---@field filename string ---@field system_prompt string | nil @@ -416,6 +430,7 @@ vim.g.avante_login = vim.g.avante_login ---@class avante.ChatMemory ---@field content string ---@field last_summarized_timestamp string +---@field last_message_uuid string | nil --- ---@class avante.CurlOpts ---@field provider AvanteProviderFunctor @@ -427,7 +442,7 @@ vim.g.avante_login = vim.g.avante_login ---@field content string ---@field uri string --- ----@alias AvanteSlashCommandBuiltInName "clear" | "help" | "lines" | "reset" | "commit" | "new" +---@alias AvanteSlashCommandBuiltInName "clear" | "help" | "lines" | "commit" | "new" ---@alias AvanteSlashCommandCallback fun(self: avante.Sidebar, args: string, cb?: fun(args: string): nil): nil ---@class AvanteSlashCommand ---@field name AvanteSlashCommandBuiltInName | string diff --git a/lua/avante/ui/selector/providers/telescope.lua b/lua/avante/ui/selector/providers/telescope.lua index 7b7373e..c259d02 100644 --- a/lua/avante/ui/selector/providers/telescope.lua +++ b/lua/avante/ui/selector/providers/telescope.lua @@ -80,7 +80,7 @@ function M.show(selector) selector.on_select(selected_item_ids) - actions.close(prompt_bufnr) + pcall(actions.close, prompt_bufnr) end) return true end, diff --git a/lua/avante/utils/history.lua b/lua/avante/utils/history.lua deleted file mode 100644 index dea77c3..0000000 --- a/lua/avante/utils/history.lua +++ /dev/null @@ -1,87 +0,0 @@ -local Utils = require("avante.utils") -local Config = require("avante.config") - ----@class avante.utils.history -local M = {} - ----@param entries avante.ChatHistoryEntry[] ----@return avante.ChatHistoryEntry[] -function M.filter_active_entries(entries) - local entries_ = {} - - for i = #entries, 1, -1 do - local entry = entries[i] - if entry.reset_memory then break end - table.insert(entries_, 1, entry) - end - - return entries_ -end - ----@param entries avante.ChatHistoryEntry[] ----@return AvanteLLMMessage[] -function M.entries_to_llm_messages(entries) - local current_provider_name = Config.provider - local messages = {} - for _, entry in ipairs(entries) do - if entry.selected_filepaths ~= nil and #entry.selected_filepaths > 0 then - local user_content = "SELECTED FILES:\n\n" - for _, filepath in ipairs(entry.selected_filepaths) do - user_content = user_content .. filepath .. "\n" - end - table.insert(messages, { role = "user", content = user_content }) - end - if entry.selected_code ~= nil then - local user_content_ = "SELECTED CODE:\n\n```" - .. (entry.selected_code.file_type or "") - .. (entry.selected_code.path and ":" .. entry.selected_code.path or "") - .. "\n" - .. entry.selected_code.content - .. "\n```\n\n" - table.insert(messages, { role = "user", content = user_content_ }) - end - if entry.request ~= nil and entry.request ~= "" then - table.insert(messages, { role = "user", content = entry.request }) - end - if entry.tool_histories ~= nil and #entry.tool_histories > 0 and entry.provider == current_provider_name then - for _, tool_history in ipairs(entry.tool_histories) do - local assistant_content = {} - if tool_history.tool_use ~= nil then - if tool_history.tool_use.response_contents ~= nil then - for _, response_content in ipairs(tool_history.tool_use.response_contents) do - table.insert(assistant_content, { type = "text", text = response_content }) - end - end - table.insert(assistant_content, { - type = "tool_use", - name = tool_history.tool_use.name, - id = tool_history.tool_use.id, - input = vim.json.decode(tool_history.tool_use.input_json), - }) - end - table.insert(messages, { - role = "assistant", - content = assistant_content, - }) - local user_content = {} - if tool_history.tool_result ~= nil and tool_history.tool_result.content ~= nil then - table.insert(user_content, { - type = "tool_result", - tool_use_id = tool_history.tool_result.tool_use_id, - content = tool_history.tool_result.content, - is_error = tool_history.tool_result.is_error, - }) - end - table.insert(messages, { - role = "user", - content = user_content, - }) - end - end - local assistant_content = Utils.trim_think_content(entry.original_response or "") - if assistant_content ~= "" then table.insert(messages, { role = "assistant", content = assistant_content }) end - end - return messages -end - -return M diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 82081d9..ee7e76c 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -6,7 +6,6 @@ local lsp = vim.lsp ---@field tokens avante.utils.tokens ---@field root avante.utils.root ---@field file avante.utils.file ----@field history avante.utils.history ---@field environment avante.utils.environment ---@field lsp avante.utils.lsp local M = {} @@ -415,7 +414,7 @@ function M.debug(...) local caller_source = info.source:match("@(.+)$") or "unknown" local caller_module = caller_source:gsub("^.*/lua/", ""):gsub("%.lua$", ""):gsub("/", ".") - local timestamp = os.date("%Y-%m-%d %H:%M:%S") + local timestamp = M.get_timestamp() local formated_args = { "[" .. timestamp .. "] [AVANTE] [DEBUG] [" .. caller_module .. ":" .. info.currentline .. "]", } @@ -1263,7 +1262,6 @@ function M.get_commands() local builtin_items = { { description = "Show help message", name = "help" }, { description = "Clear chat history", name = "clear" }, - { description = "Reset memory", name = "reset" }, { description = "New chat", name = "new" }, { shorthelp = "Ask a question about specific lines", @@ -1281,7 +1279,6 @@ function M.get_commands() if cb then cb(args) end end, clear = function(sidebar, args, cb) sidebar:clear_history(args, cb) end, - reset = function(sidebar, args, cb) sidebar:reset_memory(args, cb) end, new = function(sidebar, args, cb) sidebar:new_chat(args, cb) end, lines = function(_, args, cb) if cb then cb(args) end @@ -1310,4 +1307,97 @@ function M.get_commands() return vim.list_extend(builtin_commands, Config.slash_commands) end +---@param history avante.ChatHistory +---@return avante.HistoryMessage[] +function M.get_history_messages(history) + local HistoryMessage = require("avante.history_message") + if history.messages then return history.messages end + local messages = {} + for _, entry in ipairs(history.entries or {}) do + if entry.request and entry.request ~= "" then + local message = HistoryMessage:new({ + role = "user", + content = entry.request, + }, { + timestamp = entry.timestamp, + is_user_submission = true, + visible = entry.visible, + selected_filepaths = entry.selected_filepaths, + selected_code = entry.selected_code, + }) + table.insert(messages, message) + end + if entry.response and entry.response ~= "" then + local message = HistoryMessage:new({ + role = "assistant", + content = entry.response, + }, { + timestamp = entry.timestamp, + visible = entry.visible, + }) + table.insert(messages, message) + end + end + history.messages = messages + return messages +end + +function M.get_timestamp() return tostring(os.date("%Y-%m-%d %H:%M:%S")) end + +---@param history_messages avante.HistoryMessage[] +---@return AvanteLLMMessage[] +function M.history_messages_to_messages(history_messages) + local messages = {} + for _, history_message in ipairs(history_messages) do + if history_message.just_for_display then goto continue end + table.insert(messages, history_message.message) + ::continue:: + end + return messages +end + +function M.uuid() + local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx" + return string.gsub(template, "[xy]", function(c) + local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb) + return string.format("%x", v) + end) +end + +---@param item AvanteLLMMessageContentItem +---@param message avante.HistoryMessage +---@return string +function M.message_content_item_to_text(item, message) + if type(item) == "string" then return item end + if type(item) == "table" then + if item.type == "text" then return item.text end + if item.type == "image" then return "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" end + if item.type == "tool_use" then + local pieces = {} + table.insert(pieces, string.format("[%s]: calling", item.name)) + for _, log in ipairs(message.tool_use_logs or {}) do + table.insert(pieces, log) + end + return table.concat(pieces, "\n") + end + end + return "" +end + +---@param message avante.HistoryMessage +---@return string +function M.message_to_text(message) + local content = message.message.content + if type(content) == "string" then return content end + if vim.islist(content) then + local pieces = {} + for _, item in ipairs(content) do + local text = M.message_content_item_to_text(item, message) + if text ~= "" then table.insert(pieces, text) end + end + return table.concat(pieces, "\n") + end + return "" +end + return M diff --git a/lua/avante/utils/streaming_json_parser.lua b/lua/avante/utils/streaming_json_parser.lua index a2e6be4..20e51ec 100644 --- a/lua/avante/utils/streaming_json_parser.lua +++ b/lua/avante/utils/streaming_json_parser.lua @@ -77,10 +77,24 @@ function StreamingJSONParser:parse(chunk) -- Handle strings specially (they can contain JSON control characters) if self.state.inString then if self.state.escaping then - self.state.stringBuffer = self.state.stringBuffer .. char + local escapeMap = { + ['"'] = '"', + ["\\"] = "\\", + ["/"] = "/", + ["b"] = "\b", + ["f"] = "\f", + ["n"] = "\n", + ["r"] = "\r", + ["t"] = "\t", + } + local escapedChar = escapeMap[char] + if escapedChar then + self.state.stringBuffer = self.state.stringBuffer .. escapedChar + else + self.state.stringBuffer = self.state.stringBuffer .. char + end self.state.escaping = false elseif char == "\\" then - self.state.stringBuffer = self.state.stringBuffer .. char self.state.escaping = true elseif char == '"' then -- End of string diff --git a/plugin/avante.lua b/plugin/avante.lua index 44b8ca5..869637f 100644 --- a/plugin/avante.lua +++ b/plugin/avante.lua @@ -132,17 +132,13 @@ cmd( cmd("Clear", function(opts) local arg = vim.trim(opts.args or "") arg = arg == "" and "history" or arg - if arg == "history" or arg == "memory" then + if arg == "history" then local sidebar = require("avante").get() if not sidebar then Utils.error("No sidebar found") return end - if arg == "history" then - sidebar:clear_history() - else - sidebar:reset_memory() - end + sidebar:clear_history() elseif arg == "cache" then local P = require("avante.path") local history_path = P.history_path:absolute() diff --git a/tests/utils/streaming_json_parser_spec.lua b/tests/utils/streaming_json_parser_spec.lua index 1cc61d5..02da8f5 100644 --- a/tests/utils/streaming_json_parser_spec.lua +++ b/tests/utils/streaming_json_parser_spec.lua @@ -29,6 +29,13 @@ describe("StreamingJSONParser", function() assert.equals("value", result.key) end) + it("should parse breaklines", function() + local result, complete = parser:parse('{"key": "value\nv"}') + assert.is_true(complete) + assert.is_table(result) + assert.equals("value\nv", result.key) + end) + it("should parse a complete simple JSON array", function() local result, complete = parser:parse("[1, 2, 3]") assert.is_true(complete) @@ -119,7 +126,7 @@ describe("StreamingJSONParser", function() local result, complete = parser:parse('{"text": "line1\\nline2\\t\\"quoted\\""}') assert.is_true(complete) assert.is_table(result) - assert.equals('line1\\nline2\\t\\"quoted\\"', result.text) + assert.equals('line1\nline2\t"quoted"', result.text) end) it("should handle numbers correctly", function()