refactor: llm tool parameters (#2449)
This commit is contained in:
@@ -47,7 +47,9 @@ local Highlights = {
|
||||
AVANTE_STATE_SPINNER_SEARCHING = { name = "AvanteStateSpinnerSearching", fg = "#1e222a", bg = "#c678dd" },
|
||||
AVANTE_STATE_SPINNER_THINKING = { name = "AvanteStateSpinnerThinking", fg = "#1e222a", bg = "#c678dd" },
|
||||
AVANTE_STATE_SPINNER_COMPACTING = { name = "AvanteStateSpinnerCompacting", fg = "#1e222a", bg = "#c678dd" },
|
||||
AVANTE_TASK_RUNNING = { name = "AvanteTaskRunning", fg = "#c678dd", bg_link = "Normal" },
|
||||
AVANTE_TASK_COMPLETED = { name = "AvanteTaskCompleted", fg = "#98c379", bg_link = "Normal" },
|
||||
AVANTE_TASK_FAILED = { name = "AvanteTaskFailed", fg = "#e06c75", bg_link = "Normal" },
|
||||
AVANTE_THINKING = { name = "AvanteThinking", fg = "#c678dd", bg_link = "Normal" },
|
||||
}
|
||||
|
||||
|
||||
@@ -150,8 +150,11 @@ function M.generate_todos(user_input, cb)
|
||||
local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages)
|
||||
for _, partial_tool_use in ipairs(uncalled_tool_uses) do
|
||||
if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then
|
||||
LLMTools.process_tool_use(tools, partial_tool_use, function() end, function() cb() end, {})
|
||||
cb()
|
||||
local result = LLMTools.process_tool_use(tools, partial_tool_use, {
|
||||
session_ctx = {},
|
||||
on_complete = function() cb() end,
|
||||
})
|
||||
if result ~= nil then cb() end
|
||||
end
|
||||
end
|
||||
else
|
||||
@@ -206,6 +209,7 @@ function M.agent_loop(opts)
|
||||
table.insert(history_messages, msg)
|
||||
end
|
||||
end
|
||||
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
||||
end,
|
||||
session_ctx = session_ctx,
|
||||
prompt_opts = {
|
||||
@@ -331,8 +335,9 @@ function M.generate_prompts(opts)
|
||||
end
|
||||
|
||||
if Config.system_prompt ~= nil then
|
||||
local custom_system_prompt = Config.system_prompt
|
||||
if type(custom_system_prompt) == "function" then custom_system_prompt = custom_system_prompt() end
|
||||
local custom_system_prompt
|
||||
if type(Config.system_prompt) == "function" then custom_system_prompt = Config.system_prompt() end
|
||||
if type(Config.system_prompt) == "string" then custom_system_prompt = Config.system_prompt end
|
||||
if custom_system_prompt ~= nil and custom_system_prompt ~= "" and custom_system_prompt ~= "null" then
|
||||
system_prompt = system_prompt .. "\n\n" .. custom_system_prompt
|
||||
end
|
||||
@@ -841,13 +846,9 @@ function M._stream(opts)
|
||||
if partial_tool_use.state == "generating" then
|
||||
if type(partial_tool_use.input) == "table" then
|
||||
partial_tool_use.input.streaming = true
|
||||
LLMTools.process_tool_use(
|
||||
prompt_opts.tools,
|
||||
partial_tool_use,
|
||||
function() end,
|
||||
function() end,
|
||||
opts.session_ctx
|
||||
)
|
||||
LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, {
|
||||
session_ctx = opts.session_ctx,
|
||||
})
|
||||
end
|
||||
return
|
||||
else
|
||||
@@ -856,13 +857,12 @@ function M._stream(opts)
|
||||
partial_tool_use_message.is_calling = true
|
||||
if opts.on_messages_add then opts.on_messages_add({ partial_tool_use_message }) end
|
||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||
local result, error = LLMTools.process_tool_use(
|
||||
prompt_opts.tools,
|
||||
partial_tool_use,
|
||||
opts.on_tool_log,
|
||||
handle_tool_result,
|
||||
opts.session_ctx
|
||||
)
|
||||
local result, error = LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, {
|
||||
session_ctx = opts.session_ctx,
|
||||
on_log = opts.on_tool_log,
|
||||
set_tool_use_store = opts.set_tool_use_store,
|
||||
on_complete = handle_tool_result,
|
||||
})
|
||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||
end
|
||||
if stop_opts.reason == "cancelled" then
|
||||
@@ -1100,6 +1100,13 @@ function M.stream(opts)
|
||||
return original_on_tool_log(...)
|
||||
end)
|
||||
end
|
||||
if opts.set_tool_use_store ~= nil then
|
||||
local original_set_tool_use_store = opts.set_tool_use_store
|
||||
opts.set_tool_use_store = vim.schedule_wrap(function(...)
|
||||
if not original_set_tool_use_store then return end
|
||||
return original_set_tool_use_store(...)
|
||||
end)
|
||||
end
|
||||
if opts.on_chunk ~= nil then
|
||||
local original_on_chunk = opts.on_chunk
|
||||
opts.on_chunk = vim.schedule_wrap(function(chunk)
|
||||
|
||||
@@ -65,10 +65,11 @@ M.returns = {
|
||||
M.on_render = function() return {} end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ todos: avante.TODO[] }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
function M.func(input, opts)
|
||||
local on_complete = opts.on_complete
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then return false, "Avante sidebar not found" end
|
||||
local todos = opts.todos
|
||||
local todos = input.todos
|
||||
if not todos or #todos == 0 then return false, "No todos provided" end
|
||||
sidebar:update_todos(todos)
|
||||
if on_complete then
|
||||
|
||||
@@ -57,11 +57,11 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type avante.LLMToolOnRender<AttemptCompletionInput>
|
||||
function M.on_render(opts)
|
||||
function M.on_render(input)
|
||||
local lines = {}
|
||||
table.insert(lines, Line:new({ { "✓ Task Completed", Highlights.AVANTE_TASK_COMPLETED } }))
|
||||
table.insert(lines, Line:new({ { "" } }))
|
||||
local result = opts.result or ""
|
||||
local result = input.result or ""
|
||||
local text_lines = vim.split(result, "\n")
|
||||
for _, text_line in ipairs(text_lines) do
|
||||
table.insert(lines, Line:new({ { text_line } }))
|
||||
@@ -70,24 +70,29 @@ function M.on_render(opts)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<AttemptCompletionInput>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
function M.func(input, opts)
|
||||
if not opts.on_complete then return false, "on_complete not provided" end
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then return false, "Avante sidebar not found" end
|
||||
|
||||
local is_streaming = opts.streaming or false
|
||||
local is_streaming = input.streaming or false
|
||||
if is_streaming then
|
||||
-- wait for stream completion as command may not be complete yet
|
||||
return
|
||||
end
|
||||
|
||||
session_ctx.attempt_completion_is_called = true
|
||||
opts.session_ctx.attempt_completion_is_called = true
|
||||
|
||||
if opts.command and opts.command ~= vim.NIL and opts.command ~= "" and not vim.startswith(opts.command, "open ") then
|
||||
session_ctx.always_yes = false
|
||||
require("avante.llm_tools.bash").func({ command = opts.command }, on_log, on_complete, session_ctx)
|
||||
if
|
||||
input.command
|
||||
and input.command ~= vim.NIL
|
||||
and input.command ~= ""
|
||||
and not vim.startswith(input.command, "open ")
|
||||
then
|
||||
opts.session_ctx.always_yes = false
|
||||
require("avante.llm_tools.bash").func({ command = input.command }, opts)
|
||||
else
|
||||
on_complete(true, nil)
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -215,18 +215,18 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, command: string, streaming?: boolean }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local is_streaming = opts.streaming or false
|
||||
function M.func(input, opts)
|
||||
local is_streaming = input.streaming or false
|
||||
if is_streaming then
|
||||
-- wait for stream completion as command may not be complete yet
|
||||
return
|
||||
end
|
||||
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
|
||||
if not opts.command then return false, "Command is required" end
|
||||
if on_log then on_log("command: " .. opts.command) end
|
||||
if not input.command then return false, "Command is required" end
|
||||
if opts.on_log then opts.on_log("command: " .. input.command) end
|
||||
|
||||
---change cwd to abs_path
|
||||
---@param output string
|
||||
@@ -240,21 +240,21 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
return output, nil
|
||||
end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
if not opts.on_complete then return false, "on_complete not provided" end
|
||||
Helpers.confirm(
|
||||
"Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path,
|
||||
"Are you sure you want to run the command: `" .. input.command .. "` in the directory: " .. abs_path,
|
||||
function(ok, reason)
|
||||
if not ok then
|
||||
on_complete(false, "User declined, reason: " .. (reason and reason or "unknown"))
|
||||
opts.on_complete(false, "User declined, reason: " .. (reason and reason or "unknown"))
|
||||
return
|
||||
end
|
||||
Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code)
|
||||
Utils.shell_run_async(input.command, "bash -c", function(output, exit_code)
|
||||
local result, err = handle_result(output, exit_code)
|
||||
on_complete(result, err)
|
||||
opts.on_complete(result, err)
|
||||
end, abs_path)
|
||||
end,
|
||||
{ focus = true },
|
||||
session_ctx,
|
||||
opts.session_ctx,
|
||||
M.name -- Pass the tool name for permission checking
|
||||
)
|
||||
end
|
||||
|
||||
@@ -45,20 +45,23 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, file_text: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local session_ctx = opts.session_ctx
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
if Helpers.already_in_context(opts.path) then
|
||||
if on_log then on_log("path: " .. input.path) end
|
||||
if Helpers.already_in_context(input.path) then
|
||||
on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to create it again?")
|
||||
return
|
||||
end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if opts.file_text == nil then return false, "file_text not provided" end
|
||||
if input.file_text == nil then return false, "file_text not provided" end
|
||||
if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end
|
||||
local lines = vim.split(opts.file_text, "\n")
|
||||
if #lines == 1 and opts.file_text:match("\\n") then
|
||||
local text = Utils.trim_slashes(opts.file_text)
|
||||
local lines = vim.split(input.file_text, "\n")
|
||||
if #lines == 1 and input.file_text:match("\\n") then
|
||||
local text = Utils.trim_slashes(input.file_text)
|
||||
lines = vim.split(text, "\n")
|
||||
end
|
||||
local bufnr, err = Helpers.get_bufnr(abs_path)
|
||||
|
||||
@@ -40,7 +40,7 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ tool_use_id: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
function M.func(input, opts)
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then return false, "Avante sidebar not found" end
|
||||
local history_messages = Utils.get_history_messages(sidebar.chat_history)
|
||||
@@ -50,7 +50,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local content = msg.message.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if item.id == opts.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end
|
||||
if item.id == input.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -3,6 +3,8 @@ local Config = require("avante.config")
|
||||
local Utils = require("avante.utils")
|
||||
local Base = require("avante.llm_tools.base")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
local Line = require("avante.ui.line")
|
||||
local Highlights = require("avante.highlights")
|
||||
|
||||
---@class AvanteLLMTool
|
||||
local M = setmetatable({}, Base)
|
||||
@@ -79,12 +81,118 @@ local function get_available_tools()
|
||||
}
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ prompt: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
---@class avante.DispatchAgentInput
|
||||
---@field prompt string
|
||||
|
||||
---@type avante.LLMToolOnRender<avante.DispatchAgentInput>
|
||||
function M.on_render(input, opts)
|
||||
local result_message = opts.result_message
|
||||
local store = opts.store or {}
|
||||
local messages = store.messages or {}
|
||||
local tool_use_summary = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
local content = msg.message.content
|
||||
local summary
|
||||
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
|
||||
local tool_result_message = Utils.get_tool_result_message(msg, messages)
|
||||
if tool_result_message then
|
||||
local tool_name = msg.message.content[1].name
|
||||
if tool_name == "ls" then
|
||||
local path = msg.message.content[1].input.path
|
||||
if tool_result_message.message.content[1].is_error then
|
||||
summary = string.format("Ls %s: failed", path)
|
||||
else
|
||||
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
|
||||
if ok then summary = string.format("Ls %s: %d paths", path, #filepaths) end
|
||||
end
|
||||
elseif tool_name == "grep" then
|
||||
local path = msg.message.content[1].input.path
|
||||
local query = msg.message.content[1].input.query
|
||||
if tool_result_message.message.content[1].is_error then
|
||||
summary = string.format("Grep %s in %s: failed", query, path)
|
||||
else
|
||||
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
|
||||
if ok then summary = string.format("Grep %s in %s: %d paths", query, path, #filepaths) end
|
||||
end
|
||||
elseif tool_name == "glob" then
|
||||
local path = msg.message.content[1].input.path
|
||||
local pattern = msg.message.content[1].input.pattern
|
||||
if tool_result_message.message.content[1].is_error then
|
||||
summary = string.format("Glob %s in %s: failed", pattern, path)
|
||||
else
|
||||
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
|
||||
if ok then
|
||||
local matches = result.matches
|
||||
if matches then summary = string.format("Glob %s in %s: %d matches", pattern, path, #matches) end
|
||||
end
|
||||
end
|
||||
elseif tool_name == "view" then
|
||||
local path = msg.message.content[1].input.path
|
||||
if tool_result_message.message.content[1].is_error then
|
||||
summary = string.format("View %s: failed", path)
|
||||
else
|
||||
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
|
||||
if ok then
|
||||
local content_ = result.content
|
||||
local lines = vim.split(content_, "\n")
|
||||
summary = string.format("View %s: %d lines", path, #lines)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if summary then summary = " " .. Utils.icon("🛠️ ") .. summary end
|
||||
elseif type(content) == "table" and #content > 0 and type(content[1]) == "table" and content[1].type == "text" then
|
||||
summary = content[1].content
|
||||
elseif type(content) == "table" and #content > 0 and type(content[1]) == "string" then
|
||||
summary = content[1]
|
||||
elseif type(content) == "string" then
|
||||
summary = content
|
||||
end
|
||||
if summary then table.insert(tool_use_summary, summary) end
|
||||
end
|
||||
local state = "running"
|
||||
local icon = Utils.icon("🔄 ")
|
||||
local hl = Highlights.AVANTE_TASK_RUNNING
|
||||
if result_message then
|
||||
if result_message.message.content[1].is_error then
|
||||
state = "failed"
|
||||
icon = Utils.icon("❌ ")
|
||||
hl = Highlights.AVANTE_TASK_FAILED
|
||||
else
|
||||
state = "completed"
|
||||
icon = Utils.icon("✅ ")
|
||||
hl = Highlights.AVANTE_TASK_COMPLETED
|
||||
end
|
||||
end
|
||||
local lines = {}
|
||||
table.insert(lines, Line:new({ { icon .. "Subtask " .. state, hl } }))
|
||||
table.insert(lines, Line:new({ { "" } }))
|
||||
table.insert(lines, Line:new({ { " Task:" } }))
|
||||
local prompt_lines = vim.split(input.prompt or "", "\n")
|
||||
for _, line in ipairs(prompt_lines) do
|
||||
table.insert(lines, Line:new({ { " " .. line } }))
|
||||
end
|
||||
table.insert(lines, Line:new({ { "" } }))
|
||||
table.insert(lines, Line:new({ { " Task summary:" } }))
|
||||
for _, summary in ipairs(tool_use_summary) do
|
||||
local summary_lines = vim.split(summary, "\n")
|
||||
for _, line in ipairs(summary_lines) do
|
||||
table.insert(lines, Line:new({ { " " .. line } }))
|
||||
end
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<avante.DispatchAgentInput>
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local session_ctx = opts.session_ctx
|
||||
|
||||
local Llm = require("avante.llm")
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
|
||||
local prompt = opts.prompt
|
||||
local prompt = input.prompt
|
||||
local tools = get_available_tools()
|
||||
local start_time = Utils.get_timestamp()
|
||||
|
||||
@@ -95,10 +203,11 @@ Your task is to help the user with their request: "${prompt}"
|
||||
Be thorough and use the tools available to you to find the most relevant information.
|
||||
When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt)
|
||||
|
||||
local history_messages = {}
|
||||
local tool_use_messages = {}
|
||||
|
||||
local total_tokens = 0
|
||||
local final_response = ""
|
||||
local result = ""
|
||||
|
||||
---@type avante.AgentLoopOptions
|
||||
local agent_loop_options = {
|
||||
@@ -108,19 +217,37 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
||||
on_tool_log = session_ctx.on_tool_log,
|
||||
on_messages_add = function(msgs)
|
||||
msgs = vim.islist(msgs) and msgs or { msgs }
|
||||
for _, msg in ipairs(msgs) do
|
||||
local idx = nil
|
||||
for i, m in ipairs(history_messages) do
|
||||
if m.uuid == msg.uuid then
|
||||
idx = i
|
||||
break
|
||||
end
|
||||
end
|
||||
if idx ~= nil then
|
||||
history_messages[idx] = msg
|
||||
else
|
||||
table.insert(history_messages, msg)
|
||||
end
|
||||
end
|
||||
if opts.set_store then opts.set_store("messages", history_messages) end
|
||||
for _, msg in ipairs(msgs) do
|
||||
local content = msg.message.content
|
||||
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
|
||||
tool_use_messages[msg.uuid] = true
|
||||
if content[1].name == "attempt_completion" then
|
||||
local input_ = content[1].input
|
||||
if input_ and input_.result then result = input_.result end
|
||||
end
|
||||
end
|
||||
end
|
||||
if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
|
||||
-- if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
|
||||
end,
|
||||
session_ctx = session_ctx,
|
||||
on_start = session_ctx.on_start,
|
||||
on_chunk = function(chunk)
|
||||
if not chunk then return end
|
||||
final_response = final_response .. chunk
|
||||
total_tokens = total_tokens + (#vim.split(chunk, " ") * 1.3)
|
||||
end,
|
||||
on_complete = function(err)
|
||||
@@ -148,8 +275,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
||||
})
|
||||
session_ctx.on_messages_add({ message })
|
||||
end
|
||||
local response = string.format("Final response:\n%s\n\nSummary:\n%s", summary, final_response)
|
||||
on_complete(response, nil)
|
||||
on_complete(result, nil)
|
||||
end,
|
||||
}
|
||||
|
||||
|
||||
@@ -40,10 +40,12 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, diff: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not opts.path then return false, "pathf are required" end
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
if not input.path then return false, "pathf are required" end
|
||||
if on_log then on_log("path: " .. input.path) end
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not on_complete then return false, "on_complete is required" end
|
||||
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(abs_path)
|
||||
|
||||
@@ -45,12 +45,14 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, pattern: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if on_log then on_log("pattern: " .. opts.pattern) end
|
||||
local files = vim.fn.glob(abs_path .. "/" .. opts.pattern, true, true)
|
||||
if on_log then on_log("pattern: " .. input.pattern) end
|
||||
local files = vim.fn.glob(abs_path .. "/" .. input.pattern, true, true)
|
||||
local truncated_files = {}
|
||||
local is_truncated = false
|
||||
local size = 0
|
||||
|
||||
@@ -69,8 +69,10 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end
|
||||
|
||||
@@ -85,31 +87,31 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local cmd = ""
|
||||
if search_cmd:find("rg") then
|
||||
cmd = string.format("%s --files-with-matches --hidden", search_cmd)
|
||||
if opts.case_sensitive then
|
||||
if input.case_sensitive then
|
||||
cmd = string.format("%s --case-sensitive", cmd)
|
||||
else
|
||||
cmd = string.format("%s --ignore-case", cmd)
|
||||
end
|
||||
if opts.include_pattern then cmd = string.format("%s --glob '%s'", cmd, opts.include_pattern) end
|
||||
if opts.exclude_pattern then cmd = string.format("%s --glob '!%s'", cmd, opts.exclude_pattern) end
|
||||
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path)
|
||||
if input.include_pattern then cmd = string.format("%s --glob '%s'", cmd, input.include_pattern) end
|
||||
if input.exclude_pattern then cmd = string.format("%s --glob '!%s'", cmd, input.exclude_pattern) end
|
||||
cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
|
||||
elseif search_cmd:find("ag") then
|
||||
cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd)
|
||||
if opts.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) end
|
||||
if opts.include_pattern then cmd = string.format("%s --ignore '!%s'", cmd, opts.include_pattern) end
|
||||
if opts.exclude_pattern then cmd = string.format("%s --ignore '%s'", cmd, opts.exclude_pattern) end
|
||||
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path)
|
||||
if input.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) end
|
||||
if input.include_pattern then cmd = string.format("%s --ignore '!%s'", cmd, input.include_pattern) end
|
||||
if input.exclude_pattern then cmd = string.format("%s --ignore '%s'", cmd, input.exclude_pattern) end
|
||||
cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
|
||||
elseif search_cmd:find("ack") then
|
||||
cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd)
|
||||
if opts.case_sensitive then cmd = string.format("%s --smart-case", cmd) end
|
||||
if opts.exclude_pattern then cmd = string.format("%s --ignore-dir '%s'", cmd, opts.exclude_pattern) end
|
||||
cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path)
|
||||
if input.case_sensitive then cmd = string.format("%s --smart-case", cmd) end
|
||||
if input.exclude_pattern then cmd = string.format("%s --ignore-dir '%s'", cmd, input.exclude_pattern) end
|
||||
cmd = string.format("%s '%s' %s", cmd, input.query, abs_path)
|
||||
elseif search_cmd:find("grep") then
|
||||
cmd = string.format("cd %s && git ls-files -co --exclude-standard | xargs %s -rH", abs_path, search_cmd, abs_path)
|
||||
if not opts.case_sensitive then cmd = string.format("%s -i", cmd) end
|
||||
if opts.include_pattern then cmd = string.format("%s --include '%s'", cmd, opts.include_pattern) end
|
||||
if opts.exclude_pattern then cmd = string.format("%s --exclude '%s'", cmd, opts.exclude_pattern) end
|
||||
cmd = string.format("%s '%s'", cmd, opts.query)
|
||||
if not input.case_sensitive then cmd = string.format("%s -i", cmd) end
|
||||
if input.include_pattern then cmd = string.format("%s --include '%s'", cmd, input.include_pattern) end
|
||||
if input.exclude_pattern then cmd = string.format("%s --exclude '%s'", cmd, input.exclude_pattern) end
|
||||
cmd = string.format("%s '%s'", cmd, input.query)
|
||||
end
|
||||
|
||||
Utils.debug("cmd", cmd)
|
||||
|
||||
@@ -8,9 +8,10 @@ local Helpers = require("avante.llm_tools.helpers")
|
||||
local M = {}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
|
||||
function M.read_file_toplevel_symbols(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local RepoMap = require("avante.repo_map")
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end
|
||||
@@ -24,50 +25,47 @@ function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }>
|
||||
function M.str_replace_editor(opts, on_log, on_complete, session_ctx)
|
||||
if opts.command == "undo_edit" then
|
||||
return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
---@cast opts any
|
||||
return M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
|
||||
function M.str_replace_editor(input, opts)
|
||||
if input.command == "undo_edit" then return require("avante.llm_tools.undo_edit").func(input, opts) end
|
||||
---@cast input any
|
||||
return M.str_replace_based_edit_tool(input, opts)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }>
|
||||
function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx)
|
||||
if not opts.command then return false, "command not provided" end
|
||||
if on_log then on_log("command: " .. opts.command) end
|
||||
function M.str_replace_based_edit_tool(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
if not input.command then return false, "command not provided" end
|
||||
if on_log then on_log("command: " .. input.command) end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if opts.command == "view" then
|
||||
if input.command == "view" then
|
||||
local view = require("avante.llm_tools.view")
|
||||
local opts_ = { path = opts.path }
|
||||
if opts.view_range then
|
||||
local start_line, end_line = unpack(opts.view_range)
|
||||
opts_.start_line = start_line
|
||||
opts_.end_line = end_line
|
||||
local input_ = { path = input.path }
|
||||
if input.view_range then
|
||||
local start_line, end_line = unpack(input.view_range)
|
||||
input_.start_line = start_line
|
||||
input_.end_line = end_line
|
||||
end
|
||||
return view(opts_, on_log, on_complete, session_ctx)
|
||||
return view(input_, opts)
|
||||
end
|
||||
if opts.command == "str_replace" then
|
||||
if opts.new_str == nil and opts.file_text ~= nil then
|
||||
opts.new_str = opts.file_text
|
||||
opts.file_text = nil
|
||||
if input.command == "str_replace" then
|
||||
if input.new_str == nil and input.file_text ~= nil then
|
||||
input.new_str = input.file_text
|
||||
input.file_text = nil
|
||||
end
|
||||
return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx)
|
||||
return require("avante.llm_tools.str_replace").func(input, opts)
|
||||
end
|
||||
if opts.command == "create" then
|
||||
return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
if opts.command == "insert" then
|
||||
return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
return false, "Unknown command: " .. opts.command
|
||||
if input.command == "create" then return require("avante.llm_tools.create").func(input, opts) end
|
||||
if input.command == "insert" then return require("avante.llm_tools.insert").func(input, opts) end
|
||||
return false, "Unknown command: " .. input.command
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ abs_path: string }>
|
||||
function M.read_global_file(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.abs_path)
|
||||
function M.read_global_file(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local abs_path = Helpers.get_abs_path(input.abs_path)
|
||||
if Helpers.is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
local file = io.open(abs_path, "r")
|
||||
@@ -78,11 +76,13 @@ function M.read_global_file(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ abs_path: string, content: string }>
|
||||
function M.write_global_file(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.abs_path)
|
||||
function M.write_global_file(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.abs_path)
|
||||
if Helpers.is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if on_log then on_log("content: " .. opts.content) end
|
||||
if on_log then on_log("content: " .. input.content) end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
Helpers.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok)
|
||||
if not ok then
|
||||
@@ -94,18 +94,20 @@ function M.write_global_file(opts, on_log, on_complete, session_ctx)
|
||||
on_complete(false, "file not found: " .. abs_path)
|
||||
return
|
||||
end
|
||||
file:write(opts.content)
|
||||
file:write(input.content)
|
||||
file:close()
|
||||
on_complete(true, nil)
|
||||
end, nil, session_ctx, "write_global_file")
|
||||
end, nil, opts.session_ctx, "write_global_file")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
|
||||
function M.move_path(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.source_path)
|
||||
function M.move_path(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.source_path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
|
||||
local new_abs_path = Helpers.get_abs_path(opts.destination_path)
|
||||
local new_abs_path = Helpers.get_abs_path(input.destination_path)
|
||||
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
|
||||
if not Helpers.has_permission_to_access(new_abs_path) then
|
||||
return false, "No permission to access path: " .. new_abs_path
|
||||
@@ -123,17 +125,19 @@ function M.move_path(opts, on_log, on_complete, session_ctx)
|
||||
on_complete(true, nil)
|
||||
end,
|
||||
nil,
|
||||
session_ctx,
|
||||
opts.session_ctx,
|
||||
"move_path"
|
||||
)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }>
|
||||
function M.copy_path(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.source_path)
|
||||
function M.copy_path(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.source_path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end
|
||||
local new_abs_path = Helpers.get_abs_path(opts.destination_path)
|
||||
local new_abs_path = Helpers.get_abs_path(input.destination_path)
|
||||
if not Helpers.has_permission_to_access(new_abs_path) then
|
||||
return false, "No permission to access path: " .. new_abs_path
|
||||
end
|
||||
@@ -169,14 +173,16 @@ function M.copy_path(opts, on_log, on_complete, session_ctx)
|
||||
on_complete(true, nil)
|
||||
end,
|
||||
nil,
|
||||
session_ctx,
|
||||
opts.session_ctx,
|
||||
"copy_path"
|
||||
)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.delete_path(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.delete_path(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
@@ -188,12 +194,14 @@ function M.delete_path(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("Deleting path: " .. abs_path) end
|
||||
os.remove(abs_path)
|
||||
on_complete(true, nil)
|
||||
end, nil, session_ctx, "delete_path")
|
||||
end, nil, opts.session_ctx, "delete_path")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.create_dir(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.create_dir(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
@@ -205,16 +213,17 @@ function M.create_dir(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("Creating directory: " .. abs_path) end
|
||||
Path:new(abs_path):mkdir({ parents = true })
|
||||
on_complete(true, nil)
|
||||
end, nil, session_ctx, "create_dir")
|
||||
end, nil, opts.session_ctx, "create_dir")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ query: string }>
|
||||
function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
function M.web_search(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local provider_type = Config.web_search_engine.provider
|
||||
local proxy = Config.web_search_engine.proxy
|
||||
if provider_type == nil then return nil, "Search engine provider is not set" end
|
||||
if on_log then on_log("provider: " .. provider_type) end
|
||||
if on_log then on_log("query: " .. opts.query) end
|
||||
if on_log then on_log("query: " .. input.query) end
|
||||
local search_engine = Config.web_search_engine.providers[provider_type]
|
||||
if search_engine == nil then return nil, "No search engine found: " .. provider_type end
|
||||
if provider_type ~= "searxng" and search_engine.api_key_name == "" then return nil, "No API key provided" end
|
||||
@@ -229,7 +238,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
["Authorization"] = "Bearer " .. api_key,
|
||||
},
|
||||
body = vim.json.encode(vim.tbl_deep_extend("force", {
|
||||
query = opts.query,
|
||||
query = input.query,
|
||||
}, search_engine.extra_request_body)),
|
||||
}
|
||||
if proxy then curl_opts.proxy = proxy end
|
||||
@@ -240,7 +249,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
elseif provider_type == "serpapi" then
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
api_key = api_key,
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -259,7 +268,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
elseif provider_type == "searchapi" then
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
api_key = api_key,
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -283,7 +292,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
key = api_key,
|
||||
cx = engine_id,
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -301,7 +310,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
return search_engine.format_response_body(jsn)
|
||||
elseif provider_type == "kagi" then
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -320,7 +329,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
return search_engine.format_response_body(jsn)
|
||||
elseif provider_type == "brave" then
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -343,7 +352,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
return nil, "Environment variable " .. search_engine.api_url_name .. " is not set"
|
||||
end
|
||||
local query_params = vim.tbl_deep_extend("force", {
|
||||
q = opts.query,
|
||||
q = input.query,
|
||||
}, search_engine.extra_request_body)
|
||||
local query_string = ""
|
||||
for key, value in pairs(query_params) do
|
||||
@@ -362,16 +371,18 @@ function M.web_search(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ url: string }>
|
||||
function M.fetch(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("url: " .. opts.url) end
|
||||
function M.fetch(input, opts)
|
||||
local on_log = opts.on_log
|
||||
if on_log then on_log("url: " .. input.url) end
|
||||
local Html2Md = require("avante.html2md")
|
||||
local res, err = Html2Md.fetch_md(opts.url)
|
||||
local res, err = Html2Md.fetch_md(input.url)
|
||||
if err then return nil, err end
|
||||
return res, nil
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ scope?: string }>
|
||||
function M.git_diff(opts, on_log, on_complete, session_ctx)
|
||||
function M.git_diff(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local git_cmd = vim.fn.exepath("git")
|
||||
if git_cmd == "" then return nil, "Git command not found" end
|
||||
local project_root = Utils.get_project_root()
|
||||
@@ -382,7 +393,7 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
|
||||
if git_dir == "" then return nil, "Not a git repository" end
|
||||
|
||||
-- Get the diff
|
||||
local scope = opts.scope or ""
|
||||
local scope = input.scope or ""
|
||||
local cmd = string.format("git diff --cached %s", scope)
|
||||
if on_log then on_log("Running command: " .. cmd) end
|
||||
local diff = vim.fn.system(cmd)
|
||||
@@ -400,7 +411,10 @@ function M.git_diff(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ message: string, scope?: string }>
|
||||
function M.git_commit(opts, on_log, on_complete, session_ctx)
|
||||
function M.git_commit(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
|
||||
local git_cmd = vim.fn.exepath("git")
|
||||
if git_cmd == "" then return false, "Git command not found" end
|
||||
local project_root = Utils.get_project_root()
|
||||
@@ -444,7 +458,7 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
|
||||
|
||||
-- Prepare commit message
|
||||
local commit_msg_lines = {}
|
||||
for line in opts.message:gmatch("[^\r\n]+") do
|
||||
for line in input.message:gmatch("[^\r\n]+") do
|
||||
commit_msg_lines[#commit_msg_lines + 1] = line:gsub('"', '\\"')
|
||||
end
|
||||
commit_msg_lines[#commit_msg_lines + 1] = ""
|
||||
@@ -464,8 +478,8 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
|
||||
return
|
||||
end
|
||||
-- Stage changes if scope is provided
|
||||
if opts.scope then
|
||||
local stage_cmd = string.format("git add %s", opts.scope)
|
||||
if input.scope then
|
||||
local stage_cmd = string.format("git add %s", input.scope)
|
||||
if on_log then on_log("Staging files: " .. stage_cmd) end
|
||||
local stage_result = vim.fn.system(stage_cmd)
|
||||
if vim.v.shell_error ~= 0 then
|
||||
@@ -494,20 +508,23 @@ function M.git_commit(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
on_complete(true, nil)
|
||||
end, nil, session_ctx, "git_commit")
|
||||
end, nil, opts.session_ctx, "git_commit")
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ query: string }>
|
||||
function M.rag_search(opts, on_log, on_complete, session_ctx)
|
||||
function M.rag_search(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
if not on_complete then return nil, "on_complete not provided" end
|
||||
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
|
||||
if not opts.query then return nil, "No query provided" end
|
||||
if on_log then on_log("query: " .. opts.query) end
|
||||
if not input.query then return nil, "No query provided" end
|
||||
if on_log then on_log("query: " .. input.query) end
|
||||
local root = Utils.get_project_root()
|
||||
local uri = "file://" .. root
|
||||
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
|
||||
RagService.retrieve(
|
||||
uri,
|
||||
opts.query,
|
||||
input.query,
|
||||
vim.schedule_wrap(function(resp, err)
|
||||
if err then
|
||||
on_complete(nil, err)
|
||||
@@ -519,13 +536,15 @@ function M.rag_search(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }>
|
||||
function M.python(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.python(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end
|
||||
if on_log then on_log("cwd: " .. abs_path) end
|
||||
if on_log then on_log("code:\n" .. opts.code) end
|
||||
local container_image = opts.container_image or "python:3.11-slim-bookworm"
|
||||
if on_log then on_log("code:\n" .. input.code) end
|
||||
local container_image = input.container_image or "python:3.11-slim-bookworm"
|
||||
if not on_complete then return nil, "on_complete not provided" end
|
||||
Helpers.confirm(
|
||||
"Are you sure you want to run the following python code in the `"
|
||||
@@ -533,7 +552,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
|
||||
.. "` container, in the directory: `"
|
||||
.. abs_path
|
||||
.. "`?\n"
|
||||
.. opts.code,
|
||||
.. input.code,
|
||||
function(ok, reason)
|
||||
if not ok then
|
||||
on_complete(nil, "User declined, reason: " .. (reason or "unknown"))
|
||||
@@ -562,7 +581,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
|
||||
container_image,
|
||||
"python",
|
||||
"-c",
|
||||
opts.code,
|
||||
input.code,
|
||||
},
|
||||
{
|
||||
text = true,
|
||||
@@ -576,7 +595,7 @@ function M.python(opts, on_log, on_complete, session_ctx)
|
||||
)
|
||||
end,
|
||||
nil,
|
||||
session_ctx,
|
||||
opts.session_ctx,
|
||||
"python"
|
||||
)
|
||||
end
|
||||
@@ -1189,14 +1208,19 @@ You can delete the first file by providing a path of "directory1/a/something.txt
|
||||
--- compatibility alias for old calls & tests
|
||||
M.run_python = M.python
|
||||
|
||||
---@class avante.ProcessToolUseOpts
|
||||
---@field session_ctx table
|
||||
---@field on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
|
||||
---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil
|
||||
---@field on_complete? fun(result: string | nil, error: string | nil): nil
|
||||
|
||||
---@param tools AvanteLLMTool[]
|
||||
---@param tool_use AvanteLLMToolUse
|
||||
---@param on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
|
||||
---@param on_complete? fun(result: string | nil, error: string | nil): nil
|
||||
---@param session_ctx? table
|
||||
---@return string | nil result
|
||||
---@return string | nil error
|
||||
function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
|
||||
function M.process_tool_use(tools, tool_use, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
-- Check if execution is already cancelled
|
||||
if Helpers.is_cancelled then
|
||||
Utils.debug("Tool execution cancelled before starting: " .. tool_use.name)
|
||||
@@ -1274,25 +1298,32 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
|
||||
return result_str, err
|
||||
end
|
||||
|
||||
local result, err = func(input_json, function(log)
|
||||
-- Check for cancellation during logging
|
||||
if Helpers.is_cancelled then return end
|
||||
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
|
||||
end, function(result, err)
|
||||
-- Check for cancellation before completing
|
||||
if Helpers.is_cancelled then
|
||||
Helpers.is_cancelled = false
|
||||
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
|
||||
return
|
||||
end
|
||||
local result, err = func(input_json, {
|
||||
session_ctx = opts.session_ctx or {},
|
||||
on_log = function(log)
|
||||
-- Check for cancellation during logging
|
||||
if Helpers.is_cancelled then return end
|
||||
if on_log then on_log(tool_use.id, tool_use.name, log, "running") end
|
||||
end,
|
||||
set_store = function(key, value)
|
||||
if opts.set_tool_use_store then opts.set_tool_use_store(tool_use.id, key, value) end
|
||||
end,
|
||||
on_complete = function(result, err)
|
||||
-- Check for cancellation before completing
|
||||
if Helpers.is_cancelled then
|
||||
Helpers.is_cancelled = false
|
||||
if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end
|
||||
return
|
||||
end
|
||||
|
||||
result, err = handle_result(result, err)
|
||||
if on_complete == nil then
|
||||
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
|
||||
return
|
||||
end
|
||||
on_complete(result, err)
|
||||
end, session_ctx)
|
||||
result, err = handle_result(result, err)
|
||||
if on_complete == nil then
|
||||
Utils.error("asynchronous tool " .. tool_use.name .. " result not handled")
|
||||
return
|
||||
end
|
||||
on_complete(result, err)
|
||||
end,
|
||||
})
|
||||
|
||||
-- Result and error being nil means that the tool was executed asynchronously
|
||||
if result == nil and err == nil and on_complete then return end
|
||||
|
||||
@@ -55,19 +55,24 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, insert_line: integer, new_str: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local session_ctx = opts.session_ctx
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
|
||||
if on_log then on_log("path: " .. input.path) end
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
||||
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
|
||||
if opts.insert_line == nil then return false, "insert_line not provided" end
|
||||
if opts.new_str == nil then return false, "new_str not provided" end
|
||||
if input.insert_line == nil then return false, "insert_line not provided" end
|
||||
if input.new_str == nil then return false, "new_str not provided" end
|
||||
local ns_id = vim.api.nvim_create_namespace("avante_insert_diff")
|
||||
local bufnr, err = Helpers.get_bufnr(abs_path)
|
||||
if err then return false, err end
|
||||
local function clear_highlights() vim.api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) end
|
||||
local new_lines = vim.split(opts.new_str, "\n")
|
||||
local new_lines = vim.split(input.new_str, "\n")
|
||||
local max_col = vim.o.columns
|
||||
local virt_lines = vim
|
||||
.iter(new_lines)
|
||||
@@ -78,8 +83,8 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end)
|
||||
:totable()
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
if opts.insert_line > line_count - 1 then opts.insert_line = line_count - 1 end
|
||||
vim.api.nvim_buf_set_extmark(bufnr, ns_id, opts.insert_line, 0, {
|
||||
if input.insert_line > line_count - 1 then input.insert_line = line_count - 1 end
|
||||
vim.api.nvim_buf_set_extmark(bufnr, ns_id, input.insert_line, 0, {
|
||||
virt_lines = virt_lines,
|
||||
hl_eol = true,
|
||||
hl_mode = "combine",
|
||||
@@ -90,9 +95,9 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
on_complete(false, "User declined, reason: " .. (reason or "unknown"))
|
||||
return
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(bufnr, opts.insert_line, opts.insert_line, false, new_lines)
|
||||
vim.api.nvim_buf_set_lines(bufnr, input.insert_line, input.insert_line, false, new_lines)
|
||||
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
|
||||
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
|
||||
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
|
||||
on_complete(true, nil)
|
||||
end, { focus = true }, session_ctx, M.name)
|
||||
end
|
||||
|
||||
@@ -46,15 +46,16 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, max_depth?: integer }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if on_log then on_log("max depth: " .. tostring(opts.max_depth)) end
|
||||
if on_log then on_log("max depth: " .. tostring(input.max_depth)) end
|
||||
local files = Utils.scan_directory({
|
||||
directory = abs_path,
|
||||
add_dirs = true,
|
||||
max_depth = opts.max_depth,
|
||||
max_depth = input.max_depth,
|
||||
})
|
||||
local filepaths = {}
|
||||
for _, file in ipairs(files) do
|
||||
|
||||
@@ -125,39 +125,44 @@ end
|
||||
|
||||
--- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view.
|
||||
---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if opts.the_diff ~= nil then
|
||||
opts.diff = opts.the_diff
|
||||
opts.the_diff = nil
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local session_ctx = opts.session_ctx
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
|
||||
if input.the_diff ~= nil then
|
||||
input.diff = input.the_diff
|
||||
input.the_diff = nil
|
||||
end
|
||||
if not opts.path or not opts.diff then return false, "path and diff are required " .. vim.inspect(opts) end
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end
|
||||
if on_log then on_log("path: " .. input.path) end
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
|
||||
local is_streaming = opts.streaming or false
|
||||
local is_streaming = input.streaming or false
|
||||
|
||||
session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {}
|
||||
local current_timestamp = os.time()
|
||||
if is_streaming then
|
||||
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id]
|
||||
local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id]
|
||||
if prev_streaming_diff_timestamp ~= nil then
|
||||
if current_timestamp - prev_streaming_diff_timestamp < 2 then
|
||||
return false, "Diff hasn't changed in the last 2 seconds"
|
||||
end
|
||||
end
|
||||
local streaming_diff_lines_count = Utils.count_lines(opts.diff)
|
||||
local streaming_diff_lines_count = Utils.count_lines(input.diff)
|
||||
session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {}
|
||||
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id]
|
||||
local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[input.tool_use_id]
|
||||
if streaming_diff_lines_count == prev_streaming_diff_lines_count then
|
||||
return false, "Diff lines count hasn't changed"
|
||||
end
|
||||
session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count
|
||||
session_ctx.streaming_diff_lines_count_history[input.tool_use_id] = streaming_diff_lines_count
|
||||
end
|
||||
|
||||
local diff = fix_diff(opts.diff)
|
||||
local diff = fix_diff(input.diff)
|
||||
|
||||
if on_log and diff ~= opts.diff then on_log("diff fixed") end
|
||||
if on_log and diff ~= input.diff then on_log("diff fixed") end
|
||||
|
||||
local diff_lines = vim.split(diff, "\n")
|
||||
|
||||
@@ -203,16 +208,16 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
return false, "No diff blocks found"
|
||||
end
|
||||
|
||||
session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] = current_timestamp
|
||||
session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] = current_timestamp
|
||||
|
||||
local bufnr, err = Helpers.get_bufnr(abs_path)
|
||||
if err then return false, err end
|
||||
|
||||
session_ctx.undo_joined = session_ctx.undo_joined or {}
|
||||
local undo_joined = session_ctx.undo_joined[opts.tool_use_id]
|
||||
local undo_joined = session_ctx.undo_joined[input.tool_use_id]
|
||||
if not undo_joined then
|
||||
pcall(vim.cmd.undojoin)
|
||||
session_ctx.undo_joined[opts.tool_use_id] = true
|
||||
session_ctx.undo_joined[input.tool_use_id] = true
|
||||
end
|
||||
|
||||
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
@@ -242,10 +247,10 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
|
||||
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {}
|
||||
local rough_diff_blocks_to_diff_blocks_cache =
|
||||
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id]
|
||||
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id]
|
||||
if not rough_diff_blocks_to_diff_blocks_cache then
|
||||
rough_diff_blocks_to_diff_blocks_cache = {}
|
||||
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
|
||||
session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache
|
||||
end
|
||||
|
||||
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
|
||||
@@ -472,7 +477,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
on_complete(false, "User canceled")
|
||||
return
|
||||
end
|
||||
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
|
||||
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
|
||||
on_complete(true, nil)
|
||||
end,
|
||||
})
|
||||
@@ -615,35 +620,35 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
session_ctx.extmark_id_map = session_ctx.extmark_id_map or {}
|
||||
local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id]
|
||||
local extmark_id_map = session_ctx.extmark_id_map[input.tool_use_id]
|
||||
if not extmark_id_map then
|
||||
extmark_id_map = {}
|
||||
session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map
|
||||
session_ctx.extmark_id_map[input.tool_use_id] = extmark_id_map
|
||||
end
|
||||
session_ctx.virt_lines_map = session_ctx.virt_lines_map or {}
|
||||
local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id]
|
||||
local virt_lines_map = session_ctx.virt_lines_map[input.tool_use_id]
|
||||
if not virt_lines_map then
|
||||
virt_lines_map = {}
|
||||
session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map
|
||||
session_ctx.virt_lines_map[input.tool_use_id] = virt_lines_map
|
||||
end
|
||||
|
||||
session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {}
|
||||
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id]
|
||||
local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[input.tool_use_id]
|
||||
if not last_orig_diff_end_line then
|
||||
last_orig_diff_end_line = 1
|
||||
session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line
|
||||
session_ctx.last_orig_diff_end_line_map[input.tool_use_id] = last_orig_diff_end_line
|
||||
end
|
||||
session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {}
|
||||
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id]
|
||||
local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[input.tool_use_id]
|
||||
if not last_resp_diff_end_line then
|
||||
last_resp_diff_end_line = 1
|
||||
session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line
|
||||
session_ctx.last_resp_diff_end_line_map[input.tool_use_id] = last_resp_diff_end_line
|
||||
end
|
||||
session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {}
|
||||
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id]
|
||||
local prev_diff_blocks = session_ctx.prev_diff_blocks_map[input.tool_use_id]
|
||||
if not prev_diff_blocks then
|
||||
prev_diff_blocks = {}
|
||||
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks
|
||||
session_ctx.prev_diff_blocks_map[input.tool_use_id] = prev_diff_blocks
|
||||
end
|
||||
|
||||
local function get_unstable_diff_blocks(diff_blocks_)
|
||||
@@ -663,7 +668,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
|
||||
local function highlight_streaming_diff_blocks()
|
||||
local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks)
|
||||
session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks
|
||||
session_ctx.prev_diff_blocks_map[input.tool_use_id] = diff_blocks
|
||||
local max_col = vim.o.columns
|
||||
for _, diff_block in ipairs(unstable_diff_blocks) do
|
||||
local new_lines = diff_block.new_lines
|
||||
@@ -747,7 +752,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
--- check if the parent dir is exists, if not, create it
|
||||
if vim.fn.isdirectory(parent_dir) == 0 then vim.fn.mkdir(parent_dir, "p") end
|
||||
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
|
||||
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
|
||||
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
|
||||
on_complete(true, nil)
|
||||
end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx, M.name)
|
||||
end
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
local Base = require("avante.llm_tools.base")
|
||||
local Config = require("avante.config")
|
||||
|
||||
---@class AvanteLLMTool
|
||||
local M = setmetatable({}, Base)
|
||||
@@ -55,17 +54,17 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
function M.func(input, opts)
|
||||
local replace_in_file = require("avante.llm_tools.replace_in_file")
|
||||
local diff = "------- SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str
|
||||
if not opts.streaming then diff = diff .. "\n+++++++ REPLACE" end
|
||||
local new_opts = {
|
||||
path = opts.path,
|
||||
local diff = "------- SEARCH\n" .. input.old_str .. "\n=======\n" .. input.new_str
|
||||
if not input.streaming then diff = diff .. "\n+++++++ REPLACE" end
|
||||
local new_input = {
|
||||
path = input.path,
|
||||
diff = diff,
|
||||
streaming = opts.streaming,
|
||||
tool_use_id = opts.tool_use_id,
|
||||
streaming = input.streaming,
|
||||
tool_use_id = input.tool_use_id,
|
||||
}
|
||||
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx)
|
||||
return replace_in_file.func(new_input, opts)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -47,12 +47,13 @@ M.returns = {
|
||||
---@field thought string
|
||||
|
||||
---@type avante.LLMToolOnRender<ThinkingInput>
|
||||
function M.on_render(opts, _, state)
|
||||
function M.on_render(input, opts)
|
||||
local state = opts.state
|
||||
local lines = {}
|
||||
local text = state == "generating" and "Thinking" or "Thoughts"
|
||||
table.insert(lines, Line:new({ { Utils.icon("🤔 ") .. text, Highlights.AVANTE_THINKING } }))
|
||||
table.insert(lines, Line:new({ { "" } }))
|
||||
local content = opts.thought or ""
|
||||
local content = input.thought or ""
|
||||
local text_lines = vim.split(content, "\n")
|
||||
for _, text_line in ipairs(text_lines) do
|
||||
table.insert(lines, Line:new({ { "> " .. text_line } }))
|
||||
@@ -61,7 +62,8 @@ function M.on_render(opts, _, state)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<ThinkingInput>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
function M.func(input, opts)
|
||||
local on_complete = opts.on_complete
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
on_complete(true, nil)
|
||||
end
|
||||
|
||||
@@ -43,9 +43,14 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local session_ctx = opts.session_ctx
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
|
||||
if on_log then on_log("path: " .. input.path) end
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
||||
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
|
||||
@@ -59,7 +64,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
vim.api.nvim_win_call(winid, function() vim.cmd("noautocmd undo") end)
|
||||
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
|
||||
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
|
||||
if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end
|
||||
on_complete(true, nil)
|
||||
end, { focus = true }, session_ctx, M.name)
|
||||
end
|
||||
|
||||
@@ -43,14 +43,15 @@ M.returns = {
|
||||
M.on_render = function() return {} end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ id: string, status: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
function M.func(input, opts)
|
||||
local on_complete = opts.on_complete
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then return false, "Avante sidebar not found" end
|
||||
local todos = sidebar.chat_history.todos
|
||||
if not todos or #todos == 0 then return false, "No todos found" end
|
||||
for _, todo in ipairs(todos) do
|
||||
if todo.id == opts.id then
|
||||
todo.status = opts.status
|
||||
if todo.id == input.id then
|
||||
todo.status = input.status
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
@@ -87,18 +87,20 @@ M.returns = {
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, start_line?: integer, end_line?: integer }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not opts.path then return false, "path is required" end
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
if not input.path then return false, "path is required" end
|
||||
if on_log then on_log("path: " .. input.path) end
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
|
||||
if Path:new(abs_path):is_dir() then return false, "Path is a directory: " .. abs_path end
|
||||
local file = io.open(abs_path, "r")
|
||||
if not file then return false, "file not found: " .. abs_path end
|
||||
local lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||
local start_line = opts.start_line
|
||||
local end_line = opts.end_line
|
||||
local start_line = input.start_line
|
||||
local end_line = input.end_line
|
||||
if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end
|
||||
local truncated_lines = {}
|
||||
local is_truncated = false
|
||||
|
||||
@@ -57,27 +57,26 @@ M.returns = {
|
||||
|
||||
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
|
||||
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if opts.the_content ~= nil then
|
||||
opts.content = opts.the_content
|
||||
opts.the_content = nil
|
||||
function M.func(input, opts)
|
||||
if input.the_content ~= nil then
|
||||
input.content = input.the_content
|
||||
input.the_content = nil
|
||||
end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
local abs_path = Helpers.get_abs_path(input.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if opts.content == nil then return false, "content not provided" end
|
||||
if type(opts.content) ~= "string" then opts.content = vim.json.encode(opts.content) end
|
||||
if input.content == nil then return false, "content not provided" end
|
||||
if type(input.content) ~= "string" then input.content = vim.json.encode(input.content) end
|
||||
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||
local old_content = table.concat(old_lines or {}, "\n")
|
||||
local str_replace = require("avante.llm_tools.str_replace")
|
||||
local new_opts = {
|
||||
path = opts.path,
|
||||
local new_input = {
|
||||
path = input.path,
|
||||
old_str = old_content,
|
||||
new_str = opts.content,
|
||||
streaming = opts.streaming,
|
||||
tool_use_id = opts.tool_use_id,
|
||||
new_str = input.content,
|
||||
streaming = input.streaming,
|
||||
tool_use_id = input.tool_use_id,
|
||||
}
|
||||
return str_replace.func(new_opts, on_log, on_complete, session_ctx)
|
||||
return str_replace.func(new_input, opts)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -2318,7 +2318,7 @@ function Sidebar:get_history_messages_for_api(opts)
|
||||
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
|
||||
if is_edit_func_call and path and not message.message.content[1].is_error then
|
||||
local uniformed_path = Utils.uniform_path(path)
|
||||
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil)
|
||||
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {})
|
||||
if view_error then view_result = "Error: " .. view_error end
|
||||
local get_diagnostics_tool_use_id = Utils.uuid()
|
||||
local view_tool_use_id = Utils.uuid()
|
||||
@@ -2451,9 +2451,7 @@ function Sidebar:get_history_messages_for_api(opts)
|
||||
local end_line = tool_id_to_end_line[item.tool_use_id]
|
||||
local view_result, view_error = require("avante.llm_tools.view").func(
|
||||
{ path = path, start_line = start_line, end_line = end_line },
|
||||
nil,
|
||||
nil,
|
||||
nil
|
||||
{}
|
||||
)
|
||||
if view_error then view_result = "Error: " .. view_error end
|
||||
item.content = view_result
|
||||
@@ -2773,6 +2771,23 @@ function Sidebar:create_input_container()
|
||||
self:save_history()
|
||||
end
|
||||
|
||||
local function set_tool_use_store(tool_id, key, value)
|
||||
local tool_use_message = nil
|
||||
for idx = #self.chat_history.messages, 1, -1 do
|
||||
local message = self.chat_history.messages[idx]
|
||||
local content = message.message.content
|
||||
if type(content) == "table" and content[1].type == "tool_use" and content[1].id == tool_id then
|
||||
tool_use_message = message
|
||||
break
|
||||
end
|
||||
end
|
||||
if not tool_use_message then return end
|
||||
local tool_use_store = tool_use_message.tool_use_store or {}
|
||||
tool_use_store[key] = value
|
||||
tool_use_message.tool_use_store = tool_use_store
|
||||
self:save_history()
|
||||
end
|
||||
|
||||
---@type AvanteLLMStopCallback
|
||||
local function on_stop(stop_opts)
|
||||
self.is_generating = false
|
||||
@@ -2837,6 +2852,7 @@ function Sidebar:create_input_container()
|
||||
on_tool_log = on_tool_log,
|
||||
on_messages_add = on_messages_add,
|
||||
on_state_change = on_state_change,
|
||||
set_tool_use_store = set_tool_use_store,
|
||||
get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end,
|
||||
get_todos = function()
|
||||
local history = Path.history.load(self.code.bufnr)
|
||||
|
||||
@@ -107,6 +107,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field selected_code AvanteSelectedCode | nil
|
||||
---@field selected_filepaths string[] | nil
|
||||
---@field tool_use_logs string[] | nil
|
||||
---@field tool_use_store table | nil
|
||||
---@field just_for_display boolean | nil
|
||||
---@field is_dummy boolean | nil
|
||||
---@field is_compacted boolean | nil
|
||||
@@ -407,19 +408,30 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field on_stop AvanteLLMStopCallback
|
||||
---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback
|
||||
---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
|
||||
---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil
|
||||
---@field get_history_messages? fun(opts?: { all?: boolean }): avante.HistoryMessage[]
|
||||
---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil
|
||||
---@field on_state_change? fun(state: avante.GenerateState): nil
|
||||
---@field update_tokens_usage? fun(usage: avante.LLMTokenUsage): nil
|
||||
---
|
||||
---@class AvanteLLMToolFuncOpts
|
||||
---@field session_ctx table
|
||||
---@field on_complete? fun(result: boolean | string | nil, error: string | nil): nil
|
||||
---@field on_log? fun(log: string): nil
|
||||
---@field set_store? fun(key: string, value: any): nil
|
||||
---
|
||||
---@alias AvanteLLMToolFunc<T> fun(
|
||||
--- input: T,
|
||||
--- on_log?: (fun(log: string): nil),
|
||||
--- on_complete?: (fun(result: boolean | string | nil, error: string | nil): nil),
|
||||
--- session_ctx?: table)
|
||||
--- opts: AvanteLLMToolFuncOpts)
|
||||
--- : (boolean | string | nil, string | nil)
|
||||
---
|
||||
--- @alias avante.LLMToolOnRender<T> fun(input: T, logs: string[], state: avante.HistoryMessageState | nil): avante.ui.Line[]
|
||||
---@class avante.LLMToolOnRenderOpts
|
||||
---@field logs string[]
|
||||
---@field state avante.HistoryMessageState
|
||||
---@field store table | nil
|
||||
---@field result_message avante.HistoryMessage | nil
|
||||
---
|
||||
--- @alias avante.LLMToolOnRender<T> fun(input: T, opts: avante.LLMToolOnRenderOpts): avante.ui.Line[]
|
||||
---
|
||||
---@class AvanteLLMTool
|
||||
---@field name string
|
||||
|
||||
@@ -1674,14 +1674,22 @@ function M.message_content_item_to_lines(item, message, messages)
|
||||
return { Line:new({ { "" } }) }
|
||||
end
|
||||
if item.type == "tool_use" then
|
||||
local tool_result_message = M.get_tool_result_message(message, messages)
|
||||
local lines = {}
|
||||
local state = "generating"
|
||||
local hl = "AvanteStateSpinnerToolCalling"
|
||||
local ok, llm_tool = pcall(require, "avante.llm_tools." .. item.name)
|
||||
if ok then
|
||||
if llm_tool.on_render then return llm_tool.on_render(item.input, message.tool_use_logs, message.state) end
|
||||
---@cast llm_tool AvanteLLMTool
|
||||
if llm_tool.on_render then
|
||||
return llm_tool.on_render(item.input, {
|
||||
logs = message.tool_use_logs,
|
||||
state = message.state,
|
||||
store = message.tool_use_store,
|
||||
result_message = tool_result_message,
|
||||
})
|
||||
end
|
||||
end
|
||||
local tool_result_message = M.get_tool_result_message(message, messages)
|
||||
if tool_result_message then
|
||||
local tool_result = tool_result_message.message.content[1]
|
||||
if tool_result.is_error then
|
||||
|
||||
@@ -53,21 +53,21 @@ describe("llm_tools", function()
|
||||
|
||||
describe("ls", function()
|
||||
it("should list files in directory", function()
|
||||
local result, err = ls({ path = ".", max_depth = 1 })
|
||||
local result, err = ls({ path = ".", max_depth = 1 }, {})
|
||||
assert.is_nil(err)
|
||||
assert.falsy(result:find("avante.nvim"))
|
||||
assert.truthy(result:find("test.txt"))
|
||||
assert.falsy(result:find("test1.txt"))
|
||||
end)
|
||||
it("should list files in directory with depth", function()
|
||||
local result, err = ls({ path = ".", max_depth = 2 })
|
||||
local result, err = ls({ path = ".", max_depth = 2 }, {})
|
||||
assert.is_nil(err)
|
||||
assert.falsy(result:find("avante.nvim"))
|
||||
assert.truthy(result:find("test.txt"))
|
||||
assert.truthy(result:find("test1.txt"))
|
||||
end)
|
||||
it("should list files respecting gitignore", function()
|
||||
local result, err = ls({ path = ".", max_depth = 2 })
|
||||
local result, err = ls({ path = ".", max_depth = 2 }, {})
|
||||
assert.is_nil(err)
|
||||
assert.falsy(result:find("avante.nvim"))
|
||||
assert.truthy(result:find("test.txt"))
|
||||
@@ -78,49 +78,61 @@ describe("llm_tools", function()
|
||||
|
||||
describe("view", function()
|
||||
it("should read file content", function()
|
||||
view({ path = "test.txt" }, nil, function(content, err)
|
||||
assert.is_nil(err)
|
||||
assert.equals("test content", vim.json.decode(content).content)
|
||||
end)
|
||||
view({ path = "test.txt" }, {
|
||||
on_complete = function(content, err)
|
||||
assert.is_nil(err)
|
||||
assert.equals("test content", vim.json.decode(content).content)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
|
||||
it("should return error for non-existent file", function()
|
||||
view({ path = "non_existent.txt" }, nil, function(content, err)
|
||||
assert.truthy(err)
|
||||
assert.equals("", content)
|
||||
end)
|
||||
view({ path = "non_existent.txt" }, {
|
||||
on_complete = function(content, err)
|
||||
assert.truthy(err)
|
||||
assert.equals("", content)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
|
||||
it("should read directory content", function()
|
||||
view({ path = test_dir }, nil, function(content, err)
|
||||
assert.is_nil(err)
|
||||
assert.truthy(content:find("test.txt"))
|
||||
assert.truthy(content:find("test content"))
|
||||
end)
|
||||
view({ path = test_dir }, {
|
||||
on_complete = function(content, err)
|
||||
assert.is_nil(err)
|
||||
assert.truthy(content:find("test.txt"))
|
||||
assert.truthy(content:find("test content"))
|
||||
end,
|
||||
})
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("create_dir", function()
|
||||
it("should create new directory", function()
|
||||
LlmTools.create_dir({ path = "new_dir" }, nil, function(success, err)
|
||||
assert.is_nil(err)
|
||||
assert.is_true(success)
|
||||
LlmTools.create_dir({ path = "new_dir" }, {
|
||||
session_ctx = {},
|
||||
on_complete = function(success, err)
|
||||
assert.is_nil(err)
|
||||
assert.is_true(success)
|
||||
|
||||
local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil
|
||||
assert.is_true(dir_exists)
|
||||
end)
|
||||
local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil
|
||||
assert.is_true(dir_exists)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("delete_path", function()
|
||||
it("should delete existing file", function()
|
||||
LlmTools.delete_path({ path = "test.txt" }, nil, function(success, err)
|
||||
assert.is_nil(err)
|
||||
assert.is_true(success)
|
||||
LlmTools.delete_path({ path = "test.txt" }, {
|
||||
session_ctx = {},
|
||||
on_complete = function(success, err)
|
||||
assert.is_nil(err)
|
||||
assert.is_true(success)
|
||||
|
||||
local file_exists = io.open(test_file, "r") ~= nil
|
||||
assert.is_false(file_exists)
|
||||
end)
|
||||
local file_exists = io.open(test_file, "r") ~= nil
|
||||
assert.is_false(file_exists)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
end)
|
||||
|
||||
@@ -147,22 +159,22 @@ describe("llm_tools", function()
|
||||
file:write("this is nothing")
|
||||
file:close()
|
||||
|
||||
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false })
|
||||
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }, {})
|
||||
assert.is_nil(err)
|
||||
assert.truthy(result:find("searchable.txt"))
|
||||
assert.falsy(result:find("nothing.txt"))
|
||||
|
||||
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true })
|
||||
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }, {})
|
||||
assert.is_nil(err2)
|
||||
assert.truthy(result2:find("searchable.txt"))
|
||||
assert.falsy(result2:find("nothing.txt"))
|
||||
|
||||
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true })
|
||||
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }, {})
|
||||
assert.is_nil(err3)
|
||||
assert.falsy(result3:find("searchable.txt"))
|
||||
assert.falsy(result3:find("nothing.txt"))
|
||||
|
||||
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false })
|
||||
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }, {})
|
||||
assert.is_nil(err4)
|
||||
assert.truthy(result4:find("searchable.txt"))
|
||||
assert.falsy(result4:find("nothing.txt"))
|
||||
@@ -172,7 +184,7 @@ describe("llm_tools", function()
|
||||
query = "searchable",
|
||||
case_sensitive = false,
|
||||
exclude_pattern = "search*",
|
||||
})
|
||||
}, {})
|
||||
assert.is_nil(err5)
|
||||
assert.falsy(result5:find("searchable.txt"))
|
||||
assert.falsy(result5:find("nothing.txt"))
|
||||
@@ -191,7 +203,7 @@ describe("llm_tools", function()
|
||||
file:write("content for ag test")
|
||||
file:close()
|
||||
|
||||
local result, err = grep({ path = ".", query = "ag test" })
|
||||
local result, err = grep({ path = ".", query = "ag test" }, {})
|
||||
assert.is_nil(err)
|
||||
assert.is_string(result)
|
||||
assert.truthy(result:find("ag_test.txt"))
|
||||
@@ -215,22 +227,22 @@ describe("llm_tools", function()
|
||||
file:write("this is nothing")
|
||||
file:close()
|
||||
|
||||
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false })
|
||||
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }, {})
|
||||
assert.is_nil(err)
|
||||
assert.truthy(result:find("searchable.txt"))
|
||||
assert.falsy(result:find("nothing.txt"))
|
||||
|
||||
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true })
|
||||
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }, {})
|
||||
assert.is_nil(err2)
|
||||
assert.truthy(result2:find("searchable.txt"))
|
||||
assert.falsy(result2:find("nothing.txt"))
|
||||
|
||||
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true })
|
||||
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }, {})
|
||||
assert.is_nil(err3)
|
||||
assert.falsy(result3:find("searchable.txt"))
|
||||
assert.falsy(result3:find("nothing.txt"))
|
||||
|
||||
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false })
|
||||
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }, {})
|
||||
assert.is_nil(err4)
|
||||
assert.truthy(result4:find("searchable.txt"))
|
||||
assert.falsy(result4:find("nothing.txt"))
|
||||
@@ -240,7 +252,7 @@ describe("llm_tools", function()
|
||||
query = "searchable",
|
||||
case_sensitive = false,
|
||||
exclude_pattern = "search*",
|
||||
})
|
||||
}, {})
|
||||
assert.is_nil(err5)
|
||||
assert.falsy(result5:find("searchable.txt"))
|
||||
assert.falsy(result5:find("nothing.txt"))
|
||||
@@ -250,18 +262,18 @@ describe("llm_tools", function()
|
||||
-- Mock exepath to return nothing
|
||||
vim.fn.exepath = function() return "" end
|
||||
|
||||
local result, err = grep({ path = ".", query = "test" })
|
||||
local result, err = grep({ path = ".", query = "test" }, {})
|
||||
assert.equals("", result)
|
||||
assert.equals("No search command found", err)
|
||||
end)
|
||||
|
||||
it("should respect path permissions", function()
|
||||
local result, err = grep({ path = "../outside_project", query = "test" })
|
||||
local result, err = grep({ path = "../outside_project", query = "test" }, {})
|
||||
assert.truthy(err:find("No permission to access path"))
|
||||
end)
|
||||
|
||||
it("should handle non-existent paths", function()
|
||||
local result, err = grep({ path = "non_existent_dir", query = "test" })
|
||||
local result, err = grep({ path = "non_existent_dir", query = "test" }, {})
|
||||
assert.equals("", result)
|
||||
assert.truthy(err)
|
||||
assert.truthy(err:find("No such file or directory"))
|
||||
@@ -277,86 +289,84 @@ describe("llm_tools", function()
|
||||
-- end)
|
||||
|
||||
it("should return error when running outside current directory", function()
|
||||
bash({ path = "../outside_project", command = "echo 'test'" }, nil, function(result, err)
|
||||
assert.is_false(result)
|
||||
assert.truthy(err)
|
||||
assert.truthy(err:find("No permission to access path"))
|
||||
end)
|
||||
bash({ path = "../outside_project", command = "echo 'test'" }, {
|
||||
session_ctx = {},
|
||||
on_complete = function(result, err)
|
||||
assert.is_false(result)
|
||||
assert.truthy(err)
|
||||
assert.truthy(err:find("No permission to access path"))
|
||||
end,
|
||||
})
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("python", function()
|
||||
it("should execute Python code and return output", function()
|
||||
LlmTools.python(
|
||||
{
|
||||
path = ".",
|
||||
code = "print('Hello from Python')",
|
||||
},
|
||||
nil,
|
||||
function(result, err)
|
||||
LlmTools.python({
|
||||
path = ".",
|
||||
code = "print('Hello from Python')",
|
||||
}, {
|
||||
session_ctx = {},
|
||||
on_complete = function(result, err)
|
||||
assert.is_nil(err)
|
||||
assert.equals("Hello from Python\n", result)
|
||||
end
|
||||
)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
|
||||
it("should handle Python errors", function()
|
||||
LlmTools.python(
|
||||
{
|
||||
path = ".",
|
||||
code = "print(undefined_variable)",
|
||||
},
|
||||
nil,
|
||||
function(result, err)
|
||||
LlmTools.python({
|
||||
path = ".",
|
||||
code = "print(undefined_variable)",
|
||||
}, {
|
||||
session_ctx = {},
|
||||
on_complete = function(result, err)
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err)
|
||||
assert.truthy(err:find("Error"))
|
||||
end
|
||||
)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
|
||||
it("should respect path permissions", function()
|
||||
LlmTools.python(
|
||||
{
|
||||
path = "../outside_project",
|
||||
code = "print('test')",
|
||||
},
|
||||
nil,
|
||||
function(result, err)
|
||||
LlmTools.python({
|
||||
path = "../outside_project",
|
||||
code = "print('test')",
|
||||
}, {
|
||||
session_ctx = {},
|
||||
on_complete = function(result, err)
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:find("No permission to access path"))
|
||||
end
|
||||
)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
|
||||
it("should handle non-existent paths", function()
|
||||
LlmTools.python(
|
||||
{
|
||||
path = "non_existent_dir",
|
||||
code = "print('test')",
|
||||
},
|
||||
nil,
|
||||
function(result, err)
|
||||
LlmTools.python({
|
||||
path = "non_existent_dir",
|
||||
code = "print('test')",
|
||||
}, {
|
||||
session_ctx = {},
|
||||
on_complete = function(result, err)
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:find("Path not found"))
|
||||
end
|
||||
)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
|
||||
it("should support custom container image", function()
|
||||
os.execute("docker image rm python:3.12-slim")
|
||||
LlmTools.python(
|
||||
{
|
||||
path = ".",
|
||||
code = "print('Hello from custom container')",
|
||||
container_image = "python:3.12-slim",
|
||||
},
|
||||
nil,
|
||||
function(result, err)
|
||||
LlmTools.python({
|
||||
path = ".",
|
||||
code = "print('Hello from custom container')",
|
||||
container_image = "python:3.12-slim",
|
||||
}, {
|
||||
session_ctx = {},
|
||||
on_complete = function(result, err)
|
||||
assert.is_nil(err)
|
||||
assert.equals("Hello from custom container\n", result)
|
||||
end
|
||||
)
|
||||
end,
|
||||
})
|
||||
end)
|
||||
end)
|
||||
|
||||
@@ -370,7 +380,7 @@ describe("llm_tools", function()
|
||||
os.execute("touch " .. test_dir .. "/nested/file4.lua")
|
||||
|
||||
-- Test for lua files in the root
|
||||
local result, err = glob({ path = ".", pattern = "*.lua" })
|
||||
local result, err = glob({ path = ".", pattern = "*.lua" }, {})
|
||||
assert.is_nil(err)
|
||||
local files = vim.json.decode(result).matches
|
||||
assert.equals(2, #files)
|
||||
@@ -380,7 +390,7 @@ describe("llm_tools", function()
|
||||
assert.falsy(vim.tbl_contains(files, test_dir .. "/nested/file4.lua"))
|
||||
|
||||
-- Test with recursive pattern
|
||||
local result2, err2 = glob({ path = ".", pattern = "**/*.lua" })
|
||||
local result2, err2 = glob({ path = ".", pattern = "**/*.lua" }, {})
|
||||
assert.is_nil(err2)
|
||||
local files2 = vim.json.decode(result2).matches
|
||||
assert.equals(3, #files2)
|
||||
@@ -390,13 +400,13 @@ describe("llm_tools", function()
|
||||
end)
|
||||
|
||||
it("should respect path permissions", function()
|
||||
local result, err = glob({ path = "../outside_project", pattern = "*.txt" })
|
||||
local result, err = glob({ path = "../outside_project", pattern = "*.txt" }, {})
|
||||
assert.equals("", result)
|
||||
assert.truthy(err:find("No permission to access path"))
|
||||
end)
|
||||
|
||||
it("should handle patterns without matches", function()
|
||||
local result, err = glob({ path = ".", pattern = "*.nonexistent" })
|
||||
local result, err = glob({ path = ".", pattern = "*.nonexistent" }, {})
|
||||
assert.is_nil(err)
|
||||
local files = vim.json.decode(result).matches
|
||||
assert.equals(0, #files)
|
||||
@@ -411,7 +421,7 @@ describe("llm_tools", function()
|
||||
os.execute("touch " .. test_dir .. "/test_dir1/notignored1.lua")
|
||||
os.execute("touch " .. test_dir .. "/test_dir1/notignored2.lua")
|
||||
|
||||
local result, err = glob({ path = ".", pattern = "**/*.lua" })
|
||||
local result, err = glob({ path = ".", pattern = "**/*.lua" }, {})
|
||||
assert.is_nil(err)
|
||||
local files = vim.json.decode(result).matches
|
||||
|
||||
|
||||
Reference in New Issue
Block a user