Files
avante.nvim/lua/avante/llm.lua
2025-03-14 14:13:47 +08:00

647 lines
22 KiB
Lua

local api = vim.api
local fn = vim.fn
local uv = vim.uv
local curl = require("plenary.curl")
local Utils = require("avante.utils")
local Config = require("avante.config")
local Path = require("avante.path")
local Providers = require("avante.providers")
local LLMTools = require("avante.llm_tools")
---@class avante.LLM
local M = {}
M.CANCEL_PATTERN = "AvanteLLMEscape"
------------------------------Prompt and type------------------------------
local group = api.nvim_create_augroup("avante_llm", { clear = true })
---@param bufnr integer
---@param history avante.ChatHistory
---@param cb fun(memory: avante.ChatMemory | nil): nil
function M.summarize_memory(bufnr, history, cb)
local system_prompt =
[[Summarize the following conversation to extract the most critical information (such as languages used, conversation style, tech stack, considerations, user information, etc.) for memory in subsequent conversations. Since it is for memory purposes, be detailed and rigorous to ensure that no information from previous summaries is lost in the newly generated summary.]]
local entries = Utils.history.filter_active_entries(history.entries)
if #entries == 0 then
cb(nil)
return
end
if history.memory then
system_prompt = system_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content
entries = vim
.iter(entries)
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
:totable()
end
if #entries == 0 then
cb(history.memory)
return
end
local history_messages = Utils.history.entries_to_llm_messages(entries)
history_messages = vim.list_slice(history_messages, 1, 4)
if #history_messages == 0 then
cb(history.memory)
return
end
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
local response_content = ""
local provider = Providers[Config.memory_summary_provider or Config.provider]
M.curl({
provider = provider,
prompt_opts = {
system_prompt = system_prompt,
messages = {
{ role = "user", content = vim.json.encode(history_messages) },
},
},
handler_opts = {
on_start = function(_) end,
on_chunk = function(chunk)
if not chunk then return end
response_content = response_content .. chunk
end,
on_stop = function(stop_opts)
if stop_opts.error ~= nil then
Utils.error(string.format("summarize failed: %s", vim.inspect(stop_opts.error)))
return
end
if stop_opts.reason == "complete" then
response_content = Utils.trim_think_content(response_content)
local memory = {
content = response_content,
last_summarized_timestamp = entries[#entries].timestamp,
}
history.memory = memory
Path.history.save(bufnr, history)
cb(memory)
end
end,
},
})
end
---@param opts AvanteGeneratePromptsOptions
---@return AvantePromptOptions
function M.generate_prompts(opts)
local provider = opts.provider or Providers[Config.provider]
local mode = opts.mode or "planning"
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
local _, request_body = Providers.parse_config(provider)
local max_tokens = request_body.max_tokens or 4096
-- Check if the instructions contains an image path
local image_paths = {}
local instructions = opts.instructions
if instructions and instructions:match("image: ") then
local lines = vim.split(opts.instructions, "\n")
for i, line in ipairs(lines) do
if line:match("^image: ") then
local image_path = line:gsub("^image: ", "")
table.insert(image_paths, image_path)
table.remove(lines, i)
end
end
instructions = table.concat(lines, "\n")
end
local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root))
local system_info = Utils.get_system_info()
local template_opts = {
ask = opts.ask, -- TODO: add mode without ask instruction
code_lang = opts.code_lang,
selected_files = opts.selected_files,
selected_code = opts.selected_code,
recently_viewed_files = opts.recently_viewed_files,
project_context = opts.project_context,
diagnostics = opts.diagnostics,
system_info = system_info,
model_name = provider.model or "unknown",
memory = opts.memory,
}
local system_prompt = Path.prompts.render_mode(mode, template_opts)
if Config.system_prompt ~= nil then
local custom_system_prompt = Config.system_prompt
if type(custom_system_prompt) == "function" then custom_system_prompt = custom_system_prompt() end
if custom_system_prompt ~= nil and custom_system_prompt ~= "" and custom_system_prompt ~= "null" then
system_prompt = system_prompt .. "\n\n" .. custom_system_prompt
end
end
---@type AvanteLLMMessage[]
local messages = {}
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
end
if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then
local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts)
if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end
end
if (opts.selected_files and #opts.selected_files > 0 or false) or opts.selected_code ~= nil then
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
end
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
end
if instructions then
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
end
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
for _, message in ipairs(messages) do
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
end
if opts.history_messages then
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
local history_messages = {}
for i = #opts.history_messages, 1, -1 do
local message = opts.history_messages[i]
local tokens = Utils.tokens.calculate_tokens(message.content)
remaining_tokens = remaining_tokens - tokens
if remaining_tokens > 0 then
table.insert(history_messages, message)
else
break
end
end
-- prepend the history messages to the messages table
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end
end
if opts.mode == "cursor-applying" then
local user_prompt = [[
Merge all changes from the <update> snippet into the <code> below.
- Preserve the code's structure, order, comments, and indentation exactly.
- Output only the updated code, enclosed within <updated-code> and </updated-code> tags.
- Do not include any additional text, explanations, placeholders, ellipses, or code fences.
]]
user_prompt = user_prompt .. string.format("<code>\n%s\n</code>\n", opts.original_code)
for _, snippet in ipairs(opts.update_snippets) do
user_prompt = user_prompt .. string.format("<update>\n%s\n</update>\n", snippet)
end
user_prompt = user_prompt .. "Provide the complete updated code."
table.insert(messages, { role = "user", content = user_prompt })
end
---@type AvantePromptOptions
return {
system_prompt = system_prompt,
messages = messages,
image_paths = image_paths,
tools = opts.tools,
tool_histories = opts.tool_histories,
}
end
---@param opts AvanteGeneratePromptsOptions
---@return integer
function M.calculate_tokens(opts)
local prompt_opts = M.generate_prompts(opts)
local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt)
for _, message in ipairs(prompt_opts.messages) do
tokens = tokens + Utils.tokens.calculate_tokens(message.content)
end
return tokens
end
local parse_headers = function(headers_file)
local headers = {}
local file = io.open(headers_file, "r")
if file then
for line in file:lines() do
line = line:gsub("\r$", "")
local key, value = line:match("^%s*(.-)%s*:%s*(.*)$")
if key and value then headers[key] = value end
end
file:close()
end
return headers
end
---@param opts avante.CurlOpts
function M.curl(opts)
local provider = opts.provider
local prompt_opts = opts.prompt_opts
local handler_opts = opts.handler_opts
---@type AvanteCurlOutput
local spec = provider:parse_curl_args(prompt_opts)
---@type string
local current_event_state = nil
local resp_ctx = {}
---@param line string
local function parse_stream_data(line)
local event = line:match("^event: (.+)$")
if event then
current_event_state = event
return
end
local data_match = line:match("^data: (.+)$")
if data_match then provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
end
local function parse_response_without_stream(data)
provider:parse_response_without_stream(data, current_event_state, handler_opts)
end
local completed = false
local active_job
local temp_file = fn.tempname()
local curl_body_file = temp_file .. "-request-body.json"
local json_content = vim.json.encode(spec.body)
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
Utils.debug("curl body file:", curl_body_file)
local headers_file = temp_file .. "-headers.txt"
Utils.debug("curl headers file:", headers_file)
local function cleanup()
if Config.debug then return end
vim.schedule(function()
fn.delete(curl_body_file)
fn.delete(headers_file)
end)
end
local headers_reported = false
active_job = curl.post(spec.url, {
headers = spec.headers,
proxy = spec.proxy,
insecure = spec.insecure,
body = curl_body_file,
raw = spec.rawArgs,
dump = { "-D", headers_file },
stream = function(err, data, _)
if not headers_reported and opts.on_response_headers then
headers_reported = true
opts.on_response_headers(parse_headers(headers_file))
end
if err then
completed = true
handler_opts.on_stop({ reason = "error", error = err })
return
end
if not data then return end
vim.schedule(function()
if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then
if provider.parse_response ~= nil then
Utils.warn(
"parse_stream_data and parse_response are mutually exclusive, and thus parse_response will be ignored. Make sure that you handle the incoming data correctly.",
{ once = true }
)
end
provider:parse_stream_data(resp_ctx, data, handler_opts)
else
if provider.parse_stream_data ~= nil then
provider:parse_stream_data(resp_ctx, data, handler_opts)
else
parse_stream_data(data)
end
end
end)
end,
on_error = function(err)
if err.exit == 23 then
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then
Utils.error(
"$XDG_RUNTIME_DIR="
.. xdg_runtime_dir
.. " is set but does not exist. curl could not write output. Please make sure it exists, or unset.",
{ title = "Avante" }
)
elseif not uv.fs_access(xdg_runtime_dir, "w") then
Utils.error(
"$XDG_RUNTIME_DIR="
.. xdg_runtime_dir
.. " exists but is not writable. curl could not write output. Please make sure it is writable, or unset.",
{ title = "Avante" }
)
end
end
active_job = nil
completed = true
cleanup()
handler_opts.on_stop({ reason = "error", error = err })
end,
callback = function(result)
active_job = nil
cleanup()
if result.status >= 400 then
if provider.on_error then
provider.on_error(result)
else
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
end
if result.status == 429 then
local headers_map = vim.iter(result.headers):fold({}, function(acc, value)
local pieces = vim.split(value, ":")
local key = pieces[1]
local remain = vim.list_slice(pieces, 2)
if not remain then return acc end
local val = Utils.trim_spaces(table.concat(remain, ":"))
acc[key] = val
return acc
end)
local retry_after = 10
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
return
end
vim.schedule(function()
if not completed then
completed = true
handler_opts.on_stop({
reason = "error",
error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body),
})
end
end)
end
-- If stream is not enabled, then handle the response here
if provider:is_disable_stream() and result.status == 200 then
vim.schedule(function()
completed = true
parse_response_without_stream(result.body)
end)
end
end,
})
api.nvim_create_autocmd("User", {
group = group,
pattern = M.CANCEL_PATTERN,
once = true,
callback = function()
-- Error: cannot resume dead coroutine
if active_job then
xpcall(function() active_job:shutdown() end, function(err) return err end)
Utils.debug("LLM request cancelled")
active_job = nil
end
end,
})
return active_job
end
---@param opts AvanteLLMStreamOptions
function M._stream(opts)
local provider = opts.provider or Providers[Config.provider]
---@cast provider AvanteProviderFunctor
local prompt_opts = M.generate_prompts(opts)
local resp_headers = {}
---@type AvanteHandlerOptions
local handler_opts = {
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
---@param tool_use_list AvanteLLMToolUse[]
---@param tool_use_index integer
---@param tool_histories AvanteLLMToolHistory[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
if tool_use_index > #tool_use_list then
local new_opts = vim.tbl_deep_extend("force", opts, {
tool_histories = tool_histories,
})
if provider.get_rate_limit_sleep_time then
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
if sleep_time and sleep_time > 0 then
Utils.info("Rate limit reached. Sleeping for " .. sleep_time .. " seconds ...")
vim.defer_fn(function() M._stream(new_opts) end, sleep_time * 1000)
return
end
end
M._stream(new_opts)
return
end
local tool_use = tool_use_list[tool_use_index]
---@param result string | nil
---@param error string | nil
local function handle_tool_result(result, error)
local tool_result = {
tool_use_id = tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
end
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
table.insert(sorted_tool_use_list, tool_use)
end
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
end
if stop_opts.reason == "rate_limit" then
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
opts.on_chunk("\n*[" .. msg .. "]*\n")
local timer = vim.loop.new_timer()
if timer then
local retry_after = stop_opts.retry_after
local function countdown()
timer:start(
1000,
0,
vim.schedule_wrap(function()
if retry_after > 0 then retry_after = retry_after - 1 end
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
countdown()
end)
)
end
countdown()
end
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
vim.defer_fn(function()
if timer then timer:stop() end
M._stream(opts)
end, stop_opts.retry_after * 1000)
return
end
return opts.on_stop(stop_opts)
end,
}
return M.curl({
provider = provider,
prompt_opts = prompt_opts,
handler_opts = handler_opts,
on_response_headers = function(headers) resp_headers = headers end,
})
end
local function _merge_response(first_response, second_response, opts)
local prompt = "\n" .. Config.dual_boost.prompt
prompt = prompt
:gsub("{{[%s]*provider1_output[%s]*}}", function() return first_response end)
:gsub("{{[%s]*provider2_output[%s]*}}", function() return second_response end)
prompt = prompt .. "\n"
-- append this reference prompt to the prompt_opts messages at last
opts.instructions = opts.instructions .. prompt
M._stream(opts)
end
local function _collector_process_responses(collector, opts)
if not collector[1] or not collector[2] then
Utils.error("One or both responses failed to complete")
return
end
_merge_response(collector[1], collector[2], opts)
end
local function _collector_add_response(collector, index, response, opts)
collector[index] = response
collector.count = collector.count + 1
if collector.count == 2 then
collector.timer:stop()
_collector_process_responses(collector, opts)
end
end
function M._dual_boost_stream(opts, Provider1, Provider2)
Utils.debug("Starting Dual Boost Stream")
local collector = {
count = 0,
responses = {},
timer = uv.new_timer(),
timeout_ms = Config.dual_boost.timeout,
}
-- Setup timeout
collector.timer:start(
collector.timeout_ms,
0,
vim.schedule_wrap(function()
if collector.count < 2 then
Utils.warn("Dual boost stream timeout reached")
collector.timer:stop()
-- Process whatever responses we have
_collector_process_responses(collector, opts)
end
end)
)
-- Create options for both streams
local function create_stream_opts(index)
local response = ""
return vim.tbl_extend("force", opts, {
on_chunk = function(chunk)
if chunk then response = response .. chunk end
end,
on_stop = function(stop_opts)
if stop_opts.error then
Utils.error(string.format("Stream %d failed: %s", index, stop_opts.error))
return
end
Utils.debug(string.format("Response %d completed", index))
_collector_add_response(collector, index, response, opts)
end,
})
end
-- Start both streams
local success, err = xpcall(function()
local opts1 = create_stream_opts(1)
opts1.provider = Provider1
M._stream(opts1)
local opts2 = create_stream_opts(2)
opts2.provider = Provider2
M._stream(opts2)
end, function(err) return err end)
if not success then Utils.error("Failed to start dual_boost streams: " .. tostring(err)) end
end
---@param opts AvanteLLMStreamOptions
function M.stream(opts)
local is_completed = false
if opts.on_tool_log ~= nil then
local original_on_tool_log = opts.on_tool_log
opts.on_tool_log = vim.schedule_wrap(function(tool_name, log)
if not original_on_tool_log then return end
return original_on_tool_log(tool_name, log)
end)
end
if opts.on_chunk ~= nil then
local original_on_chunk = opts.on_chunk
opts.on_chunk = vim.schedule_wrap(function(chunk)
if is_completed then return end
return original_on_chunk(chunk)
end)
end
if opts.on_stop ~= nil then
local original_on_stop = opts.on_stop
opts.on_stop = vim.schedule_wrap(function(stop_opts)
if is_completed then return end
if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end
return original_on_stop(stop_opts)
end)
end
local valid_dual_boost_modes = {
planning = true,
["cursor-planning"] = true,
}
opts.mode = opts.mode or "planning"
if Config.dual_boost.enabled and valid_dual_boost_modes[opts.mode] then
M._dual_boost_stream(
opts,
Providers[Config.dual_boost.first_provider],
Providers[Config.dual_boost.second_provider]
)
else
M._stream(opts)
end
end
function M.cancel_inflight_request() api.nvim_exec_autocmds("User", { pattern = M.CANCEL_PATTERN }) end
return M