feat: ReAct tool calling (#2104)

This commit is contained in:
yetone
2025-05-31 08:53:34 +08:00
committed by GitHub
parent 22418bff8b
commit bc403ddcbf
25 changed files with 1358 additions and 188 deletions

View File

@@ -671,6 +671,7 @@ M.BASE_PROVIDER_KEYS = {
"disable_tools",
"entra",
"hide_in_model_selector",
"use_ReAct_prompt",
}
return M

View File

@@ -0,0 +1,517 @@
-- XML Parser for Lua
local XmlParser = {}
-- 流式解析器状态
local StreamParser = {}
StreamParser.__index = StreamParser
-- 创建新的流式解析器实例
function StreamParser.new()
local parser = {
buffer = "", -- 缓冲区存储未处理的内容
stack = {}, -- 标签栈
results = {}, -- 已完成的元素列表
current = nil, -- 当前正在处理的元素
root = nil, -- 当前根元素
position = 1, -- 当前解析位置
state = "ready", -- 解析状态: ready, parsing, incomplete, error
incomplete_tag = nil, -- 未完成的标签信息
last_error = nil, -- 最后的错误信息
}
setmetatable(parser, StreamParser)
return parser
end
-- 重置解析器状态
function StreamParser:reset()
self.buffer = ""
self.stack = {}
self.results = {}
self.current = nil
self.root = nil
self.position = 1
self.state = "ready"
self.incomplete_tag = nil
self.last_error = nil
end
-- 获取解析器状态信息
function StreamParser:getStatus()
return {
state = self.state,
completed_elements = #self.results,
stack_depth = #self.stack,
buffer_size = #self.buffer,
incomplete_tag = self.incomplete_tag,
last_error = self.last_error,
has_incomplete = self.state == "incomplete" or self.incomplete_tag ~= nil,
}
end
-- 辅助函数:去除字符串首尾空白
local function trim(s) return s:match("^%s*(.-)%s*$") end
-- 辅助函数:解析属性
local function parseAttributes(attrStr)
local attrs = {}
if not attrStr or attrStr == "" then return attrs end
-- 匹配属性模式name="value" 或 name='value'
for name, value in attrStr:gmatch("([-_%w]+)%s*=%s*[\"']([^\"']*)[\"']") do
attrs[name] = value
end
return attrs
end
-- 辅助函数HTML实体解码
local function decodeEntities(str)
local entities = {
["&lt;"] = "<",
["&gt;"] = ">",
["&amp;"] = "&",
["&quot;"] = '"',
["&apos;"] = "'",
}
for entity, char in pairs(entities) do
str = str:gsub(entity, char)
end
-- 处理数字实体 &#123; 和 &#x1A;
str = str:gsub("&#(%d+);", function(n)
local num = tonumber(n)
return num and string.char(num) or ""
end)
str = str:gsub("&#x(%x+);", function(n)
local num = tonumber(n, 16)
return num and string.char(num) or ""
end)
return str
end
-- 检查是否为有效的XML标签
local function isValidXmlTag(tag, xmlContent, tagStart)
-- 排除明显不是XML标签的内容比如数学表达式 < 或 >
-- 检查标签是否包含合理的XML标签格式
if not tag:match("^<[^<>]*>$") then return false end
-- 检查是否是合法的标签格式
if tag:match("^</[-_%w]+>$") then return true end -- 结束标签
if tag:match("^<[-_%w]+[^>]*/>$") then return true end -- 自闭合标签
if tag:match("^<[-_%w]+[^>]*>$") then
-- 对于开始标签,进行额外的上下文检查
local tagName = tag:match("^<([-_%w]+)")
-- 检查是否存在对应的结束标签
local closingTag = "</" .. tagName .. ">"
local hasClosingTag = xmlContent:find(closingTag, tagStart)
-- 如果是单个标签且没有结束标签,可能是文本中的引用
if not hasClosingTag then
-- 检查前后文本如果像是在描述而不是实际的XML结构则不认为是有效标签
local beforeText = xmlContent:sub(math.max(1, tagStart - 50), tagStart - 1)
local afterText = xmlContent:sub(tagStart + #tag, math.min(#xmlContent, tagStart + #tag + 50))
-- 如果前面有"provided in the"、"in the"等描述性文字,可能是文本引用
if
beforeText:match("provided in the%s*$")
or beforeText:match("in the%s*$")
or beforeText:match("see the%s*$")
or beforeText:match("use the%s*$")
then
return false
end
-- 如果后面紧跟着"tag"等描述性词汇,可能是文本引用
if afterText:match("^%s*tag") then return false end
end
return true
end
return false
end
-- 流式解析器方法:添加数据到缓冲区并解析
function StreamParser:addData(data)
if not data or data == "" then return end
self.buffer = self.buffer .. data
self:parseBuffer()
end
-- 获取当前解析深度
function StreamParser:getCurrentDepth() return #self.stack end
-- 解析缓冲区中的数据
function StreamParser:parseBuffer()
self.state = "parsing"
while self.position <= #self.buffer do
local remaining = self.buffer:sub(self.position)
-- 查找下一个标签
local tagStart, tagEnd = remaining:find("<[^>]*>")
if not tagStart then
-- 检查是否有未完成的开始标签(以<开始但没有>结束)
local incompleteStart = remaining:find("<[^>]*$")
if incompleteStart then
local incompleteContent = remaining:sub(incompleteStart)
-- 确保这确实是一个未完成的标签,而不是文本中的<符号
if incompleteContent:match("^<[%w_-]") then
-- 尝试解析未完成的开始标签
local tagName = incompleteContent:match("^<([%w_-]+)")
if tagName then
-- 处理未完成标签前的文本
if incompleteStart > 1 then
local precedingText = trim(remaining:sub(1, incompleteStart - 1))
if precedingText ~= "" then
if self.current then
-- 如果当前在某个标签内,添加到该标签的文本内容
precedingText = decodeEntities(precedingText)
if self.current._text then
self.current._text = self.current._text .. precedingText
else
self.current._text = precedingText
end
else
-- 如果是顶层文本,作为独立元素添加
local textElement = {
_name = "_text",
_text = decodeEntities(precedingText),
}
table.insert(self.results, textElement)
end
end
end
-- 创建未完成的元素
local element = {
_name = tagName,
_attr = {},
_state = "incomplete_start_tag",
}
if not self.root then
self.root = element
self.current = element
elseif self.current then
table.insert(self.stack, self.current)
if not self.current[tagName] then self.current[tagName] = {} end
table.insert(self.current[tagName], element)
self.current = element
end
self.incomplete_tag = {
start_pos = self.position + incompleteStart - 1,
content = incompleteContent,
element = element,
}
self.state = "incomplete"
return
end
end
end
-- 处理剩余的文本内容
if remaining ~= "" then
if self.current then
-- 检查当前深度,如果在第一层子元素中,保持原始文本
local currentDepth = #self.stack
if currentDepth >= 1 then
-- 在第一层子元素中,保持原始文本不变
if self.current._text then
self.current._text = self.current._text .. remaining
else
self.current._text = remaining
end
else
-- 在根级别,进行正常的文本处理
local text = trim(remaining)
if text ~= "" then
text = decodeEntities(text)
if self.current._text then
self.current._text = self.current._text .. text
else
self.current._text = text
end
end
end
else
-- 如果是顶层文本,作为独立元素添加
local text = trim(remaining)
if text ~= "" then
local textElement = {
_name = "_text",
_text = decodeEntities(text),
}
table.insert(self.results, textElement)
end
end
end
self.position = #self.buffer + 1
break
end
local tag = remaining:sub(tagStart, tagEnd)
local actualTagStart = self.position + tagStart - 1
local actualTagEnd = self.position + tagEnd - 1
-- 检查是否为有效的XML标签
if not isValidXmlTag(tag, self.buffer, actualTagStart) then
-- 如果不是有效标签,将其作为普通文本处理
local text = remaining:sub(1, tagEnd)
if text ~= "" then
if self.current then
-- 检查当前深度,如果在第一层子元素中,保持原始文本
local currentDepth = #self.stack
if currentDepth >= 1 then
-- 在第一层子元素中,保持原始文本不变
if self.current._text then
self.current._text = self.current._text .. text
else
self.current._text = text
end
else
-- 在根级别,进行正常的文本处理
text = trim(text)
if text ~= "" then
text = decodeEntities(text)
if self.current._text then
self.current._text = self.current._text .. text
else
self.current._text = text
end
end
end
else
-- 顶层文本作为独立元素
text = trim(text)
if text ~= "" then
local textElement = {
_name = "_text",
_text = decodeEntities(text),
}
table.insert(self.results, textElement)
end
end
end
self.position = actualTagEnd + 1
goto continue
end
-- 处理标签前的文本内容
if tagStart > 1 then
local precedingText = remaining:sub(1, tagStart - 1)
if precedingText ~= "" then
if self.current then
-- 如果当前在某个标签内,添加到该标签的文本内容
-- 检查当前深度如果在第一层子元素中不要进行实体解码和trim
local currentDepth = #self.stack
if currentDepth >= 1 then
-- 在第一层子元素中,保持原始文本不变
if self.current._text then
self.current._text = self.current._text .. precedingText
else
self.current._text = precedingText
end
else
-- 在根级别,进行正常的文本处理
precedingText = trim(precedingText)
if precedingText ~= "" then
precedingText = decodeEntities(precedingText)
if self.current._text then
self.current._text = self.current._text .. precedingText
else
self.current._text = precedingText
end
end
end
else
-- 如果是顶层文本,作为独立元素添加
precedingText = trim(precedingText)
if precedingText ~= "" then
local textElement = {
_name = "_text",
_text = decodeEntities(precedingText),
}
table.insert(self.results, textElement)
end
end
end
end
-- 检查当前深度,如果已经在第一层子元素中,将所有标签作为文本处理
local currentDepth = #self.stack
if currentDepth >= 1 then
-- 检查是否是当前元素的结束标签
if tag:match("^</[-_%w]+>$") and self.current then
local tagName = tag:match("^</([-_%w]+)>$")
if self.current._name == tagName then
-- 这是当前元素的结束标签,正常处理
if not self:processTag(tag) then
self.state = "error"
return
end
else
-- 不是当前元素的结束标签,作为文本处理
if self.current._text then
self.current._text = self.current._text .. tag
else
self.current._text = tag
end
end
else
-- 在第一层子元素中,将标签作为文本处理
if self.current then
if self.current._text then
self.current._text = self.current._text .. tag
else
self.current._text = tag
end
end
end
else
-- 处理标签
if not self:processTag(tag) then
self.state = "error"
return
end
end
self.position = actualTagEnd + 1
::continue::
end
-- 检查当前是否有未关闭的元素
if self.current and self.current._state ~= "complete" then
self.current._state = "incomplete_unclosed"
self.state = "incomplete"
elseif self.state ~= "incomplete" and self.state ~= "error" then
self.state = "ready"
end
end
-- 处理单个标签
function StreamParser:processTag(tag)
if tag:match("^</[-_%w]+>$") then
-- 结束标签
local tagName = tag:match("^</([-_%w]+)>$")
if self.current and self.current._name == tagName then
-- 标记当前元素为完成状态
self.current._state = "complete"
self.current = table.remove(self.stack)
-- 只有当栈为空且当前元素也为空时,说明完成了一个根级元素
if #self.stack == 0 and not self.current and self.root then
table.insert(self.results, self.root)
self.root = nil
end
else
self.last_error = "Mismatched closing tag: " .. tagName
return false
end
elseif tag:match("^<[-_%w]+[^>]*/>$") then
-- 自闭合标签
local tagName, attrs = tag:match("^<([-_%w]+)([^>]*)/>")
local element = {
_name = tagName,
_attr = parseAttributes(attrs),
_state = "complete",
children = {},
}
if not self.root then
-- 直接作为根级元素添加到结果中
table.insert(self.results, element)
elseif self.current then
if not self.current.children then self.current.children = {} end
table.insert(self.current.children, element)
end
elseif tag:match("^<[-_%w]+[^>]*>$") then
-- 开始标签
local tagName, attrs = tag:match("^<([-_%w]+)([^>]*)>")
local element = {
_name = tagName,
_attr = parseAttributes(attrs),
_state = "incomplete_open", -- 标记为未完成(等待结束标签)
children = {},
}
if not self.root then
self.root = element
self.current = element
elseif self.current then
table.insert(self.stack, self.current)
if not self.current.children then self.current.children = {} end
table.insert(self.current.children, element)
self.current = element
end
end
return true
end
-- 获取所有元素(已完成的和当前正在处理的)
function StreamParser:getAllElements()
local all_elements = {}
-- 添加所有已完成的元素
for _, element in ipairs(self.results) do
table.insert(all_elements, element)
end
-- 如果有当前正在处理的元素,也添加进去
if self.root then table.insert(all_elements, self.root) end
return all_elements
end
-- 获取已完成的元素(保留向后兼容性)
function StreamParser:getCompletedElements() return self.results end
-- 获取当前未完成的元素(保留向后兼容性)
function StreamParser:getCurrentElement() return self.root end
-- 强制完成解析(将未完成的内容作为已完成处理)
function StreamParser:finalize()
-- 首先处理当前正在解析的元素
if self.current then
-- 递归设置所有未完成元素的状态
local function markIncompleteElements(element)
if element._state and element._state:match("incomplete") then element._state = "incomplete_unclosed" end
-- 处理 children 数组中的子元素
if element.children and type(element.children) == "table" then
for _, child in ipairs(element.children) do
if type(child) == "table" and child._name then markIncompleteElements(child) end
end
end
end
-- 标记当前元素及其所有子元素为未完成状态,但保持层次结构
markIncompleteElements(self.current)
-- 向上遍历栈,标记所有祖先元素
for i = #self.stack, 1, -1 do
local ancestor = self.stack[i]
if ancestor._state and ancestor._state:match("incomplete") then ancestor._state = "incomplete_unclosed" end
end
end
-- 只有当存在根元素时才添加到结果中
if self.root then
table.insert(self.results, self.root)
self.root = nil
end
self.current = nil
self.stack = {}
self.state = "ready"
self.incomplete_tag = nil
end
-- 创建流式解析器实例
function XmlParser.createStreamParser() return StreamParser.new() end
return XmlParser

View File

@@ -25,7 +25,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
---@param cb fun(title: string | nil): nil
function M.summarize_chat_thread_title(content, cb)
local system_prompt =
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]]
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium. /no_think]]
local response_content = ""
local provider = Providers.get_memory_summary_provider()
M.curl({
@@ -761,7 +761,7 @@ function M._stream(opts)
},
})
end
opts.on_messages_add(messages)
if opts.on_messages_add then opts.on_messages_add(messages) end
local the_last_tool_use = tool_use_list[#tool_use_list]
if the_last_tool_use and the_last_tool_use.name == "attempt_completion" then
opts.on_stop({ reason = "complete" })
@@ -854,12 +854,8 @@ function M._stream(opts)
if is_break then break end
::continue::
end
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
for _, tool_use in vim.spairs(tool_use_list) do
table.insert(sorted_tool_use_list, tool_use)
end
if stop_opts.reason == "complete" and Config.mode == "agentic" then
if #sorted_tool_use_list == 0 then
if #tool_use_list == 0 then
local completed_attempt_completion_tool_use = nil
for idx = #history_messages, 1, -1 do
local message = history_messages[idx]
@@ -896,7 +892,7 @@ function M._stream(opts)
end
end
end
if stop_opts.reason == "tool_use" then return handle_next_tool_use(sorted_tool_use_list, 1, {}) end
if stop_opts.reason == "tool_use" then return handle_next_tool_use(tool_use_list, 1, {}) end
if stop_opts.reason == "rate_limit" then
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end

View File

@@ -32,6 +32,10 @@ M.param = {
optional = true,
},
},
usage = {
result = "The result of the task. Formulate this result in a way that is final and does not require further input from the user. Don't end your result with questions or offers for further assistance.",
command = "A CLI command to execute to show a live demo of the result to the user. For example, use `open index.html` to display a created html website, or `open localhost:3000` to display a locally running development server. But DO NOT use commands like `echo` or `cat` that merely print text. This command should be valid for the current operating system. Ensure the command is properly formatted and does not contain any harmful instructions.",
},
}
---@type AvanteLLMToolReturn[]

View File

@@ -183,7 +183,7 @@ M.param = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the project directory, as cwd",
type = "string",
},
@@ -193,6 +193,10 @@ M.param = {
type = "string",
},
},
usage = {
path = "Relative path to the project directory, as cwd",
command = "Command to run",
},
}
---@type AvanteLLMToolReturn[]
@@ -210,9 +214,9 @@ M.returns = {
},
}
---@type AvanteLLMToolFunc<{ rel_path: string, command: string }>
---@type AvanteLLMToolFunc<{ path: string, command: string }>
function M.func(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if on_log then on_log("command: " .. opts.command) end

View File

@@ -49,6 +49,9 @@ M.param = {
},
},
required = { "prompt" },
usage = {
prompt = "The task for the agent to perform",
},
}
---@type AvanteLLMToolReturn[]

View File

@@ -19,6 +19,9 @@ M.param = {
type = "string",
},
},
usage = {
path = "The path to the file in the current project scope",
},
}
---@type AvanteLLMToolReturn[]

View File

@@ -18,11 +18,15 @@ M.param = {
type = "string",
},
{
name = "rel_path",
name = "path",
description = "Relative path to the project directory, as cwd",
type = "string",
},
},
usage = {
pattern = "Glob pattern",
path = "Relative path to the project directory, as cwd",
},
}
---@type AvanteLLMToolReturn[]
@@ -40,9 +44,9 @@ M.returns = {
},
}
---@type AvanteLLMToolFunc<{ rel_path: string, pattern: string }>
---@type AvanteLLMToolFunc<{ path: string, pattern: string }>
function M.func(opts, on_log, on_complete, session_ctx)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("pattern: " .. opts.pattern) end

View File

@@ -15,7 +15,7 @@ M.param = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the project directory",
type = "string",
},
@@ -44,6 +44,13 @@ M.param = {
optional = true,
},
},
usage = {
path = "Relative path to the project directory",
query = "Query to search for",
case_sensitive = "Whether to search case sensitively",
include_pattern = "Glob pattern to include files",
exclude_pattern = "Glob pattern to exclude files",
},
}
---@type AvanteLLMToolReturn[]
@@ -61,9 +68,9 @@ M.returns = {
},
}
---@type AvanteLLMToolFunc<{ rel_path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }>
---@type AvanteLLMToolFunc<{ path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }>
function M.func(opts, on_log)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end

View File

@@ -7,10 +7,10 @@ local Helpers = require("avante.llm_tools.helpers")
local M = {}
---@type AvanteLLMToolFunc<{ rel_path: string }>
---@type AvanteLLMToolFunc<{ path: string }>
function M.read_file_toplevel_symbols(opts, on_log)
local RepoMap = require("avante.repo_map")
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end
@@ -121,13 +121,13 @@ function M.write_global_file(opts, on_log, on_complete)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
---@type AvanteLLMToolFunc<{ path: string, new_path: string }>
function M.rename_file(opts, on_log, on_complete)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.new_rel_path)
local new_abs_path = Helpers.get_abs_path(opts.new_path)
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path
@@ -147,13 +147,13 @@ function M.rename_file(opts, on_log, on_complete)
)
end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
---@type AvanteLLMToolFunc<{ path: string, new_path: string }>
function M.copy_file(opts, on_log)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.new_rel_path)
local new_abs_path = Helpers.get_abs_path(opts.new_path)
if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path
end
@@ -163,9 +163,9 @@ function M.copy_file(opts, on_log)
return true, nil
end
---@type AvanteLLMToolFunc<{ rel_path: string }>
---@type AvanteLLMToolFunc<{ path: string }>
function M.delete_file(opts, on_log, on_complete)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
@@ -181,9 +181,9 @@ function M.delete_file(opts, on_log, on_complete)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string }>
---@type AvanteLLMToolFunc<{ path: string }>
function M.create_dir(opts, on_log, on_complete)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
if not on_complete then return false, "on_complete not provided" end
@@ -198,13 +198,13 @@ function M.create_dir(opts, on_log, on_complete)
end)
end
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
---@type AvanteLLMToolFunc<{ path: string, new_path: string }>
function M.rename_dir(opts, on_log, on_complete)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
local new_abs_path = Helpers.get_abs_path(opts.new_rel_path)
local new_abs_path = Helpers.get_abs_path(opts.new_path)
if not Helpers.has_permission_to_access(new_abs_path) then
return false, "No permission to access path: " .. new_abs_path
end
@@ -224,9 +224,9 @@ function M.rename_dir(opts, on_log, on_complete)
)
end
---@type AvanteLLMToolFunc<{ rel_path: string }>
---@type AvanteLLMToolFunc<{ path: string }>
function M.delete_dir(opts, on_log, on_complete)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
@@ -552,9 +552,9 @@ function M.rag_search(opts, on_log, on_complete)
)
end
---@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }>
---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }>
function M.python(opts, on_log, on_complete)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end
if on_log then on_log("cwd: " .. abs_path) end
@@ -634,6 +634,18 @@ function M.get_tools(user_input, history_messages)
:totable()
end
function M.get_tool_names()
local custom_tools = Config.custom_tools
if type(custom_tools) == "function" then custom_tools = custom_tools() end
---@type AvanteLLMTool[]
local unfiltered_tools = vim.list_extend(vim.list_extend({}, M._tools), custom_tools)
local tool_names = {}
for _, tool in ipairs(unfiltered_tools) do
table.insert(tool_names, tool.name)
end
return tool_names
end
---@type AvanteLLMTool[]
M._tools = {
require("avante.llm_tools.replace_in_file"),
@@ -652,6 +664,9 @@ M._tools = {
type = "string",
},
},
usage = {
query = "Query to search",
},
},
returns = {
{
@@ -679,11 +694,15 @@ M._tools = {
type = "string",
},
{
name = "rel_path",
name = "path",
description = "Relative path to the project directory, as cwd",
type = "string",
},
},
usage = {
code = "Python code to run",
path = "Relative path to the project directory, as cwd",
},
},
returns = {
{
@@ -711,6 +730,9 @@ M._tools = {
type = "string",
},
},
usage = {
scope = "Scope for the git diff (e.g. specific files or directories)",
},
},
returns = {
{
@@ -744,6 +766,10 @@ M._tools = {
optional = true,
},
},
usage = {
message = "Commit message to use",
scope = "Scope for staging files (e.g. specific files or directories)",
},
},
returns = {
{
@@ -768,11 +794,14 @@ M._tools = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the file in current project scope",
type = "string",
},
},
usage = {
path = "Relative path to the file in current project scope",
},
},
returns = {
{
@@ -790,7 +819,7 @@ M._tools = {
},
require("avante.llm_tools.str_replace"),
require("avante.llm_tools.view"),
require("avante.llm_tools.create"),
require("avante.llm_tools.write_to_file"),
require("avante.llm_tools.insert"),
require("avante.llm_tools.undo_edit"),
{
@@ -820,6 +849,9 @@ M._tools = {
type = "string",
},
},
usage = {
abs_path = "Absolute path to the file in global scope",
},
},
returns = {
{
@@ -867,6 +899,10 @@ M._tools = {
type = "string",
},
},
usage = {
abs_path = "The path to the file in the current project scope",
content = "The content to write to the file",
},
},
returns = {
{
@@ -889,16 +925,20 @@ M._tools = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the file in current project scope",
type = "string",
},
{
name = "new_rel_path",
name = "new_path",
description = "New relative path for the file",
type = "string",
},
},
usage = {
path = "Relative path to the file in current project scope",
new_path = "New relative path for the file",
},
},
returns = {
{
@@ -921,11 +961,14 @@ M._tools = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the file in current project scope",
type = "string",
},
},
usage = {
path = "Relative path to the file in current project scope",
},
},
returns = {
{
@@ -948,11 +991,14 @@ M._tools = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the project directory",
type = "string",
},
},
usage = {
path = "Relative path to the project directory",
},
},
returns = {
{
@@ -975,16 +1021,20 @@ M._tools = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the project directory",
type = "string",
},
{
name = "new_rel_path",
name = "new_path",
description = "New relative path for the directory",
type = "string",
},
},
usage = {
path = "Relative path to the project directory",
new_path = "New relative path for the directory",
},
},
returns = {
{
@@ -1007,11 +1057,14 @@ M._tools = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the project directory",
type = "string",
},
},
usage = {
path = "Relative path to the project directory",
},
},
returns = {
{
@@ -1042,6 +1095,9 @@ M._tools = {
type = "string",
},
},
usage = {
query = "Query to search",
},
},
returns = {
{
@@ -1069,6 +1125,9 @@ M._tools = {
type = "string",
},
},
usage = {
url = "Url to fetch markdown from",
},
},
returns = {
{

View File

@@ -32,6 +32,11 @@ M.param = {
type = "string",
},
},
usage = {
path = "The path to the file to modify",
insert_line = "The line number after which to insert the text (0 for beginning of file)",
new_str = "The text to insert",
},
}
---@type AvanteLLMToolReturn[]

View File

@@ -14,7 +14,7 @@ M.param = {
type = "table",
fields = {
{
name = "rel_path",
name = "path",
description = "Relative path to the project directory",
type = "string",
},
@@ -24,6 +24,10 @@ M.param = {
type = "integer",
},
},
usage = {
path = "Relative path to the project directory",
max_depth = "Maximum depth of the directory",
},
}
---@type AvanteLLMToolReturn[]
@@ -41,9 +45,9 @@ M.returns = {
},
}
---@type AvanteLLMToolFunc<{ rel_path: string, max_depth?: integer }>
---@type AvanteLLMToolFunc<{ path: string, max_depth?: integer }>
function M.func(opts, on_log)
local abs_path = Helpers.get_abs_path(opts.rel_path)
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if on_log then on_log("path: " .. abs_path) end
if on_log then on_log("max depth: " .. tostring(opts.max_depth)) end

View File

@@ -16,6 +16,8 @@ M.name = "replace_in_file"
M.description =
"Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file."
-- function M.enabled() return Config.provider:match("ollama") == nil end
---@type AvanteLLMToolParam
M.param = {
type = "table",
@@ -57,6 +59,10 @@ One or more SEARCH/REPLACE blocks following this exact format:
type = "string",
},
},
usage = {
path = "File path here",
diff = "Search and replace blocks here",
},
}
---@type AvanteLLMToolReturn[]

View File

@@ -1,8 +1,4 @@
local Path = require("plenary.path")
local Utils = require("avante.utils")
local Base = require("avante.llm_tools.base")
local Helpers = require("avante.llm_tools.helpers")
local Diff = require("avante.diff")
local Config = require("avante.config")
---@class AvanteLLMTool
@@ -13,6 +9,7 @@ M.name = "str_replace"
M.description =
"The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits."
-- function M.enabled() return Config.provider:match("ollama") ~= nil end
function M.enabled() return false end
---@type AvanteLLMToolParam
@@ -35,6 +32,11 @@ M.param = {
type = "string",
},
},
usage = {
path = "File path here",
old_str = "old str here",
new_str = "new str here",
},
}
---@type AvanteLLMToolReturn[]

View File

@@ -22,6 +22,9 @@ M.param = {
type = "string",
},
},
usage = {
path = "The path to the file whose last edit should be undone",
},
}
---@type AvanteLLMToolReturn[]

View File

@@ -39,23 +39,22 @@ M.param = {
type = "string",
},
{
name = "view_range",
description = "The range of the file to view. This parameter only applies when viewing files, not directories.",
type = "object",
name = "start_line",
description = "The start line of the view range, 1-indexed",
type = "integer",
optional = true,
fields = {
{
name = "start_line",
description = "The start line of the range, 1-indexed",
type = "integer",
},
{
name = "end_line",
description = "The end line of the range, 1-indexed, and -1 for the end line means read to the end of the file",
type = "integer",
},
},
},
{
name = "end_line",
description = "The end line of the view range, 1-indexed, and -1 for the end line means read to the end of the file",
type = "integer",
optional = true,
},
},
usage = {
path = "The path to the file in the current project scope",
start_line = "The start line of the view range, 1-indexed",
end_line = "The end line of the view range, 1-indexed, and -1 for the end line means read to the end of the file",
},
}
@@ -74,7 +73,7 @@ M.returns = {
},
}
---@type AvanteLLMToolFunc<{ path: string, view_range?: { start_line: integer, end_line: integer } }>
---@type AvanteLLMToolFunc<{ path: string, start_line?: integer, end_line?: integer }>
function M.func(opts, on_log, on_complete, session_ctx)
if on_log then on_log("path: " .. opts.path) end
local abs_path = Helpers.get_abs_path(opts.path)
@@ -84,11 +83,9 @@ function M.func(opts, on_log, on_complete, session_ctx)
local file = io.open(abs_path, "r")
if not file then return false, "file not found: " .. abs_path end
local lines = Utils.read_file_from_buf_or_disk(abs_path)
if opts.view_range then
local start_line = opts.view_range.start_line
local end_line = opts.view_range.end_line
if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end
end
local start_line = opts.start_line
local end_line = opts.end_line
if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end
local truncated_lines = {}
local is_truncated = false
local size = 0

View File

@@ -0,0 +1,74 @@
local Utils = require("avante.utils")
local Base = require("avante.llm_tools.base")
local Helpers = require("avante.llm_tools.helpers")
---@class AvanteLLMTool
local M = setmetatable({}, Base)
M.name = "write_to_file"
M.description =
"Request to write content to a file at the specified path. If the file exists, it will be overwritten with the provided content. If the file doesn't exist, it will be created. This tool will automatically create any directories needed to write the file."
function M.enabled() return require("avante.config").mode == "agentic" end
---@type AvanteLLMToolParam
M.param = {
type = "table",
fields = {
{
name = "path",
get_description = function()
local res = ("The path of the file to write to (relative to the current working directory {{cwd}})"):gsub(
"{{cwd}}",
Utils.get_project_root()
)
return res
end,
type = "string",
},
{
name = "content",
description = "The content to write to the file. ALWAYS provide the COMPLETE intended content of the file, without any truncation or omissions. You MUST include ALL parts of the file, even if they haven't been modified.",
type = "string",
},
},
usage = {
path = "File path here",
content = "File content here",
},
}
---@type AvanteLLMToolReturn[]
M.returns = {
{
name = "success",
description = "Whether the file was created successfully",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not created successfully",
type = "string",
optional = true,
},
}
---@type AvanteLLMToolFunc<{ path: string, content: string }>
function M.func(opts, on_log, on_complete, session_ctx)
if not on_complete then return false, "on_complete not provided" end
local abs_path = Helpers.get_abs_path(opts.path)
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if opts.content == nil then return false, "content not provided" end
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
local old_content = table.concat(old_lines or {}, "\n")
local replace_in_file = require("avante.llm_tools.replace_in_file")
local diff = "<<<<<<< SEARCH\n" .. old_content .. "\n=======\n" .. opts.content .. "\n>>>>>>> REPLACE"
local new_opts = {
path = opts.path,
diff = diff,
}
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx)
end
return M

View File

@@ -1,10 +1,15 @@
local Utils = require("avante.utils")
local P = require("avante.providers")
local Providers = require("avante.providers")
local Config = require("avante.config")
local Clipboard = require("avante.clipboard")
local HistoryMessage = require("avante.history_message")
local Prompts = require("avante.utils.prompts")
---@class AvanteProviderFunctor
local M = {}
setmetatable(M, { __index = Providers.openai })
M.api_key_name = "" -- Ollama typically doesn't require API keys for local use
M.role_map = {
@@ -12,35 +17,182 @@ M.role_map = {
assistant = "assistant",
}
M.parse_messages = P.openai.parse_messages
M.is_reasoning_model = P.openai.is_reasoning_model
function M:parse_messages(opts)
local messages = {}
local provider_conf, _ = Providers.parse_config(self)
local system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts)
if self.is_reasoning_model(provider_conf.model) then
table.insert(messages, { role = "developer", content = system_prompt })
else
table.insert(messages, { role = "system", content = system_prompt })
end
vim.iter(opts.messages):each(function(msg)
if type(msg.content) == "string" then
table.insert(messages, { role = self.role_map[msg.role], content = msg.content })
elseif type(msg.content) == "table" then
local content = {}
for _, item in ipairs(msg.content) do
if type(item) == "string" then
table.insert(content, { type = "text", text = item })
elseif item.type == "text" then
table.insert(content, { type = "text", text = item.text })
elseif item.type == "image" then
table.insert(content, {
type = "image_url",
image_url = {
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
},
})
end
end
if not provider_conf.disable_tools then
if msg.content[1].type == "tool_result" then
local tool_use = nil
for _, msg_ in ipairs(opts.messages) do
if type(msg_.content) == "table" and #msg_.content > 0 then
if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then
tool_use = msg_
break
end
end
end
if tool_use then
msg.role = "user"
table.insert(content, {
type = "text",
text = "["
.. tool_use.content[1].name
.. " for '"
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
.. "'] Result:",
})
table.insert(content, {
type = "text",
text = msg.content[1].content,
})
end
end
end
if #content > 0 then
local text_content = {}
for _, item in ipairs(content) do
if type(item) == "table" and item.type == "text" then table.insert(text_content, item.text) end
end
table.insert(messages, { role = self.role_map[msg.role], content = table.concat(text_content, "\n\n") })
end
end
end)
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content
if type(message_content) ~= "table" or message_content[1] == nil then
message_content = { { type = "text", text = message_content } }
end
for _, image_path in ipairs(opts.image_paths) do
table.insert(message_content, {
type = "image_url",
image_url = {
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
},
})
end
messages[#messages].content = message_content
end
local final_messages = {}
local prev_role = nil
vim.iter(messages):each(function(message)
local role = message.role
if role == prev_role and role ~= "tool" then
if role == self.role_map["assistant"] then
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
else
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
end
end
prev_role = role
table.insert(final_messages, message)
end)
return final_messages
end
function M:is_disable_stream() return false end
---@class avante.OllamaFunction
---@field name string
---@field arguments table
---@class avante.OllamaToolCall
---@field function avante.OllamaFunction
---@param tool_calls avante.OllamaToolCall[]
---@param opts AvanteLLMStreamOptions
function M:add_tool_use_messages(tool_calls, opts)
local msgs = {}
for _, tool_call in ipairs(tool_calls) do
local id = Utils.uuid()
local func = tool_call["function"]
local msg = HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
name = func.name,
id = id,
input = func.arguments,
},
},
}, {
state = "generated",
uuid = id,
})
table.insert(msgs, msg)
end
if opts.on_messages_add then opts.on_messages_add(msgs) end
end
function M:parse_stream_data(ctx, data, opts)
local ok, json_data = pcall(vim.json.decode, data)
if not ok or not json_data then
local ok, jsn = pcall(vim.json.decode, data)
if not ok or not jsn then
-- Add debug logging
Utils.debug("Failed to parse JSON", data)
return
end
if json_data.message and json_data.message.content then
local content = json_data.message.content
P.openai:add_text_message(ctx, content, "generating", opts)
if content and content ~= "" and opts.on_chunk then opts.on_chunk(content) end
if jsn.message then
if jsn.message.content then
local content = jsn.message.content
if content and content ~= "" then
Providers.openai:add_text_message(ctx, content, "generating", opts)
if opts.on_chunk then opts.on_chunk(content) end
end
end
if jsn.message.tool_calls then
ctx.has_tool_use = true
local tool_calls = jsn.message.tool_calls
self:add_tool_use_messages(tool_calls, opts)
end
end
if json_data.done then
P.openai:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" })
if jsn.done then
Providers.openai:finish_pending_messages(ctx, opts)
if ctx.has_tool_use or (ctx.tool_use_list and #ctx.tool_use_list > 0) then
opts.on_stop({ reason = "tool_use" })
else
opts.on_stop({ reason = "complete" })
end
return
end
end
---@param prompt_opts AvantePromptOptions
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
local provider_conf, request_body = Providers.parse_config(self)
local keep_alive = provider_conf.keep_alive or "5m"
if not provider_conf.model or provider_conf.model == "" then error("Ollama model must be specified in config") end
@@ -50,7 +202,7 @@ function M:parse_curl_args(prompt_opts)
["Accept"] = "application/json",
}
if P.env.require_api_key(provider_conf) then
if Providers.env.require_api_key(provider_conf) then
local api_key = self.parse_api_key()
if api_key and api_key ~= "" then
headers["Authorization"] = "Bearer " .. api_key
@@ -66,7 +218,6 @@ function M:parse_curl_args(prompt_opts)
model = provider_conf.model,
messages = self:parse_messages(prompt_opts),
stream = true,
system = prompt_opts.system_prompt,
keep_alive = keep_alive,
}, request_body),
}

View File

@@ -3,6 +3,9 @@ local Config = require("avante.config")
local Clipboard = require("avante.clipboard")
local Providers = require("avante.providers")
local HistoryMessage = require("avante.history_message")
local XMLParser = require("avante.libs.xmlparser")
local Prompts = require("avante.utils.prompts")
local LlmTools = require("avante.llm_tools")
---@class AvanteProviderFunctor
local M = {}
@@ -76,10 +79,15 @@ function M:parse_messages(opts)
local messages = {}
local provider_conf, _ = Providers.parse_config(self)
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local system_prompt = opts.system_prompt
if use_ReAct_prompt then system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) end
if self.is_reasoning_model(provider_conf.model) then
table.insert(messages, { role = "developer", content = opts.system_prompt })
table.insert(messages, { role = "developer", content = system_prompt })
else
table.insert(messages, { role = "system", content = opts.system_prompt })
table.insert(messages, { role = "system", content = system_prompt })
end
local has_tool_use = false
@@ -103,22 +111,50 @@ function M:parse_messages(opts)
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
},
})
elseif item.type == "tool_use" then
elseif item.type == "tool_use" and not use_ReAct_prompt then
has_tool_use = true
table.insert(tool_calls, {
id = item.id,
type = "function",
["function"] = { name = item.name, arguments = vim.json.encode(item.input) },
})
elseif item.type == "tool_result" and has_tool_use then
elseif item.type == "tool_result" and has_tool_use and not use_ReAct_prompt then
table.insert(
tool_results,
{ tool_call_id = item.tool_use_id, content = item.is_error and "Error: " .. item.content or item.content }
)
end
end
if not provider_conf.disable_tools and use_ReAct_prompt then
if msg.content[1].type == "tool_result" then
local tool_use = nil
for _, msg_ in ipairs(opts.messages) do
if type(msg_.content) == "table" and #msg_.content > 0 then
if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then
tool_use = msg_
break
end
end
end
if tool_use then
msg.role = "user"
table.insert(content, {
type = "text",
text = "["
.. tool_use.content[1].name
.. " for '"
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
.. "'] Result:",
})
table.insert(content, {
type = "text",
text = msg.content[1].content,
})
end
end
end
if #content > 0 then table.insert(messages, { role = self.role_map[msg.role], content = content }) end
if not provider_conf.disable_tools then
if not provider_conf.disable_tools and not use_ReAct_prompt then
if #tool_calls > 0 then
local last_message = messages[#messages]
if last_message and last_message.role == self.role_map["assistant"] and last_message.tool_calls then
@@ -183,7 +219,10 @@ function M:finish_pending_messages(ctx, opts)
end
end
local llm_tool_names = nil
function M:add_text_message(ctx, text, state, opts)
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
if ctx.content == nil then ctx.content = "" end
ctx.content = ctx.content .. text
local msg = HistoryMessage:new({
@@ -194,7 +233,75 @@ function M:add_text_message(ctx, text, state, opts)
uuid = ctx.content_uuid,
})
ctx.content_uuid = msg.uuid
if opts.on_messages_add then opts.on_messages_add({ msg }) end
local msgs = { msg }
local stream_parser = XMLParser.createStreamParser()
stream_parser:addData(ctx.content)
local xml = stream_parser:getAllElements()
if xml then
local new_content_list = {}
local xml_md_openned = false
for idx, item in ipairs(xml) do
if item._name == "_text" then
local cleaned_lines = {}
local lines = vim.split(item._text, "\n")
for _, line in ipairs(lines) do
if line:match("^```xml") or line:match("^```tool_code") or line:match("^```tool_use") then
xml_md_openned = true
elseif line:match("^```$") then
if xml_md_openned then
xml_md_openned = false
else
table.insert(cleaned_lines, line)
end
else
table.insert(cleaned_lines, line)
end
end
table.insert(new_content_list, table.concat(cleaned_lines, "\n"))
goto continue
end
if not vim.tbl_contains(llm_tool_names, item._name) then goto continue end
local ok, input = pcall(vim.json.decode, item._text)
if not ok and item.children and #item.children > 0 then
input = {}
for _, item_ in ipairs(item.children) do
local ok_, input_ = pcall(vim.json.decode, item_._text)
if ok_ and input_ then
input[item_._name] = input_
else
input[item_._name] = item_._text
end
end
end
if input then
local tool_use_id = Utils.uuid()
local msg_ = HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
name = item._name,
id = tool_use_id,
input = input,
},
},
}, {
state = state,
uuid = ctx.content_uuid .. "-" .. idx,
})
msgs[#msgs + 1] = msg_
ctx.tool_use_list = ctx.tool_use_list or {}
ctx.tool_use_list[#ctx.tool_use_list + 1] = {
id = tool_use_id,
name = item._name,
input_json = input,
}
end
if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end
::continue::
end
end
if opts.on_messages_add then opts.on_messages_add(msgs) end
end
function M:add_thinking_message(ctx, text, state, opts)
@@ -242,7 +349,11 @@ end
function M:parse_response(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then
self:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" })
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
opts.on_stop({ reason = "tool_use" })
else
opts.on_stop({ reason = "complete" })
end
return
end
if data_stream == "[DONE]" then return end
@@ -316,9 +427,13 @@ function M:parse_response(ctx, data_stream, _, opts)
self:add_text_message(ctx, delta.content, "generating", opts)
end
end
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then
self:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" })
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
opts.on_stop({ reason = "tool_use", usage = jsn.usage })
else
opts.on_stop({ reason = "complete", usage = jsn.usage })
end
end
if choice.finish_reason == "tool_calls" then
self:finish_pending_messages(ctx, opts)
@@ -372,8 +487,10 @@ function M:parse_curl_args(prompt_opts)
self.set_allowed_params(provider_conf, request_body)
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local tools = nil
if not disable_tools and prompt_opts.tools then
if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then
tools = {}
for _, tool in ipairs(prompt_opts.tools) do
table.insert(tools, self:transform_tool(tool))

View File

@@ -56,6 +56,7 @@ Sidebar.__index = Sidebar
---@field current_state avante.GenerateState | nil
---@field state_timer table | nil
---@field state_spinner_chars string[]
---@field thinking_spinner_chars string[]
---@field state_spinner_idx integer
---@field state_extmark_id integer | nil
---@field scroll boolean
@@ -84,12 +85,16 @@ function Sidebar:new(id)
current_state = nil,
state_timer = nil,
state_spinner_chars = { "·", "", "", "", "", "" },
thinking_spinner_chars = { "🤯", "🙄" },
state_spinner_idx = 1,
state_extmark_id = nil,
scroll = true,
input_hint_window = nil,
ask_opts = {},
old_result_lines = {},
-- 缓存相关字段
_cached_history_lines = nil,
_history_cache_invalidated = true,
}, Sidebar)
end
@@ -1528,42 +1533,75 @@ end
---@param opts? {focus?: boolean, scroll?: boolean, backspace?: integer, callback?: fun(): nil} whether to focus the result view
function Sidebar:update_content(content, opts)
if not self.result_container or not self.result_container.bufnr then return end
-- 提前验证容器有效性,避免后续无效操作
if not Utils.is_valid_container(self.result_container) then return end
opts = vim.tbl_deep_extend("force", { focus = false, scroll = self.scroll, callback = nil }, opts or {})
local history_lines = self.get_history_lines(self.chat_history)
if content ~= nil and content ~= "" then
table.insert(history_lines, Line:new({ { "" } }))
local content_lines = vim.split(content, "\n")
for _, line in ipairs(content_lines) do
table.insert(history_lines, Line:new({ { line } }))
end
-- 缓存历史行,避免重复计算
local history_lines
if not self._cached_history_lines or self._history_cache_invalidated then
history_lines = self.get_history_lines(self.chat_history)
self._cached_history_lines = history_lines
self._history_cache_invalidated = false
else
history_lines = vim.deepcopy(self._cached_history_lines)
end
vim.defer_fn(function()
self:clear_state()
local f = function()
if not Utils.is_valid_container(self.result_container) then return end
Utils.unlock_buf(self.result_container.bufnr)
Utils.update_buffer_lines(
RESULT_BUF_HL_NAMESPACE,
self.result_container.bufnr,
self.old_result_lines,
history_lines
)
Utils.lock_buf(self.result_container.bufnr)
self.old_result_lines = history_lines
api.nvim_set_option_value("filetype", "Avante", { buf = self.result_container.bufnr })
vim.schedule(function() vim.cmd("redraw") end)
if opts.focus and not self:is_focused_on_result() then
--- set cursor to bottom of result view
xpcall(function() api.nvim_set_current_win(self.result_container.winid) end, function(err) return err end)
end
if opts.scroll then Utils.buf_scroll_to_end(self.result_container.bufnr) end
-- 批量处理内容行,减少表操作
if content ~= nil and content ~= "" then
local content_lines = vim.split(content, "\n")
local new_lines = { Line:new({ { "" } }) }
if opts.callback ~= nil then opts.callback() end
-- 预分配表大小,提升性能
for i = 1, #content_lines do
new_lines[i + 1] = Line:new({ { content_lines[i] } })
end
f()
-- 一次性扩展,而不是逐个插入
vim.list_extend(history_lines, new_lines)
end
-- 使用 vim.schedule 而不是 vim.defer_fn(0),性能更好
-- 再次检查容器有效性
if not Utils.is_valid_container(self.result_container) then return end
self:clear_state()
-- 批量更新操作
local bufnr = self.result_container.bufnr
Utils.unlock_buf(bufnr)
Utils.update_buffer_lines(RESULT_BUF_HL_NAMESPACE, bufnr, self.old_result_lines, history_lines)
-- 缓存结果行
self.old_result_lines = history_lines
-- 批量设置选项
api.nvim_set_option_value("filetype", "Avante", { buf = bufnr })
Utils.lock_buf(bufnr)
-- 处理焦点和滚动
if opts.focus and not self:is_focused_on_result() then
xpcall(function() api.nvim_set_current_win(self.result_container.winid) end, function(err)
Utils.debug("Failed to set current win:", err)
return err
end)
end
if opts.scroll then Utils.buf_scroll_to_end(bufnr) end
-- 延迟执行回调和状态渲染
if opts.callback then vim.schedule(opts.callback) end
-- 最后渲染状态
vim.schedule(function()
self:render_state()
end, 0)
-- 延迟重绘,避免阻塞
vim.defer_fn(function() vim.cmd("redraw") end, 10)
end)
return self
end
@@ -1842,8 +1880,8 @@ function Sidebar:render_state()
if self.state_extmark_id then
api.nvim_buf_del_extmark(self.result_container.bufnr, STATE_NAMESPACE, self.state_extmark_id)
end
local spinner_char = self.state_spinner_chars[self.state_spinner_idx]
self.state_spinner_idx = (self.state_spinner_idx % #self.state_spinner_chars) + 1
local spinner_chars = self.state_spinner_chars
if self.current_state == "thinking" then spinner_chars = self.thinking_spinner_chars end
local hl = "AvanteStateSpinnerGenerating"
if self.current_state == "tool calling" then hl = "AvanteStateSpinnerToolCalling" end
if self.current_state == "failed" then hl = "AvanteStateSpinnerFailed" end
@@ -1851,6 +1889,8 @@ function Sidebar:render_state()
if self.current_state == "searching" then hl = "AvanteStateSpinnerSearching" end
if self.current_state == "thinking" then hl = "AvanteStateSpinnerThinking" end
if self.current_state == "compacting" then hl = "AvanteStateSpinnerCompacting" end
local spinner_char = spinner_chars[self.state_spinner_idx]
self.state_spinner_idx = (self.state_spinner_idx % #spinner_chars) + 1
if
self.current_state ~= "generating"
and self.current_state ~= "tool calling"
@@ -1911,6 +1951,10 @@ function Sidebar:new_chat(args, cb)
if cb then cb(args) end
end
local _save_history = Utils.debounce(function(self) Path.history.save(self.code.bufnr, self.chat_history) end, 3000)
local save_history = vim.schedule_wrap(_save_history)
---@param messages avante.HistoryMessage | avante.HistoryMessage[]
function Sidebar:add_history_messages(messages)
local history_messages = Utils.get_history_messages(self.chat_history)
@@ -1934,22 +1978,30 @@ function Sidebar:add_history_messages(messages)
end
end
self.chat_history.messages = history_messages
Path.history.save(self.code.bufnr, self.chat_history)
-- 历史消息变更时,标记缓存失效
self._history_cache_invalidated = true
save_history(self)
if
self.chat_history.title == "untitled"
and #messages > 0
and messages[1].just_for_display ~= true
and messages[1].state == "generated"
then
self.chat_history.title = "generating..."
Llm.summarize_chat_thread_title(messages[1].message.content, function(title)
if title then
self.chat_history.title = title
else
self.chat_history.title = "untitled"
end
Path.history.save(self.code.bufnr, self.chat_history)
end)
-- self.chat_history.title = "generating..."
-- Llm.summarize_chat_thread_title(messages[1].message.content, function(title)
-- if title then
-- self.chat_history.title = title
-- else
-- self.chat_history.title = "untitled"
-- end
-- save_history(self)
-- end)
local first_msg_text = Utils.message_to_text(messages[1], messages)
local lines_ = vim.split(first_msg_text, "\n")
if #lines_ > 0 then
self.chat_history.title = lines_[1]
save_history(self)
end
end
local last_message = messages[#messages]
if last_message then
@@ -1964,7 +2016,10 @@ function Sidebar:add_history_messages(messages)
self.current_state = "generating"
end
end
self:update_content("")
xpcall(function() self:update_content("") end, function(err)
Utils.debug("Failed to update content:", err)
return nil
end)
end
---@param messages AvanteLLMMessage | AvanteLLMMessage[]
@@ -2135,6 +2190,8 @@ end
function Sidebar:reload_chat_history()
if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end
self.chat_history = Path.history.load(self.code.bufnr)
-- 重新加载历史时,标记缓存失效
self._history_cache_invalidated = true
end
---@return avante.HistoryMessage[]

View File

@@ -5,9 +5,11 @@
RULES
- NEVER reply the updated code.
- Always reply to the user in the same language they are using.
- Don't just provide code suggestions, use the `replace_in_file` tool to help users fulfill their needs.
- Don't just provide code suggestions, use the `replace_in_file` tool or `str_replace` tool to help users fulfill their needs.
- After the tool call is complete, please do not output the entire file content.

View File

@@ -231,6 +231,7 @@ vim.g.avante_login = vim.g.avante_login
---@field disable_tools? boolean
---@field entra? boolean
---@field hide_in_model_selector? boolean
---@field use_ReAct_prompt? boolean
---
---@class AvanteSupportedProvider: AvanteDefaultBaseProvider
---@field __inherited_from? string
@@ -396,6 +397,7 @@ vim.g.avante_login = vim.g.avante_login
---@class AvanteLLMToolParam
---@field type 'table'
---@field fields AvanteLLMToolParamField[]
---@field usage? table
---@class AvanteLLMToolParamField
---@field name string

View File

@@ -576,7 +576,7 @@ end
--- remove indentation from code: spaces or tabs
function M.remove_indentation(code)
if not code then return code end
return code:gsub("^%s*", ""):gsub("%s*$", "")
return code:gsub("%s*", "")
end
function M.relative_path(absolute)
@@ -1056,12 +1056,12 @@ function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines)
if #diffs == 0 then return end
for _, diff in ipairs(diffs) do
local lines = diff.content
-- M.debug("lines", lines)
local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines)
vim.api.nvim_buf_set_lines(bufnr, diff.start_line - 1, diff.end_line - 1, false, text_lines)
for i, line in ipairs(lines) do
line:set_highlights(ns_id, bufnr, diff.start_line + i - 2)
end
vim.cmd("redraw")
end
end
@@ -1467,6 +1467,23 @@ function M.text_to_lines(text, hl)
return lines
end
---@param thinking_text string
---@param hl string | nil
---@return avante.ui.Line[]
function M.thinking_to_lines(thinking_text, hl)
local Line = require("avante.ui.line")
local text_lines = vim.split(thinking_text, "\n")
local lines = {}
table.insert(lines, Line:new({ { M.icon("🤔 ") .. "Thought content:" } }))
table.insert(lines, Line:new({ { "" } }))
for _, text_line in ipairs(text_lines) do
local piece = { "> " .. text_line }
if hl then table.insert(piece, hl) end
table.insert(lines, Line:new({ piece }))
end
return lines
end
---@param item AvanteLLMMessageContentItem
---@param message avante.HistoryMessage
---@param messages avante.HistoryMessage[]
@@ -1475,6 +1492,9 @@ function M.message_content_item_to_lines(item, message, messages)
local Line = require("avante.ui.line")
if type(item) == "string" then return M.text_to_lines(item) end
if type(item) == "table" then
if item.type == "thinking" or item.type == "redacted_thinking" then
return M.thinking_to_lines(item.thinking or item.data or "")
end
if item.type == "text" then return M.text_to_lines(item.text) end
if item.type == "image" then
return { Line:new({ { "![image](" .. item.source.media_type .. ": " .. item.source.data .. ")" } }) }
@@ -1520,18 +1540,6 @@ function M.message_content_item_to_lines(item, message, messages)
end
end
end
elseif tool_result_message then
local tool_result = tool_result_message.message.content[1]
if tool_result.content then
local result_lines = vim.split(tool_result.content, "\n")
for idx, line in ipairs(result_lines) do
if idx ~= #result_lines then
table.insert(lines, Line:new({ { "" }, { string.format(" %s", line) } }))
else
table.insert(lines, Line:new({ { "╰─" }, { string.format(" %s", line) } }))
end
end
end
end
return lines
end

View File

@@ -0,0 +1,144 @@
local M = {}
---@param provider_conf AvanteDefaultBaseProvider
---@param opts AvantePromptOptions
---@return string
function M.get_ReAct_system_prompt(provider_conf, opts)
local system_prompt = opts.system_prompt
local disable_tools = provider_conf.disable_tools or false
if not disable_tools and opts.tools then
local tools_prompts = [[
====
TOOL USE
You have access to a set of tools that are executed upon the user's approval. You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
# Tool Use Formatting
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
<tool_name>
<parameter1_name>value1</parameter1_name>
<parameter2_name>value2</parameter2_name>
...
</tool_name>
For example:
<view>
<path>src/main.js</path>
</view>
Always adhere to this format for the tool use to ensure proper parsing and execution.
# Tools
]]
for _, tool in ipairs(opts.tools) do
local tool_prompt = ([[
## {{name}}
Description: {{description}}
Parameters:
]]):gsub("{{name}}", tool.name):gsub(
"{{description}}",
tool.get_description and tool.get_description() or (tool.description or "")
)
for _, field in ipairs(tool.param.fields) do
if field.optional then
tool_prompt = tool_prompt .. string.format(" - %s: %s\n", field.name, field.description)
else
tool_prompt = tool_prompt
.. string.format(
" - %s: (required) %s\n",
field.name,
field.get_description and field.get_description() or (field.description or "")
)
end
end
if tool.param.usage then
tool_prompt = tool_prompt
.. ("Usage:\n<{{name}}>\n"):gsub("{{([%w_]+)}}", function(name) return tool[name] end)
for k, v in pairs(tool.param.usage) do
tool_prompt = tool_prompt .. "<" .. k .. ">" .. tostring(v) .. "</" .. k .. ">\n"
end
tool_prompt = tool_prompt .. ("</{{name}}>\n"):gsub("{{([%w_]+)}}", function(name) return tool[name] end)
end
tools_prompts = tools_prompts .. tool_prompt .. "\n"
end
system_prompt = system_prompt .. tools_prompts
system_prompt = system_prompt
.. [[
# Tool Use Examples
## Example 1: Requesting to execute a command
<bash>
<path>./src</path>
<command>npm run dev</command>
</bash>
## Example 2: Requesting to create a new file
<write_to_file>
<path>src/frontend-config.json</path>
<content>
{
"apiEndpoint": "https://api.example.com",
"theme": {
"primaryColor": "#007bff",
"secondaryColor": "#6c757d",
"fontFamily": "Arial, sans-serif"
},
"features": {
"darkMode": true,
"notifications": true,
"analytics": false
},
"version": "1.0.0"
}
</content>
</write_to_file>
## Example 3: Requesting to make targeted edits to a file
<replace_in_file>
<path>src/components/App.tsx</path>
<diff>
<<<<<<< SEARCH
import React from 'react';
=======
import React, { useState } from 'react';
>>>>>>> REPLACE
<<<<<<< SEARCH
function handleSubmit() {
saveData();
setLoading(false);
}
=======
>>>>>>> REPLACE
<<<<<<< SEARCH
return (
<div>
=======
function handleSubmit() {
saveData();
setLoading(false);
}
return (
<div>
>>>>>>> REPLACE
</diff>
</replace_in_file>
]]
end
return system_prompt
end
return M

View File

@@ -53,21 +53,21 @@ describe("llm_tools", function()
describe("ls", function()
it("should list files in directory", function()
local result, err = ls({ rel_path = ".", max_depth = 1 })
local result, err = ls({ path = ".", max_depth = 1 })
assert.is_nil(err)
assert.falsy(result:find("avante.nvim"))
assert.truthy(result:find("test.txt"))
assert.falsy(result:find("test1.txt"))
end)
it("should list files in directory with depth", function()
local result, err = ls({ rel_path = ".", max_depth = 2 })
local result, err = ls({ path = ".", max_depth = 2 })
assert.is_nil(err)
assert.falsy(result:find("avante.nvim"))
assert.truthy(result:find("test.txt"))
assert.truthy(result:find("test1.txt"))
end)
it("should list files respecting gitignore", function()
local result, err = ls({ rel_path = ".", max_depth = 2 })
local result, err = ls({ path = ".", max_depth = 2 })
assert.is_nil(err)
assert.falsy(result:find("avante.nvim"))
assert.truthy(result:find("test.txt"))
@@ -102,7 +102,7 @@ describe("llm_tools", function()
describe("create_dir", function()
it("should create new directory", function()
LlmTools.create_dir({ rel_path = "new_dir" }, nil, function(success, err)
LlmTools.create_dir({ path = "new_dir" }, nil, function(success, err)
assert.is_nil(err)
assert.is_true(success)
@@ -114,7 +114,7 @@ describe("llm_tools", function()
describe("delete_file", function()
it("should delete existing file", function()
LlmTools.delete_file({ rel_path = "test.txt" }, nil, function(success, err)
LlmTools.delete_file({ path = "test.txt" }, nil, function(success, err)
assert.is_nil(err)
assert.is_true(success)
@@ -147,28 +147,28 @@ describe("llm_tools", function()
file:write("this is nothing")
file:close()
local result, err = grep({ rel_path = ".", query = "Searchable", case_sensitive = false })
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false })
assert.is_nil(err)
assert.truthy(result:find("searchable.txt"))
assert.falsy(result:find("nothing.txt"))
local result2, err2 = grep({ rel_path = ".", query = "searchable", case_sensitive = true })
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true })
assert.is_nil(err2)
assert.truthy(result2:find("searchable.txt"))
assert.falsy(result2:find("nothing.txt"))
local result3, err3 = grep({ rel_path = ".", query = "Searchable", case_sensitive = true })
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true })
assert.is_nil(err3)
assert.falsy(result3:find("searchable.txt"))
assert.falsy(result3:find("nothing.txt"))
local result4, err4 = grep({ rel_path = ".", query = "searchable", case_sensitive = false })
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false })
assert.is_nil(err4)
assert.truthy(result4:find("searchable.txt"))
assert.falsy(result4:find("nothing.txt"))
local result5, err5 = grep({
rel_path = ".",
path = ".",
query = "searchable",
case_sensitive = false,
exclude_pattern = "search*",
@@ -191,7 +191,7 @@ describe("llm_tools", function()
file:write("content for ag test")
file:close()
local result, err = grep({ rel_path = ".", query = "ag test" })
local result, err = grep({ path = ".", query = "ag test" })
assert.is_nil(err)
assert.is_string(result)
assert.truthy(result:find("ag_test.txt"))
@@ -215,28 +215,28 @@ describe("llm_tools", function()
file:write("this is nothing")
file:close()
local result, err = grep({ rel_path = ".", query = "Searchable", case_sensitive = false })
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false })
assert.is_nil(err)
assert.truthy(result:find("searchable.txt"))
assert.falsy(result:find("nothing.txt"))
local result2, err2 = grep({ rel_path = ".", query = "searchable", case_sensitive = true })
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true })
assert.is_nil(err2)
assert.truthy(result2:find("searchable.txt"))
assert.falsy(result2:find("nothing.txt"))
local result3, err3 = grep({ rel_path = ".", query = "Searchable", case_sensitive = true })
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true })
assert.is_nil(err3)
assert.falsy(result3:find("searchable.txt"))
assert.falsy(result3:find("nothing.txt"))
local result4, err4 = grep({ rel_path = ".", query = "searchable", case_sensitive = false })
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false })
assert.is_nil(err4)
assert.truthy(result4:find("searchable.txt"))
assert.falsy(result4:find("nothing.txt"))
local result5, err5 = grep({
rel_path = ".",
path = ".",
query = "searchable",
case_sensitive = false,
exclude_pattern = "search*",
@@ -250,18 +250,18 @@ describe("llm_tools", function()
-- Mock exepath to return nothing
vim.fn.exepath = function() return "" end
local result, err = grep({ rel_path = ".", query = "test" })
local result, err = grep({ path = ".", query = "test" })
assert.equals("", result)
assert.equals("No search command found", err)
end)
it("should respect path permissions", function()
local result, err = grep({ rel_path = "../outside_project", query = "test" })
local result, err = grep({ path = "../outside_project", query = "test" })
assert.truthy(err:find("No permission to access path"))
end)
it("should handle non-existent paths", function()
local result, err = grep({ rel_path = "non_existent_dir", query = "test" })
local result, err = grep({ path = "non_existent_dir", query = "test" })
assert.equals("", result)
assert.truthy(err)
assert.truthy(err:find("No such file or directory"))
@@ -270,14 +270,14 @@ describe("llm_tools", function()
describe("bash", function()
-- it("should execute command and return output", function()
-- bash({ rel_path = ".", command = "echo 'test'" }, nil, function(result, err)
-- bash({ path = ".", command = "echo 'test'" }, nil, function(result, err)
-- assert.is_nil(err)
-- assert.equals("test\n", result)
-- end)
-- end)
it("should return error when running outside current directory", function()
bash({ rel_path = "../outside_project", command = "echo 'test'" }, nil, function(result, err)
bash({ path = "../outside_project", command = "echo 'test'" }, nil, function(result, err)
assert.is_false(result)
assert.truthy(err)
assert.truthy(err:find("No permission to access path"))
@@ -289,7 +289,7 @@ describe("llm_tools", function()
it("should execute Python code and return output", function()
LlmTools.python(
{
rel_path = ".",
path = ".",
code = "print('Hello from Python')",
},
nil,
@@ -303,7 +303,7 @@ describe("llm_tools", function()
it("should handle Python errors", function()
LlmTools.python(
{
rel_path = ".",
path = ".",
code = "print(undefined_variable)",
},
nil,
@@ -318,7 +318,7 @@ describe("llm_tools", function()
it("should respect path permissions", function()
LlmTools.python(
{
rel_path = "../outside_project",
path = "../outside_project",
code = "print('test')",
},
nil,
@@ -332,7 +332,7 @@ describe("llm_tools", function()
it("should handle non-existent paths", function()
LlmTools.python(
{
rel_path = "non_existent_dir",
path = "non_existent_dir",
code = "print('test')",
},
nil,
@@ -347,7 +347,7 @@ describe("llm_tools", function()
os.execute("docker image rm python:3.12-slim")
LlmTools.python(
{
rel_path = ".",
path = ".",
code = "print('Hello from custom container')",
container_image = "python:3.12-slim",
},
@@ -370,7 +370,7 @@ describe("llm_tools", function()
os.execute("touch " .. test_dir .. "/nested/file4.lua")
-- Test for lua files in the root
local result, err = glob({ rel_path = ".", pattern = "*.lua" })
local result, err = glob({ path = ".", pattern = "*.lua" })
assert.is_nil(err)
local files = vim.json.decode(result).matches
assert.equals(2, #files)
@@ -380,7 +380,7 @@ describe("llm_tools", function()
assert.falsy(vim.tbl_contains(files, test_dir .. "/nested/file4.lua"))
-- Test with recursive pattern
local result2, err2 = glob({ rel_path = ".", pattern = "**/*.lua" })
local result2, err2 = glob({ path = ".", pattern = "**/*.lua" })
assert.is_nil(err2)
local files2 = vim.json.decode(result2).matches
assert.equals(3, #files2)
@@ -390,13 +390,13 @@ describe("llm_tools", function()
end)
it("should respect path permissions", function()
local result, err = glob({ rel_path = "../outside_project", pattern = "*.txt" })
local result, err = glob({ path = "../outside_project", pattern = "*.txt" })
assert.equals("", result)
assert.truthy(err:find("No permission to access path"))
end)
it("should handle patterns without matches", function()
local result, err = glob({ rel_path = ".", pattern = "*.nonexistent" })
local result, err = glob({ path = ".", pattern = "*.nonexistent" })
assert.is_nil(err)
local files = vim.json.decode(result).matches
assert.equals(0, #files)
@@ -411,7 +411,7 @@ describe("llm_tools", function()
os.execute("touch " .. test_dir .. "/test_dir1/notignored1.lua")
os.execute("touch " .. test_dir .. "/test_dir1/notignored2.lua")
local result, err = glob({ rel_path = ".", pattern = "**/*.lua" })
local result, err = glob({ path = ".", pattern = "**/*.lua" })
assert.is_nil(err)
local files = vim.json.decode(result).matches