fix: dispatch agent (#1953)
This commit is contained in:
@@ -59,15 +59,14 @@ function M.summarize_chat_thread_title(content, cb)
|
|||||||
})
|
})
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param bufnr integer
|
---@param prev_memory string | nil
|
||||||
---@param history avante.ChatHistory
|
|
||||||
---@param history_messages avante.HistoryMessage[]
|
---@param history_messages avante.HistoryMessage[]
|
||||||
---@param cb fun(memory: avante.ChatMemory | nil): nil
|
---@param cb fun(memory: avante.ChatMemory | nil): nil
|
||||||
function M.summarize_memory(bufnr, history, history_messages, cb)
|
function M.summarize_memory(prev_memory, history_messages, cb)
|
||||||
local system_prompt =
|
local system_prompt =
|
||||||
[[You are an expert coding assistant. Your goal is to generate a concise, structured summary of the conversation below that captures all essential information needed to continue development after context replacement. Include tasks performed, code areas modified or reviewed, key decisions or assumptions, test results or errors, and outstanding tasks or next steps.]]
|
[[You are an expert coding assistant. Your goal is to generate a concise, structured summary of the conversation below that captures all essential information needed to continue development after context replacement. Include tasks performed, code areas modified or reviewed, key decisions or assumptions, test results or errors, and outstanding tasks or next steps.]]
|
||||||
if #history_messages == 0 then
|
if #history_messages == 0 then
|
||||||
cb(history.memory)
|
cb(nil)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local latest_timestamp = history_messages[#history_messages].timestamp
|
local latest_timestamp = history_messages[#history_messages].timestamp
|
||||||
@@ -88,7 +87,7 @@ function M.summarize_memory(bufnr, history, history_messages, cb)
|
|||||||
local user_prompt = "Here is the conversation so far:\n"
|
local user_prompt = "Here is the conversation so far:\n"
|
||||||
.. conversation_text
|
.. conversation_text
|
||||||
.. "\n\nPlease summarize this conversation, covering:\n1. Tasks performed and outcomes\n2. Code files, modules, or functions modified or examined\n3. Important decisions or assumptions made\n4. Errors encountered and test or build results\n5. Remaining tasks, open questions, or next steps\nProvide the summary in a clear, concise format."
|
.. "\n\nPlease summarize this conversation, covering:\n1. Tasks performed and outcomes\n2. Code files, modules, or functions modified or examined\n3. Important decisions or assumptions made\n4. Errors encountered and test or build results\n5. Remaining tasks, open questions, or next steps\nProvide the summary in a clear, concise format."
|
||||||
if history.memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content end
|
if prev_memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. prev_memory end
|
||||||
local messages = {
|
local messages = {
|
||||||
{
|
{
|
||||||
role = "user",
|
role = "user",
|
||||||
@@ -121,11 +120,9 @@ function M.summarize_memory(bufnr, history, history_messages, cb)
|
|||||||
last_summarized_timestamp = latest_timestamp,
|
last_summarized_timestamp = latest_timestamp,
|
||||||
last_message_uuid = latest_message_uuid,
|
last_message_uuid = latest_message_uuid,
|
||||||
}
|
}
|
||||||
history.memory = memory
|
|
||||||
Path.history.save(bufnr, history)
|
|
||||||
cb(memory)
|
cb(memory)
|
||||||
else
|
else
|
||||||
cb(history.memory)
|
cb(nil)
|
||||||
end
|
end
|
||||||
end,
|
end,
|
||||||
},
|
},
|
||||||
@@ -622,6 +619,16 @@ function M._stream(opts)
|
|||||||
local provider = opts.provider or Providers[Config.provider]
|
local provider = opts.provider or Providers[Config.provider]
|
||||||
opts.session_ctx = opts.session_ctx or {}
|
opts.session_ctx = opts.session_ctx or {}
|
||||||
|
|
||||||
|
if not opts.session_ctx.on_messages_add then opts.session_ctx.on_messages_add = opts.on_messages_add end
|
||||||
|
if not opts.session_ctx.on_state_change then opts.session_ctx.on_state_change = opts.on_state_change end
|
||||||
|
if not opts.session_ctx.on_start then opts.session_ctx.on_start = opts.on_start end
|
||||||
|
if not opts.session_ctx.on_chunk then opts.session_ctx.on_chunk = opts.on_chunk end
|
||||||
|
if not opts.session_ctx.on_stop then opts.session_ctx.on_stop = opts.on_stop end
|
||||||
|
if not opts.session_ctx.on_tool_log then opts.session_ctx.on_tool_log = opts.on_tool_log end
|
||||||
|
if not opts.session_ctx.get_history_messages then
|
||||||
|
opts.session_ctx.get_history_messages = opts.get_history_messages
|
||||||
|
end
|
||||||
|
|
||||||
---@cast provider AvanteProviderFunctor
|
---@cast provider AvanteProviderFunctor
|
||||||
|
|
||||||
local prompt_opts = M.generate_prompts(opts)
|
local prompt_opts = M.generate_prompts(opts)
|
||||||
@@ -898,7 +905,7 @@ function M.stream(opts)
|
|||||||
local original_on_chunk = opts.on_chunk
|
local original_on_chunk = opts.on_chunk
|
||||||
opts.on_chunk = vim.schedule_wrap(function(chunk)
|
opts.on_chunk = vim.schedule_wrap(function(chunk)
|
||||||
if is_completed then return end
|
if is_completed then return end
|
||||||
return original_on_chunk(chunk)
|
if original_on_chunk then return original_on_chunk(chunk) end
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
if opts.on_stop ~= nil then
|
if opts.on_stop ~= nil then
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ local Providers = require("avante.providers")
|
|||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local Base = require("avante.llm_tools.base")
|
local Base = require("avante.llm_tools.base")
|
||||||
|
local HistoryMessage = require("avante.history_message")
|
||||||
|
|
||||||
---@class AvanteLLMTool
|
---@class AvanteLLMTool
|
||||||
local M = setmetatable({}, Base)
|
local M = setmetatable({}, Base)
|
||||||
@@ -80,29 +81,47 @@ Your task is to help the user with their request: "${prompt}"
|
|||||||
Be thorough and use the tools available to you to find the most relevant information.
|
Be thorough and use the tools available to you to find the most relevant information.
|
||||||
When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt)
|
When you're done, provide a clear and concise summary of what you found.]]):gsub("${prompt}", prompt)
|
||||||
|
|
||||||
local messages = session_ctx and session_ctx.messages or {}
|
local messages = {}
|
||||||
messages = messages or {}
|
table.insert(messages, { role = "user", content = "go!" })
|
||||||
table.insert(messages, { role = "user", content = prompt })
|
|
||||||
|
|
||||||
local tool_use_messages = {}
|
local tool_use_messages = {}
|
||||||
|
|
||||||
local total_tokens = 0
|
local total_tokens = 0
|
||||||
local final_response = ""
|
local final_response = ""
|
||||||
Llm._stream({
|
|
||||||
|
local memory_content = nil
|
||||||
|
local history_messages = {}
|
||||||
|
|
||||||
|
local stream_options = {
|
||||||
ask = true,
|
ask = true,
|
||||||
|
memory = memory_content,
|
||||||
code_lang = "unknown",
|
code_lang = "unknown",
|
||||||
provider = Providers[Config.provider],
|
provider = Providers[Config.provider],
|
||||||
on_tool_log = function(tool_id, tool_name, log, state)
|
get_history_messages = function() return history_messages end,
|
||||||
if on_log then on_log(string.format("[%s] %s", tool_name, log)) end
|
on_tool_log = session_ctx.on_tool_log,
|
||||||
end,
|
|
||||||
on_messages_add = function(msgs)
|
on_messages_add = function(msgs)
|
||||||
msgs = vim.is_list(msgs) and msgs or { msgs }
|
msgs = vim.islist(msgs) and msgs or { msgs }
|
||||||
for _, msg in ipairs(msgs) do
|
for _, msg in ipairs(msgs) do
|
||||||
local content = msg.message.content
|
local content = msg.message.content
|
||||||
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
|
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
|
||||||
tool_use_messages[msg.uuid] = true
|
tool_use_messages[msg.uuid] = true
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
for _, msg in ipairs(msgs) do
|
||||||
|
local idx = nil
|
||||||
|
for i, m in ipairs(history_messages) do
|
||||||
|
if m.uuid == msg.uuid then
|
||||||
|
idx = i
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if idx ~= nil then
|
||||||
|
history_messages[idx] = msg
|
||||||
|
else
|
||||||
|
table.insert(history_messages, msg)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if session_ctx.on_messages_add then session_ctx.on_messages_add(msgs) end
|
||||||
end,
|
end,
|
||||||
session_ctx = session_ctx,
|
session_ctx = session_ctx,
|
||||||
prompt_opts = {
|
prompt_opts = {
|
||||||
@@ -110,7 +129,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
|||||||
tools = tools,
|
tools = tools,
|
||||||
messages = messages,
|
messages = messages,
|
||||||
},
|
},
|
||||||
on_start = function(_) end,
|
on_start = session_ctx.on_start,
|
||||||
on_chunk = function(chunk)
|
on_chunk = function(chunk)
|
||||||
if not chunk then return end
|
if not chunk then return end
|
||||||
final_response = final_response .. chunk
|
final_response = final_response .. chunk
|
||||||
@@ -125,18 +144,46 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
|||||||
local end_time = Utils.get_timestamp()
|
local end_time = Utils.get_timestamp()
|
||||||
local elapsed_time = Utils.datetime_diff(start_time, end_time)
|
local elapsed_time = Utils.datetime_diff(start_time, end_time)
|
||||||
local tool_use_count = vim.tbl_count(tool_use_messages)
|
local tool_use_count = vim.tbl_count(tool_use_messages)
|
||||||
local summary = "Done ("
|
local summary = "dispatch_agent Done ("
|
||||||
.. (tool_use_count <= 1 and "1 tool use" or tool_use_count .. " tool uses")
|
.. (tool_use_count <= 1 and "1 tool use" or tool_use_count .. " tool uses")
|
||||||
.. " · "
|
.. " · "
|
||||||
.. math.ceil(total_tokens)
|
.. math.ceil(total_tokens)
|
||||||
.. " tokens · "
|
.. " tokens · "
|
||||||
.. elapsed_time
|
.. elapsed_time
|
||||||
.. "s)"
|
.. "s)"
|
||||||
Utils.debug("summary", summary)
|
if session_ctx.on_messages_add then
|
||||||
|
local message = HistoryMessage:new({
|
||||||
|
role = "assistant",
|
||||||
|
content = "\n\n" .. summary,
|
||||||
|
}, {
|
||||||
|
just_for_display = true,
|
||||||
|
})
|
||||||
|
session_ctx.on_messages_add({ message })
|
||||||
|
end
|
||||||
local response = string.format("Final response:\n%s\n\nSummary:\n%s", summary, final_response)
|
local response = string.format("Final response:\n%s\n\nSummary:\n%s", summary, final_response)
|
||||||
on_complete(response, nil)
|
on_complete(response, nil)
|
||||||
end,
|
end,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
local function on_memory_summarize(dropped_history_messages)
|
||||||
|
Llm.summarize_memory(memory_content, dropped_history_messages or {}, function(memory)
|
||||||
|
if memory then stream_options.memory = memory.content end
|
||||||
|
local new_history_messages = {}
|
||||||
|
for _, msg in ipairs(history_messages) do
|
||||||
|
if vim.iter(dropped_history_messages):find(function(dropped_msg) return dropped_msg.uuid == msg.uuid end) then
|
||||||
|
goto continue
|
||||||
|
end
|
||||||
|
table.insert(new_history_messages, msg)
|
||||||
|
::continue::
|
||||||
|
end
|
||||||
|
history_messages = new_history_messages
|
||||||
|
Llm._stream(stream_options)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
stream_options.on_memory_summarize = on_memory_summarize
|
||||||
|
|
||||||
|
Llm._stream(stream_options)
|
||||||
end
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|||||||
@@ -120,7 +120,12 @@ function M:parse_messages(opts)
|
|||||||
if #content > 0 then table.insert(messages, { role = self.role_map[msg.role], content = content }) end
|
if #content > 0 then table.insert(messages, { role = self.role_map[msg.role], content = content }) end
|
||||||
if not provider_conf.disable_tools then
|
if not provider_conf.disable_tools then
|
||||||
if #tool_calls > 0 then
|
if #tool_calls > 0 then
|
||||||
table.insert(messages, { role = self.role_map["assistant"], tool_calls = tool_calls })
|
local last_message = messages[#messages]
|
||||||
|
if last_message and last_message.role == self.role_map["assistant"] and last_message.tool_calls then
|
||||||
|
last_message.tool_calls = vim.list_extend(last_message.tool_calls, tool_calls)
|
||||||
|
else
|
||||||
|
table.insert(messages, { role = self.role_map["assistant"], tool_calls = tool_calls })
|
||||||
|
end
|
||||||
end
|
end
|
||||||
if #tool_results > 0 then
|
if #tool_results > 0 then
|
||||||
for _, tool_result in ipairs(tool_results) do
|
for _, tool_result in ipairs(tool_results) do
|
||||||
@@ -155,7 +160,7 @@ function M:parse_messages(opts)
|
|||||||
|
|
||||||
vim.iter(messages):each(function(message)
|
vim.iter(messages):each(function(message)
|
||||||
local role = message.role
|
local role = message.role
|
||||||
if role == prev_role then
|
if role == prev_role and role ~= "tool" then
|
||||||
if role == self.role_map["assistant"] then
|
if role == self.role_map["assistant"] then
|
||||||
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
|
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -2311,11 +2311,14 @@ function Sidebar:create_input_container()
|
|||||||
|
|
||||||
---@param dropped_history_messages avante.HistoryMessage[]
|
---@param dropped_history_messages avante.HistoryMessage[]
|
||||||
local function on_memory_summarize(dropped_history_messages)
|
local function on_memory_summarize(dropped_history_messages)
|
||||||
Llm.summarize_memory(self.code.bufnr, self.chat_history, dropped_history_messages, function(memory)
|
local history_memory = self.chat_history.memory
|
||||||
if memory then stream_options.memory = memory.content end
|
Llm.summarize_memory(history_memory and history_memory.content, dropped_history_messages, function(memory)
|
||||||
|
if memory then
|
||||||
|
self.chat_history.memory = memory
|
||||||
|
Path.history.save(self.code.bufnr, self.chat_history)
|
||||||
|
stream_options.memory = memory.content
|
||||||
|
end
|
||||||
stream_options.history_messages = self:get_history_messages_for_api()
|
stream_options.history_messages = self:get_history_messages_for_api()
|
||||||
-- Utils.debug("dropping history messages", dropped_history_messages)
|
|
||||||
-- Utils.debug("history messages", stream_options.history_messages)
|
|
||||||
Llm.stream(stream_options)
|
Llm.stream(stream_options)
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -361,7 +361,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---
|
---
|
||||||
---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions
|
---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions
|
||||||
---@field on_start AvanteLLMStartCallback
|
---@field on_start AvanteLLMStartCallback
|
||||||
---@field on_chunk AvanteLLMChunkCallback
|
---@field on_chunk? AvanteLLMChunkCallback
|
||||||
---@field on_stop AvanteLLMStopCallback
|
---@field on_stop AvanteLLMStopCallback
|
||||||
---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback
|
---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback
|
||||||
---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
|
---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
|
||||||
|
|||||||
Reference in New Issue
Block a user