refactor: history messages (#1934)

This commit is contained in:
yetone
2025-04-30 03:07:18 +08:00
committed by GitHub
parent f9aa75459d
commit f10b8383e3
36 changed files with 1699 additions and 1462 deletions

View File

@@ -305,6 +305,8 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
{ {
---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | string ---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | string
provider = "claude", -- The provider used in Aider mode or in the planning phase of Cursor Planning Mode provider = "claude", -- The provider used in Aider mode or in the planning phase of Cursor Planning Mode
---@alias Mode "agentic" | "legacy"
mode = "agentic", -- The default mode for interaction. "agentic" uses tools to automatically generate code, "legacy" uses the old planning method to generate code.
-- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive, -- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive,
-- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048 -- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048
-- Of course, you can reduce the request frequency by increasing `suggestion.debounce`. -- Of course, you can reduce the request frequency by increasing `suggestion.debounce`.
@@ -340,8 +342,6 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
support_paste_from_clipboard = false, support_paste_from_clipboard = false,
minimize_diff = true, -- Whether to remove unchanged lines when applying a code block minimize_diff = true, -- Whether to remove unchanged lines when applying a code block
enable_token_counting = true, -- Whether to enable token counting. Default to true. enable_token_counting = true, -- Whether to enable token counting. Default to true.
enable_cursor_planning_mode = false, -- Whether to enable Cursor Planning Mode. Default to false.
enable_claude_text_editor_tool_mode = false, -- Whether to enable Claude Text Editor Tool Mode.
}, },
mappings = { mappings = {
--- @class AvanteConflictMappings --- @class AvanteConflictMappings
@@ -734,12 +734,6 @@ Avante provides a set of default providers, but users can also create their own
For more information, see [Custom Providers](https://github.com/yetone/avante.nvim/wiki/Custom-providers) For more information, see [Custom Providers](https://github.com/yetone/avante.nvim/wiki/Custom-providers)
## Cursor planning mode
Because avante.nvim has always used Aiders method for planning applying, but its prompts are very picky with models and require ones like claude-3.5-sonnet or gpt-4o to work properly.
Therefore, I have adopted Cursors method to implement planning applying. For details on the implementation, please refer to [cursor-planning-mode.md](./cursor-planning-mode.md)
## RAG Service ## RAG Service
Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. By default, it is not enabled. You can enable it this way: Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. By default, it is not enabled. You can enable it this way:
@@ -890,21 +884,6 @@ Avante allows you to define custom tools that can be used by the AI during code
Now you can integrate MCP functionality for Avante through `mcphub.nvim`. For detailed documentation, please refer to [mcphub.nvim](https://github.com/ravitemer/mcphub.nvim#avante-integration) Now you can integrate MCP functionality for Avante through `mcphub.nvim`. For detailed documentation, please refer to [mcphub.nvim](https://github.com/ravitemer/mcphub.nvim#avante-integration)
## Claude Text Editor Tool Mode
Avante leverages [Claude Text Editor Tool](https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool) to provide a more elegant code editing experience. You can now enable this feature by setting `enable_claude_text_editor_tool_mode` to `true` in the `behaviour` configuration:
```lua
{
behaviour = {
enable_claude_text_editor_tool_mode = true,
},
}
```
> [!NOTE]
> To enable **Claude Text Editor Tool Mode**, you must use the `claude-3-5-sonnet-*` or `claude-3-7-sonnet-*` model with the `claude` provider! This feature is not supported by any other models!
## Custom prompts ## Custom prompts
By default, `avante.nvim` provides three different modes to interact with: `planning`, `editing`, and `suggesting`, followed with three different prompts per mode. By default, `avante.nvim` provides three different modes to interact with: `planning`, `editing`, and `suggesting`, followed with three different prompts per mode.

View File

@@ -226,16 +226,16 @@ end
function M.select_model() require("avante.model_selector").open() end function M.select_model() require("avante.model_selector").open() end
function M.select_history() function M.select_history()
require("avante.history_selector").open(vim.api.nvim_get_current_buf(), function(filename) local buf = vim.api.nvim_get_current_buf()
require("avante.history_selector").open(buf, function(filename)
vim.api.nvim_buf_call(buf, function()
if not require("avante").is_sidebar_open() then require("avante").open_sidebar({}) end
local Path = require("avante.path") local Path = require("avante.path")
Path.history.save_latest_filename(vim.api.nvim_get_current_buf(), filename) Path.history.save_latest_filename(buf, filename)
local sidebar = require("avante").get() local sidebar = require("avante").get()
if not sidebar then
require("avante.api").ask()
sidebar = require("avante").get()
end
sidebar:update_content_with_history() sidebar:update_content_with_history()
if not sidebar:is_open() then sidebar:open({}) end vim.schedule(function() sidebar:focus_input() end)
end)
end) end)
end end

View File

@@ -19,6 +19,8 @@ local M = {}
---@class avante.Config ---@class avante.Config
M._defaults = { M._defaults = {
debug = false, debug = false,
---@alias avante.Mode "agentic" | "legacy"
mode = "agentic",
---@alias avante.ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | string ---@alias avante.ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | string
provider = "claude", provider = "claude",
-- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive, -- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive,
@@ -380,8 +382,6 @@ M._defaults = {
support_paste_from_clipboard = false, support_paste_from_clipboard = false,
minimize_diff = true, minimize_diff = true,
enable_token_counting = true, enable_token_counting = true,
enable_cursor_planning_mode = false,
enable_claude_text_editor_tool_mode = false,
use_cwd_as_project_root = false, use_cwd_as_project_root = false,
auto_focus_on_diff_view = false, auto_focus_on_diff_view = false,
}, },

View File

@@ -40,6 +40,12 @@ local Highlights = {
AVANTE_SIDEBAR_NORMAL = { name = "AvanteSidebarNormal", link = "NormalFloat" }, AVANTE_SIDEBAR_NORMAL = { name = "AvanteSidebarNormal", link = "NormalFloat" },
AVANTE_COMMENT_FG = { name = "AvanteCommentFg", fg_link = "Comment" }, AVANTE_COMMENT_FG = { name = "AvanteCommentFg", fg_link = "Comment" },
AVANTE_REVERSED_NORMAL = { name = "AvanteReversedNormal", fg_link_bg = "Normal", bg_link_fg = "Normal" }, AVANTE_REVERSED_NORMAL = { name = "AvanteReversedNormal", fg_link_bg = "Normal", bg_link_fg = "Normal" },
AVANTE_STATE_SPINNER_GENERATING = { name = "AvanteStateSpinnerGenerating", fg = "#1e222a", bg = "#ab9df2" },
AVANTE_STATE_SPINNER_TOOL_CALLING = { name = "AvanteStateSpinnerToolCalling", fg = "#1e222a", bg = "#56b6c2" },
AVANTE_STATE_SPINNER_FAILED = { name = "AvanteStateSpinnerFailed", fg = "#1e222a", bg = "#e06c75" },
AVANTE_STATE_SPINNER_SUCCEEDED = { name = "AvanteStateSpinnerSucceeded", fg = "#1e222a", bg = "#98c379" },
AVANTE_STATE_SPINNER_SEARCHING = { name = "AvanteStateSpinnerSearching", fg = "#1e222a", bg = "#c678dd" },
AVANTE_STATE_SPINNER_THINKING = { name = "AvanteStateSpinnerThinking", fg = "#1e222a", bg = "#c678dd" },
} }
Highlights.conflict = { Highlights.conflict = {

View File

@@ -0,0 +1,28 @@
local Utils = require("avante.utils")
---@class avante.HistoryMessage
local M = {}
M.__index = M
---@param message AvanteLLMMessage
---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean}
---@return avante.HistoryMessage
function M:new(message, opts)
opts = opts or {}
local obj = setmetatable({}, M)
obj.message = message
obj.uuid = opts.uuid or Utils.uuid()
obj.state = opts.state or "generated"
obj.timestamp = Utils.get_timestamp()
obj.is_user_submission = false
obj.visible = true
if opts.is_user_submission ~= nil then obj.is_user_submission = opts.is_user_submission end
if opts.visible ~= nil then obj.visible = opts.visible end
if opts.displayed_content ~= nil then obj.displayed_content = opts.displayed_content end
if opts.selected_filepaths ~= nil then obj.selected_filepaths = opts.selected_filepaths end
if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end
if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end
return obj
end
return M

View File

@@ -9,8 +9,9 @@ local M = {}
---@param history avante.ChatHistory ---@param history avante.ChatHistory
---@return table? ---@return table?
local function to_selector_item(history) local function to_selector_item(history)
local timestamp = #history.entries > 0 and history.entries[#history.entries].timestamp or history.timestamp local messages = Utils.get_history_messages(history)
local name = history.title .. " - " .. timestamp .. " (" .. #history.entries .. ")" local timestamp = #messages > 0 and messages[#messages].timestamp or history.timestamp
local name = history.title .. " - " .. timestamp .. " (" .. #messages .. ")"
name = name:gsub("\n", "\\n") name = name:gsub("\n", "\\n")
return { return {
name = name, name = name,

View File

@@ -10,6 +10,7 @@ local Path = require("avante.path")
local Providers = require("avante.providers") local Providers = require("avante.providers")
local LLMToolHelpers = require("avante.llm_tools.helpers") local LLMToolHelpers = require("avante.llm_tools.helpers")
local LLMTools = require("avante.llm_tools") local LLMTools = require("avante.llm_tools")
local HistoryMessage = require("avante.history_message")
---@class avante.LLM ---@class avante.LLM
local M = {} local M = {}
@@ -26,7 +27,7 @@ function M.summarize_chat_thread_title(content, cb)
local system_prompt = local system_prompt =
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]] [[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]]
local response_content = "" local response_content = ""
local provider = Providers[Config.memory_summary_provider or Config.provider] local provider = Providers.get_memory_summary_provider()
M.curl({ M.curl({
provider = provider, provider = provider,
prompt_opts = { prompt_opts = {
@@ -58,73 +59,49 @@ function M.summarize_chat_thread_title(content, cb)
}) })
end end
---@param messages AvanteLLMMessage[]
---@return AvanteLLMMessage[]
local function filter_out_tool_use_messages(messages)
local filtered_messages = {}
for _, message in ipairs(messages) do
local content = message.content
if type(content) == "table" then
local new_content = {}
for _, item in ipairs(content) do
if item.type == "tool_use" or item.type == "tool_result" then goto continue end
table.insert(new_content, item)
::continue::
end
content = new_content
end
if type(content) == "table" then
if #content > 0 then table.insert(filtered_messages, { role = message.role, content = content }) end
else
table.insert(filtered_messages, { role = message.role, content = content })
end
end
return filtered_messages
end
---@param bufnr integer ---@param bufnr integer
---@param history avante.ChatHistory ---@param history avante.ChatHistory
---@param entries? avante.ChatHistoryEntry[] ---@param history_messages avante.HistoryMessage[]
---@param cb fun(memory: avante.ChatMemory | nil): nil ---@param cb fun(memory: avante.ChatMemory | nil): nil
function M.summarize_memory(bufnr, history, entries, cb) function M.summarize_memory(bufnr, history, history_messages, cb)
local system_prompt = [[You are a helpful AI assistant tasked with summarizing conversations.]] local system_prompt =
if not entries then entries = Utils.history.filter_active_entries(history.entries) end [[You are an expert coding assistant. Your goal is to generate a concise, structured summary of the conversation below that captures all essential information needed to continue development after context replacement. Include tasks performed, code areas modified or reviewed, key decisions or assumptions, test results or errors, and outstanding tasks or next steps.]]
if #entries == 0 then
cb(nil)
return
end
if history.memory then
entries = vim
.iter(entries)
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
:totable()
end
if #entries == 0 then
cb(history.memory)
return
end
local history_messages = Utils.history.entries_to_llm_messages(entries)
history_messages = filter_out_tool_use_messages(history_messages)
history_messages = vim.list_slice(history_messages, 1, 4)
if #history_messages == 0 then if #history_messages == 0 then
cb(history.memory) cb(history.memory)
return return
end end
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content) local latest_timestamp = history_messages[#history_messages].timestamp
local user_prompt = local latest_message_uuid = history_messages[#history_messages].uuid
[[Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.]] local conversation_items = vim
.iter(history_messages)
:filter(function(msg)
if msg.just_for_display then return false end
if msg.message.role ~= "assistant" and msg.message.role ~= "user" then return false end
local content = msg.message.content
if type(content) == "table" and content[1].type == "tool_result" then return false end
if type(content) == "table" and content[1].type == "tool_use" then return false end
return true
end)
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg) end)
:totable()
local conversation_text = table.concat(conversation_items, "\n")
local user_prompt = "Here is the conversation so far:\n"
.. conversation_text
.. "\n\nPlease summarize this conversation, covering:\n1. Tasks performed and outcomes\n2. Code files, modules, or functions modified or examined\n3. Important decisions or assumptions made\n4. Errors encountered and test or build results\n5. Remaining tasks, open questions, or next steps\nProvide the summary in a clear, concise format."
if history.memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content end if history.memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content end
table.insert(history_messages, { local messages = {
{
role = "user", role = "user",
content = user_prompt, content = user_prompt,
}) },
}
local response_content = "" local response_content = ""
local provider = Providers[Config.memory_summary_provider or Config.provider] local provider = Providers.get_memory_summary_provider()
M.curl({ M.curl({
provider = provider, provider = provider,
prompt_opts = { prompt_opts = {
system_prompt = system_prompt, system_prompt = system_prompt,
messages = history_messages, messages = messages,
}, },
handler_opts = { handler_opts = {
on_start = function(_) end, on_start = function(_) end,
@@ -141,11 +118,14 @@ function M.summarize_memory(bufnr, history, entries, cb)
response_content = Utils.trim_think_content(response_content) response_content = Utils.trim_think_content(response_content)
local memory = { local memory = {
content = response_content, content = response_content,
last_summarized_timestamp = entries[#entries].timestamp, last_summarized_timestamp = latest_timestamp,
last_message_uuid = latest_message_uuid,
} }
history.memory = memory history.memory = memory
Path.history.save(bufnr, history) Path.history.save(bufnr, history)
cb(memory) cb(memory)
else
cb(history.memory)
end end
end, end,
}, },
@@ -156,7 +136,7 @@ end
---@return AvantePromptOptions ---@return AvantePromptOptions
function M.generate_prompts(opts) function M.generate_prompts(opts)
local provider = opts.provider or Providers[Config.provider] local provider = opts.provider or Providers[Config.provider]
local mode = opts.mode or "planning" local mode = opts.mode or Config.mode
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
local _, request_body = Providers.parse_config(provider) local _, request_body = Providers.parse_config(provider)
local max_tokens = request_body.max_tokens or 4096 local max_tokens = request_body.max_tokens or 4096
@@ -166,22 +146,58 @@ function M.generate_prompts(opts)
if opts.prompt_opts and opts.prompt_opts.image_paths then if opts.prompt_opts and opts.prompt_opts.image_paths then
image_paths = vim.list_extend(image_paths, opts.prompt_opts.image_paths) image_paths = vim.list_extend(image_paths, opts.prompt_opts.image_paths)
end end
local instructions = opts.instructions
if instructions and instructions:match("image: ") then
local lines = vim.split(opts.instructions, "\n")
for i, line in ipairs(lines) do
if line:match("^image: ") then
local image_path = line:gsub("^image: ", "")
table.insert(image_paths, image_path)
table.remove(lines, i)
end
end
instructions = table.concat(lines, "\n")
end
local project_root = Utils.root.get() local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root)) Path.prompts.initialize(Path.prompts.get_templates_dir(project_root))
local tool_id_to_tool_name = {}
local tool_id_to_path = {}
local viewed_files = {}
if opts.history_messages then
for _, message in ipairs(opts.history_messages) do
local content = message.message.content
if type(content) ~= "table" then goto continue end
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue1 end
if item.type ~= "tool_use" then goto continue1 end
local tool_name = item.name
if tool_name ~= "view" then goto continue1 end
local path = item.input.path
tool_id_to_tool_name[item.id] = tool_name
if path then
local uniform_path = Utils.uniform_path(path)
tool_id_to_path[item.id] = uniform_path
viewed_files[uniform_path] = item.id
end
::continue1::
end
::continue::
end
for _, message in ipairs(opts.history_messages) do
local content = message.message.content
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) ~= "table" then goto continue end
if item.type ~= "tool_result" then goto continue end
local tool_name = tool_id_to_tool_name[item.tool_use_id]
if tool_name ~= "view" then goto continue end
local path = tool_id_to_path[item.tool_use_id]
local latest_tool_id = viewed_files[path]
if not latest_tool_id then goto continue end
if latest_tool_id ~= item.tool_use_id then
item.content = string.format("The file %s has been updated. Please use the latest view tool result!", path)
else
local lines, error = Utils.read_file_from_buf_or_disk(path)
if error ~= nil then Utils.error("error reading file: " .. error) end
lines = lines or {}
item.content = table.concat(lines, "\n")
end
::continue::
end
end
end
end
local system_info = Utils.get_system_info() local system_info = Utils.get_system_info()
local selected_files = opts.selected_files or {} local selected_files = opts.selected_files or {}
@@ -200,6 +216,8 @@ function M.generate_prompts(opts)
end end
end end
selected_files = vim.iter(selected_files):filter(function(file) return viewed_files[file.path] == nil end):totable()
local template_opts = { local template_opts = {
ask = opts.ask, -- TODO: add mode without ask instruction ask = opts.ask, -- TODO: add mode without ask instruction
code_lang = opts.code_lang, code_lang = opts.code_lang,
@@ -229,36 +247,42 @@ function M.generate_prompts(opts)
end end
---@type AvanteLLMMessage[] ---@type AvanteLLMMessage[]
local messages = {} local context_messages = {}
if opts.prompt_opts and opts.prompt_opts.messages then if opts.prompt_opts and opts.prompt_opts.messages then
messages = vim.list_extend(messages, opts.prompt_opts.messages) context_messages = vim.list_extend(context_messages, opts.prompt_opts.messages)
end end
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
local project_context = Path.prompts.render_file("_project.avanterules", template_opts) local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end if project_context ~= "" then
table.insert(context_messages, { role = "user", content = project_context, visible = false, is_context = true })
end
end end
if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then
local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts) local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts)
if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end if diagnostics ~= "" then
table.insert(context_messages, { role = "user", content = diagnostics, visible = false, is_context = true })
end
end end
if #selected_files > 0 or opts.selected_code ~= nil then if #selected_files > 0 or opts.selected_code ~= nil then
local code_context = Path.prompts.render_file("_context.avanterules", template_opts) local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end if code_context ~= "" then
table.insert(context_messages, { role = "user", content = code_context, visible = false, is_context = true })
end
end end
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
local memory = Path.prompts.render_file("_memory.avanterules", template_opts) local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end if memory ~= "" then
table.insert(context_messages, { role = "user", content = memory, visible = false, is_context = true })
end
end end
if instructions then table.insert(messages, { role = "user", content = instructions }) end
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt) local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
for _, message in ipairs(messages) do for _, message in ipairs(context_messages) do
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content) remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
end end
@@ -267,47 +291,49 @@ function M.generate_prompts(opts)
dropped_history_messages = vim.list_extend(dropped_history_messages, opts.prompt_opts.dropped_history_messages) dropped_history_messages = vim.list_extend(dropped_history_messages, opts.prompt_opts.dropped_history_messages)
end end
local final_history_messages = {}
if opts.history_messages then if opts.history_messages then
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user" -- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
local history_messages = {} local history_messages = {}
for i = #opts.history_messages, 1, -1 do for i = #opts.history_messages, 1, -1 do
local message = opts.history_messages[i] local message = opts.history_messages[i]
if Config.history.carried_entry_count ~= nil then local tokens = Utils.tokens.calculate_tokens(message.message.content)
if #history_messages > Config.history.carried_entry_count then break end
table.insert(history_messages, message)
else
local tokens = Utils.tokens.calculate_tokens(message.content)
remaining_tokens = remaining_tokens - tokens remaining_tokens = remaining_tokens - tokens
if remaining_tokens > 0 then if remaining_tokens > 0 then
table.insert(history_messages, message) table.insert(history_messages, 1, message)
else else
break break
end end
end end
if #history_messages == 0 then
history_messages = vim.list_slice(opts.history_messages, #opts.history_messages - 1, #opts.history_messages)
end end
dropped_history_messages = vim.list_slice(opts.history_messages, 1, #opts.history_messages - #history_messages) dropped_history_messages = vim.list_slice(opts.history_messages, 1, #opts.history_messages - #history_messages)
-- prepend the history messages to the messages table -- prepend the history messages to the messages table
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end) vim.iter(history_messages):each(function(msg) table.insert(final_history_messages, msg) end)
if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end
end end
if opts.mode == "cursor-applying" then -- Utils.debug("opts.history_messages", opts.history_messages)
local user_prompt = [[ -- Utils.debug("final_history_messages", final_history_messages)
Merge all changes from the <update> snippet into the <code> below.
- Preserve the code's structure, order, comments, and indentation exactly.
- Output only the updated code, enclosed within <updated-code> and </updated-code> tags.
- Do not include any additional text, explanations, placeholders, ellipses, or code fences.
]] ---@type AvanteLLMMessage[]
user_prompt = user_prompt .. string.format("<code>\n%s\n</code>\n", opts.original_code) local messages = vim.deepcopy(context_messages)
for _, snippet in ipairs(opts.update_snippets) do for _, msg in ipairs(final_history_messages) do
user_prompt = user_prompt .. string.format("<update>\n%s\n</update>\n", snippet) local message = msg.message
table.insert(messages, message)
end end
user_prompt = user_prompt .. "Provide the complete updated code."
table.insert(messages, { role = "user", content = user_prompt }) messages = vim
.iter(messages)
:filter(function(msg) return type(msg.content) ~= "string" or msg.content ~= "" end)
:totable()
if opts.instructions ~= nil and opts.instructions ~= "" then
messages = vim.list_extend(messages, { { role = "user", content = opts.instructions } })
end end
opts.session_ctx = opts.session_ctx or {} opts.session_ctx = opts.session_ctx or {}
@@ -318,19 +344,12 @@ Merge all changes from the <update> snippet into the <code> below.
if opts.tools then tools = vim.list_extend(tools, opts.tools) end if opts.tools then tools = vim.list_extend(tools, opts.tools) end
if opts.prompt_opts and opts.prompt_opts.tools then tools = vim.list_extend(tools, opts.prompt_opts.tools) end if opts.prompt_opts and opts.prompt_opts.tools then tools = vim.list_extend(tools, opts.prompt_opts.tools) end
local tool_histories = {}
if opts.tool_histories then tool_histories = vim.list_extend(tool_histories, opts.tool_histories) end
if opts.prompt_opts and opts.prompt_opts.tool_histories then
tool_histories = vim.list_extend(tool_histories, opts.prompt_opts.tool_histories)
end
---@type AvantePromptOptions ---@type AvantePromptOptions
return { return {
system_prompt = system_prompt, system_prompt = system_prompt,
messages = messages, messages = messages,
image_paths = image_paths, image_paths = image_paths,
tools = tools, tools = tools,
tool_histories = tool_histories,
dropped_history_messages = dropped_history_messages, dropped_history_messages = dropped_history_messages,
} }
end end
@@ -372,7 +391,9 @@ function M.curl(opts)
---@type string ---@type string
local current_event_state = nil local current_event_state = nil
local resp_ctx = {} local resp_ctx = {}
resp_ctx.session_id = Utils.uuid()
local response_body = ""
---@param line string ---@param line string
local function parse_stream_data(line) local function parse_stream_data(line)
local event = line:match("^event:%s*(.+)$") local event = line:match("^event:%s*(.+)$")
@@ -381,7 +402,16 @@ function M.curl(opts)
return return
end end
local data_match = line:match("^data:%s*(.+)$") local data_match = line:match("^data:%s*(.+)$")
if data_match then provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) end if data_match then
provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts)
else
response_body = response_body .. line
local ok, jsn = pcall(vim.json.decode, response_body)
if ok then
response_body = ""
if jsn.error then handler_opts.on_stop({ reason = "error", error = jsn.error }) end
end
end
end end
local function parse_response_without_stream(data) local function parse_response_without_stream(data)
@@ -394,10 +424,13 @@ function M.curl(opts)
local temp_file = fn.tempname() local temp_file = fn.tempname()
local curl_body_file = temp_file .. "-request-body.json" local curl_body_file = temp_file .. "-request-body.json"
local resp_body_file = temp_file .. "-response-body.json"
local json_content = vim.json.encode(spec.body) local json_content = vim.json.encode(spec.body)
fn.writefile(vim.split(json_content, "\n"), curl_body_file) fn.writefile(vim.split(json_content, "\n"), curl_body_file)
Utils.debug("curl body file:", curl_body_file) Utils.debug("curl request body file:", curl_body_file)
Utils.debug("curl response body file:", resp_body_file)
local headers_file = temp_file .. "-headers.txt" local headers_file = temp_file .. "-headers.txt"
@@ -407,6 +440,7 @@ function M.curl(opts)
if Config.debug then return end if Config.debug then return end
vim.schedule(function() vim.schedule(function()
fn.delete(curl_body_file) fn.delete(curl_body_file)
pcall(fn.delete, resp_body_file)
fn.delete(headers_file) fn.delete(headers_file)
end) end)
end end
@@ -431,6 +465,15 @@ function M.curl(opts)
return return
end end
if not data then return end if not data then return end
if Config.debug then
if type(data) == "string" then
local file = io.open(resp_body_file, "a")
if file then
file:write(data .. "\n")
file:close()
end
end
end
vim.schedule(function() vim.schedule(function()
if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then
if provider.parse_response ~= nil then if provider.parse_response ~= nil then
@@ -495,6 +538,17 @@ function M.curl(opts)
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" }) Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
end end
if result.status == 429 then if result.status == 429 then
Utils.debug("result", result)
if result.body then
local ok, jsn = pcall(vim.json.decode, result.body)
if ok then
if jsn.error and jsn.error.message then
handler_opts.on_stop({ reason = "error", error = jsn.error.message })
return
end
end
end
Utils.debug("result", result)
local retry_after = 10 local retry_after = 10
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after }) handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
@@ -585,17 +639,34 @@ function M._stream(opts)
---@type AvanteHandlerOptions ---@type AvanteHandlerOptions
local handler_opts = { local handler_opts = {
on_partial_tool_use = opts.on_partial_tool_use, on_messages_add = opts.on_messages_add,
on_state_change = opts.on_state_change,
on_start = opts.on_start, on_start = opts.on_start,
on_chunk = opts.on_chunk, on_chunk = opts.on_chunk,
on_stop = function(stop_opts) on_stop = function(stop_opts)
---@param tool_use_list AvanteLLMToolUse[] ---@param tool_use_list AvanteLLMToolUse[]
---@param tool_use_index integer ---@param tool_use_index integer
---@param tool_histories AvanteLLMToolHistory[] ---@param tool_results AvanteLLMToolResult[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories) local function handle_next_tool_use(tool_use_list, tool_use_index, tool_results)
if tool_use_index > #tool_use_list then if tool_use_index > #tool_use_list then
---@type avante.HistoryMessage[]
local messages = {}
for _, tool_result in ipairs(tool_results) do
messages[#messages + 1] = HistoryMessage:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = tool_result.tool_use_id,
content = tool_result.content,
is_error = tool_result.is_error,
},
},
})
end
opts.on_messages_add(messages)
local new_opts = vim.tbl_deep_extend("force", opts, { local new_opts = vim.tbl_deep_extend("force", opts, {
tool_histories = tool_histories, history_messages = opts.get_history_messages(),
}) })
if provider.get_rate_limit_sleep_time then if provider.get_rate_limit_sleep_time then
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers) local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
@@ -616,7 +687,7 @@ function M._stream(opts)
if error == LLMToolHelpers.CANCEL_TOKEN then if error == LLMToolHelpers.CANCEL_TOKEN then
Utils.debug("Tool execution was cancelled by user") Utils.debug("Tool execution was cancelled by user")
opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n") opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n")
return opts.on_stop({ reason = "cancelled", tool_histories = tool_histories }) return opts.on_stop({ reason = "cancelled" })
end end
local tool_result = { local tool_result = {
@@ -624,8 +695,8 @@ function M._stream(opts)
content = error ~= nil and error or result, content = error ~= nil and error or result,
is_error = error ~= nil, is_error = error ~= nil,
} }
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use }) table.insert(tool_results, tool_result)
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories) return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_results)
end end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil -- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use( local result, error = LLMTools.process_tool_use(
@@ -638,20 +709,53 @@ function M._stream(opts)
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end end
if stop_opts.reason == "cancelled" then if stop_opts.reason == "cancelled" then
opts.on_chunk("\n*[Request cancelled by user.]*\n") if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user.]*\n") end
return opts.on_stop({ reason = "cancelled", tool_histories = opts.tool_histories }) if opts.on_messages_add then
local message = HistoryMessage:new({
role = "user",
content = "[Request cancelled by user.]",
})
opts.on_messages_add({ message })
end
return opts.on_stop({ reason = "cancelled" })
end
if stop_opts.reason == "tool_use" then
local tool_use_list = {} ---@type AvanteLLMToolUse[]
local tool_result_seen = {}
local history_messages = opts.get_history_messages and opts.get_history_messages() or {}
for idx = #history_messages, 1, -1 do
local message = history_messages[idx]
local content = message.message.content
if type(content) ~= "table" or #content == 0 then goto continue end
if content[1].type == "tool_use" then
if not tool_result_seen[content[1].id] then
table.insert(tool_use_list, 1, content[1])
else
break
end
end
if content[1].type == "tool_result" then tool_result_seen[content[1].tool_use_id] = true end
::continue::
end end
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[] local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do for _, tool_use in vim.spairs(tool_use_list) do
table.insert(sorted_tool_use_list, tool_use) table.insert(sorted_tool_use_list, tool_use)
end end
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories) return handle_next_tool_use(sorted_tool_use_list, 1, {})
end end
if stop_opts.reason == "rate_limit" then if stop_opts.reason == "rate_limit" then
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..." local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
opts.on_chunk("\n*[" .. msg .. "]*\n") if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end
local message
if opts.on_messages_add then
message = HistoryMessage:new({
role = "assistant",
content = "\n\n" .. msg_content,
}, {
just_for_display = true,
})
opts.on_messages_add({ message })
end
local timer = vim.loop.new_timer() local timer = vim.loop.new_timer()
if timer then if timer then
local retry_after = stop_opts.retry_after local retry_after = stop_opts.retry_after
@@ -661,8 +765,12 @@ function M._stream(opts)
0, 0,
vim.schedule_wrap(function() vim.schedule_wrap(function()
if retry_after > 0 then retry_after = retry_after - 1 end if retry_after > 0 then retry_after = retry_after - 1 end
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..." local msg_content_ = "*[Rate limit reached. Retrying in " .. retry_after .. " seconds ...]*"
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n") if opts.on_chunk then opts.on_chunk([[\033[1A\033[K]] .. "\n" .. msg_content_ .. "\n") end
if opts.on_messages_add and message then
message.message.content = "\n\n" .. msg_content_
opts.on_messages_add({ message })
end
countdown() countdown()
end) end)
) )
@@ -676,7 +784,6 @@ function M._stream(opts)
end, stop_opts.retry_after * 1000) end, stop_opts.retry_after * 1000)
return return
end end
stop_opts.tool_histories = opts.tool_histories
return opts.on_stop(stop_opts) return opts.on_stop(stop_opts)
end, end,
} }
@@ -697,6 +804,8 @@ local function _merge_response(first_response, second_response, opts)
prompt = prompt .. "\n" prompt = prompt .. "\n"
if opts.instructions == nil then opts.instructions = "" end
-- append this reference prompt to the prompt_opts messages at last -- append this reference prompt to the prompt_opts messages at last
opts.instructions = opts.instructions .. prompt opts.instructions = opts.instructions .. prompt
@@ -802,20 +911,12 @@ function M.stream(opts)
return original_on_stop(stop_opts) return original_on_stop(stop_opts)
end) end)
end end
if opts.on_partial_tool_use ~= nil then
local original_on_partial_tool_use = opts.on_partial_tool_use
opts.on_partial_tool_use = vim.schedule_wrap(function(tool_use)
if is_completed then return end
return original_on_partial_tool_use(tool_use)
end)
end
local valid_dual_boost_modes = { local valid_dual_boost_modes = {
planning = true, legacy = true,
["cursor-planning"] = true,
} }
opts.mode = opts.mode or "planning" opts.mode = opts.mode or Config.mode
if Config.dual_boost.enabled and valid_dual_boost_modes[opts.mode] then if Config.dual_boost.enabled and valid_dual_boost_modes[opts.mode] then
M._dual_boost_stream( M._dual_boost_stream(

View File

@@ -10,7 +10,7 @@ M.name = "create"
M.description = "The create tool allows you to create a new file with specified content." M.description = "The create tool allows you to create a new file with specified content."
function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end function M.enabled() return require("avante.config").mode == "agentic" end
---@type AvanteLLMToolParam ---@type AvanteLLMToolParam
M.param = { M.param = {

View File

@@ -71,7 +71,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
local prompt = opts.prompt local prompt = opts.prompt
local tools = get_available_tools() local tools = get_available_tools()
local start_time = os.date("%Y-%m-%d %H:%M:%S") local start_time = Utils.get_timestamp()
if on_log then on_log("prompt: " .. prompt) end if on_log then on_log("prompt: " .. prompt) end
@@ -84,6 +84,8 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
messages = messages or {} messages = messages or {}
table.insert(messages, { role = "user", content = prompt }) table.insert(messages, { role = "user", content = prompt })
local tool_use_messages = {}
local total_tokens = 0 local total_tokens = 0
local final_response = "" local final_response = ""
Llm._stream({ Llm._stream({
@@ -93,6 +95,15 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
on_tool_log = function(tool_id, tool_name, log, state) on_tool_log = function(tool_id, tool_name, log, state)
if on_log then on_log(string.format("[%s] %s", tool_name, log)) end if on_log then on_log(string.format("[%s] %s", tool_name, log)) end
end, end,
on_messages_add = function(msgs)
msgs = vim.is_list(msgs) and msgs or { msgs }
for _, msg in ipairs(msgs) do
local content = msg.message.content
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
tool_use_messages[msg.uuid] = true
end
end
end,
session_ctx = session_ctx, session_ctx = session_ctx,
prompt_opts = { prompt_opts = {
system_prompt = system_prompt, system_prompt = system_prompt,
@@ -111,9 +122,9 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
on_complete(err, nil) on_complete(err, nil)
return return
end end
local end_time = os.date("%Y-%m-%d %H:%M:%S") local end_time = Utils.get_timestamp()
local elapsed_time = Utils.datetime_diff(tostring(start_time), tostring(end_time)) local elapsed_time = Utils.datetime_diff(start_time, end_time)
local tool_use_count = stop_opts.tool_histories and #stop_opts.tool_histories or 0 local tool_use_count = vim.tbl_count(tool_use_messages)
local summary = "Done (" local summary = "Done ("
.. (tool_use_count <= 1 and "1 tool use" or tool_use_count .. " tool uses") .. (tool_use_count <= 1 and "1 tool use" or tool_use_count .. " tool uses")
.. " · " .. " · "

View File

@@ -598,6 +598,7 @@ end
---@type AvanteLLMTool[] ---@type AvanteLLMTool[]
M._tools = { M._tools = {
require("avante.llm_tools.replace_in_file"),
require("avante.llm_tools.dispatch_agent"), require("avante.llm_tools.dispatch_agent"),
require("avante.llm_tools.glob"), require("avante.llm_tools.glob"),
{ {
@@ -1104,7 +1105,7 @@ M._tools = {
---@return string | nil result ---@return string | nil result
---@return string | nil error ---@return string | nil error
function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx) function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
Utils.debug("use tool", tool_use.name, tool_use.input_json) -- Utils.debug("use tool", tool_use.name, tool_use.input_json)
-- Check if execution is already cancelled -- Check if execution is already cancelled
if Helpers.is_cancelled then if Helpers.is_cancelled then
@@ -1125,8 +1126,7 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
if tool == nil then return nil, "This tool is not provided: " .. tool_use.name end if tool == nil then return nil, "This tool is not provided: " .. tool_use.name end
func = tool.func or M[tool.name] func = tool.func or M[tool.name]
end end
local ok, input_json = pcall(vim.json.decode, tool_use.input_json) local input_json = tool_use.input
if not ok then return nil, "Failed to decode tool input json: " .. vim.inspect(input_json) end
if not func then return nil, "Tool not found: " .. tool_use.name end if not func then return nil, "Tool not found: " .. tool_use.name end
if on_log then on_log(tool_use.id, tool_use.name, "running tool", "running") end if on_log then on_log(tool_use.id, tool_use.name, "running tool", "running") end

View File

@@ -10,7 +10,7 @@ M.name = "insert"
M.description = "The insert tool allows you to insert text at a specific location in a file." M.description = "The insert tool allows you to insert text at a specific location in a file."
function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end function M.enabled() return require("avante.config").mode == "agentic" end
---@type AvanteLLMToolParam ---@type AvanteLLMToolParam
M.param = { M.param = {

View File

@@ -0,0 +1,481 @@
local Base = require("avante.llm_tools.base")
local Helpers = require("avante.llm_tools.helpers")
local Utils = require("avante.utils")
local Highlights = require("avante.highlights")
local Config = require("avante.config")
local PRIORITY = (vim.hl or vim.highlight).priorities.user
local NAMESPACE = vim.api.nvim_create_namespace("avante-diff")
local KEYBINDING_NAMESPACE = vim.api.nvim_create_namespace("avante-diff-keybinding")
---@class AvanteLLMTool
local M = setmetatable({}, Base)
M.name = "replace_in_file"
M.description =
"Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file."
---@type AvanteLLMToolParam
M.param = {
type = "table",
fields = {
{
name = "path",
description = "The path to the file in the current project scope",
type = "string",
},
{
name = "diff",
description = [[
One or more SEARCH/REPLACE blocks following this exact format:
\`\`\`
<<<<<<< SEARCH
[exact content to find]
=======
[new content to replace with]
>>>>>>> REPLACE
\`\`\`
Critical rules:
1. SEARCH content must match the associated file section to find EXACTLY:
* Match character-for-character including whitespace, indentation, line endings
* Include all comments, docstrings, etc.
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
* Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes.
* Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change.
* When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file.
3. Keep SEARCH/REPLACE blocks concise:
* Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file.
* Include just the changing lines, and a few surrounding lines if needed for uniqueness.
* Do not include long runs of unchanging lines in SEARCH/REPLACE blocks.
* Each line must be complete. Never truncate lines mid-way through as this can cause matching failures.
4. Special operations:
* To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location)
* To delete code: Use empty REPLACE section
]],
type = "string",
},
},
}
---@type AvanteLLMToolReturn[]
M.returns = {
{
name = "success",
description = "True if the replacement was successful, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the replacement failed",
type = "string",
optional = true,
},
}
---@type AvanteLLMToolFunc<{ path: string, diff: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if not opts.path or not opts.diff then return false, "path and diff are required" end
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
local diff_lines = vim.split(opts.diff, "\n")
local is_searching = false
local is_replacing = false
local current_search = {}
local current_replace = {}
local rough_diff_blocks = {}
for _, line in ipairs(diff_lines) do
if line:match("^%s*<<<<<<< SEARCH") then
is_searching = true
is_replacing = false
current_search = {}
elseif line:match("^%s*=======") and is_searching then
is_searching = false
is_replacing = true
current_replace = {}
elseif line:match("^%s*>>>>>>> REPLACE") and is_replacing then
is_replacing = false
table.insert(
rough_diff_blocks,
{ search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
)
elseif is_searching then
table.insert(current_search, line)
elseif is_replacing then
table.insert(current_replace, line)
end
end
if #rough_diff_blocks == 0 then return false, "No diff blocks found" end
local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local function parse_rough_diff_block(rough_diff_block, current_lines)
local old_lines = vim.split(rough_diff_block.search, "\n")
local new_lines = vim.split(rough_diff_block.replace, "\n")
local start_line, end_line
for i = 1, #current_lines - #old_lines + 1 do
local match = true
for j = 1, #old_lines do
if Utils.remove_indentation(current_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
match = false
break
end
end
if match then
start_line = i
end_line = i + #old_lines - 1
break
end
end
if start_line == nil or end_line == nil then
return "Failed to find the old string:\n" .. rough_diff_block.search
end
local old_str = rough_diff_block.search
local new_str = rough_diff_block.replace
local original_indentation = Utils.get_indentation(current_lines[start_line])
if original_indentation ~= Utils.get_indentation(old_lines[1]) then
old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines)
new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines)
old_str = table.concat(old_lines, "\n")
new_str = table.concat(new_lines, "\n")
end
rough_diff_block.old_lines = old_lines
rough_diff_block.new_lines = new_lines
rough_diff_block.search = old_str
rough_diff_block.replace = new_str
rough_diff_block.start_line = start_line
rough_diff_block.end_line = end_line
return nil
end
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
local res = {}
local base_line_ = 0
for _, rough_diff_block in ipairs(rough_diff_blocks_) do
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
local patch = vim.diff(rough_diff_block.search, rough_diff_block.replace, { ---@type integer[][]
algorithm = "histogram",
result_type = "indices",
ctxlen = vim.o.scrolloff,
})
for _, hunk in ipairs(patch) do
local start_a, count_a, start_b, count_b = unpack(hunk)
local diff_block = {}
if count_a > 0 then
diff_block.old_lines = vim.list_slice(rough_diff_block.old_lines, start_a, start_a + count_a - 1)
else
diff_block.old_lines = {}
end
if count_b > 0 then
diff_block.new_lines = vim.list_slice(rough_diff_block.new_lines, start_b, start_b + count_b - 1)
else
diff_block.new_lines = {}
end
if count_a > 0 then
diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a - 1
else
diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a
end
diff_block.end_line = base_line_ + rough_diff_block.start_line + start_a + math.max(count_a, 1) - 2
diff_block.search = table.concat(diff_block.old_lines, "\n")
diff_block.replace = table.concat(diff_block.new_lines, "\n")
table.insert(res, diff_block)
end
local distance = 0
for _, hunk in ipairs(patch) do
local _, count_a, _, count_b = unpack(hunk)
distance = distance + count_b - count_a
end
local old_distance = #rough_diff_block.new_lines - #rough_diff_block.old_lines
base_line_ = base_line_ + distance - old_distance
end
return res
end
for _, rough_diff_block in ipairs(rough_diff_blocks) do
local error = parse_rough_diff_block(rough_diff_block, original_lines)
if error then
on_complete(false, error)
return
end
end
local diff_blocks = rough_diff_blocks_to_diff_blocks(rough_diff_blocks)
table.sort(diff_blocks, function(a, b) return a.start_line < b.start_line end)
local base_line = 0
for _, diff_block in ipairs(diff_blocks) do
diff_block.new_start_line = diff_block.start_line + base_line
diff_block.new_end_line = diff_block.new_start_line + #diff_block.new_lines - 1
base_line = base_line + #diff_block.new_lines - #diff_block.old_lines
end
local function remove_diff_block(removed_idx, use_new_lines)
local new_diff_blocks = {}
local distance = 0
for idx, diff_block in ipairs(diff_blocks) do
if idx == removed_idx then
if not use_new_lines then distance = #diff_block.old_lines - #diff_block.new_lines end
goto continue
end
if idx > removed_idx then
diff_block.new_start_line = diff_block.new_start_line + distance
diff_block.new_end_line = diff_block.new_end_line + distance
end
table.insert(new_diff_blocks, diff_block)
::continue::
end
diff_blocks = new_diff_blocks
end
local function get_current_diff_block()
local winid = Utils.get_winid(bufnr)
local cursor_line = Utils.get_cursor_pos(winid)
for idx, diff_block in ipairs(diff_blocks) do
if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then
return diff_block, idx
end
end
return nil, nil
end
local function get_prev_diff_block()
local winid = Utils.get_winid(bufnr)
local cursor_line = Utils.get_cursor_pos(winid)
local distance = nil
local idx = nil
for i, diff_block in ipairs(diff_blocks) do
if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then
local new_i = i - 1
if new_i < 1 then return diff_blocks[#diff_blocks] end
return diff_blocks[new_i]
end
if diff_block.new_start_line < cursor_line then
local distance_ = cursor_line - diff_block.new_start_line
if distance == nil or distance_ < distance then
distance = distance_
idx = i
end
end
end
if idx ~= nil then return diff_blocks[idx] end
if #diff_blocks > 0 then return diff_blocks[#diff_blocks] end
return nil
end
local function get_next_diff_block()
local winid = Utils.get_winid(bufnr)
local cursor_line = Utils.get_cursor_pos(winid)
local distance = nil
local idx = nil
for i, diff_block in ipairs(diff_blocks) do
if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then
local new_i = i + 1
if new_i > #diff_blocks then return diff_blocks[1] end
return diff_blocks[new_i]
end
if diff_block.new_start_line > cursor_line then
local distance_ = diff_block.new_start_line - cursor_line
if distance == nil or distance_ < distance then
distance = distance_
idx = i
end
end
end
if idx ~= nil then return diff_blocks[idx] end
if #diff_blocks > 0 then return diff_blocks[1] end
return nil
end
local show_keybinding_hint_extmark_id = nil
local function register_cursor_move_events()
local function show_keybinding_hint(lnum)
if show_keybinding_hint_extmark_id then
vim.api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id)
end
local hint = string.format(
"[<%s>: OURS, <%s>: THEIRS, <%s>: PREV, <%s>: NEXT]",
Config.mappings.diff.ours,
Config.mappings.diff.theirs,
Config.mappings.diff.prev,
Config.mappings.diff.next
)
show_keybinding_hint_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, KEYBINDING_NAMESPACE, lnum - 1, -1, {
hl_group = "AvanteInlineHint",
virt_text = { { hint, "AvanteInlineHint" } },
virt_text_pos = "right_align",
priority = PRIORITY,
})
end
vim.api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI", "WinLeave" }, {
buffer = bufnr,
callback = function(event)
local diff_block = get_current_diff_block()
if (event.event == "CursorMoved" or event.event == "CursorMovedI") and diff_block then
show_keybinding_hint(diff_block.new_start_line)
else
vim.api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1)
end
end,
})
end
local function register_keybinding_events()
vim.keymap.set({ "n", "v" }, Config.mappings.diff.ours, function()
if vim.api.nvim_get_current_buf() ~= bufnr then return end
local diff_block, idx = get_current_diff_block()
if not diff_block then return end
pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.delete_extmark_id)
pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.incoming_extmark_id)
vim.api.nvim_buf_set_lines(
bufnr,
diff_block.new_start_line - 1,
diff_block.new_end_line,
false,
diff_block.old_lines
)
diff_block.incoming_extmark_id = nil
diff_block.delete_extmark_id = nil
remove_diff_block(idx, false)
local next_diff_block = get_next_diff_block()
if next_diff_block then
local winnr = Utils.get_winid(bufnr)
vim.api.nvim_win_set_cursor(winnr, { next_diff_block.new_start_line, 0 })
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
end
end)
vim.keymap.set({ "n", "v" }, Config.mappings.diff.theirs, function()
if vim.api.nvim_get_current_buf() ~= bufnr then return end
local diff_block, idx = get_current_diff_block()
if not diff_block then return end
pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.incoming_extmark_id)
pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.delete_extmark_id)
diff_block.incoming_extmark_id = nil
diff_block.delete_extmark_id = nil
remove_diff_block(idx, true)
local next_diff_block = get_next_diff_block()
if next_diff_block then
local winnr = Utils.get_winid(bufnr)
vim.api.nvim_win_set_cursor(winnr, { next_diff_block.new_start_line, 0 })
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
end
end)
vim.keymap.set({ "n", "v" }, Config.mappings.diff.next, function()
if vim.api.nvim_get_current_buf() ~= bufnr then return end
local diff_block = get_next_diff_block()
if not diff_block then return end
local winnr = Utils.get_winid(bufnr)
vim.api.nvim_win_set_cursor(winnr, { diff_block.new_start_line, 0 })
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
end)
vim.keymap.set({ "n", "v" }, Config.mappings.diff.prev, function()
if vim.api.nvim_get_current_buf() ~= bufnr then return end
local diff_block = get_prev_diff_block()
if not diff_block then return end
local winnr = Utils.get_winid(bufnr)
vim.api.nvim_win_set_cursor(winnr, { diff_block.new_start_line, 0 })
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
end)
end
local function unregister_keybinding_events()
pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.ours)
pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.theirs)
pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.next)
pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.prev)
pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.ours)
pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.theirs)
pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.next)
pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.prev)
end
local augroup = vim.api.nvim_create_augroup("avante_replace_in_file", { clear = true })
local function clear()
if bufnr and not vim.api.nvim_buf_is_valid(bufnr) then return end
vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
vim.api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1)
unregister_keybinding_events()
pcall(vim.api.nvim_del_augroup_by_id, augroup)
end
local function insert_diff_blocks_new_lines()
local base_line_ = 0
for _, diff_block in ipairs(diff_blocks) do
local start_line = diff_block.start_line + base_line_
local end_line = diff_block.end_line + base_line_
base_line_ = base_line_ + #diff_block.new_lines - #diff_block.old_lines
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, diff_block.new_lines)
end
end
local function highlight_diff_blocks()
vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
local base_line_ = 0
local max_col = vim.o.columns
for _, diff_block in ipairs(diff_blocks) do
local start_line = diff_block.start_line + base_line_
base_line_ = base_line_ + #diff_block.new_lines - #diff_block.old_lines
local deleted_virt_lines = vim
.iter(diff_block.old_lines)
:map(function(line)
--- append spaces to the end of the line
local line_ = line .. string.rep(" ", max_col - #line)
return { { line_, Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH } }
end)
:totable()
local extmark_line = math.max(0, start_line - 2)
local delete_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
virt_lines = deleted_virt_lines,
hl_eol = true,
hl_mode = "combine",
})
local end_row = start_line + #diff_block.new_lines - 1
local incoming_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, {
hl_group = Highlights.INCOMING,
hl_eol = true,
hl_mode = "combine",
end_row = end_row,
})
diff_block.delete_extmark_id = delete_extmark_id
diff_block.incoming_extmark_id = incoming_extmark_id
end
end
insert_diff_blocks_new_lines()
highlight_diff_blocks()
register_cursor_move_events()
register_keybinding_events()
Helpers.confirm("Are you sure you want to apply this modification?", function(ok, reason)
clear()
if not ok then
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, original_lines)
on_complete(false, "User declined, reason: " .. (reason or "unknown"))
return
end
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
on_complete(true, nil)
end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx)
end
return M

View File

@@ -13,7 +13,7 @@ M.name = "str_replace"
M.description = M.description =
"The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits." "The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits."
function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end function M.enabled() return false end
---@type AvanteLLMToolParam ---@type AvanteLLMToolParam
M.param = { M.param = {
@@ -65,8 +65,8 @@ function M.func(opts, on_log, on_complete, session_ctx)
if not file then return false, "file not found: " .. abs_path end if not file then return false, "file not found: " .. abs_path end
if opts.old_str == nil then return false, "old_str not provided" end if opts.old_str == nil then return false, "old_str not provided" end
if opts.new_str == nil then return false, "new_str not provided" end if opts.new_str == nil then return false, "new_str not provided" end
Utils.debug("old_str", opts.old_str) -- Utils.debug("old_str", opts.old_str)
Utils.debug("new_str", opts.new_str) -- Utils.debug("new_str", opts.new_str)
local bufnr, err = Helpers.get_bufnr(abs_path) local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end if err then return false, err end
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
@@ -77,7 +77,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
for i = 1, #lines - #old_lines + 1 do for i = 1, #lines - #old_lines + 1 do
local match = true local match = true
for j = 1, #old_lines do for j = 1, #old_lines do
if lines[i + j - 1] ~= old_lines[j] then if Utils.remove_indentation(lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
match = false match = false
break break
end end
@@ -89,11 +89,20 @@ function M.func(opts, on_log, on_complete, session_ctx)
end end
end end
if start_line == nil or end_line == nil then if start_line == nil or end_line == nil then
on_complete(false, "Failed to find the old string: " .. opts.old_str) on_complete(false, "Failed to find the old string:\n" .. opts.old_str)
return return
end end
local old_str = opts.old_str
local new_str = opts.new_str
local original_indentation = Utils.get_indentation(lines[start_line])
if original_indentation ~= Utils.get_indentation(old_lines[1]) then
old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines)
new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines)
old_str = table.concat(old_lines, "\n")
new_str = table.concat(new_lines, "\n")
end
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
local patch = vim.diff(opts.old_str, opts.new_str, { ---@type integer[][] local patch = vim.diff(old_str, new_str, { ---@type integer[][]
algorithm = "histogram", algorithm = "histogram",
result_type = "indices", result_type = "indices",
ctxlen = vim.o.scrolloff, ctxlen = vim.o.scrolloff,

View File

@@ -10,7 +10,7 @@ M.name = "undo_edit"
M.description = "The undo_edit tool allows you to revert the last edit made to a file." M.description = "The undo_edit tool allows you to revert the last edit made to a file."
function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end function M.enabled() return require("avante.config").mode == "agentic" end
---@type AvanteLLMToolParam ---@type AvanteLLMToolParam
M.param = { M.param = {

View File

@@ -76,17 +76,6 @@ M.returns = {
function M.func(opts, on_log, on_complete, session_ctx) function M.func(opts, on_log, on_complete, session_ctx)
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("path: " .. opts.path) end if on_log then on_log("path: " .. opts.path) end
if Helpers.already_in_context(opts.path) then
on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to read it again?")
return
end
if session_ctx then
if Helpers.already_viewed(opts.path, session_ctx) then
on_complete(nil, "Ooooops! You have already viewed this file! Why you are trying to read it again?")
return
end
Helpers.mark_as_viewed(opts.path, session_ctx)
end
local abs_path = Helpers.get_abs_path(opts.path) local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end

View File

@@ -16,9 +16,9 @@ local function generate_project_dirname_in_storage(bufnr)
buf = bufnr, buf = bufnr,
}) })
-- Replace path separators with double underscores -- Replace path separators with double underscores
local path_with_separators = fn.substitute(project_root, "/", "__", "g") local path_with_separators = string.gsub(project_root, "/", "__")
-- Replace other non-alphanumeric characters with single underscores -- Replace other non-alphanumeric characters with single underscores
local dirname = fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") local dirname = string.gsub(path_with_separators, "[^A-Za-z0-9._]", "_")
return tostring(Path:new("projects"):joinpath(dirname)) return tostring(Path:new("projects"):joinpath(dirname))
end end
@@ -34,6 +34,7 @@ function History.get_history_dir(bufnr)
return history_dir return history_dir
end end
---@return avante.ChatHistory[]
function History.list(bufnr) function History.list(bufnr)
local history_dir = History.get_history_dir(bufnr) local history_dir = History.get_history_dir(bufnr)
local files = vim.fn.glob(tostring(history_dir:joinpath("*.json")), true, true) local files = vim.fn.glob(tostring(history_dir:joinpath("*.json")), true, true)
@@ -53,8 +54,10 @@ function History.list(bufnr)
table.sort(res, function(a, b) table.sort(res, function(a, b)
if a.filename == latest_filename then return true end if a.filename == latest_filename then return true end
if b.filename == latest_filename then return false end if b.filename == latest_filename then return false end
local timestamp_a = #a.entries > 0 and a.entries[#a.entries].timestamp or a.timestamp local a_messages = Utils.get_history_messages(a)
local timestamp_b = #b.entries > 0 and b.entries[#b.entries].timestamp or b.timestamp local b_messages = Utils.get_history_messages(b)
local timestamp_a = #a_messages > 0 and a_messages[#a_messages].timestamp or a.timestamp
local timestamp_b = #b_messages > 0 and b_messages[#b_messages].timestamp or b.timestamp
return timestamp_a > timestamp_b return timestamp_a > timestamp_b
end) end)
return res return res
@@ -117,8 +120,8 @@ function History.new(bufnr)
---@type avante.ChatHistory ---@type avante.ChatHistory
local history = { local history = {
title = "untitled", title = "untitled",
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")), timestamp = Utils.get_timestamp(),
entries = {}, messages = {},
filename = filepath_to_filename(filepath), filename = filepath_to_filename(filepath),
} }
return history return history
@@ -169,11 +172,10 @@ function Prompt.get_builtin_prompts_filepath(mode) return string.format("%s.avan
local _templates_lib = nil local _templates_lib = nil
Prompt.custom_modes = { Prompt.custom_modes = {
planning = true, agentic = true,
legacy = true,
editing = true, editing = true,
suggesting = true, suggesting = true,
["cursor-planning"] = true,
["cursor-applying"] = true,
} }
Prompt.custom_prompts_contents = {} Prompt.custom_prompts_contents = {}

View File

@@ -58,6 +58,7 @@ function M:parse_stream_data(ctx, data, opts)
end end
function M:parse_response_without_stream(data, event_state, opts) function M:parse_response_without_stream(data, event_state, opts)
if opts.on_chunk == nil then return end
local bedrock_match = data:gmatch("exception(%b{})") local bedrock_match = data:gmatch("exception(%b{})")
opts.on_chunk("\n**Exception caught**\n\n") opts.on_chunk("\n**Exception caught**\n\n")
for bedrock_data_match in bedrock_match do for bedrock_data_match in bedrock_match do

View File

@@ -2,7 +2,7 @@ local Utils = require("avante.utils")
local Clipboard = require("avante.clipboard") local Clipboard = require("avante.clipboard")
local P = require("avante.providers") local P = require("avante.providers")
local Config = require("avante.config") local Config = require("avante.config")
local StreamingJsonParser = require("avante.utils.streaming_json_parser") local HistoryMessage = require("avante.history_message")
---@class AvanteProviderFunctor ---@class AvanteProviderFunctor
local M = {} local M = {}
@@ -139,63 +139,6 @@ function M:parse_messages(opts)
messages[#messages].content = message_content messages[#messages].content = message_content
end end
if opts.tool_histories then
for _, tool_history in ipairs(opts.tool_histories) do
if tool_history.tool_use then
local msg = {
role = "assistant",
content = {},
}
if tool_history.tool_use.thinking_blocks then
for _, thinking_block in ipairs(tool_history.tool_use.thinking_blocks) do
msg.content[#msg.content + 1] = {
type = "thinking",
thinking = thinking_block.thinking,
signature = thinking_block.signature,
}
end
end
if tool_history.tool_use.redacted_thinking_blocks then
for _, redacted_thinking_block in ipairs(tool_history.tool_use.redacted_thinking_blocks) do
msg.content[#msg.content + 1] = {
type = "redacted_thinking",
data = redacted_thinking_block.data,
}
end
end
if tool_history.tool_use.response_contents then
for _, response_content in ipairs(tool_history.tool_use.response_contents) do
msg.content[#msg.content + 1] = {
type = "text",
text = response_content,
}
end
end
msg.content[#msg.content + 1] = {
type = "tool_use",
id = tool_history.tool_use.id,
name = tool_history.tool_use.name,
input = vim.json.decode(tool_history.tool_use.input_json),
}
messages[#messages + 1] = msg
end
if tool_history.tool_result then
messages[#messages + 1] = {
role = "user",
content = {
{
type = "tool_result",
tool_use_id = tool_history.tool_result.tool_use_id,
content = tool_history.tool_result.content,
is_error = tool_history.tool_result.is_error,
},
},
}
end
end
end
return messages return messages
end end
@@ -226,14 +169,51 @@ function M:parse_response(ctx, data_stream, event_state, opts)
local content_block = jsn.content_block local content_block = jsn.content_block
content_block.stoppped = false content_block.stoppped = false
ctx.content_blocks[jsn.index + 1] = content_block ctx.content_blocks[jsn.index + 1] = content_block
if content_block.type == "thinking" then opts.on_chunk("<think>\n") end if content_block.type == "text" then
if content_block.type == "tool_use" and opts.on_partial_tool_use then local msg = HistoryMessage:new({
opts.on_partial_tool_use({ role = "assistant",
name = content_block.name, content = content_block.text,
id = content_block.id, }, {
partial_json = {},
state = "generating", state = "generating",
}) })
content_block.uuid = msg.uuid
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
if content_block.type == "thinking" then
if opts.on_chunk then opts.on_chunk("<think>\n") end
if opts.on_messages_add then
local msg = HistoryMessage:new({
role = "assistant",
content = {
{
type = "thinking",
thinking = content_block.thinking,
signature = content_block.signature,
},
},
}, {
state = "generating",
})
content_block.uuid = msg.uuid
opts.on_messages_add({ msg })
end
end
if content_block.type == "tool_use" and opts.on_messages_add then
local msg = HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
name = content_block.name,
id = content_block.id,
input = {},
},
},
}, {
state = "generating",
})
content_block.uuid = msg.uuid
opts.on_messages_add({ msg })
end end
elseif event_state == "content_block_delta" then elseif event_state == "content_block_delta" then
local ok, jsn = pcall(vim.json.decode, data_stream) local ok, jsn = pcall(vim.json.decode, data_stream)
@@ -242,23 +222,35 @@ function M:parse_response(ctx, data_stream, event_state, opts)
if jsn.delta.type == "input_json_delta" then if jsn.delta.type == "input_json_delta" then
if not content_block.input_json then content_block.input_json = "" end if not content_block.input_json then content_block.input_json = "" end
content_block.input_json = content_block.input_json .. jsn.delta.partial_json content_block.input_json = content_block.input_json .. jsn.delta.partial_json
if opts.on_partial_tool_use then
local streaming_json_parser = StreamingJsonParser:new()
local partial_json = streaming_json_parser:parse(content_block.input_json)
opts.on_partial_tool_use({
name = content_block.name,
id = content_block.id,
partial_json = partial_json or {},
state = "generating",
})
end
return return
elseif jsn.delta.type == "thinking_delta" then elseif jsn.delta.type == "thinking_delta" then
content_block.thinking = content_block.thinking .. jsn.delta.thinking content_block.thinking = content_block.thinking .. jsn.delta.thinking
opts.on_chunk(jsn.delta.thinking) if opts.on_chunk then opts.on_chunk(jsn.delta.thinking) end
local msg = HistoryMessage:new({
role = "assistant",
content = {
{
type = "thinking",
thinking = content_block.thinking,
signature = content_block.signature,
},
},
}, {
state = "generating",
uuid = content_block.uuid,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
elseif jsn.delta.type == "text_delta" then elseif jsn.delta.type == "text_delta" then
content_block.text = content_block.text .. jsn.delta.text content_block.text = content_block.text .. jsn.delta.text
opts.on_chunk(jsn.delta.text) if opts.on_chunk then opts.on_chunk(jsn.delta.text) end
local msg = HistoryMessage:new({
role = "assistant",
content = content_block.text,
}, {
state = "generating",
uuid = content_block.uuid,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
elseif jsn.delta.type == "signature_delta" then elseif jsn.delta.type == "signature_delta" then
if ctx.content_blocks[jsn.index + 1].signature == nil then ctx.content_blocks[jsn.index + 1].signature = "" end if ctx.content_blocks[jsn.index + 1].signature == nil then ctx.content_blocks[jsn.index + 1].signature = "" end
ctx.content_blocks[jsn.index + 1].signature = ctx.content_blocks[jsn.index + 1].signature .. jsn.delta.signature ctx.content_blocks[jsn.index + 1].signature = ctx.content_blocks[jsn.index + 1].signature .. jsn.delta.signature
@@ -268,62 +260,77 @@ function M:parse_response(ctx, data_stream, event_state, opts)
if not ok then return end if not ok then return end
local content_block = ctx.content_blocks[jsn.index + 1] local content_block = ctx.content_blocks[jsn.index + 1]
content_block.stoppped = true content_block.stoppped = true
if content_block.type == "text" then
local msg = HistoryMessage:new({
role = "assistant",
content = content_block.text,
}, {
state = "generated",
uuid = content_block.uuid,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
if content_block.type == "tool_use" then
local complete_json = vim.json.decode(content_block.input_json)
local msg = HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
name = content_block.name,
id = content_block.id,
input = complete_json or {},
},
},
}, {
state = "generated",
uuid = content_block.uuid,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
if content_block.type == "thinking" then if content_block.type == "thinking" then
if opts.on_chunk then
if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then
opts.on_chunk("\n</think>\n\n") opts.on_chunk("\n</think>\n\n")
else else
opts.on_chunk("</think>\n\n") opts.on_chunk("</think>\n\n")
end end
end end
local msg = HistoryMessage:new({
role = "assistant",
content = {
{
type = "thinking",
thinking = content_block.thinking,
signature = content_block.signature,
},
},
}, {
state = "generated",
uuid = content_block.uuid,
})
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
elseif event_state == "message_delta" then elseif event_state == "message_delta" then
local ok, jsn = pcall(vim.json.decode, data_stream) local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end if not ok then return end
if jsn.delta.stop_reason == "end_turn" then if jsn.delta.stop_reason == "end_turn" then
opts.on_stop({ reason = "complete", usage = jsn.usage }) opts.on_stop({ reason = "complete", usage = jsn.usage })
elseif jsn.delta.stop_reason == "tool_use" then elseif jsn.delta.stop_reason == "tool_use" then
---@type AvanteLLMToolUse[] local tool_use_list = {}
local tool_use_list = vim for _, content_block in ipairs(ctx.content_blocks) do
.iter(ctx.content_blocks) if content_block.type == "tool_use" then
:filter(function(content_block) return content_block.stoppped and content_block.type == "tool_use" end) table.insert(tool_use_list, {
:map(function(content_block)
local response_contents = vim
.iter(ctx.content_blocks)
:filter(function(content_block_) return content_block_.stoppped and content_block_.type == "text" end)
:map(function(content_block_) return content_block_.text end)
:totable()
local thinking_blocks = vim
.iter(ctx.content_blocks)
:filter(function(content_block_) return content_block_.stoppped and content_block_.type == "thinking" end)
:map(function(content_block_)
---@type AvanteLLMThinkingBlock
return { thinking = content_block_.thinking, signature = content_block_.signature }
end)
:totable()
local redacted_thinking_blocks = vim
.iter(ctx.content_blocks)
:filter(
function(content_block_) return content_block_.stoppped and content_block_.type == "redacted_thinking" end
)
:map(function(content_block_)
---@type AvanteLLMRedactedThinkingBlock
return { data = content_block_.data }
end)
:totable()
---@type AvanteLLMToolUse
return {
name = content_block.name,
id = content_block.id, id = content_block.id,
name = content_block.name,
input_json = content_block.input_json, input_json = content_block.input_json,
response_contents = response_contents, })
thinking_blocks = thinking_blocks, end
redacted_thinking_blocks = redacted_thinking_blocks, end
}
end)
:totable()
opts.on_stop({ opts.on_stop({
reason = "tool_use", reason = "tool_use",
-- tool_use_list = tool_use_list,
usage = jsn.usage, usage = jsn.usage,
tool_use_list = tool_use_list,
}) })
end end
return return
@@ -351,7 +358,7 @@ function M:parse_curl_args(prompt_opts)
local tools = {} local tools = {}
if not disable_tools and prompt_opts.tools then if not disable_tools and prompt_opts.tools then
for _, tool in ipairs(prompt_opts.tools) do for _, tool in ipairs(prompt_opts.tools) do
if Config.behaviour.enable_claude_text_editor_tool_mode then if Config.mode == "agentic" then
if tool.name == "create_file" then goto continue end if tool.name == "create_file" then goto continue end
if tool.name == "view" then goto continue end if tool.name == "view" then goto continue end
if tool.name == "str_replace" then goto continue end if tool.name == "str_replace" then goto continue end
@@ -364,7 +371,7 @@ function M:parse_curl_args(prompt_opts)
end end
end end
if prompt_opts.tools and #prompt_opts.tools > 0 and Config.behaviour.enable_claude_text_editor_tool_mode then if prompt_opts.tools and #prompt_opts.tools > 0 and Config.mode == "agentic" then
if provider_conf.model:match("claude%-3%-7%-sonnet") then if provider_conf.model:match("claude%-3%-7%-sonnet") then
table.insert(tools, { table.insert(tools, {
type = "text_editor_20250124", type = "text_editor_20250124",

View File

@@ -211,11 +211,7 @@ M.role_map = {
function M:is_disable_stream() return false end function M:is_disable_stream() return false end
M.parse_messages = OpenAI.parse_messages setmetatable(M, { __index = OpenAI })
M.parse_response = OpenAI.parse_response
M.is_reasoning_model = OpenAI.is_reasoning_model
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
-- refresh token synchronously, only if it has expired -- refresh token synchronously, only if it has expired

View File

@@ -1,6 +1,7 @@
local Utils = require("avante.utils") local Utils = require("avante.utils")
local P = require("avante.providers") local Providers = require("avante.providers")
local Clipboard = require("avante.clipboard") local Clipboard = require("avante.clipboard")
local OpenAI = require("avante.providers").openai
---@class AvanteProviderFunctor ---@class AvanteProviderFunctor
local M = {} local M = {}
@@ -10,14 +11,32 @@ M.role_map = {
user = "user", user = "user",
assistant = "model", assistant = "model",
} }
-- M.tokenizer_id = "google/gemma-2b"
function M:is_disable_stream() return false end function M:is_disable_stream() return false end
---@param tool AvanteLLMTool
function M:transform_to_function_declaration(tool)
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
local parameters = nil
if not vim.tbl_isempty(input_schema_properties) then
parameters = {
type = "object",
properties = input_schema_properties,
required = required,
}
end
return {
name = tool.name,
description = tool.get_description and tool.get_description() or tool.description,
parameters = parameters,
}
end
function M:parse_messages(opts) function M:parse_messages(opts)
local contents = {} local contents = {}
local prev_role = nil local prev_role = nil
local tool_id_to_name = {}
vim.iter(opts.messages):each(function(message) vim.iter(opts.messages):each(function(message)
local role = message.role local role = message.role
if role == prev_role then if role == prev_role then
@@ -54,9 +73,27 @@ function M:parse_messages(opts)
}, },
}) })
elseif type(item) == "table" and item.type == "tool_use" then elseif type(item) == "table" and item.type == "tool_use" then
table.insert(parts, { text = item.name }) tool_id_to_name[item.id] = item.name
role = "model"
table.insert(parts, {
functionCall = {
name = item.name,
args = item.input,
},
})
elseif type(item) == "table" and item.type == "tool_result" then elseif type(item) == "table" and item.type == "tool_result" then
table.insert(parts, { text = item.content }) role = "function"
local ok, content = pcall(vim.json.decode, item.content)
if not ok then content = item.content end
table.insert(parts, {
functionResponse = {
name = tool_id_to_name[item.tool_use_id],
response = {
name = tool_id_to_name[item.tool_use_id],
content = content,
},
},
})
elseif type(item) == "table" and item.type == "thinking" then elseif type(item) == "table" and item.type == "thinking" then
table.insert(parts, { text = item.thinking }) table.insert(parts, { text = item.thinking })
elseif type(item) == "table" and item.type == "redacted_thinking" then elseif type(item) == "table" and item.type == "redacted_thinking" then
@@ -96,22 +133,43 @@ end
function M:parse_response(ctx, data_stream, _, opts) function M:parse_response(ctx, data_stream, _, opts)
local ok, json = pcall(vim.json.decode, data_stream) local ok, json = pcall(vim.json.decode, data_stream)
if not ok then opts.on_stop({ reason = "error", error = json }) end if not ok then opts.on_stop({ reason = "error", error = json }) end
if json.candidates then if json.candidates and #json.candidates > 0 then
if #json.candidates > 0 then local candidate = json.candidates[1]
if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then ---@type AvanteLLMToolUse[]
opts.on_chunk(json.candidates[1].content.parts[1].text) local tool_use_list = {}
opts.on_stop({ reason = "complete" }) for _, part in ipairs(candidate.content.parts) do
else if part.text then
opts.on_chunk(json.candidates[1].content.parts[1].text) if opts.on_chunk then opts.on_chunk(part.text) end
OpenAI:add_text_message(ctx, part.text, "generating", opts)
elseif part.functionCall then
if not ctx.function_call_id then ctx.function_call_id = 0 end
ctx.function_call_id = ctx.function_call_id + 1
local tool_use = {
id = ctx.session_id .. "-" .. tostring(ctx.function_call_id),
name = part.functionCall.name,
input_json = vim.json.encode(part.functionCall.args),
}
table.insert(tool_use_list, tool_use)
OpenAI:add_tool_use_message(tool_use, "generated", opts)
end end
end
if candidate.finishReason and candidate.finishReason == "STOP" then
OpenAI:finish_pending_messages(ctx, opts)
if #tool_use_list > 0 then
opts.on_stop({ reason = "tool_use", tool_use_list = tool_use_list })
else else
opts.on_stop({ reason = "complete" }) opts.on_stop({ reason = "complete" })
end end
end end
else
OpenAI:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" })
end
end end
function M:parse_curl_args(prompt_opts) function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self) local provider_conf, request_body = Providers.parse_config(self)
local disable_tools = provider_conf.disable_tools or false
request_body = vim.tbl_deep_extend("force", request_body, { request_body = vim.tbl_deep_extend("force", request_body, {
generationConfig = { generationConfig = {
@@ -125,6 +183,21 @@ function M:parse_curl_args(prompt_opts)
local api_key = self.parse_api_key() local api_key = self.parse_api_key()
if api_key == nil then error("Cannot get the gemini api key!") end if api_key == nil then error("Cannot get the gemini api key!") end
local function_declarations = {}
if not disable_tools and prompt_opts.tools then
for _, tool in ipairs(prompt_opts.tools) do
table.insert(function_declarations, self:transform_to_function_declaration(tool))
end
end
if #function_declarations > 0 then
request_body.tools = {
{
functionDeclarations = function_declarations,
},
}
end
return { return {
url = Utils.url_join( url = Utils.url_join(
provider_conf.endpoint, provider_conf.endpoint,

View File

@@ -215,14 +215,6 @@ function M.setup()
E.setup({ provider = auto_suggestions_provider }) E.setup({ provider = auto_suggestions_provider })
end end
if Config.behaviour.enable_cursor_planning_mode then
local cursor_applying_provider_name = Config.cursor_applying_provider or Config.provider
local cursor_applying_provider = M[cursor_applying_provider_name]
if cursor_applying_provider and cursor_applying_provider ~= provider then
E.setup({ provider = cursor_applying_provider })
end
end
if Config.memory_summary_provider then if Config.memory_summary_provider then
local memory_summary_provider = M[Config.memory_summary_provider] local memory_summary_provider = M[Config.memory_summary_provider]
if memory_summary_provider and memory_summary_provider ~= provider then if memory_summary_provider and memory_summary_provider ~= provider then
@@ -277,4 +269,13 @@ function M.get_config(provider_name)
return type(cur) == "function" and cur() or cur return type(cur) == "function" and cur() or cur
end end
function M.get_memory_summary_provider()
local provider_name = Config.memory_summary_provider
if provider_name == nil then
if M.openai.is_env_set() then provider_name = "openai-gpt-4o-mini" end
end
if provider_name == nil then provider_name = Config.provider end
return M[provider_name]
end
return M return M

View File

@@ -16,7 +16,7 @@ M.is_reasoning_model = P.openai.is_reasoning_model
function M:is_disable_stream() return false end function M:is_disable_stream() return false end
function M:parse_stream_data(ctx, data, handler_opts) function M:parse_stream_data(ctx, data, opts)
local ok, json_data = pcall(vim.json.decode, data) local ok, json_data = pcall(vim.json.decode, data)
if not ok or not json_data then if not ok or not json_data then
-- Add debug logging -- Add debug logging
@@ -26,11 +26,13 @@ function M:parse_stream_data(ctx, data, handler_opts)
if json_data.message and json_data.message.content then if json_data.message and json_data.message.content then
local content = json_data.message.content local content = json_data.message.content
if content and content ~= "" then handler_opts.on_chunk(content) end P.openai:add_text_message(ctx, content, "generating", opts)
if content and content ~= "" and opts.on_chunk then opts.on_chunk(content) end
end end
if json_data.done then if json_data.done then
handler_opts.on_stop({ reason = "complete" }) P.openai:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" })
return return
end end
end end

View File

@@ -2,7 +2,7 @@ local Utils = require("avante.utils")
local Config = require("avante.config") local Config = require("avante.config")
local Clipboard = require("avante.clipboard") local Clipboard = require("avante.clipboard")
local Providers = require("avante.providers") local Providers = require("avante.providers")
local StreamingJsonParser = require("avante.utils.streaming_json_parser") local HistoryMessage = require("avante.history_message")
---@class AvanteProviderFunctor ---@class AvanteProviderFunctor
local M = {} local M = {}
@@ -164,97 +164,134 @@ function M:parse_messages(opts)
table.insert(final_messages, message) table.insert(final_messages, message)
end) end)
if opts.tool_histories then
for _, tool_history in ipairs(opts.tool_histories) do
table.insert(final_messages, {
role = self.role_map["assistant"],
tool_calls = {
{
id = tool_history.tool_use.id,
type = "function",
["function"] = {
name = tool_history.tool_use.name,
arguments = tool_history.tool_use.input_json,
},
},
},
})
local result_content = tool_history.tool_result.content or ""
table.insert(final_messages, {
role = "tool",
tool_call_id = tool_history.tool_result.tool_use_id,
content = tool_history.tool_result.is_error and "Error: " .. result_content or result_content,
})
end
end
return final_messages return final_messages
end end
function M:finish_pending_messages(ctx, opts)
if ctx.content ~= nil and ctx.content ~= "" then self:add_text_message(ctx, "", "generated", opts) end
if ctx.tool_use_list then
for _, tool_use in ipairs(ctx.tool_use_list) do
if tool_use.state == "generating" then self:add_tool_use_message(tool_use, "generated", opts) end
end
end
end
function M:add_text_message(ctx, text, state, opts)
if ctx.content == nil then ctx.content = "" end
ctx.content = ctx.content .. text
local msg = HistoryMessage:new({
role = "assistant",
content = ctx.content,
}, {
state = state,
uuid = ctx.content_uuid,
})
ctx.content_uuid = msg.uuid
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
function M:add_thinking_message(ctx, text, state, opts)
if ctx.reasonging_content == nil then ctx.reasonging_content = "" end
ctx.reasonging_content = ctx.reasonging_content .. text
local msg = HistoryMessage:new({
role = "assistant",
content = {
{
type = "thinking",
thinking = ctx.reasonging_content,
signature = "",
},
},
}, {
state = state,
uuid = ctx.reasonging_content_uuid,
})
ctx.reasonging_content_uuid = msg.uuid
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
function M:add_tool_use_message(tool_use, state, opts)
local jsn = nil
if state == "generated" then jsn = vim.json.decode(tool_use.input_json) end
local msg = HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
name = tool_use.name,
id = tool_use.id,
input = jsn or {},
},
},
}, {
state = state,
uuid = tool_use.uuid,
})
tool_use.uuid = msg.uuid
tool_use.state = state
if opts.on_messages_add then opts.on_messages_add({ msg }) end
end
function M:parse_response(ctx, data_stream, _, opts) function M:parse_response(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then if data_stream:match('"%[DONE%]":') then
self:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" }) opts.on_stop({ reason = "complete" })
return return
end end
if data_stream:match('"delta":') then if not data_stream:match('"delta":') then return end
---@type AvanteOpenAIChatResponse ---@type AvanteOpenAIChatResponse
local jsn = vim.json.decode(data_stream) local jsn = vim.json.decode(data_stream)
if jsn.choices and jsn.choices[1] then if not jsn.choices or not jsn.choices[1] then return end
local choice = jsn.choices[1] local choice = jsn.choices[1]
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
if choice.delta.content and choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end if choice.delta.content and choice.delta.content ~= vim.NIL then
self:add_text_message(ctx, choice.delta.content, "generated", opts)
opts.on_chunk(choice.delta.content)
end
self:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" }) opts.on_stop({ reason = "complete" })
elseif choice.finish_reason == "tool_calls" then elseif choice.finish_reason == "tool_calls" then
self:finish_pending_messages(ctx, opts)
opts.on_stop({ opts.on_stop({
reason = "tool_use", reason = "tool_use",
-- tool_use_list = ctx.tool_use_list,
usage = jsn.usage, usage = jsn.usage,
tool_use_list = ctx.tool_use_list,
}) })
elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true ctx.returned_think_start_tag = true
opts.on_chunk("<think>\n") if opts.on_chunk then opts.on_chunk("<think>\n") end
end end
ctx.last_think_content = choice.delta.reasoning_content ctx.last_think_content = choice.delta.reasoning_content
opts.on_chunk(choice.delta.reasoning_content) self:add_thinking_message(ctx, choice.delta.reasoning_content, "generating", opts)
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning_content) end
elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true ctx.returned_think_start_tag = true
opts.on_chunk("<think>\n") if opts.on_chunk then opts.on_chunk("<think>\n") end
end end
ctx.last_think_content = choice.delta.reasoning ctx.last_think_content = choice.delta.reasoning
opts.on_chunk(choice.delta.reasoning) self:add_thinking_message(ctx, choice.delta.reasoning, "generating", opts)
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning) end
elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
for _, tool_call in ipairs(choice.delta.tool_calls) do for _, tool_call in ipairs(choice.delta.tool_calls) do
if not ctx.tool_use_list then ctx.tool_use_list = {} end if not ctx.tool_use_list then ctx.tool_use_list = {} end
if not ctx.tool_use_list[tool_call.index + 1] then if not ctx.tool_use_list[tool_call.index + 1] then
if tool_call.index > 0 and ctx.tool_use_list[tool_call.index] then
local prev_tool_use = ctx.tool_use_list[tool_call.index]
self:add_tool_use_message(prev_tool_use, "generated", opts)
end
local tool_use = { local tool_use = {
name = tool_call["function"].name, name = tool_call["function"].name,
id = tool_call.id, id = tool_call.id,
input_json = "", input_json = "",
} }
ctx.tool_use_list[tool_call.index + 1] = tool_use ctx.tool_use_list[tool_call.index + 1] = tool_use
if opts.on_partial_tool_use then self:add_tool_use_message(tool_use, "generating", opts)
opts.on_partial_tool_use({
name = tool_call["function"].name,
id = tool_call.id,
partial_json = {},
state = "generating",
})
end
else else
local tool_use = ctx.tool_use_list[tool_call.index + 1] local tool_use = ctx.tool_use_list[tool_call.index + 1]
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
if opts.on_partial_tool_use then self:add_tool_use_message(tool_use, "generating", opts)
local parser = StreamingJsonParser:new()
local partial_json = parser:parse(tool_use.input_json)
opts.on_partial_tool_use({
name = tool_call["function"].name,
id = tool_call.id,
partial_json = partial_json or {},
state = "generating",
})
end
end end
end end
elseif choice.delta.content then elseif choice.delta.content then
@@ -262,18 +299,18 @@ function M:parse_response(ctx, data_stream, _, opts)
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
then then
ctx.returned_think_end_tag = true ctx.returned_think_end_tag = true
if if opts.on_chunk then
ctx.last_think_content if ctx.last_think_content and ctx.last_think_content ~= vim.NIL and ctx.last_think_content:sub(-1) ~= "\n" then
and ctx.last_think_content ~= vim.NIL
and ctx.last_think_content:sub(-1) ~= "\n"
then
opts.on_chunk("\n</think>\n") opts.on_chunk("\n</think>\n")
else else
opts.on_chunk("</think>\n") opts.on_chunk("</think>\n")
end end
end end
if choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end self:add_thinking_message(ctx, "", "generated", opts)
end end
if choice.delta.content ~= vim.NIL then
if opts.on_chunk then opts.on_chunk(choice.delta.content) end
self:add_text_message(ctx, choice.delta.content, "generating", opts)
end end
end end
end end

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,10 @@
{% extends "base.avanterules" %}
{% block extra_prompt %}
Always reply to the user in the same language they are using.
Don't just provide code suggestions, use the `replace_in_file` tool to help users fulfill their needs.
After the tool call is complete, please do not output the entire file content.
Before calling the tool, be sure to explain the reason for calling the tool.
{% endblock %}

View File

@@ -1,6 +0,0 @@
{% extends "base.avanterules" %}
{% block extra_prompt %}
Always reply to the user in the same language they are using.
Don't just provide code suggestions, use the `str_replace` tool to help users fulfill their needs.
{% endblock %}

View File

@@ -1,4 +0,0 @@
You are a coding assistant that helps merge code updates, ensuring every modification is fully integrated.
{% block custom_prompt %}
{% endblock %}

View File

@@ -1,46 +0,0 @@
{% extends "base.avanterules" %}
{%- if ask %}
{% block extra_prompt %}
You are an intelligent programmer, powered by {{ model_name }}. You are happy to help answer any questions that the user has (usually they will be about coding).
1. When the user is asking for edits to their code, please output a simplified version of the code block that highlights the changes necessary and adds comments to indicate where unchanged code has been skipped. For example:
```language:path/to/file
// ... existing code ...
{% raw -%}
{{ edit_1 }}
{%- endraw %}
// ... existing code ...
{% raw -%}
{{ edit_2 }}
{%- endraw %}
// ... existing code ...
```
The user can see the entire file, so they prefer to only read the updates to the code. Often this will mean that the start/end of the file will be skipped, but that's okay! Rewrite the entire file only if specifically requested. Always provide a brief explanation of the updates, unless the user specifically requests only the code.
These edit codeblocks are also read by a less intelligent language model, colloquially called the apply model, to update the file. To help specify the edit to the apply model, you will be very careful when generating the codeblock to not introduce ambiguity. You will specify all unchanged regions (code and comments) of the file with "// … existing code …" comment markers. This will ensure the apply model will not delete existing unchanged code or comments when editing the file. You will not mention the apply model.
2. Do not lie or make up facts.
3. If a user messages you in a foreign language, please respond in that language.
4. Format your response in markdown.
5. When writing out new code blocks, please specify the language ID after the initial backticks, like so:
```python
{% raw -%}
{{ code }}
{%- endraw %}
```
6. When writing out code blocks for an existing file, please also specify the file path after the initial backticks and restate the method / class your codeblock belongs to, like so:
```language:some/other/file
function AIChatHistory() {
...
{% raw -%}
{{ code }}
{%- endraw %}
...
}
```
{% endblock %}
{%- endif %}

View File

@@ -75,15 +75,32 @@ vim.g.avante_login = vim.g.avante_login
---@field on_start AvanteLLMStartCallback ---@field on_start AvanteLLMStartCallback
---@field on_chunk AvanteLLMChunkCallback ---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback ---@field on_stop AvanteLLMStopCallback
---@field on_partial_tool_use? fun(tool_use: AvantePartialLLMToolUse): nil ---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil
---@field on_state_change? fun(state: avante.GenerateState): nil
--- ---
---@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string } ---@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string }
---
---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string ---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string
---
---@class AvanteLLMMessage ---@class AvanteLLMMessage
---@field role "user" | "assistant" ---@field role "user" | "assistant"
---@field content AvanteLLMMessageContent ---@field content AvanteLLMMessageContent
---@class avante.HistoryMessage
---@field message AvanteLLMMessage
---@field timestamp string
---@field state avante.HistoryMessageState
---@field uuid string | nil
---@field displayed_content string | nil
---@field visible boolean | nil
---@field is_context boolean | nil
---@field is_user_submission boolean | nil
---@field provider string | nil
---@field model string | nil
---@field selected_code AvanteSelectedCode | nil
---@field selected_filepaths string[] | nil
---@field tool_use_logs string[] | nil
---@field just_for_display boolean | nil
--- ---
---@class AvanteLLMToolResult ---@class AvanteLLMToolResult
---@field tool_name string ---@field tool_name string
@@ -96,8 +113,7 @@ vim.g.avante_login = vim.g.avante_login
---@field messages AvanteLLMMessage[] ---@field messages AvanteLLMMessage[]
---@field image_paths? string[] ---@field image_paths? string[]
---@field tools? AvanteLLMTool[] ---@field tools? AvanteLLMTool[]
---@field tool_histories? AvanteLLMToolHistory[] ---@field dropped_history_messages? avante.HistoryMessage[]
---@field dropped_history_messages? AvanteLLMMessage[]
--- ---
---@class AvanteGeminiMessage ---@class AvanteGeminiMessage
---@field role "user" ---@field role "user"
@@ -236,19 +252,18 @@ vim.g.avante_login = vim.g.avante_login
---@class AvanteLLMRedactedThinkingBlock ---@class AvanteLLMRedactedThinkingBlock
---@field data string ---@field data string
--- ---
---@alias avante.HistoryMessageState "generating" | "generated"
---
---@class AvantePartialLLMToolUse ---@class AvantePartialLLMToolUse
---@field name string ---@field name string
---@field id string ---@field id string
---@field partial_json table ---@field partial_json table
---@field state "generating" | "generated" ---@field state avante.HistoryMessageState
--- ---
---@class AvanteLLMToolUse ---@class AvanteLLMToolUse
---@field name string ---@field name string
---@field id string ---@field id string
---@field input_json string ---@field input any
---@field response_contents? string[]
---@field thinking_blocks? AvanteLLMThinkingBlock[]
---@field redacted_thinking_blocks? AvanteLLMRedactedThinkingBlock[]
--- ---
---@class AvanteLLMStartCallbackOptions ---@class AvanteLLMStartCallbackOptions
---@field usage? AvanteLLMUsage ---@field usage? AvanteLLMUsage
@@ -257,10 +272,8 @@ vim.g.avante_login = vim.g.avante_login
---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled" ---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled"
---@field error? string | table ---@field error? string | table
---@field usage? AvanteLLMUsage ---@field usage? AvanteLLMUsage
---@field tool_use_list? AvanteLLMToolUse[]
---@field retry_after? integer ---@field retry_after? integer
---@field headers? table<string, string> ---@field headers? table<string, string>
---@field tool_histories? AvanteLLMToolHistory[]
--- ---
---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
@@ -303,7 +316,7 @@ vim.g.avante_login = vim.g.avante_login
---@field parse_response AvanteResponseParser ---@field parse_response AvanteResponseParser
---@field build_bedrock_payload AvanteBedrockPayloadBuilder ---@field build_bedrock_payload AvanteBedrockPayloadBuilder
--- ---
---@alias AvanteLlmMode "planning" | "editing" | "suggesting" | "cursor-planning" | "cursor-applying" | "claude-text-editor-tool" ---@alias AvanteLlmMode avante.Mode | "editing" | "suggesting"
--- ---
---@class AvanteSelectedCode ---@class AvanteSelectedCode
---@field path string ---@field path string
@@ -324,7 +337,7 @@ vim.g.avante_login = vim.g.avante_login
---@field selected_files AvanteSelectedFile[] | nil ---@field selected_files AvanteSelectedFile[] | nil
---@field selected_filepaths string[] | nil ---@field selected_filepaths string[] | nil
---@field diagnostics string | nil ---@field diagnostics string | nil
---@field history_messages AvanteLLMMessage[] | nil ---@field history_messages avante.HistoryMessage[] | nil
---@field memory string | nil ---@field memory string | nil
--- ---
---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions ---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions
@@ -332,7 +345,6 @@ vim.g.avante_login = vim.g.avante_login
---@field mode? AvanteLlmMode ---@field mode? AvanteLlmMode
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil ---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
---@field tools? AvanteLLMTool[] ---@field tools? AvanteLLMTool[]
---@field tool_histories? AvanteLLMToolHistory[]
---@field original_code? string ---@field original_code? string
---@field update_snippets? string[] ---@field update_snippets? string[]
---@field prompt_opts? AvantePromptOptions ---@field prompt_opts? AvantePromptOptions
@@ -342,9 +354,10 @@ vim.g.avante_login = vim.g.avante_login
---@field tool_result? AvanteLLMToolResult ---@field tool_result? AvanteLLMToolResult
---@field tool_use? AvanteLLMToolUse ---@field tool_use? AvanteLLMToolUse
--- ---
---@alias AvanteLLMMemorySummarizeCallback fun(dropped_history_messages: AvanteLLMMessage[]): nil ---@alias AvanteLLMMemorySummarizeCallback fun(dropped_history_messages: avante.HistoryMessage[]): nil
--- ---
---@alias AvanteLLMToolUseState "generating" | "generated" | "running" | "succeeded" | "failed" ---@alias AvanteLLMToolUseState "generating" | "generated" | "running" | "succeeded" | "failed"
---@alias avante.GenerateState "generating" | "tool calling" | "failed" | "succeeded" | "cancelled" | "searching" | "thinking"
--- ---
---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions ---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions
---@field on_start AvanteLLMStartCallback ---@field on_start AvanteLLMStartCallback
@@ -352,7 +365,9 @@ vim.g.avante_login = vim.g.avante_login
---@field on_stop AvanteLLMStopCallback ---@field on_stop AvanteLLMStopCallback
---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback ---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback
---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil ---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@field on_partial_tool_use? fun(tool_use: AvantePartialLLMToolUse): nil ---@field get_history_messages? fun(): avante.HistoryMessage[]
---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil
---@field on_state_change? fun(state: avante.GenerateState): nil
--- ---
---@alias AvanteLLMToolFunc<T> fun( ---@alias AvanteLLMToolFunc<T> fun(
--- input: T, --- input: T,
@@ -400,15 +415,14 @@ vim.g.avante_login = vim.g.avante_login
---@field original_response string ---@field original_response string
---@field selected_file {filepath: string}? ---@field selected_file {filepath: string}?
---@field selected_code AvanteSelectedCode | nil ---@field selected_code AvanteSelectedCode | nil
---@field reset_memory boolean?
---@field selected_filepaths string[] | nil ---@field selected_filepaths string[] | nil
---@field visible boolean? ---@field visible boolean?
---@field tool_histories? AvanteLLMToolHistory[]
--- ---
---@class avante.ChatHistory ---@class avante.ChatHistory
---@field title string ---@field title string
---@field timestamp string ---@field timestamp string
---@field entries avante.ChatHistoryEntry[] ---@field messages avante.HistoryMessage[] | nil
---@field entries avante.ChatHistoryEntry[] | nil
---@field memory avante.ChatMemory | nil ---@field memory avante.ChatMemory | nil
---@field filename string ---@field filename string
---@field system_prompt string | nil ---@field system_prompt string | nil
@@ -416,6 +430,7 @@ vim.g.avante_login = vim.g.avante_login
---@class avante.ChatMemory ---@class avante.ChatMemory
---@field content string ---@field content string
---@field last_summarized_timestamp string ---@field last_summarized_timestamp string
---@field last_message_uuid string | nil
--- ---
---@class avante.CurlOpts ---@class avante.CurlOpts
---@field provider AvanteProviderFunctor ---@field provider AvanteProviderFunctor
@@ -427,7 +442,7 @@ vim.g.avante_login = vim.g.avante_login
---@field content string ---@field content string
---@field uri string ---@field uri string
--- ---
---@alias AvanteSlashCommandBuiltInName "clear" | "help" | "lines" | "reset" | "commit" | "new" ---@alias AvanteSlashCommandBuiltInName "clear" | "help" | "lines" | "commit" | "new"
---@alias AvanteSlashCommandCallback fun(self: avante.Sidebar, args: string, cb?: fun(args: string): nil): nil ---@alias AvanteSlashCommandCallback fun(self: avante.Sidebar, args: string, cb?: fun(args: string): nil): nil
---@class AvanteSlashCommand ---@class AvanteSlashCommand
---@field name AvanteSlashCommandBuiltInName | string ---@field name AvanteSlashCommandBuiltInName | string

View File

@@ -80,7 +80,7 @@ function M.show(selector)
selector.on_select(selected_item_ids) selector.on_select(selected_item_ids)
actions.close(prompt_bufnr) pcall(actions.close, prompt_bufnr)
end) end)
return true return true
end, end,

View File

@@ -1,87 +0,0 @@
local Utils = require("avante.utils")
local Config = require("avante.config")
---@class avante.utils.history
local M = {}
---@param entries avante.ChatHistoryEntry[]
---@return avante.ChatHistoryEntry[]
function M.filter_active_entries(entries)
local entries_ = {}
for i = #entries, 1, -1 do
local entry = entries[i]
if entry.reset_memory then break end
table.insert(entries_, 1, entry)
end
return entries_
end
---@param entries avante.ChatHistoryEntry[]
---@return AvanteLLMMessage[]
function M.entries_to_llm_messages(entries)
local current_provider_name = Config.provider
local messages = {}
for _, entry in ipairs(entries) do
if entry.selected_filepaths ~= nil and #entry.selected_filepaths > 0 then
local user_content = "SELECTED FILES:\n\n"
for _, filepath in ipairs(entry.selected_filepaths) do
user_content = user_content .. filepath .. "\n"
end
table.insert(messages, { role = "user", content = user_content })
end
if entry.selected_code ~= nil then
local user_content_ = "SELECTED CODE:\n\n```"
.. (entry.selected_code.file_type or "")
.. (entry.selected_code.path and ":" .. entry.selected_code.path or "")
.. "\n"
.. entry.selected_code.content
.. "\n```\n\n"
table.insert(messages, { role = "user", content = user_content_ })
end
if entry.request ~= nil and entry.request ~= "" then
table.insert(messages, { role = "user", content = entry.request })
end
if entry.tool_histories ~= nil and #entry.tool_histories > 0 and entry.provider == current_provider_name then
for _, tool_history in ipairs(entry.tool_histories) do
local assistant_content = {}
if tool_history.tool_use ~= nil then
if tool_history.tool_use.response_contents ~= nil then
for _, response_content in ipairs(tool_history.tool_use.response_contents) do
table.insert(assistant_content, { type = "text", text = response_content })
end
end
table.insert(assistant_content, {
type = "tool_use",
name = tool_history.tool_use.name,
id = tool_history.tool_use.id,
input = vim.json.decode(tool_history.tool_use.input_json),
})
end
table.insert(messages, {
role = "assistant",
content = assistant_content,
})
local user_content = {}
if tool_history.tool_result ~= nil and tool_history.tool_result.content ~= nil then
table.insert(user_content, {
type = "tool_result",
tool_use_id = tool_history.tool_result.tool_use_id,
content = tool_history.tool_result.content,
is_error = tool_history.tool_result.is_error,
})
end
table.insert(messages, {
role = "user",
content = user_content,
})
end
end
local assistant_content = Utils.trim_think_content(entry.original_response or "")
if assistant_content ~= "" then table.insert(messages, { role = "assistant", content = assistant_content }) end
end
return messages
end
return M

View File

@@ -6,7 +6,6 @@ local lsp = vim.lsp
---@field tokens avante.utils.tokens ---@field tokens avante.utils.tokens
---@field root avante.utils.root ---@field root avante.utils.root
---@field file avante.utils.file ---@field file avante.utils.file
---@field history avante.utils.history
---@field environment avante.utils.environment ---@field environment avante.utils.environment
---@field lsp avante.utils.lsp ---@field lsp avante.utils.lsp
local M = {} local M = {}
@@ -415,7 +414,7 @@ function M.debug(...)
local caller_source = info.source:match("@(.+)$") or "unknown" local caller_source = info.source:match("@(.+)$") or "unknown"
local caller_module = caller_source:gsub("^.*/lua/", ""):gsub("%.lua$", ""):gsub("/", ".") local caller_module = caller_source:gsub("^.*/lua/", ""):gsub("%.lua$", ""):gsub("/", ".")
local timestamp = os.date("%Y-%m-%d %H:%M:%S") local timestamp = M.get_timestamp()
local formated_args = { local formated_args = {
"[" .. timestamp .. "] [AVANTE] [DEBUG] [" .. caller_module .. ":" .. info.currentline .. "]", "[" .. timestamp .. "] [AVANTE] [DEBUG] [" .. caller_module .. ":" .. info.currentline .. "]",
} }
@@ -1263,7 +1262,6 @@ function M.get_commands()
local builtin_items = { local builtin_items = {
{ description = "Show help message", name = "help" }, { description = "Show help message", name = "help" },
{ description = "Clear chat history", name = "clear" }, { description = "Clear chat history", name = "clear" },
{ description = "Reset memory", name = "reset" },
{ description = "New chat", name = "new" }, { description = "New chat", name = "new" },
{ {
shorthelp = "Ask a question about specific lines", shorthelp = "Ask a question about specific lines",
@@ -1281,7 +1279,6 @@ function M.get_commands()
if cb then cb(args) end if cb then cb(args) end
end, end,
clear = function(sidebar, args, cb) sidebar:clear_history(args, cb) end, clear = function(sidebar, args, cb) sidebar:clear_history(args, cb) end,
reset = function(sidebar, args, cb) sidebar:reset_memory(args, cb) end,
new = function(sidebar, args, cb) sidebar:new_chat(args, cb) end, new = function(sidebar, args, cb) sidebar:new_chat(args, cb) end,
lines = function(_, args, cb) lines = function(_, args, cb)
if cb then cb(args) end if cb then cb(args) end
@@ -1310,4 +1307,97 @@ function M.get_commands()
return vim.list_extend(builtin_commands, Config.slash_commands) return vim.list_extend(builtin_commands, Config.slash_commands)
end end
---@param history avante.ChatHistory
---@return avante.HistoryMessage[]
function M.get_history_messages(history)
local HistoryMessage = require("avante.history_message")
if history.messages then return history.messages end
local messages = {}
for _, entry in ipairs(history.entries or {}) do
if entry.request and entry.request ~= "" then
local message = HistoryMessage:new({
role = "user",
content = entry.request,
}, {
timestamp = entry.timestamp,
is_user_submission = true,
visible = entry.visible,
selected_filepaths = entry.selected_filepaths,
selected_code = entry.selected_code,
})
table.insert(messages, message)
end
if entry.response and entry.response ~= "" then
local message = HistoryMessage:new({
role = "assistant",
content = entry.response,
}, {
timestamp = entry.timestamp,
visible = entry.visible,
})
table.insert(messages, message)
end
end
history.messages = messages
return messages
end
function M.get_timestamp() return tostring(os.date("%Y-%m-%d %H:%M:%S")) end
---@param history_messages avante.HistoryMessage[]
---@return AvanteLLMMessage[]
function M.history_messages_to_messages(history_messages)
local messages = {}
for _, history_message in ipairs(history_messages) do
if history_message.just_for_display then goto continue end
table.insert(messages, history_message.message)
::continue::
end
return messages
end
function M.uuid()
local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx"
return string.gsub(template, "[xy]", function(c)
local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb)
return string.format("%x", v)
end)
end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@return string
function M.message_content_item_to_text(item, message)
if type(item) == "string" then return item end
if type(item) == "table" then
if item.type == "text" then return item.text end
if item.type == "image" then return "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" end
if item.type == "tool_use" then
local pieces = {}
table.insert(pieces, string.format("[%s]: calling", item.name))
for _, log in ipairs(message.tool_use_logs or {}) do
table.insert(pieces, log)
end
return table.concat(pieces, "\n")
end
end
return ""
end
---@param message avante.HistoryMessage
---@return string
function M.message_to_text(message)
local content = message.message.content
if type(content) == "string" then return content end
if vim.islist(content) then
local pieces = {}
for _, item in ipairs(content) do
local text = M.message_content_item_to_text(item, message)
if text ~= "" then table.insert(pieces, text) end
end
return table.concat(pieces, "\n")
end
return ""
end
return M return M

View File

@@ -77,10 +77,24 @@ function StreamingJSONParser:parse(chunk)
-- Handle strings specially (they can contain JSON control characters) -- Handle strings specially (they can contain JSON control characters)
if self.state.inString then if self.state.inString then
if self.state.escaping then if self.state.escaping then
local escapeMap = {
['"'] = '"',
["\\"] = "\\",
["/"] = "/",
["b"] = "\b",
["f"] = "\f",
["n"] = "\n",
["r"] = "\r",
["t"] = "\t",
}
local escapedChar = escapeMap[char]
if escapedChar then
self.state.stringBuffer = self.state.stringBuffer .. escapedChar
else
self.state.stringBuffer = self.state.stringBuffer .. char self.state.stringBuffer = self.state.stringBuffer .. char
end
self.state.escaping = false self.state.escaping = false
elseif char == "\\" then elseif char == "\\" then
self.state.stringBuffer = self.state.stringBuffer .. char
self.state.escaping = true self.state.escaping = true
elseif char == '"' then elseif char == '"' then
-- End of string -- End of string

View File

@@ -132,17 +132,13 @@ cmd(
cmd("Clear", function(opts) cmd("Clear", function(opts)
local arg = vim.trim(opts.args or "") local arg = vim.trim(opts.args or "")
arg = arg == "" and "history" or arg arg = arg == "" and "history" or arg
if arg == "history" or arg == "memory" then if arg == "history" then
local sidebar = require("avante").get() local sidebar = require("avante").get()
if not sidebar then if not sidebar then
Utils.error("No sidebar found") Utils.error("No sidebar found")
return return
end end
if arg == "history" then
sidebar:clear_history() sidebar:clear_history()
else
sidebar:reset_memory()
end
elseif arg == "cache" then elseif arg == "cache" then
local P = require("avante.path") local P = require("avante.path")
local history_path = P.history_path:absolute() local history_path = P.history_path:absolute()

View File

@@ -29,6 +29,13 @@ describe("StreamingJSONParser", function()
assert.equals("value", result.key) assert.equals("value", result.key)
end) end)
it("should parse breaklines", function()
local result, complete = parser:parse('{"key": "value\nv"}')
assert.is_true(complete)
assert.is_table(result)
assert.equals("value\nv", result.key)
end)
it("should parse a complete simple JSON array", function() it("should parse a complete simple JSON array", function()
local result, complete = parser:parse("[1, 2, 3]") local result, complete = parser:parse("[1, 2, 3]")
assert.is_true(complete) assert.is_true(complete)
@@ -119,7 +126,7 @@ describe("StreamingJSONParser", function()
local result, complete = parser:parse('{"text": "line1\\nline2\\t\\"quoted\\""}') local result, complete = parser:parse('{"text": "line1\\nline2\\t\\"quoted\\""}')
assert.is_true(complete) assert.is_true(complete)
assert.is_table(result) assert.is_table(result)
assert.equals('line1\\nline2\\t\\"quoted\\"', result.text) assert.equals('line1\nline2\t"quoted"', result.text)
end) end)
it("should handle numbers correctly", function() it("should handle numbers correctly", function()