feat: todos (#2184)
This commit is contained in:
@@ -129,6 +129,72 @@ function M.summarize_memory(prev_memory, history_messages, cb)
|
||||
})
|
||||
end
|
||||
|
||||
---@param user_input string
|
||||
---@param cb fun(error: string | nil): nil
|
||||
function M.generate_todos(user_input, cb)
|
||||
local system_prompt =
|
||||
[[You are an expert coding assistant. Please generate a todo list to complete the task based on the user input and pass the todo list to the add_todos tool.]]
|
||||
local messages = {
|
||||
{ role = "user", content = user_input },
|
||||
}
|
||||
|
||||
local provider = Providers[Config.provider]
|
||||
local tools = {
|
||||
require("avante.llm_tools.add_todos"),
|
||||
}
|
||||
|
||||
local history_messages = {}
|
||||
cb = Utils.call_once(cb)
|
||||
|
||||
M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = {
|
||||
system_prompt = system_prompt,
|
||||
messages = messages,
|
||||
tools = tools,
|
||||
},
|
||||
handler_opts = {
|
||||
on_start = function() end,
|
||||
on_chunk = function() end,
|
||||
on_messages_add = function(msgs)
|
||||
msgs = vim.islist(msgs) and msgs or { msgs }
|
||||
for _, msg in ipairs(msgs) do
|
||||
if not msg.uuid then msg.uuid = Utils.uuid() end
|
||||
local idx = nil
|
||||
for i, m in ipairs(history_messages) do
|
||||
if m.uuid == msg.uuid then
|
||||
idx = i
|
||||
break
|
||||
end
|
||||
end
|
||||
if idx ~= nil then
|
||||
history_messages[idx] = msg
|
||||
else
|
||||
table.insert(history_messages, msg)
|
||||
end
|
||||
end
|
||||
end,
|
||||
on_stop = function(stop_opts)
|
||||
if stop_opts.error ~= nil then
|
||||
Utils.error(string.format("generate todos failed: %s", vim.inspect(stop_opts.error)))
|
||||
return
|
||||
end
|
||||
if stop_opts.reason == "tool_use" then
|
||||
local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages)
|
||||
for _, partial_tool_use in ipairs(uncalled_tool_uses) do
|
||||
if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then
|
||||
LLMTools.process_tool_use(tools, partial_tool_use, function() end, function() cb() end, {})
|
||||
cb()
|
||||
end
|
||||
end
|
||||
else
|
||||
cb()
|
||||
end
|
||||
end,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
---@param opts AvanteGeneratePromptsOptions
|
||||
---@return AvantePromptOptions
|
||||
function M.generate_prompts(opts)
|
||||
@@ -202,6 +268,11 @@ function M.generate_prompts(opts)
|
||||
memory = opts.memory,
|
||||
}
|
||||
|
||||
if opts.get_todos then
|
||||
local todos = opts.get_todos()
|
||||
if todos and #todos > 0 then template_opts.todos = vim.json.encode(todos) end
|
||||
end
|
||||
|
||||
local system_prompt
|
||||
if opts.prompt_opts and opts.prompt_opts.system_prompt then
|
||||
system_prompt = opts.prompt_opts.system_prompt
|
||||
@@ -492,8 +563,6 @@ function M.curl(opts)
|
||||
local retry_after = 10
|
||||
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
|
||||
if result.status == 429 then
|
||||
Utils.debug("result", result)
|
||||
|
||||
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
|
||||
return
|
||||
end
|
||||
@@ -709,74 +778,69 @@ function M._stream(opts)
|
||||
end
|
||||
return opts.on_stop({ reason = "cancelled" })
|
||||
end
|
||||
local partial_tool_use_list = {} ---@type AvantePartialLLMToolUse[]
|
||||
local tool_result_seen = {}
|
||||
local history_messages = opts.get_history_messages and opts.get_history_messages({ all = true }) or {}
|
||||
for idx = #history_messages, 1, -1 do
|
||||
local message = history_messages[idx]
|
||||
local content = message.message.content
|
||||
if type(content) ~= "table" or #content == 0 then goto continue end
|
||||
local is_break = false
|
||||
for _, item in ipairs(content) do
|
||||
if item.type == "tool_use" then
|
||||
if not tool_result_seen[item.id] then
|
||||
local partial_tool_use = {
|
||||
name = item.name,
|
||||
id = item.id,
|
||||
input = item.input,
|
||||
state = message.state,
|
||||
}
|
||||
table.insert(partial_tool_use_list, 1, partial_tool_use)
|
||||
else
|
||||
is_break = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if item.type == "tool_result" then tool_result_seen[item.tool_use_id] = true end
|
||||
end
|
||||
if is_break then break end
|
||||
::continue::
|
||||
end
|
||||
local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages)
|
||||
if stop_opts.reason == "complete" and Config.mode == "agentic" then
|
||||
if #partial_tool_use_list == 0 then
|
||||
local completed_attempt_completion_tool_use = nil
|
||||
for idx = #history_messages, 1, -1 do
|
||||
local message = history_messages[idx]
|
||||
if message.is_user_submission then break end
|
||||
if not Utils.is_tool_use_message(message) then goto continue end
|
||||
if message.message.content[1].name ~= "attempt_completion" then break end
|
||||
completed_attempt_completion_tool_use = message
|
||||
if message then break end
|
||||
::continue::
|
||||
end
|
||||
local user_reminder_count = opts.session_ctx.user_reminder_count or 0
|
||||
if not completed_attempt_completion_tool_use and opts.on_messages_add and user_reminder_count < 3 then
|
||||
opts.session_ctx.user_reminder_count = user_reminder_count + 1
|
||||
local message = HistoryMessage:new({
|
||||
local completed_attempt_completion_tool_use = nil
|
||||
for idx = #history_messages, 1, -1 do
|
||||
local message = history_messages[idx]
|
||||
if message.is_user_submission then break end
|
||||
if not Utils.is_tool_use_message(message) then goto continue end
|
||||
if message.message.content[1].name ~= "attempt_completion" then break end
|
||||
completed_attempt_completion_tool_use = message
|
||||
if message then break end
|
||||
::continue::
|
||||
end
|
||||
local unfinished_todos = {}
|
||||
if opts.get_todos then
|
||||
local todos = opts.get_todos()
|
||||
unfinished_todos = vim.tbl_filter(
|
||||
function(todo) return todo.status ~= "done" or todo.status ~= "cancelled" end,
|
||||
todos
|
||||
)
|
||||
end
|
||||
local user_reminder_count = opts.session_ctx.user_reminder_count or 0
|
||||
if
|
||||
not completed_attempt_completion_tool_use
|
||||
and opts.on_messages_add
|
||||
and (user_reminder_count < 3 or #unfinished_todos > 0)
|
||||
then
|
||||
opts.session_ctx.user_reminder_count = user_reminder_count + 1
|
||||
Utils.debug("user reminder count", user_reminder_count)
|
||||
local message
|
||||
if #unfinished_todos > 0 then
|
||||
message = HistoryMessage:new({
|
||||
role = "user",
|
||||
content = "<user-reminder>You should use tool calls to answer the question, for example, use update_todo_status if the task step is done or cancelled.</user-reminder>",
|
||||
}, {
|
||||
visible = false,
|
||||
})
|
||||
else
|
||||
message = HistoryMessage:new({
|
||||
role = "user",
|
||||
content = "<user-reminder>You should use tool calls to answer the question, for example, use attempt_completion if the job is done.</user-reminder>",
|
||||
}, {
|
||||
visible = false,
|
||||
})
|
||||
opts.on_messages_add({ message })
|
||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||
history_messages = opts.get_history_messages(),
|
||||
})
|
||||
if provider.get_rate_limit_sleep_time then
|
||||
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
|
||||
if sleep_time and sleep_time > 0 then
|
||||
Utils.info("Rate limit reached. Sleeping for " .. sleep_time .. " seconds ...")
|
||||
vim.defer_fn(function() M._stream(new_opts) end, sleep_time * 1000)
|
||||
return
|
||||
end
|
||||
end
|
||||
M._stream(new_opts)
|
||||
return
|
||||
end
|
||||
opts.on_messages_add({ message })
|
||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||
history_messages = opts.get_history_messages(),
|
||||
})
|
||||
if provider.get_rate_limit_sleep_time then
|
||||
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
|
||||
if sleep_time and sleep_time > 0 then
|
||||
Utils.info("Rate limit reached. Sleeping for " .. sleep_time .. " seconds ...")
|
||||
vim.defer_fn(function() M._stream(new_opts) end, sleep_time * 1000)
|
||||
return
|
||||
end
|
||||
end
|
||||
M._stream(new_opts)
|
||||
return
|
||||
end
|
||||
end
|
||||
if stop_opts.reason == "tool_use" then
|
||||
return handle_next_tool_use(partial_tool_use_list, 1, {}, stop_opts.streaming_tool_use)
|
||||
return handle_next_tool_use(uncalled_tool_uses, 1, {}, stop_opts.streaming_tool_use)
|
||||
end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
|
||||
|
||||
Reference in New Issue
Block a user