refactor: llm tool parameters (#2449)

This commit is contained in:
yetone
2025-07-15 16:40:25 +08:00
committed by GitHub
parent 0c6a8f5688
commit b8bb0fd969
25 changed files with 627 additions and 381 deletions

View File

@@ -47,7 +47,9 @@ local Highlights = {
AVANTE_STATE_SPINNER_SEARCHING = { name = "AvanteStateSpinnerSearching", fg = "#1e222a", bg = "#c678dd" },
AVANTE_STATE_SPINNER_THINKING = { name = "AvanteStateSpinnerThinking", fg = "#1e222a", bg = "#c678dd" },
AVANTE_STATE_SPINNER_COMPACTING = { name = "AvanteStateSpinnerCompacting", fg = "#1e222a", bg = "#c678dd" },
AVANTE_TASK_RUNNING = { name = "AvanteTaskRunning", fg = "#c678dd", bg_link = "Normal" },
AVANTE_TASK_COMPLETED = { name = "AvanteTaskCompleted", fg = "#98c379", bg_link = "Normal" },
AVANTE_TASK_FAILED = { name = "AvanteTaskFailed", fg = "#e06c75", bg_link = "Normal" },
AVANTE_THINKING = { name = "AvanteThinking", fg = "#c678dd", bg_link = "Normal" },
}

View File

@@ -150,8 +150,11 @@ function M.generate_todos(user_input, cb)
local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages)
for _, partial_tool_use in ipairs(uncalled_tool_uses) do
if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then
LLMTools.process_tool_use(tools, partial_tool_use, function() end, function() cb() end, {})
cb()
local result = LLMTools.process_tool_use(tools, partial_tool_use, {
session_ctx = {},
on_complete = function() cb() end,
})
if result ~= nil then cb() end
end
end
else
@@ -206,6 +209,7 @@ function M.agent_loop(opts)
table.insert(history_messages, msg)
end
end
if opts.on_messages_add then opts.on_messages_add(msgs) end
end,
session_ctx = session_ctx,
prompt_opts = {
@@ -331,8 +335,9 @@ function M.generate_prompts(opts)
end
if Config.system_prompt ~= nil then
local custom_system_prompt = Config.system_prompt
if type(custom_system_prompt) == "function" then custom_system_prompt = custom_system_prompt() end
local custom_system_prompt
if type(Config.system_prompt) == "function" then custom_system_prompt = Config.system_prompt() end
if type(Config.system_prompt) == "string" then custom_system_prompt = Config.system_prompt end
if custom_system_prompt ~= nil and custom_system_prompt ~= "" and custom_system_prompt ~= "null" then
system_prompt = system_prompt .. "\n\n" .. custom_system_prompt
end
@@ -841,13 +846,9 @@ function M._stream(opts)
if partial_tool_use.state == "generating" then
if type(partial_tool_use.input) == "table" then
partial_tool_use.input.streaming = true
LLMTools.process_tool_use(
prompt_opts.tools,
partial_tool_use,
function() end,
function() end,
opts.session_ctx
)
LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, {
session_ctx = opts.session_ctx,
})
end
return
else
@@ -856,13 +857,12 @@ function M._stream(opts)
partial_tool_use_message.is_calling = true
if opts.on_messages_add then opts.on_messages_add({ partial_tool_use_message }) end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use(
prompt_opts.tools,
partial_tool_use,
opts.on_tool_log,
handle_tool_result,
opts.session_ctx
)
local result, error = LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, {
session_ctx = opts.session_ctx,
on_log = opts.on_tool_log,
set_tool_use_store = opts.set_tool_use_store,
on_complete = handle_tool_result,
})
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end
if stop_opts.reason == "cancelled" then
@@ -1100,6 +1100,13 @@ function M.stream(opts)
return original_on_tool_log(...)
end)
end
if opts.set_tool_use_store ~= nil then
local original_set_tool_use_store = opts.set_tool_use_store
opts.set_tool_use_store = vim.schedule_wrap(function(...)
if not original_set_tool_use_store then return end
return original_set_tool_use_store(...)
end)
end
if opts.on_chunk ~= nil then
local original_on_chunk = opts.on_chunk
opts.on_chunk = vim.schedule_wrap(function(chunk)

View File

@@ -65,10 +65,11 @@ M.returns = {
M.on_render = function() return {} end
---@type AvanteLLMToolFunc<{ todos: avante.TODO[] }>
function M.func(opts, on_log, on_complete, session_ctx)
function M.func(input, opts)
local on_complete = opts.on_complete
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local todos = opts.todos
local todos = input.todos
if not todos or #todos == 0 then return false, "No todos provided" end
sidebar:update_todos(todos)
if on_complete then

View File

@@ -57,11 +57,11 @@ M.returns = {
}
---@type avante.LLMToolOnRender<AttemptCompletionInput>
function M.on_render(opts)
function M.on_render(input)
local lines = {}
table.insert(lines, Line:new({ { "✓ Task Completed", Highlights.AVANTE_TASK_COMPLETED } }))
table.insert(lines, Line:new({ { "" } }))
local result = opts.result or ""
local result = input.result or ""
local text_lines = vim.split(result, "\n")
for _, text_line in ipairs(text_lines) do
table.insert(lines, Line:new({ { text_line } }))
@@ -70,24 +70,29 @@ function M.on_render(opts)
end
---@type AvanteLLMToolFunc<AttemptCompletionInput>
function M.func(opts, on_log, on_complete, session_ctx)
if not on_complete then return false, "on_complete not provided" end
function M.func(input, opts)
if not opts.on_complete then return false, "on_complete not provided" end
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local is_streaming = opts.streaming or false
local is_streaming = input.streaming or false
if is_streaming then
-- wait for stream completion as command may not be complete yet
return
end
session_ctx.attempt_completion_is_called = true
opts.session_ctx.attempt_completion_is_called = true
if opts.command and opts.command ~= vim.NIL and opts.command ~= "" and not vim.startswith(opts.command, "open ") then
session_ctx.always_yes = false
require("avante.llm_tools.bash").func({ command = opts.command }, on_log, on_complete, session_ctx)
if
input.command
and input.command ~= vim.NIL
and input.command ~= ""
and not vim.startswith(input.command, "open ")
then
opts.session_ctx.always_yes = false
require("avante.llm_tools.bash").func({ command = input.command }, opts)
else
on_complete(true, nil)
opts.on_complete(true, nil)
end
end

View File

@@ -215,18 +215,18 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, command: string, streaming?: boolean }>
function M.func(opts, on_log, on_complete, session_ctx)
local is_streaming = opts.streaming or false
function M.func(input, opts)
local is_streaming = input.streaming or false
if is_streaming then
-- wait for stream completion as command may not be complete yet
return
end
local abs_path = Helpers.get_abs_path(opts.path)
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if not opts.command then return false, "Command is required" end
if on_log then on_log("command: " .. opts.command) end
if not input.command then return false, "Command is required" end
if opts.on_log then opts.on_log("command: " .. input.command) end
---change cwd to abs_path
---@param output string
@@ -240,21 +240,21 @@ function M.func(opts, on_log, on_complete, session_ctx)
end
return output, nil
end
if not on_complete then return false, "on_complete not provided" end
if not opts.on_complete then return false, "on_complete not provided" end
Helpers.confirm(
"Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path,
"Are you sure you want to run the command: `" .. input.command .. "` in the directory: " .. abs_path,
function(ok, reason)
if not ok then
on_complete(false, "User declined, reason: " .. (reason and reason or "unknown"))
opts.on_complete(false, "User declined, reason: " .. (reason and reason or "unknown"))
return
end
Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code)
Utils.shell_run_async(input.command, "bash -c", function(output, exit_code)
local result, err = handle_result(output, exit_code)
on_complete(result, err)
opts.on_complete(result, err)
end, abs_path)
end,
{ focus = true },
session_ctx,
opts.session_ctx,
M.name -- Pass the tool name for permission checking
)
end

View File

@@ -45,20 +45,23 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, file_text: string }>
function M.func(opts, on_log, on_complete, session_ctx)
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("path: " .. opts.path) end
if Helpers.already_in_context(opts.path) then
if on_log then on_log("path: " .. input.path) end
if Helpers.already_in_context(input.path) then
on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to create it again?")
return
end
local abs_path = Helpers.get_abs_path(opts.path)
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.file_text == nil then return false, "file_text not provided" end
if input.file_text == nil then return false, "file_text not provided" end
if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end
local lines = vim.split(opts.file_text, "\n")
if #lines == 1 and opts.file_text:match("\\n") then
local text = Utils.trim_slashes(opts.file_text)
local lines = vim.split(input.file_text, "\n")
if #lines == 1 and input.file_text:match("\\n") then
local text = Utils.trim_slashes(input.file_text)
lines = vim.split(text, "\n")
end
local bufnr, err = Helpers.get_bufnr(abs_path)

View File

@@ -40,7 +40,7 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ tool_use_id: string }>
function M.func(opts, on_log, on_complete, session_ctx)
function M.func(input, opts)
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local history_messages = Utils.get_history_messages(sidebar.chat_history)
@@ -50,7 +50,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
local content = msg.message.content
if type(content) == "table" then
for _, item in ipairs(content) do
if item.id == opts.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end
if item.id == input.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end
end
end
end

View File

@@ -3,6 +3,8 @@ local Config = require("avante.config")
local Utils = require("avante.utils")
local Base = require("avante.llm_tools.base")
local HistoryMessage = require("avante.history_message")
local Line = require("avante.ui.line")
local Highlights = require("avante.highlights")
---@class AvanteLLMTool
local M = setmetatable({}, Base)
@@ -79,12 +81,118 @@ local function get_available_tools()
}
end
---@type AvanteLLMToolFunc<{ prompt: string }>
function M.func(opts, on_log, on_complete, session_ctx)
---@class avante.DispatchAgentInput
---@field prompt string
---@type avante.LLMToolOnRender<avante.DispatchAgentInput>
function M.on_render(input, opts)
local result_message = opts.result_message
local store = opts.store or {}
local messages = store.messages or {}
local tool_use_summary = {}
for _, msg in ipairs(messages) do
local content = msg.message.content
local summary
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
local tool_result_message = Utils.get_tool_result_message(msg, messages)
if tool_result_message then
local tool_name = msg.message.content[1].name
if tool_name == "ls" then
local path = msg.message.content[1].input.path
if tool_result_message.message.content[1].is_error then
summary = string.format("Ls %s: failed", path)
else
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then summary = string.format("Ls %s: %d paths", path, #filepaths) end
end
elseif tool_name == "grep" then
local path = msg.message.content[1].input.path
local query = msg.message.content[1].input.query
if tool_result_message.message.content[1].is_error then
summary = string.format("Grep %s in %s: failed", query, path)
else
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then summary = string.format("Grep %s in %s: %d paths", query, path, #filepaths) end
end
elseif tool_name == "glob" then
local path = msg.message.content[1].input.path
local pattern = msg.message.content[1].input.pattern
if tool_result_message.message.content[1].is_error then
summary = string.format("Glob %s in %s: failed", pattern, path)
else
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then
local matches = result.matches
if matches then summary = string.format("Glob %s in %s: %d matches", pattern, path, #matches) end
end
end
elseif tool_name == "view" then
local path = msg.message.content[1].input.path
if tool_result_message.message.content[1].is_error then
summary = string.format("View %s: failed", path)
else
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then
local content_ = result.content
local lines = vim.split(content_, "\n")
summary = string.format("View %s: %d lines", path, #lines)
end
end
end
end
if summary then summary = " " .. Utils.icon("🛠️ ") .. summary end
elseif type(content) == "table" and #content > 0 and type(content[1]) == "table" and content[1].type == "text" then
summary = content[1].content
elseif type(content) == "table" and #content > 0 and type(content[1]) == "string" then
summary = content[1]
elseif type(content) == "string" then
summary = content
end
if summary then table.insert(tool_use_summary, summary) end
end
local state = "running"
local icon = Utils.icon("🔄 ")
local hl = Highlights.AVANTE_TASK_RUNNING
if result_message then
if result_message.message.content[1].is_error then
state = "failed"
icon = Utils.icon("")
hl = Highlights.AVANTE_TASK_FAILED
else
state = "completed"
icon = Utils.icon("")
hl = Highlights.AVANTE_TASK_COMPLETED
end
end
local lines = {}
table.insert(lines, Line:new({ { icon .. "Subtask " .. state, hl } }))
table.insert(lines, Line:new({ { "" } }))
table.insert(lines, Line:new({ { " Task:" } }))
local prompt_lines = vim.split(input.prompt or "", "\n")
for _, line in ipairs(prompt_lines) do
table.insert(lines, Line:new({ { " " .. line } }))
end
table.insert(lines, Line:new({ { "" } }))
table.insert(lines, Line:new({ { " Task summary:" } }))
for _, summary in ipairs(tool_use_summary) do
local summary_lines = vim.split(summary, "\n")
for _, line in ipairs(summary_lines) do
table.insert(lines, Line:new({ { " " .. line } }))
end
end
return lines
end
---@type AvanteLLMToolFunc<avante.DispatchAgentInput>
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
local Llm = require("avante.llm")
if not on_complete then return false, "on_complete not provided" end
local prompt = opts.prompt
local prompt = input.prompt
local tools = get_available_tools()
local start_time = Utils.get_timestamp()
@@ -95,10 +203,11 @@ Your task is to help the user with their request: "${prompt}"
Be thorough and use the tools available to you to find the most relevant information.
When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt)
local history_messages = {}
local tool_use_messages = {}
local total_tokens = 0
local final_response = ""
local result = ""
---@type avante.AgentLoopOptions
local agent_loop_options = {
@@ -108,19 +217,37 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
on_tool_log = session_ctx.on_tool_log,
on_messages_add = function(msgs)
msgs = vim.islist(msgs) and msgs or { msgs }
for _, msg in ipairs(msgs) do
local idx = nil
for i, m in ipairs(history_messages) do
if m.uuid == msg.uuid then
idx = i
break
end
end
if idx ~= nil then
history_messages[idx] = msg
else
table.insert(history_messages, msg)
end
end
if opts.set_store then opts.set_store("messages", history_messages) end
for _, msg in ipairs(msgs) do
local content = msg.message.content
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
tool_use_messages[msg.uuid] = true
if content[1].name == "attempt_completion" then
local input_ = content[1].input
if input_ and input_.result then result = input_.result end
end
end
end
if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
-- if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
end,
session_ctx = session_ctx,
on_start = session_ctx.on_start,
on_chunk = function(chunk)
if not chunk then return end
final_response = final_response .. chunk
total_tokens = total_tokens + (#vim.split(chunk, " ") * 1.3)
end,
on_complete = function(err)
@@ -148,8 +275,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
})
session_ctx.on_messages_add({ message })
end
local response = string.format("Final response:\n%s\n\nSummary:\n%s", summary, final_response)
on_complete(response, nil)
on_complete(result, nil)
end,
}

View File

@@ -40,10 +40,12 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, diff: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if not opts.path then return false, "pathf are required" end
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
if not input.path then return false, "pathf are required" end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not on_complete then return false, "on_complete is required" end
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(abs_path)

View File

@@ -45,12 +45,14 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, pattern: string }>
function M.func(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("pattern: " .. opts.pattern) end
local files = vim.fn.glob(abs_path .. "/" .. opts.pattern, true, true)
if on_log then on_log("pattern: " .. input.pattern) end
local files = vim.fn.glob(abs_path .. "/" .. input.pattern, true, true)
local truncated_files = {}
local is_truncated = false
local size = 0

View File

@@ -69,8 +69,10 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }>
function M.func(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.func(input, opts)
local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end
@@ -85,31 +87,31 @@ function M.func(opts, on_log, on_complete, session_ctx)
local cmd = ""
if search_cmd:find("rg") then
cmd = string.format("%s --files-with-matches --hidden", search_cmd)
if opts.case_sensitive then
if input.case_sensitive then
cmd = string.format("%s --case-sensitive", cmd)
else
cmd = string.format("%s --ignore-case", cmd)
end
if opts.include_pattern then cmd = string.format("%s --glob '%s'", cmd, opts.include_pattern) end
if opts.exclude_pattern then cmd = string.format("%s --glob '!%s'", cmd, opts.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path)
if input.include_pattern then cmd = string.format("%s --glob '%s'", cmd, input.include_pattern) end
if input.exclude_pattern then cmd = string.format("%s --glob '!%s'", cmd, input.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
elseif search_cmd:find("ag") then
cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd)
if opts.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) end
if opts.include_pattern then cmd = string.format("%s --ignore '!%s'", cmd, opts.include_pattern) end
if opts.exclude_pattern then cmd = string.format("%s --ignore '%s'", cmd, opts.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path)
if input.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) end
if input.include_pattern then cmd = string.format("%s --ignore '!%s'", cmd, input.include_pattern) end
if input.exclude_pattern then cmd = string.format("%s --ignore '%s'", cmd, input.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
elseif search_cmd:find("ack") then
cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd)
if opts.case_sensitive then cmd = string.format("%s --smart-case", cmd) end
if opts.exclude_pattern then cmd = string.format("%s --ignore-dir '%s'", cmd, opts.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path)
if input.case_sensitive then cmd = string.format("%s --smart-case", cmd) end
if input.exclude_pattern then cmd = string.format("%s --ignore-dir '%s'", cmd, input.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
elseif search_cmd:find("grep") then
cmd = string.format("cd %s && git ls-files -co --exclude-standard | xargs %s -rH", abs_path, search_cmd, abs_path)
if not opts.case_sensitive then cmd = string.format("%s -i", cmd) end
if opts.include_pattern then cmd = string.format("%s --include '%s'", cmd, opts.include_pattern) end
if opts.exclude_pattern then cmd = string.format("%s --exclude '%s'", cmd, opts.exclude_pattern) end
cmd = string.format("%s '%s'", cmd, opts.query)
if not input.case_sensitive then cmd = string.format("%s -i", cmd) end
if input.include_pattern then cmd = string.format("%s --include '%s'", cmd, input.include_pattern) end
if input.exclude_pattern then cmd = string.format("%s --exclude '%s'", cmd, input.exclude_pattern) end
cmd = string.format("%s '%s'", cmd, input.query)
end
Utils.debug("cmd", cmd)

View File

@@ -8,9 +8,10 @@ local Helpers = require("avante.llm_tools.helpers")
local M = {}
---@type AvanteLLMToolFunc<{ path: string }>
function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
function M.read_file_toplevel_symbols(input, opts)
local on_log = opts.on_log
local RepoMap = require("avante.repo_map")
local abs_path = Helpers.get_abs_path(opts.path)
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end
@@ -24,50 +25,47 @@ function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
if opts.command == "undo_edit" then
return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx)
end
---@cast opts any
return M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
function M.str_replace_editor(input, opts)
if input.command == "undo_edit" then return require("avante.llm_tools.undo_edit").func(input, opts) end
---@cast input any
return M.str_replace_based_edit_tool(input, opts)
end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }>
function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
if not opts.command then return false, "command not provided" end
if on_log then on_log("command: " .. opts.command) end
function M.str_replace_based_edit_tool(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
if not input.command then return false, "command not provided" end
if on_log then on_log("command: " .. input.command) end
if not on_complete then return false, "on_complete not provided" end
local abs_path = Helpers.get_abs_path(opts.path)
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.command == "view" then
if input.command == "view" then
local view = require("avante.llm_tools.view")
local opts_ = { path = opts.path }
if opts.view_range then
local start_line, end_line = unpack(opts.view_range)
opts_.start_line = start_line
opts_.end_line = end_line
local input_ = { path = input.path }
if input.view_range then
local start_line, end_line = unpack(input.view_range)
input_.start_line = start_line
input_.end_line = end_line
end
return view(opts_, on_log, on_complete, session_ctx)
return view(input_, opts)
end
if opts.command == "str_replace" then
if opts.new_str == nil and opts.file_text ~= nil then
opts.new_str = opts.file_text
opts.file_text = nil
if input.command == "str_replace" then
if input.new_str == nil and input.file_text ~= nil then
input.new_str = input.file_text
input.file_text = nil
end
return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx)
return require("avante.llm_tools.str_replace").func(input, opts)
end
if opts.command == "create" then
return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx)
end
if opts.command == "insert" then
return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx)
end
return false, "Unknown command: " .. opts.command
if input.command == "create" then return require("avante.llm_tools.create").func(input, opts) end
if input.command == "insert" then return require("avante.llm_tools.insert").func(input, opts) end
return false, "Unknown command: " .. input.command
end
---@type AvanteLLMToolFunc<{ abs_path: string }>
function M.read_global_file(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.abs_path)
function M.read_global_file(input, opts)
local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(input.abs_path)
if Helpers.is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
local file = io.open(abs_path, "r")
@@ -78,11 +76,13 @@ function M.read_global_file(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ abs_path: string, content: string }>
function M.write_global_file(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.abs_path)
function M.write_global_file(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.abs_path)
if Helpers.is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("content: " .. opts.content) end
if on_log then on_log("content: " .. input.content) end
if not on_complete then return false, "on_complete not provided" end
Helpers.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok)
if not ok then
@@ -94,18 +94,20 @@ function M.write_global_file(opts, on_log, on_complete, session_ctx)
on_complete(false, "file not found: " .. abs_path)
return
end
file:write(opts.content)
file:write(input.content)
file:close()
on_complete(true, nil)
end, nil, session_ctx, "write_global_file")
end, nil, opts.session_ctx, "write_global_file")
end
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
function M.move_path(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.source_path)
function M.move_path(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.source_path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.destination_path)
local new_abs_path = Helpers.get_abs_path(input.destination_path)
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path
@@ -123,17 +125,19 @@ function M.move_path(opts, on_log, on_complete, session_ctx)
on_complete(true, nil)
end,
nil,
session_ctx,
opts.session_ctx,
"move_path"
)
end
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
function M.copy_path(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.source_path)
function M.copy_path(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.source_path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.destination_path)
local new_abs_path = Helpers.get_abs_path(input.destination_path)
if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path
end
@@ -169,14 +173,16 @@ function M.copy_path(opts, on_log, on_complete, session_ctx)
on_complete(true, nil)
end,
nil,
session_ctx,
opts.session_ctx,
"copy_path"
)
end
---@type AvanteLLMToolFunc<{ path: string }>
function M.delete_path(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.delete_path(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if not on_complete then return false, "on_complete not provided" end
@@ -188,12 +194,14 @@ function M.delete_path(opts, on_log, on_complete, session_ctx)
if on_log then on_log("Deleting path: " .. abs_path) end
os.remove(abs_path)
on_complete(true, nil)
end, nil, session_ctx, "delete_path")
end, nil, opts.session_ctx, "delete_path")
end
---@type AvanteLLMToolFunc<{ path: string }>
function M.create_dir(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.create_dir(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
if not on_complete then return false, "on_complete not provided" end
@@ -205,16 +213,17 @@ function M.create_dir(opts, on_log, on_complete, session_ctx)
if on_log then on_log("Creating directory: " .. abs_path) end
Path:new(abs_path):mkdir({ parents = true })
on_complete(true, nil)
end, nil, session_ctx, "create_dir")
end, nil, opts.session_ctx, "create_dir")
end
---@type AvanteLLMToolFunc<{ query: string }>
function M.web_search(opts, on_log, on_complete, session_ctx)
function M.web_search(input, opts)
local on_log = opts.on_log
local provider_type = Config.web_search_engine.provider
local proxy = Config.web_search_engine.proxy
if provider_type == nil then return nil, "Search engine provider is not set" end
if on_log then on_log("provider: " .. provider_type) end
if on_log then on_log("query: " .. opts.query) end
if on_log then on_log("query: " .. input.query) end
local search_engine = Config.web_search_engine.providers[provider_type]
if search_engine == nil then return nil, "No search engine found: " .. provider_type end
if provider_type ~= "searxng" and search_engine.api_key_name == "" then return nil, "No API key provided" end
@@ -229,7 +238,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
["Authorization"] = "Bearer " .. api_key,
},
body = vim.json.encode(vim.tbl_deep_extend("force", {
query = opts.query,
query = input.query,
}, search_engine.extra_request_body)),
}
if proxy then curl_opts.proxy = proxy end
@@ -240,7 +249,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
elseif provider_type == "serpapi" then
local query_params = vim.tbl_deep_extend("force", {
api_key = api_key,
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -259,7 +268,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
elseif provider_type == "searchapi" then
local query_params = vim.tbl_deep_extend("force", {
api_key = api_key,
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -283,7 +292,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
local query_params = vim.tbl_deep_extend("force", {
key = api_key,
cx = engine_id,
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -301,7 +310,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return search_engine.format_response_body(jsn)
elseif provider_type == "kagi" then
local query_params = vim.tbl_deep_extend("force", {
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -320,7 +329,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return search_engine.format_response_body(jsn)
elseif provider_type == "brave" then
local query_params = vim.tbl_deep_extend("force", {
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -343,7 +352,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return nil, "Environment variable " .. search_engine.api_url_name .. " is not set"
end
local query_params = vim.tbl_deep_extend("force", {
q = opts.query,
q = input.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
@@ -362,16 +371,18 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ url: string }>
function M.fetch(opts, on_log, on_complete, session_ctx)
if on_log then on_log("url: " .. opts.url) end
function M.fetch(input, opts)
local on_log = opts.on_log
if on_log then on_log("url: " .. input.url) end
local Html2Md = require("avante.html2md")
local res, err = Html2Md.fetch_md(opts.url)
local res, err = Html2Md.fetch_md(input.url)
if err then return nil, err end
return res, nil
end
---@type AvanteLLMToolFunc<{ scope?: string }>
function M.git_diff(opts, on_log, on_complete, session_ctx)
function M.git_diff(input, opts)
local on_log = opts.on_log
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return nil, "Git command not found" end
local project_root = Utils.get_project_root()
@@ -382,7 +393,7 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
if git_dir == "" then return nil, "Not a git repository" end
-- Get the diff
local scope = opts.scope or ""
local scope = input.scope or ""
local cmd = string.format("git diff --cached %s", scope)
if on_log then on_log("Running command: " .. cmd) end
local diff = vim.fn.system(cmd)
@@ -400,7 +411,10 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ message: string, scope?: string }>
function M.git_commit(opts, on_log, on_complete, session_ctx)
function M.git_commit(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return false, "Git command not found" end
local project_root = Utils.get_project_root()
@@ -444,7 +458,7 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
-- Prepare commit message
local commit_msg_lines = {}
for line in opts.message:gmatch("[^\r\n]+") do
for line in input.message:gmatch("[^\r\n]+") do
commit_msg_lines[#commit_msg_lines + 1] = line:gsub('"', '\\"')
end
commit_msg_lines[#commit_msg_lines + 1] = ""
@@ -464,8 +478,8 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
return
end
-- Stage changes if scope is provided
if opts.scope then
local stage_cmd = string.format("git add %s", opts.scope)
if input.scope then
local stage_cmd = string.format("git add %s", input.scope)
if on_log then on_log("Staging files: " .. stage_cmd) end
local stage_result = vim.fn.system(stage_cmd)
if vim.v.shell_error ~= 0 then
@@ -494,20 +508,23 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
end
on_complete(true, nil)
end, nil, session_ctx, "git_commit")
end, nil, opts.session_ctx, "git_commit")
end
---@type AvanteLLMToolFunc<{ query: string }>
function M.rag_search(opts, on_log, on_complete, session_ctx)
function M.rag_search(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
if not on_complete then return nil, "on_complete not provided" end
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end
if on_log then on_log("query: " .. opts.query) end
if not input.query then return nil, "No query provided" end
if on_log then on_log("query: " .. input.query) end
local root = Utils.get_project_root()
local uri = "file://" .. root
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
RagService.retrieve(
uri,
opts.query,
input.query,
vim.schedule_wrap(function(resp, err)
if err then
on_complete(nil, err)
@@ -519,13 +536,15 @@ function M.rag_search(opts, on_log, on_complete, session_ctx)
end
---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }>
function M.python(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.python(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end
if on_log then on_log("cwd: " .. abs_path) end
if on_log then on_log("code:\n" .. opts.code) end
local container_image = opts.container_image or "python:3.11-slim-bookworm"
if on_log then on_log("code:\n" .. input.code) end
local container_image = input.container_image or "python:3.11-slim-bookworm"
if not on_complete then return nil, "on_complete not provided" end
Helpers.confirm(
"Are you sure you want to run the following python code in the `"
@@ -533,7 +552,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
.. "` container, in the directory: `"
.. abs_path
.. "`?\n"
.. opts.code,
.. input.code,
function(ok, reason)
if not ok then
on_complete(nil, "User declined, reason: " .. (reason or "unknown"))
@@ -562,7 +581,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
container_image,
"python",
"-c",
opts.code,
input.code,
},
{
text = true,
@@ -576,7 +595,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
)
end,
nil,
session_ctx,
opts.session_ctx,
"python"
)
end
@@ -1189,14 +1208,19 @@ You can delete the first file by providing a path of "directory1/a/something.txt
--- compatibility alias for old calls & tests
M.run_python = M.python
---@class avante.ProcessToolUseOpts
---@field session_ctx table
---@field on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil
---@field on_complete? fun(result: string | nil, error: string | nil): nil
---@param tools AvanteLLMTool[]
---@param tool_use AvanteLLMToolUse
---@param on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@param on_complete? fun(result: string | nil, error: string | nil): nil
---@param session_ctx? table
---@return string | nil result
---@return string | nil error
function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
function M.process_tool_use(tools, tool_use, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
-- Check if execution is already cancelled
if Helpers.is_cancelled then
Utils.debug("Tool execution cancelled before starting: " .. tool_use.name)
@@ -1274,25 +1298,32 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
return result_str, err
end
local result, err = func(input_json, function(log)
-- Check for cancellation during logging
if Helpers.is_cancelled then return end
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
end, function(result, err)
-- Check for cancellation before completing
if Helpers.is_cancelled then
Helpers.is_cancelled = false
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
return
end
local result, err = func(input_json, {
session_ctx = opts.session_ctx or {},
on_log = function(log)
-- Check for cancellation during logging
if Helpers.is_cancelled then return end
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
end,
set_store = function(key, value)
if opts.set_tool_use_store then opts.set_tool_use_store(tool_use.id, key, value) end
end,
on_complete = function(result, err)
-- Check for cancellation before completing
if Helpers.is_cancelled then
Helpers.is_cancelled = false
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
return
end
result, err = handle_result(result, err)
if on_complete == nil then
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
return
end
on_complete(result, err)
end, session_ctx)
result, err = handle_result(result, err)
if on_complete == nil then
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
return
end
on_complete(result, err)
end,
})
-- Result and error being nil means that the tool was executed asynchronously
if result == nil and err == nil and on_complete then return end

View File

@@ -55,19 +55,24 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, insert_line: integer, new_str: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
if opts.insert_line == nil then return false, "insert_line not provided" end
if opts.new_str == nil then return false, "new_str not provided" end
if input.insert_line == nil then return false, "insert_line not provided" end
if input.new_str == nil then return false, "new_str not provided" end
local ns_id = vim.api.nvim_create_namespace("avante_insert_diff")
local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end
local function clear_highlights() vim.api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) end
local new_lines = vim.split(opts.new_str, "\n")
local new_lines = vim.split(input.new_str, "\n")
local max_col = vim.o.columns
local virt_lines = vim
.iter(new_lines)
@@ -78,8 +83,8 @@ function M.func(opts, on_log, on_complete, session_ctx)
end)
:totable()
local line_count = vim.api.nvim_buf_line_count(bufnr)
if opts.insert_line > line_count - 1 then opts.insert_line = line_count - 1 end
vim.api.nvim_buf_set_extmark(bufnr, ns_id, opts.insert_line, 0, {
if input.insert_line > line_count - 1 then input.insert_line = line_count - 1 end
vim.api.nvim_buf_set_extmark(bufnr, ns_id, input.insert_line, 0, {
virt_lines = virt_lines,
hl_eol = true,
hl_mode = "combine",
@@ -90,9 +95,9 @@ function M.func(opts, on_log, on_complete, session_ctx)
on_complete(false, "User declined, reason: " .. (reason or "unknown"))
return
end
vim.api.nvim_buf_set_lines(bufnr, opts.insert_line, opts.insert_line, false, new_lines)
vim.api.nvim_buf_set_lines(bufnr, input.insert_line, input.insert_line, false, new_lines)
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil)
end, { focus = true }, session_ctx, M.name)
end

View File

@@ -46,15 +46,16 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, max_depth?: integer }>
function M.func(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.path)
function M.func(input, opts)
local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("max depth: " .. tostring(opts.max_depth)) end
if on_log then on_log("max depth: " .. tostring(input.max_depth)) end
local files = Utils.scan_directory({
directory = abs_path,
add_dirs = true,
max_depth = opts.max_depth,
max_depth = input.max_depth,
})
local filepaths = {}
for _, file in ipairs(files) do

View File

@@ -125,39 +125,44 @@ end
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if opts.the_diff ~= nil then
opts.diff = opts.the_diff
opts.the_diff = nil
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if input.the_diff ~= nil then
input.diff = input.the_diff
input.the_diff = nil
end
if not opts.path or not opts.diff then return false, "path and diff are required " .. vim.inspect(opts) end
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
local is_streaming = opts.streaming or false
local is_streaming = input.streaming or false
session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {}
local current_timestamp = os.time()
if is_streaming then
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id]
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id]
if prev_streaming_diff_timestamp ~= nil then
if current_timestamp - prev_streaming_diff_timestamp < 2 then
return false, "Diff hasn't changed in the last 2 seconds"
end
end
local streaming_diff_lines_count = Utils.count_lines(opts.diff)
local streaming_diff_lines_count = Utils.count_lines(input.diff)
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id]
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[input.tool_use_id]
if streaming_diff_lines_count == prev_streaming_diff_lines_count then
return false, "Diff lines count hasn't changed"
end
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count
session_ctx.streaming_diff_lines_count_history[input.tool_use_id] = streaming_diff_lines_count
end
local diff = fix_diff(opts.diff)
local diff = fix_diff(input.diff)
if on_log and diff ~= opts.diff then on_log("diff fixed") end
if on_log and diff ~= input.diff then on_log("diff fixed") end
local diff_lines = vim.split(diff, "\n")
@@ -203,16 +208,16 @@ function M.func(opts, on_log, on_complete, session_ctx)
return false, "No diff blocks found"
end
session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] = current_timestamp
session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] = current_timestamp
local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end
session_ctx.undo_joined = session_ctx.undo_joined or {}
local undo_joined = session_ctx.undo_joined[opts.tool_use_id]
local undo_joined = session_ctx.undo_joined[input.tool_use_id]
if not undo_joined then
pcall(vim.cmd.undojoin)
session_ctx.undo_joined[opts.tool_use_id] = true
session_ctx.undo_joined[input.tool_use_id] = true
end
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
@@ -242,10 +247,10 @@ function M.func(opts, on_log, on_complete, session_ctx)
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {}
local rough_diff_blocks_to_diff_blocks_cache =
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id]
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id]
if not rough_diff_blocks_to_diff_blocks_cache then
rough_diff_blocks_to_diff_blocks_cache = {}
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
end
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
@@ -472,7 +477,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
on_complete(false, "User canceled")
return
end
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil)
end,
})
@@ -615,35 +620,35 @@ function M.func(opts, on_log, on_complete, session_ctx)
end
session_ctx.extmark_id_map = session_ctx.extmark_id_map or {}
local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id]
local extmark_id_map = session_ctx.extmark_id_map[input.tool_use_id]
if not extmark_id_map then
extmark_id_map = {}
session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map
session_ctx.extmark_id_map[input.tool_use_id] = extmark_id_map
end
session_ctx.virt_lines_map = session_ctx.virt_lines_map or {}
local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id]
local virt_lines_map = session_ctx.virt_lines_map[input.tool_use_id]
if not virt_lines_map then
virt_lines_map = {}
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
session_ctx.virt_lines_map[input.tool_use_id] = virt_lines_map
end
session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {}
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id]
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[input.tool_use_id]
if not last_orig_diff_end_line then
last_orig_diff_end_line = 1
session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line
session_ctx.last_orig_diff_end_line_map[input.tool_use_id] = last_orig_diff_end_line
end
session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {}
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id]
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[input.tool_use_id]
if not last_resp_diff_end_line then
last_resp_diff_end_line = 1
session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line
session_ctx.last_resp_diff_end_line_map[input.tool_use_id] = last_resp_diff_end_line
end
session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {}
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id]
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[input.tool_use_id]
if not prev_diff_blocks then
prev_diff_blocks = {}
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks
session_ctx.prev_diff_blocks_map[input.tool_use_id] = prev_diff_blocks
end
local function get_unstable_diff_blocks(diff_blocks_)
@@ -663,7 +668,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
local function highlight_streaming_diff_blocks()
local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks)
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks
session_ctx.prev_diff_blocks_map[input.tool_use_id] = diff_blocks
local max_col = vim.o.columns
for _, diff_block in ipairs(unstable_diff_blocks) do
local new_lines = diff_block.new_lines
@@ -747,7 +752,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
--- check if the parent dir is exists, if not, create it
if vim.fn.isdirectory(parent_dir) == 0 then vim.fn.mkdir(parent_dir, "p") end
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil)
end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx, M.name)
end

View File

@@ -1,5 +1,4 @@
local Base = require("avante.llm_tools.base")
local Config = require("avante.config")
---@class AvanteLLMTool
local M = setmetatable({}, Base)
@@ -55,17 +54,17 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx)
function M.func(input, opts)
local replace_in_file = require("avante.llm_tools.replace_in_file")
local diff = "------- SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str
if not opts.streaming then diff = diff .. "\n+++++++ REPLACE" end
local new_opts = {
path = opts.path,
local diff = "------- SEARCH\n" .. input.old_str .. "\n=======\n" .. input.new_str
if not input.streaming then diff = diff .. "\n+++++++ REPLACE" end
local new_input = {
path = input.path,
diff = diff,
streaming = opts.streaming,
tool_use_id = opts.tool_use_id,
streaming = input.streaming,
tool_use_id = input.tool_use_id,
}
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx)
return replace_in_file.func(new_input, opts)
end
return M

View File

@@ -47,12 +47,13 @@ M.returns = {
---@field thought string
---@type avante.LLMToolOnRender<ThinkingInput>
function M.on_render(opts, _, state)
function M.on_render(input, opts)
local state = opts.state
local lines = {}
local text = state == "generating" and "Thinking" or "Thoughts"
table.insert(lines, Line:new({ { Utils.icon("🤔 ") .. text, Highlights.AVANTE_THINKING } }))
table.insert(lines, Line:new({ { "" } }))
local content = opts.thought or ""
local content = input.thought or ""
local text_lines = vim.split(content, "\n")
for _, text_line in ipairs(text_lines) do
table.insert(lines, Line:new({ { "> " .. text_line } }))
@@ -61,7 +62,8 @@ function M.on_render(opts, _, state)
end
---@type AvanteLLMToolFunc<ThinkingInput>
function M.func(opts, on_log, on_complete, session_ctx)
function M.func(input, opts)
local on_complete = opts.on_complete
if not on_complete then return false, "on_complete not provided" end
on_complete(true, nil)
end

View File

@@ -43,9 +43,14 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
@@ -59,7 +64,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
end
vim.api.nvim_win_call(winid, function() vim.cmd("noautocmd undo") end)
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil)
end, { focus = true }, session_ctx, M.name)
end

View File

@@ -43,14 +43,15 @@ M.returns = {
M.on_render = function() return {} end
---@type AvanteLLMToolFunc<{ id: string, status: string }>
function M.func(opts, on_log, on_complete, session_ctx)
function M.func(input, opts)
local on_complete = opts.on_complete
local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end
local todos = sidebar.chat_history.todos
if not todos or #todos == 0 then return false, "No todos found" end
for _, todo in ipairs(todos) do
if todo.id == opts.id then
todo.status = opts.status
if todo.id == input.id then
todo.status = input.status
break
end
end

View File

@@ -87,18 +87,20 @@ M.returns = {
}
---@type AvanteLLMToolFunc<{ path: string, start_line?: integer, end_line?: integer }>
function M.func(opts, on_log, on_complete, session_ctx)
if not opts.path then return false, "path is required" end
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
if not input.path then return false, "path is required" end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if Path:new(abs_path):is_dir() then return false, "Path is a directory: " .. abs_path end
local file = io.open(abs_path, "r")
if not file then return false, "file not found: " .. abs_path end
local lines = Utils.read_file_from_buf_or_disk(abs_path)
local start_line = opts.start_line
local end_line = opts.end_line
local start_line = input.start_line
local end_line = input.end_line
if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end
local truncated_lines = {}
local is_truncated = false

View File

@@ -57,27 +57,26 @@ M.returns = {
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if opts.the_content ~= nil then
opts.content = opts.the_content
opts.the_content = nil
function M.func(input, opts)
if input.the_content ~= nil then
input.content = input.the_content
input.the_content = nil
end
if not on_complete then return false, "on_complete not provided" end
local abs_path = Helpers.get_abs_path(opts.path)
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.content == nil then return false, "content not provided" end
if type(opts.content) ~= "string" then opts.content = vim.json.encode(opts.content) end
if input.content == nil then return false, "content not provided" end
if type(input.content) ~= "string" then input.content = vim.json.encode(input.content) end
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
local old_content = table.concat(old_lines or {}, "\n")
local str_replace = require("avante.llm_tools.str_replace")
local new_opts = {
path = opts.path,
local new_input = {
path = input.path,
old_str = old_content,
new_str = opts.content,
streaming = opts.streaming,
tool_use_id = opts.tool_use_id,
new_str = input.content,
streaming = input.streaming,
tool_use_id = input.tool_use_id,
}
return str_replace.func(new_opts, on_log, on_complete, session_ctx)
return str_replace.func(new_input, opts)
end
return M

View File

@@ -2318,7 +2318,7 @@ function Sidebar:get_history_messages_for_api(opts)
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
if is_edit_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path)
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil)
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {})
if view_error then view_result = "Error: " .. view_error end
local get_diagnostics_tool_use_id = Utils.uuid()
local view_tool_use_id = Utils.uuid()
@@ -2451,9 +2451,7 @@ function Sidebar:get_history_messages_for_api(opts)
local end_line = tool_id_to_end_line[item.tool_use_id]
local view_result, view_error = require("avante.llm_tools.view").func(
{ path = path, start_line = start_line, end_line = end_line },
nil,
nil,
nil
{}
)
if view_error then view_result = "Error: " .. view_error end
item.content = view_result
@@ -2773,6 +2771,23 @@ function Sidebar:create_input_container()
self:save_history()
end
local function set_tool_use_store(tool_id, key, value)
local tool_use_message = nil
for idx = #self.chat_history.messages, 1, -1 do
local message = self.chat_history.messages[idx]
local content = message.message.content
if type(content) == "table" and content[1].type == "tool_use" and content[1].id == tool_id then
tool_use_message = message
break
end
end
if not tool_use_message then return end
local tool_use_store = tool_use_message.tool_use_store or {}
tool_use_store[key] = value
tool_use_message.tool_use_store = tool_use_store
self:save_history()
end
---@type AvanteLLMStopCallback
local function on_stop(stop_opts)
self.is_generating = false
@@ -2837,6 +2852,7 @@ function Sidebar:create_input_container()
on_tool_log = on_tool_log,
on_messages_add = on_messages_add,
on_state_change = on_state_change,
set_tool_use_store = set_tool_use_store,
get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end,
get_todos = function()
local history = Path.history.load(self.code.bufnr)

View File

@@ -107,6 +107,7 @@ vim.g.avante_login = vim.g.avante_login
---@field selected_code AvanteSelectedCode | nil
---@field selected_filepaths string[] | nil
---@field tool_use_logs string[] | nil
---@field tool_use_store table | nil
---@field just_for_display boolean | nil
---@field is_dummy boolean | nil
---@field is_compacted boolean | nil
@@ -407,19 +408,30 @@ vim.g.avante_login = vim.g.avante_login
---@field on_stop AvanteLLMStopCallback
---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback
---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil
---@field get_history_messages? fun(opts?: { all?: boolean }): avante.HistoryMessage[]
---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil
---@field on_state_change? fun(state: avante.GenerateState): nil
---@field update_tokens_usage? fun(usage: avante.LLMTokenUsage): nil
---
---@class AvanteLLMToolFuncOpts
---@field session_ctx table
---@field on_complete? fun(result: boolean | string | nil, error: string | nil): nil
---@field on_log? fun(log: string): nil
---@field set_store? fun(key: string, value: any): nil
---
---@alias AvanteLLMToolFunc<T> fun(
--- input: T,
--- on_log?: (fun(log: string): nil),
--- on_complete?: (fun(result: boolean | string | nil, error: string | nil): nil),
--- session_ctx?: table)
--- opts: AvanteLLMToolFuncOpts)
--- : (boolean | string | nil, string | nil)
---
--- @alias avante.LLMToolOnRender<T> fun(input: T, logs: string[], state: avante.HistoryMessageState | nil): avante.ui.Line[]
---@class avante.LLMToolOnRenderOpts
---@field logs string[]
---@field state avante.HistoryMessageState
---@field store table | nil
---@field result_message avante.HistoryMessage | nil
---
--- @alias avante.LLMToolOnRender<T> fun(input: T, opts: avante.LLMToolOnRenderOpts): avante.ui.Line[]
---
---@class AvanteLLMTool
---@field name string

View File

@@ -1674,14 +1674,22 @@ function M.message_content_item_to_lines(item, message, messages)
return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) }
end
if item.type == "tool_use" then
local tool_result_message = M.get_tool_result_message(message, messages)
local lines = {}
local state = "generating"
local hl = "AvanteStateSpinnerToolCalling"
local ok, llm_tool = pcall(require, "avante.llm_tools." .. item.name)
if ok then
if llm_tool.on_render then return llm_tool.on_render(item.input, message.tool_use_logs, message.state) end
---@cast llm_tool AvanteLLMTool
if llm_tool.on_render then
return llm_tool.on_render(item.input, {
logs = message.tool_use_logs,
state = message.state,
store = message.tool_use_store,
result_message = tool_result_message,
})
end
end
local tool_result_message = M.get_tool_result_message(message, messages)
if tool_result_message then
local tool_result = tool_result_message.message.content[1]
if tool_result.is_error then