refactor: llm tool parameters (#2449)
This commit is contained in:
@@ -3,6 +3,8 @@ local Config = require("avante.config")
|
||||
local Utils = require("avante.utils")
|
||||
local Base = require("avante.llm_tools.base")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
local Line = require("avante.ui.line")
|
||||
local Highlights = require("avante.highlights")
|
||||
|
||||
---@class AvanteLLMTool
|
||||
local M = setmetatable({}, Base)
|
||||
@@ -79,12 +81,118 @@ local function get_available_tools()
|
||||
}
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<{ prompt: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
---@class avante.DispatchAgentInput
|
||||
---@field prompt string
|
||||
|
||||
---@type avante.LLMToolOnRender<avante.DispatchAgentInput>
|
||||
function M.on_render(input, opts)
|
||||
local result_message = opts.result_message
|
||||
local store = opts.store or {}
|
||||
local messages = store.messages or {}
|
||||
local tool_use_summary = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
local content = msg.message.content
|
||||
local summary
|
||||
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
|
||||
local tool_result_message = Utils.get_tool_result_message(msg, messages)
|
||||
if tool_result_message then
|
||||
local tool_name = msg.message.content[1].name
|
||||
if tool_name == "ls" then
|
||||
local path = msg.message.content[1].input.path
|
||||
if tool_result_message.message.content[1].is_error then
|
||||
summary = string.format("Ls %s: failed", path)
|
||||
else
|
||||
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
|
||||
if ok then summary = string.format("Ls %s: %d paths", path, #filepaths) end
|
||||
end
|
||||
elseif tool_name == "grep" then
|
||||
local path = msg.message.content[1].input.path
|
||||
local query = msg.message.content[1].input.query
|
||||
if tool_result_message.message.content[1].is_error then
|
||||
summary = string.format("Grep %s in %s: failed", query, path)
|
||||
else
|
||||
local ok, filepaths = pcall(vim.json.decode, tool_result_message.message.content[1].content)
|
||||
if ok then summary = string.format("Grep %s in %s: %d paths", query, path, #filepaths) end
|
||||
end
|
||||
elseif tool_name == "glob" then
|
||||
local path = msg.message.content[1].input.path
|
||||
local pattern = msg.message.content[1].input.pattern
|
||||
if tool_result_message.message.content[1].is_error then
|
||||
summary = string.format("Glob %s in %s: failed", pattern, path)
|
||||
else
|
||||
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
|
||||
if ok then
|
||||
local matches = result.matches
|
||||
if matches then summary = string.format("Glob %s in %s: %d matches", pattern, path, #matches) end
|
||||
end
|
||||
end
|
||||
elseif tool_name == "view" then
|
||||
local path = msg.message.content[1].input.path
|
||||
if tool_result_message.message.content[1].is_error then
|
||||
summary = string.format("View %s: failed", path)
|
||||
else
|
||||
local ok, result = pcall(vim.json.decode, tool_result_message.message.content[1].content)
|
||||
if ok then
|
||||
local content_ = result.content
|
||||
local lines = vim.split(content_, "\n")
|
||||
summary = string.format("View %s: %d lines", path, #lines)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if summary then summary = " " .. Utils.icon("🛠️ ") .. summary end
|
||||
elseif type(content) == "table" and #content > 0 and type(content[1]) == "table" and content[1].type == "text" then
|
||||
summary = content[1].content
|
||||
elseif type(content) == "table" and #content > 0 and type(content[1]) == "string" then
|
||||
summary = content[1]
|
||||
elseif type(content) == "string" then
|
||||
summary = content
|
||||
end
|
||||
if summary then table.insert(tool_use_summary, summary) end
|
||||
end
|
||||
local state = "running"
|
||||
local icon = Utils.icon("🔄 ")
|
||||
local hl = Highlights.AVANTE_TASK_RUNNING
|
||||
if result_message then
|
||||
if result_message.message.content[1].is_error then
|
||||
state = "failed"
|
||||
icon = Utils.icon("❌ ")
|
||||
hl = Highlights.AVANTE_TASK_FAILED
|
||||
else
|
||||
state = "completed"
|
||||
icon = Utils.icon("✅ ")
|
||||
hl = Highlights.AVANTE_TASK_COMPLETED
|
||||
end
|
||||
end
|
||||
local lines = {}
|
||||
table.insert(lines, Line:new({ { icon .. "Subtask " .. state, hl } }))
|
||||
table.insert(lines, Line:new({ { "" } }))
|
||||
table.insert(lines, Line:new({ { " Task:" } }))
|
||||
local prompt_lines = vim.split(input.prompt or "", "\n")
|
||||
for _, line in ipairs(prompt_lines) do
|
||||
table.insert(lines, Line:new({ { " " .. line } }))
|
||||
end
|
||||
table.insert(lines, Line:new({ { "" } }))
|
||||
table.insert(lines, Line:new({ { " Task summary:" } }))
|
||||
for _, summary in ipairs(tool_use_summary) do
|
||||
local summary_lines = vim.split(summary, "\n")
|
||||
for _, line in ipairs(summary_lines) do
|
||||
table.insert(lines, Line:new({ { " " .. line } }))
|
||||
end
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolFunc<avante.DispatchAgentInput>
|
||||
function M.func(input, opts)
|
||||
local on_log = opts.on_log
|
||||
local on_complete = opts.on_complete
|
||||
local session_ctx = opts.session_ctx
|
||||
|
||||
local Llm = require("avante.llm")
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
|
||||
local prompt = opts.prompt
|
||||
local prompt = input.prompt
|
||||
local tools = get_available_tools()
|
||||
local start_time = Utils.get_timestamp()
|
||||
|
||||
@@ -95,10 +203,11 @@ Your task is to help the user with their request: "${prompt}"
|
||||
Be thorough and use the tools available to you to find the most relevant information.
|
||||
When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt)
|
||||
|
||||
local history_messages = {}
|
||||
local tool_use_messages = {}
|
||||
|
||||
local total_tokens = 0
|
||||
local final_response = ""
|
||||
local result = ""
|
||||
|
||||
---@type avante.AgentLoopOptions
|
||||
local agent_loop_options = {
|
||||
@@ -108,19 +217,37 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
||||
on_tool_log = session_ctx.on_tool_log,
|
||||
on_messages_add = function(msgs)
|
||||
msgs = vim.islist(msgs) and msgs or { msgs }
|
||||
for _, msg in ipairs(msgs) do
|
||||
local idx = nil
|
||||
for i, m in ipairs(history_messages) do
|
||||
if m.uuid == msg.uuid then
|
||||
idx = i
|
||||
break
|
||||
end
|
||||
end
|
||||
if idx ~= nil then
|
||||
history_messages[idx] = msg
|
||||
else
|
||||
table.insert(history_messages, msg)
|
||||
end
|
||||
end
|
||||
if opts.set_store then opts.set_store("messages", history_messages) end
|
||||
for _, msg in ipairs(msgs) do
|
||||
local content = msg.message.content
|
||||
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
|
||||
tool_use_messages[msg.uuid] = true
|
||||
if content[1].name == "attempt_completion" then
|
||||
local input_ = content[1].input
|
||||
if input_ and input_.result then result = input_.result end
|
||||
end
|
||||
end
|
||||
end
|
||||
if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
|
||||
-- if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
|
||||
end,
|
||||
session_ctx = session_ctx,
|
||||
on_start = session_ctx.on_start,
|
||||
on_chunk = function(chunk)
|
||||
if not chunk then return end
|
||||
final_response = final_response .. chunk
|
||||
total_tokens = total_tokens + (#vim.split(chunk, " ") * 1.3)
|
||||
end,
|
||||
on_complete = function(err)
|
||||
@@ -148,8 +275,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
||||
})
|
||||
session_ctx.on_messages_add({ message })
|
||||
end
|
||||
local response = string.format("Final response:\n%s\n\nSummary:\n%s", summary, final_response)
|
||||
on_complete(response, nil)
|
||||
on_complete(result, nil)
|
||||
end,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user