refactor: llm tool parameters (#2449)

This commit is contained in:
yetone
2025-07-15 16:40:25 +08:00
committed by GitHub
parent 0c6a8f5688
commit b8bb0fd969
25 changed files with 627 additions and 381 deletions

View File

@@ -47,7 +47,9 @@ local Highlights = {
AVANTE_STATE_SPINNER_SEARCHING = { name = "AvanteStateSpinnerSearching", fg = "#1e222a", bg = "#c678dd" }, AVANTE_STATE_SPINNER_SEARCHING = { name = "AvanteStateSpinnerSearching", fg = "#1e222a", bg = "#c678dd" },
AVANTE_STATE_SPINNER_THINKING = { name = "AvanteStateSpinnerThinking", fg = "#1e222a", bg = "#c678dd" }, AVANTE_STATE_SPINNER_THINKING = { name = "AvanteStateSpinnerThinking", fg = "#1e222a", bg = "#c678dd" },
AVANTE_STATE_SPINNER_COMPACTING = { name = "AvanteStateSpinnerCompacting", fg = "#1e222a", bg = "#c678dd" }, AVANTE_STATE_SPINNER_COMPACTING = { name = "AvanteStateSpinnerCompacting", fg = "#1e222a", bg = "#c678dd" },
AVANTE_TASK_RUNNING = { name = "AvanteTaskRunning", fg = "#c678dd", bg_link = "Normal" },
AVANTE_TASK_COMPLETED = { name = "AvanteTaskCompleted", fg = "#98c379", bg_link = "Normal" }, AVANTE_TASK_COMPLETED = { name = "AvanteTaskCompleted", fg = "#98c379", bg_link = "Normal" },
AVANTE_TASK_FAILED = { name = "AvanteTaskFailed", fg = "#e06c75", bg_link = "Normal" },
AVANTE_THINKING = { name = "AvanteThinking", fg = "#c678dd", bg_link = "Normal" }, AVANTE_THINKING = { name = "AvanteThinking", fg = "#c678dd", bg_link = "Normal" },
} }

View File

@@ -150,8 +150,11 @@ function M.generate_todos(user_input, cb)
local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages) local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages)
for _, partial_tool_use in ipairs(uncalled_tool_uses) do for _, partial_tool_use in ipairs(uncalled_tool_uses) do
if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then
LLMTools.process_tool_use(tools, partial_tool_use, function() end, function() cb() end, {}) local result = LLMTools.process_tool_use(tools, partial_tool_use, {
cb() session_ctx = {},
on_complete = function() cb() end,
})
if result ~= nil then cb() end
end end
end end
else else
@@ -206,6 +209,7 @@ function M.agent_loop(opts)
table.insert(history_messages, msg) table.insert(history_messages, msg)
end end
end end
if opts.on_messages_add then opts.on_messages_add(msgs) end
end, end,
session_ctx = session_ctx, session_ctx = session_ctx,
prompt_opts = { prompt_opts = {
@@ -331,8 +335,9 @@ function M.generate_prompts(opts)
end end
if Config.system_prompt ~= nil then if Config.system_prompt ~= nil then
local custom_system_prompt = Config.system_prompt local custom_system_prompt
if type(custom_system_prompt) == "function" then custom_system_prompt = custom_system_prompt() end if type(Config.system_prompt) == "function" then custom_system_prompt = Config.system_prompt() end
if type(Config.system_prompt) == "string" then custom_system_prompt = Config.system_prompt end
if custom_system_prompt ~= nil and custom_system_prompt ~= "" and custom_system_prompt ~= "null" then if custom_system_prompt ~= nil and custom_system_prompt ~= "" and custom_system_prompt ~= "null" then
system_prompt = system_prompt .. "\n\n" .. custom_system_prompt system_prompt = system_prompt .. "\n\n" .. custom_system_prompt
end end
@@ -841,13 +846,9 @@ function M._stream(opts)
if partial_tool_use.state == "generating" then if partial_tool_use.state == "generating" then
if type(partial_tool_use.input) == "table" then if type(partial_tool_use.input) == "table" then
partial_tool_use.input.streaming = true partial_tool_use.input.streaming = true
LLMTools.process_tool_use( LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, {
prompt_opts.tools, session_ctx = opts.session_ctx,
partial_tool_use, })
function() end,
function() end,
opts.session_ctx
)
end end
return return
else else
@@ -856,13 +857,12 @@ function M._stream(opts)
partial_tool_use_message.is_calling = true partial_tool_use_message.is_calling = true
if opts.on_messages_add then opts.on_messages_add({ partial_tool_use_message }) end if opts.on_messages_add then opts.on_messages_add({ partial_tool_use_message }) end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil -- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use( local result, error = LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, {
prompt_opts.tools, session_ctx = opts.session_ctx,
partial_tool_use, on_log = opts.on_tool_log,
opts.on_tool_log, set_tool_use_store = opts.set_tool_use_store,
handle_tool_result, on_complete = handle_tool_result,
opts.session_ctx })
)
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end end
if stop_opts.reason == "cancelled" then if stop_opts.reason == "cancelled" then
@@ -1100,6 +1100,13 @@ function M.stream(opts)
return original_on_tool_log(...) return original_on_tool_log(...)
end) end)
end end
if opts.set_tool_use_store ~= nil then
local original_set_tool_use_store = opts.set_tool_use_store
opts.set_tool_use_store = vim.schedule_wrap(function(...)
if not original_set_tool_use_store then return end
return original_set_tool_use_store(...)
end)
end
if opts.on_chunk ~= nil then if opts.on_chunk ~= nil then
local original_on_chunk = opts.on_chunk local original_on_chunk = opts.on_chunk
opts.on_chunk = vim.schedule_wrap(function(chunk) opts.on_chunk = vim.schedule_wrap(function(chunk)

View File

@@ -65,10 +65,11 @@ M.returns = {
M.on_render = function() return {} end M.on_render = function() return {} end
---@type AvanteLLMToolFunc<{ todos: avante.TODO[] }> ---@type AvanteLLMToolFunc<{ todos: avante.TODO[] }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local on_complete = opts.on_complete
local sidebar = require("avante").get() local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end if not sidebar then return false, "Avante sidebar not found" end
local todos = opts.todos local todos = input.todos
if not todos or #todos == 0 then return false, "No todos provided" end if not todos or #todos == 0 then return false, "No todos provided" end
sidebar:update_todos(todos) sidebar:update_todos(todos)
if on_complete then if on_complete then

View File

@@ -57,11 +57,11 @@ M.returns = {
} }
---@type avante.LLMToolOnRender<AttemptCompletionInput> ---@type avante.LLMToolOnRender<AttemptCompletionInput>
function M.on_render(opts) function M.on_render(input)
local lines = {} local lines = {}
table.insert(lines, Line:new({ { "✓ Task Completed", Highlights.AVANTE_TASK_COMPLETED } })) table.insert(lines, Line:new({ { "✓ Task Completed", Highlights.AVANTE_TASK_COMPLETED } }))
table.insert(lines, Line:new({ { "" } })) table.insert(lines, Line:new({ { "" } }))
local result = opts.result or "" local result = input.result or ""
local text_lines = vim.split(result, "\n") local text_lines = vim.split(result, "\n")
for _, text_line in ipairs(text_lines) do for _, text_line in ipairs(text_lines) do
table.insert(lines, Line:new({ { text_line } })) table.insert(lines, Line:new({ { text_line } }))
@@ -70,24 +70,29 @@ function M.on_render(opts)
end end
---@type AvanteLLMToolFunc<AttemptCompletionInput> ---@type AvanteLLMToolFunc<AttemptCompletionInput>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
if not on_complete then return false, "on_complete not provided" end if not opts.on_complete then return false, "on_complete not provided" end
local sidebar = require("avante").get() local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end if not sidebar then return false, "Avante sidebar not found" end
local is_streaming = opts.streaming or false local is_streaming = input.streaming or false
if is_streaming then if is_streaming then
-- wait for stream completion as command may not be complete yet -- wait for stream completion as command may not be complete yet
return return
end end
session_ctx.attempt_completion_is_called = true opts.session_ctx.attempt_completion_is_called = true
if opts.command and opts.command ~= vim.NIL and opts.command ~= "" and not vim.startswith(opts.command, "open ") then if
session_ctx.always_yes = false input.command
require("avante.llm_tools.bash").func({ command = opts.command }, on_log, on_complete, session_ctx) and input.command ~= vim.NIL
and input.command ~= ""
and not vim.startswith(input.command, "open ")
then
opts.session_ctx.always_yes = false
require("avante.llm_tools.bash").func({ command = input.command }, opts)
else else
on_complete(true, nil) opts.on_complete(true, nil)
end end
end end

View File

@@ -215,18 +215,18 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, command: string, streaming?: boolean }> ---@type AvanteLLMToolFunc<{ path: string, command: string, streaming?: boolean }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local is_streaming = opts.streaming or false local is_streaming = input.streaming or false
if is_streaming then if is_streaming then
-- wait for stream completion as command may not be complete yet -- wait for stream completion as command may not be complete yet
return return
end end
local abs_path = Helpers.get_abs_path(opts.path) local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if not opts.command then return false, "Command is required" end if not input.command then return false, "Command is required" end
if on_log then on_log("command: " .. opts.command) end if opts.on_log then opts.on_log("command: " .. input.command) end
---change cwd to abs_path ---change cwd to abs_path
---@param output string ---@param output string
@@ -240,21 +240,21 @@ function M.func(opts, on_log, on_complete, session_ctx)
end end
return output, nil return output, nil
end end
if not on_complete then return false, "on_complete not provided" end if not opts.on_complete then return false, "on_complete not provided" end
Helpers.confirm( Helpers.confirm(
"Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path, "Are you sure you want to run the command: `" .. input.command .. "` in the directory: " .. abs_path,
function(ok, reason) function(ok, reason)
if not ok then if not ok then
on_complete(false, "User declined, reason: " .. (reason and reason or "unknown")) opts.on_complete(false, "User declined, reason: " .. (reason and reason or "unknown"))
return return
end end
Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code) Utils.shell_run_async(input.command, "bash -c", function(output, exit_code)
local result, err = handle_result(output, exit_code) local result, err = handle_result(output, exit_code)
on_complete(result, err) opts.on_complete(result, err)
end, abs_path) end, abs_path)
end, end,
{ focus = true }, { focus = true },
session_ctx, opts.session_ctx,
M.name -- Pass the tool name for permission checking M.name -- Pass the tool name for permission checking
) )
end end

View File

@@ -45,20 +45,23 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, file_text: string }> ---@type AvanteLLMToolFunc<{ path: string, file_text: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("path: " .. opts.path) end if on_log then on_log("path: " .. input.path) end
if Helpers.already_in_context(opts.path) then if Helpers.already_in_context(input.path) then
on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to create it again?") on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to create it again?")
return return
end end
local abs_path = Helpers.get_abs_path(opts.path) local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.file_text == nil then return false, "file_text not provided" end if input.file_text == nil then return false, "file_text not provided" end
if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end
local lines = vim.split(opts.file_text, "\n") local lines = vim.split(input.file_text, "\n")
if #lines == 1 and opts.file_text:match("\\n") then if #lines == 1 and input.file_text:match("\\n") then
local text = Utils.trim_slashes(opts.file_text) local text = Utils.trim_slashes(input.file_text)
lines = vim.split(text, "\n") lines = vim.split(text, "\n")
end end
local bufnr, err = Helpers.get_bufnr(abs_path) local bufnr, err = Helpers.get_bufnr(abs_path)

View File

@@ -40,7 +40,7 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ tool_use_id: string }> ---@type AvanteLLMToolFunc<{ tool_use_id: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local sidebar = require("avante").get() local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end if not sidebar then return false, "Avante sidebar not found" end
local history_messages = Utils.get_history_messages(sidebar.chat_history) local history_messages = Utils.get_history_messages(sidebar.chat_history)
@@ -50,7 +50,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
local content = msg.message.content local content = msg.message.content
if type(content) == "table" then if type(content) == "table" then
for _, item in ipairs(content) do for _, item in ipairs(content) do
if item.id == opts.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end if item.id == input.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end
end end
end end
end end

View File

@@ -3,6 +3,8 @@ local Config = require("avante.config")
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Base = require("avante.llm_tools.base") local Base = require("avante.llm_tools.base")
local HistoryMessage = require("avante.history_message") local HistoryMessage = require("avante.history_message")
local Line = require("avante.ui.line")
local Highlights = require("avante.highlights")
---@class AvanteLLMTool ---@class AvanteLLMTool
local M = setmetatable({}, Base) local M = setmetatable({}, Base)
@@ -79,12 +81,118 @@ local function get_available_tools()
} }
end end
---@type AvanteLLMToolFunc<{ prompt: string }> ---@class avante.DispatchAgentInput
function M.func(opts, on_log, on_complete, session_ctx) ---@field prompt string
---@type avante.LLMToolOnRender<avante.DispatchAgentInput>
function M.on_render(input, opts)
local result_message = opts.result_message
local store = opts.store or {}
local messages = store.messages or {}
local tool_use_summary = {}
for _, msg in ipairs(messages) do
local content = msg.message.content
local summary
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
local tool_result_message = Utils.get_tool_result_message(msg, messages)
if tool_result_message then
local tool_name = msg.message.content[1].name
if tool_name == "ls" then
local path = msg.message.content[1].input.path
if tool_result_message.message.content[1].is_error then
summary = string.format("Ls %s: failed", path)
else
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then summary = string.format("Ls %s: %d paths", path, #filepaths) end
end
elseif tool_name == "grep" then
local path = msg.message.content[1].input.path
local query = msg.message.content[1].input.query
if tool_result_message.message.content[1].is_error then
summary = string.format("Grep %s in %s: failed", query, path)
else
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then summary = string.format("Grep %s in %s: %d paths", query, path, #filepaths) end
end
elseif tool_name == "glob" then
local path = msg.message.content[1].input.path
local pattern = msg.message.content[1].input.pattern
if tool_result_message.message.content[1].is_error then
summary = string.format("Glob %s in %s: failed", pattern, path)
else
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then
local matches = result.matches
if matches then summary = string.format("Glob %s in %s: %d matches", pattern, path, #matches) end
end
end
elseif tool_name == "view" then
local path = msg.message.content[1].input.path
if tool_result_message.message.content[1].is_error then
summary = string.format("View %s: failed", path)
else
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then
local content_ = result.content
local lines = vim.split(content_, "\n")
summary = string.format("View %s: %d lines", path, #lines)
end
end
end
end
if summary then summary = " " .. Utils.icon("🛠️ ") .. summary end
elseif type(content) == "table" and #content > 0 and type(content[1]) == "table" and content[1].type == "text" then
summary = content[1].content
elseif type(content) == "table" and #content > 0 and type(content[1]) == "string" then
summary = content[1]
elseif type(content) == "string" then
summary = content
end
if summary then table.insert(tool_use_summary, summary) end
end
local state = "running"
local icon = Utils.icon("🔄 ")
local hl = Highlights.AVANTE_TASK_RUNNING
if result_message then
if result_message.message.content[1].is_error then
state = "failed"
icon = Utils.icon("")
hl = Highlights.AVANTE_TASK_FAILED
else
state = "completed"
icon = Utils.icon("")
hl = Highlights.AVANTE_TASK_COMPLETED
end
end
local lines = {}
table.insert(lines, Line:new({ { icon .. "Subtask " .. state, hl } }))
table.insert(lines, Line:new({ { "" } }))
table.insert(lines, Line:new({ { " Task:" } }))
local prompt_lines = vim.split(input.prompt or "", "\n")
for _, line in ipairs(prompt_lines) do
table.insert(lines, Line:new({ { " " .. line } }))
end
table.insert(lines, Line:new({ { "" } }))
table.insert(lines, Line:new({ { " Task summary:" } }))
for _, summary in ipairs(tool_use_summary) do
local summary_lines = vim.split(summary, "\n")
for _, line in ipairs(summary_lines) do
table.insert(lines, Line:new({ { " " .. line } }))
end
end
return lines
end
---@type AvanteLLMToolFunc<avante.DispatchAgentInput>
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
local Llm = require("avante.llm") local Llm = require("avante.llm")
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
local prompt = opts.prompt local prompt = input.prompt
local tools = get_available_tools() local tools = get_available_tools()
local start_time = Utils.get_timestamp() local start_time = Utils.get_timestamp()
@@ -95,10 +203,11 @@ Your task is to help the user with their request: "${prompt}"
Be thorough and use the tools available to you to find the most relevant information. Be thorough and use the tools available to you to find the most relevant information.
When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt) When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt)
local history_messages = {}
local tool_use_messages = {} local tool_use_messages = {}
local total_tokens = 0 local total_tokens = 0
local final_response = "" local result = ""
---@type avante.AgentLoopOptions ---@type avante.AgentLoopOptions
local agent_loop_options = { local agent_loop_options = {
@@ -108,19 +217,37 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
on_tool_log = session_ctx.on_tool_log, on_tool_log = session_ctx.on_tool_log,
on_messages_add = function(msgs) on_messages_add = function(msgs)
msgs = vim.islist(msgs) and msgs or { msgs } msgs = vim.islist(msgs) and msgs or { msgs }
for _, msg in ipairs(msgs) do
local idx = nil
for i, m in ipairs(history_messages) do
if m.uuid == msg.uuid then
idx = i
break
end
end
if idx ~= nil then
history_messages[idx] = msg
else
table.insert(history_messages, msg)
end
end
if opts.set_store then opts.set_store("messages", history_messages) end
for _, msg in ipairs(msgs) do for _, msg in ipairs(msgs) do
local content = msg.message.content local content = msg.message.content
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
tool_use_messages[msg.uuid] = true tool_use_messages[msg.uuid] = true
if content[1].name == "attempt_completion" then
local input_ = content[1].input
if input_ and input_.result then result = input_.result end
end
end end
end end
if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end -- if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
end, end,
session_ctx = session_ctx, session_ctx = session_ctx,
on_start = session_ctx.on_start, on_start = session_ctx.on_start,
on_chunk = function(chunk) on_chunk = function(chunk)
if not chunk then return end if not chunk then return end
final_response = final_response .. chunk
total_tokens = total_tokens + (#vim.split(chunk, " ") * 1.3) total_tokens = total_tokens + (#vim.split(chunk, " ") * 1.3)
end, end,
on_complete = function(err) on_complete = function(err)
@@ -148,8 +275,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
}) })
session_ctx.on_messages_add({ message }) session_ctx.on_messages_add({ message })
end end
local response = string.format("Final response:\n%s\n\nSummary:\n%s", summary, final_response) on_complete(result, nil)
on_complete(response, nil)
end, end,
} }

View File

@@ -40,10 +40,12 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, diff: string }> ---@type AvanteLLMToolFunc<{ path: string, diff: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
if not opts.path then return false, "pathf are required" end local on_log = opts.on_log
if on_log then on_log("path: " .. opts.path) end local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(opts.path) if not input.path then return false, "pathf are required" end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not on_complete then return false, "on_complete is required" end if not on_complete then return false, "on_complete is required" end
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(abs_path) local diagnostics = Utils.lsp.get_diagnostics_from_filepath(abs_path)

View File

@@ -45,12 +45,14 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, pattern: string }> ---@type AvanteLLMToolFunc<{ path: string, pattern: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local abs_path = Helpers.get_abs_path(opts.path) local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("pattern: " .. opts.pattern) end if on_log then on_log("pattern: " .. input.pattern) end
local files = vim.fn.glob(abs_path .. "/" .. opts.pattern, true, true) local files = vim.fn.glob(abs_path .. "/" .. input.pattern, true, true)
local truncated_files = {} local truncated_files = {}
local is_truncated = false local is_truncated = false
local size = 0 local size = 0

View File

@@ -69,8 +69,10 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }> ---@type AvanteLLMToolFunc<{ path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local abs_path = Helpers.get_abs_path(opts.path) local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end
@@ -85,31 +87,31 @@ function M.func(opts, on_log, on_complete, session_ctx)
local cmd = "" local cmd = ""
if search_cmd:find("rg") then if search_cmd:find("rg") then
cmd = string.format("%s --files-with-matches --hidden", search_cmd) cmd = string.format("%s --files-with-matches --hidden", search_cmd)
if opts.case_sensitive then if input.case_sensitive then
cmd = string.format("%s --case-sensitive", cmd) cmd = string.format("%s --case-sensitive", cmd)
else else
cmd = string.format("%s --ignore-case", cmd) cmd = string.format("%s --ignore-case", cmd)
end end
if opts.include_pattern then cmd = string.format("%s --glob '%s'", cmd, opts.include_pattern) end if input.include_pattern then cmd = string.format("%s --glob '%s'", cmd, input.include_pattern) end
if opts.exclude_pattern then cmd = string.format("%s --glob '!%s'", cmd, opts.exclude_pattern) end if input.exclude_pattern then cmd = string.format("%s --glob '!%s'", cmd, input.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path) cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
elseif search_cmd:find("ag") then elseif search_cmd:find("ag") then
cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd) cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd)
if opts.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) end if input.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) end
if opts.include_pattern then cmd = string.format("%s --ignore '!%s'", cmd, opts.include_pattern) end if input.include_pattern then cmd = string.format("%s --ignore '!%s'", cmd, input.include_pattern) end
if opts.exclude_pattern then cmd = string.format("%s --ignore '%s'", cmd, opts.exclude_pattern) end if input.exclude_pattern then cmd = string.format("%s --ignore '%s'", cmd, input.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path) cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
elseif search_cmd:find("ack") then elseif search_cmd:find("ack") then
cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd) cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd)
if opts.case_sensitive then cmd = string.format("%s --smart-case", cmd) end if input.case_sensitive then cmd = string.format("%s --smart-case", cmd) end
if opts.exclude_pattern then cmd = string.format("%s --ignore-dir '%s'", cmd, opts.exclude_pattern) end if input.exclude_pattern then cmd = string.format("%s --ignore-dir '%s'", cmd, input.exclude_pattern) end
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path) cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
elseif search_cmd:find("grep") then elseif search_cmd:find("grep") then
cmd = string.format("cd %s && git ls-files -co --exclude-standard | xargs %s -rH", abs_path, search_cmd, abs_path) cmd = string.format("cd %s && git ls-files -co --exclude-standard | xargs %s -rH", abs_path, search_cmd, abs_path)
if not opts.case_sensitive then cmd = string.format("%s -i", cmd) end if not input.case_sensitive then cmd = string.format("%s -i", cmd) end
if opts.include_pattern then cmd = string.format("%s --include '%s'", cmd, opts.include_pattern) end if input.include_pattern then cmd = string.format("%s --include '%s'", cmd, input.include_pattern) end
if opts.exclude_pattern then cmd = string.format("%s --exclude '%s'", cmd, opts.exclude_pattern) end if input.exclude_pattern then cmd = string.format("%s --exclude '%s'", cmd, input.exclude_pattern) end
cmd = string.format("%s '%s'", cmd, opts.query) cmd = string.format("%s '%s'", cmd, input.query)
end end
Utils.debug("cmd", cmd) Utils.debug("cmd", cmd)

View File

@@ -8,9 +8,10 @@ local Helpers = require("avante.llm_tools.helpers")
local M = {} local M = {}
---@type AvanteLLMToolFunc<{ path: string }> ---@type AvanteLLMToolFunc<{ path: string }>
function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx) function M.read_file_toplevel_symbols(input, opts)
local on_log = opts.on_log
local RepoMap = require("avante.repo_map") local RepoMap = require("avante.repo_map")
local abs_path = Helpers.get_abs_path(opts.path) local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end if on_log then on_log("path: " .. abs_path) end
if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end
@@ -24,50 +25,47 @@ function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
end end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }> ---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
function M.str_replace_editor(opts, on_log, on_complete, session_ctx) function M.str_replace_editor(input, opts)
if opts.command == "undo_edit" then if input.command == "undo_edit" then return require("avante.llm_tools.undo_edit").func(input, opts) end
return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx) ---@cast input any
end return M.str_replace_based_edit_tool(input, opts)
---@cast opts any
return M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
end end
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }> ---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }>
function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx) function M.str_replace_based_edit_tool(input, opts)
if not opts.command then return false, "command not provided" end local on_log = opts.on_log
if on_log then on_log("command: " .. opts.command) end local on_complete = opts.on_complete
if not input.command then return false, "command not provided" end
if on_log then on_log("command: " .. input.command) end
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
local abs_path = Helpers.get_abs_path(opts.path) local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.command == "view" then if input.command == "view" then
local view = require("avante.llm_tools.view") local view = require("avante.llm_tools.view")
local opts_ = { path = opts.path } local input_ = { path = input.path }
if opts.view_range then if input.view_range then
local start_line, end_line = unpack(opts.view_range) local start_line, end_line = unpack(input.view_range)
opts_.start_line = start_line input_.start_line = start_line
opts_.end_line = end_line input_.end_line = end_line
end end
return view(opts_, on_log, on_complete, session_ctx) return view(input_, opts)
end end
if opts.command == "str_replace" then if input.command == "str_replace" then
if opts.new_str == nil and opts.file_text ~= nil then if input.new_str == nil and input.file_text ~= nil then
opts.new_str = opts.file_text input.new_str = input.file_text
opts.file_text = nil input.file_text = nil
end end
return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx) return require("avante.llm_tools.str_replace").func(input, opts)
end end
if opts.command == "create" then if input.command == "create" then return require("avante.llm_tools.create").func(input, opts) end
return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx) if input.command == "insert" then return require("avante.llm_tools.insert").func(input, opts) end
end return false, "Unknown command: " .. input.command
if opts.command == "insert" then
return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx)
end
return false, "Unknown command: " .. opts.command
end end
---@type AvanteLLMToolFunc<{ abs_path: string }> ---@type AvanteLLMToolFunc<{ abs_path: string }>
function M.read_global_file(opts, on_log, on_complete, session_ctx) function M.read_global_file(input, opts)
local abs_path = Helpers.get_abs_path(opts.abs_path) local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(input.abs_path)
if Helpers.is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end if Helpers.is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end if on_log then on_log("path: " .. abs_path) end
local file = io.open(abs_path, "r") local file = io.open(abs_path, "r")
@@ -78,11 +76,13 @@ function M.read_global_file(opts, on_log, on_complete, session_ctx)
end end
---@type AvanteLLMToolFunc<{ abs_path: string, content: string }> ---@type AvanteLLMToolFunc<{ abs_path: string, content: string }>
function M.write_global_file(opts, on_log, on_complete, session_ctx) function M.write_global_file(input, opts)
local abs_path = Helpers.get_abs_path(opts.abs_path) local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.abs_path)
if Helpers.is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end if Helpers.is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("content: " .. opts.content) end if on_log then on_log("content: " .. input.content) end
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
Helpers.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok) Helpers.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok)
if not ok then if not ok then
@@ -94,18 +94,20 @@ function M.write_global_file(opts, on_log, on_complete, session_ctx)
on_complete(false, "file not found: " .. abs_path) on_complete(false, "file not found: " .. abs_path)
return return
end end
file:write(opts.content) file:write(input.content)
file:close() file:close()
on_complete(true, nil) on_complete(true, nil)
end, nil, session_ctx, "write_global_file") end, nil, opts.session_ctx, "write_global_file")
end end
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }> ---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
function M.move_path(opts, on_log, on_complete, session_ctx) function M.move_path(input, opts)
local abs_path = Helpers.get_abs_path(opts.source_path) local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.source_path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.destination_path) local new_abs_path = Helpers.get_abs_path(input.destination_path)
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
if not Helpers.has_permission_to_access(new_abs_path) then if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path return false, "No permission to access path: " .. new_abs_path
@@ -123,17 +125,19 @@ function M.move_path(opts, on_log, on_complete, session_ctx)
on_complete(true, nil) on_complete(true, nil)
end, end,
nil, nil,
session_ctx, opts.session_ctx,
"move_path" "move_path"
) )
end end
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }> ---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
function M.copy_path(opts, on_log, on_complete, session_ctx) function M.copy_path(input, opts)
local abs_path = Helpers.get_abs_path(opts.source_path) local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.source_path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.destination_path) local new_abs_path = Helpers.get_abs_path(input.destination_path)
if not Helpers.has_permission_to_access(new_abs_path) then if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path return false, "No permission to access path: " .. new_abs_path
end end
@@ -169,14 +173,16 @@ function M.copy_path(opts, on_log, on_complete, session_ctx)
on_complete(true, nil) on_complete(true, nil)
end, end,
nil, nil,
session_ctx, opts.session_ctx,
"copy_path" "copy_path"
) )
end end
---@type AvanteLLMToolFunc<{ path: string }> ---@type AvanteLLMToolFunc<{ path: string }>
function M.delete_path(opts, on_log, on_complete, session_ctx) function M.delete_path(input, opts)
local abs_path = Helpers.get_abs_path(opts.path) local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
@@ -188,12 +194,14 @@ function M.delete_path(opts, on_log, on_complete, session_ctx)
if on_log then on_log("Deleting path: " .. abs_path) end if on_log then on_log("Deleting path: " .. abs_path) end
os.remove(abs_path) os.remove(abs_path)
on_complete(true, nil) on_complete(true, nil)
end, nil, session_ctx, "delete_path") end, nil, opts.session_ctx, "delete_path")
end end
---@type AvanteLLMToolFunc<{ path: string }> ---@type AvanteLLMToolFunc<{ path: string }>
function M.create_dir(opts, on_log, on_complete, session_ctx) function M.create_dir(input, opts)
local abs_path = Helpers.get_abs_path(opts.path) local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
@@ -205,16 +213,17 @@ function M.create_dir(opts, on_log, on_complete, session_ctx)
if on_log then on_log("Creating directory: " .. abs_path) end if on_log then on_log("Creating directory: " .. abs_path) end
Path:new(abs_path):mkdir({ parents = true }) Path:new(abs_path):mkdir({ parents = true })
on_complete(true, nil) on_complete(true, nil)
end, nil, session_ctx, "create_dir") end, nil, opts.session_ctx, "create_dir")
end end
---@type AvanteLLMToolFunc<{ query: string }> ---@type AvanteLLMToolFunc<{ query: string }>
function M.web_search(opts, on_log, on_complete, session_ctx) function M.web_search(input, opts)
local on_log = opts.on_log
local provider_type = Config.web_search_engine.provider local provider_type = Config.web_search_engine.provider
local proxy = Config.web_search_engine.proxy local proxy = Config.web_search_engine.proxy
if provider_type == nil then return nil, "Search engine provider is not set" end if provider_type == nil then return nil, "Search engine provider is not set" end
if on_log then on_log("provider: " .. provider_type) end if on_log then on_log("provider: " .. provider_type) end
if on_log then on_log("query: " .. opts.query) end if on_log then on_log("query: " .. input.query) end
local search_engine = Config.web_search_engine.providers[provider_type] local search_engine = Config.web_search_engine.providers[provider_type]
if search_engine == nil then return nil, "No search engine found: " .. provider_type end if search_engine == nil then return nil, "No search engine found: " .. provider_type end
if provider_type ~= "searxng" and search_engine.api_key_name == "" then return nil, "No API key provided" end if provider_type ~= "searxng" and search_engine.api_key_name == "" then return nil, "No API key provided" end
@@ -229,7 +238,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
["Authorization"] = "Bearer " .. api_key, ["Authorization"] = "Bearer " .. api_key,
}, },
body = vim.json.encode(vim.tbl_deep_extend("force", { body = vim.json.encode(vim.tbl_deep_extend("force", {
query = opts.query, query = input.query,
}, search_engine.extra_request_body)), }, search_engine.extra_request_body)),
} }
if proxy then curl_opts.proxy = proxy end if proxy then curl_opts.proxy = proxy end
@@ -240,7 +249,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
elseif provider_type == "serpapi" then elseif provider_type == "serpapi" then
local query_params = vim.tbl_deep_extend("force", { local query_params = vim.tbl_deep_extend("force", {
api_key = api_key, api_key = api_key,
q = opts.query, q = input.query,
}, search_engine.extra_request_body) }, search_engine.extra_request_body)
local query_string = "" local query_string = ""
for key, value in pairs(query_params) do for key, value in pairs(query_params) do
@@ -259,7 +268,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
elseif provider_type == "searchapi" then elseif provider_type == "searchapi" then
local query_params = vim.tbl_deep_extend("force", { local query_params = vim.tbl_deep_extend("force", {
api_key = api_key, api_key = api_key,
q = opts.query, q = input.query,
}, search_engine.extra_request_body) }, search_engine.extra_request_body)
local query_string = "" local query_string = ""
for key, value in pairs(query_params) do for key, value in pairs(query_params) do
@@ -283,7 +292,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
local query_params = vim.tbl_deep_extend("force", { local query_params = vim.tbl_deep_extend("force", {
key = api_key, key = api_key,
cx = engine_id, cx = engine_id,
q = opts.query, q = input.query,
}, search_engine.extra_request_body) }, search_engine.extra_request_body)
local query_string = "" local query_string = ""
for key, value in pairs(query_params) do for key, value in pairs(query_params) do
@@ -301,7 +310,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return search_engine.format_response_body(jsn) return search_engine.format_response_body(jsn)
elseif provider_type == "kagi" then elseif provider_type == "kagi" then
local query_params = vim.tbl_deep_extend("force", { local query_params = vim.tbl_deep_extend("force", {
q = opts.query, q = input.query,
}, search_engine.extra_request_body) }, search_engine.extra_request_body)
local query_string = "" local query_string = ""
for key, value in pairs(query_params) do for key, value in pairs(query_params) do
@@ -320,7 +329,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return search_engine.format_response_body(jsn) return search_engine.format_response_body(jsn)
elseif provider_type == "brave" then elseif provider_type == "brave" then
local query_params = vim.tbl_deep_extend("force", { local query_params = vim.tbl_deep_extend("force", {
q = opts.query, q = input.query,
}, search_engine.extra_request_body) }, search_engine.extra_request_body)
local query_string = "" local query_string = ""
for key, value in pairs(query_params) do for key, value in pairs(query_params) do
@@ -343,7 +352,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
return nil, "Environment variable " .. search_engine.api_url_name .. " is not set" return nil, "Environment variable " .. search_engine.api_url_name .. " is not set"
end end
local query_params = vim.tbl_deep_extend("force", { local query_params = vim.tbl_deep_extend("force", {
q = opts.query, q = input.query,
}, search_engine.extra_request_body) }, search_engine.extra_request_body)
local query_string = "" local query_string = ""
for key, value in pairs(query_params) do for key, value in pairs(query_params) do
@@ -362,16 +371,18 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
end end
---@type AvanteLLMToolFunc<{ url: string }> ---@type AvanteLLMToolFunc<{ url: string }>
function M.fetch(opts, on_log, on_complete, session_ctx) function M.fetch(input, opts)
if on_log then on_log("url: " .. opts.url) end local on_log = opts.on_log
if on_log then on_log("url: " .. input.url) end
local Html2Md = require("avante.html2md") local Html2Md = require("avante.html2md")
local res, err = Html2Md.fetch_md(opts.url) local res, err = Html2Md.fetch_md(input.url)
if err then return nil, err end if err then return nil, err end
return res, nil return res, nil
end end
---@type AvanteLLMToolFunc<{ scope?: string }> ---@type AvanteLLMToolFunc<{ scope?: string }>
function M.git_diff(opts, on_log, on_complete, session_ctx) function M.git_diff(input, opts)
local on_log = opts.on_log
local git_cmd = vim.fn.exepath("git") local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return nil, "Git command not found" end if git_cmd == "" then return nil, "Git command not found" end
local project_root = Utils.get_project_root() local project_root = Utils.get_project_root()
@@ -382,7 +393,7 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
if git_dir == "" then return nil, "Not a git repository" end if git_dir == "" then return nil, "Not a git repository" end
-- Get the diff -- Get the diff
local scope = opts.scope or "" local scope = input.scope or ""
local cmd = string.format("git diff --cached %s", scope) local cmd = string.format("git diff --cached %s", scope)
if on_log then on_log("Running command: " .. cmd) end if on_log then on_log("Running command: " .. cmd) end
local diff = vim.fn.system(cmd) local diff = vim.fn.system(cmd)
@@ -400,7 +411,10 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
end end
---@type AvanteLLMToolFunc<{ message: string, scope?: string }> ---@type AvanteLLMToolFunc<{ message: string, scope?: string }>
function M.git_commit(opts, on_log, on_complete, session_ctx) function M.git_commit(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local git_cmd = vim.fn.exepath("git") local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return false, "Git command not found" end if git_cmd == "" then return false, "Git command not found" end
local project_root = Utils.get_project_root() local project_root = Utils.get_project_root()
@@ -444,7 +458,7 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
-- Prepare commit message -- Prepare commit message
local commit_msg_lines = {} local commit_msg_lines = {}
for line in opts.message:gmatch("[^\r\n]+") do for line in input.message:gmatch("[^\r\n]+") do
commit_msg_lines[#commit_msg_lines + 1] = line:gsub('"', '\\"') commit_msg_lines[#commit_msg_lines + 1] = line:gsub('"', '\\"')
end end
commit_msg_lines[#commit_msg_lines + 1] = "" commit_msg_lines[#commit_msg_lines + 1] = ""
@@ -464,8 +478,8 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
return return
end end
-- Stage changes if scope is provided -- Stage changes if scope is provided
if opts.scope then if input.scope then
local stage_cmd = string.format("git add %s", opts.scope) local stage_cmd = string.format("git add %s", input.scope)
if on_log then on_log("Staging files: " .. stage_cmd) end if on_log then on_log("Staging files: " .. stage_cmd) end
local stage_result = vim.fn.system(stage_cmd) local stage_result = vim.fn.system(stage_cmd)
if vim.v.shell_error ~= 0 then if vim.v.shell_error ~= 0 then
@@ -494,20 +508,23 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
end end
on_complete(true, nil) on_complete(true, nil)
end, nil, session_ctx, "git_commit") end, nil, opts.session_ctx, "git_commit")
end end
---@type AvanteLLMToolFunc<{ query: string }> ---@type AvanteLLMToolFunc<{ query: string }>
function M.rag_search(opts, on_log, on_complete, session_ctx) function M.rag_search(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
if not on_complete then return nil, "on_complete not provided" end
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end if not input.query then return nil, "No query provided" end
if on_log then on_log("query: " .. opts.query) end if on_log then on_log("query: " .. input.query) end
local root = Utils.get_project_root() local root = Utils.get_project_root()
local uri = "file://" .. root local uri = "file://" .. root
if uri:sub(-1) ~= "/" then uri = uri .. "/" end if uri:sub(-1) ~= "/" then uri = uri .. "/" end
RagService.retrieve( RagService.retrieve(
uri, uri,
opts.query, input.query,
vim.schedule_wrap(function(resp, err) vim.schedule_wrap(function(resp, err)
if err then if err then
on_complete(nil, err) on_complete(nil, err)
@@ -519,13 +536,15 @@ function M.rag_search(opts, on_log, on_complete, session_ctx)
end end
---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }> ---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }>
function M.python(opts, on_log, on_complete, session_ctx) function M.python(input, opts)
local abs_path = Helpers.get_abs_path(opts.path) local on_log = opts.on_log
local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end
if on_log then on_log("cwd: " .. abs_path) end if on_log then on_log("cwd: " .. abs_path) end
if on_log then on_log("code:\n" .. opts.code) end if on_log then on_log("code:\n" .. input.code) end
local container_image = opts.container_image or "python:3.11-slim-bookworm" local container_image = input.container_image or "python:3.11-slim-bookworm"
if not on_complete then return nil, "on_complete not provided" end if not on_complete then return nil, "on_complete not provided" end
Helpers.confirm( Helpers.confirm(
"Are you sure you want to run the following python code in the `" "Are you sure you want to run the following python code in the `"
@@ -533,7 +552,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
.. "` container, in the directory: `" .. "` container, in the directory: `"
.. abs_path .. abs_path
.. "`?\n" .. "`?\n"
.. opts.code, .. input.code,
function(ok, reason) function(ok, reason)
if not ok then if not ok then
on_complete(nil, "User declined, reason: " .. (reason or "unknown")) on_complete(nil, "User declined, reason: " .. (reason or "unknown"))
@@ -562,7 +581,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
container_image, container_image,
"python", "python",
"-c", "-c",
opts.code, input.code,
}, },
{ {
text = true, text = true,
@@ -576,7 +595,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
) )
end, end,
nil, nil,
session_ctx, opts.session_ctx,
"python" "python"
) )
end end
@@ -1189,14 +1208,19 @@ You can delete the first file by providing a path of "directory1/a/something.txt
--- compatibility alias for old calls & tests --- compatibility alias for old calls & tests
M.run_python = M.python M.run_python = M.python
---@class avante.ProcessToolUseOpts
---@field session_ctx table
---@field on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil
---@field on_complete? fun(result: string | nil, error: string | nil): nil
---@param tools AvanteLLMTool[] ---@param tools AvanteLLMTool[]
---@param tool_use AvanteLLMToolUse ---@param tool_use AvanteLLMToolUse
---@param on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@param on_complete? fun(result: string | nil, error: string | nil): nil
---@param session_ctx? table
---@return string | nil result ---@return string | nil result
---@return string | nil error ---@return string | nil error
function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx) function M.process_tool_use(tools, tool_use, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
-- Check if execution is already cancelled -- Check if execution is already cancelled
if Helpers.is_cancelled then if Helpers.is_cancelled then
Utils.debug("Tool execution cancelled before starting: " .. tool_use.name) Utils.debug("Tool execution cancelled before starting: " .. tool_use.name)
@@ -1274,25 +1298,32 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
return result_str, err return result_str, err
end end
local result, err = func(input_json, function(log) local result, err = func(input_json, {
-- Check for cancellation during logging session_ctx = opts.session_ctx or {},
if Helpers.is_cancelled then return end on_log = function(log)
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end -- Check for cancellation during logging
end, function(result, err) if Helpers.is_cancelled then return end
-- Check for cancellation before completing if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
if Helpers.is_cancelled then end,
Helpers.is_cancelled = false set_store = function(key, value)
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end if opts.set_tool_use_store then opts.set_tool_use_store(tool_use.id, key, value) end
return end,
end on_complete = function(result, err)
-- Check for cancellation before completing
if Helpers.is_cancelled then
Helpers.is_cancelled = false
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
return
end
result, err = handle_result(result, err) result, err = handle_result(result, err)
if on_complete == nil then if on_complete == nil then
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled") Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
return return
end end
on_complete(result, err) on_complete(result, err)
end, session_ctx) end,
})
-- Result and error being nil means that the tool was executed asynchronously -- Result and error being nil means that the tool was executed asynchronously
if result == nil and err == nil and on_complete then return end if result == nil and err == nil and on_complete then return end

View File

@@ -55,19 +55,24 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, insert_line: integer, new_str: string }> ---@type AvanteLLMToolFunc<{ path: string, insert_line: integer, new_str: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
if on_log then on_log("path: " .. opts.path) end local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(opts.path) local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
if opts.insert_line == nil then return false, "insert_line not provided" end if input.insert_line == nil then return false, "insert_line not provided" end
if opts.new_str == nil then return false, "new_str not provided" end if input.new_str == nil then return false, "new_str not provided" end
local ns_id = vim.api.nvim_create_namespace("avante_insert_diff") local ns_id = vim.api.nvim_create_namespace("avante_insert_diff")
local bufnr, err = Helpers.get_bufnr(abs_path) local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end if err then return false, err end
local function clear_highlights() vim.api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) end local function clear_highlights() vim.api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) end
local new_lines = vim.split(opts.new_str, "\n") local new_lines = vim.split(input.new_str, "\n")
local max_col = vim.o.columns local max_col = vim.o.columns
local virt_lines = vim local virt_lines = vim
.iter(new_lines) .iter(new_lines)
@@ -78,8 +83,8 @@ function M.func(opts, on_log, on_complete, session_ctx)
end) end)
:totable() :totable()
local line_count = vim.api.nvim_buf_line_count(bufnr) local line_count = vim.api.nvim_buf_line_count(bufnr)
if opts.insert_line > line_count - 1 then opts.insert_line = line_count - 1 end if input.insert_line > line_count - 1 then input.insert_line = line_count - 1 end
vim.api.nvim_buf_set_extmark(bufnr, ns_id, opts.insert_line, 0, { vim.api.nvim_buf_set_extmark(bufnr, ns_id, input.insert_line, 0, {
virt_lines = virt_lines, virt_lines = virt_lines,
hl_eol = true, hl_eol = true,
hl_mode = "combine", hl_mode = "combine",
@@ -90,9 +95,9 @@ function M.func(opts, on_log, on_complete, session_ctx)
on_complete(false, "User declined, reason: " .. (reason or "unknown")) on_complete(false, "User declined, reason: " .. (reason or "unknown"))
return return
end end
vim.api.nvim_buf_set_lines(bufnr, opts.insert_line, opts.insert_line, false, new_lines) vim.api.nvim_buf_set_lines(bufnr, input.insert_line, input.insert_line, false, new_lines)
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end) vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil) on_complete(true, nil)
end, { focus = true }, session_ctx, M.name) end, { focus = true }, session_ctx, M.name)
end end

View File

@@ -46,15 +46,16 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, max_depth?: integer }> ---@type AvanteLLMToolFunc<{ path: string, max_depth?: integer }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local abs_path = Helpers.get_abs_path(opts.path) local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("max depth: " .. tostring(opts.max_depth)) end if on_log then on_log("max depth: " .. tostring(input.max_depth)) end
local files = Utils.scan_directory({ local files = Utils.scan_directory({
directory = abs_path, directory = abs_path,
add_dirs = true, add_dirs = true,
max_depth = opts.max_depth, max_depth = input.max_depth,
}) })
local filepaths = {} local filepaths = {}
for _, file in ipairs(files) do for _, file in ipairs(files) do

View File

@@ -125,39 +125,44 @@ end
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view. --- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }> ---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
if opts.the_diff ~= nil then local on_log = opts.on_log
opts.diff = opts.the_diff local on_complete = opts.on_complete
opts.the_diff = nil local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if input.the_diff ~= nil then
input.diff = input.the_diff
input.the_diff = nil
end end
if not opts.path or not opts.diff then return false, "path and diff are required " .. vim.inspect(opts) end if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end
if on_log then on_log("path: " .. opts.path) end if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(opts.path) local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
local is_streaming = opts.streaming or false local is_streaming = input.streaming or false
session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {} session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {}
local current_timestamp = os.time() local current_timestamp = os.time()
if is_streaming then if is_streaming then
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id]
if prev_streaming_diff_timestamp ~= nil then if prev_streaming_diff_timestamp ~= nil then
if current_timestamp - prev_streaming_diff_timestamp < 2 then if current_timestamp - prev_streaming_diff_timestamp < 2 then
return false, "Diff hasn't changed in the last 2 seconds" return false, "Diff hasn't changed in the last 2 seconds"
end end
end end
local streaming_diff_lines_count = Utils.count_lines(opts.diff) local streaming_diff_lines_count = Utils.count_lines(input.diff)
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {} session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[input.tool_use_id]
if streaming_diff_lines_count == prev_streaming_diff_lines_count then if streaming_diff_lines_count == prev_streaming_diff_lines_count then
return false, "Diff lines count hasn't changed" return false, "Diff lines count hasn't changed"
end end
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count session_ctx.streaming_diff_lines_count_history[input.tool_use_id] = streaming_diff_lines_count
end end
local diff = fix_diff(opts.diff) local diff = fix_diff(input.diff)
if on_log and diff ~= opts.diff then on_log("diff fixed") end if on_log and diff ~= input.diff then on_log("diff fixed") end
local diff_lines = vim.split(diff, "\n") local diff_lines = vim.split(diff, "\n")
@@ -203,16 +208,16 @@ function M.func(opts, on_log, on_complete, session_ctx)
return false, "No diff blocks found" return false, "No diff blocks found"
end end
session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] = current_timestamp session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] = current_timestamp
local bufnr, err = Helpers.get_bufnr(abs_path) local bufnr, err = Helpers.get_bufnr(abs_path)
if err then return false, err end if err then return false, err end
session_ctx.undo_joined = session_ctx.undo_joined or {} session_ctx.undo_joined = session_ctx.undo_joined or {}
local undo_joined = session_ctx.undo_joined[opts.tool_use_id] local undo_joined = session_ctx.undo_joined[input.tool_use_id]
if not undo_joined then if not undo_joined then
pcall(vim.cmd.undojoin) pcall(vim.cmd.undojoin)
session_ctx.undo_joined[opts.tool_use_id] = true session_ctx.undo_joined[input.tool_use_id] = true
end end
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
@@ -242,10 +247,10 @@ function M.func(opts, on_log, on_complete, session_ctx)
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {} session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {}
local rough_diff_blocks_to_diff_blocks_cache = local rough_diff_blocks_to_diff_blocks_cache =
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id]
if not rough_diff_blocks_to_diff_blocks_cache then if not rough_diff_blocks_to_diff_blocks_cache then
rough_diff_blocks_to_diff_blocks_cache = {} rough_diff_blocks_to_diff_blocks_cache = {}
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
end end
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_) local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
@@ -472,7 +477,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
on_complete(false, "User canceled") on_complete(false, "User canceled")
return return
end end
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil) on_complete(true, nil)
end, end,
}) })
@@ -615,35 +620,35 @@ function M.func(opts, on_log, on_complete, session_ctx)
end end
session_ctx.extmark_id_map = session_ctx.extmark_id_map or {} session_ctx.extmark_id_map = session_ctx.extmark_id_map or {}
local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id] local extmark_id_map = session_ctx.extmark_id_map[input.tool_use_id]
if not extmark_id_map then if not extmark_id_map then
extmark_id_map = {} extmark_id_map = {}
session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map session_ctx.extmark_id_map[input.tool_use_id] = extmark_id_map
end end
session_ctx.virt_lines_map = session_ctx.virt_lines_map or {} session_ctx.virt_lines_map = session_ctx.virt_lines_map or {}
local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id] local virt_lines_map = session_ctx.virt_lines_map[input.tool_use_id]
if not virt_lines_map then if not virt_lines_map then
virt_lines_map = {} virt_lines_map = {}
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map session_ctx.virt_lines_map[input.tool_use_id] = virt_lines_map
end end
session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {} session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {}
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[input.tool_use_id]
if not last_orig_diff_end_line then if not last_orig_diff_end_line then
last_orig_diff_end_line = 1 last_orig_diff_end_line = 1
session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line session_ctx.last_orig_diff_end_line_map[input.tool_use_id] = last_orig_diff_end_line
end end
session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {} session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {}
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[input.tool_use_id]
if not last_resp_diff_end_line then if not last_resp_diff_end_line then
last_resp_diff_end_line = 1 last_resp_diff_end_line = 1
session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line session_ctx.last_resp_diff_end_line_map[input.tool_use_id] = last_resp_diff_end_line
end end
session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {} session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {}
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id] local prev_diff_blocks = session_ctx.prev_diff_blocks_map[input.tool_use_id]
if not prev_diff_blocks then if not prev_diff_blocks then
prev_diff_blocks = {} prev_diff_blocks = {}
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks session_ctx.prev_diff_blocks_map[input.tool_use_id] = prev_diff_blocks
end end
local function get_unstable_diff_blocks(diff_blocks_) local function get_unstable_diff_blocks(diff_blocks_)
@@ -663,7 +668,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
local function highlight_streaming_diff_blocks() local function highlight_streaming_diff_blocks()
local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks) local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks)
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks session_ctx.prev_diff_blocks_map[input.tool_use_id] = diff_blocks
local max_col = vim.o.columns local max_col = vim.o.columns
for _, diff_block in ipairs(unstable_diff_blocks) do for _, diff_block in ipairs(unstable_diff_blocks) do
local new_lines = diff_block.new_lines local new_lines = diff_block.new_lines
@@ -747,7 +752,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
--- check if the parent dir is exists, if not, create it --- check if the parent dir is exists, if not, create it
if vim.fn.isdirectory(parent_dir) == 0 then vim.fn.mkdir(parent_dir, "p") end if vim.fn.isdirectory(parent_dir) == 0 then vim.fn.mkdir(parent_dir, "p") end
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end) vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil) on_complete(true, nil)
end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx, M.name) end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx, M.name)
end end

View File

@@ -1,5 +1,4 @@
local Base = require("avante.llm_tools.base") local Base = require("avante.llm_tools.base")
local Config = require("avante.config")
---@class AvanteLLMTool ---@class AvanteLLMTool
local M = setmetatable({}, Base) local M = setmetatable({}, Base)
@@ -55,17 +54,17 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }> ---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local replace_in_file = require("avante.llm_tools.replace_in_file") local replace_in_file = require("avante.llm_tools.replace_in_file")
local diff = "------- SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str local diff = "------- SEARCH\n" .. input.old_str .. "\n=======\n" .. input.new_str
if not opts.streaming then diff = diff .. "\n+++++++ REPLACE" end if not input.streaming then diff = diff .. "\n+++++++ REPLACE" end
local new_opts = { local new_input = {
path = opts.path, path = input.path,
diff = diff, diff = diff,
streaming = opts.streaming, streaming = input.streaming,
tool_use_id = opts.tool_use_id, tool_use_id = input.tool_use_id,
} }
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx) return replace_in_file.func(new_input, opts)
end end
return M return M

View File

@@ -47,12 +47,13 @@ M.returns = {
---@field thought string ---@field thought string
---@type avante.LLMToolOnRender<ThinkingInput> ---@type avante.LLMToolOnRender<ThinkingInput>
function M.on_render(opts, _, state) function M.on_render(input, opts)
local state = opts.state
local lines = {} local lines = {}
local text = state == "generating" and "Thinking" or "Thoughts" local text = state == "generating" and "Thinking" or "Thoughts"
table.insert(lines, Line:new({ { Utils.icon("🤔 ") .. text, Highlights.AVANTE_THINKING } })) table.insert(lines, Line:new({ { Utils.icon("🤔 ") .. text, Highlights.AVANTE_THINKING } }))
table.insert(lines, Line:new({ { "" } })) table.insert(lines, Line:new({ { "" } }))
local content = opts.thought or "" local content = input.thought or ""
local text_lines = vim.split(content, "\n") local text_lines = vim.split(content, "\n")
for _, text_line in ipairs(text_lines) do for _, text_line in ipairs(text_lines) do
table.insert(lines, Line:new({ { "> " .. text_line } })) table.insert(lines, Line:new({ { "> " .. text_line } }))
@@ -61,7 +62,8 @@ function M.on_render(opts, _, state)
end end
---@type AvanteLLMToolFunc<ThinkingInput> ---@type AvanteLLMToolFunc<ThinkingInput>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local on_complete = opts.on_complete
if not on_complete then return false, "on_complete not provided" end if not on_complete then return false, "on_complete not provided" end
on_complete(true, nil) on_complete(true, nil)
end end

View File

@@ -43,9 +43,14 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string }> ---@type AvanteLLMToolFunc<{ path: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
if on_log then on_log("path: " .. opts.path) end local on_log = opts.on_log
local abs_path = Helpers.get_abs_path(opts.path) local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
@@ -59,7 +64,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
end end
vim.api.nvim_win_call(winid, function() vim.cmd("noautocmd undo") end) vim.api.nvim_win_call(winid, function() vim.cmd("noautocmd undo") end)
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end) vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
on_complete(true, nil) on_complete(true, nil)
end, { focus = true }, session_ctx, M.name) end, { focus = true }, session_ctx, M.name)
end end

View File

@@ -43,14 +43,15 @@ M.returns = {
M.on_render = function() return {} end M.on_render = function() return {} end
---@type AvanteLLMToolFunc<{ id: string, status: string }> ---@type AvanteLLMToolFunc<{ id: string, status: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
local on_complete = opts.on_complete
local sidebar = require("avante").get() local sidebar = require("avante").get()
if not sidebar then return false, "Avante sidebar not found" end if not sidebar then return false, "Avante sidebar not found" end
local todos = sidebar.chat_history.todos local todos = sidebar.chat_history.todos
if not todos or #todos == 0 then return false, "No todos found" end if not todos or #todos == 0 then return false, "No todos found" end
for _, todo in ipairs(todos) do for _, todo in ipairs(todos) do
if todo.id == opts.id then if todo.id == input.id then
todo.status = opts.status todo.status = input.status
break break
end end
end end

View File

@@ -87,18 +87,20 @@ M.returns = {
} }
---@type AvanteLLMToolFunc<{ path: string, start_line?: integer, end_line?: integer }> ---@type AvanteLLMToolFunc<{ path: string, start_line?: integer, end_line?: integer }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
if not opts.path then return false, "path is required" end local on_log = opts.on_log
if on_log then on_log("path: " .. opts.path) end local on_complete = opts.on_complete
local abs_path = Helpers.get_abs_path(opts.path) if not input.path then return false, "path is required" end
if on_log then on_log("path: " .. input.path) end
local abs_path = Helpers.get_abs_path(input.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if Path:new(abs_path):is_dir() then return false, "Path is a directory: " .. abs_path end if Path:new(abs_path):is_dir() then return false, "Path is a directory: " .. abs_path end
local file = io.open(abs_path, "r") local file = io.open(abs_path, "r")
if not file then return false, "file not found: " .. abs_path end if not file then return false, "file not found: " .. abs_path end
local lines = Utils.read_file_from_buf_or_disk(abs_path) local lines = Utils.read_file_from_buf_or_disk(abs_path)
local start_line = opts.start_line local start_line = input.start_line
local end_line = opts.end_line local end_line = input.end_line
if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end
local truncated_lines = {} local truncated_lines = {}
local is_truncated = false local is_truncated = false

View File

@@ -57,27 +57,26 @@ M.returns = {
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view. --- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }> ---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }>
function M.func(opts, on_log, on_complete, session_ctx) function M.func(input, opts)
if opts.the_content ~= nil then if input.the_content ~= nil then
opts.content = opts.the_content input.content = input.the_content
opts.the_content = nil input.the_content = nil
end end
if not on_complete then return false, "on_complete not provided" end local abs_path = Helpers.get_abs_path(input.path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.content == nil then return false, "content not provided" end if input.content == nil then return false, "content not provided" end
if type(opts.content) ~= "string" then opts.content = vim.json.encode(opts.content) end if type(input.content) ~= "string" then input.content = vim.json.encode(input.content) end
local old_lines = Utils.read_file_from_buf_or_disk(abs_path) local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
local old_content = table.concat(old_lines or {}, "\n") local old_content = table.concat(old_lines or {}, "\n")
local str_replace = require("avante.llm_tools.str_replace") local str_replace = require("avante.llm_tools.str_replace")
local new_opts = { local new_input = {
path = opts.path, path = input.path,
old_str = old_content, old_str = old_content,
new_str = opts.content, new_str = input.content,
streaming = opts.streaming, streaming = input.streaming,
tool_use_id = opts.tool_use_id, tool_use_id = input.tool_use_id,
} }
return str_replace.func(new_opts, on_log, on_complete, session_ctx) return str_replace.func(new_input, opts)
end end
return M return M

View File

@@ -2318,7 +2318,7 @@ function Sidebar:get_history_messages_for_api(opts)
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content --- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
if is_edit_func_call and path and not message.message.content[1].is_error then if is_edit_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path) local uniformed_path = Utils.uniform_path(path)
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil) local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {})
if view_error then view_result = "Error: " .. view_error end if view_error then view_result = "Error: " .. view_error end
local get_diagnostics_tool_use_id = Utils.uuid() local get_diagnostics_tool_use_id = Utils.uuid()
local view_tool_use_id = Utils.uuid() local view_tool_use_id = Utils.uuid()
@@ -2451,9 +2451,7 @@ function Sidebar:get_history_messages_for_api(opts)
local end_line = tool_id_to_end_line[item.tool_use_id] local end_line = tool_id_to_end_line[item.tool_use_id]
local view_result, view_error = require("avante.llm_tools.view").func( local view_result, view_error = require("avante.llm_tools.view").func(
{ path = path, start_line = start_line, end_line = end_line }, { path = path, start_line = start_line, end_line = end_line },
nil, {}
nil,
nil
) )
if view_error then view_result = "Error: " .. view_error end if view_error then view_result = "Error: " .. view_error end
item.content = view_result item.content = view_result
@@ -2773,6 +2771,23 @@ function Sidebar:create_input_container()
self:save_history() self:save_history()
end end
local function set_tool_use_store(tool_id, key, value)
local tool_use_message = nil
for idx = #self.chat_history.messages, 1, -1 do
local message = self.chat_history.messages[idx]
local content = message.message.content
if type(content) == "table" and content[1].type == "tool_use" and content[1].id == tool_id then
tool_use_message = message
break
end
end
if not tool_use_message then return end
local tool_use_store = tool_use_message.tool_use_store or {}
tool_use_store[key] = value
tool_use_message.tool_use_store = tool_use_store
self:save_history()
end
---@type AvanteLLMStopCallback ---@type AvanteLLMStopCallback
local function on_stop(stop_opts) local function on_stop(stop_opts)
self.is_generating = false self.is_generating = false
@@ -2837,6 +2852,7 @@ function Sidebar:create_input_container()
on_tool_log = on_tool_log, on_tool_log = on_tool_log,
on_messages_add = on_messages_add, on_messages_add = on_messages_add,
on_state_change = on_state_change, on_state_change = on_state_change,
set_tool_use_store = set_tool_use_store,
get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end, get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end,
get_todos = function() get_todos = function()
local history = Path.history.load(self.code.bufnr) local history = Path.history.load(self.code.bufnr)

View File

@@ -107,6 +107,7 @@ vim.g.avante_login = vim.g.avante_login
---@field selected_code AvanteSelectedCode | nil ---@field selected_code AvanteSelectedCode | nil
---@field selected_filepaths string[] | nil ---@field selected_filepaths string[] | nil
---@field tool_use_logs string[] | nil ---@field tool_use_logs string[] | nil
---@field tool_use_store table | nil
---@field just_for_display boolean | nil ---@field just_for_display boolean | nil
---@field is_dummy boolean | nil ---@field is_dummy boolean | nil
---@field is_compacted boolean | nil ---@field is_compacted boolean | nil
@@ -407,19 +408,30 @@ vim.g.avante_login = vim.g.avante_login
---@field on_stop AvanteLLMStopCallback ---@field on_stop AvanteLLMStopCallback
---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback ---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback
---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil ---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil
---@field get_history_messages? fun(opts?: { all?: boolean }): avante.HistoryMessage[] ---@field get_history_messages? fun(opts?: { all?: boolean }): avante.HistoryMessage[]
---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil ---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil
---@field on_state_change? fun(state: avante.GenerateState): nil ---@field on_state_change? fun(state: avante.GenerateState): nil
---@field update_tokens_usage? fun(usage: avante.LLMTokenUsage): nil ---@field update_tokens_usage? fun(usage: avante.LLMTokenUsage): nil
--- ---
---@class AvanteLLMToolFuncOpts
---@field session_ctx table
---@field on_complete? fun(result: boolean | string | nil, error: string | nil): nil
---@field on_log? fun(log: string): nil
---@field set_store? fun(key: string, value: any): nil
---
---@alias AvanteLLMToolFunc<T> fun( ---@alias AvanteLLMToolFunc<T> fun(
--- input: T, --- input: T,
--- on_log?: (fun(log: string): nil), --- opts: AvanteLLMToolFuncOpts)
--- on_complete?: (fun(result: boolean | string | nil, error: string | nil): nil),
--- session_ctx?: table)
--- : (boolean | string | nil, string | nil) --- : (boolean | string | nil, string | nil)
--- ---
--- @alias avante.LLMToolOnRender<T> fun(input: T, logs: string[], state: avante.HistoryMessageState | nil): avante.ui.Line[] ---@class avante.LLMToolOnRenderOpts
---@field logs string[]
---@field state avante.HistoryMessageState
---@field store table | nil
---@field result_message avante.HistoryMessage | nil
---
--- @alias avante.LLMToolOnRender<T> fun(input: T, opts: avante.LLMToolOnRenderOpts): avante.ui.Line[]
--- ---
---@class AvanteLLMTool ---@class AvanteLLMTool
---@field name string ---@field name string

View File

@@ -1674,14 +1674,22 @@ function M.message_content_item_to_lines(item, message, messages)
return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) } return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) }
end end
if item.type == "tool_use" then if item.type == "tool_use" then
local tool_result_message = M.get_tool_result_message(message, messages)
local lines = {} local lines = {}
local state = "generating" local state = "generating"
local hl = "AvanteStateSpinnerToolCalling" local hl = "AvanteStateSpinnerToolCalling"
local ok, llm_tool = pcall(require, "avante.llm_tools." .. item.name) local ok, llm_tool = pcall(require, "avante.llm_tools." .. item.name)
if ok then if ok then
if llm_tool.on_render then return llm_tool.on_render(item.input, message.tool_use_logs, message.state) end ---@cast llm_tool AvanteLLMTool
if llm_tool.on_render then
return llm_tool.on_render(item.input, {
logs = message.tool_use_logs,
state = message.state,
store = message.tool_use_store,
result_message = tool_result_message,
})
end
end end
local tool_result_message = M.get_tool_result_message(message, messages)
if tool_result_message then if tool_result_message then
local tool_result = tool_result_message.message.content[1] local tool_result = tool_result_message.message.content[1]
if tool_result.is_error then if tool_result.is_error then

View File

@@ -53,21 +53,21 @@ describe("llm_tools", function()
describe("ls", function() describe("ls", function()
it("should list files in directory", function() it("should list files in directory", function()
local result, err = ls({ path = ".", max_depth = 1 }) local result, err = ls({ path = ".", max_depth = 1 }, {})
assert.is_nil(err) assert.is_nil(err)
assert.falsy(result:find("avante.nvim")) assert.falsy(result:find("avante.nvim"))
assert.truthy(result:find("test.txt")) assert.truthy(result:find("test.txt"))
assert.falsy(result:find("test1.txt")) assert.falsy(result:find("test1.txt"))
end) end)
it("should list files in directory with depth", function() it("should list files in directory with depth", function()
local result, err = ls({ path = ".", max_depth = 2 }) local result, err = ls({ path = ".", max_depth = 2 }, {})
assert.is_nil(err) assert.is_nil(err)
assert.falsy(result:find("avante.nvim")) assert.falsy(result:find("avante.nvim"))
assert.truthy(result:find("test.txt")) assert.truthy(result:find("test.txt"))
assert.truthy(result:find("test1.txt")) assert.truthy(result:find("test1.txt"))
end) end)
it("should list files respecting gitignore", function() it("should list files respecting gitignore", function()
local result, err = ls({ path = ".", max_depth = 2 }) local result, err = ls({ path = ".", max_depth = 2 }, {})
assert.is_nil(err) assert.is_nil(err)
assert.falsy(result:find("avante.nvim")) assert.falsy(result:find("avante.nvim"))
assert.truthy(result:find("test.txt")) assert.truthy(result:find("test.txt"))
@@ -78,49 +78,61 @@ describe("llm_tools", function()
describe("view", function() describe("view", function()
it("should read file content", function() it("should read file content", function()
view({ path = "test.txt" }, nil, function(content, err) view({ path = "test.txt" }, {
assert.is_nil(err) on_complete = function(content, err)
assert.equals("test content", vim.json.decode(content).content) assert.is_nil(err)
end) assert.equals("test content", vim.json.decode(content).content)
end,
})
end) end)
it("should return error for non-existent file", function() it("should return error for non-existent file", function()
view({ path = "non_existent.txt" }, nil, function(content, err) view({ path = "non_existent.txt" }, {
assert.truthy(err) on_complete = function(content, err)
assert.equals("", content) assert.truthy(err)
end) assert.equals("", content)
end,
})
end) end)
it("should read directory content", function() it("should read directory content", function()
view({ path = test_dir }, nil, function(content, err) view({ path = test_dir }, {
assert.is_nil(err) on_complete = function(content, err)
assert.truthy(content:find("test.txt")) assert.is_nil(err)
assert.truthy(content:find("test content")) assert.truthy(content:find("test.txt"))
end) assert.truthy(content:find("test content"))
end,
})
end) end)
end) end)
describe("create_dir", function() describe("create_dir", function()
it("should create new directory", function() it("should create new directory", function()
LlmTools.create_dir({ path = "new_dir" }, nil, function(success, err) LlmTools.create_dir({ path = "new_dir" }, {
assert.is_nil(err) session_ctx = {},
assert.is_true(success) on_complete = function(success, err)
assert.is_nil(err)
assert.is_true(success)
local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil
assert.is_true(dir_exists) assert.is_true(dir_exists)
end) end,
})
end) end)
end) end)
describe("delete_path", function() describe("delete_path", function()
it("should delete existing file", function() it("should delete existing file", function()
LlmTools.delete_path({ path = "test.txt" }, nil, function(success, err) LlmTools.delete_path({ path = "test.txt" }, {
assert.is_nil(err) session_ctx = {},
assert.is_true(success) on_complete = function(success, err)
assert.is_nil(err)
assert.is_true(success)
local file_exists = io.open(test_file, "r") ~= nil local file_exists = io.open(test_file, "r") ~= nil
assert.is_false(file_exists) assert.is_false(file_exists)
end) end,
})
end) end)
end) end)
@@ -147,22 +159,22 @@ describe("llm_tools", function()
file:write("this is nothing") file:write("this is nothing")
file:close() file:close()
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }) local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }, {})
assert.is_nil(err) assert.is_nil(err)
assert.truthy(result:find("searchable.txt")) assert.truthy(result:find("searchable.txt"))
assert.falsy(result:find("nothing.txt")) assert.falsy(result:find("nothing.txt"))
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }) local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }, {})
assert.is_nil(err2) assert.is_nil(err2)
assert.truthy(result2:find("searchable.txt")) assert.truthy(result2:find("searchable.txt"))
assert.falsy(result2:find("nothing.txt")) assert.falsy(result2:find("nothing.txt"))
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }) local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }, {})
assert.is_nil(err3) assert.is_nil(err3)
assert.falsy(result3:find("searchable.txt")) assert.falsy(result3:find("searchable.txt"))
assert.falsy(result3:find("nothing.txt")) assert.falsy(result3:find("nothing.txt"))
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }) local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }, {})
assert.is_nil(err4) assert.is_nil(err4)
assert.truthy(result4:find("searchable.txt")) assert.truthy(result4:find("searchable.txt"))
assert.falsy(result4:find("nothing.txt")) assert.falsy(result4:find("nothing.txt"))
@@ -172,7 +184,7 @@ describe("llm_tools", function()
query = "searchable", query = "searchable",
case_sensitive = false, case_sensitive = false,
exclude_pattern = "search*", exclude_pattern = "search*",
}) }, {})
assert.is_nil(err5) assert.is_nil(err5)
assert.falsy(result5:find("searchable.txt")) assert.falsy(result5:find("searchable.txt"))
assert.falsy(result5:find("nothing.txt")) assert.falsy(result5:find("nothing.txt"))
@@ -191,7 +203,7 @@ describe("llm_tools", function()
file:write("content for ag test") file:write("content for ag test")
file:close() file:close()
local result, err = grep({ path = ".", query = "ag test" }) local result, err = grep({ path = ".", query = "ag test" }, {})
assert.is_nil(err) assert.is_nil(err)
assert.is_string(result) assert.is_string(result)
assert.truthy(result:find("ag_test.txt")) assert.truthy(result:find("ag_test.txt"))
@@ -215,22 +227,22 @@ describe("llm_tools", function()
file:write("this is nothing") file:write("this is nothing")
file:close() file:close()
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }) local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }, {})
assert.is_nil(err) assert.is_nil(err)
assert.truthy(result:find("searchable.txt")) assert.truthy(result:find("searchable.txt"))
assert.falsy(result:find("nothing.txt")) assert.falsy(result:find("nothing.txt"))
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }) local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }, {})
assert.is_nil(err2) assert.is_nil(err2)
assert.truthy(result2:find("searchable.txt")) assert.truthy(result2:find("searchable.txt"))
assert.falsy(result2:find("nothing.txt")) assert.falsy(result2:find("nothing.txt"))
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }) local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }, {})
assert.is_nil(err3) assert.is_nil(err3)
assert.falsy(result3:find("searchable.txt")) assert.falsy(result3:find("searchable.txt"))
assert.falsy(result3:find("nothing.txt")) assert.falsy(result3:find("nothing.txt"))
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }) local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }, {})
assert.is_nil(err4) assert.is_nil(err4)
assert.truthy(result4:find("searchable.txt")) assert.truthy(result4:find("searchable.txt"))
assert.falsy(result4:find("nothing.txt")) assert.falsy(result4:find("nothing.txt"))
@@ -240,7 +252,7 @@ describe("llm_tools", function()
query = "searchable", query = "searchable",
case_sensitive = false, case_sensitive = false,
exclude_pattern = "search*", exclude_pattern = "search*",
}) }, {})
assert.is_nil(err5) assert.is_nil(err5)
assert.falsy(result5:find("searchable.txt")) assert.falsy(result5:find("searchable.txt"))
assert.falsy(result5:find("nothing.txt")) assert.falsy(result5:find("nothing.txt"))
@@ -250,18 +262,18 @@ describe("llm_tools", function()
-- Mock exepath to return nothing -- Mock exepath to return nothing
vim.fn.exepath = function() return "" end vim.fn.exepath = function() return "" end
local result, err = grep({ path = ".", query = "test" }) local result, err = grep({ path = ".", query = "test" }, {})
assert.equals("", result) assert.equals("", result)
assert.equals("No search command found", err) assert.equals("No search command found", err)
end) end)
it("should respect path permissions", function() it("should respect path permissions", function()
local result, err = grep({ path = "../outside_project", query = "test" }) local result, err = grep({ path = "../outside_project", query = "test" }, {})
assert.truthy(err:find("No permission to access path")) assert.truthy(err:find("No permission to access path"))
end) end)
it("should handle non-existent paths", function() it("should handle non-existent paths", function()
local result, err = grep({ path = "non_existent_dir", query = "test" }) local result, err = grep({ path = "non_existent_dir", query = "test" }, {})
assert.equals("", result) assert.equals("", result)
assert.truthy(err) assert.truthy(err)
assert.truthy(err:find("No such file or directory")) assert.truthy(err:find("No such file or directory"))
@@ -277,86 +289,84 @@ describe("llm_tools", function()
-- end) -- end)
it("should return error when running outside current directory", function() it("should return error when running outside current directory", function()
bash({ path = "../outside_project", command = "echo 'test'" }, nil, function(result, err) bash({ path = "../outside_project", command = "echo 'test'" }, {
assert.is_false(result) session_ctx = {},
assert.truthy(err) on_complete = function(result, err)
assert.truthy(err:find("No permission to access path")) assert.is_false(result)
end) assert.truthy(err)
assert.truthy(err:find("No permission to access path"))
end,
})
end) end)
end) end)
describe("python", function() describe("python", function()
it("should execute Python code and return output", function() it("should execute Python code and return output", function()
LlmTools.python( LlmTools.python({
{ path = ".",
path = ".", code = "print('Hello from Python')",
code = "print('Hello from Python')", }, {
}, session_ctx = {},
nil, on_complete = function(result, err)
function(result, err)
assert.is_nil(err) assert.is_nil(err)
assert.equals("Hello from Python\n", result) assert.equals("Hello from Python\n", result)
end end,
) })
end) end)
it("should handle Python errors", function() it("should handle Python errors", function()
LlmTools.python( LlmTools.python({
{ path = ".",
path = ".", code = "print(undefined_variable)",
code = "print(undefined_variable)", }, {
}, session_ctx = {},
nil, on_complete = function(result, err)
function(result, err)
assert.is_nil(result) assert.is_nil(result)
assert.truthy(err) assert.truthy(err)
assert.truthy(err:find("Error")) assert.truthy(err:find("Error"))
end end,
) })
end) end)
it("should respect path permissions", function() it("should respect path permissions", function()
LlmTools.python( LlmTools.python({
{ path = "../outside_project",
path = "../outside_project", code = "print('test')",
code = "print('test')", }, {
}, session_ctx = {},
nil, on_complete = function(result, err)
function(result, err)
assert.is_nil(result) assert.is_nil(result)
assert.truthy(err:find("No permission to access path")) assert.truthy(err:find("No permission to access path"))
end end,
) })
end) end)
it("should handle non-existent paths", function() it("should handle non-existent paths", function()
LlmTools.python( LlmTools.python({
{ path = "non_existent_dir",
path = "non_existent_dir", code = "print('test')",
code = "print('test')", }, {
}, session_ctx = {},
nil, on_complete = function(result, err)
function(result, err)
assert.is_nil(result) assert.is_nil(result)
assert.truthy(err:find("Path not found")) assert.truthy(err:find("Path not found"))
end end,
) })
end) end)
it("should support custom container image", function() it("should support custom container image", function()
os.execute("docker image rm python:3.12-slim") os.execute("docker image rm python:3.12-slim")
LlmTools.python( LlmTools.python({
{ path = ".",
path = ".", code = "print('Hello from custom container')",
code = "print('Hello from custom container')", container_image = "python:3.12-slim",
container_image = "python:3.12-slim", }, {
}, session_ctx = {},
nil, on_complete = function(result, err)
function(result, err)
assert.is_nil(err) assert.is_nil(err)
assert.equals("Hello from custom container\n", result) assert.equals("Hello from custom container\n", result)
end end,
) })
end) end)
end) end)
@@ -370,7 +380,7 @@ describe("llm_tools", function()
os.execute("touch " .. test_dir .. "/nested/file4.lua") os.execute("touch " .. test_dir .. "/nested/file4.lua")
-- Test for lua files in the root -- Test for lua files in the root
local result, err = glob({ path = ".", pattern = "*.lua" }) local result, err = glob({ path = ".", pattern = "*.lua" }, {})
assert.is_nil(err) assert.is_nil(err)
local files = vim.json.decode(result).matches local files = vim.json.decode(result).matches
assert.equals(2, #files) assert.equals(2, #files)
@@ -380,7 +390,7 @@ describe("llm_tools", function()
assert.falsy(vim.tbl_contains(files, test_dir .. "/nested/file4.lua")) assert.falsy(vim.tbl_contains(files, test_dir .. "/nested/file4.lua"))
-- Test with recursive pattern -- Test with recursive pattern
local result2, err2 = glob({ path = ".", pattern = "**/*.lua" }) local result2, err2 = glob({ path = ".", pattern = "**/*.lua" }, {})
assert.is_nil(err2) assert.is_nil(err2)
local files2 = vim.json.decode(result2).matches local files2 = vim.json.decode(result2).matches
assert.equals(3, #files2) assert.equals(3, #files2)
@@ -390,13 +400,13 @@ describe("llm_tools", function()
end) end)
it("should respect path permissions", function() it("should respect path permissions", function()
local result, err = glob({ path = "../outside_project", pattern = "*.txt" }) local result, err = glob({ path = "../outside_project", pattern = "*.txt" }, {})
assert.equals("", result) assert.equals("", result)
assert.truthy(err:find("No permission to access path")) assert.truthy(err:find("No permission to access path"))
end) end)
it("should handle patterns without matches", function() it("should handle patterns without matches", function()
local result, err = glob({ path = ".", pattern = "*.nonexistent" }) local result, err = glob({ path = ".", pattern = "*.nonexistent" }, {})
assert.is_nil(err) assert.is_nil(err)
local files = vim.json.decode(result).matches local files = vim.json.decode(result).matches
assert.equals(0, #files) assert.equals(0, #files)
@@ -411,7 +421,7 @@ describe("llm_tools", function()
os.execute("touch " .. test_dir .. "/test_dir1/notignored1.lua") os.execute("touch " .. test_dir .. "/test_dir1/notignored1.lua")
os.execute("touch " .. test_dir .. "/test_dir1/notignored2.lua") os.execute("touch " .. test_dir .. "/test_dir1/notignored2.lua")
local result, err = glob({ path = ".", pattern = "**/*.lua" }) local result, err = glob({ path = ".", pattern = "**/*.lua" }, {})
assert.is_nil(err) assert.is_nil(err)
local files = vim.json.decode(result).matches local files = vim.json.decode(result).matches