diff --git a/lua/avante/highlights.lua b/lua/avante/highlights.lua index e862733..c896da6 100644 --- a/lua/avante/highlights.lua +++ b/lua/avante/highlights.lua @@ -47,7 +47,9 @@ local Highlights = { AVANTE_STATE_SPINNER_SEARCHING = { name = "AvanteStateSpinnerSearching", fg = "#1e222a", bg = "#c678dd" }, AVANTE_STATE_SPINNER_THINKING = { name = "AvanteStateSpinnerThinking", fg = "#1e222a", bg = "#c678dd" }, AVANTE_STATE_SPINNER_COMPACTING = { name = "AvanteStateSpinnerCompacting", fg = "#1e222a", bg = "#c678dd" }, + AVANTE_TASK_RUNNING = { name = "AvanteTaskRunning", fg = "#c678dd", bg_link = "Normal" }, AVANTE_TASK_COMPLETED = { name = "AvanteTaskCompleted", fg = "#98c379", bg_link = "Normal" }, + AVANTE_TASK_FAILED = { name = "AvanteTaskFailed", fg = "#e06c75", bg_link = "Normal" }, AVANTE_THINKING = { name = "AvanteThinking", fg = "#c678dd", bg_link = "Normal" }, } diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 5c10394..d817fdf 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -150,8 +150,11 @@ function M.generate_todos(user_input, cb) local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages) for _, partial_tool_use in ipairs(uncalled_tool_uses) do if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then - LLMTools.process_tool_use(tools, partial_tool_use, function() end, function() cb() end, {}) - cb() + local result = LLMTools.process_tool_use(tools, partial_tool_use, { + session_ctx = {}, + on_complete = function() cb() end, + }) + if result ~= nil then cb() end end end else @@ -206,6 +209,7 @@ function M.agent_loop(opts) table.insert(history_messages, msg) end end + if opts.on_messages_add then opts.on_messages_add(msgs) end end, session_ctx = session_ctx, prompt_opts = { @@ -331,8 +335,9 @@ function M.generate_prompts(opts) end if Config.system_prompt ~= nil then - local custom_system_prompt = Config.system_prompt - if type(custom_system_prompt) == "function" then custom_system_prompt = custom_system_prompt() end + local custom_system_prompt + if type(Config.system_prompt) == "function" then custom_system_prompt = Config.system_prompt() end + if type(Config.system_prompt) == "string" then custom_system_prompt = Config.system_prompt end if custom_system_prompt ~= nil and custom_system_prompt ~= "" and custom_system_prompt ~= "null" then system_prompt = system_prompt .. "\n\n" .. custom_system_prompt end @@ -841,13 +846,9 @@ function M._stream(opts) if partial_tool_use.state == "generating" then if type(partial_tool_use.input) == "table" then partial_tool_use.input.streaming = true - LLMTools.process_tool_use( - prompt_opts.tools, - partial_tool_use, - function() end, - function() end, - opts.session_ctx - ) + LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, { + session_ctx = opts.session_ctx, + }) end return else @@ -856,13 +857,12 @@ function M._stream(opts) partial_tool_use_message.is_calling = true if opts.on_messages_add then opts.on_messages_add({ partial_tool_use_message }) end -- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil - local result, error = LLMTools.process_tool_use( - prompt_opts.tools, - partial_tool_use, - opts.on_tool_log, - handle_tool_result, - opts.session_ctx - ) + local result, error = LLMTools.process_tool_use(prompt_opts.tools, partial_tool_use, { + session_ctx = opts.session_ctx, + on_log = opts.on_tool_log, + set_tool_use_store = opts.set_tool_use_store, + on_complete = handle_tool_result, + }) if result ~= nil or error ~= nil then return handle_tool_result(result, error) end end if stop_opts.reason == "cancelled" then @@ -1100,6 +1100,13 @@ function M.stream(opts) return original_on_tool_log(...) end) end + if opts.set_tool_use_store ~= nil then + local original_set_tool_use_store = opts.set_tool_use_store + opts.set_tool_use_store = vim.schedule_wrap(function(...) + if not original_set_tool_use_store then return end + return original_set_tool_use_store(...) + end) + end if opts.on_chunk ~= nil then local original_on_chunk = opts.on_chunk opts.on_chunk = vim.schedule_wrap(function(chunk) diff --git a/lua/avante/llm_tools/add_todos.lua b/lua/avante/llm_tools/add_todos.lua index 280b520..83ec520 100644 --- a/lua/avante/llm_tools/add_todos.lua +++ b/lua/avante/llm_tools/add_todos.lua @@ -65,10 +65,11 @@ M.returns = { M.on_render = function() return {} end ---@type AvanteLLMToolFunc<{ todos: avante.TODO[] }> -function M.func(opts, on_log, on_complete, session_ctx) +function M.func(input, opts) + local on_complete = opts.on_complete local sidebar = require("avante").get() if not sidebar then return false, "Avante sidebar not found" end - local todos = opts.todos + local todos = input.todos if not todos or #todos == 0 then return false, "No todos provided" end sidebar:update_todos(todos) if on_complete then diff --git a/lua/avante/llm_tools/attempt_completion.lua b/lua/avante/llm_tools/attempt_completion.lua index a672c3f..c72ad2e 100644 --- a/lua/avante/llm_tools/attempt_completion.lua +++ b/lua/avante/llm_tools/attempt_completion.lua @@ -57,11 +57,11 @@ M.returns = { } ---@type avante.LLMToolOnRender -function M.on_render(opts) +function M.on_render(input) local lines = {} table.insert(lines, Line:new({ { "✓ Task Completed", Highlights.AVANTE_TASK_COMPLETED } })) table.insert(lines, Line:new({ { "" } })) - local result = opts.result or "" + local result = input.result or "" local text_lines = vim.split(result, "\n") for _, text_line in ipairs(text_lines) do table.insert(lines, Line:new({ { text_line } })) @@ -70,24 +70,29 @@ function M.on_render(opts) end ---@type AvanteLLMToolFunc -function M.func(opts, on_log, on_complete, session_ctx) - if not on_complete then return false, "on_complete not provided" end +function M.func(input, opts) + if not opts.on_complete then return false, "on_complete not provided" end local sidebar = require("avante").get() if not sidebar then return false, "Avante sidebar not found" end - local is_streaming = opts.streaming or false + local is_streaming = input.streaming or false if is_streaming then -- wait for stream completion as command may not be complete yet return end - session_ctx.attempt_completion_is_called = true + opts.session_ctx.attempt_completion_is_called = true - if opts.command and opts.command ~= vim.NIL and opts.command ~= "" and not vim.startswith(opts.command, "open ") then - session_ctx.always_yes = false - require("avante.llm_tools.bash").func({ command = opts.command }, on_log, on_complete, session_ctx) + if + input.command + and input.command ~= vim.NIL + and input.command ~= "" + and not vim.startswith(input.command, "open ") + then + opts.session_ctx.always_yes = false + require("avante.llm_tools.bash").func({ command = input.command }, opts) else - on_complete(true, nil) + opts.on_complete(true, nil) end end diff --git a/lua/avante/llm_tools/bash.lua b/lua/avante/llm_tools/bash.lua index 2a270ec..8bf7195 100644 --- a/lua/avante/llm_tools/bash.lua +++ b/lua/avante/llm_tools/bash.lua @@ -215,18 +215,18 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, command: string, streaming?: boolean }> -function M.func(opts, on_log, on_complete, session_ctx) - local is_streaming = opts.streaming or false +function M.func(input, opts) + local is_streaming = input.streaming or false if is_streaming then -- wait for stream completion as command may not be complete yet return end - local abs_path = Helpers.get_abs_path(opts.path) + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end - if not opts.command then return false, "Command is required" end - if on_log then on_log("command: " .. opts.command) end + if not input.command then return false, "Command is required" end + if opts.on_log then opts.on_log("command: " .. input.command) end ---change cwd to abs_path ---@param output string @@ -240,21 +240,21 @@ function M.func(opts, on_log, on_complete, session_ctx) end return output, nil end - if not on_complete then return false, "on_complete not provided" end + if not opts.on_complete then return false, "on_complete not provided" end Helpers.confirm( - "Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path, + "Are you sure you want to run the command: `" .. input.command .. "` in the directory: " .. abs_path, function(ok, reason) if not ok then - on_complete(false, "User declined, reason: " .. (reason and reason or "unknown")) + opts.on_complete(false, "User declined, reason: " .. (reason and reason or "unknown")) return end - Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code) + Utils.shell_run_async(input.command, "bash -c", function(output, exit_code) local result, err = handle_result(output, exit_code) - on_complete(result, err) + opts.on_complete(result, err) end, abs_path) end, { focus = true }, - session_ctx, + opts.session_ctx, M.name -- Pass the tool name for permission checking ) end diff --git a/lua/avante/llm_tools/create.lua b/lua/avante/llm_tools/create.lua index f2f0a5d..a9a91a7 100644 --- a/lua/avante/llm_tools/create.lua +++ b/lua/avante/llm_tools/create.lua @@ -45,20 +45,23 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, file_text: string }> -function M.func(opts, on_log, on_complete, session_ctx) +function M.func(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local session_ctx = opts.session_ctx if not on_complete then return false, "on_complete not provided" end - if on_log then on_log("path: " .. opts.path) end - if Helpers.already_in_context(opts.path) then + if on_log then on_log("path: " .. input.path) end + if Helpers.already_in_context(input.path) then on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to create it again?") return end - local abs_path = Helpers.get_abs_path(opts.path) + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end - if opts.file_text == nil then return false, "file_text not provided" end + if input.file_text == nil then return false, "file_text not provided" end if Path:new(abs_path):exists() then return false, "File already exists: " .. abs_path end - local lines = vim.split(opts.file_text, "\n") - if #lines == 1 and opts.file_text:match("\\n") then - local text = Utils.trim_slashes(opts.file_text) + local lines = vim.split(input.file_text, "\n") + if #lines == 1 and input.file_text:match("\\n") then + local text = Utils.trim_slashes(input.file_text) lines = vim.split(text, "\n") end local bufnr, err = Helpers.get_bufnr(abs_path) diff --git a/lua/avante/llm_tools/delete_tool_use_messages.lua b/lua/avante/llm_tools/delete_tool_use_messages.lua index e443150..afcbbde 100644 --- a/lua/avante/llm_tools/delete_tool_use_messages.lua +++ b/lua/avante/llm_tools/delete_tool_use_messages.lua @@ -40,7 +40,7 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ tool_use_id: string }> -function M.func(opts, on_log, on_complete, session_ctx) +function M.func(input, opts) local sidebar = require("avante").get() if not sidebar then return false, "Avante sidebar not found" end local history_messages = Utils.get_history_messages(sidebar.chat_history) @@ -50,7 +50,7 @@ function M.func(opts, on_log, on_complete, session_ctx) local content = msg.message.content if type(content) == "table" then for _, item in ipairs(content) do - if item.id == opts.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end + if item.id == input.tool_use_id then table.insert(the_deleted_message_uuids, msg.uuid) end end end end diff --git a/lua/avante/llm_tools/dispatch_agent.lua b/lua/avante/llm_tools/dispatch_agent.lua index 3d404a1..b1e9c27 100644 --- a/lua/avante/llm_tools/dispatch_agent.lua +++ b/lua/avante/llm_tools/dispatch_agent.lua @@ -3,6 +3,8 @@ local Config = require("avante.config") local Utils = require("avante.utils") local Base = require("avante.llm_tools.base") local HistoryMessage = require("avante.history_message") +local Line = require("avante.ui.line") +local Highlights = require("avante.highlights") ---@class AvanteLLMTool local M = setmetatable({}, Base) @@ -79,12 +81,118 @@ local function get_available_tools() } end ----@type AvanteLLMToolFunc<{ prompt: string }> -function M.func(opts, on_log, on_complete, session_ctx) +---@class avante.DispatchAgentInput +---@field prompt string + +---@type avante.LLMToolOnRender +function M.on_render(input, opts) + local result_message = opts.result_message + local store = opts.store or {} + local messages = store.messages or {} + local tool_use_summary = {} + for _, msg in ipairs(messages) do + local content = msg.message.content + local summary + if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then + local tool_result_message = Utils.get_tool_result_message(msg, messages) + if tool_result_message then + local tool_name = msg.message.content[1].name + if tool_name == "ls" then + local path = msg.message.content[1].input.path + if tool_result_message.message.content[1].is_error then + summary = string.format("Ls %s: failed", path) + else + local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content) + if ok then summary = string.format("Ls %s: %d paths", path, #filepaths) end + end + elseif tool_name == "grep" then + local path = msg.message.content[1].input.path + local query = msg.message.content[1].input.query + if tool_result_message.message.content[1].is_error then + summary = string.format("Grep %s in %s: failed", query, path) + else + local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content) + if ok then summary = string.format("Grep %s in %s: %d paths", query, path, #filepaths) end + end + elseif tool_name == "glob" then + local path = msg.message.content[1].input.path + local pattern = msg.message.content[1].input.pattern + if tool_result_message.message.content[1].is_error then + summary = string.format("Glob %s in %s: failed", pattern, path) + else + local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content) + if ok then + local matches = result.matches + if matches then summary = string.format("Glob %s in %s: %d matches", pattern, path, #matches) end + end + end + elseif tool_name == "view" then + local path = msg.message.content[1].input.path + if tool_result_message.message.content[1].is_error then + summary = string.format("View %s: failed", path) + else + local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content) + if ok then + local content_ = result.content + local lines = vim.split(content_, "\n") + summary = string.format("View %s: %d lines", path, #lines) + end + end + end + end + if summary then summary = " " .. Utils.icon("🛠️ ") .. summary end + elseif type(content) == "table" and #content > 0 and type(content[1]) == "table" and content[1].type == "text" then + summary = content[1].content + elseif type(content) == "table" and #content > 0 and type(content[1]) == "string" then + summary = content[1] + elseif type(content) == "string" then + summary = content + end + if summary then table.insert(tool_use_summary, summary) end + end + local state = "running" + local icon = Utils.icon("🔄 ") + local hl = Highlights.AVANTE_TASK_RUNNING + if result_message then + if result_message.message.content[1].is_error then + state = "failed" + icon = Utils.icon("❌ ") + hl = Highlights.AVANTE_TASK_FAILED + else + state = "completed" + icon = Utils.icon("✅ ") + hl = Highlights.AVANTE_TASK_COMPLETED + end + end + local lines = {} + table.insert(lines, Line:new({ { icon .. "Subtask " .. state, hl } })) + table.insert(lines, Line:new({ { "" } })) + table.insert(lines, Line:new({ { " Task:" } })) + local prompt_lines = vim.split(input.prompt or "", "\n") + for _, line in ipairs(prompt_lines) do + table.insert(lines, Line:new({ { " " .. line } })) + end + table.insert(lines, Line:new({ { "" } })) + table.insert(lines, Line:new({ { " Task summary:" } })) + for _, summary in ipairs(tool_use_summary) do + local summary_lines = vim.split(summary, "\n") + for _, line in ipairs(summary_lines) do + table.insert(lines, Line:new({ { " " .. line } })) + end + end + return lines +end + +---@type AvanteLLMToolFunc +function M.func(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local session_ctx = opts.session_ctx + local Llm = require("avante.llm") if not on_complete then return false, "on_complete not provided" end - local prompt = opts.prompt + local prompt = input.prompt local tools = get_available_tools() local start_time = Utils.get_timestamp() @@ -95,10 +203,11 @@ Your task is to help the user with their request: "${prompt}" Be thorough and use the tools available to you to find the most relevant information. When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt) + local history_messages = {} local tool_use_messages = {} local total_tokens = 0 - local final_response = "" + local result = "" ---@type avante.AgentLoopOptions local agent_loop_options = { @@ -108,19 +217,37 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub on_tool_log = session_ctx.on_tool_log, on_messages_add = function(msgs) msgs = vim.islist(msgs) and msgs or { msgs } + for _, msg in ipairs(msgs) do + local idx = nil + for i, m in ipairs(history_messages) do + if m.uuid == msg.uuid then + idx = i + break + end + end + if idx ~= nil then + history_messages[idx] = msg + else + table.insert(history_messages, msg) + end + end + if opts.set_store then opts.set_store("messages", history_messages) end for _, msg in ipairs(msgs) do local content = msg.message.content if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then tool_use_messages[msg.uuid] = true + if content[1].name == "attempt_completion" then + local input_ = content[1].input + if input_ and input_.result then result = input_.result end + end end end - if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end + -- if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end end, session_ctx = session_ctx, on_start = session_ctx.on_start, on_chunk = function(chunk) if not chunk then return end - final_response = final_response .. chunk total_tokens = total_tokens + (#vim.split(chunk, " ") * 1.3) end, on_complete = function(err) @@ -148,8 +275,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub }) session_ctx.on_messages_add({ message }) end - local response = string.format("Final response:\n%s\n\nSummary:\n%s", summary, final_response) - on_complete(response, nil) + on_complete(result, nil) end, } diff --git a/lua/avante/llm_tools/get_diagnostics.lua b/lua/avante/llm_tools/get_diagnostics.lua index 010945d..322d7e5 100644 --- a/lua/avante/llm_tools/get_diagnostics.lua +++ b/lua/avante/llm_tools/get_diagnostics.lua @@ -40,10 +40,12 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, diff: string }> -function M.func(opts, on_log, on_complete, session_ctx) - if not opts.path then return false, "pathf are required" end - if on_log then on_log("path: " .. opts.path) end - local abs_path = Helpers.get_abs_path(opts.path) +function M.func(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + if not input.path then return false, "pathf are required" end + if on_log then on_log("path: " .. input.path) end + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not on_complete then return false, "on_complete is required" end local diagnostics = Utils.lsp.get_diagnostics_from_filepath(abs_path) diff --git a/lua/avante/llm_tools/glob.lua b/lua/avante/llm_tools/glob.lua index 5e00478..6bc3c2c 100644 --- a/lua/avante/llm_tools/glob.lua +++ b/lua/avante/llm_tools/glob.lua @@ -45,12 +45,14 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, pattern: string }> -function M.func(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.path) +function M.func(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if on_log then on_log("path: " .. abs_path) end - if on_log then on_log("pattern: " .. opts.pattern) end - local files = vim.fn.glob(abs_path .. "/" .. opts.pattern, true, true) + if on_log then on_log("pattern: " .. input.pattern) end + local files = vim.fn.glob(abs_path .. "/" .. input.pattern, true, true) local truncated_files = {} local is_truncated = false local size = 0 diff --git a/lua/avante/llm_tools/grep.lua b/lua/avante/llm_tools/grep.lua index 2442454..db7d220 100644 --- a/lua/avante/llm_tools/grep.lua +++ b/lua/avante/llm_tools/grep.lua @@ -69,8 +69,10 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }> -function M.func(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.path) +function M.func(input, opts) + local on_log = opts.on_log + + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end @@ -85,31 +87,31 @@ function M.func(opts, on_log, on_complete, session_ctx) local cmd = "" if search_cmd:find("rg") then cmd = string.format("%s --files-with-matches --hidden", search_cmd) - if opts.case_sensitive then + if input.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) else cmd = string.format("%s --ignore-case", cmd) end - if opts.include_pattern then cmd = string.format("%s --glob '%s'", cmd, opts.include_pattern) end - if opts.exclude_pattern then cmd = string.format("%s --glob '!%s'", cmd, opts.exclude_pattern) end - cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path) + if input.include_pattern then cmd = string.format("%s --glob '%s'", cmd, input.include_pattern) end + if input.exclude_pattern then cmd = string.format("%s --glob '!%s'", cmd, input.exclude_pattern) end + cmd = string.format("%s '%s' %s", cmd, input.query, abs_path) elseif search_cmd:find("ag") then cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd) - if opts.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) end - if opts.include_pattern then cmd = string.format("%s --ignore '!%s'", cmd, opts.include_pattern) end - if opts.exclude_pattern then cmd = string.format("%s --ignore '%s'", cmd, opts.exclude_pattern) end - cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path) + if input.case_sensitive then cmd = string.format("%s --case-sensitive", cmd) end + if input.include_pattern then cmd = string.format("%s --ignore '!%s'", cmd, input.include_pattern) end + if input.exclude_pattern then cmd = string.format("%s --ignore '%s'", cmd, input.exclude_pattern) end + cmd = string.format("%s '%s' %s", cmd, input.query, abs_path) elseif search_cmd:find("ack") then cmd = string.format("%s --nocolor --nogroup --hidden", search_cmd) - if opts.case_sensitive then cmd = string.format("%s --smart-case", cmd) end - if opts.exclude_pattern then cmd = string.format("%s --ignore-dir '%s'", cmd, opts.exclude_pattern) end - cmd = string.format("%s '%s' %s", cmd, opts.query, abs_path) + if input.case_sensitive then cmd = string.format("%s --smart-case", cmd) end + if input.exclude_pattern then cmd = string.format("%s --ignore-dir '%s'", cmd, input.exclude_pattern) end + cmd = string.format("%s '%s' %s", cmd, input.query, abs_path) elseif search_cmd:find("grep") then cmd = string.format("cd %s && git ls-files -co --exclude-standard | xargs %s -rH", abs_path, search_cmd, abs_path) - if not opts.case_sensitive then cmd = string.format("%s -i", cmd) end - if opts.include_pattern then cmd = string.format("%s --include '%s'", cmd, opts.include_pattern) end - if opts.exclude_pattern then cmd = string.format("%s --exclude '%s'", cmd, opts.exclude_pattern) end - cmd = string.format("%s '%s'", cmd, opts.query) + if not input.case_sensitive then cmd = string.format("%s -i", cmd) end + if input.include_pattern then cmd = string.format("%s --include '%s'", cmd, input.include_pattern) end + if input.exclude_pattern then cmd = string.format("%s --exclude '%s'", cmd, input.exclude_pattern) end + cmd = string.format("%s '%s'", cmd, input.query) end Utils.debug("cmd", cmd) diff --git a/lua/avante/llm_tools/init.lua b/lua/avante/llm_tools/init.lua index 3c71d81..825577f 100644 --- a/lua/avante/llm_tools/init.lua +++ b/lua/avante/llm_tools/init.lua @@ -8,9 +8,10 @@ local Helpers = require("avante.llm_tools.helpers") local M = {} ---@type AvanteLLMToolFunc<{ path: string }> -function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx) +function M.read_file_toplevel_symbols(input, opts) + local on_log = opts.on_log local RepoMap = require("avante.repo_map") - local abs_path = Helpers.get_abs_path(opts.path) + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if on_log then on_log("path: " .. abs_path) end if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end @@ -24,50 +25,47 @@ function M.read_file_toplevel_symbols(opts, on_log, on_complete, session_ctx) end ---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert" | "undo_edit", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[] }> -function M.str_replace_editor(opts, on_log, on_complete, session_ctx) - if opts.command == "undo_edit" then - return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx) - end - ---@cast opts any - return M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx) +function M.str_replace_editor(input, opts) + if input.command == "undo_edit" then return require("avante.llm_tools.undo_edit").func(input, opts) end + ---@cast input any + return M.str_replace_based_edit_tool(input, opts) end ---@type AvanteLLMToolFunc<{ command: "view" | "str_replace" | "create" | "insert", path: string, old_str?: string, new_str?: string, file_text?: string, insert_line?: integer, new_str?: string, view_range?: integer[], streaming?: boolean }> -function M.str_replace_based_edit_tool(opts, on_log, on_complete, session_ctx) - if not opts.command then return false, "command not provided" end - if on_log then on_log("command: " .. opts.command) end +function M.str_replace_based_edit_tool(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + if not input.command then return false, "command not provided" end + if on_log then on_log("command: " .. input.command) end if not on_complete then return false, "on_complete not provided" end - local abs_path = Helpers.get_abs_path(opts.path) + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end - if opts.command == "view" then + if input.command == "view" then local view = require("avante.llm_tools.view") - local opts_ = { path = opts.path } - if opts.view_range then - local start_line, end_line = unpack(opts.view_range) - opts_.start_line = start_line - opts_.end_line = end_line + local input_ = { path = input.path } + if input.view_range then + local start_line, end_line = unpack(input.view_range) + input_.start_line = start_line + input_.end_line = end_line end - return view(opts_, on_log, on_complete, session_ctx) + return view(input_, opts) end - if opts.command == "str_replace" then - if opts.new_str == nil and opts.file_text ~= nil then - opts.new_str = opts.file_text - opts.file_text = nil + if input.command == "str_replace" then + if input.new_str == nil and input.file_text ~= nil then + input.new_str = input.file_text + input.file_text = nil end - return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx) + return require("avante.llm_tools.str_replace").func(input, opts) end - if opts.command == "create" then - return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx) - end - if opts.command == "insert" then - return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx) - end - return false, "Unknown command: " .. opts.command + if input.command == "create" then return require("avante.llm_tools.create").func(input, opts) end + if input.command == "insert" then return require("avante.llm_tools.insert").func(input, opts) end + return false, "Unknown command: " .. input.command end ---@type AvanteLLMToolFunc<{ abs_path: string }> -function M.read_global_file(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.abs_path) +function M.read_global_file(input, opts) + local on_log = opts.on_log + local abs_path = Helpers.get_abs_path(input.abs_path) if Helpers.is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end if on_log then on_log("path: " .. abs_path) end local file = io.open(abs_path, "r") @@ -78,11 +76,13 @@ function M.read_global_file(opts, on_log, on_complete, session_ctx) end ---@type AvanteLLMToolFunc<{ abs_path: string, content: string }> -function M.write_global_file(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.abs_path) +function M.write_global_file(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local abs_path = Helpers.get_abs_path(input.abs_path) if Helpers.is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end if on_log then on_log("path: " .. abs_path) end - if on_log then on_log("content: " .. opts.content) end + if on_log then on_log("content: " .. input.content) end if not on_complete then return false, "on_complete not provided" end Helpers.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok) if not ok then @@ -94,18 +94,20 @@ function M.write_global_file(opts, on_log, on_complete, session_ctx) on_complete(false, "file not found: " .. abs_path) return end - file:write(opts.content) + file:write(input.content) file:close() on_complete(true, nil) - end, nil, session_ctx, "write_global_file") + end, nil, opts.session_ctx, "write_global_file") end ---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }> -function M.move_path(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.source_path) +function M.move_path(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local abs_path = Helpers.get_abs_path(input.source_path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end - local new_abs_path = Helpers.get_abs_path(opts.destination_path) + local new_abs_path = Helpers.get_abs_path(input.destination_path) if on_log then on_log(abs_path .. " -> " .. new_abs_path) end if not Helpers.has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path @@ -123,17 +125,19 @@ function M.move_path(opts, on_log, on_complete, session_ctx) on_complete(true, nil) end, nil, - session_ctx, + opts.session_ctx, "move_path" ) end ---@type AvanteLLMToolFunc<{ source_path: string, destination_path: string }> -function M.copy_path(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.source_path) +function M.copy_path(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local abs_path = Helpers.get_abs_path(input.source_path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "The source path not found: " .. abs_path end - local new_abs_path = Helpers.get_abs_path(opts.destination_path) + local new_abs_path = Helpers.get_abs_path(input.destination_path) if not Helpers.has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end @@ -169,14 +173,16 @@ function M.copy_path(opts, on_log, on_complete, session_ctx) on_complete(true, nil) end, nil, - session_ctx, + opts.session_ctx, "copy_path" ) end ---@type AvanteLLMToolFunc<{ path: string }> -function M.delete_path(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.path) +function M.delete_path(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if not on_complete then return false, "on_complete not provided" end @@ -188,12 +194,14 @@ function M.delete_path(opts, on_log, on_complete, session_ctx) if on_log then on_log("Deleting path: " .. abs_path) end os.remove(abs_path) on_complete(true, nil) - end, nil, session_ctx, "delete_path") + end, nil, opts.session_ctx, "delete_path") end ---@type AvanteLLMToolFunc<{ path: string }> -function M.create_dir(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.path) +function M.create_dir(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end if not on_complete then return false, "on_complete not provided" end @@ -205,16 +213,17 @@ function M.create_dir(opts, on_log, on_complete, session_ctx) if on_log then on_log("Creating directory: " .. abs_path) end Path:new(abs_path):mkdir({ parents = true }) on_complete(true, nil) - end, nil, session_ctx, "create_dir") + end, nil, opts.session_ctx, "create_dir") end ---@type AvanteLLMToolFunc<{ query: string }> -function M.web_search(opts, on_log, on_complete, session_ctx) +function M.web_search(input, opts) + local on_log = opts.on_log local provider_type = Config.web_search_engine.provider local proxy = Config.web_search_engine.proxy if provider_type == nil then return nil, "Search engine provider is not set" end if on_log then on_log("provider: " .. provider_type) end - if on_log then on_log("query: " .. opts.query) end + if on_log then on_log("query: " .. input.query) end local search_engine = Config.web_search_engine.providers[provider_type] if search_engine == nil then return nil, "No search engine found: " .. provider_type end if provider_type ~= "searxng" and search_engine.api_key_name == "" then return nil, "No API key provided" end @@ -229,7 +238,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx) ["Authorization"] = "Bearer " .. api_key, }, body = vim.json.encode(vim.tbl_deep_extend("force", { - query = opts.query, + query = input.query, }, search_engine.extra_request_body)), } if proxy then curl_opts.proxy = proxy end @@ -240,7 +249,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx) elseif provider_type == "serpapi" then local query_params = vim.tbl_deep_extend("force", { api_key = api_key, - q = opts.query, + q = input.query, }, search_engine.extra_request_body) local query_string = "" for key, value in pairs(query_params) do @@ -259,7 +268,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx) elseif provider_type == "searchapi" then local query_params = vim.tbl_deep_extend("force", { api_key = api_key, - q = opts.query, + q = input.query, }, search_engine.extra_request_body) local query_string = "" for key, value in pairs(query_params) do @@ -283,7 +292,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx) local query_params = vim.tbl_deep_extend("force", { key = api_key, cx = engine_id, - q = opts.query, + q = input.query, }, search_engine.extra_request_body) local query_string = "" for key, value in pairs(query_params) do @@ -301,7 +310,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx) return search_engine.format_response_body(jsn) elseif provider_type == "kagi" then local query_params = vim.tbl_deep_extend("force", { - q = opts.query, + q = input.query, }, search_engine.extra_request_body) local query_string = "" for key, value in pairs(query_params) do @@ -320,7 +329,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx) return search_engine.format_response_body(jsn) elseif provider_type == "brave" then local query_params = vim.tbl_deep_extend("force", { - q = opts.query, + q = input.query, }, search_engine.extra_request_body) local query_string = "" for key, value in pairs(query_params) do @@ -343,7 +352,7 @@ function M.web_search(opts, on_log, on_complete, session_ctx) return nil, "Environment variable " .. search_engine.api_url_name .. " is not set" end local query_params = vim.tbl_deep_extend("force", { - q = opts.query, + q = input.query, }, search_engine.extra_request_body) local query_string = "" for key, value in pairs(query_params) do @@ -362,16 +371,18 @@ function M.web_search(opts, on_log, on_complete, session_ctx) end ---@type AvanteLLMToolFunc<{ url: string }> -function M.fetch(opts, on_log, on_complete, session_ctx) - if on_log then on_log("url: " .. opts.url) end +function M.fetch(input, opts) + local on_log = opts.on_log + if on_log then on_log("url: " .. input.url) end local Html2Md = require("avante.html2md") - local res, err = Html2Md.fetch_md(opts.url) + local res, err = Html2Md.fetch_md(input.url) if err then return nil, err end return res, nil end ---@type AvanteLLMToolFunc<{ scope?: string }> -function M.git_diff(opts, on_log, on_complete, session_ctx) +function M.git_diff(input, opts) + local on_log = opts.on_log local git_cmd = vim.fn.exepath("git") if git_cmd == "" then return nil, "Git command not found" end local project_root = Utils.get_project_root() @@ -382,7 +393,7 @@ function M.git_diff(opts, on_log, on_complete, session_ctx) if git_dir == "" then return nil, "Not a git repository" end -- Get the diff - local scope = opts.scope or "" + local scope = input.scope or "" local cmd = string.format("git diff --cached %s", scope) if on_log then on_log("Running command: " .. cmd) end local diff = vim.fn.system(cmd) @@ -400,7 +411,10 @@ function M.git_diff(opts, on_log, on_complete, session_ctx) end ---@type AvanteLLMToolFunc<{ message: string, scope?: string }> -function M.git_commit(opts, on_log, on_complete, session_ctx) +function M.git_commit(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local git_cmd = vim.fn.exepath("git") if git_cmd == "" then return false, "Git command not found" end local project_root = Utils.get_project_root() @@ -444,7 +458,7 @@ function M.git_commit(opts, on_log, on_complete, session_ctx) -- Prepare commit message local commit_msg_lines = {} - for line in opts.message:gmatch("[^\r\n]+") do + for line in input.message:gmatch("[^\r\n]+") do commit_msg_lines[#commit_msg_lines + 1] = line:gsub('"', '\\"') end commit_msg_lines[#commit_msg_lines + 1] = "" @@ -464,8 +478,8 @@ function M.git_commit(opts, on_log, on_complete, session_ctx) return end -- Stage changes if scope is provided - if opts.scope then - local stage_cmd = string.format("git add %s", opts.scope) + if input.scope then + local stage_cmd = string.format("git add %s", input.scope) if on_log then on_log("Staging files: " .. stage_cmd) end local stage_result = vim.fn.system(stage_cmd) if vim.v.shell_error ~= 0 then @@ -494,20 +508,23 @@ function M.git_commit(opts, on_log, on_complete, session_ctx) end on_complete(true, nil) - end, nil, session_ctx, "git_commit") + end, nil, opts.session_ctx, "git_commit") end ---@type AvanteLLMToolFunc<{ query: string }> -function M.rag_search(opts, on_log, on_complete, session_ctx) +function M.rag_search(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + if not on_complete then return nil, "on_complete not provided" end if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end - if not opts.query then return nil, "No query provided" end - if on_log then on_log("query: " .. opts.query) end + if not input.query then return nil, "No query provided" end + if on_log then on_log("query: " .. input.query) end local root = Utils.get_project_root() local uri = "file://" .. root if uri:sub(-1) ~= "/" then uri = uri .. "/" end RagService.retrieve( uri, - opts.query, + input.query, vim.schedule_wrap(function(resp, err) if err then on_complete(nil, err) @@ -519,13 +536,15 @@ function M.rag_search(opts, on_log, on_complete, session_ctx) end ---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }> -function M.python(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.path) +function M.python(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end if on_log then on_log("cwd: " .. abs_path) end - if on_log then on_log("code:\n" .. opts.code) end - local container_image = opts.container_image or "python:3.11-slim-bookworm" + if on_log then on_log("code:\n" .. input.code) end + local container_image = input.container_image or "python:3.11-slim-bookworm" if not on_complete then return nil, "on_complete not provided" end Helpers.confirm( "Are you sure you want to run the following python code in the `" @@ -533,7 +552,7 @@ function M.python(opts, on_log, on_complete, session_ctx) .. "` container, in the directory: `" .. abs_path .. "`?\n" - .. opts.code, + .. input.code, function(ok, reason) if not ok then on_complete(nil, "User declined, reason: " .. (reason or "unknown")) @@ -562,7 +581,7 @@ function M.python(opts, on_log, on_complete, session_ctx) container_image, "python", "-c", - opts.code, + input.code, }, { text = true, @@ -576,7 +595,7 @@ function M.python(opts, on_log, on_complete, session_ctx) ) end, nil, - session_ctx, + opts.session_ctx, "python" ) end @@ -1189,14 +1208,19 @@ You can delete the first file by providing a path of "directory1/a/something.txt --- compatibility alias for old calls & tests M.run_python = M.python +---@class avante.ProcessToolUseOpts +---@field session_ctx table +---@field on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil +---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil +---@field on_complete? fun(result: string | nil, error: string | nil): nil + ---@param tools AvanteLLMTool[] ---@param tool_use AvanteLLMToolUse ----@param on_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil ----@param on_complete? fun(result: string | nil, error: string | nil): nil ----@param session_ctx? table ---@return string | nil result ---@return string | nil error -function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx) +function M.process_tool_use(tools, tool_use, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete -- Check if execution is already cancelled if Helpers.is_cancelled then Utils.debug("Tool execution cancelled before starting: " .. tool_use.name) @@ -1274,25 +1298,32 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx) return result_str, err end - local result, err = func(input_json, function(log) - -- Check for cancellation during logging - if Helpers.is_cancelled then return end - if on_log then on_log(tool_use.id, tool_use.name, log, "running") end - end, function(result, err) - -- Check for cancellation before completing - if Helpers.is_cancelled then - Helpers.is_cancelled = false - if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end - return - end + local result, err = func(input_json, { + session_ctx = opts.session_ctx or {}, + on_log = function(log) + -- Check for cancellation during logging + if Helpers.is_cancelled then return end + if on_log then on_log(tool_use.id, tool_use.name, log, "running") end + end, + set_store = function(key, value) + if opts.set_tool_use_store then opts.set_tool_use_store(tool_use.id, key, value) end + end, + on_complete = function(result, err) + -- Check for cancellation before completing + if Helpers.is_cancelled then + Helpers.is_cancelled = false + if on_complete then on_complete(nil, Helpers.CANCEL_TOKEN) end + return + end - result, err = handle_result(result, err) - if on_complete == nil then - Utils.error("asynchronous tool " .. tool_use.name .. " result not handled") - return - end - on_complete(result, err) - end, session_ctx) + result, err = handle_result(result, err) + if on_complete == nil then + Utils.error("asynchronous tool " .. tool_use.name .. " result not handled") + return + end + on_complete(result, err) + end, + }) -- Result and error being nil means that the tool was executed asynchronously if result == nil and err == nil and on_complete then return end diff --git a/lua/avante/llm_tools/insert.lua b/lua/avante/llm_tools/insert.lua index be17cdf..c82b9f7 100644 --- a/lua/avante/llm_tools/insert.lua +++ b/lua/avante/llm_tools/insert.lua @@ -55,19 +55,24 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, insert_line: integer, new_str: string }> -function M.func(opts, on_log, on_complete, session_ctx) - if on_log then on_log("path: " .. opts.path) end - local abs_path = Helpers.get_abs_path(opts.path) +function M.func(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local session_ctx = opts.session_ctx + if not on_complete then return false, "on_complete not provided" end + + if on_log then on_log("path: " .. input.path) end + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end - if opts.insert_line == nil then return false, "insert_line not provided" end - if opts.new_str == nil then return false, "new_str not provided" end + if input.insert_line == nil then return false, "insert_line not provided" end + if input.new_str == nil then return false, "new_str not provided" end local ns_id = vim.api.nvim_create_namespace("avante_insert_diff") local bufnr, err = Helpers.get_bufnr(abs_path) if err then return false, err end local function clear_highlights() vim.api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1) end - local new_lines = vim.split(opts.new_str, "\n") + local new_lines = vim.split(input.new_str, "\n") local max_col = vim.o.columns local virt_lines = vim .iter(new_lines) @@ -78,8 +83,8 @@ function M.func(opts, on_log, on_complete, session_ctx) end) :totable() local line_count = vim.api.nvim_buf_line_count(bufnr) - if opts.insert_line > line_count - 1 then opts.insert_line = line_count - 1 end - vim.api.nvim_buf_set_extmark(bufnr, ns_id, opts.insert_line, 0, { + if input.insert_line > line_count - 1 then input.insert_line = line_count - 1 end + vim.api.nvim_buf_set_extmark(bufnr, ns_id, input.insert_line, 0, { virt_lines = virt_lines, hl_eol = true, hl_mode = "combine", @@ -90,9 +95,9 @@ function M.func(opts, on_log, on_complete, session_ctx) on_complete(false, "User declined, reason: " .. (reason or "unknown")) return end - vim.api.nvim_buf_set_lines(bufnr, opts.insert_line, opts.insert_line, false, new_lines) + vim.api.nvim_buf_set_lines(bufnr, input.insert_line, input.insert_line, false, new_lines) vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end) - if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end + if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end on_complete(true, nil) end, { focus = true }, session_ctx, M.name) end diff --git a/lua/avante/llm_tools/ls.lua b/lua/avante/llm_tools/ls.lua index 30b072e..20a47e3 100644 --- a/lua/avante/llm_tools/ls.lua +++ b/lua/avante/llm_tools/ls.lua @@ -46,15 +46,16 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, max_depth?: integer }> -function M.func(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.path) +function M.func(input, opts) + local on_log = opts.on_log + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if on_log then on_log("path: " .. abs_path) end - if on_log then on_log("max depth: " .. tostring(opts.max_depth)) end + if on_log then on_log("max depth: " .. tostring(input.max_depth)) end local files = Utils.scan_directory({ directory = abs_path, add_dirs = true, - max_depth = opts.max_depth, + max_depth = input.max_depth, }) local filepaths = {} for _, file in ipairs(files) do diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index cf82c7f..b094b23 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -125,39 +125,44 @@ end --- IMPORTANT: Using "the_diff" instead of "diff" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "diff", making it impossible to achieve a streaming diff view. ---@type AvanteLLMToolFunc<{ path: string, diff: string, the_diff?: string, streaming?: boolean, tool_use_id?: string }> -function M.func(opts, on_log, on_complete, session_ctx) - if opts.the_diff ~= nil then - opts.diff = opts.the_diff - opts.the_diff = nil +function M.func(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local session_ctx = opts.session_ctx + if not on_complete then return false, "on_complete not provided" end + + if input.the_diff ~= nil then + input.diff = input.the_diff + input.the_diff = nil end - if not opts.path or not opts.diff then return false, "path and diff are required " .. vim.inspect(opts) end - if on_log then on_log("path: " .. opts.path) end - local abs_path = Helpers.get_abs_path(opts.path) + if not input.path or not input.diff then return false, "path and diff are required " .. vim.inspect(input) end + if on_log then on_log("path: " .. input.path) end + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end - local is_streaming = opts.streaming or false + local is_streaming = input.streaming or false session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {} local current_timestamp = os.time() if is_streaming then - local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] + local prev_streaming_diff_timestamp = session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] if prev_streaming_diff_timestamp ~= nil then if current_timestamp - prev_streaming_diff_timestamp < 2 then return false, "Diff hasn't changed in the last 2 seconds" end end - local streaming_diff_lines_count = Utils.count_lines(opts.diff) + local streaming_diff_lines_count = Utils.count_lines(input.diff) session_ctx.streaming_diff_lines_count_history = session_ctx.streaming_diff_lines_count_history or {} - local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] + local prev_streaming_diff_lines_count = session_ctx.streaming_diff_lines_count_history[input.tool_use_id] if streaming_diff_lines_count == prev_streaming_diff_lines_count then return false, "Diff lines count hasn't changed" end - session_ctx.streaming_diff_lines_count_history[opts.tool_use_id] = streaming_diff_lines_count + session_ctx.streaming_diff_lines_count_history[input.tool_use_id] = streaming_diff_lines_count end - local diff = fix_diff(opts.diff) + local diff = fix_diff(input.diff) - if on_log and diff ~= opts.diff then on_log("diff fixed") end + if on_log and diff ~= input.diff then on_log("diff fixed") end local diff_lines = vim.split(diff, "\n") @@ -203,16 +208,16 @@ function M.func(opts, on_log, on_complete, session_ctx) return false, "No diff blocks found" end - session_ctx.prev_streaming_diff_timestamp_map[opts.tool_use_id] = current_timestamp + session_ctx.prev_streaming_diff_timestamp_map[input.tool_use_id] = current_timestamp local bufnr, err = Helpers.get_bufnr(abs_path) if err then return false, err end session_ctx.undo_joined = session_ctx.undo_joined or {} - local undo_joined = session_ctx.undo_joined[opts.tool_use_id] + local undo_joined = session_ctx.undo_joined[input.tool_use_id] if not undo_joined then pcall(vim.cmd.undojoin) - session_ctx.undo_joined[opts.tool_use_id] = true + session_ctx.undo_joined[input.tool_use_id] = true end local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) @@ -242,10 +247,10 @@ function M.func(opts, on_log, on_complete, session_ctx) session_ctx.rough_diff_blocks_to_diff_blocks_cache_map = session_ctx.rough_diff_blocks_to_diff_blocks_cache_map or {} local rough_diff_blocks_to_diff_blocks_cache = - session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] + session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] if not rough_diff_blocks_to_diff_blocks_cache then rough_diff_blocks_to_diff_blocks_cache = {} - session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[opts.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache + session_ctx.rough_diff_blocks_to_diff_blocks_cache_map[input.tool_use_id] = rough_diff_blocks_to_diff_blocks_cache end local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_) @@ -472,7 +477,7 @@ function M.func(opts, on_log, on_complete, session_ctx) on_complete(false, "User canceled") return end - if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end + if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end on_complete(true, nil) end, }) @@ -615,35 +620,35 @@ function M.func(opts, on_log, on_complete, session_ctx) end session_ctx.extmark_id_map = session_ctx.extmark_id_map or {} - local extmark_id_map = session_ctx.extmark_id_map[opts.tool_use_id] + local extmark_id_map = session_ctx.extmark_id_map[input.tool_use_id] if not extmark_id_map then extmark_id_map = {} - session_ctx.extmark_id_map[opts.tool_use_id] = extmark_id_map + session_ctx.extmark_id_map[input.tool_use_id] = extmark_id_map end session_ctx.virt_lines_map = session_ctx.virt_lines_map or {} - local virt_lines_map = session_ctx.virt_lines_map[opts.tool_use_id] + local virt_lines_map = session_ctx.virt_lines_map[input.tool_use_id] if not virt_lines_map then virt_lines_map = {} - session_ctx.virt_lines_map[opts.tool_use_id] = virt_lines_map + session_ctx.virt_lines_map[input.tool_use_id] = virt_lines_map end session_ctx.last_orig_diff_end_line_map = session_ctx.last_orig_diff_end_line_map or {} - local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] + local last_orig_diff_end_line = session_ctx.last_orig_diff_end_line_map[input.tool_use_id] if not last_orig_diff_end_line then last_orig_diff_end_line = 1 - session_ctx.last_orig_diff_end_line_map[opts.tool_use_id] = last_orig_diff_end_line + session_ctx.last_orig_diff_end_line_map[input.tool_use_id] = last_orig_diff_end_line end session_ctx.last_resp_diff_end_line_map = session_ctx.last_resp_diff_end_line_map or {} - local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] + local last_resp_diff_end_line = session_ctx.last_resp_diff_end_line_map[input.tool_use_id] if not last_resp_diff_end_line then last_resp_diff_end_line = 1 - session_ctx.last_resp_diff_end_line_map[opts.tool_use_id] = last_resp_diff_end_line + session_ctx.last_resp_diff_end_line_map[input.tool_use_id] = last_resp_diff_end_line end session_ctx.prev_diff_blocks_map = session_ctx.prev_diff_blocks_map or {} - local prev_diff_blocks = session_ctx.prev_diff_blocks_map[opts.tool_use_id] + local prev_diff_blocks = session_ctx.prev_diff_blocks_map[input.tool_use_id] if not prev_diff_blocks then prev_diff_blocks = {} - session_ctx.prev_diff_blocks_map[opts.tool_use_id] = prev_diff_blocks + session_ctx.prev_diff_blocks_map[input.tool_use_id] = prev_diff_blocks end local function get_unstable_diff_blocks(diff_blocks_) @@ -663,7 +668,7 @@ function M.func(opts, on_log, on_complete, session_ctx) local function highlight_streaming_diff_blocks() local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks) - session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks + session_ctx.prev_diff_blocks_map[input.tool_use_id] = diff_blocks local max_col = vim.o.columns for _, diff_block in ipairs(unstable_diff_blocks) do local new_lines = diff_block.new_lines @@ -747,7 +752,7 @@ function M.func(opts, on_log, on_complete, session_ctx) --- check if the parent dir is exists, if not, create it if vim.fn.isdirectory(parent_dir) == 0 then vim.fn.mkdir(parent_dir, "p") end vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end) - if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end + if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end on_complete(true, nil) end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx, M.name) end diff --git a/lua/avante/llm_tools/str_replace.lua b/lua/avante/llm_tools/str_replace.lua index d744d3d..1d0daf6 100644 --- a/lua/avante/llm_tools/str_replace.lua +++ b/lua/avante/llm_tools/str_replace.lua @@ -1,5 +1,4 @@ local Base = require("avante.llm_tools.base") -local Config = require("avante.config") ---@class AvanteLLMTool local M = setmetatable({}, Base) @@ -55,17 +54,17 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string, streaming?: boolean, tool_use_id?: string }> -function M.func(opts, on_log, on_complete, session_ctx) +function M.func(input, opts) local replace_in_file = require("avante.llm_tools.replace_in_file") - local diff = "------- SEARCH\n" .. opts.old_str .. "\n=======\n" .. opts.new_str - if not opts.streaming then diff = diff .. "\n+++++++ REPLACE" end - local new_opts = { - path = opts.path, + local diff = "------- SEARCH\n" .. input.old_str .. "\n=======\n" .. input.new_str + if not input.streaming then diff = diff .. "\n+++++++ REPLACE" end + local new_input = { + path = input.path, diff = diff, - streaming = opts.streaming, - tool_use_id = opts.tool_use_id, + streaming = input.streaming, + tool_use_id = input.tool_use_id, } - return replace_in_file.func(new_opts, on_log, on_complete, session_ctx) + return replace_in_file.func(new_input, opts) end return M diff --git a/lua/avante/llm_tools/think.lua b/lua/avante/llm_tools/think.lua index 8103335..51fb8b2 100644 --- a/lua/avante/llm_tools/think.lua +++ b/lua/avante/llm_tools/think.lua @@ -47,12 +47,13 @@ M.returns = { ---@field thought string ---@type avante.LLMToolOnRender -function M.on_render(opts, _, state) +function M.on_render(input, opts) + local state = opts.state local lines = {} local text = state == "generating" and "Thinking" or "Thoughts" table.insert(lines, Line:new({ { Utils.icon("🤔 ") .. text, Highlights.AVANTE_THINKING } })) table.insert(lines, Line:new({ { "" } })) - local content = opts.thought or "" + local content = input.thought or "" local text_lines = vim.split(content, "\n") for _, text_line in ipairs(text_lines) do table.insert(lines, Line:new({ { "> " .. text_line } })) @@ -61,7 +62,8 @@ function M.on_render(opts, _, state) end ---@type AvanteLLMToolFunc -function M.func(opts, on_log, on_complete, session_ctx) +function M.func(input, opts) + local on_complete = opts.on_complete if not on_complete then return false, "on_complete not provided" end on_complete(true, nil) end diff --git a/lua/avante/llm_tools/undo_edit.lua b/lua/avante/llm_tools/undo_edit.lua index afaba1d..8190394 100644 --- a/lua/avante/llm_tools/undo_edit.lua +++ b/lua/avante/llm_tools/undo_edit.lua @@ -43,9 +43,14 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string }> -function M.func(opts, on_log, on_complete, session_ctx) - if on_log then on_log("path: " .. opts.path) end - local abs_path = Helpers.get_abs_path(opts.path) +function M.func(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + local session_ctx = opts.session_ctx + if not on_complete then return false, "on_complete not provided" end + + if on_log then on_log("path: " .. input.path) end + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end @@ -59,7 +64,7 @@ function M.func(opts, on_log, on_complete, session_ctx) end vim.api.nvim_win_call(winid, function() vim.cmd("noautocmd undo") end) vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end) - if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end + if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end on_complete(true, nil) end, { focus = true }, session_ctx, M.name) end diff --git a/lua/avante/llm_tools/update_todo_status.lua b/lua/avante/llm_tools/update_todo_status.lua index d6b49c2..84dc61f 100644 --- a/lua/avante/llm_tools/update_todo_status.lua +++ b/lua/avante/llm_tools/update_todo_status.lua @@ -43,14 +43,15 @@ M.returns = { M.on_render = function() return {} end ---@type AvanteLLMToolFunc<{ id: string, status: string }> -function M.func(opts, on_log, on_complete, session_ctx) +function M.func(input, opts) + local on_complete = opts.on_complete local sidebar = require("avante").get() if not sidebar then return false, "Avante sidebar not found" end local todos = sidebar.chat_history.todos if not todos or #todos == 0 then return false, "No todos found" end for _, todo in ipairs(todos) do - if todo.id == opts.id then - todo.status = opts.status + if todo.id == input.id then + todo.status = input.status break end end diff --git a/lua/avante/llm_tools/view.lua b/lua/avante/llm_tools/view.lua index 2fec92b..b0a74de 100644 --- a/lua/avante/llm_tools/view.lua +++ b/lua/avante/llm_tools/view.lua @@ -87,18 +87,20 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, start_line?: integer, end_line?: integer }> -function M.func(opts, on_log, on_complete, session_ctx) - if not opts.path then return false, "path is required" end - if on_log then on_log("path: " .. opts.path) end - local abs_path = Helpers.get_abs_path(opts.path) +function M.func(input, opts) + local on_log = opts.on_log + local on_complete = opts.on_complete + if not input.path then return false, "path is required" end + if on_log then on_log("path: " .. input.path) end + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if Path:new(abs_path):is_dir() then return false, "Path is a directory: " .. abs_path end local file = io.open(abs_path, "r") if not file then return false, "file not found: " .. abs_path end local lines = Utils.read_file_from_buf_or_disk(abs_path) - local start_line = opts.start_line - local end_line = opts.end_line + local start_line = input.start_line + local end_line = input.end_line if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end local truncated_lines = {} local is_truncated = false diff --git a/lua/avante/llm_tools/write_to_file.lua b/lua/avante/llm_tools/write_to_file.lua index 4df1f73..de2b978 100644 --- a/lua/avante/llm_tools/write_to_file.lua +++ b/lua/avante/llm_tools/write_to_file.lua @@ -57,27 +57,26 @@ M.returns = { --- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view. ---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }> -function M.func(opts, on_log, on_complete, session_ctx) - if opts.the_content ~= nil then - opts.content = opts.the_content - opts.the_content = nil +function M.func(input, opts) + if input.the_content ~= nil then + input.content = input.the_content + input.the_content = nil end - if not on_complete then return false, "on_complete not provided" end - local abs_path = Helpers.get_abs_path(opts.path) + local abs_path = Helpers.get_abs_path(input.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end - if opts.content == nil then return false, "content not provided" end - if type(opts.content) ~= "string" then opts.content = vim.json.encode(opts.content) end + if input.content == nil then return false, "content not provided" end + if type(input.content) ~= "string" then input.content = vim.json.encode(input.content) end local old_lines = Utils.read_file_from_buf_or_disk(abs_path) local old_content = table.concat(old_lines or {}, "\n") local str_replace = require("avante.llm_tools.str_replace") - local new_opts = { - path = opts.path, + local new_input = { + path = input.path, old_str = old_content, - new_str = opts.content, - streaming = opts.streaming, - tool_use_id = opts.tool_use_id, + new_str = input.content, + streaming = input.streaming, + tool_use_id = input.tool_use_id, } - return str_replace.func(new_opts, on_log, on_complete, session_ctx) + return str_replace.func(new_input, opts) end return M diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 0ee397f..296109f 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2318,7 +2318,7 @@ function Sidebar:get_history_messages_for_api(opts) --- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content if is_edit_func_call and path and not message.message.content[1].is_error then local uniformed_path = Utils.uniform_path(path) - local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil) + local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, {}) if view_error then view_result = "Error: " .. view_error end local get_diagnostics_tool_use_id = Utils.uuid() local view_tool_use_id = Utils.uuid() @@ -2451,9 +2451,7 @@ function Sidebar:get_history_messages_for_api(opts) local end_line = tool_id_to_end_line[item.tool_use_id] local view_result, view_error = require("avante.llm_tools.view").func( { path = path, start_line = start_line, end_line = end_line }, - nil, - nil, - nil + {} ) if view_error then view_result = "Error: " .. view_error end item.content = view_result @@ -2773,6 +2771,23 @@ function Sidebar:create_input_container() self:save_history() end + local function set_tool_use_store(tool_id, key, value) + local tool_use_message = nil + for idx = #self.chat_history.messages, 1, -1 do + local message = self.chat_history.messages[idx] + local content = message.message.content + if type(content) == "table" and content[1].type == "tool_use" and content[1].id == tool_id then + tool_use_message = message + break + end + end + if not tool_use_message then return end + local tool_use_store = tool_use_message.tool_use_store or {} + tool_use_store[key] = value + tool_use_message.tool_use_store = tool_use_store + self:save_history() + end + ---@type AvanteLLMStopCallback local function on_stop(stop_opts) self.is_generating = false @@ -2837,6 +2852,7 @@ function Sidebar:create_input_container() on_tool_log = on_tool_log, on_messages_add = on_messages_add, on_state_change = on_state_change, + set_tool_use_store = set_tool_use_store, get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end, get_todos = function() local history = Path.history.load(self.code.bufnr) diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 43e150e..3d90256 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -107,6 +107,7 @@ vim.g.avante_login = vim.g.avante_login ---@field selected_code AvanteSelectedCode | nil ---@field selected_filepaths string[] | nil ---@field tool_use_logs string[] | nil +---@field tool_use_store table | nil ---@field just_for_display boolean | nil ---@field is_dummy boolean | nil ---@field is_compacted boolean | nil @@ -407,19 +408,30 @@ vim.g.avante_login = vim.g.avante_login ---@field on_stop AvanteLLMStopCallback ---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback ---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil +---@field set_tool_use_store? fun(tool_id: string, key: string, value: any): nil ---@field get_history_messages? fun(opts?: { all?: boolean }): avante.HistoryMessage[] ---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil ---@field on_state_change? fun(state: avante.GenerateState): nil ---@field update_tokens_usage? fun(usage: avante.LLMTokenUsage): nil --- +---@class AvanteLLMToolFuncOpts +---@field session_ctx table +---@field on_complete? fun(result: boolean | string | nil, error: string | nil): nil +---@field on_log? fun(log: string): nil +---@field set_store? fun(key: string, value: any): nil +--- ---@alias AvanteLLMToolFunc fun( --- input: T, ---- on_log?: (fun(log: string): nil), ---- on_complete?: (fun(result: boolean | string | nil, error: string | nil): nil), ---- session_ctx?: table) +--- opts: AvanteLLMToolFuncOpts) --- : (boolean | string | nil, string | nil) --- ---- @alias avante.LLMToolOnRender fun(input: T, logs: string[], state: avante.HistoryMessageState | nil): avante.ui.Line[] +---@class avante.LLMToolOnRenderOpts +---@field logs string[] +---@field state avante.HistoryMessageState +---@field store table | nil +---@field result_message avante.HistoryMessage | nil +--- +--- @alias avante.LLMToolOnRender fun(input: T, opts: avante.LLMToolOnRenderOpts): avante.ui.Line[] --- ---@class AvanteLLMTool ---@field name string diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 050167d..5a54ee6 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1674,14 +1674,22 @@ function M.message_content_item_to_lines(item, message, messages) return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) } end if item.type == "tool_use" then + local tool_result_message = M.get_tool_result_message(message, messages) local lines = {} local state = "generating" local hl = "AvanteStateSpinnerToolCalling" local ok, llm_tool = pcall(require, "avante.llm_tools." .. item.name) if ok then - if llm_tool.on_render then return llm_tool.on_render(item.input, message.tool_use_logs, message.state) end + ---@cast llm_tool AvanteLLMTool + if llm_tool.on_render then + return llm_tool.on_render(item.input, { + logs = message.tool_use_logs, + state = message.state, + store = message.tool_use_store, + result_message = tool_result_message, + }) + end end - local tool_result_message = M.get_tool_result_message(message, messages) if tool_result_message then local tool_result = tool_result_message.message.content[1] if tool_result.is_error then diff --git a/tests/llm_tools_spec.lua b/tests/llm_tools_spec.lua index 82747c5..7f8c2ba 100644 --- a/tests/llm_tools_spec.lua +++ b/tests/llm_tools_spec.lua @@ -53,21 +53,21 @@ describe("llm_tools", function() describe("ls", function() it("should list files in directory", function() - local result, err = ls({ path = ".", max_depth = 1 }) + local result, err = ls({ path = ".", max_depth = 1 }, {}) assert.is_nil(err) assert.falsy(result:find("avante.nvim")) assert.truthy(result:find("test.txt")) assert.falsy(result:find("test1.txt")) end) it("should list files in directory with depth", function() - local result, err = ls({ path = ".", max_depth = 2 }) + local result, err = ls({ path = ".", max_depth = 2 }, {}) assert.is_nil(err) assert.falsy(result:find("avante.nvim")) assert.truthy(result:find("test.txt")) assert.truthy(result:find("test1.txt")) end) it("should list files respecting gitignore", function() - local result, err = ls({ path = ".", max_depth = 2 }) + local result, err = ls({ path = ".", max_depth = 2 }, {}) assert.is_nil(err) assert.falsy(result:find("avante.nvim")) assert.truthy(result:find("test.txt")) @@ -78,49 +78,61 @@ describe("llm_tools", function() describe("view", function() it("should read file content", function() - view({ path = "test.txt" }, nil, function(content, err) - assert.is_nil(err) - assert.equals("test content", vim.json.decode(content).content) - end) + view({ path = "test.txt" }, { + on_complete = function(content, err) + assert.is_nil(err) + assert.equals("test content", vim.json.decode(content).content) + end, + }) end) it("should return error for non-existent file", function() - view({ path = "non_existent.txt" }, nil, function(content, err) - assert.truthy(err) - assert.equals("", content) - end) + view({ path = "non_existent.txt" }, { + on_complete = function(content, err) + assert.truthy(err) + assert.equals("", content) + end, + }) end) it("should read directory content", function() - view({ path = test_dir }, nil, function(content, err) - assert.is_nil(err) - assert.truthy(content:find("test.txt")) - assert.truthy(content:find("test content")) - end) + view({ path = test_dir }, { + on_complete = function(content, err) + assert.is_nil(err) + assert.truthy(content:find("test.txt")) + assert.truthy(content:find("test content")) + end, + }) end) end) describe("create_dir", function() it("should create new directory", function() - LlmTools.create_dir({ path = "new_dir" }, nil, function(success, err) - assert.is_nil(err) - assert.is_true(success) + LlmTools.create_dir({ path = "new_dir" }, { + session_ctx = {}, + on_complete = function(success, err) + assert.is_nil(err) + assert.is_true(success) - local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil - assert.is_true(dir_exists) - end) + local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil + assert.is_true(dir_exists) + end, + }) end) end) describe("delete_path", function() it("should delete existing file", function() - LlmTools.delete_path({ path = "test.txt" }, nil, function(success, err) - assert.is_nil(err) - assert.is_true(success) + LlmTools.delete_path({ path = "test.txt" }, { + session_ctx = {}, + on_complete = function(success, err) + assert.is_nil(err) + assert.is_true(success) - local file_exists = io.open(test_file, "r") ~= nil - assert.is_false(file_exists) - end) + local file_exists = io.open(test_file, "r") ~= nil + assert.is_false(file_exists) + end, + }) end) end) @@ -147,22 +159,22 @@ describe("llm_tools", function() file:write("this is nothing") file:close() - local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }) + local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }, {}) assert.is_nil(err) assert.truthy(result:find("searchable.txt")) assert.falsy(result:find("nothing.txt")) - local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }) + local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }, {}) assert.is_nil(err2) assert.truthy(result2:find("searchable.txt")) assert.falsy(result2:find("nothing.txt")) - local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }) + local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }, {}) assert.is_nil(err3) assert.falsy(result3:find("searchable.txt")) assert.falsy(result3:find("nothing.txt")) - local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }) + local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }, {}) assert.is_nil(err4) assert.truthy(result4:find("searchable.txt")) assert.falsy(result4:find("nothing.txt")) @@ -172,7 +184,7 @@ describe("llm_tools", function() query = "searchable", case_sensitive = false, exclude_pattern = "search*", - }) + }, {}) assert.is_nil(err5) assert.falsy(result5:find("searchable.txt")) assert.falsy(result5:find("nothing.txt")) @@ -191,7 +203,7 @@ describe("llm_tools", function() file:write("content for ag test") file:close() - local result, err = grep({ path = ".", query = "ag test" }) + local result, err = grep({ path = ".", query = "ag test" }, {}) assert.is_nil(err) assert.is_string(result) assert.truthy(result:find("ag_test.txt")) @@ -215,22 +227,22 @@ describe("llm_tools", function() file:write("this is nothing") file:close() - local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }) + local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }, {}) assert.is_nil(err) assert.truthy(result:find("searchable.txt")) assert.falsy(result:find("nothing.txt")) - local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }) + local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }, {}) assert.is_nil(err2) assert.truthy(result2:find("searchable.txt")) assert.falsy(result2:find("nothing.txt")) - local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }) + local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }, {}) assert.is_nil(err3) assert.falsy(result3:find("searchable.txt")) assert.falsy(result3:find("nothing.txt")) - local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }) + local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }, {}) assert.is_nil(err4) assert.truthy(result4:find("searchable.txt")) assert.falsy(result4:find("nothing.txt")) @@ -240,7 +252,7 @@ describe("llm_tools", function() query = "searchable", case_sensitive = false, exclude_pattern = "search*", - }) + }, {}) assert.is_nil(err5) assert.falsy(result5:find("searchable.txt")) assert.falsy(result5:find("nothing.txt")) @@ -250,18 +262,18 @@ describe("llm_tools", function() -- Mock exepath to return nothing vim.fn.exepath = function() return "" end - local result, err = grep({ path = ".", query = "test" }) + local result, err = grep({ path = ".", query = "test" }, {}) assert.equals("", result) assert.equals("No search command found", err) end) it("should respect path permissions", function() - local result, err = grep({ path = "../outside_project", query = "test" }) + local result, err = grep({ path = "../outside_project", query = "test" }, {}) assert.truthy(err:find("No permission to access path")) end) it("should handle non-existent paths", function() - local result, err = grep({ path = "non_existent_dir", query = "test" }) + local result, err = grep({ path = "non_existent_dir", query = "test" }, {}) assert.equals("", result) assert.truthy(err) assert.truthy(err:find("No such file or directory")) @@ -277,86 +289,84 @@ describe("llm_tools", function() -- end) it("should return error when running outside current directory", function() - bash({ path = "../outside_project", command = "echo 'test'" }, nil, function(result, err) - assert.is_false(result) - assert.truthy(err) - assert.truthy(err:find("No permission to access path")) - end) + bash({ path = "../outside_project", command = "echo 'test'" }, { + session_ctx = {}, + on_complete = function(result, err) + assert.is_false(result) + assert.truthy(err) + assert.truthy(err:find("No permission to access path")) + end, + }) end) end) describe("python", function() it("should execute Python code and return output", function() - LlmTools.python( - { - path = ".", - code = "print('Hello from Python')", - }, - nil, - function(result, err) + LlmTools.python({ + path = ".", + code = "print('Hello from Python')", + }, { + session_ctx = {}, + on_complete = function(result, err) assert.is_nil(err) assert.equals("Hello from Python\n", result) - end - ) + end, + }) end) it("should handle Python errors", function() - LlmTools.python( - { - path = ".", - code = "print(undefined_variable)", - }, - nil, - function(result, err) + LlmTools.python({ + path = ".", + code = "print(undefined_variable)", + }, { + session_ctx = {}, + on_complete = function(result, err) assert.is_nil(result) assert.truthy(err) assert.truthy(err:find("Error")) - end - ) + end, + }) end) it("should respect path permissions", function() - LlmTools.python( - { - path = "../outside_project", - code = "print('test')", - }, - nil, - function(result, err) + LlmTools.python({ + path = "../outside_project", + code = "print('test')", + }, { + session_ctx = {}, + on_complete = function(result, err) assert.is_nil(result) assert.truthy(err:find("No permission to access path")) - end - ) + end, + }) end) it("should handle non-existent paths", function() - LlmTools.python( - { - path = "non_existent_dir", - code = "print('test')", - }, - nil, - function(result, err) + LlmTools.python({ + path = "non_existent_dir", + code = "print('test')", + }, { + session_ctx = {}, + on_complete = function(result, err) assert.is_nil(result) assert.truthy(err:find("Path not found")) - end - ) + end, + }) end) it("should support custom container image", function() os.execute("docker image rm python:3.12-slim") - LlmTools.python( - { - path = ".", - code = "print('Hello from custom container')", - container_image = "python:3.12-slim", - }, - nil, - function(result, err) + LlmTools.python({ + path = ".", + code = "print('Hello from custom container')", + container_image = "python:3.12-slim", + }, { + session_ctx = {}, + on_complete = function(result, err) assert.is_nil(err) assert.equals("Hello from custom container\n", result) - end - ) + end, + }) end) end) @@ -370,7 +380,7 @@ describe("llm_tools", function() os.execute("touch " .. test_dir .. "/nested/file4.lua") -- Test for lua files in the root - local result, err = glob({ path = ".", pattern = "*.lua" }) + local result, err = glob({ path = ".", pattern = "*.lua" }, {}) assert.is_nil(err) local files = vim.json.decode(result).matches assert.equals(2, #files) @@ -380,7 +390,7 @@ describe("llm_tools", function() assert.falsy(vim.tbl_contains(files, test_dir .. "/nested/file4.lua")) -- Test with recursive pattern - local result2, err2 = glob({ path = ".", pattern = "**/*.lua" }) + local result2, err2 = glob({ path = ".", pattern = "**/*.lua" }, {}) assert.is_nil(err2) local files2 = vim.json.decode(result2).matches assert.equals(3, #files2) @@ -390,13 +400,13 @@ describe("llm_tools", function() end) it("should respect path permissions", function() - local result, err = glob({ path = "../outside_project", pattern = "*.txt" }) + local result, err = glob({ path = "../outside_project", pattern = "*.txt" }, {}) assert.equals("", result) assert.truthy(err:find("No permission to access path")) end) it("should handle patterns without matches", function() - local result, err = glob({ path = ".", pattern = "*.nonexistent" }) + local result, err = glob({ path = ".", pattern = "*.nonexistent" }, {}) assert.is_nil(err) local files = vim.json.decode(result).matches assert.equals(0, #files) @@ -411,7 +421,7 @@ describe("llm_tools", function() os.execute("touch " .. test_dir .. "/test_dir1/notignored1.lua") os.execute("touch " .. test_dir .. "/test_dir1/notignored2.lua") - local result, err = glob({ path = ".", pattern = "**/*.lua" }) + local result, err = glob({ path = ".", pattern = "**/*.lua" }, {}) assert.is_nil(err) local files = vim.json.decode(result).matches