From 6e77da83c15578ed5d10153a6d91dfcfc1f23859 Mon Sep 17 00:00:00 2001 From: yetone Date: Mon, 17 Mar 2025 01:40:05 +0800 Subject: [PATCH] fix: better sidebar (#1603) * fix: better sidebar * feat: better msg history * fix: tests --- lua/avante/api.lua | 9 + lua/avante/config.lua | 24 +- lua/avante/file_selector.lua | 1 + lua/avante/highlights.lua | 7 + lua/avante/llm.lua | 1 + lua/avante/llm_tools.lua | 508 ++++++++++++------ lua/avante/providers/claude.lua | 47 +- lua/avante/providers/gemini.lua | 33 +- lua/avante/providers/openai.lua | 51 +- lua/avante/sidebar.lua | 122 +++-- .../templates/_tools-guidelines.avanterules | 1 + lua/avante/templates/base.avanterules | 8 - lua/avante/types.lua | 11 +- lua/avante/ui.lua | 139 +++++ lua/avante/utils/history.lua | 61 ++- lua/avante/utils/tokens.lua | 24 +- tests/llm_tools_spec.lua | 142 +++-- 17 files changed, 870 insertions(+), 319 deletions(-) create mode 100644 lua/avante/ui.lua diff --git a/lua/avante/api.lua b/lua/avante/api.lua index beea54a..e7abbe9 100644 --- a/lua/avante/api.lua +++ b/lua/avante/api.lua @@ -99,6 +99,8 @@ end ---@field win? table windows options similar to |nvim_open_win()| ---@field ask? boolean ---@field floating? boolean whether to open a floating input to enter the question +---@field new_chat? boolean whether to open a new chat +---@field without_selection? boolean whether to open a new chat without selection ---@param opts? AskOptions function M.ask(opts) @@ -117,6 +119,7 @@ function M.ask(opts) opts = vim.tbl_extend("force", { selection = Utils.get_visual_selection_and_range() }, opts) + ---@param input string | nil local function ask(input) if input == nil or input == "" then input = opts.question end local sidebar = require("avante").get() @@ -124,6 +127,12 @@ function M.ask(opts) sidebar:close({ goto_code_win = false }) end require("avante").open_sidebar(opts) + if opts.new_chat then sidebar:new_chat() end + if opts.without_selection then + sidebar.code.selection = nil + sidebar.file_selector:reset() + if sidebar.selected_files_container then sidebar.selected_files_container:unmount() end + end if input == nil or input == "" then return true end vim.api.nvim_exec_autocmds("User", { pattern = "AvanteInputSubmitted", data = { request = input } }) return true diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 6ba3cc1..f3661aa 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -195,7 +195,7 @@ M._defaults = { model = "gpt-4o", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 4096, + max_tokens = 8192, }, ---@type AvanteSupportedProvider copilot = { @@ -205,7 +205,7 @@ M._defaults = { allow_insecure = false, -- Allow insecure server connections timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 4096, + max_tokens = 8192, }, ---@type AvanteAzureProvider azure = { @@ -214,7 +214,7 @@ M._defaults = { api_version = "2024-06-01", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 4096, + max_tokens = 8192, }, ---@type AvanteSupportedProvider claude = { @@ -222,14 +222,14 @@ M._defaults = { model = "claude-3-7-sonnet-20250219", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 8000, + max_tokens = 8192, }, ---@type AvanteSupportedProvider bedrock = { model = "anthropic.claude-3-5-sonnet-20241022-v2:0", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 8000, + max_tokens = 8192, }, ---@type AvanteSupportedProvider gemini = { @@ -237,7 +237,7 @@ M._defaults = { model = "gemini-1.5-flash-latest", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 4096, + max_tokens = 8192, }, ---@type AvanteSupportedProvider vertex = { @@ -245,7 +245,7 @@ M._defaults = { model = "gemini-1.5-flash-002", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 4096, + max_tokens = 8192, }, ---@type AvanteSupportedProvider cohere = { @@ -253,7 +253,7 @@ M._defaults = { model = "command-r-plus-08-2024", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 4096, + max_tokens = 8192, }, ---@type AvanteSupportedProvider ollama = { @@ -261,7 +261,7 @@ M._defaults = { timeout = 30000, -- Timeout in milliseconds options = { temperature = 0, - num_ctx = 4096, + num_ctx = 8192, }, }, ---@type AvanteSupportedProvider @@ -270,7 +270,7 @@ M._defaults = { model = "claude-3-5-sonnet-v2@20241022", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 4096, + max_tokens = 8192, }, ---To add support for custom provider, follow the format below ---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details @@ -282,7 +282,7 @@ M._defaults = { model = "claude-3-5-haiku-20241022", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 8000, + max_tokens = 8192, }, ---@type AvanteSupportedProvider ["claude-opus"] = { @@ -290,7 +290,7 @@ M._defaults = { model = "claude-3-opus-20240229", timeout = 30000, -- Timeout in milliseconds temperature = 0, - max_tokens = 8000, + max_tokens = 8192, }, ["openai-gpt-4o-mini"] = { __inherited_from = "openai", diff --git a/lua/avante/file_selector.lua b/lua/avante/file_selector.lua index 1cbc2cf..99440b8 100644 --- a/lua/avante/file_selector.lua +++ b/lua/avante/file_selector.lua @@ -85,6 +85,7 @@ end function FileSelector:reset() self.selected_filepaths = {} self.event_handlers = {} + self:emit("update") end function FileSelector:add_selected_file(filepath) diff --git a/lua/avante/highlights.lua b/lua/avante/highlights.lua index 35c4228..0fac9ff 100644 --- a/lua/avante/highlights.lua +++ b/lua/avante/highlights.lua @@ -18,6 +18,13 @@ local Highlights = { INLINE_HINT = { name = "AvanteInlineHint", link = "Keyword" }, TO_BE_DELETED = { name = "AvanteToBeDeleted", bg = "#ffcccc", strikethrough = true }, TO_BE_DELETED_WITHOUT_STRIKETHROUGH = { name = "AvanteToBeDeletedWOStrikethrough", bg = "#562C30" }, + CONFIRM_TITLE = { name = "AvanteConfirmTitle", fg = "#1e222a", bg = "#e06c75" }, + BUTTON_DEFAULT = { name = "AvanteButtonDefault", fg = "#1e222a", bg = "#ABB2BF" }, + BUTTON_DEFAULT_HOVER = { name = "AvanteButtonDefaultHover", fg = "#1e222a", bg = "#a9cf8a" }, + BUTTON_PRIMARY = { name = "AvanteButtonPrimary", fg = "#1e222a", bg = "#ABB2BF" }, + BUTTON_PRIMARY_HOVER = { name = "AvanteButtonPrimaryHover", fg = "#1e222a", bg = "#56b6c2" }, + BUTTON_DANGER = { name = "AvanteButtonDanger", fg = "#1e222a", bg = "#ABB2BF" }, + BUTTON_DANGER_HOVER = { name = "AvanteButtonDangerHover", fg = "#1e222a", bg = "#e06c75" }, } Highlights.conflict = { diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index ee19c36..d5bb742 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -509,6 +509,7 @@ function M._stream(opts) end, stop_opts.retry_after * 1000) return end + stop_opts.tool_histories = opts.tool_histories return opts.on_stop(stop_opts) end, } diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index 8311299..10a56f0 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -17,9 +17,23 @@ local function get_abs_path(rel_path) return p end -function M.confirm(msg) - local ok = vim.fn.confirm(msg, "&Yes\n&No", 2) - return ok == 1 +function M.confirm(message, callback) + local UI = require("avante.ui") + UI.confirm(message, callback) +end + +---@param abs_path string +---@return boolean +local function is_ignored(abs_path) + local project_root = Utils.get_project_root() + local gitignore_path = project_root .. "/.gitignore" + local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path) + -- The checker should only take care of the path inside the project root + -- Specifically, it should not check the project root itself + -- Otherwise if the binary is named the same as the project root (such as Go binary), any paths + -- insde the project root will be ignored + local rel_path = Utils.make_relative_path(abs_path, project_root) + return Utils.is_ignored(rel_path, gitignore_patterns, gitignore_negate_patterns) end ---@param abs_path string @@ -28,14 +42,7 @@ local function has_permission_to_access(abs_path) if not Path:new(abs_path):is_absolute() then return false end local project_root = Utils.get_project_root() if abs_path:sub(1, #project_root) ~= project_root then return false end - local gitignore_path = project_root .. "/.gitignore" - local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path) - -- The checker should only take care of the path inside the project root - -- Specifically, it should not check the project root itself - -- Otherwise if the binary is named the same as the project root (such as Go binary), any paths - -- insde the project root will be ignored - local rel_path = Utils.make_relative_path(abs_path, project_root) - return not Utils.is_ignored(rel_path, gitignore_patterns, gitignore_negate_patterns) + return not is_ignored(abs_path) end ---@type AvanteLLMToolFunc<{ rel_path: string, pattern: string }> @@ -164,6 +171,41 @@ function M.read_file(opts, on_log) return content, nil end +---@type AvanteLLMToolFunc<{ abs_path: string }> +function M.read_global_file(opts, on_log) + local abs_path = get_abs_path(opts.abs_path) + if is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end + if on_log then on_log("path: " .. abs_path) end + local file = io.open(abs_path, "r") + if not file then return "", "file not found: " .. abs_path end + local content = file:read("*a") + file:close() + return content, nil +end + +---@type AvanteLLMToolFunc<{ abs_path: string, content: string }> +function M.write_global_file(opts, on_log, on_complete) + local abs_path = get_abs_path(opts.abs_path) + if is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end + if on_log then on_log("path: " .. abs_path) end + if on_log then on_log("content: " .. opts.content) end + if not on_complete then return false, "on_complete not provided" end + M.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + local file = io.open(abs_path, "w") + if not file then + on_complete(false, "file not found: " .. abs_path) + return + end + file:write(opts.content) + file:close() + on_complete(true, nil) + end) +end + ---@type AvanteLLMToolFunc<{ rel_path: string }> function M.create_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) @@ -183,7 +225,7 @@ function M.create_file(opts, on_log) end ---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> -function M.rename_file(opts, on_log) +function M.rename_file(opts, on_log, on_complete) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end @@ -192,11 +234,15 @@ function M.rename_file(opts, on_log) if on_log then on_log(abs_path .. " -> " .. new_abs_path) end if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end - if not M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path) then - return false, "User canceled" - end - os.rename(abs_path, new_abs_path) - return true, nil + if not on_complete then return false, "on_complete not provided" end + M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path, function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + os.rename(abs_path, new_abs_path) + on_complete(true, nil) + end) end ---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> @@ -214,32 +260,42 @@ function M.copy_file(opts, on_log) end ---@type AvanteLLMToolFunc<{ rel_path: string }> -function M.delete_file(opts, on_log) +function M.delete_file(opts, on_log, on_complete) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end - if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end - if on_log then on_log("Deleting file: " .. abs_path) end - os.remove(abs_path) - return true, nil + if not on_complete then return false, "on_complete not provided" end + M.confirm("Are you sure you want to delete the file: " .. abs_path, function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + if on_log then on_log("Deleting file: " .. abs_path) end + os.remove(abs_path) + on_complete(true, nil) + end) end ---@type AvanteLLMToolFunc<{ rel_path: string }> -function M.create_dir(opts, on_log) +function M.create_dir(opts, on_log, on_complete) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end - if not M.confirm("Are you sure you want to create the directory: " .. abs_path) then - return false, "User canceled" - end - if on_log then on_log("Creating directory: " .. abs_path) end - Path:new(abs_path):mkdir({ parents = true }) - return true, nil + if not on_complete then return false, "on_complete not provided" end + M.confirm("Are you sure you want to create the directory: " .. abs_path, function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + if on_log then on_log("Creating directory: " .. abs_path) end + Path:new(abs_path):mkdir({ parents = true }) + on_complete(true, nil) + end) end ---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> -function M.rename_dir(opts, on_log) +function M.rename_dir(opts, on_log, on_complete) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end @@ -247,26 +303,34 @@ function M.rename_dir(opts, on_log) local new_abs_path = get_abs_path(opts.new_rel_path) if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end if Path:new(new_abs_path):exists() then return false, "Directory already exists: " .. new_abs_path end - if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then - return false, "User canceled" - end - if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end - os.rename(abs_path, new_abs_path) - return true, nil + if not on_complete then return false, "on_complete not provided" end + M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?", function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end + os.rename(abs_path, new_abs_path) + on_complete(true, nil) + end) end ---@type AvanteLLMToolFunc<{ rel_path: string }> -function M.delete_dir(opts, on_log) +function M.delete_dir(opts, on_log, on_complete) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end - if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then - return false, "User canceled" - end - if on_log then on_log("Deleting directory: " .. abs_path) end - os.remove(abs_path) - return true, nil + if not on_complete then return false, "on_complete not provided" end + M.confirm("Are you sure you want to delete the directory: " .. abs_path, function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + if on_log then on_log("Deleting directory: " .. abs_path) end + os.remove(abs_path) + on_complete(true, nil) + end) end ---@type AvanteLLMToolFunc<{ rel_path: string, command: string }> @@ -275,11 +339,6 @@ function M.bash(opts, on_log, on_complete) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if on_log then on_log("command: " .. opts.command) end - if - not M.confirm("Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path) - then - return false, "User canceled" - end ---change cwd to abs_path ---@param output string ---@param exit_code integer @@ -292,18 +351,20 @@ function M.bash(opts, on_log, on_complete) end return output, nil end - if on_complete then - Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code) - local result, err = handle_result(output, exit_code) - on_complete(result, err) - end, abs_path) - return nil, nil - end - local old_cwd = vim.fn.getcwd() - vim.fn.chdir(abs_path) - local res = Utils.shell_run(opts.command, "bash -c") - vim.fn.chdir(old_cwd) - return handle_result(res.stdout, res.code) + if not on_complete then return false, "on_complete not provided" end + M.confirm( + "Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path, + function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code) + local result, err = handle_result(output, exit_code) + on_complete(result, err) + end, abs_path) + end + ) end ---@type AvanteLLMToolFunc<{ query: string }> @@ -464,7 +525,7 @@ function M.git_diff(opts, on_log) end ---@type AvanteLLMToolFunc<{ message: string, scope?: string }> -function M.git_commit(opts, on_log) +function M.git_commit(opts, on_log, on_complete) local git_cmd = vim.fn.exepath("git") if git_cmd == "" then return false, "Git command not found" end local project_root = Utils.get_project_root() @@ -518,36 +579,46 @@ function M.git_commit(opts, on_log) -- Construct full commit message for confirmation local full_commit_msg = table.concat(commit_msg_lines, "\n") + if not on_complete then return false, "on_complete not provided" end + -- Confirm with user - if not M.confirm("Are you sure you want to commit with message:\n" .. full_commit_msg) then - return false, "User canceled" - end + M.confirm("Are you sure you want to commit with message:\n" .. full_commit_msg, function(ok) + if not ok then + on_complete(false, "User canceled") + return + end + -- Stage changes if scope is provided + if opts.scope then + local stage_cmd = string.format("git add %s", opts.scope) + if on_log then on_log("Staging files: " .. stage_cmd) end + local stage_result = vim.fn.system(stage_cmd) + if vim.v.shell_error ~= 0 then + on_complete(false, "Failed to stage files: " .. stage_result) + return + end + end - -- Stage changes if scope is provided - if opts.scope then - local stage_cmd = string.format("git add %s", opts.scope) - if on_log then on_log("Staging files: " .. stage_cmd) end - local stage_result = vim.fn.system(stage_cmd) - if vim.v.shell_error ~= 0 then return false, "Failed to stage files: " .. stage_result end - end + -- Construct git commit command + local cmd_parts = { "git", "commit" } + -- Only add -S flag if GPG is available + if has_gpg then table.insert(cmd_parts, "-S") end + for _, line in ipairs(commit_msg_lines) do + table.insert(cmd_parts, "-m") + table.insert(cmd_parts, '"' .. line .. '"') + end + local cmd = table.concat(cmd_parts, " ") - -- Construct git commit command - local cmd_parts = { "git", "commit" } - -- Only add -S flag if GPG is available - if has_gpg then table.insert(cmd_parts, "-S") end - for _, line in ipairs(commit_msg_lines) do - table.insert(cmd_parts, "-m") - table.insert(cmd_parts, '"' .. line .. '"') - end - local cmd = table.concat(cmd_parts, " ") + -- Execute git commit + if on_log then on_log("Running command: " .. cmd) end + local result = vim.fn.system(cmd) - -- Execute git commit - if on_log then on_log("Running command: " .. cmd) end - local result = vim.fn.system(cmd) + if vim.v.shell_error ~= 0 then + on_complete(false, "Failed to commit: " .. result) + return + end - if vim.v.shell_error ~= 0 then return false, "Failed to commit: " .. result end - - return true, nil + on_complete(true, nil) + end) end ---@type AvanteLLMToolFunc<{ query: string }> @@ -571,57 +642,62 @@ function M.python(opts, on_log, on_complete) if on_log then on_log("cwd: " .. abs_path) end if on_log then on_log("code:\n" .. opts.code) end local container_image = opts.container_image or "python:3.11-slim-bookworm" - if - not M.confirm( - "Are you sure you want to run the following python code in the `" - .. container_image - .. "` container, in the directory: `" - .. abs_path - .. "`?\n" - .. opts.code - ) - then - return nil, "User canceled" - end - if vim.fn.executable("docker") == 0 then return nil, "Python tool is not available to execute any code" end + if not on_complete then return nil, "on_complete not provided" end + M.confirm( + "Are you sure you want to run the following python code in the `" + .. container_image + .. "` container, in the directory: `" + .. abs_path + .. "`?\n" + .. opts.code, + function(ok) + if not ok then + on_complete(nil, "User canceled") + return + end + if vim.fn.executable("docker") == 0 then + on_complete(nil, "Python tool is not available to execute any code") + return + end - local function handle_result(result) ---@param result vim.SystemCompleted - if result.code ~= 0 then return nil, "Error: " .. (result.stderr or "Unknown error") end + local function handle_result(result) ---@param result vim.SystemCompleted + if result.code ~= 0 then return nil, "Error: " .. (result.stderr or "Unknown error") end - Utils.debug("output", result.stdout) - return result.stdout, nil - end - local job = vim.system( - { - "docker", - "run", - "--rm", - "-v", - abs_path .. ":" .. abs_path, - "-w", - abs_path, - container_image, - "python", - "-c", - opts.code, - }, - { - text = true, - cwd = abs_path, - }, - vim.schedule_wrap(function(result) - if not on_complete then return end - local output, err = handle_result(result) - on_complete(output, err) - end) + Utils.debug("output", result.stdout) + return result.stdout, nil + end + vim.system( + { + "docker", + "run", + "--rm", + "-v", + abs_path .. ":" .. abs_path, + "-w", + abs_path, + container_image, + "python", + "-c", + opts.code, + }, + { + text = true, + cwd = abs_path, + }, + vim.schedule_wrap(function(result) + if not on_complete then return end + local output, err = handle_result(result) + on_complete(output, err) + end) + ) + end ) - if on_complete then return end - local result = job:wait() - return handle_result(result) end +---@param user_input string +---@param history_messages AvanteLLMMessage[] ---@return AvanteLLMTool[] -function M.get_tools() +function M.get_tools(user_input, history_messages) local custom_tools = Config.custom_tools if type(custom_tools) == "function" then custom_tools = custom_tools() end ---@type AvanteLLMTool[] @@ -634,7 +710,7 @@ function M.get_tools() if tool.enabled == nil then return true else - return tool.enabled() + return tool.enabled({ user_input = user_input, history_messages = history_messages }) end end) :totable() @@ -644,7 +720,7 @@ end M._tools = { { name = "glob", - description = 'Fast file pattern matching using glob patterns like "**/*.js"', + description = 'Fast file pattern matching using glob patterns like "**/*.js", in current project scope', param = { type = "table", fields = { @@ -655,7 +731,7 @@ M._tools = { }, { name = "rel_path", - description = "Relative path to the directory, as cwd", + description = "Relative path to the project directory, as cwd", type = "string", }, }, @@ -704,7 +780,7 @@ M._tools = { }, { name = "python", - description = "Run python code. Can't use it to read files or modify files.", + description = "Run python code in current project scope. Can't use it to read files or modify files.", param = { type = "table", fields = { @@ -715,7 +791,7 @@ M._tools = { }, { name = "rel_path", - description = "Relative path to the directory, as cwd", + description = "Relative path to the project directory, as cwd", type = "string", }, }, @@ -736,7 +812,7 @@ M._tools = { }, { name = "git_diff", - description = "Get git diff for generating commit message", + description = "Get git diff for generating commit message in current project scope", param = { type = "table", fields = { @@ -763,7 +839,7 @@ M._tools = { }, { name = "git_commit", - description = "Commit changes with the given commit message", + description = "Commit changes with the given commit message in current project scope", param = { type = "table", fields = { @@ -796,13 +872,13 @@ M._tools = { }, { name = "list_files", - description = "List files in a directory", + description = "List files in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the directory", + description = "Relative path to the project directory", type = "string", }, { @@ -815,7 +891,7 @@ M._tools = { returns = { { name = "files", - description = "List of files in the directory", + description = "List of filepaths in the directory", type = "string[]", }, { @@ -828,13 +904,13 @@ M._tools = { }, { name = "search_files", - description = "Search for files in a directory", + description = "Search for files in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the directory", + description = "Relative path to the project directory", type = "string", }, { @@ -847,7 +923,7 @@ M._tools = { returns = { { name = "files", - description = "List of files that match the keyword", + description = "List of filepaths that match the keyword", type = "string", }, { @@ -860,13 +936,13 @@ M._tools = { }, { name = "grep_search", - description = "Search for a keyword in a directory using grep", + description = "Search for a keyword in a directory using grep in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the directory", + description = "Relative path to the project directory", type = "string", }, { @@ -911,13 +987,13 @@ M._tools = { }, { name = "read_file_toplevel_symbols", - description = "Read the top-level symbols of a file", + description = "Read the top-level symbols of a file in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the file", + description = "Relative path to the file in current project scope", type = "string", }, }, @@ -938,13 +1014,28 @@ M._tools = { }, { name = "read_file", - description = "Read the contents of a file. If the file content is already in the context, do not use this tool.", + description = "Read the contents of a file in current project scope. If the file content is already in the context, do not use this tool.", + enabled = function(opts) + if opts.user_input:match("@read_global_file") then return false end + for _, message in ipairs(opts.history_messages) do + if message.role == "user" then + local content = message.content + if type(content) == "string" and content:match("@read_global_file") then return false end + if type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" and item:match("@read_global_file") then return false end + end + end + end + end + return true + end, param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the file", + description = "Relative path to the file in current project scope", type = "string", }, }, @@ -963,15 +1054,104 @@ M._tools = { }, }, }, + { + name = "read_global_file", + description = "Read the contents of a file in the global scope. If the file content is already in the context, do not use this tool.", + enabled = function(opts) + if opts.user_input:match("@read_global_file") then return true end + for _, message in ipairs(opts.history_messages) do + if message.role == "user" then + local content = message.content + if type(content) == "string" and content:match("@read_global_file") then return true end + if type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" and item:match("@read_global_file") then return true end + end + end + end + end + return false + end, + param = { + type = "table", + fields = { + { + name = "abs_path", + description = "Absolute path to the file in global scope", + type = "string", + }, + }, + }, + returns = { + { + name = "content", + description = "Contents of the file", + type = "string", + }, + { + name = "error", + description = "Error message if the file was not read successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "write_global_file", + description = "Write to a file in the global scope", + enabled = function(opts) + if opts.user_input:match("@write_global_file") then return true end + for _, message in ipairs(opts.history_messages) do + if message.role == "user" then + local content = message.content + if type(content) == "string" and content:match("@write_global_file") then return true end + if type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" and item:match("@write_global_file") then return true end + end + end + end + end + return false + end, + param = { + type = "table", + fields = { + { + name = "abs_path", + description = "Absolute path to the file in global scope", + type = "string", + }, + { + name = "content", + description = "Content to write to the file", + type = "string", + }, + }, + }, + returns = { + { + name = "success", + description = "True if the file was written successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the file was not written successfully", + type = "string", + optional = true, + }, + }, + }, { name = "create_file", - description = "Create a new file", + description = "Create a new file in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the file", + description = "Relative path to the file in current project scope", type = "string", }, }, @@ -992,13 +1172,13 @@ M._tools = { }, { name = "rename_file", - description = "Rename a file", + description = "Rename a file in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the file", + description = "Relative path to the file in current project scope", type = "string", }, { @@ -1024,13 +1204,13 @@ M._tools = { }, { name = "delete_file", - description = "Delete a file", + description = "Delete a file in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the file", + description = "Relative path to the file in current project scope", type = "string", }, }, @@ -1051,13 +1231,13 @@ M._tools = { }, { name = "create_dir", - description = "Create a new directory", + description = "Create a new directory in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the directory", + description = "Relative path to the project directory", type = "string", }, }, @@ -1078,13 +1258,13 @@ M._tools = { }, { name = "rename_dir", - description = "Rename a directory", + description = "Rename a directory in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the directory", + description = "Relative path to the project directory", type = "string", }, { @@ -1110,13 +1290,13 @@ M._tools = { }, { name = "delete_dir", - description = "Delete a directory", + description = "Delete a directory in current project scope", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the directory", + description = "Relative path to the project directory", type = "string", }, }, @@ -1137,13 +1317,13 @@ M._tools = { }, { name = "bash", - description = "Run a bash command in a directory. Can't use search commands like find/grep or read tools like cat/ls. Can't use it to read files or modify files.", + description = "Run a bash command in current project scope. Can't use search commands like find/grep or read tools like cat/ls. Can't use it to read files or modify files.", param = { type = "table", fields = { { name = "rel_path", - description = "Relative path to the directory", + description = "Relative path to the project directory, as cwd", type = "string", }, { diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index b5ce6a3..7dda129 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -59,6 +59,8 @@ function M:parse_messages(opts) ---@type AvanteClaudeMessage[] local messages = {} + local provider_conf, _ = P.parse_config(self) + ---@type {idx: integer, length: integer}[] local messages_with_length = {} for idx, message in ipairs(opts.messages) do @@ -76,15 +78,46 @@ function M:parse_messages(opts) end for idx, message in ipairs(opts.messages) do + local content_items = message.content + local message_content = {} + if type(content_items) == "string" then + table.insert(message_content, { + type = "text", + text = message.content, + cache_control = top_two[idx] and { type = "ephemeral" } or nil, + }) + elseif type(content_items) == "table" then + ---@cast content_items AvanteLLMMessageContentItem[] + for _, item in ipairs(content_items) do + if type(item) == "string" then + table.insert( + message_content, + { type = "text", text = item, cache_control = top_two[idx] and { type = "ephemeral" } or nil } + ) + elseif type(item) == "table" and item.type == "text" then + table.insert( + message_content, + { type = "text", text = item.text, cache_control = top_two[idx] and { type = "ephemeral" } or nil } + ) + elseif type(item) == "table" and item.type == "image" then + table.insert(message_content, { type = "image", source = item.source }) + elseif not provider_conf.disable_tools and type(item) == "table" and item.type == "tool_use" then + table.insert(message_content, { type = "tool_use", name = item.name, id = item.id, input = item.input }) + elseif not provider_conf.disable_tools and type(item) == "table" and item.type == "tool_result" then + table.insert( + message_content, + { type = "tool_result", tool_use_id = item.tool_use_id, content = item.content, is_error = item.is_error } + ) + elseif type(item) == "table" and item.type == "thinking" then + table.insert(message_content, { type = "thinking", thinking = item.thinking, signature = item.signature }) + elseif type(item) == "table" and item.type == "redacted_thinking" then + table.insert(message_content, { type = "redacted_thinking", data = item.data }) + end + end + end table.insert(messages, { role = self.role_map[message.role], - content = { - { - type = "text", - text = message.content, - cache_control = top_two[idx] and { type = "ephemeral" } or nil, - }, - }, + content = message_content, }) end diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index eafc2b8..4718426 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -35,9 +35,36 @@ function M:parse_messages(opts) end end prev_role = role - table.insert(contents, { role = M.role_map[role] or role, parts = { - { text = message.content }, - } }) + local parts = {} + local content_items = message.content + if type(content_items) == "string" then + table.insert(parts, { text = content_items }) + elseif type(content_items) == "table" then + ---@cast content_items AvanteLLMMessageContentItem[] + for _, item in ipairs(content_items) do + if type(item) == "string" then + table.insert(parts, { text = item }) + elseif type(item) == "table" and item.type == "text" then + table.insert(parts, { text = item.text }) + elseif type(item) == "table" and item.type == "image" then + table.insert(parts, { + inline_data = { + mime_type = "image/png", + data = item.source.data, + }, + }) + elseif type(item) == "table" and item.type == "tool_use" then + table.insert(parts, { text = item.name }) + elseif type(item) == "table" and item.type == "tool_result" then + table.insert(parts, { text = item.content }) + elseif type(item) == "table" and item.type == "thinking" then + table.insert(parts, { text = item.thinking }) + elseif type(item) == "table" and item.type == "redacted_thinking" then + table.insert(parts, { text = item.data }) + end + end + end + table.insert(contents, { role = M.role_map[role] or role, parts = parts }) end) if Clipboard.support_paste_image() and opts.image_paths then diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index a29ac4d..69670af 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -81,9 +81,54 @@ function M:parse_messages(opts) table.insert(messages, { role = "system", content = opts.system_prompt }) end - vim - .iter(opts.messages) - :each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end) + vim.iter(opts.messages):each(function(msg) + if type(msg.content) == "string" then + table.insert(messages, { role = self.role_map[msg.role], content = msg.content }) + else + local content = {} + local tool_calls = {} + local tool_results = {} + for _, item in ipairs(msg.content) do + if type(item) == "string" then + table.insert(content, { type = "text", text = item }) + elseif item.type == "text" then + table.insert(content, { type = "text", text = item.text }) + elseif item.type == "image" then + table.insert(content, { + type = "image_url", + image_url = { + url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data, + }, + }) + elseif item.type == "tool_use" then + table.insert(tool_calls, { + id = item.id, + type = "function", + ["function"] = { name = item.name, arguments = vim.json.encode(item.input) }, + }) + elseif item.type == "tool_result" then + table.insert( + tool_results, + { tool_call_id = item.tool_use_id, content = item.is_error and "Error: " .. item.content or item.content } + ) + end + end + table.insert(messages, { role = self.role_map[msg.role], content = content }) + if not provider_conf.disable_tools then + if #tool_calls > 0 then + table.insert(messages, { role = self.role_map["assistant"], tool_calls = tool_calls }) + end + if #tool_results > 0 then + for _, tool_result in ipairs(tool_results) do + table.insert( + messages, + { role = "tool", tool_call_id = tool_result.tool_call_id, content = tool_result.content or "" } + ) + end + end + end + end + end) if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then local message_content = messages[#messages].content diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 23fce22..ea7e349 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -43,6 +43,7 @@ local Sidebar = {} ---@field selected_files_container NuiSplit | nil ---@field input_container NuiSplit | nil ---@field file_selector FileSelector +---@field chat_history avante.ChatHistory | nil ---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage() function Sidebar:new(id) @@ -61,6 +62,7 @@ function Sidebar:new(id) input_container = nil, file_selector = FileSelector:new(id), is_generating = false, + chat_history = nil, }, { __index = self }) end @@ -397,6 +399,7 @@ local function transform_result_content(selected_files, result_content, prev_fil elseif line_content == "" then is_thinking = true last_think_tag_start_line = i + last_think_tag_end_line = 0 elseif line_content == "" then is_thinking = false last_think_tag_end_line = i @@ -1810,6 +1813,7 @@ function Sidebar:on_mount(opts) group = self.augroup, callback = function(args) local closed_winid = tonumber(args.match) + if closed_winid == self.winids.selected_files_container then return end if not self:is_focused_on(closed_winid) then return end self:close() end, @@ -1838,6 +1842,7 @@ function Sidebar:refresh_winids() local function switch_windows() local current_winid = api.nvim_get_current_win() + winids = vim.iter(winids):filter(function(winid) return api.nvim_win_is_valid(winid) end):totable() local current_idx = Utils.tbl_indexof(winids, current_winid) or 1 if current_idx == #winids then current_idx = 1 @@ -1906,6 +1911,8 @@ function Sidebar:initialize() self.file_selector:reset() self.file_selector:add_selected_file(filepath) + self:reload_chat_history() + return self end @@ -2095,6 +2102,7 @@ end function Sidebar:render_history_content(history) local content = "" for idx, entry in ipairs(history.entries) do + if entry.visible == false then goto continue end if entry.reset_memory then content = content .. "***MEMORY RESET***\n\n" if idx < #history.entries then content = content .. "-------\n\n" end @@ -2180,7 +2188,7 @@ end function Sidebar:new_chat(args, cb) Path.history.new(self.code.bufnr) - Sidebar.reload_chat_history() + self:reload_chat_history() self:update_content( "New chat", { ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end } @@ -2188,6 +2196,26 @@ function Sidebar:new_chat(args, cb) if cb then cb(args) end end +---@param message AvanteLLMMessage +---@param options {visible?: boolean} +function Sidebar:add_chat_history(message, options) + local timestamp = get_timestamp() + self:reload_chat_history() + table.insert(self.chat_history.entries, { + timestamp = timestamp, + provider = Config.provider, + model = Config.get_provider_config(Config.provider).model, + request = message.role == "user" and message.content or "", + response = message.role == "assistant" and message.content or "", + original_response = "", + selected_filepaths = nil, + selected_code = nil, + reset_memory = false, + visible = options.visible, + }) + Path.history.save(self.code.bufnr, self.chat_history) +end + function Sidebar:reset_memory(args, cb) local chat_history = Path.history.load(self.code.bufnr) if next(chat_history) ~= nil then @@ -2203,7 +2231,7 @@ function Sidebar:reset_memory(args, cb) reset_memory = true, }) Path.history.save(self.code.bufnr, chat_history) - Sidebar.reload_chat_history() + self:reload_chat_history() local history_content = self:render_history_content(chat_history) self:update_content(history_content, { focus = false, @@ -2212,7 +2240,7 @@ function Sidebar:reset_memory(args, cb) }) if cb then cb(args) end else - Sidebar.reload_chat_history() + self:reload_chat_history() self:update_content( "Chat history is already empty", { focus = false, scroll = false, callback = function() self:focus_input() end } @@ -2321,46 +2349,18 @@ local generating_text = "**Generating response ...**\n" local hint_window = nil +function Sidebar:reload_chat_history() + if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end + self.chat_history = Path.history.load(self.code.bufnr) +end + ---@param opts AskOptions function Sidebar:create_input_container(opts) if self.input_container then self.input_container:unmount() end if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end - local chat_history = Path.history.load(self.code.bufnr) - - Sidebar.reload_chat_history = function() chat_history = Path.history.load(self.code.bufnr) end - - local tools = vim.deepcopy(LLMTools.get_tools()) - table.insert(tools, { - name = "add_file_to_context", - description = "Add a file to the context", - ---@type AvanteLLMToolFunc<{ rel_path: string }> - func = function(input) - self.file_selector:add_selected_file(input.rel_path) - return "Added file to context", nil - end, - param = { - type = "table", - fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } }, - }, - returns = {}, - }) - - table.insert(tools, { - name = "remove_file_from_context", - description = "Remove a file from the context", - ---@type AvanteLLMToolFunc<{ rel_path: string }> - func = function(input) - self.file_selector:remove_selected_file(input.rel_path) - return "Removed file from context", nil - end, - param = { - type = "table", - fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } }, - }, - returns = {}, - }) + if self.chat_history == nil then self:reload_chat_history() end ---@param request string ---@param summarize_memory boolean @@ -2399,17 +2399,48 @@ function Sidebar:create_input_container(opts) end end - local entries = Utils.history.filter_active_entries(chat_history.entries) + local entries = Utils.history.filter_active_entries(self.chat_history.entries) - if chat_history.memory then + if self.chat_history.memory then entries = vim .iter(entries) - :filter(function(entry) return entry.timestamp > chat_history.memory.last_summarized_timestamp end) + :filter(function(entry) return entry.timestamp > self.chat_history.memory.last_summarized_timestamp end) :totable() end local history_messages = Utils.history.entries_to_llm_messages(entries) + local tools = vim.deepcopy(LLMTools.get_tools(request, history_messages)) + table.insert(tools, { + name = "add_file_to_context", + description = "Add a file to the context", + ---@type AvanteLLMToolFunc<{ rel_path: string }> + func = function(input) + self.file_selector:add_selected_file(input.rel_path) + return "Added file to context", nil + end, + param = { + type = "table", + fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } }, + }, + returns = {}, + }) + + table.insert(tools, { + name = "remove_file_from_context", + description = "Remove a file from the context", + ---@type AvanteLLMToolFunc<{ rel_path: string }> + func = function(input) + self.file_selector:remove_selected_file(input.rel_path) + return "Removed file from context", nil + end, + param = { + type = "table", + fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } }, + }, + returns = {}, + }) + ---@type AvanteGeneratePromptsOptions local prompts_opts = { ask = opts.ask or true, @@ -2425,7 +2456,7 @@ function Sidebar:create_input_container(opts) tools = tools, } - if chat_history.memory then prompts_opts.memory = chat_history.memory.content end + if self.chat_history.memory then prompts_opts.memory = self.chat_history.memory.content end if not summarize_memory or #history_messages < 8 then cb(prompts_opts) @@ -2434,7 +2465,7 @@ function Sidebar:create_input_container(opts) prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5) - Llm.summarize_memory(self.code.bufnr, chat_history, function(memory) + Llm.summarize_memory(self.code.bufnr, self.chat_history, function(memory) if memory then prompts_opts.memory = memory.content end cb(prompts_opts) end) @@ -2628,8 +2659,8 @@ function Sidebar:create_input_container(opts) end, 0) -- Save chat history - chat_history.entries = chat_history.entries or {} - table.insert(chat_history.entries, { + self.chat_history.entries = self.chat_history.entries or {} + table.insert(self.chat_history.entries, { timestamp = timestamp, provider = Config.provider, model = model, @@ -2638,8 +2669,9 @@ function Sidebar:create_input_container(opts) original_response = original_response, selected_filepaths = selected_filepaths, selected_code = selected_code, + tool_histories = stop_opts.tool_histories, }) - Path.history.save(self.code.bufnr, chat_history) + Path.history.save(self.code.bufnr, self.chat_history) end get_generate_prompts_options(request, true, function(generate_prompts_options) diff --git a/lua/avante/templates/_tools-guidelines.avanterules b/lua/avante/templates/_tools-guidelines.avanterules index f17d2d4..0a28641 100644 --- a/lua/avante/templates/_tools-guidelines.avanterules +++ b/lua/avante/templates/_tools-guidelines.avanterules @@ -19,3 +19,4 @@ Tools Usage Guide: - For any mathematical calculation problems, please prioritize using the `python` tool to solve them. Please try to avoid mathematical symbols in the return value of the `python` tool for mathematical problems and directly output human-readable results, because large models don't understand mathematical symbols, they only understand human natural language. - Do not use the `python` tool to read or modify files! If you use the `python` tool to read or modify files, you will be fired!!!!! - Do not use the `bash` tool to read or modify files! If you use the `bash` tool to read or modify files, you will be fired!!!!! + - If you are provided with the `write_file` tool, there's no need to output your change suggestions, just directly use the `write_file` tool to complete the changes. diff --git a/lua/avante/templates/base.avanterules b/lua/avante/templates/base.avanterules index 0818510..a938b75 100644 --- a/lua/avante/templates/base.avanterules +++ b/lua/avante/templates/base.avanterules @@ -1,11 +1,3 @@ -{# Uses https://mitsuhiko.github.io/minijinja-playground/ for testing: -{ - "ask": true, - "question": "Refactor to include tab flow", - "code_lang": "lua", - "file_content": "local Config = require('avante.config')" -} -#} Act as an expert software developer. Always use best practices when coding. Respect and use existing conventions, libraries, etc that are already present in the code base. diff --git a/lua/avante/types.lua b/lua/avante/types.lua index bd79221..f1c862b 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -76,9 +76,13 @@ vim.g.avante_login = vim.g.avante_login ---@field on_chunk AvanteLLMChunkCallback ---@field on_stop AvanteLLMStopCallback --- +---@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string } +--- +---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string +--- ---@class AvanteLLMMessage ---@field role "user" | "assistant" ----@field content string +---@field content AvanteLLMMessageContent --- ---@class AvanteLLMToolResult ---@field tool_name string @@ -245,6 +249,7 @@ vim.g.avante_login = vim.g.avante_login ---@field tool_use_list? AvanteLLMToolUse[] ---@field retry_after? integer ---@field headers? table +---@field tool_histories? AvanteLLMToolHistory[] --- ---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil @@ -342,7 +347,7 @@ vim.g.avante_login = vim.g.avante_login ---@field func? AvanteLLMToolFunc ---@field param AvanteLLMToolParam ---@field returns AvanteLLMToolReturn[] ----@field enabled? fun(): boolean +---@field enabled? fun(opts: { user_input: string, history_messages: AvanteLLMMessage[] }): boolean ---@class AvanteLLMToolPublic : AvanteLLMTool ---@field func AvanteLLMToolFunc @@ -374,6 +379,8 @@ vim.g.avante_login = vim.g.avante_login ---@field selected_code AvanteSelectedCode | nil ---@field reset_memory boolean? ---@field selected_filepaths string[] | nil +---@field visible boolean? +---@field tool_histories? AvanteLLMToolHistory[] --- ---@class avante.ChatHistory ---@field title string diff --git a/lua/avante/ui.lua b/lua/avante/ui.lua new file mode 100644 index 0000000..5ef52ff --- /dev/null +++ b/lua/avante/ui.lua @@ -0,0 +1,139 @@ +local Popup = require("nui.popup") +local NuiText = require("nui.text") +local event = require("nui.utils.autocmd").event +local Highlights = require("avante.highlights") + +local M = {} + +function M.confirm(message, callback) + local focus_index = 2 -- 1 = Yes, 2 = No + local yes_button_pos = { 18, 23 } + local no_button_pos = { 28, 32 } + + local BUTTON_NORMAL = Highlights.BUTTON_DEFAULT + local BUTTON_FOCUS = Highlights.BUTTON_DEFAULT_HOVER + + local popup = Popup({ + position = { + row = vim.o.lines - 5, + col = "50%", + }, + size = { width = 50, height = 7 }, + enter = true, + focusable = true, + border = { + style = "rounded", + text = { top = NuiText(" Confirmation ", Highlights.CONFIRM_TITLE) }, + }, + win_options = { + winblend = 10, + }, + }) + + local function focus_button() + if focus_index == 1 then + vim.api.nvim_win_set_cursor(popup.winid, { 4, yes_button_pos[1] }) + else + vim.api.nvim_win_set_cursor(popup.winid, { 4, no_button_pos[1] }) + end + end + + local function render_buttons() + local yes_style = (focus_index == 1) and BUTTON_FOCUS or BUTTON_NORMAL + local no_style = (focus_index == 2) and BUTTON_FOCUS or BUTTON_NORMAL + + vim.api.nvim_buf_set_lines(popup.bufnr, 0, -1, false, { + "", + " " .. message, + "", + " " .. " Yes No ", + "", + }) + + vim.api.nvim_buf_add_highlight(popup.bufnr, 0, yes_style, 3, yes_button_pos[1], yes_button_pos[2]) + vim.api.nvim_buf_add_highlight(popup.bufnr, 0, no_style, 3, no_button_pos[1], no_button_pos[2]) + focus_button() + end + + local function select_button() + popup:unmount() + callback(focus_index == 1) + end + + vim.keymap.set("n", "y", function() + focus_index = 1 + render_buttons() + select_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "n", function() + focus_index = 2 + render_buttons() + select_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "", function() + focus_index = 1 + focus_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "", function() + focus_index = 2 + focus_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "", function() + focus_index = (focus_index == 1) and 2 or 1 + focus_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "", function() + focus_index = (focus_index == 1) and 2 or 1 + focus_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "", function() select_button() end, { buffer = popup.bufnr }) + + vim.api.nvim_buf_set_keymap(popup.bufnr, "n", "", "", { + callback = function() + local pos = vim.fn.getmousepos() + local row, col = pos["winrow"], pos["wincol"] + if row == 4 then + if col >= yes_button_pos[1] and col <= yes_button_pos[2] then + focus_index = 1 + render_buttons() + select_button() + elseif col >= no_button_pos[1] and col <= no_button_pos[2] then + focus_index = 2 + render_buttons() + select_button() + end + end + end, + noremap = true, + silent = true, + }) + + vim.api.nvim_create_autocmd("CursorMoved", { + buffer = popup.bufnr, + callback = function() + local row, col = unpack(vim.api.nvim_win_get_cursor(0)) + if row == 4 then + if col >= yes_button_pos[1] and col <= yes_button_pos[2] then + focus_index = 1 + render_buttons() + elseif col >= no_button_pos[1] and col <= no_button_pos[2] then + focus_index = 2 + render_buttons() + end + end + end, + }) + + popup:on(event.BufLeave, function() popup:unmount() end) + + popup:mount() + render_buttons() +end + +return M diff --git a/lua/avante/utils/history.lua b/lua/avante/utils/history.lua index 5527b8b..9b72e2f 100644 --- a/lua/avante/utils/history.lua +++ b/lua/avante/utils/history.lua @@ -11,14 +11,6 @@ function M.filter_active_entries(entries) for i = #entries, 1, -1 do local entry = entries[i] if entry.reset_memory then break end - if - entry.request == nil - or entry.original_response == nil - or entry.request == "" - or entry.original_response == "" - then - break - end table.insert(entries_, 1, entry) end @@ -30,25 +22,62 @@ end function M.entries_to_llm_messages(entries) local messages = {} for _, entry in ipairs(entries) do - local user_content = "" - if entry.selected_filepaths ~= nil then - user_content = user_content .. "SELECTED FILES:\n\n" + if entry.selected_filepaths ~= nil and #entry.selected_filepaths > 0 then + local user_content = "SELECTED FILES:\n\n" for _, filepath in ipairs(entry.selected_filepaths) do user_content = user_content .. filepath .. "\n" end + table.insert(messages, { role = "user", content = user_content }) end if entry.selected_code ~= nil then - user_content = user_content - .. "SELECTED CODE:\n\n```" + local user_content_ = "SELECTED CODE:\n\n```" .. (entry.selected_code.file_type or "") .. (entry.selected_code.path and ":" .. entry.selected_code.path or "") .. "\n" .. entry.selected_code.content .. "\n```\n\n" + table.insert(messages, { role = "user", content = user_content_ }) end - user_content = user_content .. "USER PROMPT:\n\n" .. entry.request - table.insert(messages, { role = "user", content = user_content }) - table.insert(messages, { role = "assistant", content = Utils.trim_think_content(entry.original_response) }) + if entry.request ~= nil and entry.request ~= "" then + table.insert(messages, { role = "user", content = entry.request }) + end + if entry.tool_histories ~= nil and #entry.tool_histories > 0 then + for _, tool_history in ipairs(entry.tool_histories) do + local assistant_content = {} + if tool_history.tool_use ~= nil then + if tool_history.tool_use.response_contents ~= nil then + for _, response_content in ipairs(tool_history.tool_use.response_contents) do + table.insert(assistant_content, { type = "text", text = response_content }) + end + end + table.insert(assistant_content, { + type = "tool_use", + name = tool_history.tool_use.name, + id = tool_history.tool_use.id, + input = vim.json.decode(tool_history.tool_use.input_json), + }) + end + table.insert(messages, { + role = "assistant", + content = assistant_content, + }) + local user_content = {} + if tool_history.tool_result ~= nil and tool_history.tool_result.content ~= nil then + table.insert(user_content, { + type = "tool_result", + tool_use_id = tool_history.tool_result.tool_use_id, + content = tool_history.tool_result.content, + is_error = tool_history.tool_result.is_error, + }) + end + table.insert(messages, { + role = "user", + content = user_content, + }) + end + end + local assistant_content = Utils.trim_think_content(entry.original_response or "") + if assistant_content ~= "" then table.insert(messages, { role = "assistant", content = assistant_content }) end end return messages end diff --git a/lua/avante/utils/tokens.lua b/lua/avante/utils/tokens.lua index 4f1382d..3aaa9d7 100644 --- a/lua/avante/utils/tokens.lua +++ b/lua/avante/utils/tokens.lua @@ -10,9 +10,29 @@ local cost_per_token = { } --- Calculate the number of tokens in a given text. ----@param text string The text to calculate the number of tokens for. +---@param content AvanteLLMMessageContent The text to calculate the number of tokens in. ---@return integer The number of tokens in the given text. -function Tokens.calculate_tokens(text) +function Tokens.calculate_tokens(content) + local text = "" + + if type(content) == "string" then + text = content + elseif type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" then + text = text .. item + elseif type(item) == "table" and item.type == "text" then + text = text .. item.text + elseif type(item) == "table" and item.type == "image" then + text = text .. item.source.data + elseif type(item) == "table" and item.type == "tool_use" then + text = text .. item.name .. item.id + elseif type(item) == "table" and item.type == "tool_result" then + text = text .. item.tool_use_id .. item.content + end + end + end + if Tokenizer.available() then return Tokenizer.count(text) end local tokens = 0 diff --git a/tests/llm_tools_spec.lua b/tests/llm_tools_spec.lua index a278efe..717cfd0 100644 --- a/tests/llm_tools_spec.lua +++ b/tests/llm_tools_spec.lua @@ -3,7 +3,7 @@ local LlmTools = require("avante.llm_tools") local Config = require("avante.config") local Utils = require("avante.utils") -LlmTools.confirm = function(msg) return true end +LlmTools.confirm = function(msg, cb) return cb(true) end describe("llm_tools", function() local test_dir = "/tmp/test_llm_tools" @@ -85,34 +85,37 @@ describe("llm_tools", function() describe("create_file", function() it("should create new file", function() - local success, err = LlmTools.create_file({ rel_path = "new_file.txt" }) - assert.is_nil(err) - assert.is_true(success) + LlmTools.create_file({ rel_path = "new_file.txt" }, nil, function(success, err) + assert.is_nil(err) + assert.is_true(success) - local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil - assert.is_true(file_exists) + local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil + assert.is_true(file_exists) + end) end) end) describe("create_dir", function() it("should create new directory", function() - local success, err = LlmTools.create_dir({ rel_path = "new_dir" }) - assert.is_nil(err) - assert.is_true(success) + LlmTools.create_dir({ rel_path = "new_dir" }, nil, function(success, err) + assert.is_nil(err) + assert.is_true(success) - local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil - assert.is_true(dir_exists) + local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil + assert.is_true(dir_exists) + end) end) end) describe("delete_file", function() it("should delete existing file", function() - local success, err = LlmTools.delete_file({ rel_path = "test.txt" }) - assert.is_nil(err) - assert.is_true(success) + LlmTools.delete_file({ rel_path = "test.txt" }, nil, function(success, err) + assert.is_nil(err) + assert.is_true(success) - local file_exists = io.open(test_file, "r") ~= nil - assert.is_false(file_exists) + local file_exists = io.open(test_file, "r") ~= nil + assert.is_false(file_exists) + end) end) end) @@ -270,68 +273,93 @@ describe("llm_tools", function() describe("bash", function() it("should execute command and return output", function() - local result, err = LlmTools.bash({ rel_path = ".", command = "echo 'test'" }) - assert.is_nil(err) - assert.equals("test\n", result) + LlmTools.bash({ rel_path = ".", command = "echo 'test'" }, nil, function(result, err) + assert.is_nil(err) + assert.equals("test\n", result) + end) end) it("should return error when running outside current directory", function() - local result, err = LlmTools.bash({ rel_path = "../outside_project", command = "echo 'test'" }) - assert.is_false(result) - assert.truthy(err) - assert.truthy(err:find("No permission to access path")) + LlmTools.bash({ rel_path = "../outside_project", command = "echo 'test'" }, nil, function(result, err) + assert.is_false(result) + assert.truthy(err) + assert.truthy(err:find("No permission to access path")) + end) end) end) describe("python", function() - local original_system = vim.fn.system - it("should execute Python code and return output", function() - local result, err = LlmTools.python({ - rel_path = ".", - code = "print('Hello from Python')", - }) - assert.is_nil(err) - assert.equals("Hello from Python\n", result) + LlmTools.python( + { + rel_path = ".", + code = "print('Hello from Python')", + }, + nil, + function(result, err) + assert.is_nil(err) + assert.equals("Hello from Python\n", result) + end + ) end) it("should handle Python errors", function() - local result, err = LlmTools.python({ - rel_path = ".", - code = "print(undefined_variable)", - }) - assert.is_nil(result) - assert.truthy(err) - assert.truthy(err:find("Error")) + LlmTools.python( + { + rel_path = ".", + code = "print(undefined_variable)", + }, + nil, + function(result, err) + assert.is_nil(result) + assert.truthy(err) + assert.truthy(err:find("Error")) + end + ) end) it("should respect path permissions", function() - local result, err = LlmTools.python({ - rel_path = "../outside_project", - code = "print('test')", - }) - assert.is_nil(result) - assert.truthy(err:find("No permission to access path")) + LlmTools.python( + { + rel_path = "../outside_project", + code = "print('test')", + }, + nil, + function(result, err) + assert.is_nil(result) + assert.truthy(err:find("No permission to access path")) + end + ) end) it("should handle non-existent paths", function() - local result, err = LlmTools.python({ - rel_path = "non_existent_dir", - code = "print('test')", - }) - assert.is_nil(result) - assert.truthy(err:find("Path not found")) + LlmTools.python( + { + rel_path = "non_existent_dir", + code = "print('test')", + }, + nil, + function(result, err) + assert.is_nil(result) + assert.truthy(err:find("Path not found")) + end + ) end) it("should support custom container image", function() os.execute("docker image rm python:3.12-slim") - local result, err = LlmTools.python({ - rel_path = ".", - code = "print('Hello from custom container')", - container_image = "python:3.12-slim", - }) - assert.is_nil(err) - assert.equals("Hello from custom container\n", result) + LlmTools.python( + { + rel_path = ".", + code = "print('Hello from custom container')", + container_image = "python:3.12-slim", + }, + nil, + function(result, err) + assert.is_nil(err) + assert.equals("Hello from custom container\n", result) + end + ) end) end)