From 901e1caa918be0b9620b09466fb2ffa8c57e1de9 Mon Sep 17 00:00:00 2001 From: Peter Cardenas <16930781+PeterCardenas@users.noreply.github.com> Date: Mon, 24 Feb 2025 21:22:36 -0800 Subject: [PATCH] feat: run_command is run asynchronously (#1377) --- lua/avante/llm.lua | 36 ++++++++++++++----- lua/avante/llm_tools.lua | 73 +++++++++++++++++++++++++++------------ lua/avante/types.lua | 6 +++- lua/avante/utils/init.lua | 39 +++++++++++++++++---- 4 files changed, 116 insertions(+), 38 deletions(-) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 6baac85..5548767 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -164,21 +164,39 @@ function M._stream(opts) on_start = opts.on_start, on_chunk = opts.on_chunk, on_stop = function(stop_opts) - if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then - local old_tool_histories = vim.deepcopy(opts.tool_histories) or {} - for _, tool_use in vim.spairs(stop_opts.tool_use_list) do - local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log) + ---@param tool_use_list AvanteLLMToolUse[] + ---@param tool_use_index integer + ---@param tool_histories AvanteLLMToolHistory[] + local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories) + if tool_use_index > #tool_use_list then + local new_opts = vim.tbl_deep_extend("force", opts, { + tool_histories = tool_histories, + }) + return M._stream(new_opts) + end + local tool_use = tool_use_list[tool_use_index] + ---@param result string | nil + ---@param error string | nil + local function handle_tool_result(result, error) local tool_result = { tool_use_id = tool_use.id, content = error ~= nil and error or result, is_error = error ~= nil, } - table.insert(old_tool_histories, { tool_result = tool_result, tool_use = tool_use }) + table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use }) + return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories) end - local new_opts = vim.tbl_deep_extend("force", opts, { - tool_histories = old_tool_histories, - }) - return M._stream(new_opts) + -- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil + local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result) + if result ~= nil or error ~= nil then return handle_tool_result(result, error) end + end + if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then + local old_tool_histories = vim.deepcopy(opts.tool_histories) or {} + local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[] + for _, tool_use in vim.spairs(stop_opts.tool_use_list) do + table.insert(sorted_tool_use_list, tool_use) + end + return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories) end return opts.on_stop(stop_opts) end, diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index 27207c6..e994403 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -238,7 +238,7 @@ function M.delete_dir(opts, on_log) end ---@type AvanteLLMToolFunc<{ rel_path: string, command: string }> -function M.run_command(opts, on_log) +function M.run_command(opts, on_log, on_complete) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end @@ -251,13 +251,27 @@ function M.run_command(opts, on_log) ---change cwd to abs_path local old_cwd = vim.fn.getcwd() vim.fn.chdir(abs_path) - local res = Utils.shell_run(opts.command, Config.run_command.shell_cmd) - vim.fn.chdir(old_cwd) - if res.code ~= 0 then - if res.stdout then return false, "Error: " .. res.stdout .. "; Error code: " .. tostring(res.code) end - return false, "Error code: " .. tostring(res.code) + ---@param output string + ---@param exit_code integer + ---@return string | boolean | nil result + ---@return string | nil error + local function handle_result(output, exit_code) + vim.fn.chdir(old_cwd) + if exit_code ~= 0 then + if output then return false, "Error: " .. output .. "; Error code: " .. tostring(exit_code) end + return false, "Error code: " .. tostring(exit_code) + end + return output, nil end - return res.stdout, nil + if on_complete then + Utils.shell_run_async(opts.command, Config.run_command.shell_cmd, function(output, exit_code) + local result, err = handle_result(output, exit_code) + on_complete(result, err) + end) + return nil, nil + end + local res = Utils.shell_run(opts.command, Config.run_command.shell_cmd) + return handle_result(res.stdout, res.code) end ---@type AvanteLLMToolFunc<{ query: string }> @@ -1100,9 +1114,10 @@ M._tools = { ---@param tools AvanteLLMTool[] ---@param tool_use AvanteLLMToolUse ---@param on_log? fun(tool_name: string, log: string): nil +---@param on_complete? fun(result: string | nil, error: string | nil): nil ---@return string | nil result ---@return string | nil error -function M.process_tool_use(tools, tool_use, on_log) +function M.process_tool_use(tools, tool_use, on_log, on_complete) Utils.debug("use tool", tool_use.name, tool_use.input_json) ---@type AvanteLLMTool? local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end) ---@param tool AvanteLLMTool @@ -1110,22 +1125,36 @@ function M.process_tool_use(tools, tool_use, on_log) local input_json = vim.json.decode(tool_use.input_json) local func = tool.func or M[tool.name] if on_log then on_log(tool.name, "running tool") end - local result, error = func(input_json, function(log) + ---@param result string | nil | boolean + ---@param err string | nil + local function handle_result(result, err) + if on_log then on_log(tool.name, "tool finished") end + -- Utils.debug("result", result) + -- Utils.debug("error", error) + if err ~= nil then + if on_log then on_log(tool.name, "Error: " .. err) end + end + local result_str ---@type string? + if type(result) == "string" then + result_str = result + elseif result ~= nil then + result_str = vim.json.encode(result) + end + return result_str, err + end + local result, err = func(input_json, function(log) if on_log then on_log(tool.name, log) end + end, function(result, err) + result, err = handle_result(result, err) + if on_complete == nil then + Utils.error("asynchronous tool " .. tool.name .. " result not handled") + return + end + on_complete(result, err) end) - if on_log then on_log(tool.name, "tool finished") end - -- Utils.debug("result", result) - -- Utils.debug("error", error) - if error ~= nil then - if on_log then on_log(tool.name, "Error: " .. error) end - end - local result_str ---@type string? - if type(result) == "string" then - result_str = result - elseif result ~= nil then - result_str = vim.json.encode(result) - end - return result_str, error + -- Result and error being nil means that the tool was executed asynchronously + if result == nil and err == nil and on_complete then return end + return handle_result(result, err) end ---@param tool_use AvanteLLMToolUse diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 04ed5dc..a5debcb 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -323,7 +323,11 @@ vim.g.avante_login = vim.g.avante_login ---@field on_stop AvanteLLMStopCallback ---@field on_tool_log? function(tool_name: string, log: string): nil --- ----@alias AvanteLLMToolFunc fun(input: T, on_log?: (fun(log: string): nil) | nil): (boolean | string | nil, string | nil) +---@alias AvanteLLMToolFunc fun( +--- input: T, +--- on_log?: (fun(log: string): nil) | nil, +--- on_complete?: (fun(result: boolean | string | nil, error: string | nil): nil) | nil) +--- : (boolean | string | nil, string | nil) --- ---@class AvanteLLMTool ---@field name string diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index b2ee2aa..11a7f22 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -71,14 +71,11 @@ function M.get_system_info() return res end ---- This function will run given shell command synchronously. ---@param input_cmd string ---@param shell_cmd string? ----@return vim.SystemCompleted -function M.shell_run(input_cmd, shell_cmd) +local function get_cmd_for_shell(input_cmd, shell_cmd) local shell = vim.o.shell:lower() - ---@type string - local cmd + local cmd ---@type string -- powershell then we can just run the cmd if shell:match("powershell") or shell:match("pwsh") then @@ -89,17 +86,47 @@ function M.shell_run(input_cmd, shell_cmd) elseif fn.has("win32") > 0 then cmd = 'powershell.exe -NoProfile -Command "' .. input_cmd:gsub('"', "'") .. '"' else - -- linux and macos we wil just do sh -c + -- linux and macos we will just do sh -c shell_cmd = shell_cmd or "sh -c" cmd = shell_cmd .. " " .. fn.shellescape(input_cmd) end + return cmd +end + +--- This function will run given shell command synchronously. +---@param input_cmd string +---@param shell_cmd string? +---@return vim.SystemCompleted +function M.shell_run(input_cmd, shell_cmd) + local cmd = get_cmd_for_shell(input_cmd, shell_cmd) + local output = fn.system(cmd) local code = vim.v.shell_error return { stdout = output, code = code } end +---@param input_cmd string +---@param shell_cmd string? +---@param on_complete fun(output: string, code: integer) +function M.shell_run_async(input_cmd, shell_cmd, on_complete) + local cmd = get_cmd_for_shell(input_cmd, shell_cmd) + ---@type string[] + local output = {} + fn.jobstart(cmd, { + on_stdout = function(_, data) + if not data then return end + vim.list_extend(output, data) + end, + on_stderr = function(_, data) + if not data then return end + vim.list_extend(output, data) + end, + on_exit = function(_, exit_code) on_complete(table.concat(output, "\n"), exit_code) end, + }) +end + ---@see https://github.com/LazyVim/LazyVim/blob/main/lua/lazyvim/util/toggle.lua --- ---@alias _ToggleSet fun(state: boolean): nil