From f2bd4adba4d11b75ad8bade11e6a9dae22272b73 Mon Sep 17 00:00:00 2001 From: yetone Date: Thu, 6 Feb 2025 16:00:14 +0800 Subject: [PATCH] feat: add `add_file_to_context` tool (#1191) --- lua/avante/file_selector.lua | 8 ++- lua/avante/llm.lua | 2 +- lua/avante/llm_tools.lua | 8 +-- lua/avante/sidebar.lua | 95 ++++++++++++++++++++++++++++-------- 4 files changed, 88 insertions(+), 25 deletions(-) diff --git a/lua/avante/file_selector.lua b/lua/avante/file_selector.lua index 6da117f..aa6e280 100644 --- a/lua/avante/file_selector.lua +++ b/lua/avante/file_selector.lua @@ -344,7 +344,7 @@ end ---@param idx integer ---@return boolean -function FileSelector:remove_selected_filepaths(idx) +function FileSelector:remove_selected_filepaths_with_index(idx) if idx > 0 and idx <= #self.selected_filepaths then table.remove(self.selected_filepaths, idx) self:emit("update") @@ -353,6 +353,12 @@ function FileSelector:remove_selected_filepaths(idx) return false end +function FileSelector:remove_selected_file(rel_path) + local uniform_path = Utils.uniform_path(rel_path) + local idx = Utils.tbl_indexof(self.selected_filepaths, uniform_path) + if idx then self:remove_selected_filepaths_with_index(idx) end +end + ---@return { path: string, content: string, file_type: string }[] function FileSelector:get_selected_files_contents() local contents = {} diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 2c481e4..ccb9217 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -146,7 +146,7 @@ M._stream = function(opts) on_chunk = opts.on_chunk, on_stop = function(stop_opts) if stop_opts.reason == "tool_use" and stop_opts.tool_use then - local result, error = LLMTools.process_tool_use(stop_opts.tool_use, opts.on_tool_log) + local result, error = LLMTools.process_tool_use(opts.tools, stop_opts.tool_use, opts.on_tool_log) local tool_result = { tool_use_id = stop_opts.tool_use.id, content = error ~= nil and error or result, diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index fa64f12..7a95713 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -311,6 +311,7 @@ end ---@class AvanteLLMTool ---@field name string ---@field description string +---@field func? fun(input: any): (string | nil, string | nil) ---@field param AvanteLLMToolParam ---@field returns AvanteLLMToolReturn[] @@ -716,16 +717,17 @@ M.tools = { }, } +---@param tools AvanteLLMTool[] ---@param tool_use AvanteLLMToolUse ---@param on_log? fun(tool_name: string, log: string): nil ---@return string | nil result ---@return string | nil error -function M.process_tool_use(tool_use, on_log) +function M.process_tool_use(tools, tool_use, on_log) Utils.debug("use tool", tool_use.name, tool_use.input_json) - local tool = vim.iter(M.tools):find(function(tool) return tool.name == tool_use.name end) + local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end) if tool == nil then return end local input_json = vim.json.decode(tool_use.input_json) - local func = M[tool.name] + local func = tool.func or M[tool.name] if on_log then on_log(tool_use.name, "running tool") end local result, error = func(input_json, function(log) if on_log then on_log(tool_use.name, log) end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 2112719..fd0a79d 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1628,6 +1628,41 @@ function Sidebar:create_input_container(opts) local chat_history = Path.history.load(self.code.bufnr) + local tools = vim.deepcopy(LLMTools.tools) + table.insert(tools, { + name = "add_file_to_context", + description = "Add a file to the context", + ---@param input { rel_path: string } + ---@return string | nil result + ---@return string | nil error + func = function(input) + self.file_selector:add_selected_file(input.rel_path) + return "Added file to context", nil + end, + param = { + type = "table", + fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } }, + }, + returns = {}, + }) + + table.insert(tools, { + name = "remove_file_from_context", + description = "Remove a file from the context", + ---@param input { rel_path: string } + ---@return string | nil result + ---@return string | nil error + func = function(input) + self.file_selector:remove_selected_file(input.rel_path) + return "Removed file from context", nil + end, + param = { + type = "table", + fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } }, + }, + returns = {}, + }) + ---@param request string ---@return GeneratePromptsOptions local function get_generate_prompts_options(request) @@ -1697,7 +1732,7 @@ function Sidebar:create_input_container(opts) selected_code = selected_code_content, instructions = request, mode = "planning", - tools = LLMTools.tools, + tools = tools, } end @@ -2133,6 +2168,41 @@ function Sidebar:get_selected_code_size() return selected_code_size end +function Sidebar:get_selected_files_size() + if not self.file_selector then return 0 end + + local selected_files_max_lines_count = 10 + + local selected_files = self.file_selector:get_selected_filepaths() + local selected_files_size = #selected_files + selected_files_size = math.min(selected_files_size, selected_files_max_lines_count) + + return selected_files_size +end + +function Sidebar:get_result_container_height() + local selected_code_size = self:get_selected_code_size() + local selected_files_size = self:get_selected_files_size() + + if self:get_layout() == "horizontal" then return math.floor(Config.windows.height / 100 * vim.o.lines) end + + return math.max(1, api.nvim_win_get_height(self.code.winid) - selected_files_size - selected_code_size - 3 - 8) +end + +function Sidebar:get_result_container_width() + if self:get_layout() == "vertical" then return math.floor(Config.windows.width / 100 * vim.o.columns) end + + return math.max(1, api.nvim_win_get_width(self.code.winid)) +end + +function Sidebar:adjust_result_container_layout() + local width = self:get_result_container_width() + local height = self:get_result_container_height() + + api.nvim_win_set_width(self.result_container.winid, width) + api.nvim_win_set_height(self.result_container.winid, height) +end + ---@param opts AskOptions function Sidebar:render(opts) local chat_history = Path.history.load(self.code.bufnr) @@ -2141,20 +2211,6 @@ function Sidebar:render(opts) return (opts and opts.win and opts.win.position) and opts.win.position or calculate_config_window_position() end - local get_height = function() - local selected_code_size = self:get_selected_code_size() - - if self:get_layout() == "horizontal" then return math.floor(Config.windows.height / 100 * vim.o.lines) end - - return math.max(1, api.nvim_win_get_height(self.code.winid) - selected_code_size - 3 - 8) - end - - local get_width = function() - if self:get_layout() == "vertical" then return math.floor(Config.windows.width / 100 * vim.o.columns) end - - return math.max(1, api.nvim_win_get_width(self.code.winid)) - end - self.result_container = Split({ enter = false, relative = "editor", @@ -2170,8 +2226,8 @@ function Sidebar:render(opts) wrap = Config.windows.wrap, }), size = { - width = get_width(), - height = get_height(), + width = self:get_result_container_width(), + height = self:get_result_container_height(), }, }) @@ -2270,13 +2326,12 @@ function Sidebar:create_selected_files_container() Highlights.SUBTITLE, Highlights.REVERSED_SUBTITLE ) + self:adjust_result_container_layout() end self.file_selector:on("update", render) - local remove_file = function(line_number) - if self.file_selector:remove_selected_filepaths(line_number) then render() end - end + local remove_file = function(line_number) self.file_selector:remove_selected_filepaths_with_index(line_number) end -- Function to show hint local function show_hint()