fix: better sidebar (#1603)

* fix: better sidebar

* feat: better msg history

* fix: tests
This commit is contained in:
yetone
2025-03-17 01:40:05 +08:00
committed by GitHub
parent f60f150a21
commit 6e77da83c1
17 changed files with 870 additions and 319 deletions

View File

@@ -99,6 +99,8 @@ end
---@field win? table<string, any> windows options similar to |nvim_open_win()| ---@field win? table<string, any> windows options similar to |nvim_open_win()|
---@field ask? boolean ---@field ask? boolean
---@field floating? boolean whether to open a floating input to enter the question ---@field floating? boolean whether to open a floating input to enter the question
---@field new_chat? boolean whether to open a new chat
---@field without_selection? boolean whether to open a new chat without selection
---@param opts? AskOptions ---@param opts? AskOptions
function M.ask(opts) function M.ask(opts)
@@ -117,6 +119,7 @@ function M.ask(opts)
opts = vim.tbl_extend("force", { selection = Utils.get_visual_selection_and_range() }, opts) opts = vim.tbl_extend("force", { selection = Utils.get_visual_selection_and_range() }, opts)
---@param input string | nil
local function ask(input) local function ask(input)
if input == nil or input == "" then input = opts.question end if input == nil or input == "" then input = opts.question end
local sidebar = require("avante").get() local sidebar = require("avante").get()
@@ -124,6 +127,12 @@ function M.ask(opts)
sidebar:close({ goto_code_win = false }) sidebar:close({ goto_code_win = false })
end end
require("avante").open_sidebar(opts) require("avante").open_sidebar(opts)
if opts.new_chat then sidebar:new_chat() end
if opts.without_selection then
sidebar.code.selection = nil
sidebar.file_selector:reset()
if sidebar.selected_files_container then sidebar.selected_files_container:unmount() end
end
if input == nil or input == "" then return true end if input == nil or input == "" then return true end
vim.api.nvim_exec_autocmds("User", { pattern = "AvanteInputSubmitted", data = { request = input } }) vim.api.nvim_exec_autocmds("User", { pattern = "AvanteInputSubmitted", data = { request = input } })
return true return true

View File

@@ -195,7 +195,7 @@ M._defaults = {
model = "gpt-4o", model = "gpt-4o",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 8192,
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
copilot = { copilot = {
@@ -205,7 +205,7 @@ M._defaults = {
allow_insecure = false, -- Allow insecure server connections allow_insecure = false, -- Allow insecure server connections
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 8192,
}, },
---@type AvanteAzureProvider ---@type AvanteAzureProvider
azure = { azure = {
@@ -214,7 +214,7 @@ M._defaults = {
api_version = "2024-06-01", api_version = "2024-06-01",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 8192,
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
claude = { claude = {
@@ -222,14 +222,14 @@ M._defaults = {
model = "claude-3-7-sonnet-20250219", model = "claude-3-7-sonnet-20250219",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 8000, max_tokens = 8192,
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
bedrock = { bedrock = {
model = "anthropic.claude-3-5-sonnet-20241022-v2:0", model = "anthropic.claude-3-5-sonnet-20241022-v2:0",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 8000, max_tokens = 8192,
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
gemini = { gemini = {
@@ -237,7 +237,7 @@ M._defaults = {
model = "gemini-1.5-flash-latest", model = "gemini-1.5-flash-latest",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 8192,
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
vertex = { vertex = {
@@ -245,7 +245,7 @@ M._defaults = {
model = "gemini-1.5-flash-002", model = "gemini-1.5-flash-002",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 8192,
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
cohere = { cohere = {
@@ -253,7 +253,7 @@ M._defaults = {
model = "command-r-plus-08-2024", model = "command-r-plus-08-2024",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 8192,
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
ollama = { ollama = {
@@ -261,7 +261,7 @@ M._defaults = {
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
options = { options = {
temperature = 0, temperature = 0,
num_ctx = 4096, num_ctx = 8192,
}, },
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
@@ -270,7 +270,7 @@ M._defaults = {
model = "claude-3-5-sonnet-v2@20241022", model = "claude-3-5-sonnet-v2@20241022",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 8192,
}, },
---To add support for custom provider, follow the format below ---To add support for custom provider, follow the format below
---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details ---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details
@@ -282,7 +282,7 @@ M._defaults = {
model = "claude-3-5-haiku-20241022", model = "claude-3-5-haiku-20241022",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 8000, max_tokens = 8192,
}, },
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
["claude-opus"] = { ["claude-opus"] = {
@@ -290,7 +290,7 @@ M._defaults = {
model = "claude-3-opus-20240229", model = "claude-3-opus-20240229",
timeout = 30000, -- Timeout in milliseconds timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 8000, max_tokens = 8192,
}, },
["openai-gpt-4o-mini"] = { ["openai-gpt-4o-mini"] = {
__inherited_from = "openai", __inherited_from = "openai",

View File

@@ -85,6 +85,7 @@ end
function FileSelector:reset() function FileSelector:reset()
self.selected_filepaths = {} self.selected_filepaths = {}
self.event_handlers = {} self.event_handlers = {}
self:emit("update")
end end
function FileSelector:add_selected_file(filepath) function FileSelector:add_selected_file(filepath)

View File

@@ -18,6 +18,13 @@ local Highlights = {
INLINE_HINT = { name = "AvanteInlineHint", link = "Keyword" }, INLINE_HINT = { name = "AvanteInlineHint", link = "Keyword" },
TO_BE_DELETED = { name = "AvanteToBeDeleted", bg = "#ffcccc", strikethrough = true }, TO_BE_DELETED = { name = "AvanteToBeDeleted", bg = "#ffcccc", strikethrough = true },
TO_BE_DELETED_WITHOUT_STRIKETHROUGH = { name = "AvanteToBeDeletedWOStrikethrough", bg = "#562C30" }, TO_BE_DELETED_WITHOUT_STRIKETHROUGH = { name = "AvanteToBeDeletedWOStrikethrough", bg = "#562C30" },
CONFIRM_TITLE = { name = "AvanteConfirmTitle", fg = "#1e222a", bg = "#e06c75" },
BUTTON_DEFAULT = { name = "AvanteButtonDefault", fg = "#1e222a", bg = "#ABB2BF" },
BUTTON_DEFAULT_HOVER = { name = "AvanteButtonDefaultHover", fg = "#1e222a", bg = "#a9cf8a" },
BUTTON_PRIMARY = { name = "AvanteButtonPrimary", fg = "#1e222a", bg = "#ABB2BF" },
BUTTON_PRIMARY_HOVER = { name = "AvanteButtonPrimaryHover", fg = "#1e222a", bg = "#56b6c2" },
BUTTON_DANGER = { name = "AvanteButtonDanger", fg = "#1e222a", bg = "#ABB2BF" },
BUTTON_DANGER_HOVER = { name = "AvanteButtonDangerHover", fg = "#1e222a", bg = "#e06c75" },
} }
Highlights.conflict = { Highlights.conflict = {

View File

@@ -509,6 +509,7 @@ function M._stream(opts)
end, stop_opts.retry_after * 1000) end, stop_opts.retry_after * 1000)
return return
end end
stop_opts.tool_histories = opts.tool_histories
return opts.on_stop(stop_opts) return opts.on_stop(stop_opts)
end, end,
} }

View File

@@ -17,9 +17,23 @@ local function get_abs_path(rel_path)
return p return p
end end
function M.confirm(msg) function M.confirm(message, callback)
local ok = vim.fn.confirm(msg, "&Yes\n&No", 2) local UI = require("avante.ui")
return ok == 1 UI.confirm(message, callback)
end
---@param abs_path string
---@return boolean
local function is_ignored(abs_path)
local project_root = Utils.get_project_root()
local gitignore_path = project_root .. "/.gitignore"
local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path)
-- The checker should only take care of the path inside the project root
-- Specifically, it should not check the project root itself
-- Otherwise if the binary is named the same as the project root (such as Go binary), any paths
-- insde the project root will be ignored
local rel_path = Utils.make_relative_path(abs_path, project_root)
return Utils.is_ignored(rel_path, gitignore_patterns, gitignore_negate_patterns)
end end
---@param abs_path string ---@param abs_path string
@@ -28,14 +42,7 @@ local function has_permission_to_access(abs_path)
if not Path:new(abs_path):is_absolute() then return false end if not Path:new(abs_path):is_absolute() then return false end
local project_root = Utils.get_project_root() local project_root = Utils.get_project_root()
if abs_path:sub(1, #project_root) ~= project_root then return false end if abs_path:sub(1, #project_root) ~= project_root then return false end
local gitignore_path = project_root .. "/.gitignore" return not is_ignored(abs_path)
local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path)
-- The checker should only take care of the path inside the project root
-- Specifically, it should not check the project root itself
-- Otherwise if the binary is named the same as the project root (such as Go binary), any paths
-- insde the project root will be ignored
local rel_path = Utils.make_relative_path(abs_path, project_root)
return not Utils.is_ignored(rel_path, gitignore_patterns, gitignore_negate_patterns)
end end
---@type AvanteLLMToolFunc<{ rel_path: string, pattern: string }> ---@type AvanteLLMToolFunc<{ rel_path: string, pattern: string }>
@@ -164,6 +171,41 @@ function M.read_file(opts, on_log)
return content, nil return content, nil
end end
---@type AvanteLLMToolFunc<{ abs_path: string }>
function M.read_global_file(opts, on_log)
local abs_path = get_abs_path(opts.abs_path)
if is_ignored(abs_path) then return "", "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
local file = io.open(abs_path, "r")
if not file then return "", "file not found: " .. abs_path end
local content = file:read("*a")
file:close()
return content, nil
end
---@type AvanteLLMToolFunc<{ abs_path: string, content: string }>
function M.write_global_file(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.abs_path)
if is_ignored(abs_path) then return false, "This file is ignored: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("content: " .. opts.content) end
if not on_complete then return false, "on_complete not provided" end
M.confirm("Are you sure you want to write to the file: " .. abs_path, function(ok)
if not ok then
on_complete(false, "User canceled")
return
end
local file = io.open(abs_path, "w")
if not file then
on_complete(false, "file not found: " .. abs_path)
return
end
file:write(opts.content)
file:close()
on_complete(true, nil)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string }> ---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.create_file(opts, on_log) function M.create_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
@@ -183,7 +225,7 @@ function M.create_file(opts, on_log)
end end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> ---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
function M.rename_file(opts, on_log) function M.rename_file(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
@@ -192,11 +234,15 @@ function M.rename_file(opts, on_log)
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
if not M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path) then if not on_complete then return false, "on_complete not provided" end
return false, "User canceled" M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path, function(ok)
end if not ok then
os.rename(abs_path, new_abs_path) on_complete(false, "User canceled")
return true, nil return
end
os.rename(abs_path, new_abs_path)
on_complete(true, nil)
end)
end end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> ---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
@@ -214,32 +260,42 @@ function M.copy_file(opts, on_log)
end end
---@type AvanteLLMToolFunc<{ rel_path: string }> ---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.delete_file(opts, on_log) function M.delete_file(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end if not on_complete then return false, "on_complete not provided" end
if on_log then on_log("Deleting file: " .. abs_path) end M.confirm("Are you sure you want to delete the file: " .. abs_path, function(ok)
os.remove(abs_path) if not ok then
return true, nil on_complete(false, "User canceled")
return
end
if on_log then on_log("Deleting file: " .. abs_path) end
os.remove(abs_path)
on_complete(true, nil)
end)
end end
---@type AvanteLLMToolFunc<{ rel_path: string }> ---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.create_dir(opts, on_log) function M.create_dir(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
if not M.confirm("Are you sure you want to create the directory: " .. abs_path) then if not on_complete then return false, "on_complete not provided" end
return false, "User canceled" M.confirm("Are you sure you want to create the directory: " .. abs_path, function(ok)
end if not ok then
if on_log then on_log("Creating directory: " .. abs_path) end on_complete(false, "User canceled")
Path:new(abs_path):mkdir({ parents = true }) return
return true, nil end
if on_log then on_log("Creating directory: " .. abs_path) end
Path:new(abs_path):mkdir({ parents = true })
on_complete(true, nil)
end)
end end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> ---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
function M.rename_dir(opts, on_log) function M.rename_dir(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
@@ -247,26 +303,34 @@ function M.rename_dir(opts, on_log)
local new_abs_path = get_abs_path(opts.new_rel_path) local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "Directory already exists: " .. new_abs_path end if Path:new(new_abs_path):exists() then return false, "Directory already exists: " .. new_abs_path end
if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then if not on_complete then return false, "on_complete not provided" end
return false, "User canceled" M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?", function(ok)
end if not ok then
if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end on_complete(false, "User canceled")
os.rename(abs_path, new_abs_path) return
return true, nil end
if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end
os.rename(abs_path, new_abs_path)
on_complete(true, nil)
end)
end end
---@type AvanteLLMToolFunc<{ rel_path: string }> ---@type AvanteLLMToolFunc<{ rel_path: string }>
function M.delete_dir(opts, on_log) function M.delete_dir(opts, on_log, on_complete)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then if not on_complete then return false, "on_complete not provided" end
return false, "User canceled" M.confirm("Are you sure you want to delete the directory: " .. abs_path, function(ok)
end if not ok then
if on_log then on_log("Deleting directory: " .. abs_path) end on_complete(false, "User canceled")
os.remove(abs_path) return
return true, nil end
if on_log then on_log("Deleting directory: " .. abs_path) end
os.remove(abs_path)
on_complete(true, nil)
end)
end end
---@type AvanteLLMToolFunc<{ rel_path: string, command: string }> ---@type AvanteLLMToolFunc<{ rel_path: string, command: string }>
@@ -275,11 +339,6 @@ function M.bash(opts, on_log, on_complete)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if on_log then on_log("command: " .. opts.command) end if on_log then on_log("command: " .. opts.command) end
if
not M.confirm("Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path)
then
return false, "User canceled"
end
---change cwd to abs_path ---change cwd to abs_path
---@param output string ---@param output string
---@param exit_code integer ---@param exit_code integer
@@ -292,18 +351,20 @@ function M.bash(opts, on_log, on_complete)
end end
return output, nil return output, nil
end end
if on_complete then if not on_complete then return false, "on_complete not provided" end
Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code) M.confirm(
local result, err = handle_result(output, exit_code) "Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path,
on_complete(result, err) function(ok)
end, abs_path) if not ok then
return nil, nil on_complete(false, "User canceled")
end return
local old_cwd = vim.fn.getcwd() end
vim.fn.chdir(abs_path) Utils.shell_run_async(opts.command, "bash -c", function(output, exit_code)
local res = Utils.shell_run(opts.command, "bash -c") local result, err = handle_result(output, exit_code)
vim.fn.chdir(old_cwd) on_complete(result, err)
return handle_result(res.stdout, res.code) end, abs_path)
end
)
end end
---@type AvanteLLMToolFunc<{ query: string }> ---@type AvanteLLMToolFunc<{ query: string }>
@@ -464,7 +525,7 @@ function M.git_diff(opts, on_log)
end end
---@type AvanteLLMToolFunc<{ message: string, scope?: string }> ---@type AvanteLLMToolFunc<{ message: string, scope?: string }>
function M.git_commit(opts, on_log) function M.git_commit(opts, on_log, on_complete)
local git_cmd = vim.fn.exepath("git") local git_cmd = vim.fn.exepath("git")
if git_cmd == "" then return false, "Git command not found" end if git_cmd == "" then return false, "Git command not found" end
local project_root = Utils.get_project_root() local project_root = Utils.get_project_root()
@@ -518,36 +579,46 @@ function M.git_commit(opts, on_log)
-- Construct full commit message for confirmation -- Construct full commit message for confirmation
local full_commit_msg = table.concat(commit_msg_lines, "\n") local full_commit_msg = table.concat(commit_msg_lines, "\n")
if not on_complete then return false, "on_complete not provided" end
-- Confirm with user -- Confirm with user
if not M.confirm("Are you sure you want to commit with message:\n" .. full_commit_msg) then M.confirm("Are you sure you want to commit with message:\n" .. full_commit_msg, function(ok)
return false, "User canceled" if not ok then
end on_complete(false, "User canceled")
return
end
-- Stage changes if scope is provided
if opts.scope then
local stage_cmd = string.format("git add %s", opts.scope)
if on_log then on_log("Staging files: " .. stage_cmd) end
local stage_result = vim.fn.system(stage_cmd)
if vim.v.shell_error ~= 0 then
on_complete(false, "Failed to stage files: " .. stage_result)
return
end
end
-- Stage changes if scope is provided -- Construct git commit command
if opts.scope then local cmd_parts = { "git", "commit" }
local stage_cmd = string.format("git add %s", opts.scope) -- Only add -S flag if GPG is available
if on_log then on_log("Staging files: " .. stage_cmd) end if has_gpg then table.insert(cmd_parts, "-S") end
local stage_result = vim.fn.system(stage_cmd) for _, line in ipairs(commit_msg_lines) do
if vim.v.shell_error ~= 0 then return false, "Failed to stage files: " .. stage_result end table.insert(cmd_parts, "-m")
end table.insert(cmd_parts, '"' .. line .. '"')
end
local cmd = table.concat(cmd_parts, " ")
-- Construct git commit command -- Execute git commit
local cmd_parts = { "git", "commit" } if on_log then on_log("Running command: " .. cmd) end
-- Only add -S flag if GPG is available local result = vim.fn.system(cmd)
if has_gpg then table.insert(cmd_parts, "-S") end
for _, line in ipairs(commit_msg_lines) do
table.insert(cmd_parts, "-m")
table.insert(cmd_parts, '"' .. line .. '"')
end
local cmd = table.concat(cmd_parts, " ")
-- Execute git commit if vim.v.shell_error ~= 0 then
if on_log then on_log("Running command: " .. cmd) end on_complete(false, "Failed to commit: " .. result)
local result = vim.fn.system(cmd) return
end
if vim.v.shell_error ~= 0 then return false, "Failed to commit: " .. result end on_complete(true, nil)
end)
return true, nil
end end
---@type AvanteLLMToolFunc<{ query: string }> ---@type AvanteLLMToolFunc<{ query: string }>
@@ -571,57 +642,62 @@ function M.python(opts, on_log, on_complete)
if on_log then on_log("cwd: " .. abs_path) end if on_log then on_log("cwd: " .. abs_path) end
if on_log then on_log("code:\n" .. opts.code) end if on_log then on_log("code:\n" .. opts.code) end
local container_image = opts.container_image or "python:3.11-slim-bookworm" local container_image = opts.container_image or "python:3.11-slim-bookworm"
if if not on_complete then return nil, "on_complete not provided" end
not M.confirm( M.confirm(
"Are you sure you want to run the following python code in the `" "Are you sure you want to run the following python code in the `"
.. container_image .. container_image
.. "` container, in the directory: `" .. "` container, in the directory: `"
.. abs_path .. abs_path
.. "`?\n" .. "`?\n"
.. opts.code .. opts.code,
) function(ok)
then if not ok then
return nil, "User canceled" on_complete(nil, "User canceled")
end return
if vim.fn.executable("docker") == 0 then return nil, "Python tool is not available to execute any code" end end
if vim.fn.executable("docker") == 0 then
on_complete(nil, "Python tool is not available to execute any code")
return
end
local function handle_result(result) ---@param result vim.SystemCompleted local function handle_result(result) ---@param result vim.SystemCompleted
if result.code ~= 0 then return nil, "Error: " .. (result.stderr or "Unknown error") end if result.code ~= 0 then return nil, "Error: " .. (result.stderr or "Unknown error") end
Utils.debug("output", result.stdout) Utils.debug("output", result.stdout)
return result.stdout, nil return result.stdout, nil
end end
local job = vim.system( vim.system(
{ {
"docker", "docker",
"run", "run",
"--rm", "--rm",
"-v", "-v",
abs_path .. ":" .. abs_path, abs_path .. ":" .. abs_path,
"-w", "-w",
abs_path, abs_path,
container_image, container_image,
"python", "python",
"-c", "-c",
opts.code, opts.code,
}, },
{ {
text = true, text = true,
cwd = abs_path, cwd = abs_path,
}, },
vim.schedule_wrap(function(result) vim.schedule_wrap(function(result)
if not on_complete then return end if not on_complete then return end
local output, err = handle_result(result) local output, err = handle_result(result)
on_complete(output, err) on_complete(output, err)
end) end)
)
end
) )
if on_complete then return end
local result = job:wait()
return handle_result(result)
end end
---@param user_input string
---@param history_messages AvanteLLMMessage[]
---@return AvanteLLMTool[] ---@return AvanteLLMTool[]
function M.get_tools() function M.get_tools(user_input, history_messages)
local custom_tools = Config.custom_tools local custom_tools = Config.custom_tools
if type(custom_tools) == "function" then custom_tools = custom_tools() end if type(custom_tools) == "function" then custom_tools = custom_tools() end
---@type AvanteLLMTool[] ---@type AvanteLLMTool[]
@@ -634,7 +710,7 @@ function M.get_tools()
if tool.enabled == nil then if tool.enabled == nil then
return true return true
else else
return tool.enabled() return tool.enabled({ user_input = user_input, history_messages = history_messages })
end end
end) end)
:totable() :totable()
@@ -644,7 +720,7 @@ end
M._tools = { M._tools = {
{ {
name = "glob", name = "glob",
description = 'Fast file pattern matching using glob patterns like "**/*.js"', description = 'Fast file pattern matching using glob patterns like "**/*.js", in current project scope',
param = { param = {
type = "table", type = "table",
fields = { fields = {
@@ -655,7 +731,7 @@ M._tools = {
}, },
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory, as cwd", description = "Relative path to the project directory, as cwd",
type = "string", type = "string",
}, },
}, },
@@ -704,7 +780,7 @@ M._tools = {
}, },
{ {
name = "python", name = "python",
description = "Run python code. Can't use it to read files or modify files.", description = "Run python code in current project scope. Can't use it to read files or modify files.",
param = { param = {
type = "table", type = "table",
fields = { fields = {
@@ -715,7 +791,7 @@ M._tools = {
}, },
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory, as cwd", description = "Relative path to the project directory, as cwd",
type = "string", type = "string",
}, },
}, },
@@ -736,7 +812,7 @@ M._tools = {
}, },
{ {
name = "git_diff", name = "git_diff",
description = "Get git diff for generating commit message", description = "Get git diff for generating commit message in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
@@ -763,7 +839,7 @@ M._tools = {
}, },
{ {
name = "git_commit", name = "git_commit",
description = "Commit changes with the given commit message", description = "Commit changes with the given commit message in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
@@ -796,13 +872,13 @@ M._tools = {
}, },
{ {
name = "list_files", name = "list_files",
description = "List files in a directory", description = "List files in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory", description = "Relative path to the project directory",
type = "string", type = "string",
}, },
{ {
@@ -815,7 +891,7 @@ M._tools = {
returns = { returns = {
{ {
name = "files", name = "files",
description = "List of files in the directory", description = "List of filepaths in the directory",
type = "string[]", type = "string[]",
}, },
{ {
@@ -828,13 +904,13 @@ M._tools = {
}, },
{ {
name = "search_files", name = "search_files",
description = "Search for files in a directory", description = "Search for files in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory", description = "Relative path to the project directory",
type = "string", type = "string",
}, },
{ {
@@ -847,7 +923,7 @@ M._tools = {
returns = { returns = {
{ {
name = "files", name = "files",
description = "List of files that match the keyword", description = "List of filepaths that match the keyword",
type = "string", type = "string",
}, },
{ {
@@ -860,13 +936,13 @@ M._tools = {
}, },
{ {
name = "grep_search", name = "grep_search",
description = "Search for a keyword in a directory using grep", description = "Search for a keyword in a directory using grep in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory", description = "Relative path to the project directory",
type = "string", type = "string",
}, },
{ {
@@ -911,13 +987,13 @@ M._tools = {
}, },
{ {
name = "read_file_toplevel_symbols", name = "read_file_toplevel_symbols",
description = "Read the top-level symbols of a file", description = "Read the top-level symbols of a file in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the file", description = "Relative path to the file in current project scope",
type = "string", type = "string",
}, },
}, },
@@ -938,13 +1014,28 @@ M._tools = {
}, },
{ {
name = "read_file", name = "read_file",
description = "Read the contents of a file. If the file content is already in the context, do not use this tool.", description = "Read the contents of a file in current project scope. If the file content is already in the context, do not use this tool.",
enabled = function(opts)
if opts.user_input:match("@read_global_file") then return false end
for _, message in ipairs(opts.history_messages) do
if message.role == "user" then
local content = message.content
if type(content) == "string" and content:match("@read_global_file") then return false end
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) == "string" and item:match("@read_global_file") then return false end
end
end
end
end
return true
end,
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the file", description = "Relative path to the file in current project scope",
type = "string", type = "string",
}, },
}, },
@@ -963,15 +1054,104 @@ M._tools = {
}, },
}, },
}, },
{
name = "read_global_file",
description = "Read the contents of a file in the global scope. If the file content is already in the context, do not use this tool.",
enabled = function(opts)
if opts.user_input:match("@read_global_file") then return true end
for _, message in ipairs(opts.history_messages) do
if message.role == "user" then
local content = message.content
if type(content) == "string" and content:match("@read_global_file") then return true end
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) == "string" and item:match("@read_global_file") then return true end
end
end
end
end
return false
end,
param = {
type = "table",
fields = {
{
name = "abs_path",
description = "Absolute path to the file in global scope",
type = "string",
},
},
},
returns = {
{
name = "content",
description = "Contents of the file",
type = "string",
},
{
name = "error",
description = "Error message if the file was not read successfully",
type = "string",
optional = true,
},
},
},
{
name = "write_global_file",
description = "Write to a file in the global scope",
enabled = function(opts)
if opts.user_input:match("@write_global_file") then return true end
for _, message in ipairs(opts.history_messages) do
if message.role == "user" then
local content = message.content
if type(content) == "string" and content:match("@write_global_file") then return true end
if type(content) == "table" then
for _, item in ipairs(content) do
if type(item) == "string" and item:match("@write_global_file") then return true end
end
end
end
end
return false
end,
param = {
type = "table",
fields = {
{
name = "abs_path",
description = "Absolute path to the file in global scope",
type = "string",
},
{
name = "content",
description = "Content to write to the file",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the file was written successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not written successfully",
type = "string",
optional = true,
},
},
},
{ {
name = "create_file", name = "create_file",
description = "Create a new file", description = "Create a new file in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the file", description = "Relative path to the file in current project scope",
type = "string", type = "string",
}, },
}, },
@@ -992,13 +1172,13 @@ M._tools = {
}, },
{ {
name = "rename_file", name = "rename_file",
description = "Rename a file", description = "Rename a file in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the file", description = "Relative path to the file in current project scope",
type = "string", type = "string",
}, },
{ {
@@ -1024,13 +1204,13 @@ M._tools = {
}, },
{ {
name = "delete_file", name = "delete_file",
description = "Delete a file", description = "Delete a file in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the file", description = "Relative path to the file in current project scope",
type = "string", type = "string",
}, },
}, },
@@ -1051,13 +1231,13 @@ M._tools = {
}, },
{ {
name = "create_dir", name = "create_dir",
description = "Create a new directory", description = "Create a new directory in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory", description = "Relative path to the project directory",
type = "string", type = "string",
}, },
}, },
@@ -1078,13 +1258,13 @@ M._tools = {
}, },
{ {
name = "rename_dir", name = "rename_dir",
description = "Rename a directory", description = "Rename a directory in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory", description = "Relative path to the project directory",
type = "string", type = "string",
}, },
{ {
@@ -1110,13 +1290,13 @@ M._tools = {
}, },
{ {
name = "delete_dir", name = "delete_dir",
description = "Delete a directory", description = "Delete a directory in current project scope",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory", description = "Relative path to the project directory",
type = "string", type = "string",
}, },
}, },
@@ -1137,13 +1317,13 @@ M._tools = {
}, },
{ {
name = "bash", name = "bash",
description = "Run a bash command in a directory. Can't use search commands like find/grep or read tools like cat/ls. Can't use it to read files or modify files.", description = "Run a bash command in current project scope. Can't use search commands like find/grep or read tools like cat/ls. Can't use it to read files or modify files.",
param = { param = {
type = "table", type = "table",
fields = { fields = {
{ {
name = "rel_path", name = "rel_path",
description = "Relative path to the directory", description = "Relative path to the project directory, as cwd",
type = "string", type = "string",
}, },
{ {

View File

@@ -59,6 +59,8 @@ function M:parse_messages(opts)
---@type AvanteClaudeMessage[] ---@type AvanteClaudeMessage[]
local messages = {} local messages = {}
local provider_conf, _ = P.parse_config(self)
---@type {idx: integer, length: integer}[] ---@type {idx: integer, length: integer}[]
local messages_with_length = {} local messages_with_length = {}
for idx, message in ipairs(opts.messages) do for idx, message in ipairs(opts.messages) do
@@ -76,15 +78,46 @@ function M:parse_messages(opts)
end end
for idx, message in ipairs(opts.messages) do for idx, message in ipairs(opts.messages) do
local content_items = message.content
local message_content = {}
if type(content_items) == "string" then
table.insert(message_content, {
type = "text",
text = message.content,
cache_control = top_two[idx] and { type = "ephemeral" } or nil,
})
elseif type(content_items) == "table" then
---@cast content_items AvanteLLMMessageContentItem[]
for _, item in ipairs(content_items) do
if type(item) == "string" then
table.insert(
message_content,
{ type = "text", text = item, cache_control = top_two[idx] and { type = "ephemeral" } or nil }
)
elseif type(item) == "table" and item.type == "text" then
table.insert(
message_content,
{ type = "text", text = item.text, cache_control = top_two[idx] and { type = "ephemeral" } or nil }
)
elseif type(item) == "table" and item.type == "image" then
table.insert(message_content, { type = "image", source = item.source })
elseif not provider_conf.disable_tools and type(item) == "table" and item.type == "tool_use" then
table.insert(message_content, { type = "tool_use", name = item.name, id = item.id, input = item.input })
elseif not provider_conf.disable_tools and type(item) == "table" and item.type == "tool_result" then
table.insert(
message_content,
{ type = "tool_result", tool_use_id = item.tool_use_id, content = item.content, is_error = item.is_error }
)
elseif type(item) == "table" and item.type == "thinking" then
table.insert(message_content, { type = "thinking", thinking = item.thinking, signature = item.signature })
elseif type(item) == "table" and item.type == "redacted_thinking" then
table.insert(message_content, { type = "redacted_thinking", data = item.data })
end
end
end
table.insert(messages, { table.insert(messages, {
role = self.role_map[message.role], role = self.role_map[message.role],
content = { content = message_content,
{
type = "text",
text = message.content,
cache_control = top_two[idx] and { type = "ephemeral" } or nil,
},
},
}) })
end end

View File

@@ -35,9 +35,36 @@ function M:parse_messages(opts)
end end
end end
prev_role = role prev_role = role
table.insert(contents, { role = M.role_map[role] or role, parts = { local parts = {}
{ text = message.content }, local content_items = message.content
} }) if type(content_items) == "string" then
table.insert(parts, { text = content_items })
elseif type(content_items) == "table" then
---@cast content_items AvanteLLMMessageContentItem[]
for _, item in ipairs(content_items) do
if type(item) == "string" then
table.insert(parts, { text = item })
elseif type(item) == "table" and item.type == "text" then
table.insert(parts, { text = item.text })
elseif type(item) == "table" and item.type == "image" then
table.insert(parts, {
inline_data = {
mime_type = "image/png",
data = item.source.data,
},
})
elseif type(item) == "table" and item.type == "tool_use" then
table.insert(parts, { text = item.name })
elseif type(item) == "table" and item.type == "tool_result" then
table.insert(parts, { text = item.content })
elseif type(item) == "table" and item.type == "thinking" then
table.insert(parts, { text = item.thinking })
elseif type(item) == "table" and item.type == "redacted_thinking" then
table.insert(parts, { text = item.data })
end
end
end
table.insert(contents, { role = M.role_map[role] or role, parts = parts })
end) end)
if Clipboard.support_paste_image() and opts.image_paths then if Clipboard.support_paste_image() and opts.image_paths then

View File

@@ -81,9 +81,54 @@ function M:parse_messages(opts)
table.insert(messages, { role = "system", content = opts.system_prompt }) table.insert(messages, { role = "system", content = opts.system_prompt })
end end
vim vim.iter(opts.messages):each(function(msg)
.iter(opts.messages) if type(msg.content) == "string" then
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end) table.insert(messages, { role = self.role_map[msg.role], content = msg.content })
else
local content = {}
local tool_calls = {}
local tool_results = {}
for _, item in ipairs(msg.content) do
if type(item) == "string" then
table.insert(content, { type = "text", text = item })
elseif item.type == "text" then
table.insert(content, { type = "text", text = item.text })
elseif item.type == "image" then
table.insert(content, {
type = "image_url",
image_url = {
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
},
})
elseif item.type == "tool_use" then
table.insert(tool_calls, {
id = item.id,
type = "function",
["function"] = { name = item.name, arguments = vim.json.encode(item.input) },
})
elseif item.type == "tool_result" then
table.insert(
tool_results,
{ tool_call_id = item.tool_use_id, content = item.is_error and "Error: " .. item.content or item.content }
)
end
end
table.insert(messages, { role = self.role_map[msg.role], content = content })
if not provider_conf.disable_tools then
if #tool_calls > 0 then
table.insert(messages, { role = self.role_map["assistant"], tool_calls = tool_calls })
end
if #tool_results > 0 then
for _, tool_result in ipairs(tool_results) do
table.insert(
messages,
{ role = "tool", tool_call_id = tool_result.tool_call_id, content = tool_result.content or "" }
)
end
end
end
end
end)
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content local message_content = messages[#messages].content

View File

@@ -43,6 +43,7 @@ local Sidebar = {}
---@field selected_files_container NuiSplit | nil ---@field selected_files_container NuiSplit | nil
---@field input_container NuiSplit | nil ---@field input_container NuiSplit | nil
---@field file_selector FileSelector ---@field file_selector FileSelector
---@field chat_history avante.ChatHistory | nil
---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage() ---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage()
function Sidebar:new(id) function Sidebar:new(id)
@@ -61,6 +62,7 @@ function Sidebar:new(id)
input_container = nil, input_container = nil,
file_selector = FileSelector:new(id), file_selector = FileSelector:new(id),
is_generating = false, is_generating = false,
chat_history = nil,
}, { __index = self }) }, { __index = self })
end end
@@ -397,6 +399,7 @@ local function transform_result_content(selected_files, result_content, prev_fil
elseif line_content == "<think>" then elseif line_content == "<think>" then
is_thinking = true is_thinking = true
last_think_tag_start_line = i last_think_tag_start_line = i
last_think_tag_end_line = 0
elseif line_content == "</think>" then elseif line_content == "</think>" then
is_thinking = false is_thinking = false
last_think_tag_end_line = i last_think_tag_end_line = i
@@ -1810,6 +1813,7 @@ function Sidebar:on_mount(opts)
group = self.augroup, group = self.augroup,
callback = function(args) callback = function(args)
local closed_winid = tonumber(args.match) local closed_winid = tonumber(args.match)
if closed_winid == self.winids.selected_files_container then return end
if not self:is_focused_on(closed_winid) then return end if not self:is_focused_on(closed_winid) then return end
self:close() self:close()
end, end,
@@ -1838,6 +1842,7 @@ function Sidebar:refresh_winids()
local function switch_windows() local function switch_windows()
local current_winid = api.nvim_get_current_win() local current_winid = api.nvim_get_current_win()
winids = vim.iter(winids):filter(function(winid) return api.nvim_win_is_valid(winid) end):totable()
local current_idx = Utils.tbl_indexof(winids, current_winid) or 1 local current_idx = Utils.tbl_indexof(winids, current_winid) or 1
if current_idx == #winids then if current_idx == #winids then
current_idx = 1 current_idx = 1
@@ -1906,6 +1911,8 @@ function Sidebar:initialize()
self.file_selector:reset() self.file_selector:reset()
self.file_selector:add_selected_file(filepath) self.file_selector:add_selected_file(filepath)
self:reload_chat_history()
return self return self
end end
@@ -2095,6 +2102,7 @@ end
function Sidebar:render_history_content(history) function Sidebar:render_history_content(history)
local content = "" local content = ""
for idx, entry in ipairs(history.entries) do for idx, entry in ipairs(history.entries) do
if entry.visible == false then goto continue end
if entry.reset_memory then if entry.reset_memory then
content = content .. "***MEMORY RESET***\n\n" content = content .. "***MEMORY RESET***\n\n"
if idx < #history.entries then content = content .. "-------\n\n" end if idx < #history.entries then content = content .. "-------\n\n" end
@@ -2180,7 +2188,7 @@ end
function Sidebar:new_chat(args, cb) function Sidebar:new_chat(args, cb)
Path.history.new(self.code.bufnr) Path.history.new(self.code.bufnr)
Sidebar.reload_chat_history() self:reload_chat_history()
self:update_content( self:update_content(
"New chat", "New chat",
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end } { ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
@@ -2188,6 +2196,26 @@ function Sidebar:new_chat(args, cb)
if cb then cb(args) end if cb then cb(args) end
end end
---@param message AvanteLLMMessage
---@param options {visible?: boolean}
function Sidebar:add_chat_history(message, options)
local timestamp = get_timestamp()
self:reload_chat_history()
table.insert(self.chat_history.entries, {
timestamp = timestamp,
provider = Config.provider,
model = Config.get_provider_config(Config.provider).model,
request = message.role == "user" and message.content or "",
response = message.role == "assistant" and message.content or "",
original_response = "",
selected_filepaths = nil,
selected_code = nil,
reset_memory = false,
visible = options.visible,
})
Path.history.save(self.code.bufnr, self.chat_history)
end
function Sidebar:reset_memory(args, cb) function Sidebar:reset_memory(args, cb)
local chat_history = Path.history.load(self.code.bufnr) local chat_history = Path.history.load(self.code.bufnr)
if next(chat_history) ~= nil then if next(chat_history) ~= nil then
@@ -2203,7 +2231,7 @@ function Sidebar:reset_memory(args, cb)
reset_memory = true, reset_memory = true,
}) })
Path.history.save(self.code.bufnr, chat_history) Path.history.save(self.code.bufnr, chat_history)
Sidebar.reload_chat_history() self:reload_chat_history()
local history_content = self:render_history_content(chat_history) local history_content = self:render_history_content(chat_history)
self:update_content(history_content, { self:update_content(history_content, {
focus = false, focus = false,
@@ -2212,7 +2240,7 @@ function Sidebar:reset_memory(args, cb)
}) })
if cb then cb(args) end if cb then cb(args) end
else else
Sidebar.reload_chat_history() self:reload_chat_history()
self:update_content( self:update_content(
"Chat history is already empty", "Chat history is already empty",
{ focus = false, scroll = false, callback = function() self:focus_input() end } { focus = false, scroll = false, callback = function() self:focus_input() end }
@@ -2321,46 +2349,18 @@ local generating_text = "**Generating response ...**\n"
local hint_window = nil local hint_window = nil
function Sidebar:reload_chat_history()
if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end
self.chat_history = Path.history.load(self.code.bufnr)
end
---@param opts AskOptions ---@param opts AskOptions
function Sidebar:create_input_container(opts) function Sidebar:create_input_container(opts)
if self.input_container then self.input_container:unmount() end if self.input_container then self.input_container:unmount() end
if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end
local chat_history = Path.history.load(self.code.bufnr) if self.chat_history == nil then self:reload_chat_history() end
Sidebar.reload_chat_history = function() chat_history = Path.history.load(self.code.bufnr) end
local tools = vim.deepcopy(LLMTools.get_tools())
table.insert(tools, {
name = "add_file_to_context",
description = "Add a file to the context",
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:add_selected_file(input.rel_path)
return "Added file to context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
table.insert(tools, {
name = "remove_file_from_context",
description = "Remove a file from the context",
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:remove_selected_file(input.rel_path)
return "Removed file from context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
---@param request string ---@param request string
---@param summarize_memory boolean ---@param summarize_memory boolean
@@ -2399,17 +2399,48 @@ function Sidebar:create_input_container(opts)
end end
end end
local entries = Utils.history.filter_active_entries(chat_history.entries) local entries = Utils.history.filter_active_entries(self.chat_history.entries)
if chat_history.memory then if self.chat_history.memory then
entries = vim entries = vim
.iter(entries) .iter(entries)
:filter(function(entry) return entry.timestamp > chat_history.memory.last_summarized_timestamp end) :filter(function(entry) return entry.timestamp > self.chat_history.memory.last_summarized_timestamp end)
:totable() :totable()
end end
local history_messages = Utils.history.entries_to_llm_messages(entries) local history_messages = Utils.history.entries_to_llm_messages(entries)
local tools = vim.deepcopy(LLMTools.get_tools(request, history_messages))
table.insert(tools, {
name = "add_file_to_context",
description = "Add a file to the context",
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:add_selected_file(input.rel_path)
return "Added file to context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
table.insert(tools, {
name = "remove_file_from_context",
description = "Remove a file from the context",
---@type AvanteLLMToolFunc<{ rel_path: string }>
func = function(input)
self.file_selector:remove_selected_file(input.rel_path)
return "Removed file from context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
---@type AvanteGeneratePromptsOptions ---@type AvanteGeneratePromptsOptions
local prompts_opts = { local prompts_opts = {
ask = opts.ask or true, ask = opts.ask or true,
@@ -2425,7 +2456,7 @@ function Sidebar:create_input_container(opts)
tools = tools, tools = tools,
} }
if chat_history.memory then prompts_opts.memory = chat_history.memory.content end if self.chat_history.memory then prompts_opts.memory = self.chat_history.memory.content end
if not summarize_memory or #history_messages < 8 then if not summarize_memory or #history_messages < 8 then
cb(prompts_opts) cb(prompts_opts)
@@ -2434,7 +2465,7 @@ function Sidebar:create_input_container(opts)
prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5) prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5)
Llm.summarize_memory(self.code.bufnr, chat_history, function(memory) Llm.summarize_memory(self.code.bufnr, self.chat_history, function(memory)
if memory then prompts_opts.memory = memory.content end if memory then prompts_opts.memory = memory.content end
cb(prompts_opts) cb(prompts_opts)
end) end)
@@ -2628,8 +2659,8 @@ function Sidebar:create_input_container(opts)
end, 0) end, 0)
-- Save chat history -- Save chat history
chat_history.entries = chat_history.entries or {} self.chat_history.entries = self.chat_history.entries or {}
table.insert(chat_history.entries, { table.insert(self.chat_history.entries, {
timestamp = timestamp, timestamp = timestamp,
provider = Config.provider, provider = Config.provider,
model = model, model = model,
@@ -2638,8 +2669,9 @@ function Sidebar:create_input_container(opts)
original_response = original_response, original_response = original_response,
selected_filepaths = selected_filepaths, selected_filepaths = selected_filepaths,
selected_code = selected_code, selected_code = selected_code,
tool_histories = stop_opts.tool_histories,
}) })
Path.history.save(self.code.bufnr, chat_history) Path.history.save(self.code.bufnr, self.chat_history)
end end
get_generate_prompts_options(request, true, function(generate_prompts_options) get_generate_prompts_options(request, true, function(generate_prompts_options)

View File

@@ -19,3 +19,4 @@ Tools Usage Guide:
- For any mathematical calculation problems, please prioritize using the `python` tool to solve them. Please try to avoid mathematical symbols in the return value of the `python` tool for mathematical problems and directly output human-readable results, because large models don't understand mathematical symbols, they only understand human natural language. - For any mathematical calculation problems, please prioritize using the `python` tool to solve them. Please try to avoid mathematical symbols in the return value of the `python` tool for mathematical problems and directly output human-readable results, because large models don't understand mathematical symbols, they only understand human natural language.
- Do not use the `python` tool to read or modify files! If you use the `python` tool to read or modify files, you will be fired!!!!! - Do not use the `python` tool to read or modify files! If you use the `python` tool to read or modify files, you will be fired!!!!!
- Do not use the `bash` tool to read or modify files! If you use the `bash` tool to read or modify files, you will be fired!!!!! - Do not use the `bash` tool to read or modify files! If you use the `bash` tool to read or modify files, you will be fired!!!!!
- If you are provided with the `write_file` tool, there's no need to output your change suggestions, just directly use the `write_file` tool to complete the changes.

View File

@@ -1,11 +1,3 @@
{# Uses https://mitsuhiko.github.io/minijinja-playground/ for testing:
{
"ask": true,
"question": "Refactor to include tab flow",
"code_lang": "lua",
"file_content": "local Config = require('avante.config')"
}
#}
Act as an expert software developer. Act as an expert software developer.
Always use best practices when coding. Always use best practices when coding.
Respect and use existing conventions, libraries, etc that are already present in the code base. Respect and use existing conventions, libraries, etc that are already present in the code base.

View File

@@ -76,9 +76,13 @@ vim.g.avante_login = vim.g.avante_login
---@field on_chunk AvanteLLMChunkCallback ---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback ---@field on_stop AvanteLLMStopCallback
--- ---
---@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string }
---
---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string
---
---@class AvanteLLMMessage ---@class AvanteLLMMessage
---@field role "user" | "assistant" ---@field role "user" | "assistant"
---@field content string ---@field content AvanteLLMMessageContent
--- ---
---@class AvanteLLMToolResult ---@class AvanteLLMToolResult
---@field tool_name string ---@field tool_name string
@@ -245,6 +249,7 @@ vim.g.avante_login = vim.g.avante_login
---@field tool_use_list? AvanteLLMToolUse[] ---@field tool_use_list? AvanteLLMToolUse[]
---@field retry_after? integer ---@field retry_after? integer
---@field headers? table<string, string> ---@field headers? table<string, string>
---@field tool_histories? AvanteLLMToolHistory[]
--- ---
---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
@@ -342,7 +347,7 @@ vim.g.avante_login = vim.g.avante_login
---@field func? AvanteLLMToolFunc ---@field func? AvanteLLMToolFunc
---@field param AvanteLLMToolParam ---@field param AvanteLLMToolParam
---@field returns AvanteLLMToolReturn[] ---@field returns AvanteLLMToolReturn[]
---@field enabled? fun(): boolean ---@field enabled? fun(opts: { user_input: string, history_messages: AvanteLLMMessage[] }): boolean
---@class AvanteLLMToolPublic : AvanteLLMTool ---@class AvanteLLMToolPublic : AvanteLLMTool
---@field func AvanteLLMToolFunc ---@field func AvanteLLMToolFunc
@@ -374,6 +379,8 @@ vim.g.avante_login = vim.g.avante_login
---@field selected_code AvanteSelectedCode | nil ---@field selected_code AvanteSelectedCode | nil
---@field reset_memory boolean? ---@field reset_memory boolean?
---@field selected_filepaths string[] | nil ---@field selected_filepaths string[] | nil
---@field visible boolean?
---@field tool_histories? AvanteLLMToolHistory[]
--- ---
---@class avante.ChatHistory ---@class avante.ChatHistory
---@field title string ---@field title string

139
lua/avante/ui.lua Normal file
View File

@@ -0,0 +1,139 @@
local Popup = require("nui.popup")
local NuiText = require("nui.text")
local event = require("nui.utils.autocmd").event
local Highlights = require("avante.highlights")
local M = {}
function M.confirm(message, callback)
local focus_index = 2 -- 1 = Yes, 2 = No
local yes_button_pos = { 18, 23 }
local no_button_pos = { 28, 32 }
local BUTTON_NORMAL = Highlights.BUTTON_DEFAULT
local BUTTON_FOCUS = Highlights.BUTTON_DEFAULT_HOVER
local popup = Popup({
position = {
row = vim.o.lines - 5,
col = "50%",
},
size = { width = 50, height = 7 },
enter = true,
focusable = true,
border = {
style = "rounded",
text = { top = NuiText(" Confirmation ", Highlights.CONFIRM_TITLE) },
},
win_options = {
winblend = 10,
},
})
local function focus_button()
if focus_index == 1 then
vim.api.nvim_win_set_cursor(popup.winid, { 4, yes_button_pos[1] })
else
vim.api.nvim_win_set_cursor(popup.winid, { 4, no_button_pos[1] })
end
end
local function render_buttons()
local yes_style = (focus_index == 1) and BUTTON_FOCUS or BUTTON_NORMAL
local no_style = (focus_index == 2) and BUTTON_FOCUS or BUTTON_NORMAL
vim.api.nvim_buf_set_lines(popup.bufnr, 0, -1, false, {
"",
" " .. message,
"",
" " .. " Yes No ",
"",
})
vim.api.nvim_buf_add_highlight(popup.bufnr, 0, yes_style, 3, yes_button_pos[1], yes_button_pos[2])
vim.api.nvim_buf_add_highlight(popup.bufnr, 0, no_style, 3, no_button_pos[1], no_button_pos[2])
focus_button()
end
local function select_button()
popup:unmount()
callback(focus_index == 1)
end
vim.keymap.set("n", "y", function()
focus_index = 1
render_buttons()
select_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "n", function()
focus_index = 2
render_buttons()
select_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<Left>", function()
focus_index = 1
focus_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<Right>", function()
focus_index = 2
focus_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<Tab>", function()
focus_index = (focus_index == 1) and 2 or 1
focus_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<S-Tab>", function()
focus_index = (focus_index == 1) and 2 or 1
focus_button()
end, { buffer = popup.bufnr })
vim.keymap.set("n", "<CR>", function() select_button() end, { buffer = popup.bufnr })
vim.api.nvim_buf_set_keymap(popup.bufnr, "n", "<LeftMouse>", "", {
callback = function()
local pos = vim.fn.getmousepos()
local row, col = pos["winrow"], pos["wincol"]
if row == 4 then
if col >= yes_button_pos[1] and col <= yes_button_pos[2] then
focus_index = 1
render_buttons()
select_button()
elseif col >= no_button_pos[1] and col <= no_button_pos[2] then
focus_index = 2
render_buttons()
select_button()
end
end
end,
noremap = true,
silent = true,
})
vim.api.nvim_create_autocmd("CursorMoved", {
buffer = popup.bufnr,
callback = function()
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
if row == 4 then
if col >= yes_button_pos[1] and col <= yes_button_pos[2] then
focus_index = 1
render_buttons()
elseif col >= no_button_pos[1] and col <= no_button_pos[2] then
focus_index = 2
render_buttons()
end
end
end,
})
popup:on(event.BufLeave, function() popup:unmount() end)
popup:mount()
render_buttons()
end
return M

View File

@@ -11,14 +11,6 @@ function M.filter_active_entries(entries)
for i = #entries, 1, -1 do for i = #entries, 1, -1 do
local entry = entries[i] local entry = entries[i]
if entry.reset_memory then break end if entry.reset_memory then break end
if
entry.request == nil
or entry.original_response == nil
or entry.request == ""
or entry.original_response == ""
then
break
end
table.insert(entries_, 1, entry) table.insert(entries_, 1, entry)
end end
@@ -30,25 +22,62 @@ end
function M.entries_to_llm_messages(entries) function M.entries_to_llm_messages(entries)
local messages = {} local messages = {}
for _, entry in ipairs(entries) do for _, entry in ipairs(entries) do
local user_content = "" if entry.selected_filepaths ~= nil and #entry.selected_filepaths > 0 then
if entry.selected_filepaths ~= nil then local user_content = "SELECTED FILES:\n\n"
user_content = user_content .. "SELECTED FILES:\n\n"
for _, filepath in ipairs(entry.selected_filepaths) do for _, filepath in ipairs(entry.selected_filepaths) do
user_content = user_content .. filepath .. "\n" user_content = user_content .. filepath .. "\n"
end end
table.insert(messages, { role = "user", content = user_content })
end end
if entry.selected_code ~= nil then if entry.selected_code ~= nil then
user_content = user_content local user_content_ = "SELECTED CODE:\n\n```"
.. "SELECTED CODE:\n\n```"
.. (entry.selected_code.file_type or "") .. (entry.selected_code.file_type or "")
.. (entry.selected_code.path and ":" .. entry.selected_code.path or "") .. (entry.selected_code.path and ":" .. entry.selected_code.path or "")
.. "\n" .. "\n"
.. entry.selected_code.content .. entry.selected_code.content
.. "\n```\n\n" .. "\n```\n\n"
table.insert(messages, { role = "user", content = user_content_ })
end end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request if entry.request ~= nil and entry.request ~= "" then
table.insert(messages, { role = "user", content = user_content }) table.insert(messages, { role = "user", content = entry.request })
table.insert(messages, { role = "assistant", content = Utils.trim_think_content(entry.original_response) }) end
if entry.tool_histories ~= nil and #entry.tool_histories > 0 then
for _, tool_history in ipairs(entry.tool_histories) do
local assistant_content = {}
if tool_history.tool_use ~= nil then
if tool_history.tool_use.response_contents ~= nil then
for _, response_content in ipairs(tool_history.tool_use.response_contents) do
table.insert(assistant_content, { type = "text", text = response_content })
end
end
table.insert(assistant_content, {
type = "tool_use",
name = tool_history.tool_use.name,
id = tool_history.tool_use.id,
input = vim.json.decode(tool_history.tool_use.input_json),
})
end
table.insert(messages, {
role = "assistant",
content = assistant_content,
})
local user_content = {}
if tool_history.tool_result ~= nil and tool_history.tool_result.content ~= nil then
table.insert(user_content, {
type = "tool_result",
tool_use_id = tool_history.tool_result.tool_use_id,
content = tool_history.tool_result.content,
is_error = tool_history.tool_result.is_error,
})
end
table.insert(messages, {
role = "user",
content = user_content,
})
end
end
local assistant_content = Utils.trim_think_content(entry.original_response or "")
if assistant_content ~= "" then table.insert(messages, { role = "assistant", content = assistant_content }) end
end end
return messages return messages
end end

View File

@@ -10,9 +10,29 @@ local cost_per_token = {
} }
--- Calculate the number of tokens in a given text. --- Calculate the number of tokens in a given text.
---@param text string The text to calculate the number of tokens for. ---@param content AvanteLLMMessageContent The text to calculate the number of tokens in.
---@return integer The number of tokens in the given text. ---@return integer The number of tokens in the given text.
function Tokens.calculate_tokens(text) function Tokens.calculate_tokens(content)
local text = ""
if type(content) == "string" then
text = content
elseif type(content) == "table" then
for _, item in ipairs(content) do
if type(item) == "string" then
text = text .. item
elseif type(item) == "table" and item.type == "text" then
text = text .. item.text
elseif type(item) == "table" and item.type == "image" then
text = text .. item.source.data
elseif type(item) == "table" and item.type == "tool_use" then
text = text .. item.name .. item.id
elseif type(item) == "table" and item.type == "tool_result" then
text = text .. item.tool_use_id .. item.content
end
end
end
if Tokenizer.available() then return Tokenizer.count(text) end if Tokenizer.available() then return Tokenizer.count(text) end
local tokens = 0 local tokens = 0

View File

@@ -3,7 +3,7 @@ local LlmTools = require("avante.llm_tools")
local Config = require("avante.config") local Config = require("avante.config")
local Utils = require("avante.utils") local Utils = require("avante.utils")
LlmTools.confirm = function(msg) return true end LlmTools.confirm = function(msg, cb) return cb(true) end
describe("llm_tools", function() describe("llm_tools", function()
local test_dir = "/tmp/test_llm_tools" local test_dir = "/tmp/test_llm_tools"
@@ -85,34 +85,37 @@ describe("llm_tools", function()
describe("create_file", function() describe("create_file", function()
it("should create new file", function() it("should create new file", function()
local success, err = LlmTools.create_file({ rel_path = "new_file.txt" }) LlmTools.create_file({ rel_path = "new_file.txt" }, nil, function(success, err)
assert.is_nil(err) assert.is_nil(err)
assert.is_true(success) assert.is_true(success)
local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil
assert.is_true(file_exists) assert.is_true(file_exists)
end)
end) end)
end) end)
describe("create_dir", function() describe("create_dir", function()
it("should create new directory", function() it("should create new directory", function()
local success, err = LlmTools.create_dir({ rel_path = "new_dir" }) LlmTools.create_dir({ rel_path = "new_dir" }, nil, function(success, err)
assert.is_nil(err) assert.is_nil(err)
assert.is_true(success) assert.is_true(success)
local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil
assert.is_true(dir_exists) assert.is_true(dir_exists)
end)
end) end)
end) end)
describe("delete_file", function() describe("delete_file", function()
it("should delete existing file", function() it("should delete existing file", function()
local success, err = LlmTools.delete_file({ rel_path = "test.txt" }) LlmTools.delete_file({ rel_path = "test.txt" }, nil, function(success, err)
assert.is_nil(err) assert.is_nil(err)
assert.is_true(success) assert.is_true(success)
local file_exists = io.open(test_file, "r") ~= nil local file_exists = io.open(test_file, "r") ~= nil
assert.is_false(file_exists) assert.is_false(file_exists)
end)
end) end)
end) end)
@@ -270,68 +273,93 @@ describe("llm_tools", function()
describe("bash", function() describe("bash", function()
it("should execute command and return output", function() it("should execute command and return output", function()
local result, err = LlmTools.bash({ rel_path = ".", command = "echo 'test'" }) LlmTools.bash({ rel_path = ".", command = "echo 'test'" }, nil, function(result, err)
assert.is_nil(err) assert.is_nil(err)
assert.equals("test\n", result) assert.equals("test\n", result)
end)
end) end)
it("should return error when running outside current directory", function() it("should return error when running outside current directory", function()
local result, err = LlmTools.bash({ rel_path = "../outside_project", command = "echo 'test'" }) LlmTools.bash({ rel_path = "../outside_project", command = "echo 'test'" }, nil, function(result, err)
assert.is_false(result) assert.is_false(result)
assert.truthy(err) assert.truthy(err)
assert.truthy(err:find("No permission to access path")) assert.truthy(err:find("No permission to access path"))
end)
end) end)
end) end)
describe("python", function() describe("python", function()
local original_system = vim.fn.system
it("should execute Python code and return output", function() it("should execute Python code and return output", function()
local result, err = LlmTools.python({ LlmTools.python(
rel_path = ".", {
code = "print('Hello from Python')", rel_path = ".",
}) code = "print('Hello from Python')",
assert.is_nil(err) },
assert.equals("Hello from Python\n", result) nil,
function(result, err)
assert.is_nil(err)
assert.equals("Hello from Python\n", result)
end
)
end) end)
it("should handle Python errors", function() it("should handle Python errors", function()
local result, err = LlmTools.python({ LlmTools.python(
rel_path = ".", {
code = "print(undefined_variable)", rel_path = ".",
}) code = "print(undefined_variable)",
assert.is_nil(result) },
assert.truthy(err) nil,
assert.truthy(err:find("Error")) function(result, err)
assert.is_nil(result)
assert.truthy(err)
assert.truthy(err:find("Error"))
end
)
end) end)
it("should respect path permissions", function() it("should respect path permissions", function()
local result, err = LlmTools.python({ LlmTools.python(
rel_path = "../outside_project", {
code = "print('test')", rel_path = "../outside_project",
}) code = "print('test')",
assert.is_nil(result) },
assert.truthy(err:find("No permission to access path")) nil,
function(result, err)
assert.is_nil(result)
assert.truthy(err:find("No permission to access path"))
end
)
end) end)
it("should handle non-existent paths", function() it("should handle non-existent paths", function()
local result, err = LlmTools.python({ LlmTools.python(
rel_path = "non_existent_dir", {
code = "print('test')", rel_path = "non_existent_dir",
}) code = "print('test')",
assert.is_nil(result) },
assert.truthy(err:find("Path not found")) nil,
function(result, err)
assert.is_nil(result)
assert.truthy(err:find("Path not found"))
end
)
end) end)
it("should support custom container image", function() it("should support custom container image", function()
os.execute("docker image rm python:3.12-slim") os.execute("docker image rm python:3.12-slim")
local result, err = LlmTools.python({ LlmTools.python(
rel_path = ".", {
code = "print('Hello from custom container')", rel_path = ".",
container_image = "python:3.12-slim", code = "print('Hello from custom container')",
}) container_image = "python:3.12-slim",
assert.is_nil(err) },
assert.equals("Hello from custom container\n", result) nil,
function(result, err)
assert.is_nil(err)
assert.equals("Hello from custom container\n", result)
end
)
end) end)
end) end)