1090 lines
39 KiB
Lua
1090 lines
39 KiB
Lua
local api = vim.api
|
|
local fn = vim.fn
|
|
local uv = vim.uv
|
|
|
|
local curl = require("plenary.curl")
|
|
|
|
local Utils = require("avante.utils")
|
|
local Config = require("avante.config")
|
|
local Path = require("avante.path")
|
|
local Providers = require("avante.providers")
|
|
local LLMToolHelpers = require("avante.llm_tools.helpers")
|
|
local LLMTools = require("avante.llm_tools")
|
|
local HistoryMessage = require("avante.history_message")
|
|
|
|
---@class avante.LLM
|
|
local M = {}
|
|
|
|
M.CANCEL_PATTERN = "AvanteLLMEscape"
|
|
|
|
------------------------------Prompt and type------------------------------
|
|
|
|
local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
|
|
|
---@param content AvanteLLMMessageContent
|
|
---@param cb fun(title: string | nil): nil
|
|
function M.summarize_chat_thread_title(content, cb)
|
|
local system_prompt =
|
|
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]]
|
|
local response_content = ""
|
|
local provider = Providers.get_memory_summary_provider()
|
|
M.curl({
|
|
provider = provider,
|
|
prompt_opts = {
|
|
system_prompt = system_prompt,
|
|
messages = {
|
|
{ role = "user", content = content },
|
|
},
|
|
},
|
|
handler_opts = {
|
|
on_start = function(_) end,
|
|
on_chunk = function(chunk)
|
|
if not chunk then return end
|
|
response_content = response_content .. chunk
|
|
end,
|
|
on_stop = function(stop_opts)
|
|
if stop_opts.error ~= nil then
|
|
Utils.error(string.format("summarize chat thread title failed: %s", vim.inspect(stop_opts.error)))
|
|
return
|
|
end
|
|
if stop_opts.reason == "complete" then
|
|
response_content = Utils.trim_think_content(response_content)
|
|
response_content = Utils.trim(response_content, { prefix = "\n", suffix = "\n" })
|
|
response_content = Utils.trim(response_content, { prefix = '"', suffix = '"' })
|
|
local title = response_content
|
|
cb(title)
|
|
end
|
|
end,
|
|
},
|
|
})
|
|
end
|
|
|
|
---@param prev_memory string | nil
|
|
---@param history_messages avante.HistoryMessage[]
|
|
---@param cb fun(memory: avante.ChatMemory | nil): nil
|
|
function M.summarize_memory(prev_memory, history_messages, cb)
|
|
local system_prompt =
|
|
[[You are an expert coding assistant. Your goal is to generate a concise, structured summary of the conversation below that captures all essential information needed to continue development after context replacement. Include tasks performed, code areas modified or reviewed, key decisions or assumptions, test results or errors, and outstanding tasks or next steps.]]
|
|
if #history_messages == 0 then
|
|
cb(nil)
|
|
return
|
|
end
|
|
local latest_timestamp = history_messages[#history_messages].timestamp
|
|
local latest_message_uuid = history_messages[#history_messages].uuid
|
|
local conversation_items = vim
|
|
.iter(history_messages)
|
|
:filter(function(msg)
|
|
if msg.just_for_display then return false end
|
|
if msg.message.role ~= "assistant" and msg.message.role ~= "user" then return false end
|
|
local content = msg.message.content
|
|
if type(content) == "table" and content[1].type == "tool_result" then return false end
|
|
if type(content) == "table" and content[1].type == "tool_use" then return false end
|
|
return true
|
|
end)
|
|
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg, history_messages) end)
|
|
:totable()
|
|
local conversation_text = table.concat(conversation_items, "\n")
|
|
local user_prompt = "Here is the conversation so far:\n"
|
|
.. conversation_text
|
|
.. "\n\nPlease summarize this conversation, covering:\n1. Tasks performed and outcomes\n2. Code files, modules, or functions modified or examined\n3. Important decisions or assumptions made\n4. Errors encountered and test or build results\n5. Remaining tasks, open questions, or next steps\nProvide the summary in a clear, concise format."
|
|
if prev_memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. prev_memory end
|
|
local messages = {
|
|
{
|
|
role = "user",
|
|
content = user_prompt,
|
|
},
|
|
}
|
|
local response_content = ""
|
|
local provider = Providers.get_memory_summary_provider()
|
|
M.curl({
|
|
provider = provider,
|
|
prompt_opts = {
|
|
system_prompt = system_prompt,
|
|
messages = messages,
|
|
},
|
|
handler_opts = {
|
|
on_start = function(_) end,
|
|
on_chunk = function(chunk)
|
|
if not chunk then return end
|
|
response_content = response_content .. chunk
|
|
end,
|
|
on_stop = function(stop_opts)
|
|
if stop_opts.error ~= nil then
|
|
Utils.error(string.format("summarize memory failed: %s", vim.inspect(stop_opts.error)))
|
|
return
|
|
end
|
|
if stop_opts.reason == "complete" then
|
|
response_content = Utils.trim_think_content(response_content)
|
|
local memory = {
|
|
content = response_content,
|
|
last_summarized_timestamp = latest_timestamp,
|
|
last_message_uuid = latest_message_uuid,
|
|
}
|
|
cb(memory)
|
|
else
|
|
cb(nil)
|
|
end
|
|
end,
|
|
},
|
|
})
|
|
end
|
|
|
|
---@param opts AvanteGeneratePromptsOptions
|
|
---@return AvantePromptOptions
|
|
function M.generate_prompts(opts)
|
|
local provider = opts.provider or Providers[Config.provider]
|
|
local mode = opts.mode or Config.mode
|
|
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
|
local _, request_body = Providers.parse_config(provider)
|
|
local max_tokens = request_body.max_tokens or 4096
|
|
|
|
-- Check if the instructions contains an image path
|
|
local image_paths = {}
|
|
if opts.prompt_opts and opts.prompt_opts.image_paths then
|
|
image_paths = vim.list_extend(image_paths, opts.prompt_opts.image_paths)
|
|
end
|
|
|
|
local project_root = Utils.root.get()
|
|
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root))
|
|
|
|
local tool_id_to_tool_name = {}
|
|
local tool_id_to_path = {}
|
|
local viewed_files = {}
|
|
local history_messages = {}
|
|
if opts.history_messages then
|
|
for _, message in ipairs(opts.history_messages) do
|
|
table.insert(history_messages, message)
|
|
if Utils.is_tool_result_message(message) then
|
|
local tool_use_message = Utils.get_tool_use_message(message, opts.history_messages)
|
|
local is_replace_func_call = false
|
|
local is_str_replace_editor_func_call = false
|
|
local path = nil
|
|
if tool_use_message then
|
|
if tool_use_message.message.content[1].name == "replace_in_file" then
|
|
is_replace_func_call = true
|
|
path = tool_use_message.message.content[1].input.path
|
|
end
|
|
if tool_use_message.message.content[1].name == "str_replace_editor" then
|
|
if tool_use_message.message.content[1].input.command == "str_replace" then
|
|
is_replace_func_call = true
|
|
is_str_replace_editor_func_call = true
|
|
path = tool_use_message.message.content[1].input.path
|
|
end
|
|
end
|
|
end
|
|
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
|
|
if is_replace_func_call and path and not message.message.content[1].is_error then
|
|
local lines = Utils.read_file_from_buf_or_disk(path)
|
|
local get_diagnostics_tool_use_id = Utils.uuid()
|
|
local view_tool_use_id = Utils.uuid()
|
|
local view_tool_name = "view"
|
|
local view_tool_input = { path = path }
|
|
if is_str_replace_editor_func_call then
|
|
view_tool_name = "str_replace_editor"
|
|
view_tool_input = { command = "view", path = path }
|
|
end
|
|
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path)
|
|
history_messages = vim.list_extend(history_messages, {
|
|
HistoryMessage:new({
|
|
role = "assistant",
|
|
content = string.format("Viewing file %s to get the latest content", path),
|
|
}, {
|
|
is_dummy = true,
|
|
}),
|
|
HistoryMessage:new({
|
|
role = "assistant",
|
|
content = {
|
|
{
|
|
type = "tool_use",
|
|
id = view_tool_use_id,
|
|
name = view_tool_name,
|
|
input = view_tool_input,
|
|
},
|
|
},
|
|
}, {
|
|
is_dummy = true,
|
|
}),
|
|
HistoryMessage:new({
|
|
role = "user",
|
|
content = {
|
|
{
|
|
type = "tool_result",
|
|
tool_use_id = view_tool_use_id,
|
|
content = table.concat(lines or {}, "\n"),
|
|
is_error = false,
|
|
},
|
|
},
|
|
}, {
|
|
is_dummy = true,
|
|
}),
|
|
HistoryMessage:new({
|
|
role = "assistant",
|
|
content = string.format(
|
|
"The file %s has been modified, let me check if there are any errors in the changes.",
|
|
path
|
|
),
|
|
}, {
|
|
is_dummy = true,
|
|
}),
|
|
HistoryMessage:new({
|
|
role = "assistant",
|
|
content = {
|
|
{
|
|
type = "tool_use",
|
|
id = get_diagnostics_tool_use_id,
|
|
name = "get_diagnostics",
|
|
input = { path = path },
|
|
},
|
|
},
|
|
}, {
|
|
is_dummy = true,
|
|
}),
|
|
HistoryMessage:new({
|
|
role = "user",
|
|
content = {
|
|
{
|
|
type = "tool_result",
|
|
tool_use_id = get_diagnostics_tool_use_id,
|
|
content = vim.json.encode(diagnostics),
|
|
is_error = false,
|
|
},
|
|
},
|
|
}, {
|
|
is_dummy = true,
|
|
}),
|
|
})
|
|
end
|
|
end
|
|
end
|
|
for _, message in ipairs(history_messages) do
|
|
local content = message.message.content
|
|
if type(content) ~= "table" then goto continue end
|
|
for _, item in ipairs(content) do
|
|
if type(item) ~= "table" then goto continue1 end
|
|
if item.type ~= "tool_use" then goto continue1 end
|
|
local tool_name = item.name
|
|
if tool_name ~= "view" then goto continue1 end
|
|
local path = item.input.path
|
|
tool_id_to_tool_name[item.id] = tool_name
|
|
if path then
|
|
local uniform_path = Utils.uniform_path(path)
|
|
tool_id_to_path[item.id] = uniform_path
|
|
viewed_files[uniform_path] = item.id
|
|
end
|
|
::continue1::
|
|
end
|
|
::continue::
|
|
end
|
|
for _, message in ipairs(history_messages) do
|
|
local content = message.message.content
|
|
if type(content) == "table" then
|
|
for _, item in ipairs(content) do
|
|
if type(item) ~= "table" then goto continue end
|
|
if item.type ~= "tool_result" then goto continue end
|
|
local tool_name = tool_id_to_tool_name[item.tool_use_id]
|
|
if tool_name ~= "view" then goto continue end
|
|
if item.is_error then goto continue end
|
|
local path = tool_id_to_path[item.tool_use_id]
|
|
local latest_tool_id = viewed_files[path]
|
|
if not latest_tool_id then goto continue end
|
|
if latest_tool_id ~= item.tool_use_id then
|
|
item.content =
|
|
string.format("The file %s has been updated. Please use the latest `view` tool result!", path)
|
|
else
|
|
local lines, error = Utils.read_file_from_buf_or_disk(path)
|
|
if error ~= nil then Utils.error("error reading file: " .. error) end
|
|
lines = lines or {}
|
|
item.content = table.concat(lines, "\n")
|
|
end
|
|
::continue::
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
local system_info = Utils.get_system_info()
|
|
|
|
local selected_files = opts.selected_files or {}
|
|
|
|
if opts.selected_filepaths then
|
|
for _, filepath in ipairs(opts.selected_filepaths) do
|
|
local lines, error = Utils.read_file_from_buf_or_disk(filepath)
|
|
lines = lines or {}
|
|
local filetype = Utils.get_filetype(filepath)
|
|
if error ~= nil then
|
|
Utils.error("error reading file: " .. error)
|
|
else
|
|
local content = table.concat(lines, "\n")
|
|
table.insert(selected_files, { path = filepath, content = content, file_type = filetype })
|
|
end
|
|
end
|
|
end
|
|
|
|
selected_files = vim.iter(selected_files):filter(function(file) return viewed_files[file.path] == nil end):totable()
|
|
|
|
local template_opts = {
|
|
ask = opts.ask, -- TODO: add mode without ask instruction
|
|
code_lang = opts.code_lang,
|
|
selected_files = selected_files,
|
|
selected_code = opts.selected_code,
|
|
recently_viewed_files = opts.recently_viewed_files,
|
|
project_context = opts.project_context,
|
|
diagnostics = opts.diagnostics,
|
|
system_info = system_info,
|
|
model_name = provider.model or "unknown",
|
|
memory = opts.memory,
|
|
}
|
|
|
|
local system_prompt
|
|
if opts.prompt_opts and opts.prompt_opts.system_prompt then
|
|
system_prompt = opts.prompt_opts.system_prompt
|
|
else
|
|
system_prompt = Path.prompts.render_mode(mode, template_opts)
|
|
end
|
|
|
|
if Config.system_prompt ~= nil then
|
|
local custom_system_prompt = Config.system_prompt
|
|
if type(custom_system_prompt) == "function" then custom_system_prompt = custom_system_prompt() end
|
|
if custom_system_prompt ~= nil and custom_system_prompt ~= "" and custom_system_prompt ~= "null" then
|
|
system_prompt = system_prompt .. "\n\n" .. custom_system_prompt
|
|
end
|
|
end
|
|
|
|
---@type AvanteLLMMessage[]
|
|
local context_messages = {}
|
|
if opts.prompt_opts and opts.prompt_opts.messages then
|
|
context_messages = vim.list_extend(context_messages, opts.prompt_opts.messages)
|
|
end
|
|
|
|
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
|
|
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
|
|
if project_context ~= "" then
|
|
table.insert(context_messages, { role = "user", content = project_context, visible = false, is_context = true })
|
|
end
|
|
end
|
|
|
|
if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then
|
|
local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts)
|
|
if diagnostics ~= "" then
|
|
table.insert(context_messages, { role = "user", content = diagnostics, visible = false, is_context = true })
|
|
end
|
|
end
|
|
|
|
if #selected_files > 0 or opts.selected_code ~= nil then
|
|
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
|
|
if code_context ~= "" then
|
|
table.insert(context_messages, { role = "user", content = code_context, visible = false, is_context = true })
|
|
end
|
|
end
|
|
|
|
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
|
|
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
|
|
if memory ~= "" then
|
|
table.insert(context_messages, { role = "user", content = memory, visible = false, is_context = true })
|
|
end
|
|
end
|
|
|
|
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
|
|
|
|
for _, message in ipairs(context_messages) do
|
|
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
|
|
end
|
|
|
|
local pending_compaction_history_messages = {}
|
|
if opts.prompt_opts and opts.prompt_opts.pending_compaction_history_messages then
|
|
pending_compaction_history_messages =
|
|
vim.list_extend(pending_compaction_history_messages, opts.prompt_opts.pending_compaction_history_messages)
|
|
end
|
|
|
|
local cleaned_history_messages = history_messages
|
|
|
|
local final_history_messages = {}
|
|
if cleaned_history_messages then
|
|
if opts.disable_compact_history_messages then
|
|
vim.iter(cleaned_history_messages):each(function(msg)
|
|
if Utils.is_tool_use_message(msg) and not Utils.get_tool_result_message(msg, cleaned_history_messages) then
|
|
return
|
|
end
|
|
if Utils.is_tool_result_message(msg) and not Utils.get_tool_use_message(msg, cleaned_history_messages) then
|
|
return
|
|
end
|
|
table.insert(final_history_messages, msg)
|
|
end)
|
|
else
|
|
if Config.history.max_tokens > 0 then
|
|
remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens)
|
|
end
|
|
|
|
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
|
|
local retained_history_messages = {}
|
|
for i = #cleaned_history_messages, 1, -1 do
|
|
local message = cleaned_history_messages[i]
|
|
local tokens = Utils.tokens.calculate_tokens(message.message.content)
|
|
remaining_tokens = remaining_tokens - tokens
|
|
if remaining_tokens > 0 then
|
|
table.insert(retained_history_messages, 1, message)
|
|
else
|
|
break
|
|
end
|
|
end
|
|
|
|
if #retained_history_messages == 0 then
|
|
retained_history_messages =
|
|
vim.list_slice(cleaned_history_messages, #cleaned_history_messages - 1, #cleaned_history_messages)
|
|
end
|
|
|
|
pending_compaction_history_messages =
|
|
vim.list_slice(cleaned_history_messages, 1, #cleaned_history_messages - #retained_history_messages)
|
|
|
|
pending_compaction_history_messages = vim
|
|
.iter(pending_compaction_history_messages)
|
|
:filter(function(msg) return msg.is_dummy ~= true end)
|
|
:totable()
|
|
|
|
vim.iter(retained_history_messages):each(function(msg)
|
|
if Utils.is_tool_use_message(msg) and not Utils.get_tool_result_message(msg, retained_history_messages) then
|
|
return
|
|
end
|
|
if Utils.is_tool_result_message(msg) and not Utils.get_tool_use_message(msg, retained_history_messages) then
|
|
return
|
|
end
|
|
table.insert(final_history_messages, msg)
|
|
end)
|
|
end
|
|
end
|
|
|
|
---@type AvanteLLMMessage[]
|
|
local messages = vim.deepcopy(context_messages)
|
|
for _, msg in ipairs(final_history_messages) do
|
|
local message = msg.message
|
|
table.insert(messages, message)
|
|
end
|
|
|
|
messages = vim
|
|
.iter(messages)
|
|
:filter(function(msg) return type(msg.content) ~= "string" or msg.content ~= "" end)
|
|
:totable()
|
|
|
|
if opts.instructions ~= nil and opts.instructions ~= "" then
|
|
messages = vim.list_extend(messages, { { role = "user", content = opts.instructions } })
|
|
end
|
|
|
|
opts.session_ctx = opts.session_ctx or {}
|
|
opts.session_ctx.system_prompt = system_prompt
|
|
opts.session_ctx.messages = messages
|
|
|
|
local tools = {}
|
|
if opts.tools then tools = vim.list_extend(tools, opts.tools) end
|
|
if opts.prompt_opts and opts.prompt_opts.tools then tools = vim.list_extend(tools, opts.prompt_opts.tools) end
|
|
|
|
---@type AvantePromptOptions
|
|
return {
|
|
system_prompt = system_prompt,
|
|
messages = messages,
|
|
image_paths = image_paths,
|
|
tools = tools,
|
|
pending_compaction_history_messages = pending_compaction_history_messages,
|
|
}
|
|
end
|
|
|
|
---@param opts AvanteGeneratePromptsOptions
|
|
---@return integer
|
|
function M.calculate_tokens(opts)
|
|
local prompt_opts = M.generate_prompts(opts)
|
|
local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt)
|
|
for _, message in ipairs(prompt_opts.messages) do
|
|
tokens = tokens + Utils.tokens.calculate_tokens(message.content)
|
|
end
|
|
return tokens
|
|
end
|
|
|
|
local parse_headers = function(headers_file)
|
|
local headers = {}
|
|
local file = io.open(headers_file, "r")
|
|
if file then
|
|
for line in file:lines() do
|
|
line = line:gsub("\r$", "")
|
|
local key, value = line:match("^%s*(.-)%s*:%s*(.*)$")
|
|
if key and value then headers[key] = value end
|
|
end
|
|
file:close()
|
|
end
|
|
return headers
|
|
end
|
|
|
|
---@param opts avante.CurlOpts
|
|
function M.curl(opts)
|
|
local provider = opts.provider
|
|
local prompt_opts = opts.prompt_opts
|
|
local handler_opts = opts.handler_opts
|
|
|
|
---@type AvanteCurlOutput
|
|
local spec = provider:parse_curl_args(prompt_opts)
|
|
|
|
---@type string
|
|
local current_event_state = nil
|
|
local resp_ctx = {}
|
|
resp_ctx.session_id = Utils.uuid()
|
|
|
|
local response_body = ""
|
|
---@param line string
|
|
local function parse_stream_data(line)
|
|
local event = line:match("^event:%s*(.+)$")
|
|
if event then
|
|
current_event_state = event
|
|
return
|
|
end
|
|
local data_match = line:match("^data:%s*(.+)$")
|
|
if data_match then
|
|
response_body = ""
|
|
provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts)
|
|
else
|
|
response_body = response_body .. line
|
|
local ok, jsn = pcall(vim.json.decode, response_body)
|
|
if ok then
|
|
if jsn.error then
|
|
handler_opts.on_stop({ reason = "error", error = jsn.error })
|
|
else
|
|
provider:parse_response(resp_ctx, response_body, current_event_state, handler_opts)
|
|
end
|
|
response_body = ""
|
|
end
|
|
end
|
|
end
|
|
|
|
local function parse_response_without_stream(data)
|
|
provider:parse_response_without_stream(data, current_event_state, handler_opts)
|
|
end
|
|
|
|
local completed = false
|
|
|
|
local active_job
|
|
|
|
local temp_file = fn.tempname()
|
|
local curl_body_file = temp_file .. "-request-body.json"
|
|
local resp_body_file = temp_file .. "-response-body.txt"
|
|
local headers_file = temp_file .. "-response-headers.txt"
|
|
local json_content = vim.json.encode(spec.body)
|
|
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
|
|
|
|
Utils.debug("curl request body file:", curl_body_file)
|
|
Utils.debug("curl response body file:", resp_body_file)
|
|
Utils.debug("curl headers file:", headers_file)
|
|
|
|
local function cleanup()
|
|
if Config.debug then return end
|
|
vim.schedule(function()
|
|
fn.delete(curl_body_file)
|
|
pcall(fn.delete, resp_body_file)
|
|
fn.delete(headers_file)
|
|
end)
|
|
end
|
|
|
|
local headers_reported = false
|
|
|
|
active_job = curl.post(spec.url, {
|
|
headers = spec.headers,
|
|
proxy = spec.proxy,
|
|
insecure = spec.insecure,
|
|
body = curl_body_file,
|
|
raw = spec.rawArgs,
|
|
dump = { "-D", headers_file },
|
|
stream = function(err, data, _)
|
|
if not headers_reported and opts.on_response_headers then
|
|
headers_reported = true
|
|
opts.on_response_headers(parse_headers(headers_file))
|
|
end
|
|
if err then
|
|
completed = true
|
|
handler_opts.on_stop({ reason = "error", error = err })
|
|
return
|
|
end
|
|
if not data then return end
|
|
if Config.debug then
|
|
if type(data) == "string" then
|
|
local file = io.open(resp_body_file, "a")
|
|
if file then
|
|
file:write(data .. "\n")
|
|
file:close()
|
|
end
|
|
end
|
|
end
|
|
vim.schedule(function()
|
|
if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then
|
|
if provider.parse_response ~= nil then
|
|
Utils.warn(
|
|
"parse_stream_data and parse_response are mutually exclusive, and thus parse_response will be ignored. Make sure that you handle the incoming data correctly.",
|
|
{ once = true }
|
|
)
|
|
end
|
|
provider:parse_stream_data(resp_ctx, data, handler_opts)
|
|
else
|
|
if provider.parse_stream_data ~= nil then
|
|
provider:parse_stream_data(resp_ctx, data, handler_opts)
|
|
else
|
|
parse_stream_data(data)
|
|
end
|
|
end
|
|
end)
|
|
end,
|
|
on_error = function(err)
|
|
if err.exit == 23 then
|
|
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
|
|
if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then
|
|
Utils.error(
|
|
"$XDG_RUNTIME_DIR="
|
|
.. xdg_runtime_dir
|
|
.. " is set but does not exist. curl could not write output. Please make sure it exists, or unset.",
|
|
{ title = "Avante" }
|
|
)
|
|
elseif not uv.fs_access(xdg_runtime_dir, "w") then
|
|
Utils.error(
|
|
"$XDG_RUNTIME_DIR="
|
|
.. xdg_runtime_dir
|
|
.. " exists but is not writable. curl could not write output. Please make sure it is writable, or unset.",
|
|
{ title = "Avante" }
|
|
)
|
|
end
|
|
end
|
|
|
|
active_job = nil
|
|
if not completed then
|
|
completed = true
|
|
cleanup()
|
|
handler_opts.on_stop({ reason = "error", error = err })
|
|
end
|
|
end,
|
|
callback = function(result)
|
|
active_job = nil
|
|
cleanup()
|
|
local headers_map = vim.iter(result.headers):fold({}, function(acc, value)
|
|
local pieces = vim.split(value, ":")
|
|
local key = pieces[1]
|
|
local remain = vim.list_slice(pieces, 2)
|
|
if not remain then return acc end
|
|
local val = Utils.trim_spaces(table.concat(remain, ":"))
|
|
acc[key] = val
|
|
return acc
|
|
end)
|
|
if result.status >= 400 then
|
|
if provider.on_error then
|
|
provider.on_error(result)
|
|
else
|
|
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
|
|
end
|
|
local retry_after = 10
|
|
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
|
|
if result.status == 429 then
|
|
Utils.debug("result", result)
|
|
|
|
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
|
|
return
|
|
end
|
|
vim.schedule(function()
|
|
if not completed then
|
|
completed = true
|
|
handler_opts.on_stop({
|
|
reason = "error",
|
|
error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body),
|
|
})
|
|
end
|
|
end)
|
|
end
|
|
|
|
-- If stream is not enabled, then handle the response here
|
|
if provider:is_disable_stream() and result.status == 200 then
|
|
vim.schedule(function()
|
|
completed = true
|
|
parse_response_without_stream(result.body)
|
|
end)
|
|
end
|
|
|
|
if result.status == 200 and spec.url:match("https://openrouter.ai") then
|
|
local content_type = headers_map["content-type"]
|
|
if content_type and content_type:match("text/html") then
|
|
handler_opts.on_stop({
|
|
reason = "error",
|
|
error = "Your openrouter endpoint setting is incorrect, please set it to https://openrouter.ai/api/v1",
|
|
})
|
|
end
|
|
end
|
|
end,
|
|
})
|
|
|
|
api.nvim_create_autocmd("User", {
|
|
group = group,
|
|
pattern = M.CANCEL_PATTERN,
|
|
once = true,
|
|
callback = function()
|
|
-- Error: cannot resume dead coroutine
|
|
if active_job then
|
|
-- Mark as completed first to prevent error handler from running
|
|
completed = true
|
|
|
|
-- Attempt to shutdown the active job, but ignore any errors
|
|
xpcall(function() active_job:shutdown() end, function(err)
|
|
Utils.debug("Ignored error during job shutdown: " .. vim.inspect(err))
|
|
return err
|
|
end)
|
|
|
|
Utils.debug("LLM request cancelled")
|
|
active_job = nil
|
|
|
|
-- Clean up and notify of cancellation
|
|
cleanup()
|
|
vim.schedule(function() handler_opts.on_stop({ reason = "cancelled" }) end)
|
|
end
|
|
end,
|
|
})
|
|
|
|
return active_job
|
|
end
|
|
|
|
---@param opts AvanteLLMStreamOptions
|
|
function M._stream(opts)
|
|
-- Reset the cancellation flag at the start of a new request
|
|
if LLMToolHelpers then LLMToolHelpers.is_cancelled = false end
|
|
|
|
local provider = opts.provider or Providers[Config.provider]
|
|
opts.session_ctx = opts.session_ctx or {}
|
|
|
|
if not opts.session_ctx.on_messages_add then opts.session_ctx.on_messages_add = opts.on_messages_add end
|
|
if not opts.session_ctx.on_state_change then opts.session_ctx.on_state_change = opts.on_state_change end
|
|
if not opts.session_ctx.on_start then opts.session_ctx.on_start = opts.on_start end
|
|
if not opts.session_ctx.on_chunk then opts.session_ctx.on_chunk = opts.on_chunk end
|
|
if not opts.session_ctx.on_stop then opts.session_ctx.on_stop = opts.on_stop end
|
|
if not opts.session_ctx.on_tool_log then opts.session_ctx.on_tool_log = opts.on_tool_log end
|
|
if not opts.session_ctx.get_history_messages then
|
|
opts.session_ctx.get_history_messages = opts.get_history_messages
|
|
end
|
|
|
|
---@cast provider AvanteProviderFunctor
|
|
|
|
local prompt_opts = M.generate_prompts(opts)
|
|
|
|
if
|
|
prompt_opts.pending_compaction_history_messages
|
|
and #prompt_opts.pending_compaction_history_messages > 0
|
|
and opts.on_memory_summarize
|
|
then
|
|
opts.on_memory_summarize(prompt_opts.pending_compaction_history_messages)
|
|
return
|
|
end
|
|
|
|
local resp_headers = {}
|
|
|
|
---@type AvanteHandlerOptions
|
|
local handler_opts = {
|
|
on_messages_add = opts.on_messages_add,
|
|
on_state_change = opts.on_state_change,
|
|
on_start = opts.on_start,
|
|
on_chunk = opts.on_chunk,
|
|
on_stop = function(stop_opts)
|
|
---@param tool_use_list AvanteLLMToolUse[]
|
|
---@param tool_use_index integer
|
|
---@param tool_results AvanteLLMToolResult[]
|
|
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_results)
|
|
if tool_use_index > #tool_use_list then
|
|
---@type avante.HistoryMessage[]
|
|
local messages = {}
|
|
for _, tool_result in ipairs(tool_results) do
|
|
messages[#messages + 1] = HistoryMessage:new({
|
|
role = "user",
|
|
content = {
|
|
{
|
|
type = "tool_result",
|
|
tool_use_id = tool_result.tool_use_id,
|
|
content = tool_result.content,
|
|
is_error = tool_result.is_error,
|
|
},
|
|
},
|
|
})
|
|
end
|
|
opts.on_messages_add(messages)
|
|
local new_opts = vim.tbl_deep_extend("force", opts, {
|
|
history_messages = opts.get_history_messages(),
|
|
})
|
|
if provider.get_rate_limit_sleep_time then
|
|
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
|
|
if sleep_time and sleep_time > 0 then
|
|
Utils.info("Rate limit reached. Sleeping for " .. sleep_time .. " seconds ...")
|
|
vim.defer_fn(function() M._stream(new_opts) end, sleep_time * 1000)
|
|
return
|
|
end
|
|
end
|
|
M._stream(new_opts)
|
|
return
|
|
end
|
|
local tool_use = tool_use_list[tool_use_index]
|
|
---@param result string | nil
|
|
---@param error string | nil
|
|
local function handle_tool_result(result, error)
|
|
-- Special handling for cancellation signal from tools
|
|
if error == LLMToolHelpers.CANCEL_TOKEN then
|
|
Utils.debug("Tool execution was cancelled by user")
|
|
if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n") end
|
|
if opts.on_messages_add then
|
|
local message = HistoryMessage:new({
|
|
role = "assistant",
|
|
content = "\n*[Request cancelled by user during tool execution.]*\n",
|
|
}, {
|
|
just_for_display = true,
|
|
})
|
|
opts.on_messages_add({ message })
|
|
end
|
|
return opts.on_stop({ reason = "cancelled" })
|
|
end
|
|
|
|
local tool_result = {
|
|
tool_use_id = tool_use.id,
|
|
content = error ~= nil and error or result,
|
|
is_error = error ~= nil,
|
|
}
|
|
table.insert(tool_results, tool_result)
|
|
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_results)
|
|
end
|
|
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
|
local result, error = LLMTools.process_tool_use(
|
|
prompt_opts.tools,
|
|
tool_use,
|
|
opts.on_tool_log,
|
|
handle_tool_result,
|
|
opts.session_ctx
|
|
)
|
|
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
|
end
|
|
if stop_opts.reason == "cancelled" then
|
|
if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user.]*\n") end
|
|
if opts.on_messages_add then
|
|
local message = HistoryMessage:new({
|
|
role = "assistant",
|
|
content = "\n*[Request cancelled by user.]*\n",
|
|
}, {
|
|
just_for_display = true,
|
|
})
|
|
opts.on_messages_add({ message })
|
|
end
|
|
return opts.on_stop({ reason = "cancelled" })
|
|
end
|
|
if stop_opts.reason == "tool_use" then
|
|
local tool_use_list = {} ---@type AvanteLLMToolUse[]
|
|
local tool_result_seen = {}
|
|
local history_messages = opts.get_history_messages and opts.get_history_messages() or {}
|
|
for idx = #history_messages, 1, -1 do
|
|
local message = history_messages[idx]
|
|
local content = message.message.content
|
|
if type(content) ~= "table" or #content == 0 then goto continue end
|
|
if content[1].type == "tool_use" then
|
|
if not tool_result_seen[content[1].id] then
|
|
table.insert(tool_use_list, 1, content[1])
|
|
else
|
|
break
|
|
end
|
|
end
|
|
if content[1].type == "tool_result" then tool_result_seen[content[1].tool_use_id] = true end
|
|
::continue::
|
|
end
|
|
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
|
for _, tool_use in vim.spairs(tool_use_list) do
|
|
table.insert(sorted_tool_use_list, tool_use)
|
|
end
|
|
return handle_next_tool_use(sorted_tool_use_list, 1, {})
|
|
end
|
|
if stop_opts.reason == "rate_limit" then
|
|
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
|
|
if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end
|
|
local message
|
|
if opts.on_messages_add then
|
|
message = HistoryMessage:new({
|
|
role = "assistant",
|
|
content = "\n\n" .. msg_content,
|
|
}, {
|
|
just_for_display = true,
|
|
})
|
|
opts.on_messages_add({ message })
|
|
end
|
|
local timer = vim.loop.new_timer()
|
|
if timer then
|
|
local retry_after = stop_opts.retry_after
|
|
local function countdown()
|
|
timer:start(
|
|
1000,
|
|
0,
|
|
vim.schedule_wrap(function()
|
|
if retry_after > 0 then retry_after = retry_after - 1 end
|
|
local msg_content_ = "*[Rate limit reached. Retrying in " .. retry_after .. " seconds ...]*"
|
|
if opts.on_chunk then opts.on_chunk([[\033[1A\033[K]] .. "\n" .. msg_content_ .. "\n") end
|
|
if opts.on_messages_add and message then
|
|
message.message.content = "\n\n" .. msg_content_
|
|
opts.on_messages_add({ message })
|
|
end
|
|
countdown()
|
|
end)
|
|
)
|
|
end
|
|
countdown()
|
|
end
|
|
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
|
|
vim.defer_fn(function()
|
|
if timer then timer:stop() end
|
|
M._stream(opts)
|
|
end, stop_opts.retry_after * 1000)
|
|
return
|
|
end
|
|
return opts.on_stop(stop_opts)
|
|
end,
|
|
}
|
|
|
|
return M.curl({
|
|
provider = provider,
|
|
prompt_opts = prompt_opts,
|
|
handler_opts = handler_opts,
|
|
on_response_headers = function(headers) resp_headers = headers end,
|
|
})
|
|
end
|
|
|
|
local function _merge_response(first_response, second_response, opts)
|
|
local prompt = "\n" .. Config.dual_boost.prompt
|
|
prompt = prompt
|
|
:gsub("{{[%s]*provider1_output[%s]*}}", function() return first_response end)
|
|
:gsub("{{[%s]*provider2_output[%s]*}}", function() return second_response end)
|
|
|
|
prompt = prompt .. "\n"
|
|
|
|
if opts.instructions == nil then opts.instructions = "" end
|
|
|
|
-- append this reference prompt to the prompt_opts messages at last
|
|
opts.instructions = opts.instructions .. prompt
|
|
|
|
M._stream(opts)
|
|
end
|
|
|
|
local function _collector_process_responses(collector, opts)
|
|
if not collector[1] or not collector[2] then
|
|
Utils.error("One or both responses failed to complete")
|
|
return
|
|
end
|
|
_merge_response(collector[1], collector[2], opts)
|
|
end
|
|
|
|
local function _collector_add_response(collector, index, response, opts)
|
|
collector[index] = response
|
|
collector.count = collector.count + 1
|
|
|
|
if collector.count == 2 then
|
|
collector.timer:stop()
|
|
_collector_process_responses(collector, opts)
|
|
end
|
|
end
|
|
|
|
function M._dual_boost_stream(opts, Provider1, Provider2)
|
|
Utils.debug("Starting Dual Boost Stream")
|
|
|
|
local collector = {
|
|
count = 0,
|
|
responses = {},
|
|
timer = uv.new_timer(),
|
|
timeout_ms = Config.dual_boost.timeout,
|
|
}
|
|
|
|
-- Setup timeout
|
|
collector.timer:start(
|
|
collector.timeout_ms,
|
|
0,
|
|
vim.schedule_wrap(function()
|
|
if collector.count < 2 then
|
|
Utils.warn("Dual boost stream timeout reached")
|
|
collector.timer:stop()
|
|
-- Process whatever responses we have
|
|
_collector_process_responses(collector, opts)
|
|
end
|
|
end)
|
|
)
|
|
|
|
-- Create options for both streams
|
|
local function create_stream_opts(index)
|
|
local response = ""
|
|
return vim.tbl_extend("force", opts, {
|
|
on_chunk = function(chunk)
|
|
if chunk then response = response .. chunk end
|
|
end,
|
|
on_stop = function(stop_opts)
|
|
if stop_opts.error then
|
|
Utils.error(string.format("Stream %d failed: %s", index, stop_opts.error))
|
|
return
|
|
end
|
|
Utils.debug(string.format("Response %d completed", index))
|
|
_collector_add_response(collector, index, response, opts)
|
|
end,
|
|
})
|
|
end
|
|
|
|
-- Start both streams
|
|
local success, err = xpcall(function()
|
|
local opts1 = create_stream_opts(1)
|
|
opts1.provider = Provider1
|
|
M._stream(opts1)
|
|
local opts2 = create_stream_opts(2)
|
|
opts2.provider = Provider2
|
|
M._stream(opts2)
|
|
end, function(err) return err end)
|
|
if not success then Utils.error("Failed to start dual_boost streams: " .. tostring(err)) end
|
|
end
|
|
|
|
---@param opts AvanteLLMStreamOptions
|
|
function M.stream(opts)
|
|
local is_completed = false
|
|
if opts.on_tool_log ~= nil then
|
|
local original_on_tool_log = opts.on_tool_log
|
|
opts.on_tool_log = vim.schedule_wrap(function(...)
|
|
if not original_on_tool_log then return end
|
|
return original_on_tool_log(...)
|
|
end)
|
|
end
|
|
if opts.on_chunk ~= nil then
|
|
local original_on_chunk = opts.on_chunk
|
|
opts.on_chunk = vim.schedule_wrap(function(chunk)
|
|
if is_completed then return end
|
|
if original_on_chunk then return original_on_chunk(chunk) end
|
|
end)
|
|
end
|
|
if opts.on_stop ~= nil then
|
|
local original_on_stop = opts.on_stop
|
|
opts.on_stop = vim.schedule_wrap(function(stop_opts)
|
|
if is_completed then return end
|
|
if stop_opts.reason == "complete" or stop_opts.reason == "error" or stop_opts.reason == "cancelled" then
|
|
is_completed = true
|
|
end
|
|
return original_on_stop(stop_opts)
|
|
end)
|
|
end
|
|
|
|
local valid_dual_boost_modes = {
|
|
legacy = true,
|
|
}
|
|
|
|
opts.mode = opts.mode or Config.mode
|
|
|
|
if Config.dual_boost.enabled and valid_dual_boost_modes[opts.mode] then
|
|
M._dual_boost_stream(
|
|
opts,
|
|
Providers[Config.dual_boost.first_provider],
|
|
Providers[Config.dual_boost.second_provider]
|
|
)
|
|
else
|
|
M._stream(opts)
|
|
end
|
|
end
|
|
|
|
function M.cancel_inflight_request()
|
|
if LLMToolHelpers.is_cancelled ~= nil then LLMToolHelpers.is_cancelled = true end
|
|
if LLMToolHelpers.confirm_popup ~= nil then
|
|
LLMToolHelpers.confirm_popup:cancel()
|
|
LLMToolHelpers.confirm_popup = nil
|
|
end
|
|
|
|
api.nvim_exec_autocmds("User", { pattern = M.CANCEL_PATTERN })
|
|
end
|
|
|
|
return M
|