refactor: llm tool parameters (#2449)

This commit is contained in:
yetone
2025-07-15 16:40:25 +08:00
committed by GitHub
parent 0c6a8f5688
commit b8bb0fd969
25 changed files with 627 additions and 381 deletions

View File

@@ -3,6 +3,8 @@ local Config = require("avante.config")
local Utils = require("avante.utils")
local Base = require("avante.llm_tools.base")
local HistoryMessage = require("avante.history_message")
local Line = require("avante.ui.line")
local Highlights = require("avante.highlights")
---@class AvanteLLMTool
local M = setmetatable({}, Base)
@@ -79,12 +81,118 @@ local function get_available_tools()
}
end
---@type AvanteLLMToolFunc<{ prompt: string }>
function M.func(opts, on_log, on_complete, session_ctx)
---@class avante.DispatchAgentInput
---@field prompt string
---@type avante.LLMToolOnRender<avante.DispatchAgentInput>
function M.on_render(input, opts)
local result_message = opts.result_message
local store = opts.store or {}
local messages = store.messages or {}
local tool_use_summary = {}
for _, msg in ipairs(messages) do
local content = msg.message.content
local summary
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
local tool_result_message = Utils.get_tool_result_message(msg, messages)
if tool_result_message then
local tool_name = msg.message.content[1].name
if tool_name == "ls" then
local path = msg.message.content[1].input.path
if tool_result_message.message.content[1].is_error then
summary = string.format("Ls %s: failed", path)
else
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then summary = string.format("Ls %s: %d paths", path, #filepaths) end
end
elseif tool_name == "grep" then
local path = msg.message.content[1].input.path
local query = msg.message.content[1].input.query
if tool_result_message.message.content[1].is_error then
summary = string.format("Grep %s in %s: failed", query, path)
else
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then summary = string.format("Grep %s in %s: %d paths", query, path, #filepaths) end
end
elseif tool_name == "glob" then
local path = msg.message.content[1].input.path
local pattern = msg.message.content[1].input.pattern
if tool_result_message.message.content[1].is_error then
summary = string.format("Glob %s in %s: failed", pattern, path)
else
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then
local matches = result.matches
if matches then summary = string.format("Glob %s in %s: %d matches", pattern, path, #matches) end
end
end
elseif tool_name == "view" then
local path = msg.message.content[1].input.path
if tool_result_message.message.content[1].is_error then
summary = string.format("View %s: failed", path)
else
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
if ok then
local content_ = result.content
local lines = vim.split(content_, "\n")
summary = string.format("View %s: %d lines", path, #lines)
end
end
end
end
if summary then summary = " " .. Utils.icon("🛠️ ") .. summary end
elseif type(content) == "table" and #content > 0 and type(content[1]) == "table" and content[1].type == "text" then
summary = content[1].content
elseif type(content) == "table" and #content > 0 and type(content[1]) == "string" then
summary = content[1]
elseif type(content) == "string" then
summary = content
end
if summary then table.insert(tool_use_summary, summary) end
end
local state = "running"
local icon = Utils.icon("🔄 ")
local hl = Highlights.AVANTE_TASK_RUNNING
if result_message then
if result_message.message.content[1].is_error then
state = "failed"
icon = Utils.icon("")
hl = Highlights.AVANTE_TASK_FAILED
else
state = "completed"
icon = Utils.icon("")
hl = Highlights.AVANTE_TASK_COMPLETED
end
end
local lines = {}
table.insert(lines, Line:new({ { icon .. "Subtask " .. state, hl } }))
table.insert(lines, Line:new({ { "" } }))
table.insert(lines, Line:new({ { " Task:" } }))
local prompt_lines = vim.split(input.prompt or "", "\n")
for _, line in ipairs(prompt_lines) do
table.insert(lines, Line:new({ { " " .. line } }))
end
table.insert(lines, Line:new({ { "" } }))
table.insert(lines, Line:new({ { " Task summary:" } }))
for _, summary in ipairs(tool_use_summary) do
local summary_lines = vim.split(summary, "\n")
for _, line in ipairs(summary_lines) do
table.insert(lines, Line:new({ { " " .. line } }))
end
end
return lines
end
---@type AvanteLLMToolFunc<avante.DispatchAgentInput>
function M.func(input, opts)
local on_log = opts.on_log
local on_complete = opts.on_complete
local session_ctx = opts.session_ctx
local Llm = require("avante.llm")
if not on_complete then return false, "on_complete not provided" end
local prompt = opts.prompt
local prompt = input.prompt
local tools = get_available_tools()
local start_time = Utils.get_timestamp()
@@ -95,10 +203,11 @@ Your task is to help the user with their request: "${prompt}"
Be thorough and use the tools available to you to find the most relevant information.
When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt)
local history_messages = {}
local tool_use_messages = {}
local total_tokens = 0
local final_response = ""
local result = ""
---@type avante.AgentLoopOptions
local agent_loop_options = {
@@ -108,19 +217,37 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
on_tool_log = session_ctx.on_tool_log,
on_messages_add = function(msgs)
msgs = vim.islist(msgs) and msgs or { msgs }
for _, msg in ipairs(msgs) do
local idx = nil
for i, m in ipairs(history_messages) do
if m.uuid == msg.uuid then
idx = i
break
end
end
if idx ~= nil then
history_messages[idx] = msg
else
table.insert(history_messages, msg)
end
end
if opts.set_store then opts.set_store("messages", history_messages) end
for _, msg in ipairs(msgs) do
local content = msg.message.content
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
tool_use_messages[msg.uuid] = true
if content[1].name == "attempt_completion" then
local input_ = content[1].input
if input_ and input_.result then result = input_.result end
end
end
end
if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
-- if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
end,
session_ctx = session_ctx,
on_start = session_ctx.on_start,
on_chunk = function(chunk)
if not chunk then return end
final_response = final_response .. chunk
total_tokens = total_tokens + (#vim.split(chunk, " ") * 1.3)
end,
on_complete = function(err)
@@ -148,8 +275,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
})
session_ctx.on_messages_add({ message })
end
local response = string.format("Final response:\n%s\n\nSummary:\n%s", summary, final_response)
on_complete(response, nil)
on_complete(result, nil)
end,
}