From bc403ddcbf98c4181ee2a7efd35cd1e18a2fdc5c Mon Sep 17 00:00:00 2001 From: yetone Date: Sat, 31 May 2025 08:53:34 +0800 Subject: [PATCH] feat: ReAct tool calling (#2104) --- lua/avante/config.lua | 1 + lua/avante/libs/xmlparser.lua | 517 ++++++++++++++++++++ lua/avante/llm.lua | 12 +- lua/avante/llm_tools/attempt_completion.lua | 4 + lua/avante/llm_tools/bash.lua | 10 +- lua/avante/llm_tools/dispatch_agent.lua | 3 + lua/avante/llm_tools/get_diagnostics.lua | 3 + lua/avante/llm_tools/glob.lua | 10 +- lua/avante/llm_tools/grep.lua | 13 +- lua/avante/llm_tools/init.lua | 117 +++-- lua/avante/llm_tools/insert.lua | 5 + lua/avante/llm_tools/ls.lua | 10 +- lua/avante/llm_tools/replace_in_file.lua | 6 + lua/avante/llm_tools/str_replace.lua | 10 +- lua/avante/llm_tools/undo_edit.lua | 3 + lua/avante/llm_tools/view.lua | 39 +- lua/avante/llm_tools/write_to_file.lua | 74 +++ lua/avante/providers/ollama.lua | 181 ++++++- lua/avante/providers/openai.lua | 137 +++++- lua/avante/sidebar.lua | 143 ++++-- lua/avante/templates/agentic.avanterules | 4 +- lua/avante/types.lua | 2 + lua/avante/utils/init.lua | 36 +- lua/avante/utils/prompts.lua | 144 ++++++ tests/llm_tools_spec.lua | 62 +-- 25 files changed, 1358 insertions(+), 188 deletions(-) create mode 100644 lua/avante/libs/xmlparser.lua create mode 100644 lua/avante/llm_tools/write_to_file.lua create mode 100644 lua/avante/utils/prompts.lua diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 4a04ac1..50d70ab 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -671,6 +671,7 @@ M.BASE_PROVIDER_KEYS = { "disable_tools", "entra", "hide_in_model_selector", + "use_ReAct_prompt", } return M diff --git a/lua/avante/libs/xmlparser.lua b/lua/avante/libs/xmlparser.lua new file mode 100644 index 0000000..c210247 --- /dev/null +++ b/lua/avante/libs/xmlparser.lua @@ -0,0 +1,517 @@ +-- XML Parser for Lua +local XmlParser = {} + +-- 流式解析器状态 +local StreamParser = {} +StreamParser.__index = StreamParser + +-- 创建新的流式解析器实例 +function StreamParser.new() + local parser = { + buffer = "", -- 缓冲区存储未处理的内容 + stack = {}, -- 标签栈 + results = {}, -- 已完成的元素列表 + current = nil, -- 当前正在处理的元素 + root = nil, -- 当前根元素 + position = 1, -- 当前解析位置 + state = "ready", -- 解析状态: ready, parsing, incomplete, error + incomplete_tag = nil, -- 未完成的标签信息 + last_error = nil, -- 最后的错误信息 + } + setmetatable(parser, StreamParser) + return parser +end + +-- 重置解析器状态 +function StreamParser:reset() + self.buffer = "" + self.stack = {} + self.results = {} + self.current = nil + self.root = nil + self.position = 1 + self.state = "ready" + self.incomplete_tag = nil + self.last_error = nil +end + +-- 获取解析器状态信息 +function StreamParser:getStatus() + return { + state = self.state, + completed_elements = #self.results, + stack_depth = #self.stack, + buffer_size = #self.buffer, + incomplete_tag = self.incomplete_tag, + last_error = self.last_error, + has_incomplete = self.state == "incomplete" or self.incomplete_tag ~= nil, + } +end + +-- 辅助函数:去除字符串首尾空白 +local function trim(s) return s:match("^%s*(.-)%s*$") end + +-- 辅助函数:解析属性 +local function parseAttributes(attrStr) + local attrs = {} + if not attrStr or attrStr == "" then return attrs end + + -- 匹配属性模式:name="value" 或 name='value' + for name, value in attrStr:gmatch("([-_%w]+)%s*=%s*[\"']([^\"']*)[\"']") do + attrs[name] = value + end + return attrs +end + +-- 辅助函数:HTML实体解码 +local function decodeEntities(str) + local entities = { + ["<"] = "<", + [">"] = ">", + ["&"] = "&", + ["""] = '"', + ["'"] = "'", + } + + for entity, char in pairs(entities) do + str = str:gsub(entity, char) + end + + -- 处理数字实体 { 和  + str = str:gsub("&#(%d+);", function(n) + local num = tonumber(n) + return num and string.char(num) or "" + end) + str = str:gsub("&#x(%x+);", function(n) + local num = tonumber(n, 16) + return num and string.char(num) or "" + end) + + return str +end + +-- 检查是否为有效的XML标签 +local function isValidXmlTag(tag, xmlContent, tagStart) + -- 排除明显不是XML标签的内容,比如数学表达式 < 或 > + -- 检查标签是否包含合理的XML标签格式 + if not tag:match("^<[^<>]*>$") then return false end + + -- 检查是否是合法的标签格式 + if tag:match("^$") then return true end -- 结束标签 + if tag:match("^<[-_%w]+[^>]*/>$") then return true end -- 自闭合标签 + if tag:match("^<[-_%w]+[^>]*>$") then + -- 对于开始标签,进行额外的上下文检查 + local tagName = tag:match("^<([-_%w]+)") + + -- 检查是否存在对应的结束标签 + local closingTag = "" + local hasClosingTag = xmlContent:find(closingTag, tagStart) + + -- 如果是单个标签且没有结束标签,可能是文本中的引用 + if not hasClosingTag then + -- 检查前后文本,如果像是在描述而不是实际的XML结构,则不认为是有效标签 + local beforeText = xmlContent:sub(math.max(1, tagStart - 50), tagStart - 1) + local afterText = xmlContent:sub(tagStart + #tag, math.min(#xmlContent, tagStart + #tag + 50)) + + -- 如果前面有"provided in the"、"in the"等描述性文字,可能是文本引用 + if + beforeText:match("provided in the%s*$") + or beforeText:match("in the%s*$") + or beforeText:match("see the%s*$") + or beforeText:match("use the%s*$") + then + return false + end + + -- 如果后面紧跟着"tag"等描述性词汇,可能是文本引用 + if afterText:match("^%s*tag") then return false end + end + + return true + end + + return false +end + +-- 流式解析器方法:添加数据到缓冲区并解析 +function StreamParser:addData(data) + if not data or data == "" then return end + + self.buffer = self.buffer .. data + self:parseBuffer() +end + +-- 获取当前解析深度 +function StreamParser:getCurrentDepth() return #self.stack end + +-- 解析缓冲区中的数据 +function StreamParser:parseBuffer() + self.state = "parsing" + + while self.position <= #self.buffer do + local remaining = self.buffer:sub(self.position) + + -- 查找下一个标签 + local tagStart, tagEnd = remaining:find("<[^>]*>") + + if not tagStart then + -- 检查是否有未完成的开始标签(以<开始但没有>结束) + local incompleteStart = remaining:find("<[^>]*$") + if incompleteStart then + local incompleteContent = remaining:sub(incompleteStart) + -- 确保这确实是一个未完成的标签,而不是文本中的<符号 + if incompleteContent:match("^<[%w_-]") then + -- 尝试解析未完成的开始标签 + local tagName = incompleteContent:match("^<([%w_-]+)") + if tagName then + -- 处理未完成标签前的文本 + if incompleteStart > 1 then + local precedingText = trim(remaining:sub(1, incompleteStart - 1)) + if precedingText ~= "" then + if self.current then + -- 如果当前在某个标签内,添加到该标签的文本内容 + precedingText = decodeEntities(precedingText) + if self.current._text then + self.current._text = self.current._text .. precedingText + else + self.current._text = precedingText + end + else + -- 如果是顶层文本,作为独立元素添加 + local textElement = { + _name = "_text", + _text = decodeEntities(precedingText), + } + table.insert(self.results, textElement) + end + end + end + + -- 创建未完成的元素 + local element = { + _name = tagName, + _attr = {}, + _state = "incomplete_start_tag", + } + + if not self.root then + self.root = element + self.current = element + elseif self.current then + table.insert(self.stack, self.current) + if not self.current[tagName] then self.current[tagName] = {} end + table.insert(self.current[tagName], element) + self.current = element + end + + self.incomplete_tag = { + start_pos = self.position + incompleteStart - 1, + content = incompleteContent, + element = element, + } + self.state = "incomplete" + return + end + end + end + + -- 处理剩余的文本内容 + if remaining ~= "" then + if self.current then + -- 检查当前深度,如果在第一层子元素中,保持原始文本 + local currentDepth = #self.stack + if currentDepth >= 1 then + -- 在第一层子元素中,保持原始文本不变 + if self.current._text then + self.current._text = self.current._text .. remaining + else + self.current._text = remaining + end + else + -- 在根级别,进行正常的文本处理 + local text = trim(remaining) + if text ~= "" then + text = decodeEntities(text) + if self.current._text then + self.current._text = self.current._text .. text + else + self.current._text = text + end + end + end + else + -- 如果是顶层文本,作为独立元素添加 + local text = trim(remaining) + if text ~= "" then + local textElement = { + _name = "_text", + _text = decodeEntities(text), + } + table.insert(self.results, textElement) + end + end + end + self.position = #self.buffer + 1 + break + end + + local tag = remaining:sub(tagStart, tagEnd) + local actualTagStart = self.position + tagStart - 1 + local actualTagEnd = self.position + tagEnd - 1 + + -- 检查是否为有效的XML标签 + if not isValidXmlTag(tag, self.buffer, actualTagStart) then + -- 如果不是有效标签,将其作为普通文本处理 + local text = remaining:sub(1, tagEnd) + if text ~= "" then + if self.current then + -- 检查当前深度,如果在第一层子元素中,保持原始文本 + local currentDepth = #self.stack + if currentDepth >= 1 then + -- 在第一层子元素中,保持原始文本不变 + if self.current._text then + self.current._text = self.current._text .. text + else + self.current._text = text + end + else + -- 在根级别,进行正常的文本处理 + text = trim(text) + if text ~= "" then + text = decodeEntities(text) + if self.current._text then + self.current._text = self.current._text .. text + else + self.current._text = text + end + end + end + else + -- 顶层文本作为独立元素 + text = trim(text) + if text ~= "" then + local textElement = { + _name = "_text", + _text = decodeEntities(text), + } + table.insert(self.results, textElement) + end + end + end + self.position = actualTagEnd + 1 + goto continue + end + + -- 处理标签前的文本内容 + if tagStart > 1 then + local precedingText = remaining:sub(1, tagStart - 1) + if precedingText ~= "" then + if self.current then + -- 如果当前在某个标签内,添加到该标签的文本内容 + -- 检查当前深度,如果在第一层子元素中,不要进行实体解码和trim + local currentDepth = #self.stack + if currentDepth >= 1 then + -- 在第一层子元素中,保持原始文本不变 + if self.current._text then + self.current._text = self.current._text .. precedingText + else + self.current._text = precedingText + end + else + -- 在根级别,进行正常的文本处理 + precedingText = trim(precedingText) + if precedingText ~= "" then + precedingText = decodeEntities(precedingText) + if self.current._text then + self.current._text = self.current._text .. precedingText + else + self.current._text = precedingText + end + end + end + else + -- 如果是顶层文本,作为独立元素添加 + precedingText = trim(precedingText) + if precedingText ~= "" then + local textElement = { + _name = "_text", + _text = decodeEntities(precedingText), + } + table.insert(self.results, textElement) + end + end + end + end + + -- 检查当前深度,如果已经在第一层子元素中,将所有标签作为文本处理 + local currentDepth = #self.stack + if currentDepth >= 1 then + -- 检查是否是当前元素的结束标签 + if tag:match("^$") and self.current then + local tagName = tag:match("^$") + if self.current._name == tagName then + -- 这是当前元素的结束标签,正常处理 + if not self:processTag(tag) then + self.state = "error" + return + end + else + -- 不是当前元素的结束标签,作为文本处理 + if self.current._text then + self.current._text = self.current._text .. tag + else + self.current._text = tag + end + end + else + -- 在第一层子元素中,将标签作为文本处理 + if self.current then + if self.current._text then + self.current._text = self.current._text .. tag + else + self.current._text = tag + end + end + end + else + -- 处理标签 + if not self:processTag(tag) then + self.state = "error" + return + end + end + + self.position = actualTagEnd + 1 + ::continue:: + end + + -- 检查当前是否有未关闭的元素 + if self.current and self.current._state ~= "complete" then + self.current._state = "incomplete_unclosed" + self.state = "incomplete" + elseif self.state ~= "incomplete" and self.state ~= "error" then + self.state = "ready" + end +end + +-- 处理单个标签 +function StreamParser:processTag(tag) + if tag:match("^$") then + -- 结束标签 + local tagName = tag:match("^$") + if self.current and self.current._name == tagName then + -- 标记当前元素为完成状态 + self.current._state = "complete" + self.current = table.remove(self.stack) + -- 只有当栈为空且当前元素也为空时,说明完成了一个根级元素 + if #self.stack == 0 and not self.current and self.root then + table.insert(self.results, self.root) + self.root = nil + end + else + self.last_error = "Mismatched closing tag: " .. tagName + return false + end + elseif tag:match("^<[-_%w]+[^>]*/>$") then + -- 自闭合标签 + local tagName, attrs = tag:match("^<([-_%w]+)([^>]*)/>") + local element = { + _name = tagName, + _attr = parseAttributes(attrs), + _state = "complete", + children = {}, + } + + if not self.root then + -- 直接作为根级元素添加到结果中 + table.insert(self.results, element) + elseif self.current then + if not self.current.children then self.current.children = {} end + table.insert(self.current.children, element) + end + elseif tag:match("^<[-_%w]+[^>]*>$") then + -- 开始标签 + local tagName, attrs = tag:match("^<([-_%w]+)([^>]*)>") + local element = { + _name = tagName, + _attr = parseAttributes(attrs), + _state = "incomplete_open", -- 标记为未完成(等待结束标签) + children = {}, + } + + if not self.root then + self.root = element + self.current = element + elseif self.current then + table.insert(self.stack, self.current) + if not self.current.children then self.current.children = {} end + table.insert(self.current.children, element) + self.current = element + end + end + + return true +end + +-- 获取所有元素(已完成的和当前正在处理的) +function StreamParser:getAllElements() + local all_elements = {} + + -- 添加所有已完成的元素 + for _, element in ipairs(self.results) do + table.insert(all_elements, element) + end + + -- 如果有当前正在处理的元素,也添加进去 + if self.root then table.insert(all_elements, self.root) end + + return all_elements +end + +-- 获取已完成的元素(保留向后兼容性) +function StreamParser:getCompletedElements() return self.results end + +-- 获取当前未完成的元素(保留向后兼容性) +function StreamParser:getCurrentElement() return self.root end + +-- 强制完成解析(将未完成的内容作为已完成处理) +function StreamParser:finalize() + -- 首先处理当前正在解析的元素 + if self.current then + -- 递归设置所有未完成元素的状态 + local function markIncompleteElements(element) + if element._state and element._state:match("incomplete") then element._state = "incomplete_unclosed" end + -- 处理 children 数组中的子元素 + if element.children and type(element.children) == "table" then + for _, child in ipairs(element.children) do + if type(child) == "table" and child._name then markIncompleteElements(child) end + end + end + end + + -- 标记当前元素及其所有子元素为未完成状态,但保持层次结构 + markIncompleteElements(self.current) + + -- 向上遍历栈,标记所有祖先元素 + for i = #self.stack, 1, -1 do + local ancestor = self.stack[i] + if ancestor._state and ancestor._state:match("incomplete") then ancestor._state = "incomplete_unclosed" end + end + end + + -- 只有当存在根元素时才添加到结果中 + if self.root then + table.insert(self.results, self.root) + self.root = nil + end + + self.current = nil + self.stack = {} + self.state = "ready" + self.incomplete_tag = nil +end + +-- 创建流式解析器实例 +function XmlParser.createStreamParser() return StreamParser.new() end + +return XmlParser diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index cd7218e..dc8a0b7 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -25,7 +25,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true }) ---@param cb fun(title: string | nil): nil function M.summarize_chat_thread_title(content, cb) local system_prompt = - [[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]] + [[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium. /no_think]] local response_content = "" local provider = Providers.get_memory_summary_provider() M.curl({ @@ -761,7 +761,7 @@ function M._stream(opts) }, }) end - opts.on_messages_add(messages) + if opts.on_messages_add then opts.on_messages_add(messages) end local the_last_tool_use = tool_use_list[#tool_use_list] if the_last_tool_use and the_last_tool_use.name == "attempt_completion" then opts.on_stop({ reason = "complete" }) @@ -854,12 +854,8 @@ function M._stream(opts) if is_break then break end ::continue:: end - local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[] - for _, tool_use in vim.spairs(tool_use_list) do - table.insert(sorted_tool_use_list, tool_use) - end if stop_opts.reason == "complete" and Config.mode == "agentic" then - if #sorted_tool_use_list == 0 then + if #tool_use_list == 0 then local completed_attempt_completion_tool_use = nil for idx = #history_messages, 1, -1 do local message = history_messages[idx] @@ -896,7 +892,7 @@ function M._stream(opts) end end end - if stop_opts.reason == "tool_use" then return handle_next_tool_use(sorted_tool_use_list, 1, {}) end + if stop_opts.reason == "tool_use" then return handle_next_tool_use(tool_use_list, 1, {}) end if stop_opts.reason == "rate_limit" then local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*" if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end diff --git a/lua/avante/llm_tools/attempt_completion.lua b/lua/avante/llm_tools/attempt_completion.lua index dcab0cc..8a22eb5 100644 --- a/lua/avante/llm_tools/attempt_completion.lua +++ b/lua/avante/llm_tools/attempt_completion.lua @@ -32,6 +32,10 @@ M.param = { optional = true, }, }, + usage = { + result = "The result of the task. Formulate this result in a way that is final and does not require further input from the user. Don't end your result with questions or offers for further assistance.", + command = "A CLI command to execute to show a live demo of the result to the user. For example, use `open index.html` to display a created html website, or `open localhost:3000` to display a locally running development server. But DO NOT use commands like `echo` or `cat` that merely print text. This command should be valid for the current operating system. Ensure the command is properly formatted and does not contain any harmful instructions.", + }, } ---@type AvanteLLMToolReturn[] diff --git a/lua/avante/llm_tools/bash.lua b/lua/avante/llm_tools/bash.lua index b51e871..5f52261 100644 --- a/lua/avante/llm_tools/bash.lua +++ b/lua/avante/llm_tools/bash.lua @@ -183,7 +183,7 @@ M.param = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the project directory, as cwd", type = "string", }, @@ -193,6 +193,10 @@ M.param = { type = "string", }, }, + usage = { + path = "Relative path to the project directory, as cwd", + command = "Command to run", + }, } ---@type AvanteLLMToolReturn[] @@ -210,9 +214,9 @@ M.returns = { }, } ----@type AvanteLLMToolFunc<{ rel_path: string, command: string }> +---@type AvanteLLMToolFunc<{ path: string, command: string }> function M.func(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end if on_log then on_log("command: " .. opts.command) end diff --git a/lua/avante/llm_tools/dispatch_agent.lua b/lua/avante/llm_tools/dispatch_agent.lua index 465e9ad..fd17016 100644 --- a/lua/avante/llm_tools/dispatch_agent.lua +++ b/lua/avante/llm_tools/dispatch_agent.lua @@ -49,6 +49,9 @@ M.param = { }, }, required = { "prompt" }, + usage = { + prompt = "The task for the agent to perform", + }, } ---@type AvanteLLMToolReturn[] diff --git a/lua/avante/llm_tools/get_diagnostics.lua b/lua/avante/llm_tools/get_diagnostics.lua index 4d5d6de..2a217b5 100644 --- a/lua/avante/llm_tools/get_diagnostics.lua +++ b/lua/avante/llm_tools/get_diagnostics.lua @@ -19,6 +19,9 @@ M.param = { type = "string", }, }, + usage = { + path = "The path to the file in the current project scope", + }, } ---@type AvanteLLMToolReturn[] diff --git a/lua/avante/llm_tools/glob.lua b/lua/avante/llm_tools/glob.lua index 2e8f6b2..5e00478 100644 --- a/lua/avante/llm_tools/glob.lua +++ b/lua/avante/llm_tools/glob.lua @@ -18,11 +18,15 @@ M.param = { type = "string", }, { - name = "rel_path", + name = "path", description = "Relative path to the project directory, as cwd", type = "string", }, }, + usage = { + pattern = "Glob pattern", + path = "Relative path to the project directory, as cwd", + }, } ---@type AvanteLLMToolReturn[] @@ -40,9 +44,9 @@ M.returns = { }, } ----@type AvanteLLMToolFunc<{ rel_path: string, pattern: string }> +---@type AvanteLLMToolFunc<{ path: string, pattern: string }> function M.func(opts, on_log, on_complete, session_ctx) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if on_log then on_log("path: " .. abs_path) end if on_log then on_log("pattern: " .. opts.pattern) end diff --git a/lua/avante/llm_tools/grep.lua b/lua/avante/llm_tools/grep.lua index 937747f..c247c5a 100644 --- a/lua/avante/llm_tools/grep.lua +++ b/lua/avante/llm_tools/grep.lua @@ -15,7 +15,7 @@ M.param = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the project directory", type = "string", }, @@ -44,6 +44,13 @@ M.param = { optional = true, }, }, + usage = { + path = "Relative path to the project directory", + query = "Query to search for", + case_sensitive = "Whether to search case sensitively", + include_pattern = "Glob pattern to include files", + exclude_pattern = "Glob pattern to exclude files", + }, } ---@type AvanteLLMToolReturn[] @@ -61,9 +68,9 @@ M.returns = { }, } ----@type AvanteLLMToolFunc<{ rel_path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }> +---@type AvanteLLMToolFunc<{ path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }> function M.func(opts, on_log) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end diff --git a/lua/avante/llm_tools/init.lua b/lua/avante/llm_tools/init.lua index d39e350..7a90eb5 100644 --- a/lua/avante/llm_tools/init.lua +++ b/lua/avante/llm_tools/init.lua @@ -7,10 +7,10 @@ local Helpers = require("avante.llm_tools.helpers") local M = {} ----@type AvanteLLMToolFunc<{ rel_path: string }> +---@type AvanteLLMToolFunc<{ path: string }> function M.read_file_toplevel_symbols(opts, on_log) local RepoMap = require("avante.repo_map") - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if on_log then on_log("path: " .. abs_path) end if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end @@ -121,13 +121,13 @@ function M.write_global_file(opts, on_log, on_complete) end) end ----@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> +---@type AvanteLLMToolFunc<{ path: string, new_path: string }> function M.rename_file(opts, on_log, on_complete) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end - local new_abs_path = Helpers.get_abs_path(opts.new_rel_path) + local new_abs_path = Helpers.get_abs_path(opts.new_path) if on_log then on_log(abs_path .. " -> " .. new_abs_path) end if not Helpers.has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path @@ -147,13 +147,13 @@ function M.rename_file(opts, on_log, on_complete) ) end ----@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> +---@type AvanteLLMToolFunc<{ path: string, new_path: string }> function M.copy_file(opts, on_log) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end - local new_abs_path = Helpers.get_abs_path(opts.new_rel_path) + local new_abs_path = Helpers.get_abs_path(opts.new_path) if not Helpers.has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end @@ -163,9 +163,9 @@ function M.copy_file(opts, on_log) return true, nil end ----@type AvanteLLMToolFunc<{ rel_path: string }> +---@type AvanteLLMToolFunc<{ path: string }> function M.delete_file(opts, on_log, on_complete) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end @@ -181,9 +181,9 @@ function M.delete_file(opts, on_log, on_complete) end) end ----@type AvanteLLMToolFunc<{ rel_path: string }> +---@type AvanteLLMToolFunc<{ path: string }> function M.create_dir(opts, on_log, on_complete) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end if not on_complete then return false, "on_complete not provided" end @@ -198,13 +198,13 @@ function M.create_dir(opts, on_log, on_complete) end) end ----@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }> +---@type AvanteLLMToolFunc<{ path: string, new_path: string }> function M.rename_dir(opts, on_log, on_complete) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end - local new_abs_path = Helpers.get_abs_path(opts.new_rel_path) + local new_abs_path = Helpers.get_abs_path(opts.new_path) if not Helpers.has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end @@ -224,9 +224,9 @@ function M.rename_dir(opts, on_log, on_complete) ) end ----@type AvanteLLMToolFunc<{ rel_path: string }> +---@type AvanteLLMToolFunc<{ path: string }> function M.delete_dir(opts, on_log, on_complete) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end @@ -552,9 +552,9 @@ function M.rag_search(opts, on_log, on_complete) ) end ----@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }> +---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }> function M.python(opts, on_log, on_complete) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end if on_log then on_log("cwd: " .. abs_path) end @@ -634,6 +634,18 @@ function M.get_tools(user_input, history_messages) :totable() end +function M.get_tool_names() + local custom_tools = Config.custom_tools + if type(custom_tools) == "function" then custom_tools = custom_tools() end + ---@type AvanteLLMTool[] + local unfiltered_tools = vim.list_extend(vim.list_extend({}, M._tools), custom_tools) + local tool_names = {} + for _, tool in ipairs(unfiltered_tools) do + table.insert(tool_names, tool.name) + end + return tool_names +end + ---@type AvanteLLMTool[] M._tools = { require("avante.llm_tools.replace_in_file"), @@ -652,6 +664,9 @@ M._tools = { type = "string", }, }, + usage = { + query = "Query to search", + }, }, returns = { { @@ -679,11 +694,15 @@ M._tools = { type = "string", }, { - name = "rel_path", + name = "path", description = "Relative path to the project directory, as cwd", type = "string", }, }, + usage = { + code = "Python code to run", + path = "Relative path to the project directory, as cwd", + }, }, returns = { { @@ -711,6 +730,9 @@ M._tools = { type = "string", }, }, + usage = { + scope = "Scope for the git diff (e.g. specific files or directories)", + }, }, returns = { { @@ -744,6 +766,10 @@ M._tools = { optional = true, }, }, + usage = { + message = "Commit message to use", + scope = "Scope for staging files (e.g. specific files or directories)", + }, }, returns = { { @@ -768,11 +794,14 @@ M._tools = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the file in current project scope", type = "string", }, }, + usage = { + path = "Relative path to the file in current project scope", + }, }, returns = { { @@ -790,7 +819,7 @@ M._tools = { }, require("avante.llm_tools.str_replace"), require("avante.llm_tools.view"), - require("avante.llm_tools.create"), + require("avante.llm_tools.write_to_file"), require("avante.llm_tools.insert"), require("avante.llm_tools.undo_edit"), { @@ -820,6 +849,9 @@ M._tools = { type = "string", }, }, + usage = { + abs_path = "Absolute path to the file in global scope", + }, }, returns = { { @@ -867,6 +899,10 @@ M._tools = { type = "string", }, }, + usage = { + abs_path = "The path to the file in the current project scope", + content = "The content to write to the file", + }, }, returns = { { @@ -889,16 +925,20 @@ M._tools = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the file in current project scope", type = "string", }, { - name = "new_rel_path", + name = "new_path", description = "New relative path for the file", type = "string", }, }, + usage = { + path = "Relative path to the file in current project scope", + new_path = "New relative path for the file", + }, }, returns = { { @@ -921,11 +961,14 @@ M._tools = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the file in current project scope", type = "string", }, }, + usage = { + path = "Relative path to the file in current project scope", + }, }, returns = { { @@ -948,11 +991,14 @@ M._tools = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the project directory", type = "string", }, }, + usage = { + path = "Relative path to the project directory", + }, }, returns = { { @@ -975,16 +1021,20 @@ M._tools = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the project directory", type = "string", }, { - name = "new_rel_path", + name = "new_path", description = "New relative path for the directory", type = "string", }, }, + usage = { + path = "Relative path to the project directory", + new_path = "New relative path for the directory", + }, }, returns = { { @@ -1007,11 +1057,14 @@ M._tools = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the project directory", type = "string", }, }, + usage = { + path = "Relative path to the project directory", + }, }, returns = { { @@ -1042,6 +1095,9 @@ M._tools = { type = "string", }, }, + usage = { + query = "Query to search", + }, }, returns = { { @@ -1069,6 +1125,9 @@ M._tools = { type = "string", }, }, + usage = { + url = "Url to fetch markdown from", + }, }, returns = { { diff --git a/lua/avante/llm_tools/insert.lua b/lua/avante/llm_tools/insert.lua index 60fc3c2..1247a86 100644 --- a/lua/avante/llm_tools/insert.lua +++ b/lua/avante/llm_tools/insert.lua @@ -32,6 +32,11 @@ M.param = { type = "string", }, }, + usage = { + path = "The path to the file to modify", + insert_line = "The line number after which to insert the text (0 for beginning of file)", + new_str = "The text to insert", + }, } ---@type AvanteLLMToolReturn[] diff --git a/lua/avante/llm_tools/ls.lua b/lua/avante/llm_tools/ls.lua index 2d19e7e..8d6c6af 100644 --- a/lua/avante/llm_tools/ls.lua +++ b/lua/avante/llm_tools/ls.lua @@ -14,7 +14,7 @@ M.param = { type = "table", fields = { { - name = "rel_path", + name = "path", description = "Relative path to the project directory", type = "string", }, @@ -24,6 +24,10 @@ M.param = { type = "integer", }, }, + usage = { + path = "Relative path to the project directory", + max_depth = "Maximum depth of the directory", + }, } ---@type AvanteLLMToolReturn[] @@ -41,9 +45,9 @@ M.returns = { }, } ----@type AvanteLLMToolFunc<{ rel_path: string, max_depth?: integer }> +---@type AvanteLLMToolFunc<{ path: string, max_depth?: integer }> function M.func(opts, on_log) - local abs_path = Helpers.get_abs_path(opts.rel_path) + local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end if on_log then on_log("path: " .. abs_path) end if on_log then on_log("max depth: " .. tostring(opts.max_depth)) end diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index 01c4d4b..6d3ca39 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -16,6 +16,8 @@ M.name = "replace_in_file" M.description = "Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file." +-- function M.enabled() return Config.provider:match("ollama") == nil end + ---@type AvanteLLMToolParam M.param = { type = "table", @@ -57,6 +59,10 @@ One or more SEARCH/REPLACE blocks following this exact format: type = "string", }, }, + usage = { + path = "File path here", + diff = "Search and replace blocks here", + }, } ---@type AvanteLLMToolReturn[] diff --git a/lua/avante/llm_tools/str_replace.lua b/lua/avante/llm_tools/str_replace.lua index 495e0a1..6cced50 100644 --- a/lua/avante/llm_tools/str_replace.lua +++ b/lua/avante/llm_tools/str_replace.lua @@ -1,8 +1,4 @@ -local Path = require("plenary.path") -local Utils = require("avante.utils") local Base = require("avante.llm_tools.base") -local Helpers = require("avante.llm_tools.helpers") -local Diff = require("avante.diff") local Config = require("avante.config") ---@class AvanteLLMTool @@ -13,6 +9,7 @@ M.name = "str_replace" M.description = "The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits." +-- function M.enabled() return Config.provider:match("ollama") ~= nil end function M.enabled() return false end ---@type AvanteLLMToolParam @@ -35,6 +32,11 @@ M.param = { type = "string", }, }, + usage = { + path = "File path here", + old_str = "old str here", + new_str = "new str here", + }, } ---@type AvanteLLMToolReturn[] diff --git a/lua/avante/llm_tools/undo_edit.lua b/lua/avante/llm_tools/undo_edit.lua index 919b862..b4f2ed5 100644 --- a/lua/avante/llm_tools/undo_edit.lua +++ b/lua/avante/llm_tools/undo_edit.lua @@ -22,6 +22,9 @@ M.param = { type = "string", }, }, + usage = { + path = "The path to the file whose last edit should be undone", + }, } ---@type AvanteLLMToolReturn[] diff --git a/lua/avante/llm_tools/view.lua b/lua/avante/llm_tools/view.lua index 28d2e21..4d134d5 100644 --- a/lua/avante/llm_tools/view.lua +++ b/lua/avante/llm_tools/view.lua @@ -39,23 +39,22 @@ M.param = { type = "string", }, { - name = "view_range", - description = "The range of the file to view. This parameter only applies when viewing files, not directories.", - type = "object", + name = "start_line", + description = "The start line of the view range, 1-indexed", + type = "integer", optional = true, - fields = { - { - name = "start_line", - description = "The start line of the range, 1-indexed", - type = "integer", - }, - { - name = "end_line", - description = "The end line of the range, 1-indexed, and -1 for the end line means read to the end of the file", - type = "integer", - }, - }, }, + { + name = "end_line", + description = "The end line of the view range, 1-indexed, and -1 for the end line means read to the end of the file", + type = "integer", + optional = true, + }, + }, + usage = { + path = "The path to the file in the current project scope", + start_line = "The start line of the view range, 1-indexed", + end_line = "The end line of the view range, 1-indexed, and -1 for the end line means read to the end of the file", }, } @@ -74,7 +73,7 @@ M.returns = { }, } ----@type AvanteLLMToolFunc<{ path: string, view_range?: { start_line: integer, end_line: integer } }> +---@type AvanteLLMToolFunc<{ path: string, start_line?: integer, end_line?: integer }> function M.func(opts, on_log, on_complete, session_ctx) if on_log then on_log("path: " .. opts.path) end local abs_path = Helpers.get_abs_path(opts.path) @@ -84,11 +83,9 @@ function M.func(opts, on_log, on_complete, session_ctx) local file = io.open(abs_path, "r") if not file then return false, "file not found: " .. abs_path end local lines = Utils.read_file_from_buf_or_disk(abs_path) - if opts.view_range then - local start_line = opts.view_range.start_line - local end_line = opts.view_range.end_line - if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end - end + local start_line = opts.start_line + local end_line = opts.end_line + if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end local truncated_lines = {} local is_truncated = false local size = 0 diff --git a/lua/avante/llm_tools/write_to_file.lua b/lua/avante/llm_tools/write_to_file.lua new file mode 100644 index 0000000..08041da --- /dev/null +++ b/lua/avante/llm_tools/write_to_file.lua @@ -0,0 +1,74 @@ +local Utils = require("avante.utils") +local Base = require("avante.llm_tools.base") +local Helpers = require("avante.llm_tools.helpers") + +---@class AvanteLLMTool +local M = setmetatable({}, Base) + +M.name = "write_to_file" + +M.description = + "Request to write content to a file at the specified path. If the file exists, it will be overwritten with the provided content. If the file doesn't exist, it will be created. This tool will automatically create any directories needed to write the file." + +function M.enabled() return require("avante.config").mode == "agentic" end + +---@type AvanteLLMToolParam +M.param = { + type = "table", + fields = { + { + name = "path", + get_description = function() + local res = ("The path of the file to write to (relative to the current working directory {{cwd}})"):gsub( + "{{cwd}}", + Utils.get_project_root() + ) + return res + end, + type = "string", + }, + { + name = "content", + description = "The content to write to the file. ALWAYS provide the COMPLETE intended content of the file, without any truncation or omissions. You MUST include ALL parts of the file, even if they haven't been modified.", + type = "string", + }, + }, + usage = { + path = "File path here", + content = "File content here", + }, +} + +---@type AvanteLLMToolReturn[] +M.returns = { + { + name = "success", + description = "Whether the file was created successfully", + type = "boolean", + }, + { + name = "error", + description = "Error message if the file was not created successfully", + type = "string", + optional = true, + }, +} + +---@type AvanteLLMToolFunc<{ path: string, content: string }> +function M.func(opts, on_log, on_complete, session_ctx) + if not on_complete then return false, "on_complete not provided" end + local abs_path = Helpers.get_abs_path(opts.path) + if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if opts.content == nil then return false, "content not provided" end + local old_lines = Utils.read_file_from_buf_or_disk(abs_path) + local old_content = table.concat(old_lines or {}, "\n") + local replace_in_file = require("avante.llm_tools.replace_in_file") + local diff = "<<<<<<< SEARCH\n" .. old_content .. "\n=======\n" .. opts.content .. "\n>>>>>>> REPLACE" + local new_opts = { + path = opts.path, + diff = diff, + } + return replace_in_file.func(new_opts, on_log, on_complete, session_ctx) +end + +return M diff --git a/lua/avante/providers/ollama.lua b/lua/avante/providers/ollama.lua index ada016a..a3569fb 100644 --- a/lua/avante/providers/ollama.lua +++ b/lua/avante/providers/ollama.lua @@ -1,10 +1,15 @@ local Utils = require("avante.utils") -local P = require("avante.providers") +local Providers = require("avante.providers") local Config = require("avante.config") +local Clipboard = require("avante.clipboard") +local HistoryMessage = require("avante.history_message") +local Prompts = require("avante.utils.prompts") ---@class AvanteProviderFunctor local M = {} +setmetatable(M, { __index = Providers.openai }) + M.api_key_name = "" -- Ollama typically doesn't require API keys for local use M.role_map = { @@ -12,35 +17,182 @@ M.role_map = { assistant = "assistant", } -M.parse_messages = P.openai.parse_messages -M.is_reasoning_model = P.openai.is_reasoning_model +function M:parse_messages(opts) + local messages = {} + local provider_conf, _ = Providers.parse_config(self) + + local system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) + + if self.is_reasoning_model(provider_conf.model) then + table.insert(messages, { role = "developer", content = system_prompt }) + else + table.insert(messages, { role = "system", content = system_prompt }) + end + + vim.iter(opts.messages):each(function(msg) + if type(msg.content) == "string" then + table.insert(messages, { role = self.role_map[msg.role], content = msg.content }) + elseif type(msg.content) == "table" then + local content = {} + for _, item in ipairs(msg.content) do + if type(item) == "string" then + table.insert(content, { type = "text", text = item }) + elseif item.type == "text" then + table.insert(content, { type = "text", text = item.text }) + elseif item.type == "image" then + table.insert(content, { + type = "image_url", + image_url = { + url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data, + }, + }) + end + end + if not provider_conf.disable_tools then + if msg.content[1].type == "tool_result" then + local tool_use = nil + for _, msg_ in ipairs(opts.messages) do + if type(msg_.content) == "table" and #msg_.content > 0 then + if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then + tool_use = msg_ + break + end + end + end + if tool_use then + msg.role = "user" + table.insert(content, { + type = "text", + text = "[" + .. tool_use.content[1].name + .. " for '" + .. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "") + .. "'] Result:", + }) + table.insert(content, { + type = "text", + text = msg.content[1].content, + }) + end + end + end + if #content > 0 then + local text_content = {} + for _, item in ipairs(content) do + if type(item) == "table" and item.type == "text" then table.insert(text_content, item.text) end + end + table.insert(messages, { role = self.role_map[msg.role], content = table.concat(text_content, "\n\n") }) + end + end + end) + + if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then + local message_content = messages[#messages].content + if type(message_content) ~= "table" or message_content[1] == nil then + message_content = { { type = "text", text = message_content } } + end + for _, image_path in ipairs(opts.image_paths) do + table.insert(message_content, { + type = "image_url", + image_url = { + url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path), + }, + }) + end + messages[#messages].content = message_content + end + + local final_messages = {} + local prev_role = nil + + vim.iter(messages):each(function(message) + local role = message.role + if role == prev_role and role ~= "tool" then + if role == self.role_map["assistant"] then + table.insert(final_messages, { role = self.role_map["user"], content = "Ok" }) + else + table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." }) + end + end + prev_role = role + table.insert(final_messages, message) + end) + + return final_messages +end function M:is_disable_stream() return false end +---@class avante.OllamaFunction +---@field name string +---@field arguments table + +---@class avante.OllamaToolCall +---@field function avante.OllamaFunction + +---@param tool_calls avante.OllamaToolCall[] +---@param opts AvanteLLMStreamOptions +function M:add_tool_use_messages(tool_calls, opts) + local msgs = {} + for _, tool_call in ipairs(tool_calls) do + local id = Utils.uuid() + local func = tool_call["function"] + local msg = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "tool_use", + name = func.name, + id = id, + input = func.arguments, + }, + }, + }, { + state = "generated", + uuid = id, + }) + table.insert(msgs, msg) + end + if opts.on_messages_add then opts.on_messages_add(msgs) end +end + function M:parse_stream_data(ctx, data, opts) - local ok, json_data = pcall(vim.json.decode, data) - if not ok or not json_data then + local ok, jsn = pcall(vim.json.decode, data) + if not ok or not jsn then -- Add debug logging Utils.debug("Failed to parse JSON", data) return end - if json_data.message and json_data.message.content then - local content = json_data.message.content - P.openai:add_text_message(ctx, content, "generating", opts) - if content and content ~= "" and opts.on_chunk then opts.on_chunk(content) end + if jsn.message then + if jsn.message.content then + local content = jsn.message.content + if content and content ~= "" then + Providers.openai:add_text_message(ctx, content, "generating", opts) + if opts.on_chunk then opts.on_chunk(content) end + end + end + if jsn.message.tool_calls then + ctx.has_tool_use = true + local tool_calls = jsn.message.tool_calls + self:add_tool_use_messages(tool_calls, opts) + end end - if json_data.done then - P.openai:finish_pending_messages(ctx, opts) - opts.on_stop({ reason = "complete" }) + if jsn.done then + Providers.openai:finish_pending_messages(ctx, opts) + if ctx.has_tool_use or (ctx.tool_use_list and #ctx.tool_use_list > 0) then + opts.on_stop({ reason = "tool_use" }) + else + opts.on_stop({ reason = "complete" }) + end return end end ---@param prompt_opts AvantePromptOptions function M:parse_curl_args(prompt_opts) - local provider_conf, request_body = P.parse_config(self) + local provider_conf, request_body = Providers.parse_config(self) local keep_alive = provider_conf.keep_alive or "5m" if not provider_conf.model or provider_conf.model == "" then error("Ollama model must be specified in config") end @@ -50,7 +202,7 @@ function M:parse_curl_args(prompt_opts) ["Accept"] = "application/json", } - if P.env.require_api_key(provider_conf) then + if Providers.env.require_api_key(provider_conf) then local api_key = self.parse_api_key() if api_key and api_key ~= "" then headers["Authorization"] = "Bearer " .. api_key @@ -66,7 +218,6 @@ function M:parse_curl_args(prompt_opts) model = provider_conf.model, messages = self:parse_messages(prompt_opts), stream = true, - system = prompt_opts.system_prompt, keep_alive = keep_alive, }, request_body), } diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 9bf213c..a89fd8a 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -3,6 +3,9 @@ local Config = require("avante.config") local Clipboard = require("avante.clipboard") local Providers = require("avante.providers") local HistoryMessage = require("avante.history_message") +local XMLParser = require("avante.libs.xmlparser") +local Prompts = require("avante.utils.prompts") +local LlmTools = require("avante.llm_tools") ---@class AvanteProviderFunctor local M = {} @@ -76,10 +79,15 @@ function M:parse_messages(opts) local messages = {} local provider_conf, _ = Providers.parse_config(self) + local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true + local system_prompt = opts.system_prompt + + if use_ReAct_prompt then system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) end + if self.is_reasoning_model(provider_conf.model) then - table.insert(messages, { role = "developer", content = opts.system_prompt }) + table.insert(messages, { role = "developer", content = system_prompt }) else - table.insert(messages, { role = "system", content = opts.system_prompt }) + table.insert(messages, { role = "system", content = system_prompt }) end local has_tool_use = false @@ -103,22 +111,50 @@ function M:parse_messages(opts) url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data, }, }) - elseif item.type == "tool_use" then + elseif item.type == "tool_use" and not use_ReAct_prompt then has_tool_use = true table.insert(tool_calls, { id = item.id, type = "function", ["function"] = { name = item.name, arguments = vim.json.encode(item.input) }, }) - elseif item.type == "tool_result" and has_tool_use then + elseif item.type == "tool_result" and has_tool_use and not use_ReAct_prompt then table.insert( tool_results, { tool_call_id = item.tool_use_id, content = item.is_error and "Error: " .. item.content or item.content } ) end end + if not provider_conf.disable_tools and use_ReAct_prompt then + if msg.content[1].type == "tool_result" then + local tool_use = nil + for _, msg_ in ipairs(opts.messages) do + if type(msg_.content) == "table" and #msg_.content > 0 then + if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then + tool_use = msg_ + break + end + end + end + if tool_use then + msg.role = "user" + table.insert(content, { + type = "text", + text = "[" + .. tool_use.content[1].name + .. " for '" + .. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "") + .. "'] Result:", + }) + table.insert(content, { + type = "text", + text = msg.content[1].content, + }) + end + end + end if #content > 0 then table.insert(messages, { role = self.role_map[msg.role], content = content }) end - if not provider_conf.disable_tools then + if not provider_conf.disable_tools and not use_ReAct_prompt then if #tool_calls > 0 then local last_message = messages[#messages] if last_message and last_message.role == self.role_map["assistant"] and last_message.tool_calls then @@ -183,7 +219,10 @@ function M:finish_pending_messages(ctx, opts) end end +local llm_tool_names = nil + function M:add_text_message(ctx, text, state, opts) + if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end if ctx.content == nil then ctx.content = "" end ctx.content = ctx.content .. text local msg = HistoryMessage:new({ @@ -194,7 +233,75 @@ function M:add_text_message(ctx, text, state, opts) uuid = ctx.content_uuid, }) ctx.content_uuid = msg.uuid - if opts.on_messages_add then opts.on_messages_add({ msg }) end + local msgs = { msg } + local stream_parser = XMLParser.createStreamParser() + stream_parser:addData(ctx.content) + local xml = stream_parser:getAllElements() + if xml then + local new_content_list = {} + local xml_md_openned = false + for idx, item in ipairs(xml) do + if item._name == "_text" then + local cleaned_lines = {} + local lines = vim.split(item._text, "\n") + for _, line in ipairs(lines) do + if line:match("^```xml") or line:match("^```tool_code") or line:match("^```tool_use") then + xml_md_openned = true + elseif line:match("^```$") then + if xml_md_openned then + xml_md_openned = false + else + table.insert(cleaned_lines, line) + end + else + table.insert(cleaned_lines, line) + end + end + table.insert(new_content_list, table.concat(cleaned_lines, "\n")) + goto continue + end + if not vim.tbl_contains(llm_tool_names, item._name) then goto continue end + local ok, input = pcall(vim.json.decode, item._text) + if not ok and item.children and #item.children > 0 then + input = {} + for _, item_ in ipairs(item.children) do + local ok_, input_ = pcall(vim.json.decode, item_._text) + if ok_ and input_ then + input[item_._name] = input_ + else + input[item_._name] = item_._text + end + end + end + if input then + local tool_use_id = Utils.uuid() + local msg_ = HistoryMessage:new({ + role = "assistant", + content = { + { + type = "tool_use", + name = item._name, + id = tool_use_id, + input = input, + }, + }, + }, { + state = state, + uuid = ctx.content_uuid .. "-" .. idx, + }) + msgs[#msgs + 1] = msg_ + ctx.tool_use_list = ctx.tool_use_list or {} + ctx.tool_use_list[#ctx.tool_use_list + 1] = { + id = tool_use_id, + name = item._name, + input_json = input, + } + end + if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end + ::continue:: + end + end + if opts.on_messages_add then opts.on_messages_add(msgs) end end function M:add_thinking_message(ctx, text, state, opts) @@ -242,7 +349,11 @@ end function M:parse_response(ctx, data_stream, _, opts) if data_stream:match('"%[DONE%]":') then self:finish_pending_messages(ctx, opts) - opts.on_stop({ reason = "complete" }) + if ctx.tool_use_list and #ctx.tool_use_list > 0 then + opts.on_stop({ reason = "tool_use" }) + else + opts.on_stop({ reason = "complete" }) + end return end if data_stream == "[DONE]" then return end @@ -316,9 +427,13 @@ function M:parse_response(ctx, data_stream, _, opts) self:add_text_message(ctx, delta.content, "generating", opts) end end - if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then + if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then self:finish_pending_messages(ctx, opts) - opts.on_stop({ reason = "complete" }) + if ctx.tool_use_list and #ctx.tool_use_list > 0 then + opts.on_stop({ reason = "tool_use", usage = jsn.usage }) + else + opts.on_stop({ reason = "complete", usage = jsn.usage }) + end end if choice.finish_reason == "tool_calls" then self:finish_pending_messages(ctx, opts) @@ -372,8 +487,10 @@ function M:parse_curl_args(prompt_opts) self.set_allowed_params(provider_conf, request_body) + local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true + local tools = nil - if not disable_tools and prompt_opts.tools then + if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then tools = {} for _, tool in ipairs(prompt_opts.tools) do table.insert(tools, self:transform_tool(tool)) diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 474f4fe..c5c88fd 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -56,6 +56,7 @@ Sidebar.__index = Sidebar ---@field current_state avante.GenerateState | nil ---@field state_timer table | nil ---@field state_spinner_chars string[] +---@field thinking_spinner_chars string[] ---@field state_spinner_idx integer ---@field state_extmark_id integer | nil ---@field scroll boolean @@ -84,12 +85,16 @@ function Sidebar:new(id) current_state = nil, state_timer = nil, state_spinner_chars = { "·", "✢", "✳", "∗", "✻", "✽" }, + thinking_spinner_chars = { "🤯", "🙄" }, state_spinner_idx = 1, state_extmark_id = nil, scroll = true, input_hint_window = nil, ask_opts = {}, old_result_lines = {}, + -- 缓存相关字段 + _cached_history_lines = nil, + _history_cache_invalidated = true, }, Sidebar) end @@ -1528,42 +1533,75 @@ end ---@param opts? {focus?: boolean, scroll?: boolean, backspace?: integer, callback?: fun(): nil} whether to focus the result view function Sidebar:update_content(content, opts) if not self.result_container or not self.result_container.bufnr then return end + + -- 提前验证容器有效性,避免后续无效操作 + if not Utils.is_valid_container(self.result_container) then return end + opts = vim.tbl_deep_extend("force", { focus = false, scroll = self.scroll, callback = nil }, opts or {}) - local history_lines = self.get_history_lines(self.chat_history) - if content ~= nil and content ~= "" then - table.insert(history_lines, Line:new({ { "" } })) - local content_lines = vim.split(content, "\n") - for _, line in ipairs(content_lines) do - table.insert(history_lines, Line:new({ { line } })) - end + + -- 缓存历史行,避免重复计算 + local history_lines + if not self._cached_history_lines or self._history_cache_invalidated then + history_lines = self.get_history_lines(self.chat_history) + self._cached_history_lines = history_lines + self._history_cache_invalidated = false + else + history_lines = vim.deepcopy(self._cached_history_lines) end - vim.defer_fn(function() - self:clear_state() - local f = function() - if not Utils.is_valid_container(self.result_container) then return end - Utils.unlock_buf(self.result_container.bufnr) - Utils.update_buffer_lines( - RESULT_BUF_HL_NAMESPACE, - self.result_container.bufnr, - self.old_result_lines, - history_lines - ) - Utils.lock_buf(self.result_container.bufnr) - self.old_result_lines = history_lines - api.nvim_set_option_value("filetype", "Avante", { buf = self.result_container.bufnr }) - vim.schedule(function() vim.cmd("redraw") end) - if opts.focus and not self:is_focused_on_result() then - --- set cursor to bottom of result view - xpcall(function() api.nvim_set_current_win(self.result_container.winid) end, function(err) return err end) - end - if opts.scroll then Utils.buf_scroll_to_end(self.result_container.bufnr) end + -- 批量处理内容行,减少表操作 + if content ~= nil and content ~= "" then + local content_lines = vim.split(content, "\n") + local new_lines = { Line:new({ { "" } }) } - if opts.callback ~= nil then opts.callback() end + -- 预分配表大小,提升性能 + for i = 1, #content_lines do + new_lines[i + 1] = Line:new({ { content_lines[i] } }) end - f() + + -- 一次性扩展,而不是逐个插入 + vim.list_extend(history_lines, new_lines) + end + + -- 使用 vim.schedule 而不是 vim.defer_fn(0),性能更好 + -- 再次检查容器有效性 + if not Utils.is_valid_container(self.result_container) then return end + + self:clear_state() + + -- 批量更新操作 + local bufnr = self.result_container.bufnr + Utils.unlock_buf(bufnr) + + Utils.update_buffer_lines(RESULT_BUF_HL_NAMESPACE, bufnr, self.old_result_lines, history_lines) + + -- 缓存结果行 + self.old_result_lines = history_lines + + -- 批量设置选项 + api.nvim_set_option_value("filetype", "Avante", { buf = bufnr }) + Utils.lock_buf(bufnr) + + -- 处理焦点和滚动 + if opts.focus and not self:is_focused_on_result() then + xpcall(function() api.nvim_set_current_win(self.result_container.winid) end, function(err) + Utils.debug("Failed to set current win:", err) + return err + end) + end + + if opts.scroll then Utils.buf_scroll_to_end(bufnr) end + + -- 延迟执行回调和状态渲染 + if opts.callback then vim.schedule(opts.callback) end + + -- 最后渲染状态 + vim.schedule(function() self:render_state() - end, 0) + -- 延迟重绘,避免阻塞 + vim.defer_fn(function() vim.cmd("redraw") end, 10) + end) + return self end @@ -1842,8 +1880,8 @@ function Sidebar:render_state() if self.state_extmark_id then api.nvim_buf_del_extmark(self.result_container.bufnr, STATE_NAMESPACE, self.state_extmark_id) end - local spinner_char = self.state_spinner_chars[self.state_spinner_idx] - self.state_spinner_idx = (self.state_spinner_idx % #self.state_spinner_chars) + 1 + local spinner_chars = self.state_spinner_chars + if self.current_state == "thinking" then spinner_chars = self.thinking_spinner_chars end local hl = "AvanteStateSpinnerGenerating" if self.current_state == "tool calling" then hl = "AvanteStateSpinnerToolCalling" end if self.current_state == "failed" then hl = "AvanteStateSpinnerFailed" end @@ -1851,6 +1889,8 @@ function Sidebar:render_state() if self.current_state == "searching" then hl = "AvanteStateSpinnerSearching" end if self.current_state == "thinking" then hl = "AvanteStateSpinnerThinking" end if self.current_state == "compacting" then hl = "AvanteStateSpinnerCompacting" end + local spinner_char = spinner_chars[self.state_spinner_idx] + self.state_spinner_idx = (self.state_spinner_idx % #spinner_chars) + 1 if self.current_state ~= "generating" and self.current_state ~= "tool calling" @@ -1911,6 +1951,10 @@ function Sidebar:new_chat(args, cb) if cb then cb(args) end end +local _save_history = Utils.debounce(function(self) Path.history.save(self.code.bufnr, self.chat_history) end, 3000) + +local save_history = vim.schedule_wrap(_save_history) + ---@param messages avante.HistoryMessage | avante.HistoryMessage[] function Sidebar:add_history_messages(messages) local history_messages = Utils.get_history_messages(self.chat_history) @@ -1934,22 +1978,30 @@ function Sidebar:add_history_messages(messages) end end self.chat_history.messages = history_messages - Path.history.save(self.code.bufnr, self.chat_history) + -- 历史消息变更时,标记缓存失效 + self._history_cache_invalidated = true + save_history(self) if self.chat_history.title == "untitled" and #messages > 0 and messages[1].just_for_display ~= true and messages[1].state == "generated" then - self.chat_history.title = "generating..." - Llm.summarize_chat_thread_title(messages[1].message.content, function(title) - if title then - self.chat_history.title = title - else - self.chat_history.title = "untitled" - end - Path.history.save(self.code.bufnr, self.chat_history) - end) + -- self.chat_history.title = "generating..." + -- Llm.summarize_chat_thread_title(messages[1].message.content, function(title) + -- if title then + -- self.chat_history.title = title + -- else + -- self.chat_history.title = "untitled" + -- end + -- save_history(self) + -- end) + local first_msg_text = Utils.message_to_text(messages[1], messages) + local lines_ = vim.split(first_msg_text, "\n") + if #lines_ > 0 then + self.chat_history.title = lines_[1] + save_history(self) + end end local last_message = messages[#messages] if last_message then @@ -1964,7 +2016,10 @@ function Sidebar:add_history_messages(messages) self.current_state = "generating" end end - self:update_content("") + xpcall(function() self:update_content("") end, function(err) + Utils.debug("Failed to update content:", err) + return nil + end) end ---@param messages AvanteLLMMessage | AvanteLLMMessage[] @@ -2135,6 +2190,8 @@ end function Sidebar:reload_chat_history() if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end self.chat_history = Path.history.load(self.code.bufnr) + -- 重新加载历史时,标记缓存失效 + self._history_cache_invalidated = true end ---@return avante.HistoryMessage[] diff --git a/lua/avante/templates/agentic.avanterules b/lua/avante/templates/agentic.avanterules index d3f00c4..51efb25 100644 --- a/lua/avante/templates/agentic.avanterules +++ b/lua/avante/templates/agentic.avanterules @@ -5,9 +5,11 @@ RULES +- NEVER reply the updated code. + - Always reply to the user in the same language they are using. -- Don't just provide code suggestions, use the `replace_in_file` tool to help users fulfill their needs. +- Don't just provide code suggestions, use the `replace_in_file` tool or `str_replace` tool to help users fulfill their needs. - After the tool call is complete, please do not output the entire file content. diff --git a/lua/avante/types.lua b/lua/avante/types.lua index c390973..c934cd2 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -231,6 +231,7 @@ vim.g.avante_login = vim.g.avante_login ---@field disable_tools? boolean ---@field entra? boolean ---@field hide_in_model_selector? boolean +---@field use_ReAct_prompt? boolean --- ---@class AvanteSupportedProvider: AvanteDefaultBaseProvider ---@field __inherited_from? string @@ -396,6 +397,7 @@ vim.g.avante_login = vim.g.avante_login ---@class AvanteLLMToolParam ---@field type 'table' ---@field fields AvanteLLMToolParamField[] +---@field usage? table ---@class AvanteLLMToolParamField ---@field name string diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 000e750..603e869 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -576,7 +576,7 @@ end --- remove indentation from code: spaces or tabs function M.remove_indentation(code) if not code then return code end - return code:gsub("^%s*", ""):gsub("%s*$", "") + return code:gsub("%s*", "") end function M.relative_path(absolute) @@ -1056,12 +1056,12 @@ function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines) if #diffs == 0 then return end for _, diff in ipairs(diffs) do local lines = diff.content - -- M.debug("lines", lines) local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines) vim.api.nvim_buf_set_lines(bufnr, diff.start_line - 1, diff.end_line - 1, false, text_lines) for i, line in ipairs(lines) do line:set_highlights(ns_id, bufnr, diff.start_line + i - 2) end + vim.cmd("redraw") end end @@ -1467,6 +1467,23 @@ function M.text_to_lines(text, hl) return lines end +---@param thinking_text string +---@param hl string | nil +---@return avante.ui.Line[] +function M.thinking_to_lines(thinking_text, hl) + local Line = require("avante.ui.line") + local text_lines = vim.split(thinking_text, "\n") + local lines = {} + table.insert(lines, Line:new({ { M.icon("🤔 ") .. "Thought content:" } })) + table.insert(lines, Line:new({ { "" } })) + for _, text_line in ipairs(text_lines) do + local piece = { "> " .. text_line } + if hl then table.insert(piece, hl) end + table.insert(lines, Line:new({ piece })) + end + return lines +end + ---@param item AvanteLLMMessageContentItem ---@param message avante.HistoryMessage ---@param messages avante.HistoryMessage[] @@ -1475,6 +1492,9 @@ function M.message_content_item_to_lines(item, message, messages) local Line = require("avante.ui.line") if type(item) == "string" then return M.text_to_lines(item) end if type(item) == "table" then + if item.type == "thinking" or item.type == "redacted_thinking" then + return M.thinking_to_lines(item.thinking or item.data or "") + end if item.type == "text" then return M.text_to_lines(item.text) end if item.type == "image" then return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) } @@ -1520,18 +1540,6 @@ function M.message_content_item_to_lines(item, message, messages) end end end - elseif tool_result_message then - local tool_result = tool_result_message.message.content[1] - if tool_result.content then - local result_lines = vim.split(tool_result.content, "\n") - for idx, line in ipairs(result_lines) do - if idx ~= #result_lines then - table.insert(lines, Line:new({ { "│" }, { string.format(" %s", line) } })) - else - table.insert(lines, Line:new({ { "╰─" }, { string.format(" %s", line) } })) - end - end - end end return lines end diff --git a/lua/avante/utils/prompts.lua b/lua/avante/utils/prompts.lua new file mode 100644 index 0000000..091d469 --- /dev/null +++ b/lua/avante/utils/prompts.lua @@ -0,0 +1,144 @@ +local M = {} + +---@param provider_conf AvanteDefaultBaseProvider +---@param opts AvantePromptOptions +---@return string +function M.get_ReAct_system_prompt(provider_conf, opts) + local system_prompt = opts.system_prompt + local disable_tools = provider_conf.disable_tools or false + if not disable_tools and opts.tools then + local tools_prompts = [[ +==== + +TOOL USE + +You have access to a set of tools that are executed upon the user's approval. You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use. + +# Tool Use Formatting + +Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure: + + +value1 +value2 +... + + +For example: + + +src/main.js + + +Always adhere to this format for the tool use to ensure proper parsing and execution. + +# Tools + +]] + for _, tool in ipairs(opts.tools) do + local tool_prompt = ([[ +## {{name}} +Description: {{description}} +Parameters: +]]):gsub("{{name}}", tool.name):gsub( + "{{description}}", + tool.get_description and tool.get_description() or (tool.description or "") + ) + for _, field in ipairs(tool.param.fields) do + if field.optional then + tool_prompt = tool_prompt .. string.format(" - %s: %s\n", field.name, field.description) + else + tool_prompt = tool_prompt + .. string.format( + " - %s: (required) %s\n", + field.name, + field.get_description and field.get_description() or (field.description or "") + ) + end + end + if tool.param.usage then + tool_prompt = tool_prompt + .. ("Usage:\n<{{name}}>\n"):gsub("{{([%w_]+)}}", function(name) return tool[name] end) + for k, v in pairs(tool.param.usage) do + tool_prompt = tool_prompt .. "<" .. k .. ">" .. tostring(v) .. "\n" + end + tool_prompt = tool_prompt .. ("\n"):gsub("{{([%w_]+)}}", function(name) return tool[name] end) + end + tools_prompts = tools_prompts .. tool_prompt .. "\n" + end + + system_prompt = system_prompt .. tools_prompts + + system_prompt = system_prompt + .. [[ +# Tool Use Examples + +## Example 1: Requesting to execute a command + + +./src +npm run dev + + +## Example 2: Requesting to create a new file + + +src/frontend-config.json + +{ + "apiEndpoint": "https://api.example.com", + "theme": { + "primaryColor": "#007bff", + "secondaryColor": "#6c757d", + "fontFamily": "Arial, sans-serif" + }, + "features": { + "darkMode": true, + "notifications": true, + "analytics": false + }, + "version": "1.0.0" +} + + + +## Example 3: Requesting to make targeted edits to a file + + +src/components/App.tsx + +<<<<<<< SEARCH +import React from 'react'; +======= +import React, { useState } from 'react'; +>>>>>>> REPLACE + +<<<<<<< SEARCH +function handleSubmit() { + saveData(); + setLoading(false); +} + +======= +>>>>>>> REPLACE + +<<<<<<< SEARCH +return ( +
+======= +function handleSubmit() { + saveData(); + setLoading(false); +} + +return ( +
+>>>>>>> REPLACE + + +]] + end + return system_prompt +end + +return M diff --git a/tests/llm_tools_spec.lua b/tests/llm_tools_spec.lua index 7e5062e..c730c3c 100644 --- a/tests/llm_tools_spec.lua +++ b/tests/llm_tools_spec.lua @@ -53,21 +53,21 @@ describe("llm_tools", function() describe("ls", function() it("should list files in directory", function() - local result, err = ls({ rel_path = ".", max_depth = 1 }) + local result, err = ls({ path = ".", max_depth = 1 }) assert.is_nil(err) assert.falsy(result:find("avante.nvim")) assert.truthy(result:find("test.txt")) assert.falsy(result:find("test1.txt")) end) it("should list files in directory with depth", function() - local result, err = ls({ rel_path = ".", max_depth = 2 }) + local result, err = ls({ path = ".", max_depth = 2 }) assert.is_nil(err) assert.falsy(result:find("avante.nvim")) assert.truthy(result:find("test.txt")) assert.truthy(result:find("test1.txt")) end) it("should list files respecting gitignore", function() - local result, err = ls({ rel_path = ".", max_depth = 2 }) + local result, err = ls({ path = ".", max_depth = 2 }) assert.is_nil(err) assert.falsy(result:find("avante.nvim")) assert.truthy(result:find("test.txt")) @@ -102,7 +102,7 @@ describe("llm_tools", function() describe("create_dir", function() it("should create new directory", function() - LlmTools.create_dir({ rel_path = "new_dir" }, nil, function(success, err) + LlmTools.create_dir({ path = "new_dir" }, nil, function(success, err) assert.is_nil(err) assert.is_true(success) @@ -114,7 +114,7 @@ describe("llm_tools", function() describe("delete_file", function() it("should delete existing file", function() - LlmTools.delete_file({ rel_path = "test.txt" }, nil, function(success, err) + LlmTools.delete_file({ path = "test.txt" }, nil, function(success, err) assert.is_nil(err) assert.is_true(success) @@ -147,28 +147,28 @@ describe("llm_tools", function() file:write("this is nothing") file:close() - local result, err = grep({ rel_path = ".", query = "Searchable", case_sensitive = false }) + local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }) assert.is_nil(err) assert.truthy(result:find("searchable.txt")) assert.falsy(result:find("nothing.txt")) - local result2, err2 = grep({ rel_path = ".", query = "searchable", case_sensitive = true }) + local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }) assert.is_nil(err2) assert.truthy(result2:find("searchable.txt")) assert.falsy(result2:find("nothing.txt")) - local result3, err3 = grep({ rel_path = ".", query = "Searchable", case_sensitive = true }) + local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }) assert.is_nil(err3) assert.falsy(result3:find("searchable.txt")) assert.falsy(result3:find("nothing.txt")) - local result4, err4 = grep({ rel_path = ".", query = "searchable", case_sensitive = false }) + local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }) assert.is_nil(err4) assert.truthy(result4:find("searchable.txt")) assert.falsy(result4:find("nothing.txt")) local result5, err5 = grep({ - rel_path = ".", + path = ".", query = "searchable", case_sensitive = false, exclude_pattern = "search*", @@ -191,7 +191,7 @@ describe("llm_tools", function() file:write("content for ag test") file:close() - local result, err = grep({ rel_path = ".", query = "ag test" }) + local result, err = grep({ path = ".", query = "ag test" }) assert.is_nil(err) assert.is_string(result) assert.truthy(result:find("ag_test.txt")) @@ -215,28 +215,28 @@ describe("llm_tools", function() file:write("this is nothing") file:close() - local result, err = grep({ rel_path = ".", query = "Searchable", case_sensitive = false }) + local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false }) assert.is_nil(err) assert.truthy(result:find("searchable.txt")) assert.falsy(result:find("nothing.txt")) - local result2, err2 = grep({ rel_path = ".", query = "searchable", case_sensitive = true }) + local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true }) assert.is_nil(err2) assert.truthy(result2:find("searchable.txt")) assert.falsy(result2:find("nothing.txt")) - local result3, err3 = grep({ rel_path = ".", query = "Searchable", case_sensitive = true }) + local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true }) assert.is_nil(err3) assert.falsy(result3:find("searchable.txt")) assert.falsy(result3:find("nothing.txt")) - local result4, err4 = grep({ rel_path = ".", query = "searchable", case_sensitive = false }) + local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false }) assert.is_nil(err4) assert.truthy(result4:find("searchable.txt")) assert.falsy(result4:find("nothing.txt")) local result5, err5 = grep({ - rel_path = ".", + path = ".", query = "searchable", case_sensitive = false, exclude_pattern = "search*", @@ -250,18 +250,18 @@ describe("llm_tools", function() -- Mock exepath to return nothing vim.fn.exepath = function() return "" end - local result, err = grep({ rel_path = ".", query = "test" }) + local result, err = grep({ path = ".", query = "test" }) assert.equals("", result) assert.equals("No search command found", err) end) it("should respect path permissions", function() - local result, err = grep({ rel_path = "../outside_project", query = "test" }) + local result, err = grep({ path = "../outside_project", query = "test" }) assert.truthy(err:find("No permission to access path")) end) it("should handle non-existent paths", function() - local result, err = grep({ rel_path = "non_existent_dir", query = "test" }) + local result, err = grep({ path = "non_existent_dir", query = "test" }) assert.equals("", result) assert.truthy(err) assert.truthy(err:find("No such file or directory")) @@ -270,14 +270,14 @@ describe("llm_tools", function() describe("bash", function() -- it("should execute command and return output", function() - -- bash({ rel_path = ".", command = "echo 'test'" }, nil, function(result, err) + -- bash({ path = ".", command = "echo 'test'" }, nil, function(result, err) -- assert.is_nil(err) -- assert.equals("test\n", result) -- end) -- end) it("should return error when running outside current directory", function() - bash({ rel_path = "../outside_project", command = "echo 'test'" }, nil, function(result, err) + bash({ path = "../outside_project", command = "echo 'test'" }, nil, function(result, err) assert.is_false(result) assert.truthy(err) assert.truthy(err:find("No permission to access path")) @@ -289,7 +289,7 @@ describe("llm_tools", function() it("should execute Python code and return output", function() LlmTools.python( { - rel_path = ".", + path = ".", code = "print('Hello from Python')", }, nil, @@ -303,7 +303,7 @@ describe("llm_tools", function() it("should handle Python errors", function() LlmTools.python( { - rel_path = ".", + path = ".", code = "print(undefined_variable)", }, nil, @@ -318,7 +318,7 @@ describe("llm_tools", function() it("should respect path permissions", function() LlmTools.python( { - rel_path = "../outside_project", + path = "../outside_project", code = "print('test')", }, nil, @@ -332,7 +332,7 @@ describe("llm_tools", function() it("should handle non-existent paths", function() LlmTools.python( { - rel_path = "non_existent_dir", + path = "non_existent_dir", code = "print('test')", }, nil, @@ -347,7 +347,7 @@ describe("llm_tools", function() os.execute("docker image rm python:3.12-slim") LlmTools.python( { - rel_path = ".", + path = ".", code = "print('Hello from custom container')", container_image = "python:3.12-slim", }, @@ -370,7 +370,7 @@ describe("llm_tools", function() os.execute("touch " .. test_dir .. "/nested/file4.lua") -- Test for lua files in the root - local result, err = glob({ rel_path = ".", pattern = "*.lua" }) + local result, err = glob({ path = ".", pattern = "*.lua" }) assert.is_nil(err) local files = vim.json.decode(result).matches assert.equals(2, #files) @@ -380,7 +380,7 @@ describe("llm_tools", function() assert.falsy(vim.tbl_contains(files, test_dir .. "/nested/file4.lua")) -- Test with recursive pattern - local result2, err2 = glob({ rel_path = ".", pattern = "**/*.lua" }) + local result2, err2 = glob({ path = ".", pattern = "**/*.lua" }) assert.is_nil(err2) local files2 = vim.json.decode(result2).matches assert.equals(3, #files2) @@ -390,13 +390,13 @@ describe("llm_tools", function() end) it("should respect path permissions", function() - local result, err = glob({ rel_path = "../outside_project", pattern = "*.txt" }) + local result, err = glob({ path = "../outside_project", pattern = "*.txt" }) assert.equals("", result) assert.truthy(err:find("No permission to access path")) end) it("should handle patterns without matches", function() - local result, err = glob({ rel_path = ".", pattern = "*.nonexistent" }) + local result, err = glob({ path = ".", pattern = "*.nonexistent" }) assert.is_nil(err) local files = vim.json.decode(result).matches assert.equals(0, #files) @@ -411,7 +411,7 @@ describe("llm_tools", function() os.execute("touch " .. test_dir .. "/test_dir1/notignored1.lua") os.execute("touch " .. test_dir .. "/test_dir1/notignored2.lua") - local result, err = glob({ rel_path = ".", pattern = "**/*.lua" }) + local result, err = glob({ path = ".", pattern = "**/*.lua" }) assert.is_nil(err) local files = vim.json.decode(result).matches