fix: better sidebar (#1603)

* fix: better sidebar

* feat: better msg history

* fix: tests
This commit is contained in:
yetone
2025-03-17 01:40:05 +08:00
committed by GitHub
parent f60f150a21
commit 6e77da83c1
17 changed files with 870 additions and 319 deletions

View File

@@ -99,6 +99,8 @@ end
---@field win? table<string, any> windows options similar to |nvim_open_win()|
---@field ask? boolean
---@field floating? boolean whether to open a floating input to enter the question
---@field new_chat? boolean whether to open a new chat
---@field without_selection? boolean whether to open a new chat without selection
---@param opts? AskOptions
function M.ask(opts)
@@ -117,6 +119,7 @@ function M.ask(opts)
opts = vim.tbl_extend("force", { selection = Utils.get_visual_selection_and_range() }, opts)
---@param input string | nil
local function ask(input)
if input == nil or input == "" then input = opts.question end
local sidebar = require("avante").get()
@@ -124,6 +127,12 @@ function M.ask(opts)
sidebar:close({ goto_code_win = false })
end
require("avante").open_sidebar(opts)
if opts.new_chat then sidebar:new_chat() end
if opts.without_selection then
sidebar.code.selection = nil
sidebar.file_selector:reset()
if sidebar.selected_files_container then sidebar.selected_files_container:unmount() end
end
if input == nil or input == "" then return true end
vim.api.nvim_exec_autocmds("User", { pattern = "AvanteInputSubmitted", data = { request = input } })
return true

View File

@@ -195,7 +195,7 @@ M._defaults = {
model = "gpt-4o",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 4096,
max_tokens = 8192,
},
---@type AvanteSupportedProvider
copilot = {
@@ -205,7 +205,7 @@ M._defaults = {
allow_insecure = false, -- Allow insecure server connections
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 4096,
max_tokens = 8192,
},
---@type AvanteAzureProvider
azure = {
@@ -214,7 +214,7 @@ M._defaults = {
api_version = "2024-06-01",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 4096,
max_tokens = 8192,
},
---@type AvanteSupportedProvider
claude = {
@@ -222,14 +222,14 @@ M._defaults = {
model = "claude-3-7-sonnet-20250219",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 8000,
max_tokens = 8192,
},
---@type AvanteSupportedProvider
bedrock = {
model = "anthropic.claude-3-5-sonnet-20241022-v2:0",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 8000,
max_tokens = 8192,
},
---@type AvanteSupportedProvider
gemini = {
@@ -237,7 +237,7 @@ M._defaults = {
model = "gemini-1.5-flash-latest",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 4096,
max_tokens = 8192,
},
---@type AvanteSupportedProvider
vertex = {
@@ -245,7 +245,7 @@ M._defaults = {
model = "gemini-1.5-flash-002",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 4096,
max_tokens = 8192,
},
---@type AvanteSupportedProvider
cohere = {
@@ -253,7 +253,7 @@ M._defaults = {
model = "command-r-plus-08-2024",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 4096,
max_tokens = 8192,
},
---@type AvanteSupportedProvider
ollama = {
@@ -261,7 +261,7 @@ M._defaults = {
timeout = 30000, -- Timeout in milliseconds
options = {
temperature = 0,
num_ctx = 4096,
num_ctx = 8192,
},
},
---@type AvanteSupportedProvider
@@ -270,7 +270,7 @@ M._defaults = {
model = "claude-3-5-sonnet-v2@20241022",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 4096,
max_tokens = 8192,
},
---To add support for custom provider, follow the format below
---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details
@@ -282,7 +282,7 @@ M._defaults = {
model = "claude-3-5-haiku-20241022",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 8000,
max_tokens = 8192,
},
---@type AvanteSupportedProvider
["claude-opus"] = {
@@ -290,7 +290,7 @@ M._defaults = {
model = "claude-3-opus-20240229",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,
max_tokens = 8000,
max_tokens = 8192,
},
["openai-gpt-4o-mini"] = {
__inherited_from = "openai",

View File

@@ -85,6 +85,7 @@ end
function FileSelector:reset()
self.selected_filepaths = {}
self.event_handlers = {}
self:emit("update")
end
function FileSelector:add_selected_file(filepath)

View File

@@ -18,6 +18,13 @@ local Highlights = {
INLINE_HINT = { name = "AvanteInlineHint", link = "Keyword" },
TO_BE_DELETED = { name = "AvanteToBeDeleted", bg = "#ffcccc", strikethrough = true },
TO_BE_DELETED_WITHOUT_STRIKETHROUGH = { name = "AvanteToBeDeletedWOStrikethrough", bg = "#562C30" },
CONFIRM_TITLE = { name = "AvanteConfirmTitle", fg = "#1e222a", bg = "#e06c75" },
BUTTON_DEFAULT = { name = "AvanteButtonDefault", fg = "#1e222a", bg = "#ABB2BF" },
BUTTON_DEFAULT_HOVER = { name = "AvanteButtonDefaultHover", fg = "#1e222a", bg = "#a9cf8a" },
BUTTON_PRIMARY = { name = "AvanteButtonPrimary", fg = "#1e222a", bg = "#ABB2BF" },
BUTTON_PRIMARY_HOVER = { name = "AvanteButtonPrimaryHover", fg = "#1e222a", bg = "#56b6c2" },
BUTTON_DANGER = { name = "AvanteButtonDanger", fg = "#1e222a", bg = "#ABB2BF" },
BUTTON_DANGER_HOVER = { name = "AvanteButtonDangerHover", fg = "#1e222a", bg = "#e06c75" },
}
Highlights.conflict = {

View File

@@ -509,6 +509,7 @@ function M._stream(opts)
end, stop_opts.retry_after * 1000)
return
end
stop_opts.tool_histories = opts.tool_histories
return opts.on_stop(stop_opts)
end,
}

View File

@@ -17,9 +17,23 @@ local function get_abs_path(rel_path)
return p
end
function M.confirm(msg)
local ok = vim.fn.confirm(msg, "&Yes\n&No", 2)
return ok == 1
function M.confirm(message, callback)
local UI = require("avante.ui")
UI.confirm(message, callback)
end
---@param abs_path string
---@return boolean
local function is_ignored(abs_path)
local project_root = Utils.get_project_root()
local gitignore_path = project_root .. "/.gitignore"
local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path)
-- The checker should only take care of the path inside the project root
-- Specifically, it should not check the project root itself
-- Otherwise if the binary is named the same as the project root (such as Go binary), any paths
-- insde the project root will be ignored
local rel_path = Utils.make_relative_path(abs_path, project_root)
return Utils.is_ignored(rel_path, gitignore_patterns, gitignore_negate_patterns)
end
---@param abs_path string
@@ -28,14 +42,7 @@ local function has_permission_to_access(abs_path)
if not Path:new(abs_path):is_absolute() then return false end
local project_root = Utils.get_project_root()
if abs_path:sub(1, #project_root) ~= project_root then return false end
local gitignore_path = project_root .. "/.gitignore"
local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path)
-- The checker should only take care of the path inside the project root
-- Specifically, it should not check the project root itself
-- Otherwise if the binary is named the same as the project root (such as Go binary), any paths
-- insde the project root will be ignored
local rel_path = Utils.make_relative_path(abs_path, project_root)
return not Utils.is_ignored(rel_path, gitignore_patterns, gitignore_negate_patterns)
return not is_ignored(abs_path)
end
---@type AvanteLLMToolFunc<{ rel_path: string, pattern: string }>
@@ -164,6 +171,41 @@ function M.read_file(opts, on_log)
return content, nil
end
---@type AvanteLLMToolFunc<{ abs_path: string }>
function M.read_global_file(opts, on_log)
local abs_path = get_abs_path(opts.abs_path)
if is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
local file = io.open(abs_path, "r")
if not file then return "", "file not found: " .. abs_path end
local content = file:read("*a")
file:close()
return content, nil
end
---@type AvanteLLMToolFunc<{ abs_path: string, content: string }>
function M.write_global_file(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.abs_path)
if is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("content: " .. opts.content) end
if not on_complete then return false, "on_complete not provided" end
M.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
local file = io.open(abs_path, "w")
if not file then
on_complete(false, "file not found: " .. abs_path)
return
end
file:write(opts.content)
file:close()
on_complete(true, nil)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.create_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path)
@@ -183,7 +225,7 @@ function M.create_file(opts, on_log)
end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
function M.rename_file(opts, on_log)
function M.rename_file(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
@@ -192,11 +234,15 @@ function M.rename_file(opts, on_log)
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
if not M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path) then
return false, "User canceled"
end
os.rename(abs_path, new_abs_path)
return true, nil
if not on_complete then return false, "on_complete not provided" end
M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path, function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
os.rename(abs_path, new_abs_path)
on_complete(true, nil)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
@@ -214,32 +260,42 @@ function M.copy_file(opts, on_log)
end
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.delete_file(opts, on_log)
function M.delete_file(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end
if on_log then on_log("Deleting file: " .. abs_path) end
os.remove(abs_path)
return true, nil
if not on_complete then return false, "on_complete not provided" end
M.confirm("Are you sure you want to delete the file: " .. abs_path, function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
if on_log then on_log("Deleting file: " .. abs_path) end
os.remove(abs_path)
on_complete(true, nil)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.create_dir(opts, on_log)
function M.create_dir(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
if not M.confirm("Are you sure you want to create the directory: " .. abs_path) then
return false, "User canceled"
end
if on_log then on_log("Creating directory: " .. abs_path) end
Path:new(abs_path):mkdir({ parents = true })
return true, nil
if not on_complete then return false, "on_complete not provided" end
M.confirm("Are you sure you want to create the directory: " .. abs_path, function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
if on_log then on_log("Creating directory: " .. abs_path) end
Path:new(abs_path):mkdir({ parents = true })
on_complete(true, nil)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
function M.rename_dir(opts, on_log)
function M.rename_dir(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
@@ -247,26 +303,34 @@ function M.rename_dir(opts, on_log)
local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "Directory already exists: " .. new_abs_path end
if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then
return false, "User canceled"
end
if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end
os.rename(abs_path, new_abs_path)
return true, nil
if not on_complete then return false, "on_complete not provided" end
M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?", function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end
os.rename(abs_path, new_abs_path)
on_complete(true, nil)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.delete_dir(opts, on_log)
function M.delete_dir(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then
return false, "User canceled"
end
if on_log then on_log("Deleting directory: " .. abs_path) end
os.remove(abs_path)
return true, nil
if not on_complete then return false, "on_complete not provided" end
M.confirm("Are you sure you want to delete the directory: " .. abs_path, function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
if on_log then on_log("Deleting directory: " .. abs_path) end
os.remove(abs_path)
on_complete(true, nil)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string, command: string }>
@@ -275,11 +339,6 @@ function M.bash(opts, on_log, on_complete)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if on_log then on_log("command: " .. opts.command) end
if
not M.confirm("Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path)
then
return false, "User canceled"
end
---change cwd to abs_path
---@param output string
---@param exit_code integer
@@ -292,18 +351,20 @@ function M.bash(opts, on_log, on_complete)
end
return output, nil
end
if on_complete then
Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code)
local result, err = handle_result(output, exit_code)
on_complete(result, err)
end, abs_path)
return nil, nil
end
local old_cwd = vim.fn.getcwd()
vim.fn.chdir(abs_path)
local res = Utils.shell_run(opts.command, "bash -c")
vim.fn.chdir(old_cwd)
return handle_result(res.stdout, res.code)
if not on_complete then return false, "on_complete not provided" end
M.confirm(
"Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path,
function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code)
local result, err = handle_result(output, exit_code)
on_complete(result, err)
end, abs_path)
end
)
end
---@type AvanteLLMToolFunc<{ query: string }>
@@ -464,7 +525,7 @@ function M.git_diff(opts, on_log)
end
---@type AvanteLLMToolFunc<{ message: string, scope?: string }>
function M.git_commit(opts, on_log)
function M.git_commit(opts, on_log, on_complete)
local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return false, "Git command not found" end
local project_root = Utils.get_project_root()
@@ -518,36 +579,46 @@ function M.git_commit(opts, on_log)
-- Construct full commit message for confirmation
local full_commit_msg = table.concat(commit_msg_lines, "\n")
if not on_complete then return false, "on_complete not provided" end
-- Confirm with user
if not M.confirm("Are you sure you want to commit with message:\n" .. full_commit_msg) then
return false, "User canceled"
end
M.confirm("Are you sure you want to commit with message:\n" .. full_commit_msg, function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
-- Stage changes if scope is provided
if opts.scope then
local stage_cmd = string.format("git add %s", opts.scope)
if on_log then on_log("Staging files: " .. stage_cmd) end
local stage_result = vim.fn.system(stage_cmd)
if vim.v.shell_error ~= 0 then
on_complete(false, "Failed to stage files: " .. stage_result)
return
end
end
-- Stage changes if scope is provided
if opts.scope then
local stage_cmd = string.format("git add %s", opts.scope)
if on_log then on_log("Staging files: " .. stage_cmd) end
local stage_result = vim.fn.system(stage_cmd)
if vim.v.shell_error ~= 0 then return false, "Failed to stage files: " .. stage_result end
end
-- Construct git commit command
local cmd_parts = { "git", "commit" }
-- Only add -S flag if GPG is available
if has_gpg then table.insert(cmd_parts, "-S") end
for _, line in ipairs(commit_msg_lines) do
table.insert(cmd_parts, "-m")
table.insert(cmd_parts, '"' .. line .. '"')
end
local cmd = table.concat(cmd_parts, " ")
-- Construct git commit command
local cmd_parts = { "git", "commit" }
-- Only add -S flag if GPG is available
if has_gpg then table.insert(cmd_parts, "-S") end
for _, line in ipairs(commit_msg_lines) do
table.insert(cmd_parts, "-m")
table.insert(cmd_parts, '"' .. line .. '"')
end
local cmd = table.concat(cmd_parts, " ")
-- Execute git commit
if on_log then on_log("Running command: " .. cmd) end
local result = vim.fn.system(cmd)
-- Execute git commit
if on_log then on_log("Running command: " .. cmd) end
local result = vim.fn.system(cmd)
if vim.v.shell_error ~= 0 then
on_complete(false, "Failed to commit: " .. result)
return
end
if vim.v.shell_error ~= 0 then return false, "Failed to commit: " .. result end
return true, nil
on_complete(true, nil)
end)
end
---@type AvanteLLMToolFunc<{ query: string }>
@@ -571,57 +642,62 @@ function M.python(opts, on_log, on_complete)
if on_log then on_log("cwd: " .. abs_path) end
if on_log then on_log("code:\n" .. opts.code) end
local container_image = opts.container_image or "python:3.11-slim-bookworm"
if
not M.confirm(
"Are you sure you want to run the following python code in the `"
.. container_image
.. "` container, in the directory: `"
.. abs_path
.. "`?\n"
.. opts.code
)
then
return nil, "User canceled"
end
if vim.fn.executable("docker") == 0 then return nil, "Python tool is not available to execute any code" end
if not on_complete then return nil, "on_complete not provided" end
M.confirm(
"Are you sure you want to run the following python code in the `"
.. container_image
.. "` container, in the directory: `"
.. abs_path
.. "`?\n"
.. opts.code,
function(ok)
if not ok then
on_complete(nil, "User canceled")
return
end
if vim.fn.executable("docker") == 0 then
on_complete(nil, "Python tool is not available to execute any code")
return
end
local function handle_result(result) ---@param result vim.SystemCompleted
if result.code ~= 0 then return nil, "Error: " .. (result.stderr or "Unknown error") end
local function handle_result(result) ---@param result vim.SystemCompleted
if result.code ~= 0 then return nil, "Error: " .. (result.stderr or "Unknown error") end
Utils.debug("output", result.stdout)
return result.stdout, nil
end
local job = vim.system(
{
"docker",
"run",
"--rm",
"-v",
abs_path .. ":" .. abs_path,
"-w",
abs_path,
container_image,
"python",
"-c",
opts.code,
},
{
text = true,
cwd = abs_path,
},
vim.schedule_wrap(function(result)
if not on_complete then return end
local output, err = handle_result(result)
on_complete(output, err)
end)
Utils.debug("output", result.stdout)
return result.stdout, nil
end
vim.system(
{
"docker",
"run",
"--rm",
"-v",
abs_path .. ":" .. abs_path,
"-w",
abs_path,
container_image,
"python",
"-c",
opts.code,
},
{
text = true,
cwd = abs_path,
},
vim.schedule_wrap(function(result)
if not on_complete then return end
local output, err = handle_result(result)
on_complete(output, err)
end)
)
end
)
if on_complete then return end
local result = job:wait()
return handle_result(result)
end
---@param user_input string
---@param history_messages AvanteLLMMessage[]
---@return AvanteLLMTool[]
function M.get_tools()
function M.get_tools(user_input, history_messages)
local custom_tools = Config.custom_tools
if type(custom_tools) == "function" then custom_tools = custom_tools() end
---@type AvanteLLMTool[]
@@ -634,7 +710,7 @@ function M.get_tools()
if tool.enabled == nil then
return true
else
return tool.enabled()
return tool.enabled({ user_input = user_input, history_messages = history_messages })
end
end)
:totable()
@@ -644,7 +720,7 @@ end
M._tools = {
{
name = "glob",
description = 'Fast file pattern matching using glob patterns like "**/*.js"',
description = 'Fast file pattern matching using glob patterns like "**/*.js", in current project scope',
param = {
type = "table",
fields = {
@@ -655,7 +731,7 @@ M._tools = {
},
{
name = "rel_path",
description = "Relative path to the directory, as cwd",
description = "Relative path to the project directory, as cwd",
type = "string",
},
},
@@ -704,7 +780,7 @@ M._tools = {
},
{
name = "python",
description = "Run python code. Can't use it to read files or modify files.",
description = "Run python code in current project scope. Can't use it to read files or modify files.",
param = {
type = "table",
fields = {
@@ -715,7 +791,7 @@ M._tools = {
},
{
name = "rel_path",
description = "Relative path to the directory, as cwd",
description = "Relative path to the project directory, as cwd",
type = "string",
},
},
@@ -736,7 +812,7 @@ M._tools = {
},
{
name = "git_diff",
description = "Get git diff for generating commit message",
description = "Get git diff for generating commit message in current project scope",
param = {
type = "table",
fields = {
@@ -763,7 +839,7 @@ M._tools = {
},
{
name = "git_commit",
description = "Commit changes with the given commit message",
description = "Commit changes with the given commit message in current project scope",
param = {
type = "table",
fields = {
@@ -796,13 +872,13 @@ M._tools = {
},
{
name = "list_files",
description = "List files in a directory",
description = "List files in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
description = "Relative path to the project directory",
type = "string",
},
{
@@ -815,7 +891,7 @@ M._tools = {
returns = {
{
name = "files",
description = "List of files in the directory",
description = "List of filepaths in the directory",
type = "string[]",
},
{
@@ -828,13 +904,13 @@ M._tools = {
},
{
name = "search_files",
description = "Search for files in a directory",
description = "Search for files in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
description = "Relative path to the project directory",
type = "string",
},
{
@@ -847,7 +923,7 @@ M._tools = {
returns = {
{
name = "files",
description = "List of files that match the keyword",
description = "List of filepaths that match the keyword",
type = "string",
},
{
@@ -860,13 +936,13 @@ M._tools = {
},
{
name = "grep_search",
description = "Search for a keyword in a directory using grep",
description = "Search for a keyword in a directory using grep in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
description = "Relative path to the project directory",
type = "string",
},
{
@@ -911,13 +987,13 @@ M._tools = {
},
{
name = "read_file_toplevel_symbols",
description = "Read the top-level symbols of a file",
description = "Read the top-level symbols of a file in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
description = "Relative path to the file in current project scope",
type = "string",
},
},
@@ -938,13 +1014,28 @@ M._tools = {
},
{
name = "read_file",
description = "Read the contents of a file. If the file content is already in the context, do not use this tool.",
description = "Read the contents of a file in current project scope. If the file content is already in the context, do not use this tool.",
enabled = function(opts)
if opts.user_input:match("@read_global_file") then return false end
for _, message in ipairs(opts.history_messages) do
if message.role == "user" then
local content = message.content
if type(content) == "string" and content:match("@read_global_file") then return false end
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) == "string" and item:match("@read_global_file") then return false end
end
end
end
end
return true
end,
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
description = "Relative path to the file in current project scope",
type = "string",
},
},
@@ -963,15 +1054,104 @@ M._tools = {
},
},
},
{
name = "read_global_file",
description = "Read the contents of a file in the global scope. If the file content is already in the context, do not use this tool.",
enabled = function(opts)
if opts.user_input:match("@read_global_file") then return true end
for _, message in ipairs(opts.history_messages) do
if message.role == "user" then
local content = message.content
if type(content) == "string" and content:match("@read_global_file") then return true end
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) == "string" and item:match("@read_global_file") then return true end
end
end
end
end
return false
end,
param = {
type = "table",
fields = {
{
name = "abs_path",
description = "Absolute path to the file in global scope",
type = "string",
},
},
},
returns = {
{
name = "content",
description = "Contents of the file",
type = "string",
},
{
name = "error",
description = "Error message if the file was not read successfully",
type = "string",
optional = true,
},
},
},
{
name = "write_global_file",
description = "Write to a file in the global scope",
enabled = function(opts)
if opts.user_input:match("@write_global_file") then return true end
for _, message in ipairs(opts.history_messages) do
if message.role == "user" then
local content = message.content
if type(content) == "string" and content:match("@write_global_file") then return true end
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) == "string" and item:match("@write_global_file") then return true end
end
end
end
end
return false
end,
param = {
type = "table",
fields = {
{
name = "abs_path",
description = "Absolute path to the file in global scope",
type = "string",
},
{
name = "content",
description = "Content to write to the file",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the file was written successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not written successfully",
type = "string",
optional = true,
},
},
},
{
name = "create_file",
description = "Create a new file",
description = "Create a new file in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
description = "Relative path to the file in current project scope",
type = "string",
},
},
@@ -992,13 +1172,13 @@ M._tools = {
},
{
name = "rename_file",
description = "Rename a file",
description = "Rename a file in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
description = "Relative path to the file in current project scope",
type = "string",
},
{
@@ -1024,13 +1204,13 @@ M._tools = {
},
{
name = "delete_file",
description = "Delete a file",
description = "Delete a file in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
description = "Relative path to the file in current project scope",
type = "string",
},
},
@@ -1051,13 +1231,13 @@ M._tools = {
},
{
name = "create_dir",
description = "Create a new directory",
description = "Create a new directory in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
description = "Relative path to the project directory",
type = "string",
},
},
@@ -1078,13 +1258,13 @@ M._tools = {
},
{
name = "rename_dir",
description = "Rename a directory",
description = "Rename a directory in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
description = "Relative path to the project directory",
type = "string",
},
{
@@ -1110,13 +1290,13 @@ M._tools = {
},
{
name = "delete_dir",
description = "Delete a directory",
description = "Delete a directory in current project scope",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
description = "Relative path to the project directory",
type = "string",
},
},
@@ -1137,13 +1317,13 @@ M._tools = {
},
{
name = "bash",
description = "Run a bash command in a directory. Can't use search commands like find/grep or read tools like cat/ls. Can't use it to read files or modify files.",
description = "Run a bash command in current project scope. Can't use search commands like find/grep or read tools like cat/ls. Can't use it to read files or modify files.",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
description = "Relative path to the project directory, as cwd",
type = "string",
},
{

View File

@@ -59,6 +59,8 @@ function M:parse_messages(opts)
---@type AvanteClaudeMessage[]
local messages = {}
local provider_conf, _ = P.parse_config(self)
---@type {idx: integer, length: integer}[]
local messages_with_length = {}
for idx, message in ipairs(opts.messages) do
@@ -76,15 +78,46 @@ function M:parse_messages(opts)
end
for idx, message in ipairs(opts.messages) do
local content_items = message.content
local message_content = {}
if type(content_items) == "string" then
table.insert(message_content, {
type = "text",
text = message.content,
cache_control = top_two[idx] and { type = "ephemeral" } or nil,
})
elseif type(content_items) == "table" then
---@cast content_items AvanteLLMMessageContentItem[]
for _, item in ipairs(content_items) do
if type(item) == "string" then
table.insert(
message_content,
{ type = "text", text = item, cache_control = top_two[idx] and { type = "ephemeral" } or nil }
)
elseif type(item) == "table" and item.type == "text" then
table.insert(
message_content,
{ type = "text", text = item.text, cache_control = top_two[idx] and { type = "ephemeral" } or nil }
)
elseif type(item) == "table" and item.type == "image" then
table.insert(message_content, { type = "image", source = item.source })
elseif not provider_conf.disable_tools and type(item) == "table" and item.type == "tool_use" then
table.insert(message_content, { type = "tool_use", name = item.name, id = item.id, input = item.input })
elseif not provider_conf.disable_tools and type(item) == "table" and item.type == "tool_result" then
table.insert(
message_content,
{ type = "tool_result", tool_use_id = item.tool_use_id, content = item.content, is_error = item.is_error }
)
elseif type(item) == "table" and item.type == "thinking" then
table.insert(message_content, { type = "thinking", thinking = item.thinking, signature = item.signature })
elseif type(item) == "table" and item.type == "redacted_thinking" then
table.insert(message_content, { type = "redacted_thinking", data = item.data })
end
end
end
table.insert(messages, {
role = self.role_map[message.role],
content = {
{
type = "text",
text = message.content,
cache_control = top_two[idx] and { type = "ephemeral" } or nil,
},
},
content = message_content,
})
end

View File

@@ -35,9 +35,36 @@ function M:parse_messages(opts)
end
end
prev_role = role
table.insert(contents, { role = M.role_map[role] or role, parts = {
{ text = message.content },
} })
local parts = {}
local content_items = message.content
if type(content_items) == "string" then
table.insert(parts, { text = content_items })
elseif type(content_items) == "table" then
---@cast content_items AvanteLLMMessageContentItem[]
for _, item in ipairs(content_items) do
if type(item) == "string" then
table.insert(parts, { text = item })
elseif type(item) == "table" and item.type == "text" then
table.insert(parts, { text = item.text })
elseif type(item) == "table" and item.type == "image" then
table.insert(parts, {
inline_data = {
mime_type = "image/png",
data = item.source.data,
},
})
elseif type(item) == "table" and item.type == "tool_use" then
table.insert(parts, { text = item.name })
elseif type(item) == "table" and item.type == "tool_result" then
table.insert(parts, { text = item.content })
elseif type(item) == "table" and item.type == "thinking" then
table.insert(parts, { text = item.thinking })
elseif type(item) == "table" and item.type == "redacted_thinking" then
table.insert(parts, { text = item.data })
end
end
end
table.insert(contents, { role = M.role_map[role] or role, parts = parts })
end)
if Clipboard.support_paste_image() and opts.image_paths then

View File

@@ -81,9 +81,54 @@ function M:parse_messages(opts)
table.insert(messages, { role = "system", content = opts.system_prompt })
end
vim
.iter(opts.messages)
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
vim.iter(opts.messages):each(function(msg)
if type(msg.content) == "string" then
table.insert(messages, { role = self.role_map[msg.role], content = msg.content })
else
local content = {}
local tool_calls = {}
local tool_results = {}
for _, item in ipairs(msg.content) do
if type(item) == "string" then
table.insert(content, { type = "text", text = item })
elseif item.type == "text" then
table.insert(content, { type = "text", text = item.text })
elseif item.type == "image" then
table.insert(content, {
type = "image_url",
image_url = {
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
},
})
elseif item.type == "tool_use" then
table.insert(tool_calls, {
id = item.id,
type = "function",
["function"] = { name = item.name, arguments = vim.json.encode(item.input) },
})
elseif item.type == "tool_result" then
table.insert(
tool_results,
{ tool_call_id = item.tool_use_id, content = item.is_error and "Error: " .. item.content or item.content }
)
end
end
table.insert(messages, { role = self.role_map[msg.role], content = content })
if not provider_conf.disable_tools then
if #tool_calls > 0 then
table.insert(messages, { role = self.role_map["assistant"], tool_calls = tool_calls })
end
if #tool_results > 0 then
for _, tool_result in ipairs(tool_results) do
table.insert(
messages,
{ role = "tool", tool_call_id = tool_result.tool_call_id, content = tool_result.content or "" }
)
end
end
end
end
end)
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content

View File

@@ -43,6 +43,7 @@ local Sidebar = {}
---@field selected_files_container NuiSplit | nil
---@field input_container NuiSplit | nil
---@field file_selector FileSelector
---@field chat_history avante.ChatHistory | nil
---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage()
function Sidebar:new(id)
@@ -61,6 +62,7 @@ function Sidebar:new(id)
input_container = nil,
file_selector = FileSelector:new(id),
is_generating = false,
chat_history = nil,
}, { __index = self })
end
@@ -397,6 +399,7 @@ local function transform_result_content(selected_files, result_content, prev_fil
elseif line_content == "<think>" then
is_thinking = true
last_think_tag_start_line = i
last_think_tag_end_line = 0
elseif line_content == "</think>" then
is_thinking = false
last_think_tag_end_line = i
@@ -1810,6 +1813,7 @@ function Sidebar:on_mount(opts)
group = self.augroup,
callback = function(args)
local closed_winid = tonumber(args.match)
if closed_winid == self.winids.selected_files_container then return end
if not self:is_focused_on(closed_winid) then return end
self:close()
end,
@@ -1838,6 +1842,7 @@ function Sidebar:refresh_winids()
local function switch_windows()
local current_winid = api.nvim_get_current_win()
winids = vim.iter(winids):filter(function(winid) return api.nvim_win_is_valid(winid) end):totable()
local current_idx = Utils.tbl_indexof(winids, current_winid) or 1
if current_idx == #winids then
current_idx = 1
@@ -1906,6 +1911,8 @@ function Sidebar:initialize()
self.file_selector:reset()
self.file_selector:add_selected_file(filepath)
self:reload_chat_history()
return self
end
@@ -2095,6 +2102,7 @@ end
function Sidebar:render_history_content(history)
local content = ""
for idx, entry in ipairs(history.entries) do
if entry.visible == false then goto continue end
if entry.reset_memory then
content = content .. "***MEMORY RESET***\n\n"
if idx < #history.entries then content = content .. "-------\n\n" end
@@ -2180,7 +2188,7 @@ end
function Sidebar:new_chat(args, cb)
Path.history.new(self.code.bufnr)
Sidebar.reload_chat_history()
self:reload_chat_history()
self:update_content(
"New chat",
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
@@ -2188,6 +2196,26 @@ function Sidebar:new_chat(args, cb)
if cb then cb(args) end
end
---@param message AvanteLLMMessage
---@param options {visible?: boolean}
function Sidebar:add_chat_history(message, options)
local timestamp = get_timestamp()
self:reload_chat_history()
table.insert(self.chat_history.entries, {
timestamp = timestamp,
provider = Config.provider,
model = Config.get_provider_config(Config.provider).model,
request = message.role == "user" and message.content or "",
response = message.role == "assistant" and message.content or "",
original_response = "",
selected_filepaths = nil,
selected_code = nil,
reset_memory = false,
visible = options.visible,
})
Path.history.save(self.code.bufnr, self.chat_history)
end
function Sidebar:reset_memory(args, cb)
local chat_history = Path.history.load(self.code.bufnr)
if next(chat_history) ~= nil then
@@ -2203,7 +2231,7 @@ function Sidebar:reset_memory(args, cb)
reset_memory = true,
})
Path.history.save(self.code.bufnr, chat_history)
Sidebar.reload_chat_history()
self:reload_chat_history()
local history_content = self:render_history_content(chat_history)
self:update_content(history_content, {
focus = false,
@@ -2212,7 +2240,7 @@ function Sidebar:reset_memory(args, cb)
})
if cb then cb(args) end
else
Sidebar.reload_chat_history()
self:reload_chat_history()
self:update_content(
"Chat history is already empty",
{ focus = false, scroll = false, callback = function() self:focus_input() end }
@@ -2321,46 +2349,18 @@ local generating_text = "**Generating response ...**\n"
local hint_window = nil
function Sidebar:reload_chat_history()
if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end
self.chat_history = Path.history.load(self.code.bufnr)
end
---@param opts AskOptions
function Sidebar:create_input_container(opts)
if self.input_container then self.input_container:unmount() end
if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end
local chat_history = Path.history.load(self.code.bufnr)
Sidebar.reload_chat_history = function() chat_history = Path.history.load(self.code.bufnr) end
local tools = vim.deepcopy(LLMTools.get_tools())
table.insert(tools, {
name = "add_file_to_context",
description = "Add a file to the context",
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:add_selected_file(input.rel_path)
return "Added file to context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
table.insert(tools, {
name = "remove_file_from_context",
description = "Remove a file from the context",
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:remove_selected_file(input.rel_path)
return "Removed file from context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
if self.chat_history == nil then self:reload_chat_history() end
---@param request string
---@param summarize_memory boolean
@@ -2399,17 +2399,48 @@ function Sidebar:create_input_container(opts)
end
end
local entries = Utils.history.filter_active_entries(chat_history.entries)
local entries = Utils.history.filter_active_entries(self.chat_history.entries)
if chat_history.memory then
if self.chat_history.memory then
entries = vim
.iter(entries)
:filter(function(entry) return entry.timestamp > chat_history.memory.last_summarized_timestamp end)
:filter(function(entry) return entry.timestamp > self.chat_history.memory.last_summarized_timestamp end)
:totable()
end
local history_messages = Utils.history.entries_to_llm_messages(entries)
local tools = vim.deepcopy(LLMTools.get_tools(request, history_messages))
table.insert(tools, {
name = "add_file_to_context",
description = "Add a file to the context",
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:add_selected_file(input.rel_path)
return "Added file to context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
table.insert(tools, {
name = "remove_file_from_context",
description = "Remove a file from the context",
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:remove_selected_file(input.rel_path)
return "Removed file from context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
---@type AvanteGeneratePromptsOptions
local prompts_opts = {
ask = opts.ask or true,
@@ -2425,7 +2456,7 @@ function Sidebar:create_input_container(opts)
tools = tools,
}
if chat_history.memory then prompts_opts.memory = chat_history.memory.content end
if self.chat_history.memory then prompts_opts.memory = self.chat_history.memory.content end
if not summarize_memory or #history_messages < 8 then
cb(prompts_opts)
@@ -2434,7 +2465,7 @@ function Sidebar:create_input_container(opts)
prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5)
Llm.summarize_memory(self.code.bufnr, chat_history, function(memory)
Llm.summarize_memory(self.code.bufnr, self.chat_history, function(memory)
if memory then prompts_opts.memory = memory.content end
cb(prompts_opts)
end)
@@ -2628,8 +2659,8 @@ function Sidebar:create_input_container(opts)
end, 0)
-- Save chat history
chat_history.entries = chat_history.entries or {}
table.insert(chat_history.entries, {
self.chat_history.entries = self.chat_history.entries or {}
table.insert(self.chat_history.entries, {
timestamp = timestamp,
provider = Config.provider,
model = model,
@@ -2638,8 +2669,9 @@ function Sidebar:create_input_container(opts)
original_response = original_response,
selected_filepaths = selected_filepaths,
selected_code = selected_code,
tool_histories = stop_opts.tool_histories,
})
Path.history.save(self.code.bufnr, chat_history)
Path.history.save(self.code.bufnr, self.chat_history)
end
get_generate_prompts_options(request, true, function(generate_prompts_options)

View File

@@ -19,3 +19,4 @@ Tools Usage Guide:
- For any mathematical calculation problems, please prioritize using the `python` tool to solve them. Please try to avoid mathematical symbols in the return value of the `python` tool for mathematical problems and directly output human-readable results, because large models don't understand mathematical symbols, they only understand human natural language.
- Do not use the `python` tool to read or modify files! If you use the `python` tool to read or modify files, you will be fired!!!!!
- Do not use the `bash` tool to read or modify files! If you use the `bash` tool to read or modify files, you will be fired!!!!!
- If you are provided with the `write_file` tool, there's no need to output your change suggestions, just directly use the `write_file` tool to complete the changes.

View File

@@ -1,11 +1,3 @@
{# Uses https://mitsuhiko.github.io/minijinja-playground/ for testing:
{
"ask": true,
"question": "Refactor to include tab flow",
"code_lang": "lua",
"file_content": "local Config = require('avante.config')"
}
#}
Act as an expert software developer.
Always use best practices when coding.
Respect and use existing conventions, libraries, etc that are already present in the code base.

View File

@@ -76,9 +76,13 @@ vim.g.avante_login = vim.g.avante_login
---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---
---@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string }
---
---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string
---
---@class AvanteLLMMessage
---@field role "user" | "assistant"
---@field content string
---@field content AvanteLLMMessageContent
---
---@class AvanteLLMToolResult
---@field tool_name string
@@ -245,6 +249,7 @@ vim.g.avante_login = vim.g.avante_login
---@field tool_use_list? AvanteLLMToolUse[]
---@field retry_after? integer
---@field headers? table<string, string>
---@field tool_histories? AvanteLLMToolHistory[]
---
---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
@@ -342,7 +347,7 @@ vim.g.avante_login = vim.g.avante_login
---@field func? AvanteLLMToolFunc
---@field param AvanteLLMToolParam
---@field returns AvanteLLMToolReturn[]
---@field enabled? fun(): boolean
---@field enabled? fun(opts: { user_input: string, history_messages: AvanteLLMMessage[] }): boolean
---@class AvanteLLMToolPublic : AvanteLLMTool
---@field func AvanteLLMToolFunc
@@ -374,6 +379,8 @@ vim.g.avante_login = vim.g.avante_login
---@field selected_code AvanteSelectedCode | nil
---@field reset_memory boolean?
---@field selected_filepaths string[] | nil
---@field visible boolean?
---@field tool_histories? AvanteLLMToolHistory[]
---
---@class avante.ChatHistory
---@field title string

139
lua/avante/ui.lua Normal file
View File

@@ -0,0 +1,139 @@
local Popup = require("nui.popup")
local NuiText = require("nui.text")
local event = require("nui.utils.autocmd").event
local Highlights = require("avante.highlights")
local M = {}
function M.confirm(message, callback)
local focus_index = 2 -- 1 = Yes, 2 = No
local yes_button_pos = { 18, 23 }
local no_button_pos = { 28, 32 }
local BUTTON_NORMAL = Highlights.BUTTON_DEFAULT
local BUTTON_FOCUS = Highlights.BUTTON_DEFAULT_HOVER
local popup = Popup({
position = {
row = vim.o.lines - 5,
col = "50%",
},
size = { width = 50, height = 7 },
enter = true,
focusable = true,
border = {
style = "rounded",
text = { top = NuiText(" Confirmation ", Highlights.CONFIRM_TITLE) },
},
win_options = {
winblend = 10,
},
})
local function focus_button()
if focus_index == 1 then
vim.api.nvim_win_set_cursor(popup.winid, { 4, yes_button_pos[1] })
else
vim.api.nvim_win_set_cursor(popup.winid, { 4, no_button_pos[1] })
end
end
local function render_buttons()
local yes_style = (focus_index == 1) and BUTTON_FOCUS or BUTTON_NORMAL
local no_style = (focus_index == 2) and BUTTON_FOCUS or BUTTON_NORMAL
vim.api.nvim_buf_set_lines(popup.bufnr, 0, -1, false, {
"",
" " .. message,
"",
" " .. " Yes No ",
"",
})
vim.api.nvim_buf_add_highlight(popup.bufnr, 0, yes_style, 3, yes_button_pos[1], yes_button_pos[2])
vim.api.nvim_buf_add_highlight(popup.bufnr, 0, no_style, 3, no_button_pos[1], no_button_pos[2])
focus_button()
end
local function select_button()
popup:unmount()
callback(focus_index == 1)
end
vim.keymap.set("n", "y", function()
focus_index = 1
render_buttons()
select_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "n", function()
focus_index = 2
render_buttons()
select_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<Left>", function()
focus_index = 1
focus_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<Right>", function()
focus_index = 2
focus_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<Tab>", function()
focus_index = (focus_index == 1) and 2 or 1
focus_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<S-Tab>", function()
focus_index = (focus_index == 1) and 2 or 1
focus_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<CR>", function() select_button() end, { buffer = popup.bufnr })
vim.api.nvim_buf_set_keymap(popup.bufnr, "n", "<LeftMouse>", "", {
callback = function()
local pos = vim.fn.getmousepos()
local row, col = pos["winrow"], pos["wincol"]
if row == 4 then
if col >= yes_button_pos[1] and col <= yes_button_pos[2] then
focus_index = 1
render_buttons()
select_button()
elseif col >= no_button_pos[1] and col <= no_button_pos[2] then
focus_index = 2
render_buttons()
select_button()
end
end
end,
noremap = true,
silent = true,
})
vim.api.nvim_create_autocmd("CursorMoved", {
buffer = popup.bufnr,
callback = function()
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
if row == 4 then
if col >= yes_button_pos[1] and col <= yes_button_pos[2] then
focus_index = 1
render_buttons()
elseif col >= no_button_pos[1] and col <= no_button_pos[2] then
focus_index = 2
render_buttons()
end
end
end,
})
popup:on(event.BufLeave, function() popup:unmount() end)
popup:mount()
render_buttons()
end
return M

View File

@@ -11,14 +11,6 @@ function M.filter_active_entries(entries)
for i = #entries, 1, -1 do
local entry = entries[i]
if entry.reset_memory then break end
if
entry.request == nil
or entry.original_response == nil
or entry.request == ""
or entry.original_response == ""
then
break
end
table.insert(entries_, 1, entry)
end
@@ -30,25 +22,62 @@ end
function M.entries_to_llm_messages(entries)
local messages = {}
for _, entry in ipairs(entries) do
local user_content = ""
if entry.selected_filepaths ~= nil then
user_content = user_content .. "SELECTED FILES:\n\n"
if entry.selected_filepaths ~= nil and #entry.selected_filepaths > 0 then
local user_content = "SELECTED FILES:\n\n"
for _, filepath in ipairs(entry.selected_filepaths) do
user_content = user_content .. filepath .. "\n"
end
table.insert(messages, { role = "user", content = user_content })
end
if entry.selected_code ~= nil then
user_content = user_content
.. "SELECTED CODE:\n\n```"
local user_content_ = "SELECTED CODE:\n\n```"
.. (entry.selected_code.file_type or "")
.. (entry.selected_code.path and ":" .. entry.selected_code.path or "")
.. "\n"
.. entry.selected_code.content
.. "\n```\n\n"
table.insert(messages, { role = "user", content = user_content_ })
end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
table.insert(messages, { role = "user", content = user_content })
table.insert(messages, { role = "assistant", content = Utils.trim_think_content(entry.original_response) })
if entry.request ~= nil and entry.request ~= "" then
table.insert(messages, { role = "user", content = entry.request })
end
if entry.tool_histories ~= nil and #entry.tool_histories > 0 then
for _, tool_history in ipairs(entry.tool_histories) do
local assistant_content = {}
if tool_history.tool_use ~= nil then
if tool_history.tool_use.response_contents ~= nil then
for _, response_content in ipairs(tool_history.tool_use.response_contents) do
table.insert(assistant_content, { type = "text", text = response_content })
end
end
table.insert(assistant_content, {
type = "tool_use",
name = tool_history.tool_use.name,
id = tool_history.tool_use.id,
input = vim.json.decode(tool_history.tool_use.input_json),
})
end
table.insert(messages, {
role = "assistant",
content = assistant_content,
})
local user_content = {}
if tool_history.tool_result ~= nil and tool_history.tool_result.content ~= nil then
table.insert(user_content, {
type = "tool_result",
tool_use_id = tool_history.tool_result.tool_use_id,
content = tool_history.tool_result.content,
is_error = tool_history.tool_result.is_error,
})
end
table.insert(messages, {
role = "user",
content = user_content,
})
end
end
local assistant_content = Utils.trim_think_content(entry.original_response or "")
if assistant_content ~= "" then table.insert(messages, { role = "assistant", content = assistant_content }) end
end
return messages
end

View File

@@ -10,9 +10,29 @@ local cost_per_token = {
}
--- Calculate the number of tokens in a given text.
---@param text string The text to calculate the number of tokens for.
---@param content AvanteLLMMessageContent The text to calculate the number of tokens in.
---@return integer The number of tokens in the given text.
function Tokens.calculate_tokens(text)
function Tokens.calculate_tokens(content)
local text = ""
if type(content) == "string" then
text = content
elseif type(content) == "table" then
for _, item in ipairs(content) do
if type(item) == "string" then
text = text .. item
elseif type(item) == "table" and item.type == "text" then
text = text .. item.text
elseif type(item) == "table" and item.type == "image" then
text = text .. item.source.data
elseif type(item) == "table" and item.type == "tool_use" then
text = text .. item.name .. item.id
elseif type(item) == "table" and item.type == "tool_result" then
text = text .. item.tool_use_id .. item.content
end
end
end
if Tokenizer.available() then return Tokenizer.count(text) end
local tokens = 0

View File

@@ -3,7 +3,7 @@ local LlmTools = require("avante.llm_tools")
local Config = require("avante.config")
local Utils = require("avante.utils")
LlmTools.confirm = function(msg) return true end
LlmTools.confirm = function(msg, cb) return cb(true) end
describe("llm_tools", function()
local test_dir = "/tmp/test_llm_tools"
@@ -85,34 +85,37 @@ describe("llm_tools", function()
describe("create_file", function()
it("should create new file", function()
local success, err = LlmTools.create_file({ rel_path = "new_file.txt" })
assert.is_nil(err)
assert.is_true(success)
LlmTools.create_file({ rel_path = "new_file.txt" }, nil, function(success, err)
assert.is_nil(err)
assert.is_true(success)
local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil
assert.is_true(file_exists)
local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil
assert.is_true(file_exists)
end)
end)
end)
describe("create_dir", function()
it("should create new directory", function()
local success, err = LlmTools.create_dir({ rel_path = "new_dir" })
assert.is_nil(err)
assert.is_true(success)
LlmTools.create_dir({ rel_path = "new_dir" }, nil, function(success, err)
assert.is_nil(err)
assert.is_true(success)
local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil
assert.is_true(dir_exists)
local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil
assert.is_true(dir_exists)
end)
end)
end)
describe("delete_file", function()
it("should delete existing file", function()
local success, err = LlmTools.delete_file({ rel_path = "test.txt" })
assert.is_nil(err)
assert.is_true(success)
LlmTools.delete_file({ rel_path = "test.txt" }, nil, function(success, err)
assert.is_nil(err)
assert.is_true(success)
local file_exists = io.open(test_file, "r") ~= nil
assert.is_false(file_exists)
local file_exists = io.open(test_file, "r") ~= nil
assert.is_false(file_exists)
end)
end)
end)
@@ -270,68 +273,93 @@ describe("llm_tools", function()
describe("bash", function()
it("should execute command and return output", function()
local result, err = LlmTools.bash({ rel_path = ".", command = "echo 'test'" })
assert.is_nil(err)
assert.equals("test\n", result)
LlmTools.bash({ rel_path = ".", command = "echo 'test'" }, nil, function(result, err)
assert.is_nil(err)
assert.equals("test\n", result)
end)
end)
it("should return error when running outside current directory", function()
local result, err = LlmTools.bash({ rel_path = "../outside_project", command = "echo 'test'" })
assert.is_false(result)
assert.truthy(err)
assert.truthy(err:find("No permission to access path"))
LlmTools.bash({ rel_path = "../outside_project", command = "echo 'test'" }, nil, function(result, err)
assert.is_false(result)
assert.truthy(err)
assert.truthy(err:find("No permission to access path"))
end)
end)
end)
describe("python", function()
local original_system = vim.fn.system
it("should execute Python code and return output", function()
local result, err = LlmTools.python({
rel_path = ".",
code = "print('Hello from Python')",
})
assert.is_nil(err)
assert.equals("Hello from Python\n", result)
LlmTools.python(
{
rel_path = ".",
code = "print('Hello from Python')",
},
nil,
function(result, err)
assert.is_nil(err)
assert.equals("Hello from Python\n", result)
end
)
end)
it("should handle Python errors", function()
local result, err = LlmTools.python({
rel_path = ".",
code = "print(undefined_variable)",
})
assert.is_nil(result)
assert.truthy(err)
assert.truthy(err:find("Error"))
LlmTools.python(
{
rel_path = ".",
code = "print(undefined_variable)",
},
nil,
function(result, err)
assert.is_nil(result)
assert.truthy(err)
assert.truthy(err:find("Error"))
end
)
end)
it("should respect path permissions", function()
local result, err = LlmTools.python({
rel_path = "../outside_project",
code = "print('test')",
})
assert.is_nil(result)
assert.truthy(err:find("No permission to access path"))
LlmTools.python(
{
rel_path = "../outside_project",
code = "print('test')",
},
nil,
function(result, err)
assert.is_nil(result)
assert.truthy(err:find("No permission to access path"))
end
)
end)
it("should handle non-existent paths", function()
local result, err = LlmTools.python({
rel_path = "non_existent_dir",
code = "print('test')",
})
assert.is_nil(result)
assert.truthy(err:find("Path not found"))
LlmTools.python(
{
rel_path = "non_existent_dir",
code = "print('test')",
},
nil,
function(result, err)
assert.is_nil(result)
assert.truthy(err:find("Path not found"))
end
)
end)
it("should support custom container image", function()
os.execute("docker image rm python:3.12-slim")
local result, err = LlmTools.python({
rel_path = ".",
code = "print('Hello from custom container')",
container_image = "python:3.12-slim",
})
assert.is_nil(err)
assert.equals("Hello from custom container\n", result)
LlmTools.python(
{
rel_path = ".",
code = "print('Hello from custom container')",
container_image = "python:3.12-slim",
},
nil,
function(result, err)
assert.is_nil(err)
assert.equals("Hello from custom container\n", result)
end
)
end)
end)