refactor: history messages (#1934)
This commit is contained in:
@@ -10,6 +10,7 @@ local Path = require("avante.path")
|
||||
local Providers = require("avante.providers")
|
||||
local LLMToolHelpers = require("avante.llm_tools.helpers")
|
||||
local LLMTools = require("avante.llm_tools")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
|
||||
---@class avante.LLM
|
||||
local M = {}
|
||||
@@ -26,7 +27,7 @@ function M.summarize_chat_thread_title(content, cb)
|
||||
local system_prompt =
|
||||
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]]
|
||||
local response_content = ""
|
||||
local provider = Providers[Config.memory_summary_provider or Config.provider]
|
||||
local provider = Providers.get_memory_summary_provider()
|
||||
M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = {
|
||||
@@ -58,73 +59,49 @@ function M.summarize_chat_thread_title(content, cb)
|
||||
})
|
||||
end
|
||||
|
||||
---@param messages AvanteLLMMessage[]
|
||||
---@return AvanteLLMMessage[]
|
||||
local function filter_out_tool_use_messages(messages)
|
||||
local filtered_messages = {}
|
||||
for _, message in ipairs(messages) do
|
||||
local content = message.content
|
||||
if type(content) == "table" then
|
||||
local new_content = {}
|
||||
for _, item in ipairs(content) do
|
||||
if item.type == "tool_use" or item.type == "tool_result" then goto continue end
|
||||
table.insert(new_content, item)
|
||||
::continue::
|
||||
end
|
||||
content = new_content
|
||||
end
|
||||
if type(content) == "table" then
|
||||
if #content > 0 then table.insert(filtered_messages, { role = message.role, content = content }) end
|
||||
else
|
||||
table.insert(filtered_messages, { role = message.role, content = content })
|
||||
end
|
||||
end
|
||||
return filtered_messages
|
||||
end
|
||||
|
||||
---@param bufnr integer
|
||||
---@param history avante.ChatHistory
|
||||
---@param entries? avante.ChatHistoryEntry[]
|
||||
---@param history_messages avante.HistoryMessage[]
|
||||
---@param cb fun(memory: avante.ChatMemory | nil): nil
|
||||
function M.summarize_memory(bufnr, history, entries, cb)
|
||||
local system_prompt = [[You are a helpful AI assistant tasked with summarizing conversations.]]
|
||||
if not entries then entries = Utils.history.filter_active_entries(history.entries) end
|
||||
if #entries == 0 then
|
||||
cb(nil)
|
||||
return
|
||||
end
|
||||
if history.memory then
|
||||
entries = vim
|
||||
.iter(entries)
|
||||
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
|
||||
:totable()
|
||||
end
|
||||
if #entries == 0 then
|
||||
cb(history.memory)
|
||||
return
|
||||
end
|
||||
local history_messages = Utils.history.entries_to_llm_messages(entries)
|
||||
history_messages = filter_out_tool_use_messages(history_messages)
|
||||
history_messages = vim.list_slice(history_messages, 1, 4)
|
||||
function M.summarize_memory(bufnr, history, history_messages, cb)
|
||||
local system_prompt =
|
||||
[[You are an expert coding assistant. Your goal is to generate a concise, structured summary of the conversation below that captures all essential information needed to continue development after context replacement. Include tasks performed, code areas modified or reviewed, key decisions or assumptions, test results or errors, and outstanding tasks or next steps.]]
|
||||
if #history_messages == 0 then
|
||||
cb(history.memory)
|
||||
return
|
||||
end
|
||||
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
|
||||
local user_prompt =
|
||||
[[Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.]]
|
||||
local latest_timestamp = history_messages[#history_messages].timestamp
|
||||
local latest_message_uuid = history_messages[#history_messages].uuid
|
||||
local conversation_items = vim
|
||||
.iter(history_messages)
|
||||
:filter(function(msg)
|
||||
if msg.just_for_display then return false end
|
||||
if msg.message.role ~= "assistant" and msg.message.role ~= "user" then return false end
|
||||
local content = msg.message.content
|
||||
if type(content) == "table" and content[1].type == "tool_result" then return false end
|
||||
if type(content) == "table" and content[1].type == "tool_use" then return false end
|
||||
return true
|
||||
end)
|
||||
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg) end)
|
||||
:totable()
|
||||
local conversation_text = table.concat(conversation_items, "\n")
|
||||
local user_prompt = "Here is the conversation so far:\n"
|
||||
.. conversation_text
|
||||
.. "\n\nPlease summarize this conversation, covering:\n1. Tasks performed and outcomes\n2. Code files, modules, or functions modified or examined\n3. Important decisions or assumptions made\n4. Errors encountered and test or build results\n5. Remaining tasks, open questions, or next steps\nProvide the summary in a clear, concise format."
|
||||
if history.memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content end
|
||||
table.insert(history_messages, {
|
||||
role = "user",
|
||||
content = user_prompt,
|
||||
})
|
||||
local messages = {
|
||||
{
|
||||
role = "user",
|
||||
content = user_prompt,
|
||||
},
|
||||
}
|
||||
local response_content = ""
|
||||
local provider = Providers[Config.memory_summary_provider or Config.provider]
|
||||
local provider = Providers.get_memory_summary_provider()
|
||||
M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = {
|
||||
system_prompt = system_prompt,
|
||||
messages = history_messages,
|
||||
messages = messages,
|
||||
},
|
||||
handler_opts = {
|
||||
on_start = function(_) end,
|
||||
@@ -141,11 +118,14 @@ function M.summarize_memory(bufnr, history, entries, cb)
|
||||
response_content = Utils.trim_think_content(response_content)
|
||||
local memory = {
|
||||
content = response_content,
|
||||
last_summarized_timestamp = entries[#entries].timestamp,
|
||||
last_summarized_timestamp = latest_timestamp,
|
||||
last_message_uuid = latest_message_uuid,
|
||||
}
|
||||
history.memory = memory
|
||||
Path.history.save(bufnr, history)
|
||||
cb(memory)
|
||||
else
|
||||
cb(history.memory)
|
||||
end
|
||||
end,
|
||||
},
|
||||
@@ -156,7 +136,7 @@ end
|
||||
---@return AvantePromptOptions
|
||||
function M.generate_prompts(opts)
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
local mode = opts.mode or "planning"
|
||||
local mode = opts.mode or Config.mode
|
||||
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||
local _, request_body = Providers.parse_config(provider)
|
||||
local max_tokens = request_body.max_tokens or 4096
|
||||
@@ -166,22 +146,58 @@ function M.generate_prompts(opts)
|
||||
if opts.prompt_opts and opts.prompt_opts.image_paths then
|
||||
image_paths = vim.list_extend(image_paths, opts.prompt_opts.image_paths)
|
||||
end
|
||||
local instructions = opts.instructions
|
||||
if instructions and instructions:match("image: ") then
|
||||
local lines = vim.split(opts.instructions, "\n")
|
||||
for i, line in ipairs(lines) do
|
||||
if line:match("^image: ") then
|
||||
local image_path = line:gsub("^image: ", "")
|
||||
table.insert(image_paths, image_path)
|
||||
table.remove(lines, i)
|
||||
end
|
||||
end
|
||||
instructions = table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
local project_root = Utils.root.get()
|
||||
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root))
|
||||
|
||||
local tool_id_to_tool_name = {}
|
||||
local tool_id_to_path = {}
|
||||
local viewed_files = {}
|
||||
if opts.history_messages then
|
||||
for _, message in ipairs(opts.history_messages) do
|
||||
local content = message.message.content
|
||||
if type(content) ~= "table" then goto continue end
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) ~= "table" then goto continue1 end
|
||||
if item.type ~= "tool_use" then goto continue1 end
|
||||
local tool_name = item.name
|
||||
if tool_name ~= "view" then goto continue1 end
|
||||
local path = item.input.path
|
||||
tool_id_to_tool_name[item.id] = tool_name
|
||||
if path then
|
||||
local uniform_path = Utils.uniform_path(path)
|
||||
tool_id_to_path[item.id] = uniform_path
|
||||
viewed_files[uniform_path] = item.id
|
||||
end
|
||||
::continue1::
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
for _, message in ipairs(opts.history_messages) do
|
||||
local content = message.message.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) ~= "table" then goto continue end
|
||||
if item.type ~= "tool_result" then goto continue end
|
||||
local tool_name = tool_id_to_tool_name[item.tool_use_id]
|
||||
if tool_name ~= "view" then goto continue end
|
||||
local path = tool_id_to_path[item.tool_use_id]
|
||||
local latest_tool_id = viewed_files[path]
|
||||
if not latest_tool_id then goto continue end
|
||||
if latest_tool_id ~= item.tool_use_id then
|
||||
item.content = string.format("The file %s has been updated. Please use the latest view tool result!", path)
|
||||
else
|
||||
local lines, error = Utils.read_file_from_buf_or_disk(path)
|
||||
if error ~= nil then Utils.error("error reading file: " .. error) end
|
||||
lines = lines or {}
|
||||
item.content = table.concat(lines, "\n")
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local system_info = Utils.get_system_info()
|
||||
|
||||
local selected_files = opts.selected_files or {}
|
||||
@@ -200,6 +216,8 @@ function M.generate_prompts(opts)
|
||||
end
|
||||
end
|
||||
|
||||
selected_files = vim.iter(selected_files):filter(function(file) return viewed_files[file.path] == nil end):totable()
|
||||
|
||||
local template_opts = {
|
||||
ask = opts.ask, -- TODO: add mode without ask instruction
|
||||
code_lang = opts.code_lang,
|
||||
@@ -229,36 +247,42 @@ function M.generate_prompts(opts)
|
||||
end
|
||||
|
||||
---@type AvanteLLMMessage[]
|
||||
local messages = {}
|
||||
local context_messages = {}
|
||||
if opts.prompt_opts and opts.prompt_opts.messages then
|
||||
messages = vim.list_extend(messages, opts.prompt_opts.messages)
|
||||
context_messages = vim.list_extend(context_messages, opts.prompt_opts.messages)
|
||||
end
|
||||
|
||||
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
|
||||
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
|
||||
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
|
||||
if project_context ~= "" then
|
||||
table.insert(context_messages, { role = "user", content = project_context, visible = false, is_context = true })
|
||||
end
|
||||
end
|
||||
|
||||
if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then
|
||||
local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts)
|
||||
if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end
|
||||
if diagnostics ~= "" then
|
||||
table.insert(context_messages, { role = "user", content = diagnostics, visible = false, is_context = true })
|
||||
end
|
||||
end
|
||||
|
||||
if #selected_files > 0 or opts.selected_code ~= nil then
|
||||
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
|
||||
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
||||
if code_context ~= "" then
|
||||
table.insert(context_messages, { role = "user", content = code_context, visible = false, is_context = true })
|
||||
end
|
||||
end
|
||||
|
||||
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
|
||||
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
|
||||
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
|
||||
if memory ~= "" then
|
||||
table.insert(context_messages, { role = "user", content = memory, visible = false, is_context = true })
|
||||
end
|
||||
end
|
||||
|
||||
if instructions then table.insert(messages, { role = "user", content = instructions }) end
|
||||
|
||||
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
|
||||
|
||||
for _, message in ipairs(messages) do
|
||||
for _, message in ipairs(context_messages) do
|
||||
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
|
||||
end
|
||||
|
||||
@@ -267,47 +291,49 @@ function M.generate_prompts(opts)
|
||||
dropped_history_messages = vim.list_extend(dropped_history_messages, opts.prompt_opts.dropped_history_messages)
|
||||
end
|
||||
|
||||
local final_history_messages = {}
|
||||
if opts.history_messages then
|
||||
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
|
||||
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
|
||||
local history_messages = {}
|
||||
for i = #opts.history_messages, 1, -1 do
|
||||
local message = opts.history_messages[i]
|
||||
if Config.history.carried_entry_count ~= nil then
|
||||
if #history_messages > Config.history.carried_entry_count then break end
|
||||
table.insert(history_messages, message)
|
||||
local tokens = Utils.tokens.calculate_tokens(message.message.content)
|
||||
remaining_tokens = remaining_tokens - tokens
|
||||
if remaining_tokens > 0 then
|
||||
table.insert(history_messages, 1, message)
|
||||
else
|
||||
local tokens = Utils.tokens.calculate_tokens(message.content)
|
||||
remaining_tokens = remaining_tokens - tokens
|
||||
if remaining_tokens > 0 then
|
||||
table.insert(history_messages, message)
|
||||
else
|
||||
break
|
||||
end
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if #history_messages == 0 then
|
||||
history_messages = vim.list_slice(opts.history_messages, #opts.history_messages - 1, #opts.history_messages)
|
||||
end
|
||||
|
||||
dropped_history_messages = vim.list_slice(opts.history_messages, 1, #opts.history_messages - #history_messages)
|
||||
|
||||
-- prepend the history messages to the messages table
|
||||
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
|
||||
if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end
|
||||
vim.iter(history_messages):each(function(msg) table.insert(final_history_messages, msg) end)
|
||||
end
|
||||
|
||||
if opts.mode == "cursor-applying" then
|
||||
local user_prompt = [[
|
||||
Merge all changes from the <update> snippet into the <code> below.
|
||||
- Preserve the code's structure, order, comments, and indentation exactly.
|
||||
- Output only the updated code, enclosed within <updated-code> and </updated-code> tags.
|
||||
- Do not include any additional text, explanations, placeholders, ellipses, or code fences.
|
||||
-- Utils.debug("opts.history_messages", opts.history_messages)
|
||||
-- Utils.debug("final_history_messages", final_history_messages)
|
||||
|
||||
]]
|
||||
user_prompt = user_prompt .. string.format("<code>\n%s\n</code>\n", opts.original_code)
|
||||
for _, snippet in ipairs(opts.update_snippets) do
|
||||
user_prompt = user_prompt .. string.format("<update>\n%s\n</update>\n", snippet)
|
||||
end
|
||||
user_prompt = user_prompt .. "Provide the complete updated code."
|
||||
table.insert(messages, { role = "user", content = user_prompt })
|
||||
---@type AvanteLLMMessage[]
|
||||
local messages = vim.deepcopy(context_messages)
|
||||
for _, msg in ipairs(final_history_messages) do
|
||||
local message = msg.message
|
||||
table.insert(messages, message)
|
||||
end
|
||||
|
||||
messages = vim
|
||||
.iter(messages)
|
||||
:filter(function(msg) return type(msg.content) ~= "string" or msg.content ~= "" end)
|
||||
:totable()
|
||||
|
||||
if opts.instructions ~= nil and opts.instructions ~= "" then
|
||||
messages = vim.list_extend(messages, { { role = "user", content = opts.instructions } })
|
||||
end
|
||||
|
||||
opts.session_ctx = opts.session_ctx or {}
|
||||
@@ -318,19 +344,12 @@ Merge all changes from the <update> snippet into the <code> below.
|
||||
if opts.tools then tools = vim.list_extend(tools, opts.tools) end
|
||||
if opts.prompt_opts and opts.prompt_opts.tools then tools = vim.list_extend(tools, opts.prompt_opts.tools) end
|
||||
|
||||
local tool_histories = {}
|
||||
if opts.tool_histories then tool_histories = vim.list_extend(tool_histories, opts.tool_histories) end
|
||||
if opts.prompt_opts and opts.prompt_opts.tool_histories then
|
||||
tool_histories = vim.list_extend(tool_histories, opts.prompt_opts.tool_histories)
|
||||
end
|
||||
|
||||
---@type AvantePromptOptions
|
||||
return {
|
||||
system_prompt = system_prompt,
|
||||
messages = messages,
|
||||
image_paths = image_paths,
|
||||
tools = tools,
|
||||
tool_histories = tool_histories,
|
||||
dropped_history_messages = dropped_history_messages,
|
||||
}
|
||||
end
|
||||
@@ -372,7 +391,9 @@ function M.curl(opts)
|
||||
---@type string
|
||||
local current_event_state = nil
|
||||
local resp_ctx = {}
|
||||
resp_ctx.session_id = Utils.uuid()
|
||||
|
||||
local response_body = ""
|
||||
---@param line string
|
||||
local function parse_stream_data(line)
|
||||
local event = line:match("^event:%s*(.+)$")
|
||||
@@ -381,7 +402,16 @@ function M.curl(opts)
|
||||
return
|
||||
end
|
||||
local data_match = line:match("^data:%s*(.+)$")
|
||||
if data_match then provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
|
||||
if data_match then
|
||||
provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts)
|
||||
else
|
||||
response_body = response_body .. line
|
||||
local ok, jsn = pcall(vim.json.decode, response_body)
|
||||
if ok then
|
||||
response_body = ""
|
||||
if jsn.error then handler_opts.on_stop({ reason = "error", error = jsn.error }) end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function parse_response_without_stream(data)
|
||||
@@ -394,10 +424,13 @@ function M.curl(opts)
|
||||
|
||||
local temp_file = fn.tempname()
|
||||
local curl_body_file = temp_file .. "-request-body.json"
|
||||
local resp_body_file = temp_file .. "-response-body.json"
|
||||
local json_content = vim.json.encode(spec.body)
|
||||
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
|
||||
|
||||
Utils.debug("curl body file:", curl_body_file)
|
||||
Utils.debug("curl request body file:", curl_body_file)
|
||||
|
||||
Utils.debug("curl response body file:", resp_body_file)
|
||||
|
||||
local headers_file = temp_file .. "-headers.txt"
|
||||
|
||||
@@ -407,6 +440,7 @@ function M.curl(opts)
|
||||
if Config.debug then return end
|
||||
vim.schedule(function()
|
||||
fn.delete(curl_body_file)
|
||||
pcall(fn.delete, resp_body_file)
|
||||
fn.delete(headers_file)
|
||||
end)
|
||||
end
|
||||
@@ -431,6 +465,15 @@ function M.curl(opts)
|
||||
return
|
||||
end
|
||||
if not data then return end
|
||||
if Config.debug then
|
||||
if type(data) == "string" then
|
||||
local file = io.open(resp_body_file, "a")
|
||||
if file then
|
||||
file:write(data .. "\n")
|
||||
file:close()
|
||||
end
|
||||
end
|
||||
end
|
||||
vim.schedule(function()
|
||||
if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then
|
||||
if provider.parse_response ~= nil then
|
||||
@@ -495,6 +538,17 @@ function M.curl(opts)
|
||||
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
|
||||
end
|
||||
if result.status == 429 then
|
||||
Utils.debug("result", result)
|
||||
if result.body then
|
||||
local ok, jsn = pcall(vim.json.decode, result.body)
|
||||
if ok then
|
||||
if jsn.error and jsn.error.message then
|
||||
handler_opts.on_stop({ reason = "error", error = jsn.error.message })
|
||||
return
|
||||
end
|
||||
end
|
||||
end
|
||||
Utils.debug("result", result)
|
||||
local retry_after = 10
|
||||
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
|
||||
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
|
||||
@@ -585,17 +639,34 @@ function M._stream(opts)
|
||||
|
||||
---@type AvanteHandlerOptions
|
||||
local handler_opts = {
|
||||
on_partial_tool_use = opts.on_partial_tool_use,
|
||||
on_messages_add = opts.on_messages_add,
|
||||
on_state_change = opts.on_state_change,
|
||||
on_start = opts.on_start,
|
||||
on_chunk = opts.on_chunk,
|
||||
on_stop = function(stop_opts)
|
||||
---@param tool_use_list AvanteLLMToolUse[]
|
||||
---@param tool_use_index integer
|
||||
---@param tool_histories AvanteLLMToolHistory[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
|
||||
---@param tool_results AvanteLLMToolResult[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_results)
|
||||
if tool_use_index > #tool_use_list then
|
||||
---@type avante.HistoryMessage[]
|
||||
local messages = {}
|
||||
for _, tool_result in ipairs(tool_results) do
|
||||
messages[#messages + 1] = HistoryMessage:new({
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = tool_result.tool_use_id,
|
||||
content = tool_result.content,
|
||||
is_error = tool_result.is_error,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
opts.on_messages_add(messages)
|
||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||
tool_histories = tool_histories,
|
||||
history_messages = opts.get_history_messages(),
|
||||
})
|
||||
if provider.get_rate_limit_sleep_time then
|
||||
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
|
||||
@@ -616,7 +687,7 @@ function M._stream(opts)
|
||||
if error == LLMToolHelpers.CANCEL_TOKEN then
|
||||
Utils.debug("Tool execution was cancelled by user")
|
||||
opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n")
|
||||
return opts.on_stop({ reason = "cancelled", tool_histories = tool_histories })
|
||||
return opts.on_stop({ reason = "cancelled" })
|
||||
end
|
||||
|
||||
local tool_result = {
|
||||
@@ -624,8 +695,8 @@ function M._stream(opts)
|
||||
content = error ~= nil and error or result,
|
||||
is_error = error ~= nil,
|
||||
}
|
||||
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
|
||||
table.insert(tool_results, tool_result)
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_results)
|
||||
end
|
||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||
local result, error = LLMTools.process_tool_use(
|
||||
@@ -638,20 +709,53 @@ function M._stream(opts)
|
||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||
end
|
||||
if stop_opts.reason == "cancelled" then
|
||||
opts.on_chunk("\n*[Request cancelled by user.]*\n")
|
||||
return opts.on_stop({ reason = "cancelled", tool_histories = opts.tool_histories })
|
||||
if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user.]*\n") end
|
||||
if opts.on_messages_add then
|
||||
local message = HistoryMessage:new({
|
||||
role = "user",
|
||||
content = "[Request cancelled by user.]",
|
||||
})
|
||||
opts.on_messages_add({ message })
|
||||
end
|
||||
return opts.on_stop({ reason = "cancelled" })
|
||||
end
|
||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
||||
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
||||
if stop_opts.reason == "tool_use" then
|
||||
local tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
local tool_result_seen = {}
|
||||
local history_messages = opts.get_history_messages and opts.get_history_messages() or {}
|
||||
for idx = #history_messages, 1, -1 do
|
||||
local message = history_messages[idx]
|
||||
local content = message.message.content
|
||||
if type(content) ~= "table" or #content == 0 then goto continue end
|
||||
if content[1].type == "tool_use" then
|
||||
if not tool_result_seen[content[1].id] then
|
||||
table.insert(tool_use_list, 1, content[1])
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
if content[1].type == "tool_result" then tool_result_seen[content[1].tool_use_id] = true end
|
||||
::continue::
|
||||
end
|
||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
|
||||
for _, tool_use in vim.spairs(tool_use_list) do
|
||||
table.insert(sorted_tool_use_list, tool_use)
|
||||
end
|
||||
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
|
||||
return handle_next_tool_use(sorted_tool_use_list, 1, {})
|
||||
end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
|
||||
opts.on_chunk("\n*[" .. msg .. "]*\n")
|
||||
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
|
||||
if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end
|
||||
local message
|
||||
if opts.on_messages_add then
|
||||
message = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = "\n\n" .. msg_content,
|
||||
}, {
|
||||
just_for_display = true,
|
||||
})
|
||||
opts.on_messages_add({ message })
|
||||
end
|
||||
local timer = vim.loop.new_timer()
|
||||
if timer then
|
||||
local retry_after = stop_opts.retry_after
|
||||
@@ -661,8 +765,12 @@ function M._stream(opts)
|
||||
0,
|
||||
vim.schedule_wrap(function()
|
||||
if retry_after > 0 then retry_after = retry_after - 1 end
|
||||
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
|
||||
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
|
||||
local msg_content_ = "*[Rate limit reached. Retrying in " .. retry_after .. " seconds ...]*"
|
||||
if opts.on_chunk then opts.on_chunk([[\033[1A\033[K]] .. "\n" .. msg_content_ .. "\n") end
|
||||
if opts.on_messages_add and message then
|
||||
message.message.content = "\n\n" .. msg_content_
|
||||
opts.on_messages_add({ message })
|
||||
end
|
||||
countdown()
|
||||
end)
|
||||
)
|
||||
@@ -676,7 +784,6 @@ function M._stream(opts)
|
||||
end, stop_opts.retry_after * 1000)
|
||||
return
|
||||
end
|
||||
stop_opts.tool_histories = opts.tool_histories
|
||||
return opts.on_stop(stop_opts)
|
||||
end,
|
||||
}
|
||||
@@ -697,6 +804,8 @@ local function _merge_response(first_response, second_response, opts)
|
||||
|
||||
prompt = prompt .. "\n"
|
||||
|
||||
if opts.instructions == nil then opts.instructions = "" end
|
||||
|
||||
-- append this reference prompt to the prompt_opts messages at last
|
||||
opts.instructions = opts.instructions .. prompt
|
||||
|
||||
@@ -802,20 +911,12 @@ function M.stream(opts)
|
||||
return original_on_stop(stop_opts)
|
||||
end)
|
||||
end
|
||||
if opts.on_partial_tool_use ~= nil then
|
||||
local original_on_partial_tool_use = opts.on_partial_tool_use
|
||||
opts.on_partial_tool_use = vim.schedule_wrap(function(tool_use)
|
||||
if is_completed then return end
|
||||
return original_on_partial_tool_use(tool_use)
|
||||
end)
|
||||
end
|
||||
|
||||
local valid_dual_boost_modes = {
|
||||
planning = true,
|
||||
["cursor-planning"] = true,
|
||||
legacy = true,
|
||||
}
|
||||
|
||||
opts.mode = opts.mode or "planning"
|
||||
opts.mode = opts.mode or Config.mode
|
||||
|
||||
if Config.dual_boost.enabled and valid_dual_boost_modes[opts.mode] then
|
||||
M._dual_boost_stream(
|
||||
|
||||
Reference in New Issue
Block a user