feat: ReAct tool calling (#2104)
This commit is contained in:
@@ -671,6 +671,7 @@ M.BASE_PROVIDER_KEYS = {
|
||||
"disable_tools",
|
||||
"entra",
|
||||
"hide_in_model_selector",
|
||||
"use_ReAct_prompt",
|
||||
}
|
||||
|
||||
return M
|
||||
|
||||
517
lua/avante/libs/xmlparser.lua
Normal file
517
lua/avante/libs/xmlparser.lua
Normal file
@@ -0,0 +1,517 @@
|
||||
-- XML Parser for Lua
|
||||
local XmlParser = {}
|
||||
|
||||
-- 流式解析器状态
|
||||
local StreamParser = {}
|
||||
StreamParser.__index = StreamParser
|
||||
|
||||
-- 创建新的流式解析器实例
|
||||
function StreamParser.new()
|
||||
local parser = {
|
||||
buffer = "", -- 缓冲区存储未处理的内容
|
||||
stack = {}, -- 标签栈
|
||||
results = {}, -- 已完成的元素列表
|
||||
current = nil, -- 当前正在处理的元素
|
||||
root = nil, -- 当前根元素
|
||||
position = 1, -- 当前解析位置
|
||||
state = "ready", -- 解析状态: ready, parsing, incomplete, error
|
||||
incomplete_tag = nil, -- 未完成的标签信息
|
||||
last_error = nil, -- 最后的错误信息
|
||||
}
|
||||
setmetatable(parser, StreamParser)
|
||||
return parser
|
||||
end
|
||||
|
||||
-- 重置解析器状态
|
||||
function StreamParser:reset()
|
||||
self.buffer = ""
|
||||
self.stack = {}
|
||||
self.results = {}
|
||||
self.current = nil
|
||||
self.root = nil
|
||||
self.position = 1
|
||||
self.state = "ready"
|
||||
self.incomplete_tag = nil
|
||||
self.last_error = nil
|
||||
end
|
||||
|
||||
-- 获取解析器状态信息
|
||||
function StreamParser:getStatus()
|
||||
return {
|
||||
state = self.state,
|
||||
completed_elements = #self.results,
|
||||
stack_depth = #self.stack,
|
||||
buffer_size = #self.buffer,
|
||||
incomplete_tag = self.incomplete_tag,
|
||||
last_error = self.last_error,
|
||||
has_incomplete = self.state == "incomplete" or self.incomplete_tag ~= nil,
|
||||
}
|
||||
end
|
||||
|
||||
-- 辅助函数:去除字符串首尾空白
|
||||
local function trim(s) return s:match("^%s*(.-)%s*$") end
|
||||
|
||||
-- 辅助函数:解析属性
|
||||
local function parseAttributes(attrStr)
|
||||
local attrs = {}
|
||||
if not attrStr or attrStr == "" then return attrs end
|
||||
|
||||
-- 匹配属性模式:name="value" 或 name='value'
|
||||
for name, value in attrStr:gmatch("([-_%w]+)%s*=%s*[\"']([^\"']*)[\"']") do
|
||||
attrs[name] = value
|
||||
end
|
||||
return attrs
|
||||
end
|
||||
|
||||
-- 辅助函数:HTML实体解码
|
||||
local function decodeEntities(str)
|
||||
local entities = {
|
||||
["<"] = "<",
|
||||
[">"] = ">",
|
||||
["&"] = "&",
|
||||
["""] = '"',
|
||||
["'"] = "'",
|
||||
}
|
||||
|
||||
for entity, char in pairs(entities) do
|
||||
str = str:gsub(entity, char)
|
||||
end
|
||||
|
||||
-- 处理数字实体 { 和 
|
||||
str = str:gsub("&#(%d+);", function(n)
|
||||
local num = tonumber(n)
|
||||
return num and string.char(num) or ""
|
||||
end)
|
||||
str = str:gsub("&#x(%x+);", function(n)
|
||||
local num = tonumber(n, 16)
|
||||
return num and string.char(num) or ""
|
||||
end)
|
||||
|
||||
return str
|
||||
end
|
||||
|
||||
-- 检查是否为有效的XML标签
|
||||
local function isValidXmlTag(tag, xmlContent, tagStart)
|
||||
-- 排除明显不是XML标签的内容,比如数学表达式 < 或 >
|
||||
-- 检查标签是否包含合理的XML标签格式
|
||||
if not tag:match("^<[^<>]*>$") then return false end
|
||||
|
||||
-- 检查是否是合法的标签格式
|
||||
if tag:match("^</[-_%w]+>$") then return true end -- 结束标签
|
||||
if tag:match("^<[-_%w]+[^>]*/>$") then return true end -- 自闭合标签
|
||||
if tag:match("^<[-_%w]+[^>]*>$") then
|
||||
-- 对于开始标签,进行额外的上下文检查
|
||||
local tagName = tag:match("^<([-_%w]+)")
|
||||
|
||||
-- 检查是否存在对应的结束标签
|
||||
local closingTag = "</" .. tagName .. ">"
|
||||
local hasClosingTag = xmlContent:find(closingTag, tagStart)
|
||||
|
||||
-- 如果是单个标签且没有结束标签,可能是文本中的引用
|
||||
if not hasClosingTag then
|
||||
-- 检查前后文本,如果像是在描述而不是实际的XML结构,则不认为是有效标签
|
||||
local beforeText = xmlContent:sub(math.max(1, tagStart - 50), tagStart - 1)
|
||||
local afterText = xmlContent:sub(tagStart + #tag, math.min(#xmlContent, tagStart + #tag + 50))
|
||||
|
||||
-- 如果前面有"provided in the"、"in the"等描述性文字,可能是文本引用
|
||||
if
|
||||
beforeText:match("provided in the%s*$")
|
||||
or beforeText:match("in the%s*$")
|
||||
or beforeText:match("see the%s*$")
|
||||
or beforeText:match("use the%s*$")
|
||||
then
|
||||
return false
|
||||
end
|
||||
|
||||
-- 如果后面紧跟着"tag"等描述性词汇,可能是文本引用
|
||||
if afterText:match("^%s*tag") then return false end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
-- 流式解析器方法:添加数据到缓冲区并解析
|
||||
function StreamParser:addData(data)
|
||||
if not data or data == "" then return end
|
||||
|
||||
self.buffer = self.buffer .. data
|
||||
self:parseBuffer()
|
||||
end
|
||||
|
||||
-- 获取当前解析深度
|
||||
function StreamParser:getCurrentDepth() return #self.stack end
|
||||
|
||||
-- 解析缓冲区中的数据
|
||||
function StreamParser:parseBuffer()
|
||||
self.state = "parsing"
|
||||
|
||||
while self.position <= #self.buffer do
|
||||
local remaining = self.buffer:sub(self.position)
|
||||
|
||||
-- 查找下一个标签
|
||||
local tagStart, tagEnd = remaining:find("<[^>]*>")
|
||||
|
||||
if not tagStart then
|
||||
-- 检查是否有未完成的开始标签(以<开始但没有>结束)
|
||||
local incompleteStart = remaining:find("<[^>]*$")
|
||||
if incompleteStart then
|
||||
local incompleteContent = remaining:sub(incompleteStart)
|
||||
-- 确保这确实是一个未完成的标签,而不是文本中的<符号
|
||||
if incompleteContent:match("^<[%w_-]") then
|
||||
-- 尝试解析未完成的开始标签
|
||||
local tagName = incompleteContent:match("^<([%w_-]+)")
|
||||
if tagName then
|
||||
-- 处理未完成标签前的文本
|
||||
if incompleteStart > 1 then
|
||||
local precedingText = trim(remaining:sub(1, incompleteStart - 1))
|
||||
if precedingText ~= "" then
|
||||
if self.current then
|
||||
-- 如果当前在某个标签内,添加到该标签的文本内容
|
||||
precedingText = decodeEntities(precedingText)
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. precedingText
|
||||
else
|
||||
self.current._text = precedingText
|
||||
end
|
||||
else
|
||||
-- 如果是顶层文本,作为独立元素添加
|
||||
local textElement = {
|
||||
_name = "_text",
|
||||
_text = decodeEntities(precedingText),
|
||||
}
|
||||
table.insert(self.results, textElement)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 创建未完成的元素
|
||||
local element = {
|
||||
_name = tagName,
|
||||
_attr = {},
|
||||
_state = "incomplete_start_tag",
|
||||
}
|
||||
|
||||
if not self.root then
|
||||
self.root = element
|
||||
self.current = element
|
||||
elseif self.current then
|
||||
table.insert(self.stack, self.current)
|
||||
if not self.current[tagName] then self.current[tagName] = {} end
|
||||
table.insert(self.current[tagName], element)
|
||||
self.current = element
|
||||
end
|
||||
|
||||
self.incomplete_tag = {
|
||||
start_pos = self.position + incompleteStart - 1,
|
||||
content = incompleteContent,
|
||||
element = element,
|
||||
}
|
||||
self.state = "incomplete"
|
||||
return
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 处理剩余的文本内容
|
||||
if remaining ~= "" then
|
||||
if self.current then
|
||||
-- 检查当前深度,如果在第一层子元素中,保持原始文本
|
||||
local currentDepth = #self.stack
|
||||
if currentDepth >= 1 then
|
||||
-- 在第一层子元素中,保持原始文本不变
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. remaining
|
||||
else
|
||||
self.current._text = remaining
|
||||
end
|
||||
else
|
||||
-- 在根级别,进行正常的文本处理
|
||||
local text = trim(remaining)
|
||||
if text ~= "" then
|
||||
text = decodeEntities(text)
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. text
|
||||
else
|
||||
self.current._text = text
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- 如果是顶层文本,作为独立元素添加
|
||||
local text = trim(remaining)
|
||||
if text ~= "" then
|
||||
local textElement = {
|
||||
_name = "_text",
|
||||
_text = decodeEntities(text),
|
||||
}
|
||||
table.insert(self.results, textElement)
|
||||
end
|
||||
end
|
||||
end
|
||||
self.position = #self.buffer + 1
|
||||
break
|
||||
end
|
||||
|
||||
local tag = remaining:sub(tagStart, tagEnd)
|
||||
local actualTagStart = self.position + tagStart - 1
|
||||
local actualTagEnd = self.position + tagEnd - 1
|
||||
|
||||
-- 检查是否为有效的XML标签
|
||||
if not isValidXmlTag(tag, self.buffer, actualTagStart) then
|
||||
-- 如果不是有效标签,将其作为普通文本处理
|
||||
local text = remaining:sub(1, tagEnd)
|
||||
if text ~= "" then
|
||||
if self.current then
|
||||
-- 检查当前深度,如果在第一层子元素中,保持原始文本
|
||||
local currentDepth = #self.stack
|
||||
if currentDepth >= 1 then
|
||||
-- 在第一层子元素中,保持原始文本不变
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. text
|
||||
else
|
||||
self.current._text = text
|
||||
end
|
||||
else
|
||||
-- 在根级别,进行正常的文本处理
|
||||
text = trim(text)
|
||||
if text ~= "" then
|
||||
text = decodeEntities(text)
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. text
|
||||
else
|
||||
self.current._text = text
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- 顶层文本作为独立元素
|
||||
text = trim(text)
|
||||
if text ~= "" then
|
||||
local textElement = {
|
||||
_name = "_text",
|
||||
_text = decodeEntities(text),
|
||||
}
|
||||
table.insert(self.results, textElement)
|
||||
end
|
||||
end
|
||||
end
|
||||
self.position = actualTagEnd + 1
|
||||
goto continue
|
||||
end
|
||||
|
||||
-- 处理标签前的文本内容
|
||||
if tagStart > 1 then
|
||||
local precedingText = remaining:sub(1, tagStart - 1)
|
||||
if precedingText ~= "" then
|
||||
if self.current then
|
||||
-- 如果当前在某个标签内,添加到该标签的文本内容
|
||||
-- 检查当前深度,如果在第一层子元素中,不要进行实体解码和trim
|
||||
local currentDepth = #self.stack
|
||||
if currentDepth >= 1 then
|
||||
-- 在第一层子元素中,保持原始文本不变
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. precedingText
|
||||
else
|
||||
self.current._text = precedingText
|
||||
end
|
||||
else
|
||||
-- 在根级别,进行正常的文本处理
|
||||
precedingText = trim(precedingText)
|
||||
if precedingText ~= "" then
|
||||
precedingText = decodeEntities(precedingText)
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. precedingText
|
||||
else
|
||||
self.current._text = precedingText
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- 如果是顶层文本,作为独立元素添加
|
||||
precedingText = trim(precedingText)
|
||||
if precedingText ~= "" then
|
||||
local textElement = {
|
||||
_name = "_text",
|
||||
_text = decodeEntities(precedingText),
|
||||
}
|
||||
table.insert(self.results, textElement)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 检查当前深度,如果已经在第一层子元素中,将所有标签作为文本处理
|
||||
local currentDepth = #self.stack
|
||||
if currentDepth >= 1 then
|
||||
-- 检查是否是当前元素的结束标签
|
||||
if tag:match("^</[-_%w]+>$") and self.current then
|
||||
local tagName = tag:match("^</([-_%w]+)>$")
|
||||
if self.current._name == tagName then
|
||||
-- 这是当前元素的结束标签,正常处理
|
||||
if not self:processTag(tag) then
|
||||
self.state = "error"
|
||||
return
|
||||
end
|
||||
else
|
||||
-- 不是当前元素的结束标签,作为文本处理
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. tag
|
||||
else
|
||||
self.current._text = tag
|
||||
end
|
||||
end
|
||||
else
|
||||
-- 在第一层子元素中,将标签作为文本处理
|
||||
if self.current then
|
||||
if self.current._text then
|
||||
self.current._text = self.current._text .. tag
|
||||
else
|
||||
self.current._text = tag
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- 处理标签
|
||||
if not self:processTag(tag) then
|
||||
self.state = "error"
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
self.position = actualTagEnd + 1
|
||||
::continue::
|
||||
end
|
||||
|
||||
-- 检查当前是否有未关闭的元素
|
||||
if self.current and self.current._state ~= "complete" then
|
||||
self.current._state = "incomplete_unclosed"
|
||||
self.state = "incomplete"
|
||||
elseif self.state ~= "incomplete" and self.state ~= "error" then
|
||||
self.state = "ready"
|
||||
end
|
||||
end
|
||||
|
||||
-- 处理单个标签
|
||||
function StreamParser:processTag(tag)
|
||||
if tag:match("^</[-_%w]+>$") then
|
||||
-- 结束标签
|
||||
local tagName = tag:match("^</([-_%w]+)>$")
|
||||
if self.current and self.current._name == tagName then
|
||||
-- 标记当前元素为完成状态
|
||||
self.current._state = "complete"
|
||||
self.current = table.remove(self.stack)
|
||||
-- 只有当栈为空且当前元素也为空时,说明完成了一个根级元素
|
||||
if #self.stack == 0 and not self.current and self.root then
|
||||
table.insert(self.results, self.root)
|
||||
self.root = nil
|
||||
end
|
||||
else
|
||||
self.last_error = "Mismatched closing tag: " .. tagName
|
||||
return false
|
||||
end
|
||||
elseif tag:match("^<[-_%w]+[^>]*/>$") then
|
||||
-- 自闭合标签
|
||||
local tagName, attrs = tag:match("^<([-_%w]+)([^>]*)/>")
|
||||
local element = {
|
||||
_name = tagName,
|
||||
_attr = parseAttributes(attrs),
|
||||
_state = "complete",
|
||||
children = {},
|
||||
}
|
||||
|
||||
if not self.root then
|
||||
-- 直接作为根级元素添加到结果中
|
||||
table.insert(self.results, element)
|
||||
elseif self.current then
|
||||
if not self.current.children then self.current.children = {} end
|
||||
table.insert(self.current.children, element)
|
||||
end
|
||||
elseif tag:match("^<[-_%w]+[^>]*>$") then
|
||||
-- 开始标签
|
||||
local tagName, attrs = tag:match("^<([-_%w]+)([^>]*)>")
|
||||
local element = {
|
||||
_name = tagName,
|
||||
_attr = parseAttributes(attrs),
|
||||
_state = "incomplete_open", -- 标记为未完成(等待结束标签)
|
||||
children = {},
|
||||
}
|
||||
|
||||
if not self.root then
|
||||
self.root = element
|
||||
self.current = element
|
||||
elseif self.current then
|
||||
table.insert(self.stack, self.current)
|
||||
if not self.current.children then self.current.children = {} end
|
||||
table.insert(self.current.children, element)
|
||||
self.current = element
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
-- 获取所有元素(已完成的和当前正在处理的)
|
||||
function StreamParser:getAllElements()
|
||||
local all_elements = {}
|
||||
|
||||
-- 添加所有已完成的元素
|
||||
for _, element in ipairs(self.results) do
|
||||
table.insert(all_elements, element)
|
||||
end
|
||||
|
||||
-- 如果有当前正在处理的元素,也添加进去
|
||||
if self.root then table.insert(all_elements, self.root) end
|
||||
|
||||
return all_elements
|
||||
end
|
||||
|
||||
-- 获取已完成的元素(保留向后兼容性)
|
||||
function StreamParser:getCompletedElements() return self.results end
|
||||
|
||||
-- 获取当前未完成的元素(保留向后兼容性)
|
||||
function StreamParser:getCurrentElement() return self.root end
|
||||
|
||||
-- 强制完成解析(将未完成的内容作为已完成处理)
|
||||
function StreamParser:finalize()
|
||||
-- 首先处理当前正在解析的元素
|
||||
if self.current then
|
||||
-- 递归设置所有未完成元素的状态
|
||||
local function markIncompleteElements(element)
|
||||
if element._state and element._state:match("incomplete") then element._state = "incomplete_unclosed" end
|
||||
-- 处理 children 数组中的子元素
|
||||
if element.children and type(element.children) == "table" then
|
||||
for _, child in ipairs(element.children) do
|
||||
if type(child) == "table" and child._name then markIncompleteElements(child) end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 标记当前元素及其所有子元素为未完成状态,但保持层次结构
|
||||
markIncompleteElements(self.current)
|
||||
|
||||
-- 向上遍历栈,标记所有祖先元素
|
||||
for i = #self.stack, 1, -1 do
|
||||
local ancestor = self.stack[i]
|
||||
if ancestor._state and ancestor._state:match("incomplete") then ancestor._state = "incomplete_unclosed" end
|
||||
end
|
||||
end
|
||||
|
||||
-- 只有当存在根元素时才添加到结果中
|
||||
if self.root then
|
||||
table.insert(self.results, self.root)
|
||||
self.root = nil
|
||||
end
|
||||
|
||||
self.current = nil
|
||||
self.stack = {}
|
||||
self.state = "ready"
|
||||
self.incomplete_tag = nil
|
||||
end
|
||||
|
||||
-- 创建流式解析器实例
|
||||
function XmlParser.createStreamParser() return StreamParser.new() end
|
||||
|
||||
return XmlParser
|
||||
@@ -25,7 +25,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
||||
---@param cb fun(title: string | nil): nil
|
||||
function M.summarize_chat_thread_title(content, cb)
|
||||
local system_prompt =
|
||||
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]]
|
||||
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium. /no_think]]
|
||||
local response_content = ""
|
||||
local provider = Providers.get_memory_summary_provider()
|
||||
M.curl({
|
||||
@@ -761,7 +761,7 @@ function M._stream(opts)
|
||||
},
|
||||
})
|
||||
end
|
||||
opts.on_messages_add(messages)
|
||||
if opts.on_messages_add then opts.on_messages_add(messages) end
|
||||
local the_last_tool_use = tool_use_list[#tool_use_list]
|
||||
if the_last_tool_use and the_last_tool_use.name == "attempt_completion" then
|
||||
opts.on_stop({ reason = "complete" })
|
||||
@@ -854,12 +854,8 @@ function M._stream(opts)
|
||||
if is_break then break end
|
||||
::continue::
|
||||
end
|
||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
for _, tool_use in vim.spairs(tool_use_list) do
|
||||
table.insert(sorted_tool_use_list, tool_use)
|
||||
end
|
||||
if stop_opts.reason == "complete" and Config.mode == "agentic" then
|
||||
if #sorted_tool_use_list == 0 then
|
||||
if #tool_use_list == 0 then
|
||||
local completed_attempt_completion_tool_use = nil
|
||||
for idx = #history_messages, 1, -1 do
|
||||
local message = history_messages[idx]
|
||||
@@ -896,7 +892,7 @@ function M._stream(opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
if stop_opts.reason == "tool_use" then return handle_next_tool_use(sorted_tool_use_list, 1, {}) end
|
||||
if stop_opts.reason == "tool_use" then return handle_next_tool_use(tool_use_list, 1, {}) end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
|
||||
if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end
|
||||
|
||||
@@ -32,6 +32,10 @@ M.param = {
|
||||
optional = true,
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
result = "The result of the task. Formulate this result in a way that is final and does not require further input from the user. Don't end your result with questions or offers for further assistance.",
|
||||
command = "A CLI command to execute to show a live demo of the result to the user. For example, use `open index.html` to display a created html website, or `open localhost:3000` to display a locally running development server. But DO NOT use commands like `echo` or `cat` that merely print text. This command should be valid for the current operating system. Ensure the command is properly formatted and does not contain any harmful instructions.",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
|
||||
@@ -183,7 +183,7 @@ M.param = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the project directory, as cwd",
|
||||
type = "string",
|
||||
},
|
||||
@@ -193,6 +193,10 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the project directory, as cwd",
|
||||
command = "Command to run",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
@@ -210,9 +214,9 @@ M.returns = {
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string, command: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, command: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
|
||||
if on_log then on_log("command: " .. opts.command) end
|
||||
|
||||
@@ -49,6 +49,9 @@ M.param = {
|
||||
},
|
||||
},
|
||||
required = { "prompt" },
|
||||
usage = {
|
||||
prompt = "The task for the agent to perform",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
|
||||
@@ -19,6 +19,9 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "The path to the file in the current project scope",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
|
||||
@@ -18,11 +18,15 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the project directory, as cwd",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
pattern = "Glob pattern",
|
||||
path = "Relative path to the project directory, as cwd",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
@@ -40,9 +44,9 @@ M.returns = {
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string, pattern: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, pattern: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if on_log then on_log("pattern: " .. opts.pattern) end
|
||||
|
||||
@@ -15,7 +15,7 @@ M.param = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the project directory",
|
||||
type = "string",
|
||||
},
|
||||
@@ -44,6 +44,13 @@ M.param = {
|
||||
optional = true,
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the project directory",
|
||||
query = "Query to search for",
|
||||
case_sensitive = "Whether to search case sensitively",
|
||||
include_pattern = "Glob pattern to include files",
|
||||
exclude_pattern = "Glob pattern to exclude files",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
@@ -61,9 +68,9 @@ M.returns = {
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, query: string, case_sensitive?: boolean, include_pattern?: string, exclude_pattern?: string }>
|
||||
function M.func(opts, on_log)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ local Helpers = require("avante.llm_tools.helpers")
|
||||
|
||||
local M = {}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.read_file_toplevel_symbols(opts, on_log)
|
||||
local RepoMap = require("avante.repo_map")
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if not Path:new(abs_path):exists() then return "", "File does not exists: " .. abs_path end
|
||||
@@ -121,13 +121,13 @@ function M.write_global_file(opts, on_log, on_complete)
|
||||
end)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, new_path: string }>
|
||||
function M.rename_file(opts, on_log, on_complete)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
||||
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
|
||||
local new_abs_path = Helpers.get_abs_path(opts.new_rel_path)
|
||||
local new_abs_path = Helpers.get_abs_path(opts.new_path)
|
||||
if on_log then on_log(abs_path .. " -> " .. new_abs_path) end
|
||||
if not Helpers.has_permission_to_access(new_abs_path) then
|
||||
return false, "No permission to access path: " .. new_abs_path
|
||||
@@ -147,13 +147,13 @@ function M.rename_file(opts, on_log, on_complete)
|
||||
)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, new_path: string }>
|
||||
function M.copy_file(opts, on_log)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
||||
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
|
||||
local new_abs_path = Helpers.get_abs_path(opts.new_rel_path)
|
||||
local new_abs_path = Helpers.get_abs_path(opts.new_path)
|
||||
if not Helpers.has_permission_to_access(new_abs_path) then
|
||||
return false, "No permission to access path: " .. new_abs_path
|
||||
end
|
||||
@@ -163,9 +163,9 @@ function M.copy_file(opts, on_log)
|
||||
return true, nil
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.delete_file(opts, on_log, on_complete)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
||||
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
|
||||
@@ -181,9 +181,9 @@ function M.delete_file(opts, on_log, on_complete)
|
||||
end)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.create_dir(opts, on_log, on_complete)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
@@ -198,13 +198,13 @@ function M.create_dir(opts, on_log, on_complete)
|
||||
end)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string, new_rel_path: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, new_path: string }>
|
||||
function M.rename_dir(opts, on_log, on_complete)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
|
||||
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
|
||||
local new_abs_path = Helpers.get_abs_path(opts.new_rel_path)
|
||||
local new_abs_path = Helpers.get_abs_path(opts.new_path)
|
||||
if not Helpers.has_permission_to_access(new_abs_path) then
|
||||
return false, "No permission to access path: " .. new_abs_path
|
||||
end
|
||||
@@ -224,9 +224,9 @@ function M.rename_dir(opts, on_log, on_complete)
|
||||
)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string }>
|
||||
---@type AvanteLLMToolFunc<{ path: string }>
|
||||
function M.delete_dir(opts, on_log, on_complete)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
|
||||
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
|
||||
@@ -552,9 +552,9 @@ function M.rag_search(opts, on_log, on_complete)
|
||||
)
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }>
|
||||
---@type AvanteLLMToolFunc<{ code: string, path: string, container_image?: string }>
|
||||
function M.python(opts, on_log, on_complete)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return nil, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end
|
||||
if on_log then on_log("cwd: " .. abs_path) end
|
||||
@@ -634,6 +634,18 @@ function M.get_tools(user_input, history_messages)
|
||||
:totable()
|
||||
end
|
||||
|
||||
function M.get_tool_names()
|
||||
local custom_tools = Config.custom_tools
|
||||
if type(custom_tools) == "function" then custom_tools = custom_tools() end
|
||||
---@type AvanteLLMTool[]
|
||||
local unfiltered_tools = vim.list_extend(vim.list_extend({}, M._tools), custom_tools)
|
||||
local tool_names = {}
|
||||
for _, tool in ipairs(unfiltered_tools) do
|
||||
table.insert(tool_names, tool.name)
|
||||
end
|
||||
return tool_names
|
||||
end
|
||||
|
||||
---@type AvanteLLMTool[]
|
||||
M._tools = {
|
||||
require("avante.llm_tools.replace_in_file"),
|
||||
@@ -652,6 +664,9 @@ M._tools = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
query = "Query to search",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -679,11 +694,15 @@ M._tools = {
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the project directory, as cwd",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
code = "Python code to run",
|
||||
path = "Relative path to the project directory, as cwd",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -711,6 +730,9 @@ M._tools = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
scope = "Scope for the git diff (e.g. specific files or directories)",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -744,6 +766,10 @@ M._tools = {
|
||||
optional = true,
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
message = "Commit message to use",
|
||||
scope = "Scope for staging files (e.g. specific files or directories)",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -768,11 +794,14 @@ M._tools = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the file in current project scope",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the file in current project scope",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -790,7 +819,7 @@ M._tools = {
|
||||
},
|
||||
require("avante.llm_tools.str_replace"),
|
||||
require("avante.llm_tools.view"),
|
||||
require("avante.llm_tools.create"),
|
||||
require("avante.llm_tools.write_to_file"),
|
||||
require("avante.llm_tools.insert"),
|
||||
require("avante.llm_tools.undo_edit"),
|
||||
{
|
||||
@@ -820,6 +849,9 @@ M._tools = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
abs_path = "Absolute path to the file in global scope",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -867,6 +899,10 @@ M._tools = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
abs_path = "The path to the file in the current project scope",
|
||||
content = "The content to write to the file",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -889,16 +925,20 @@ M._tools = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the file in current project scope",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "new_rel_path",
|
||||
name = "new_path",
|
||||
description = "New relative path for the file",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the file in current project scope",
|
||||
new_path = "New relative path for the file",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -921,11 +961,14 @@ M._tools = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the file in current project scope",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the file in current project scope",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -948,11 +991,14 @@ M._tools = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the project directory",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the project directory",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -975,16 +1021,20 @@ M._tools = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the project directory",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "new_rel_path",
|
||||
name = "new_path",
|
||||
description = "New relative path for the directory",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the project directory",
|
||||
new_path = "New relative path for the directory",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -1007,11 +1057,14 @@ M._tools = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the project directory",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the project directory",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -1042,6 +1095,9 @@ M._tools = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
query = "Query to search",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
@@ -1069,6 +1125,9 @@ M._tools = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
url = "Url to fetch markdown from",
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
|
||||
@@ -32,6 +32,11 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "The path to the file to modify",
|
||||
insert_line = "The line number after which to insert the text (0 for beginning of file)",
|
||||
new_str = "The text to insert",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
|
||||
@@ -14,7 +14,7 @@ M.param = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "rel_path",
|
||||
name = "path",
|
||||
description = "Relative path to the project directory",
|
||||
type = "string",
|
||||
},
|
||||
@@ -24,6 +24,10 @@ M.param = {
|
||||
type = "integer",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "Relative path to the project directory",
|
||||
max_depth = "Maximum depth of the directory",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
@@ -41,9 +45,9 @@ M.returns = {
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ rel_path: string, max_depth?: integer }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, max_depth?: integer }>
|
||||
function M.func(opts, on_log)
|
||||
local abs_path = Helpers.get_abs_path(opts.rel_path)
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
|
||||
if on_log then on_log("path: " .. abs_path) end
|
||||
if on_log then on_log("max depth: " .. tostring(opts.max_depth)) end
|
||||
|
||||
@@ -16,6 +16,8 @@ M.name = "replace_in_file"
|
||||
M.description =
|
||||
"Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file."
|
||||
|
||||
-- function M.enabled() return Config.provider:match("ollama") == nil end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
type = "table",
|
||||
@@ -57,6 +59,10 @@ One or more SEARCH/REPLACE blocks following this exact format:
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "File path here",
|
||||
diff = "Search and replace blocks here",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
local Path = require("plenary.path")
|
||||
local Utils = require("avante.utils")
|
||||
local Base = require("avante.llm_tools.base")
|
||||
local Helpers = require("avante.llm_tools.helpers")
|
||||
local Diff = require("avante.diff")
|
||||
local Config = require("avante.config")
|
||||
|
||||
---@class AvanteLLMTool
|
||||
@@ -13,6 +9,7 @@ M.name = "str_replace"
|
||||
M.description =
|
||||
"The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits."
|
||||
|
||||
-- function M.enabled() return Config.provider:match("ollama") ~= nil end
|
||||
function M.enabled() return false end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
@@ -35,6 +32,11 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "File path here",
|
||||
old_str = "old str here",
|
||||
new_str = "new str here",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
|
||||
@@ -22,6 +22,9 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "The path to the file whose last edit should be undone",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
|
||||
@@ -39,23 +39,22 @@ M.param = {
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "view_range",
|
||||
description = "The range of the file to view. This parameter only applies when viewing files, not directories.",
|
||||
type = "object",
|
||||
name = "start_line",
|
||||
description = "The start line of the view range, 1-indexed",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
fields = {
|
||||
{
|
||||
name = "start_line",
|
||||
description = "The start line of the range, 1-indexed",
|
||||
type = "integer",
|
||||
},
|
||||
{
|
||||
name = "end_line",
|
||||
description = "The end line of the range, 1-indexed, and -1 for the end line means read to the end of the file",
|
||||
type = "integer",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name = "end_line",
|
||||
description = "The end line of the view range, 1-indexed, and -1 for the end line means read to the end of the file",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "The path to the file in the current project scope",
|
||||
start_line = "The start line of the view range, 1-indexed",
|
||||
end_line = "The end line of the view range, 1-indexed, and -1 for the end line means read to the end of the file",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -74,7 +73,7 @@ M.returns = {
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, view_range?: { start_line: integer, end_line: integer } }>
|
||||
---@type AvanteLLMToolFunc<{ path: string, start_line?: integer, end_line?: integer }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
@@ -84,11 +83,9 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
local file = io.open(abs_path, "r")
|
||||
if not file then return false, "file not found: " .. abs_path end
|
||||
local lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||
if opts.view_range then
|
||||
local start_line = opts.view_range.start_line
|
||||
local end_line = opts.view_range.end_line
|
||||
if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end
|
||||
end
|
||||
local start_line = opts.start_line
|
||||
local end_line = opts.end_line
|
||||
if start_line and end_line and lines then lines = vim.list_slice(lines, start_line, end_line) end
|
||||
local truncated_lines = {}
|
||||
local is_truncated = false
|
||||
local size = 0
|
||||
|
||||
74
lua/avante/llm_tools/write_to_file.lua
Normal file
74
lua/avante/llm_tools/write_to_file.lua
Normal file
@@ -0,0 +1,74 @@
|
||||
local Utils = require("avante.utils")
|
||||
local Base = require("avante.llm_tools.base")
|
||||
local Helpers = require("avante.llm_tools.helpers")
|
||||
|
||||
---@class AvanteLLMTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "write_to_file"
|
||||
|
||||
M.description =
|
||||
"Request to write content to a file at the specified path. If the file exists, it will be overwritten with the provided content. If the file doesn't exist, it will be created. This tool will automatically create any directories needed to write the file."
|
||||
|
||||
function M.enabled() return require("avante.config").mode == "agentic" end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "path",
|
||||
get_description = function()
|
||||
local res = ("The path of the file to write to (relative to the current working directory {{cwd}})"):gsub(
|
||||
"{{cwd}}",
|
||||
Utils.get_project_root()
|
||||
)
|
||||
return res
|
||||
end,
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "content",
|
||||
description = "The content to write to the file. ALWAYS provide the COMPLETE intended content of the file, without any truncation or omissions. You MUST include ALL parts of the file, even if they haven't been modified.",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
usage = {
|
||||
path = "File path here",
|
||||
content = "File content here",
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "Whether the file was created successfully",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if the file was not created successfully",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, content: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if opts.content == nil then return false, "content not provided" end
|
||||
local old_lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||
local old_content = table.concat(old_lines or {}, "\n")
|
||||
local replace_in_file = require("avante.llm_tools.replace_in_file")
|
||||
local diff = "<<<<<<< SEARCH\n" .. old_content .. "\n=======\n" .. opts.content .. "\n>>>>>>> REPLACE"
|
||||
local new_opts = {
|
||||
path = opts.path,
|
||||
diff = diff,
|
||||
}
|
||||
return replace_in_file.func(new_opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,10 +1,15 @@
|
||||
local Utils = require("avante.utils")
|
||||
local P = require("avante.providers")
|
||||
local Providers = require("avante.providers")
|
||||
local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
setmetatable(M, { __index = Providers.openai })
|
||||
|
||||
M.api_key_name = "" -- Ollama typically doesn't require API keys for local use
|
||||
|
||||
M.role_map = {
|
||||
@@ -12,35 +17,182 @@ M.role_map = {
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
M.parse_messages = P.openai.parse_messages
|
||||
M.is_reasoning_model = P.openai.is_reasoning_model
|
||||
function M:parse_messages(opts)
|
||||
local messages = {}
|
||||
local provider_conf, _ = Providers.parse_config(self)
|
||||
|
||||
local system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts)
|
||||
|
||||
if self.is_reasoning_model(provider_conf.model) then
|
||||
table.insert(messages, { role = "developer", content = system_prompt })
|
||||
else
|
||||
table.insert(messages, { role = "system", content = system_prompt })
|
||||
end
|
||||
|
||||
vim.iter(opts.messages):each(function(msg)
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(messages, { role = self.role_map[msg.role], content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
local content = {}
|
||||
for _, item in ipairs(msg.content) do
|
||||
if type(item) == "string" then
|
||||
table.insert(content, { type = "text", text = item })
|
||||
elseif item.type == "text" then
|
||||
table.insert(content, { type = "text", text = item.text })
|
||||
elseif item.type == "image" then
|
||||
table.insert(content, {
|
||||
type = "image_url",
|
||||
image_url = {
|
||||
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
if not provider_conf.disable_tools then
|
||||
if msg.content[1].type == "tool_result" then
|
||||
local tool_use = nil
|
||||
for _, msg_ in ipairs(opts.messages) do
|
||||
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||
if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then
|
||||
tool_use = msg_
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if tool_use then
|
||||
msg.role = "user"
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = "["
|
||||
.. tool_use.content[1].name
|
||||
.. " for '"
|
||||
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
|
||||
.. "'] Result:",
|
||||
})
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = msg.content[1].content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
if #content > 0 then
|
||||
local text_content = {}
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) == "table" and item.type == "text" then table.insert(text_content, item.text) end
|
||||
end
|
||||
table.insert(messages, { role = self.role_map[msg.role], content = table.concat(text_content, "\n\n") })
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
|
||||
local message_content = messages[#messages].content
|
||||
if type(message_content) ~= "table" or message_content[1] == nil then
|
||||
message_content = { { type = "text", text = message_content } }
|
||||
end
|
||||
for _, image_path in ipairs(opts.image_paths) do
|
||||
table.insert(message_content, {
|
||||
type = "image_url",
|
||||
image_url = {
|
||||
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
|
||||
},
|
||||
})
|
||||
end
|
||||
messages[#messages].content = message_content
|
||||
end
|
||||
|
||||
local final_messages = {}
|
||||
local prev_role = nil
|
||||
|
||||
vim.iter(messages):each(function(message)
|
||||
local role = message.role
|
||||
if role == prev_role and role ~= "tool" then
|
||||
if role == self.role_map["assistant"] then
|
||||
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
|
||||
else
|
||||
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
|
||||
end
|
||||
end
|
||||
prev_role = role
|
||||
table.insert(final_messages, message)
|
||||
end)
|
||||
|
||||
return final_messages
|
||||
end
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
---@class avante.OllamaFunction
|
||||
---@field name string
|
||||
---@field arguments table
|
||||
|
||||
---@class avante.OllamaToolCall
|
||||
---@field function avante.OllamaFunction
|
||||
|
||||
---@param tool_calls avante.OllamaToolCall[]
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M:add_tool_use_messages(tool_calls, opts)
|
||||
local msgs = {}
|
||||
for _, tool_call in ipairs(tool_calls) do
|
||||
local id = Utils.uuid()
|
||||
local func = tool_call["function"]
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = func.name,
|
||||
id = id,
|
||||
input = func.arguments,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = id,
|
||||
})
|
||||
table.insert(msgs, msg)
|
||||
end
|
||||
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
||||
end
|
||||
|
||||
function M:parse_stream_data(ctx, data, opts)
|
||||
local ok, json_data = pcall(vim.json.decode, data)
|
||||
if not ok or not json_data then
|
||||
local ok, jsn = pcall(vim.json.decode, data)
|
||||
if not ok or not jsn then
|
||||
-- Add debug logging
|
||||
Utils.debug("Failed to parse JSON", data)
|
||||
return
|
||||
end
|
||||
|
||||
if json_data.message and json_data.message.content then
|
||||
local content = json_data.message.content
|
||||
P.openai:add_text_message(ctx, content, "generating", opts)
|
||||
if content and content ~= "" and opts.on_chunk then opts.on_chunk(content) end
|
||||
if jsn.message then
|
||||
if jsn.message.content then
|
||||
local content = jsn.message.content
|
||||
if content and content ~= "" then
|
||||
Providers.openai:add_text_message(ctx, content, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(content) end
|
||||
end
|
||||
end
|
||||
if jsn.message.tool_calls then
|
||||
ctx.has_tool_use = true
|
||||
local tool_calls = jsn.message.tool_calls
|
||||
self:add_tool_use_messages(tool_calls, opts)
|
||||
end
|
||||
end
|
||||
|
||||
if json_data.done then
|
||||
P.openai:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
if jsn.done then
|
||||
Providers.openai:finish_pending_messages(ctx, opts)
|
||||
if ctx.has_tool_use or (ctx.tool_use_list and #ctx.tool_use_list > 0) then
|
||||
opts.on_stop({ reason = "tool_use" })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local keep_alive = provider_conf.keep_alive or "5m"
|
||||
|
||||
if not provider_conf.model or provider_conf.model == "" then error("Ollama model must be specified in config") end
|
||||
@@ -50,7 +202,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
["Accept"] = "application/json",
|
||||
}
|
||||
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
if Providers.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key and api_key ~= "" then
|
||||
headers["Authorization"] = "Bearer " .. api_key
|
||||
@@ -66,7 +218,6 @@ function M:parse_curl_args(prompt_opts)
|
||||
model = provider_conf.model,
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
system = prompt_opts.system_prompt,
|
||||
keep_alive = keep_alive,
|
||||
}, request_body),
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local Providers = require("avante.providers")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
local XMLParser = require("avante.libs.xmlparser")
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
local LlmTools = require("avante.llm_tools")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -76,10 +79,15 @@ function M:parse_messages(opts)
|
||||
local messages = {}
|
||||
local provider_conf, _ = Providers.parse_config(self)
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
local system_prompt = opts.system_prompt
|
||||
|
||||
if use_ReAct_prompt then system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) end
|
||||
|
||||
if self.is_reasoning_model(provider_conf.model) then
|
||||
table.insert(messages, { role = "developer", content = opts.system_prompt })
|
||||
table.insert(messages, { role = "developer", content = system_prompt })
|
||||
else
|
||||
table.insert(messages, { role = "system", content = opts.system_prompt })
|
||||
table.insert(messages, { role = "system", content = system_prompt })
|
||||
end
|
||||
|
||||
local has_tool_use = false
|
||||
@@ -103,22 +111,50 @@ function M:parse_messages(opts)
|
||||
url = "data:" .. item.source.media_type .. ";" .. item.source.type .. "," .. item.source.data,
|
||||
},
|
||||
})
|
||||
elseif item.type == "tool_use" then
|
||||
elseif item.type == "tool_use" and not use_ReAct_prompt then
|
||||
has_tool_use = true
|
||||
table.insert(tool_calls, {
|
||||
id = item.id,
|
||||
type = "function",
|
||||
["function"] = { name = item.name, arguments = vim.json.encode(item.input) },
|
||||
})
|
||||
elseif item.type == "tool_result" and has_tool_use then
|
||||
elseif item.type == "tool_result" and has_tool_use and not use_ReAct_prompt then
|
||||
table.insert(
|
||||
tool_results,
|
||||
{ tool_call_id = item.tool_use_id, content = item.is_error and "Error: " .. item.content or item.content }
|
||||
)
|
||||
end
|
||||
end
|
||||
if not provider_conf.disable_tools and use_ReAct_prompt then
|
||||
if msg.content[1].type == "tool_result" then
|
||||
local tool_use = nil
|
||||
for _, msg_ in ipairs(opts.messages) do
|
||||
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||
if msg_.content[1].type == "tool_use" and msg_.content[1].id == msg.content[1].tool_use_id then
|
||||
tool_use = msg_
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if tool_use then
|
||||
msg.role = "user"
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = "["
|
||||
.. tool_use.content[1].name
|
||||
.. " for '"
|
||||
.. (tool_use.content[1].input.path or tool_use.content[1].input.rel_path or "")
|
||||
.. "'] Result:",
|
||||
})
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = msg.content[1].content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
if #content > 0 then table.insert(messages, { role = self.role_map[msg.role], content = content }) end
|
||||
if not provider_conf.disable_tools then
|
||||
if not provider_conf.disable_tools and not use_ReAct_prompt then
|
||||
if #tool_calls > 0 then
|
||||
local last_message = messages[#messages]
|
||||
if last_message and last_message.role == self.role_map["assistant"] and last_message.tool_calls then
|
||||
@@ -183,7 +219,10 @@ function M:finish_pending_messages(ctx, opts)
|
||||
end
|
||||
end
|
||||
|
||||
local llm_tool_names = nil
|
||||
|
||||
function M:add_text_message(ctx, text, state, opts)
|
||||
if llm_tool_names == nil then llm_tool_names = LlmTools.get_tool_names() end
|
||||
if ctx.content == nil then ctx.content = "" end
|
||||
ctx.content = ctx.content .. text
|
||||
local msg = HistoryMessage:new({
|
||||
@@ -194,7 +233,75 @@ function M:add_text_message(ctx, text, state, opts)
|
||||
uuid = ctx.content_uuid,
|
||||
})
|
||||
ctx.content_uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
local msgs = { msg }
|
||||
local stream_parser = XMLParser.createStreamParser()
|
||||
stream_parser:addData(ctx.content)
|
||||
local xml = stream_parser:getAllElements()
|
||||
if xml then
|
||||
local new_content_list = {}
|
||||
local xml_md_openned = false
|
||||
for idx, item in ipairs(xml) do
|
||||
if item._name == "_text" then
|
||||
local cleaned_lines = {}
|
||||
local lines = vim.split(item._text, "\n")
|
||||
for _, line in ipairs(lines) do
|
||||
if line:match("^```xml") or line:match("^```tool_code") or line:match("^```tool_use") then
|
||||
xml_md_openned = true
|
||||
elseif line:match("^```$") then
|
||||
if xml_md_openned then
|
||||
xml_md_openned = false
|
||||
else
|
||||
table.insert(cleaned_lines, line)
|
||||
end
|
||||
else
|
||||
table.insert(cleaned_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(new_content_list, table.concat(cleaned_lines, "\n"))
|
||||
goto continue
|
||||
end
|
||||
if not vim.tbl_contains(llm_tool_names, item._name) then goto continue end
|
||||
local ok, input = pcall(vim.json.decode, item._text)
|
||||
if not ok and item.children and #item.children > 0 then
|
||||
input = {}
|
||||
for _, item_ in ipairs(item.children) do
|
||||
local ok_, input_ = pcall(vim.json.decode, item_._text)
|
||||
if ok_ and input_ then
|
||||
input[item_._name] = input_
|
||||
else
|
||||
input[item_._name] = item_._text
|
||||
end
|
||||
end
|
||||
end
|
||||
if input then
|
||||
local tool_use_id = Utils.uuid()
|
||||
local msg_ = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = item._name,
|
||||
id = tool_use_id,
|
||||
input = input,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.content_uuid .. "-" .. idx,
|
||||
})
|
||||
msgs[#msgs + 1] = msg_
|
||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||
ctx.tool_use_list[#ctx.tool_use_list + 1] = {
|
||||
id = tool_use_id,
|
||||
name = item._name,
|
||||
input_json = input,
|
||||
}
|
||||
end
|
||||
if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
||||
end
|
||||
|
||||
function M:add_thinking_message(ctx, text, state, opts)
|
||||
@@ -242,7 +349,11 @@ end
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||
opts.on_stop({ reason = "tool_use" })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
return
|
||||
end
|
||||
if data_stream == "[DONE]" then return end
|
||||
@@ -316,9 +427,13 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
self:add_text_message(ctx, delta.content, "generating", opts)
|
||||
end
|
||||
end
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||
opts.on_stop({ reason = "tool_use", usage = jsn.usage })
|
||||
else
|
||||
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
||||
end
|
||||
end
|
||||
if choice.finish_reason == "tool_calls" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
@@ -372,8 +487,10 @@ function M:parse_curl_args(prompt_opts)
|
||||
|
||||
self.set_allowed_params(provider_conf, request_body)
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
|
||||
local tools = nil
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then
|
||||
tools = {}
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(tools, self:transform_tool(tool))
|
||||
|
||||
@@ -56,6 +56,7 @@ Sidebar.__index = Sidebar
|
||||
---@field current_state avante.GenerateState | nil
|
||||
---@field state_timer table | nil
|
||||
---@field state_spinner_chars string[]
|
||||
---@field thinking_spinner_chars string[]
|
||||
---@field state_spinner_idx integer
|
||||
---@field state_extmark_id integer | nil
|
||||
---@field scroll boolean
|
||||
@@ -84,12 +85,16 @@ function Sidebar:new(id)
|
||||
current_state = nil,
|
||||
state_timer = nil,
|
||||
state_spinner_chars = { "·", "✢", "✳", "∗", "✻", "✽" },
|
||||
thinking_spinner_chars = { "🤯", "🙄" },
|
||||
state_spinner_idx = 1,
|
||||
state_extmark_id = nil,
|
||||
scroll = true,
|
||||
input_hint_window = nil,
|
||||
ask_opts = {},
|
||||
old_result_lines = {},
|
||||
-- 缓存相关字段
|
||||
_cached_history_lines = nil,
|
||||
_history_cache_invalidated = true,
|
||||
}, Sidebar)
|
||||
end
|
||||
|
||||
@@ -1528,42 +1533,75 @@ end
|
||||
---@param opts? {focus?: boolean, scroll?: boolean, backspace?: integer, callback?: fun(): nil} whether to focus the result view
|
||||
function Sidebar:update_content(content, opts)
|
||||
if not self.result_container or not self.result_container.bufnr then return end
|
||||
|
||||
-- 提前验证容器有效性,避免后续无效操作
|
||||
if not Utils.is_valid_container(self.result_container) then return end
|
||||
|
||||
opts = vim.tbl_deep_extend("force", { focus = false, scroll = self.scroll, callback = nil }, opts or {})
|
||||
local history_lines = self.get_history_lines(self.chat_history)
|
||||
if content ~= nil and content ~= "" then
|
||||
table.insert(history_lines, Line:new({ { "" } }))
|
||||
local content_lines = vim.split(content, "\n")
|
||||
for _, line in ipairs(content_lines) do
|
||||
table.insert(history_lines, Line:new({ { line } }))
|
||||
end
|
||||
|
||||
-- 缓存历史行,避免重复计算
|
||||
local history_lines
|
||||
if not self._cached_history_lines or self._history_cache_invalidated then
|
||||
history_lines = self.get_history_lines(self.chat_history)
|
||||
self._cached_history_lines = history_lines
|
||||
self._history_cache_invalidated = false
|
||||
else
|
||||
history_lines = vim.deepcopy(self._cached_history_lines)
|
||||
end
|
||||
vim.defer_fn(function()
|
||||
self:clear_state()
|
||||
local f = function()
|
||||
if not Utils.is_valid_container(self.result_container) then return end
|
||||
Utils.unlock_buf(self.result_container.bufnr)
|
||||
Utils.update_buffer_lines(
|
||||
RESULT_BUF_HL_NAMESPACE,
|
||||
self.result_container.bufnr,
|
||||
self.old_result_lines,
|
||||
history_lines
|
||||
)
|
||||
Utils.lock_buf(self.result_container.bufnr)
|
||||
self.old_result_lines = history_lines
|
||||
api.nvim_set_option_value("filetype", "Avante", { buf = self.result_container.bufnr })
|
||||
vim.schedule(function() vim.cmd("redraw") end)
|
||||
if opts.focus and not self:is_focused_on_result() then
|
||||
--- set cursor to bottom of result view
|
||||
xpcall(function() api.nvim_set_current_win(self.result_container.winid) end, function(err) return err end)
|
||||
end
|
||||
|
||||
if opts.scroll then Utils.buf_scroll_to_end(self.result_container.bufnr) end
|
||||
-- 批量处理内容行,减少表操作
|
||||
if content ~= nil and content ~= "" then
|
||||
local content_lines = vim.split(content, "\n")
|
||||
local new_lines = { Line:new({ { "" } }) }
|
||||
|
||||
if opts.callback ~= nil then opts.callback() end
|
||||
-- 预分配表大小,提升性能
|
||||
for i = 1, #content_lines do
|
||||
new_lines[i + 1] = Line:new({ { content_lines[i] } })
|
||||
end
|
||||
f()
|
||||
|
||||
-- 一次性扩展,而不是逐个插入
|
||||
vim.list_extend(history_lines, new_lines)
|
||||
end
|
||||
|
||||
-- 使用 vim.schedule 而不是 vim.defer_fn(0),性能更好
|
||||
-- 再次检查容器有效性
|
||||
if not Utils.is_valid_container(self.result_container) then return end
|
||||
|
||||
self:clear_state()
|
||||
|
||||
-- 批量更新操作
|
||||
local bufnr = self.result_container.bufnr
|
||||
Utils.unlock_buf(bufnr)
|
||||
|
||||
Utils.update_buffer_lines(RESULT_BUF_HL_NAMESPACE, bufnr, self.old_result_lines, history_lines)
|
||||
|
||||
-- 缓存结果行
|
||||
self.old_result_lines = history_lines
|
||||
|
||||
-- 批量设置选项
|
||||
api.nvim_set_option_value("filetype", "Avante", { buf = bufnr })
|
||||
Utils.lock_buf(bufnr)
|
||||
|
||||
-- 处理焦点和滚动
|
||||
if opts.focus and not self:is_focused_on_result() then
|
||||
xpcall(function() api.nvim_set_current_win(self.result_container.winid) end, function(err)
|
||||
Utils.debug("Failed to set current win:", err)
|
||||
return err
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.scroll then Utils.buf_scroll_to_end(bufnr) end
|
||||
|
||||
-- 延迟执行回调和状态渲染
|
||||
if opts.callback then vim.schedule(opts.callback) end
|
||||
|
||||
-- 最后渲染状态
|
||||
vim.schedule(function()
|
||||
self:render_state()
|
||||
end, 0)
|
||||
-- 延迟重绘,避免阻塞
|
||||
vim.defer_fn(function() vim.cmd("redraw") end, 10)
|
||||
end)
|
||||
|
||||
return self
|
||||
end
|
||||
|
||||
@@ -1842,8 +1880,8 @@ function Sidebar:render_state()
|
||||
if self.state_extmark_id then
|
||||
api.nvim_buf_del_extmark(self.result_container.bufnr, STATE_NAMESPACE, self.state_extmark_id)
|
||||
end
|
||||
local spinner_char = self.state_spinner_chars[self.state_spinner_idx]
|
||||
self.state_spinner_idx = (self.state_spinner_idx % #self.state_spinner_chars) + 1
|
||||
local spinner_chars = self.state_spinner_chars
|
||||
if self.current_state == "thinking" then spinner_chars = self.thinking_spinner_chars end
|
||||
local hl = "AvanteStateSpinnerGenerating"
|
||||
if self.current_state == "tool calling" then hl = "AvanteStateSpinnerToolCalling" end
|
||||
if self.current_state == "failed" then hl = "AvanteStateSpinnerFailed" end
|
||||
@@ -1851,6 +1889,8 @@ function Sidebar:render_state()
|
||||
if self.current_state == "searching" then hl = "AvanteStateSpinnerSearching" end
|
||||
if self.current_state == "thinking" then hl = "AvanteStateSpinnerThinking" end
|
||||
if self.current_state == "compacting" then hl = "AvanteStateSpinnerCompacting" end
|
||||
local spinner_char = spinner_chars[self.state_spinner_idx]
|
||||
self.state_spinner_idx = (self.state_spinner_idx % #spinner_chars) + 1
|
||||
if
|
||||
self.current_state ~= "generating"
|
||||
and self.current_state ~= "tool calling"
|
||||
@@ -1911,6 +1951,10 @@ function Sidebar:new_chat(args, cb)
|
||||
if cb then cb(args) end
|
||||
end
|
||||
|
||||
local _save_history = Utils.debounce(function(self) Path.history.save(self.code.bufnr, self.chat_history) end, 3000)
|
||||
|
||||
local save_history = vim.schedule_wrap(_save_history)
|
||||
|
||||
---@param messages avante.HistoryMessage | avante.HistoryMessage[]
|
||||
function Sidebar:add_history_messages(messages)
|
||||
local history_messages = Utils.get_history_messages(self.chat_history)
|
||||
@@ -1934,22 +1978,30 @@ function Sidebar:add_history_messages(messages)
|
||||
end
|
||||
end
|
||||
self.chat_history.messages = history_messages
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
-- 历史消息变更时,标记缓存失效
|
||||
self._history_cache_invalidated = true
|
||||
save_history(self)
|
||||
if
|
||||
self.chat_history.title == "untitled"
|
||||
and #messages > 0
|
||||
and messages[1].just_for_display ~= true
|
||||
and messages[1].state == "generated"
|
||||
then
|
||||
self.chat_history.title = "generating..."
|
||||
Llm.summarize_chat_thread_title(messages[1].message.content, function(title)
|
||||
if title then
|
||||
self.chat_history.title = title
|
||||
else
|
||||
self.chat_history.title = "untitled"
|
||||
end
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
end)
|
||||
-- self.chat_history.title = "generating..."
|
||||
-- Llm.summarize_chat_thread_title(messages[1].message.content, function(title)
|
||||
-- if title then
|
||||
-- self.chat_history.title = title
|
||||
-- else
|
||||
-- self.chat_history.title = "untitled"
|
||||
-- end
|
||||
-- save_history(self)
|
||||
-- end)
|
||||
local first_msg_text = Utils.message_to_text(messages[1], messages)
|
||||
local lines_ = vim.split(first_msg_text, "\n")
|
||||
if #lines_ > 0 then
|
||||
self.chat_history.title = lines_[1]
|
||||
save_history(self)
|
||||
end
|
||||
end
|
||||
local last_message = messages[#messages]
|
||||
if last_message then
|
||||
@@ -1964,7 +2016,10 @@ function Sidebar:add_history_messages(messages)
|
||||
self.current_state = "generating"
|
||||
end
|
||||
end
|
||||
self:update_content("")
|
||||
xpcall(function() self:update_content("") end, function(err)
|
||||
Utils.debug("Failed to update content:", err)
|
||||
return nil
|
||||
end)
|
||||
end
|
||||
|
||||
---@param messages AvanteLLMMessage | AvanteLLMMessage[]
|
||||
@@ -2135,6 +2190,8 @@ end
|
||||
function Sidebar:reload_chat_history()
|
||||
if not self.code.bufnr or not api.nvim_buf_is_valid(self.code.bufnr) then return end
|
||||
self.chat_history = Path.history.load(self.code.bufnr)
|
||||
-- 重新加载历史时,标记缓存失效
|
||||
self._history_cache_invalidated = true
|
||||
end
|
||||
|
||||
---@return avante.HistoryMessage[]
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
|
||||
RULES
|
||||
|
||||
- NEVER reply the updated code.
|
||||
|
||||
- Always reply to the user in the same language they are using.
|
||||
|
||||
- Don't just provide code suggestions, use the `replace_in_file` tool to help users fulfill their needs.
|
||||
- Don't just provide code suggestions, use the `replace_in_file` tool or `str_replace` tool to help users fulfill their needs.
|
||||
|
||||
- After the tool call is complete, please do not output the entire file content.
|
||||
|
||||
|
||||
@@ -231,6 +231,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field disable_tools? boolean
|
||||
---@field entra? boolean
|
||||
---@field hide_in_model_selector? boolean
|
||||
---@field use_ReAct_prompt? boolean
|
||||
---
|
||||
---@class AvanteSupportedProvider: AvanteDefaultBaseProvider
|
||||
---@field __inherited_from? string
|
||||
@@ -396,6 +397,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@class AvanteLLMToolParam
|
||||
---@field type 'table'
|
||||
---@field fields AvanteLLMToolParamField[]
|
||||
---@field usage? table
|
||||
|
||||
---@class AvanteLLMToolParamField
|
||||
---@field name string
|
||||
|
||||
@@ -576,7 +576,7 @@ end
|
||||
--- remove indentation from code: spaces or tabs
|
||||
function M.remove_indentation(code)
|
||||
if not code then return code end
|
||||
return code:gsub("^%s*", ""):gsub("%s*$", "")
|
||||
return code:gsub("%s*", "")
|
||||
end
|
||||
|
||||
function M.relative_path(absolute)
|
||||
@@ -1056,12 +1056,12 @@ function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines)
|
||||
if #diffs == 0 then return end
|
||||
for _, diff in ipairs(diffs) do
|
||||
local lines = diff.content
|
||||
-- M.debug("lines", lines)
|
||||
local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines)
|
||||
vim.api.nvim_buf_set_lines(bufnr, diff.start_line - 1, diff.end_line - 1, false, text_lines)
|
||||
for i, line in ipairs(lines) do
|
||||
line:set_highlights(ns_id, bufnr, diff.start_line + i - 2)
|
||||
end
|
||||
vim.cmd("redraw")
|
||||
end
|
||||
end
|
||||
|
||||
@@ -1467,6 +1467,23 @@ function M.text_to_lines(text, hl)
|
||||
return lines
|
||||
end
|
||||
|
||||
---@param thinking_text string
|
||||
---@param hl string | nil
|
||||
---@return avante.ui.Line[]
|
||||
function M.thinking_to_lines(thinking_text, hl)
|
||||
local Line = require("avante.ui.line")
|
||||
local text_lines = vim.split(thinking_text, "\n")
|
||||
local lines = {}
|
||||
table.insert(lines, Line:new({ { M.icon("🤔 ") .. "Thought content:" } }))
|
||||
table.insert(lines, Line:new({ { "" } }))
|
||||
for _, text_line in ipairs(text_lines) do
|
||||
local piece = { "> " .. text_line }
|
||||
if hl then table.insert(piece, hl) end
|
||||
table.insert(lines, Line:new({ piece }))
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
---@param item AvanteLLMMessageContentItem
|
||||
---@param message avante.HistoryMessage
|
||||
---@param messages avante.HistoryMessage[]
|
||||
@@ -1475,6 +1492,9 @@ function M.message_content_item_to_lines(item, message, messages)
|
||||
local Line = require("avante.ui.line")
|
||||
if type(item) == "string" then return M.text_to_lines(item) end
|
||||
if type(item) == "table" then
|
||||
if item.type == "thinking" or item.type == "redacted_thinking" then
|
||||
return M.thinking_to_lines(item.thinking or item.data or "")
|
||||
end
|
||||
if item.type == "text" then return M.text_to_lines(item.text) end
|
||||
if item.type == "image" then
|
||||
return { Line:new({ { "" } }) }
|
||||
@@ -1520,18 +1540,6 @@ function M.message_content_item_to_lines(item, message, messages)
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif tool_result_message then
|
||||
local tool_result = tool_result_message.message.content[1]
|
||||
if tool_result.content then
|
||||
local result_lines = vim.split(tool_result.content, "\n")
|
||||
for idx, line in ipairs(result_lines) do
|
||||
if idx ~= #result_lines then
|
||||
table.insert(lines, Line:new({ { "│" }, { string.format(" %s", line) } }))
|
||||
else
|
||||
table.insert(lines, Line:new({ { "╰─" }, { string.format(" %s", line) } }))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
144
lua/avante/utils/prompts.lua
Normal file
144
lua/avante/utils/prompts.lua
Normal file
@@ -0,0 +1,144 @@
|
||||
local M = {}
|
||||
|
||||
---@param provider_conf AvanteDefaultBaseProvider
|
||||
---@param opts AvantePromptOptions
|
||||
---@return string
|
||||
function M.get_ReAct_system_prompt(provider_conf, opts)
|
||||
local system_prompt = opts.system_prompt
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
if not disable_tools and opts.tools then
|
||||
local tools_prompts = [[
|
||||
====
|
||||
|
||||
TOOL USE
|
||||
|
||||
You have access to a set of tools that are executed upon the user's approval. You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
|
||||
# Tool Use Formatting
|
||||
|
||||
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
|
||||
|
||||
<tool_name>
|
||||
<parameter1_name>value1</parameter1_name>
|
||||
<parameter2_name>value2</parameter2_name>
|
||||
...
|
||||
</tool_name>
|
||||
|
||||
For example:
|
||||
|
||||
<view>
|
||||
<path>src/main.js</path>
|
||||
</view>
|
||||
|
||||
Always adhere to this format for the tool use to ensure proper parsing and execution.
|
||||
|
||||
# Tools
|
||||
|
||||
]]
|
||||
for _, tool in ipairs(opts.tools) do
|
||||
local tool_prompt = ([[
|
||||
## {{name}}
|
||||
Description: {{description}}
|
||||
Parameters:
|
||||
]]):gsub("{{name}}", tool.name):gsub(
|
||||
"{{description}}",
|
||||
tool.get_description and tool.get_description() or (tool.description or "")
|
||||
)
|
||||
for _, field in ipairs(tool.param.fields) do
|
||||
if field.optional then
|
||||
tool_prompt = tool_prompt .. string.format(" - %s: %s\n", field.name, field.description)
|
||||
else
|
||||
tool_prompt = tool_prompt
|
||||
.. string.format(
|
||||
" - %s: (required) %s\n",
|
||||
field.name,
|
||||
field.get_description and field.get_description() or (field.description or "")
|
||||
)
|
||||
end
|
||||
end
|
||||
if tool.param.usage then
|
||||
tool_prompt = tool_prompt
|
||||
.. ("Usage:\n<{{name}}>\n"):gsub("{{([%w_]+)}}", function(name) return tool[name] end)
|
||||
for k, v in pairs(tool.param.usage) do
|
||||
tool_prompt = tool_prompt .. "<" .. k .. ">" .. tostring(v) .. "</" .. k .. ">\n"
|
||||
end
|
||||
tool_prompt = tool_prompt .. ("</{{name}}>\n"):gsub("{{([%w_]+)}}", function(name) return tool[name] end)
|
||||
end
|
||||
tools_prompts = tools_prompts .. tool_prompt .. "\n"
|
||||
end
|
||||
|
||||
system_prompt = system_prompt .. tools_prompts
|
||||
|
||||
system_prompt = system_prompt
|
||||
.. [[
|
||||
# Tool Use Examples
|
||||
|
||||
## Example 1: Requesting to execute a command
|
||||
|
||||
<bash>
|
||||
<path>./src</path>
|
||||
<command>npm run dev</command>
|
||||
</bash>
|
||||
|
||||
## Example 2: Requesting to create a new file
|
||||
|
||||
<write_to_file>
|
||||
<path>src/frontend-config.json</path>
|
||||
<content>
|
||||
{
|
||||
"apiEndpoint": "https://api.example.com",
|
||||
"theme": {
|
||||
"primaryColor": "#007bff",
|
||||
"secondaryColor": "#6c757d",
|
||||
"fontFamily": "Arial, sans-serif"
|
||||
},
|
||||
"features": {
|
||||
"darkMode": true,
|
||||
"notifications": true,
|
||||
"analytics": false
|
||||
},
|
||||
"version": "1.0.0"
|
||||
}
|
||||
</content>
|
||||
</write_to_file>
|
||||
|
||||
## Example 3: Requesting to make targeted edits to a file
|
||||
|
||||
<replace_in_file>
|
||||
<path>src/components/App.tsx</path>
|
||||
<diff>
|
||||
<<<<<<< SEARCH
|
||||
import React from 'react';
|
||||
=======
|
||||
import React, { useState } from 'react';
|
||||
>>>>>>> REPLACE
|
||||
|
||||
<<<<<<< SEARCH
|
||||
function handleSubmit() {
|
||||
saveData();
|
||||
setLoading(false);
|
||||
}
|
||||
|
||||
=======
|
||||
>>>>>>> REPLACE
|
||||
|
||||
<<<<<<< SEARCH
|
||||
return (
|
||||
<div>
|
||||
=======
|
||||
function handleSubmit() {
|
||||
saveData();
|
||||
setLoading(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
>>>>>>> REPLACE
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
]]
|
||||
end
|
||||
return system_prompt
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -53,21 +53,21 @@ describe("llm_tools", function()
|
||||
|
||||
describe("ls", function()
|
||||
it("should list files in directory", function()
|
||||
local result, err = ls({ rel_path = ".", max_depth = 1 })
|
||||
local result, err = ls({ path = ".", max_depth = 1 })
|
||||
assert.is_nil(err)
|
||||
assert.falsy(result:find("avante.nvim"))
|
||||
assert.truthy(result:find("test.txt"))
|
||||
assert.falsy(result:find("test1.txt"))
|
||||
end)
|
||||
it("should list files in directory with depth", function()
|
||||
local result, err = ls({ rel_path = ".", max_depth = 2 })
|
||||
local result, err = ls({ path = ".", max_depth = 2 })
|
||||
assert.is_nil(err)
|
||||
assert.falsy(result:find("avante.nvim"))
|
||||
assert.truthy(result:find("test.txt"))
|
||||
assert.truthy(result:find("test1.txt"))
|
||||
end)
|
||||
it("should list files respecting gitignore", function()
|
||||
local result, err = ls({ rel_path = ".", max_depth = 2 })
|
||||
local result, err = ls({ path = ".", max_depth = 2 })
|
||||
assert.is_nil(err)
|
||||
assert.falsy(result:find("avante.nvim"))
|
||||
assert.truthy(result:find("test.txt"))
|
||||
@@ -102,7 +102,7 @@ describe("llm_tools", function()
|
||||
|
||||
describe("create_dir", function()
|
||||
it("should create new directory", function()
|
||||
LlmTools.create_dir({ rel_path = "new_dir" }, nil, function(success, err)
|
||||
LlmTools.create_dir({ path = "new_dir" }, nil, function(success, err)
|
||||
assert.is_nil(err)
|
||||
assert.is_true(success)
|
||||
|
||||
@@ -114,7 +114,7 @@ describe("llm_tools", function()
|
||||
|
||||
describe("delete_file", function()
|
||||
it("should delete existing file", function()
|
||||
LlmTools.delete_file({ rel_path = "test.txt" }, nil, function(success, err)
|
||||
LlmTools.delete_file({ path = "test.txt" }, nil, function(success, err)
|
||||
assert.is_nil(err)
|
||||
assert.is_true(success)
|
||||
|
||||
@@ -147,28 +147,28 @@ describe("llm_tools", function()
|
||||
file:write("this is nothing")
|
||||
file:close()
|
||||
|
||||
local result, err = grep({ rel_path = ".", query = "Searchable", case_sensitive = false })
|
||||
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false })
|
||||
assert.is_nil(err)
|
||||
assert.truthy(result:find("searchable.txt"))
|
||||
assert.falsy(result:find("nothing.txt"))
|
||||
|
||||
local result2, err2 = grep({ rel_path = ".", query = "searchable", case_sensitive = true })
|
||||
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true })
|
||||
assert.is_nil(err2)
|
||||
assert.truthy(result2:find("searchable.txt"))
|
||||
assert.falsy(result2:find("nothing.txt"))
|
||||
|
||||
local result3, err3 = grep({ rel_path = ".", query = "Searchable", case_sensitive = true })
|
||||
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true })
|
||||
assert.is_nil(err3)
|
||||
assert.falsy(result3:find("searchable.txt"))
|
||||
assert.falsy(result3:find("nothing.txt"))
|
||||
|
||||
local result4, err4 = grep({ rel_path = ".", query = "searchable", case_sensitive = false })
|
||||
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false })
|
||||
assert.is_nil(err4)
|
||||
assert.truthy(result4:find("searchable.txt"))
|
||||
assert.falsy(result4:find("nothing.txt"))
|
||||
|
||||
local result5, err5 = grep({
|
||||
rel_path = ".",
|
||||
path = ".",
|
||||
query = "searchable",
|
||||
case_sensitive = false,
|
||||
exclude_pattern = "search*",
|
||||
@@ -191,7 +191,7 @@ describe("llm_tools", function()
|
||||
file:write("content for ag test")
|
||||
file:close()
|
||||
|
||||
local result, err = grep({ rel_path = ".", query = "ag test" })
|
||||
local result, err = grep({ path = ".", query = "ag test" })
|
||||
assert.is_nil(err)
|
||||
assert.is_string(result)
|
||||
assert.truthy(result:find("ag_test.txt"))
|
||||
@@ -215,28 +215,28 @@ describe("llm_tools", function()
|
||||
file:write("this is nothing")
|
||||
file:close()
|
||||
|
||||
local result, err = grep({ rel_path = ".", query = "Searchable", case_sensitive = false })
|
||||
local result, err = grep({ path = ".", query = "Searchable", case_sensitive = false })
|
||||
assert.is_nil(err)
|
||||
assert.truthy(result:find("searchable.txt"))
|
||||
assert.falsy(result:find("nothing.txt"))
|
||||
|
||||
local result2, err2 = grep({ rel_path = ".", query = "searchable", case_sensitive = true })
|
||||
local result2, err2 = grep({ path = ".", query = "searchable", case_sensitive = true })
|
||||
assert.is_nil(err2)
|
||||
assert.truthy(result2:find("searchable.txt"))
|
||||
assert.falsy(result2:find("nothing.txt"))
|
||||
|
||||
local result3, err3 = grep({ rel_path = ".", query = "Searchable", case_sensitive = true })
|
||||
local result3, err3 = grep({ path = ".", query = "Searchable", case_sensitive = true })
|
||||
assert.is_nil(err3)
|
||||
assert.falsy(result3:find("searchable.txt"))
|
||||
assert.falsy(result3:find("nothing.txt"))
|
||||
|
||||
local result4, err4 = grep({ rel_path = ".", query = "searchable", case_sensitive = false })
|
||||
local result4, err4 = grep({ path = ".", query = "searchable", case_sensitive = false })
|
||||
assert.is_nil(err4)
|
||||
assert.truthy(result4:find("searchable.txt"))
|
||||
assert.falsy(result4:find("nothing.txt"))
|
||||
|
||||
local result5, err5 = grep({
|
||||
rel_path = ".",
|
||||
path = ".",
|
||||
query = "searchable",
|
||||
case_sensitive = false,
|
||||
exclude_pattern = "search*",
|
||||
@@ -250,18 +250,18 @@ describe("llm_tools", function()
|
||||
-- Mock exepath to return nothing
|
||||
vim.fn.exepath = function() return "" end
|
||||
|
||||
local result, err = grep({ rel_path = ".", query = "test" })
|
||||
local result, err = grep({ path = ".", query = "test" })
|
||||
assert.equals("", result)
|
||||
assert.equals("No search command found", err)
|
||||
end)
|
||||
|
||||
it("should respect path permissions", function()
|
||||
local result, err = grep({ rel_path = "../outside_project", query = "test" })
|
||||
local result, err = grep({ path = "../outside_project", query = "test" })
|
||||
assert.truthy(err:find("No permission to access path"))
|
||||
end)
|
||||
|
||||
it("should handle non-existent paths", function()
|
||||
local result, err = grep({ rel_path = "non_existent_dir", query = "test" })
|
||||
local result, err = grep({ path = "non_existent_dir", query = "test" })
|
||||
assert.equals("", result)
|
||||
assert.truthy(err)
|
||||
assert.truthy(err:find("No such file or directory"))
|
||||
@@ -270,14 +270,14 @@ describe("llm_tools", function()
|
||||
|
||||
describe("bash", function()
|
||||
-- it("should execute command and return output", function()
|
||||
-- bash({ rel_path = ".", command = "echo 'test'" }, nil, function(result, err)
|
||||
-- bash({ path = ".", command = "echo 'test'" }, nil, function(result, err)
|
||||
-- assert.is_nil(err)
|
||||
-- assert.equals("test\n", result)
|
||||
-- end)
|
||||
-- end)
|
||||
|
||||
it("should return error when running outside current directory", function()
|
||||
bash({ rel_path = "../outside_project", command = "echo 'test'" }, nil, function(result, err)
|
||||
bash({ path = "../outside_project", command = "echo 'test'" }, nil, function(result, err)
|
||||
assert.is_false(result)
|
||||
assert.truthy(err)
|
||||
assert.truthy(err:find("No permission to access path"))
|
||||
@@ -289,7 +289,7 @@ describe("llm_tools", function()
|
||||
it("should execute Python code and return output", function()
|
||||
LlmTools.python(
|
||||
{
|
||||
rel_path = ".",
|
||||
path = ".",
|
||||
code = "print('Hello from Python')",
|
||||
},
|
||||
nil,
|
||||
@@ -303,7 +303,7 @@ describe("llm_tools", function()
|
||||
it("should handle Python errors", function()
|
||||
LlmTools.python(
|
||||
{
|
||||
rel_path = ".",
|
||||
path = ".",
|
||||
code = "print(undefined_variable)",
|
||||
},
|
||||
nil,
|
||||
@@ -318,7 +318,7 @@ describe("llm_tools", function()
|
||||
it("should respect path permissions", function()
|
||||
LlmTools.python(
|
||||
{
|
||||
rel_path = "../outside_project",
|
||||
path = "../outside_project",
|
||||
code = "print('test')",
|
||||
},
|
||||
nil,
|
||||
@@ -332,7 +332,7 @@ describe("llm_tools", function()
|
||||
it("should handle non-existent paths", function()
|
||||
LlmTools.python(
|
||||
{
|
||||
rel_path = "non_existent_dir",
|
||||
path = "non_existent_dir",
|
||||
code = "print('test')",
|
||||
},
|
||||
nil,
|
||||
@@ -347,7 +347,7 @@ describe("llm_tools", function()
|
||||
os.execute("docker image rm python:3.12-slim")
|
||||
LlmTools.python(
|
||||
{
|
||||
rel_path = ".",
|
||||
path = ".",
|
||||
code = "print('Hello from custom container')",
|
||||
container_image = "python:3.12-slim",
|
||||
},
|
||||
@@ -370,7 +370,7 @@ describe("llm_tools", function()
|
||||
os.execute("touch " .. test_dir .. "/nested/file4.lua")
|
||||
|
||||
-- Test for lua files in the root
|
||||
local result, err = glob({ rel_path = ".", pattern = "*.lua" })
|
||||
local result, err = glob({ path = ".", pattern = "*.lua" })
|
||||
assert.is_nil(err)
|
||||
local files = vim.json.decode(result).matches
|
||||
assert.equals(2, #files)
|
||||
@@ -380,7 +380,7 @@ describe("llm_tools", function()
|
||||
assert.falsy(vim.tbl_contains(files, test_dir .. "/nested/file4.lua"))
|
||||
|
||||
-- Test with recursive pattern
|
||||
local result2, err2 = glob({ rel_path = ".", pattern = "**/*.lua" })
|
||||
local result2, err2 = glob({ path = ".", pattern = "**/*.lua" })
|
||||
assert.is_nil(err2)
|
||||
local files2 = vim.json.decode(result2).matches
|
||||
assert.equals(3, #files2)
|
||||
@@ -390,13 +390,13 @@ describe("llm_tools", function()
|
||||
end)
|
||||
|
||||
it("should respect path permissions", function()
|
||||
local result, err = glob({ rel_path = "../outside_project", pattern = "*.txt" })
|
||||
local result, err = glob({ path = "../outside_project", pattern = "*.txt" })
|
||||
assert.equals("", result)
|
||||
assert.truthy(err:find("No permission to access path"))
|
||||
end)
|
||||
|
||||
it("should handle patterns without matches", function()
|
||||
local result, err = glob({ rel_path = ".", pattern = "*.nonexistent" })
|
||||
local result, err = glob({ path = ".", pattern = "*.nonexistent" })
|
||||
assert.is_nil(err)
|
||||
local files = vim.json.decode(result).matches
|
||||
assert.equals(0, #files)
|
||||
@@ -411,7 +411,7 @@ describe("llm_tools", function()
|
||||
os.execute("touch " .. test_dir .. "/test_dir1/notignored1.lua")
|
||||
os.execute("touch " .. test_dir .. "/test_dir1/notignored2.lua")
|
||||
|
||||
local result, err = glob({ rel_path = ".", pattern = "**/*.lua" })
|
||||
local result, err = glob({ path = ".", pattern = "**/*.lua" })
|
||||
assert.is_nil(err)
|
||||
local files = vim.json.decode(result).matches
|
||||
|
||||
|
||||
Reference in New Issue
Block a user