refactor: summarize memory (#1508)
This commit is contained in:
@@ -33,6 +33,7 @@ struct TemplateContext {
|
|||||||
diagnostics: Option<String>,
|
diagnostics: Option<String>,
|
||||||
system_info: Option<String>,
|
system_info: Option<String>,
|
||||||
model_name: Option<String>,
|
model_name: Option<String>,
|
||||||
|
memory: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given the file name registered after add, the context table in Lua, resulted in a formatted
|
// Given the file name registered after add, the context table in Lua, resulted in a formatted
|
||||||
@@ -58,6 +59,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
|
|||||||
diagnostics => context.diagnostics,
|
diagnostics => context.diagnostics,
|
||||||
system_info => context.system_info,
|
system_info => context.system_info,
|
||||||
model_name => context.model_name,
|
model_name => context.model_name,
|
||||||
|
memory => context.memory,
|
||||||
})
|
})
|
||||||
.map_err(LuaError::external)
|
.map_err(LuaError::external)
|
||||||
.unwrap())
|
.unwrap())
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ M._defaults = {
|
|||||||
-- Of course, you can reduce the request frequency by increasing `suggestion.debounce`.
|
-- Of course, you can reduce the request frequency by increasing `suggestion.debounce`.
|
||||||
auto_suggestions_provider = "claude",
|
auto_suggestions_provider = "claude",
|
||||||
cursor_applying_provider = nil,
|
cursor_applying_provider = nil,
|
||||||
|
memory_summary_provider = nil,
|
||||||
---@alias Tokenizer "tiktoken" | "hf"
|
---@alias Tokenizer "tiktoken" | "hf"
|
||||||
-- Used for counting tokens and encoding text.
|
-- Used for counting tokens and encoding text.
|
||||||
-- By default, we will use tiktoken.
|
-- By default, we will use tiktoken.
|
||||||
@@ -273,6 +274,10 @@ M._defaults = {
|
|||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 8000,
|
max_tokens = 8000,
|
||||||
},
|
},
|
||||||
|
["openai-gpt-4o-mini"] = {
|
||||||
|
__inherited_from = "openai",
|
||||||
|
model = "gpt-4o-mini",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
---Specify the special dual_boost mode
|
---Specify the special dual_boost mode
|
||||||
---1. enabled: Whether to enable dual_boost mode. Default to false.
|
---1. enabled: Whether to enable dual_boost mode. Default to false.
|
||||||
|
|||||||
@@ -19,6 +19,71 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"
|
|||||||
|
|
||||||
local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
||||||
|
|
||||||
|
---@param bufnr integer
|
||||||
|
---@param history avante.ChatHistory
|
||||||
|
---@param cb fun(memory: avante.ChatMemory | nil): nil
|
||||||
|
function M.summarize_memory(bufnr, history, cb)
|
||||||
|
local system_prompt =
|
||||||
|
[[Summarize the following conversation to extract the most critical information (such as languages used, conversation style, tech stack, considerations, user information, etc.) for memory in subsequent conversations. Since it is for memory purposes, be detailed and rigorous to ensure that no information from previous summaries is lost in the newly generated summary.]]
|
||||||
|
local entries = Utils.history.filter_active_entries(history.entries)
|
||||||
|
if #entries == 0 then
|
||||||
|
cb(nil)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
if history.memory then
|
||||||
|
system_prompt = system_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content
|
||||||
|
entries = vim
|
||||||
|
.iter(entries)
|
||||||
|
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
|
||||||
|
:totable()
|
||||||
|
end
|
||||||
|
if #entries == 0 then
|
||||||
|
cb(history.memory)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
local history_messages = Utils.history.entries_to_llm_messages(entries)
|
||||||
|
history_messages = vim.list_slice(history_messages, 1, 4)
|
||||||
|
if #history_messages == 0 then
|
||||||
|
cb(history.memory)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
|
||||||
|
local response_content = ""
|
||||||
|
local provider = Providers[Config.memory_summary_provider or Config.provider]
|
||||||
|
M.curl({
|
||||||
|
provider = provider,
|
||||||
|
prompt_opts = {
|
||||||
|
system_prompt = system_prompt,
|
||||||
|
messages = {
|
||||||
|
{ role = "user", content = vim.json.encode(history_messages) },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
handler_opts = {
|
||||||
|
on_start = function(_) end,
|
||||||
|
on_chunk = function(chunk)
|
||||||
|
if not chunk then return end
|
||||||
|
response_content = response_content .. chunk
|
||||||
|
end,
|
||||||
|
on_stop = function(stop_opts)
|
||||||
|
if stop_opts.error ~= nil then
|
||||||
|
Utils.error(string.format("summarize failed: %s", vim.inspect(stop_opts.error)))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
if stop_opts.reason == "complete" then
|
||||||
|
response_content = Utils.trim_think_content(response_content)
|
||||||
|
local memory = {
|
||||||
|
content = response_content,
|
||||||
|
last_summarized_timestamp = entries[#entries].timestamp,
|
||||||
|
}
|
||||||
|
history.memory = memory
|
||||||
|
Path.history.save(bufnr, history)
|
||||||
|
cb(memory)
|
||||||
|
end
|
||||||
|
end,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
---@param opts AvanteGeneratePromptsOptions
|
---@param opts AvanteGeneratePromptsOptions
|
||||||
---@return AvantePromptOptions
|
---@return AvantePromptOptions
|
||||||
function M.generate_prompts(opts)
|
function M.generate_prompts(opts)
|
||||||
@@ -58,6 +123,7 @@ function M.generate_prompts(opts)
|
|||||||
diagnostics = opts.diagnostics,
|
diagnostics = opts.diagnostics,
|
||||||
system_info = system_info,
|
system_info = system_info,
|
||||||
model_name = provider.model or "unknown",
|
model_name = provider.model or "unknown",
|
||||||
|
memory = opts.memory,
|
||||||
}
|
}
|
||||||
|
|
||||||
local system_prompt = Path.prompts.render_mode(mode, template_opts)
|
local system_prompt = Path.prompts.render_mode(mode, template_opts)
|
||||||
@@ -80,6 +146,11 @@ function M.generate_prompts(opts)
|
|||||||
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
|
||||||
|
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
|
||||||
|
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
|
||||||
|
end
|
||||||
|
|
||||||
if instructions then
|
if instructions then
|
||||||
if opts.use_xml_format then
|
if opts.use_xml_format then
|
||||||
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
|
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
|
||||||
@@ -150,90 +221,17 @@ function M.calculate_tokens(opts)
|
|||||||
return tokens
|
return tokens
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts AvanteLLMStreamOptions
|
---@param opts avante.CurlOpts
|
||||||
function M._stream(opts)
|
function M.curl(opts)
|
||||||
local provider = opts.provider or Providers[Config.provider]
|
local provider = opts.provider
|
||||||
|
local prompt_opts = opts.prompt_opts
|
||||||
|
local handler_opts = opts.handler_opts
|
||||||
|
|
||||||
---@cast provider AvanteProviderFunctor
|
---@type AvanteCurlOutput
|
||||||
|
local spec = provider:parse_curl_args(prompt_opts)
|
||||||
local prompt_opts = M.generate_prompts(opts)
|
|
||||||
|
|
||||||
---@type string
|
---@type string
|
||||||
local current_event_state = nil
|
local current_event_state = nil
|
||||||
|
|
||||||
---@type AvanteHandlerOptions
|
|
||||||
local handler_opts = {
|
|
||||||
on_start = opts.on_start,
|
|
||||||
on_chunk = opts.on_chunk,
|
|
||||||
on_stop = function(stop_opts)
|
|
||||||
---@param tool_use_list AvanteLLMToolUse[]
|
|
||||||
---@param tool_use_index integer
|
|
||||||
---@param tool_histories AvanteLLMToolHistory[]
|
|
||||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
|
|
||||||
if tool_use_index > #tool_use_list then
|
|
||||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
|
||||||
tool_histories = tool_histories,
|
|
||||||
})
|
|
||||||
return M._stream(new_opts)
|
|
||||||
end
|
|
||||||
local tool_use = tool_use_list[tool_use_index]
|
|
||||||
---@param result string | nil
|
|
||||||
---@param error string | nil
|
|
||||||
local function handle_tool_result(result, error)
|
|
||||||
local tool_result = {
|
|
||||||
tool_use_id = tool_use.id,
|
|
||||||
content = error ~= nil and error or result,
|
|
||||||
is_error = error ~= nil,
|
|
||||||
}
|
|
||||||
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
|
|
||||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
|
|
||||||
end
|
|
||||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
|
||||||
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
|
|
||||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
|
||||||
end
|
|
||||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
|
||||||
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
|
||||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
|
||||||
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
|
|
||||||
table.insert(sorted_tool_use_list, tool_use)
|
|
||||||
end
|
|
||||||
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
|
|
||||||
end
|
|
||||||
if stop_opts.reason == "rate_limit" then
|
|
||||||
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
|
|
||||||
opts.on_chunk("\n*[" .. msg .. "]*\n")
|
|
||||||
local timer = vim.loop.new_timer()
|
|
||||||
if timer then
|
|
||||||
local retry_after = stop_opts.retry_after
|
|
||||||
local function countdown()
|
|
||||||
timer:start(
|
|
||||||
1000,
|
|
||||||
0,
|
|
||||||
vim.schedule_wrap(function()
|
|
||||||
if retry_after > 0 then retry_after = retry_after - 1 end
|
|
||||||
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
|
|
||||||
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
|
|
||||||
countdown()
|
|
||||||
end)
|
|
||||||
)
|
|
||||||
end
|
|
||||||
countdown()
|
|
||||||
end
|
|
||||||
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
|
|
||||||
vim.defer_fn(function()
|
|
||||||
if timer then timer:stop() end
|
|
||||||
M._stream(opts)
|
|
||||||
end, stop_opts.retry_after * 1000)
|
|
||||||
return
|
|
||||||
end
|
|
||||||
return opts.on_stop(stop_opts)
|
|
||||||
end,
|
|
||||||
}
|
|
||||||
|
|
||||||
---@type AvanteCurlOutput
|
|
||||||
local spec = provider:parse_curl_args(provider, prompt_opts)
|
|
||||||
|
|
||||||
local resp_ctx = {}
|
local resp_ctx = {}
|
||||||
|
|
||||||
---@param line string
|
---@param line string
|
||||||
@@ -383,6 +381,91 @@ function M._stream(opts)
|
|||||||
return active_job
|
return active_job
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@param opts AvanteLLMStreamOptions
|
||||||
|
function M._stream(opts)
|
||||||
|
local provider = opts.provider or Providers[Config.provider]
|
||||||
|
|
||||||
|
---@cast provider AvanteProviderFunctor
|
||||||
|
|
||||||
|
local prompt_opts = M.generate_prompts(opts)
|
||||||
|
|
||||||
|
---@type AvanteHandlerOptions
|
||||||
|
local handler_opts = {
|
||||||
|
on_start = opts.on_start,
|
||||||
|
on_chunk = opts.on_chunk,
|
||||||
|
on_stop = function(stop_opts)
|
||||||
|
---@param tool_use_list AvanteLLMToolUse[]
|
||||||
|
---@param tool_use_index integer
|
||||||
|
---@param tool_histories AvanteLLMToolHistory[]
|
||||||
|
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
|
||||||
|
if tool_use_index > #tool_use_list then
|
||||||
|
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||||
|
tool_histories = tool_histories,
|
||||||
|
})
|
||||||
|
return M._stream(new_opts)
|
||||||
|
end
|
||||||
|
local tool_use = tool_use_list[tool_use_index]
|
||||||
|
---@param result string | nil
|
||||||
|
---@param error string | nil
|
||||||
|
local function handle_tool_result(result, error)
|
||||||
|
local tool_result = {
|
||||||
|
tool_use_id = tool_use.id,
|
||||||
|
content = error ~= nil and error or result,
|
||||||
|
is_error = error ~= nil,
|
||||||
|
}
|
||||||
|
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
|
||||||
|
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
|
||||||
|
end
|
||||||
|
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||||
|
local result, error = LLMTools.process_tool_use(opts.tools, tool_use, opts.on_tool_log, handle_tool_result)
|
||||||
|
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||||
|
end
|
||||||
|
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
||||||
|
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
||||||
|
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||||
|
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
|
||||||
|
table.insert(sorted_tool_use_list, tool_use)
|
||||||
|
end
|
||||||
|
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
|
||||||
|
end
|
||||||
|
if stop_opts.reason == "rate_limit" then
|
||||||
|
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
|
||||||
|
opts.on_chunk("\n*[" .. msg .. "]*\n")
|
||||||
|
local timer = vim.loop.new_timer()
|
||||||
|
if timer then
|
||||||
|
local retry_after = stop_opts.retry_after
|
||||||
|
local function countdown()
|
||||||
|
timer:start(
|
||||||
|
1000,
|
||||||
|
0,
|
||||||
|
vim.schedule_wrap(function()
|
||||||
|
if retry_after > 0 then retry_after = retry_after - 1 end
|
||||||
|
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
|
||||||
|
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
|
||||||
|
countdown()
|
||||||
|
end)
|
||||||
|
)
|
||||||
|
end
|
||||||
|
countdown()
|
||||||
|
end
|
||||||
|
Utils.info("Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds", { title = "Avante" })
|
||||||
|
vim.defer_fn(function()
|
||||||
|
if timer then timer:stop() end
|
||||||
|
M._stream(opts)
|
||||||
|
end, stop_opts.retry_after * 1000)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
return opts.on_stop(stop_opts)
|
||||||
|
end,
|
||||||
|
}
|
||||||
|
|
||||||
|
return M.curl({
|
||||||
|
provider = provider,
|
||||||
|
prompt_opts = prompt_opts,
|
||||||
|
handler_opts = handler_opts,
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
local function _merge_response(first_response, second_response, opts)
|
local function _merge_response(first_response, second_response, opts)
|
||||||
local prompt = "\n" .. Config.dual_boost.prompt
|
local prompt = "\n" .. Config.dual_boost.prompt
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
|
|||||||
@@ -5,58 +5,75 @@ local Path = require("plenary.path")
|
|||||||
local Scan = require("plenary.scandir")
|
local Scan = require("plenary.scandir")
|
||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
|
|
||||||
---@class avante.ChatHistoryEntry
|
|
||||||
---@field timestamp string
|
|
||||||
---@field provider string
|
|
||||||
---@field model string
|
|
||||||
---@field request string
|
|
||||||
---@field response string
|
|
||||||
---@field original_response string
|
|
||||||
---@field selected_file {filepath: string}?
|
|
||||||
---@field selected_code {filetype: string, content: string}?
|
|
||||||
---@field reset_memory boolean?
|
|
||||||
---@field selected_filepaths string[] | nil
|
|
||||||
|
|
||||||
---@class avante.Path
|
---@class avante.Path
|
||||||
---@field history_path Path
|
|
||||||
---@field cache_path Path
|
|
||||||
local P = {}
|
local P = {}
|
||||||
|
|
||||||
local history_file_cache = LRUCache:new(12)
|
local history_file_cache = LRUCache:new(12)
|
||||||
|
|
||||||
-- History path
|
---@param bufnr integer | nil
|
||||||
local History = {}
|
---@return string dirname
|
||||||
|
local function generate_project_dirname_in_storage(bufnr)
|
||||||
-- Get a chat history file name given a buffer
|
|
||||||
---@param bufnr integer
|
|
||||||
---@return string
|
|
||||||
function History.filename(bufnr)
|
|
||||||
local project_root = Utils.root.get({
|
local project_root = Utils.root.get({
|
||||||
buf = bufnr,
|
buf = bufnr,
|
||||||
})
|
})
|
||||||
-- Replace path separators with double underscores
|
-- Replace path separators with double underscores
|
||||||
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
|
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
|
||||||
-- Replace other non-alphanumeric characters with single underscores
|
-- Replace other non-alphanumeric characters with single underscores
|
||||||
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json"
|
local dirname = fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g")
|
||||||
|
return tostring(Path:new("projects"):joinpath(dirname))
|
||||||
|
end
|
||||||
|
|
||||||
|
-- History path
|
||||||
|
local History = {}
|
||||||
|
|
||||||
|
-- Get a chat history file name given a buffer
|
||||||
|
---@param bufnr integer
|
||||||
|
---@param new boolean
|
||||||
|
---@return Path
|
||||||
|
function History.filepath(bufnr, new)
|
||||||
|
local dirname = generate_project_dirname_in_storage(bufnr)
|
||||||
|
local history_dir = Path:new(Config.history.storage_path):joinpath(dirname):joinpath("history")
|
||||||
|
if not history_dir:exists() then history_dir:mkdir({ parents = true }) end
|
||||||
|
local pattern = tostring(history_dir:joinpath("*.json"))
|
||||||
|
local files = vim.fn.glob(pattern, true, true)
|
||||||
|
local filename = #files .. ".json"
|
||||||
|
if #files > 0 and not new then filename = (#files - 1) .. ".json" end
|
||||||
|
return history_dir:joinpath(filename)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Returns the Path to the chat history file for the given buffer.
|
-- Returns the Path to the chat history file for the given buffer.
|
||||||
---@param bufnr integer
|
---@param bufnr integer
|
||||||
---@return Path
|
---@return Path
|
||||||
function History.get(bufnr) return Path:new(Config.history.storage_path):joinpath(History.filename(bufnr)) end
|
function History.get(bufnr) return History.filepath(bufnr, false) end
|
||||||
|
|
||||||
|
---@param bufnr integer
|
||||||
|
function History.new(bufnr)
|
||||||
|
local filepath = History.filepath(bufnr, true)
|
||||||
|
local history = {
|
||||||
|
title = "untitled",
|
||||||
|
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
|
||||||
|
entries = {},
|
||||||
|
}
|
||||||
|
filepath:write(vim.json.encode(history), "w")
|
||||||
|
end
|
||||||
|
|
||||||
-- Loads the chat history for the given buffer.
|
-- Loads the chat history for the given buffer.
|
||||||
---@param bufnr integer
|
---@param bufnr integer
|
||||||
---@return avante.ChatHistoryEntry[]
|
---@return avante.ChatHistory
|
||||||
function History.load(bufnr)
|
function History.load(bufnr)
|
||||||
local history_file = History.get(bufnr)
|
local history_file = History.get(bufnr)
|
||||||
local cached_key = tostring(history_file:absolute())
|
local cached_key = tostring(history_file:absolute())
|
||||||
local cached_value = history_file_cache:get(cached_key)
|
local cached_value = history_file_cache:get(cached_key)
|
||||||
if cached_value ~= nil then return cached_value end
|
if cached_value ~= nil then return cached_value end
|
||||||
local value = {}
|
---@type avante.ChatHistory
|
||||||
|
local value = {
|
||||||
|
title = "untitled",
|
||||||
|
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
|
||||||
|
entries = {},
|
||||||
|
}
|
||||||
if history_file:exists() then
|
if history_file:exists() then
|
||||||
local content = history_file:read()
|
local content = history_file:read()
|
||||||
value = content ~= nil and vim.json.decode(content) or {}
|
value = content ~= nil and vim.json.decode(content) or value
|
||||||
end
|
end
|
||||||
history_file_cache:set(cached_key, value)
|
history_file_cache:set(cached_key, value)
|
||||||
return value
|
return value
|
||||||
@@ -64,7 +81,7 @@ end
|
|||||||
|
|
||||||
-- Saves the chat history for the given buffer.
|
-- Saves the chat history for the given buffer.
|
||||||
---@param bufnr integer
|
---@param bufnr integer
|
||||||
---@param history avante.ChatHistoryEntry[]
|
---@param history avante.ChatHistory
|
||||||
History.save = vim.schedule_wrap(function(bufnr, history)
|
History.save = vim.schedule_wrap(function(bufnr, history)
|
||||||
local history_file = History.get(bufnr)
|
local history_file = History.get(bufnr)
|
||||||
local cached_key = tostring(history_file:absolute())
|
local cached_key = tostring(history_file:absolute())
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ M.parse_response = O.parse_response
|
|||||||
M.parse_response_without_stream = O.parse_response_without_stream
|
M.parse_response_without_stream = O.parse_response_without_stream
|
||||||
M.is_disable_stream = O.is_disable_stream
|
M.is_disable_stream = O.is_disable_stream
|
||||||
|
|
||||||
function M:parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
@@ -27,9 +27,9 @@ function M:parse_curl_args(provider, prompt_opts)
|
|||||||
|
|
||||||
if P.env.require_api_key(provider_conf) then
|
if P.env.require_api_key(provider_conf) then
|
||||||
if provider_conf.entra then
|
if provider_conf.entra then
|
||||||
headers["Authorization"] = "Bearer " .. provider.parse_api_key()
|
headers["Authorization"] = "Bearer " .. self.parse_api_key()
|
||||||
else
|
else
|
||||||
headers["api-key"] = provider.parse_api_key()
|
headers["api-key"] = self.parse_api_key()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -51,13 +51,12 @@ function M:parse_response_without_stream(data, event_state, opts)
|
|||||||
vim.schedule(function() opts.on_stop({ reason = "complete" }) end)
|
vim.schedule(function() opts.on_stop({ reason = "complete" }) end)
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param provider AvanteBedrockProviderFunctor
|
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@return table
|
---@return table
|
||||||
function M:parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local base, body_opts = P.parse_config(self)
|
||||||
|
|
||||||
local api_key = provider.parse_api_key()
|
local api_key = self.parse_api_key()
|
||||||
if api_key == nil then error("Cannot get the bedrock api key!") end
|
if api_key == nil then error("Cannot get the bedrock api key!") end
|
||||||
local parts = vim.split(api_key, ",")
|
local parts = vim.split(api_key, ",")
|
||||||
local aws_access_key_id = parts[1]
|
local aws_access_key_id = parts[1]
|
||||||
|
|||||||
@@ -262,11 +262,10 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param provider AvanteProviderFunctor
|
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@return table
|
---@return table
|
||||||
function M:parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
@@ -275,7 +274,7 @@ function M:parse_curl_args(provider, prompt_opts)
|
|||||||
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
||||||
}
|
}
|
||||||
|
|
||||||
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end
|
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = self.parse_api_key() end
|
||||||
|
|
||||||
local messages = self:parse_messages(prompt_opts)
|
local messages = self:parse_messages(prompt_opts)
|
||||||
|
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ function M:parse_stream_data(ctx, data, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
["Accept"] = "application/json",
|
["Accept"] = "application/json",
|
||||||
@@ -84,7 +84,7 @@ function M:parse_curl_args(provider, prompt_opts)
|
|||||||
.. "."
|
.. "."
|
||||||
.. vim.version().patch,
|
.. vim.version().patch,
|
||||||
}
|
}
|
||||||
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. self.parse_api_key() end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
||||||
|
|||||||
@@ -215,12 +215,12 @@ M.parse_messages = OpenAI.parse_messages
|
|||||||
|
|
||||||
M.parse_response = OpenAI.parse_response
|
M.parse_response = OpenAI.parse_response
|
||||||
|
|
||||||
function M:parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
-- refresh token synchronously, only if it has expired
|
-- refresh token synchronously, only if it has expired
|
||||||
-- (this should rarely happen, as we refresh the token in the background)
|
-- (this should rarely happen, as we refresh the token in the background)
|
||||||
H.refresh_token(false, false)
|
H.refresh_token(false, false)
|
||||||
|
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
local tools = {}
|
local tools = {}
|
||||||
|
|||||||
@@ -83,8 +83,8 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
|
|
||||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||||
generationConfig = {
|
generationConfig = {
|
||||||
@@ -95,7 +95,7 @@ function M:parse_curl_args(provider, prompt_opts)
|
|||||||
request_body.temperature = nil
|
request_body.temperature = nil
|
||||||
request_body.max_tokens = nil
|
request_body.max_tokens = nil
|
||||||
|
|
||||||
local api_key = provider.parse_api_key()
|
local api_key = self.parse_api_key()
|
||||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -261,6 +261,13 @@ function M.setup()
|
|||||||
E.setup({ provider = cursor_applying_provider })
|
E.setup({ provider = cursor_applying_provider })
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if Config.memory_summary_provider then
|
||||||
|
local memory_summary_provider = M[Config.memory_summary_provider]
|
||||||
|
if memory_summary_provider and memory_summary_provider ~= provider then
|
||||||
|
E.setup({ provider = memory_summary_provider })
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param provider Provider
|
---@param provider Provider
|
||||||
|
|||||||
@@ -219,8 +219,8 @@ function M:parse_response_without_stream(data, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
@@ -228,7 +228,7 @@ function M:parse_curl_args(provider, prompt_opts)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if P.env.require_api_key(provider_conf) then
|
if P.env.require_api_key(provider_conf) then
|
||||||
local api_key = provider.parse_api_key()
|
local api_key = self.parse_api_key()
|
||||||
if api_key == nil then
|
if api_key == nil then
|
||||||
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ function M.parse_api_key()
|
|||||||
return direct_output
|
return direct_output
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
local location = vim.fn.getenv("LOCATION")
|
local location = vim.fn.getenv("LOCATION")
|
||||||
local project_id = vim.fn.getenv("PROJECT_ID")
|
local project_id = vim.fn.getenv("PROJECT_ID")
|
||||||
local model_id = provider_conf.model or "default-model-id"
|
local model_id = provider_conf.model or "default-model-id"
|
||||||
|
|||||||
@@ -2073,11 +2073,11 @@ function Sidebar:get_layout()
|
|||||||
return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal"
|
return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal"
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param history avante.ChatHistoryEntry[]
|
---@param history avante.ChatHistory
|
||||||
---@return string
|
---@return string
|
||||||
function Sidebar:render_history_content(history)
|
function Sidebar:render_history_content(history)
|
||||||
local content = ""
|
local content = ""
|
||||||
for idx, entry in ipairs(history) do
|
for idx, entry in ipairs(history.entries) do
|
||||||
if entry.reset_memory then
|
if entry.reset_memory then
|
||||||
content = content .. "***MEMORY RESET***\n\n"
|
content = content .. "***MEMORY RESET***\n\n"
|
||||||
if idx < #history then content = content .. "-------\n\n" end
|
if idx < #history then content = content .. "-------\n\n" end
|
||||||
@@ -2146,11 +2146,11 @@ end
|
|||||||
function Sidebar:clear_history(args, cb)
|
function Sidebar:clear_history(args, cb)
|
||||||
local chat_history = Path.history.load(self.code.bufnr)
|
local chat_history = Path.history.load(self.code.bufnr)
|
||||||
if next(chat_history) ~= nil then
|
if next(chat_history) ~= nil then
|
||||||
chat_history = {}
|
chat_history.entries = {}
|
||||||
Path.history.save(self.code.bufnr, chat_history)
|
Path.history.save(self.code.bufnr, chat_history)
|
||||||
self:update_content(
|
self:update_content(
|
||||||
"Chat history cleared",
|
"Chat history cleared",
|
||||||
{ focus = false, scroll = false, callback = function() self:focus_input() end }
|
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
|
||||||
)
|
)
|
||||||
if cb then cb(args) end
|
if cb then cb(args) end
|
||||||
else
|
else
|
||||||
@@ -2161,6 +2161,16 @@ function Sidebar:clear_history(args, cb)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Sidebar:new_chat(args, cb)
|
||||||
|
Path.history.new(self.code.bufnr)
|
||||||
|
Sidebar.reload_chat_history()
|
||||||
|
self:update_content(
|
||||||
|
"New chat",
|
||||||
|
{ ignore_history = true, focus = false, scroll = false, callback = function() self:focus_input() end }
|
||||||
|
)
|
||||||
|
if cb then cb(args) end
|
||||||
|
end
|
||||||
|
|
||||||
function Sidebar:reset_memory(args, cb)
|
function Sidebar:reset_memory(args, cb)
|
||||||
local chat_history = Path.history.load(self.code.bufnr)
|
local chat_history = Path.history.load(self.code.bufnr)
|
||||||
if next(chat_history) ~= nil then
|
if next(chat_history) ~= nil then
|
||||||
@@ -2176,6 +2186,7 @@ function Sidebar:reset_memory(args, cb)
|
|||||||
reset_memory = true,
|
reset_memory = true,
|
||||||
})
|
})
|
||||||
Path.history.save(self.code.bufnr, chat_history)
|
Path.history.save(self.code.bufnr, chat_history)
|
||||||
|
Sidebar.reload_chat_history()
|
||||||
local history_content = self:render_history_content(chat_history)
|
local history_content = self:render_history_content(chat_history)
|
||||||
self:update_content(history_content, {
|
self:update_content(history_content, {
|
||||||
focus = false,
|
focus = false,
|
||||||
@@ -2184,6 +2195,7 @@ function Sidebar:reset_memory(args, cb)
|
|||||||
})
|
})
|
||||||
if cb then cb(args) end
|
if cb then cb(args) end
|
||||||
else
|
else
|
||||||
|
Sidebar.reload_chat_history()
|
||||||
self:update_content(
|
self:update_content(
|
||||||
"Chat history is already empty",
|
"Chat history is already empty",
|
||||||
{ focus = false, scroll = false, callback = function() self:focus_input() end }
|
{ focus = false, scroll = false, callback = function() self:focus_input() end }
|
||||||
@@ -2191,7 +2203,7 @@ function Sidebar:reset_memory(args, cb)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit"
|
---@alias AvanteSlashCommandType "clear" | "help" | "lines" | "reset" | "commit" | "new"
|
||||||
---@alias AvanteSlashCommandCallback fun(args: string, cb?: fun(args: string): nil): nil
|
---@alias AvanteSlashCommandCallback fun(args: string, cb?: fun(args: string): nil): nil
|
||||||
---@alias AvanteSlashCommand {description: string, command: AvanteSlashCommandType, details: string, shorthelp?: string, callback?: AvanteSlashCommandCallback}
|
---@alias AvanteSlashCommand {description: string, command: AvanteSlashCommandType, details: string, shorthelp?: string, callback?: AvanteSlashCommandCallback}
|
||||||
---@return AvanteSlashCommand[]
|
---@return AvanteSlashCommand[]
|
||||||
@@ -2211,6 +2223,7 @@ function Sidebar:get_commands()
|
|||||||
{ description = "Show help message", command = "help" },
|
{ description = "Show help message", command = "help" },
|
||||||
{ description = "Clear chat history", command = "clear" },
|
{ description = "Clear chat history", command = "clear" },
|
||||||
{ description = "Reset memory", command = "reset" },
|
{ description = "Reset memory", command = "reset" },
|
||||||
|
{ description = "New chat", command = "new" },
|
||||||
{
|
{
|
||||||
shorthelp = "Ask a question about specific lines",
|
shorthelp = "Ask a question about specific lines",
|
||||||
description = "/lines <start>-<end> <question>",
|
description = "/lines <start>-<end> <question>",
|
||||||
@@ -2228,6 +2241,7 @@ function Sidebar:get_commands()
|
|||||||
end,
|
end,
|
||||||
clear = function(args, cb) self:clear_history(args, cb) end,
|
clear = function(args, cb) self:clear_history(args, cb) end,
|
||||||
reset = function(args, cb) self:reset_memory(args, cb) end,
|
reset = function(args, cb) self:reset_memory(args, cb) end,
|
||||||
|
new = function(args, cb) self:new_chat(args, cb) end,
|
||||||
lines = function(args, cb)
|
lines = function(args, cb)
|
||||||
if cb then cb(args) end
|
if cb then cb(args) end
|
||||||
end,
|
end,
|
||||||
@@ -2298,6 +2312,8 @@ function Sidebar:create_input_container(opts)
|
|||||||
|
|
||||||
local chat_history = Path.history.load(self.code.bufnr)
|
local chat_history = Path.history.load(self.code.bufnr)
|
||||||
|
|
||||||
|
Sidebar.reload_chat_history = function() chat_history = Path.history.load(self.code.bufnr) end
|
||||||
|
|
||||||
local tools = vim.deepcopy(LLMTools.get_tools())
|
local tools = vim.deepcopy(LLMTools.get_tools())
|
||||||
table.insert(tools, {
|
table.insert(tools, {
|
||||||
name = "add_file_to_context",
|
name = "add_file_to_context",
|
||||||
@@ -2330,8 +2346,9 @@ function Sidebar:create_input_container(opts)
|
|||||||
})
|
})
|
||||||
|
|
||||||
---@param request string
|
---@param request string
|
||||||
---@return AvanteGeneratePromptsOptions
|
---@param summarize_memory boolean
|
||||||
local function get_generate_prompts_options(request)
|
---@param cb fun(opts: AvanteGeneratePromptsOptions): nil
|
||||||
|
local function get_generate_prompts_options(request, summarize_memory, cb)
|
||||||
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })
|
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })
|
||||||
|
|
||||||
local selected_code_content = nil
|
local selected_code_content = nil
|
||||||
@@ -2355,40 +2372,19 @@ function Sidebar:create_input_container(opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local history_messages = {}
|
local entries = Utils.history.filter_active_entries(chat_history.entries)
|
||||||
for i = #chat_history, 1, -1 do
|
|
||||||
local entry = chat_history[i]
|
if chat_history.memory then
|
||||||
if entry.reset_memory then break end
|
entries = vim
|
||||||
if
|
.iter(entries)
|
||||||
entry.request == nil
|
:filter(function(entry) return entry.timestamp > chat_history.memory.last_summarized_timestamp end)
|
||||||
or entry.original_response == nil
|
:totable()
|
||||||
or entry.request == ""
|
|
||||||
or entry.original_response == ""
|
|
||||||
then
|
|
||||||
break
|
|
||||||
end
|
|
||||||
table.insert(
|
|
||||||
history_messages,
|
|
||||||
1,
|
|
||||||
{ role = "assistant", content = Utils.trim_think_content(entry.original_response) }
|
|
||||||
)
|
|
||||||
local user_content = ""
|
|
||||||
if entry.selected_file ~= nil then
|
|
||||||
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
|
|
||||||
end
|
|
||||||
if entry.selected_code ~= nil then
|
|
||||||
user_content = user_content
|
|
||||||
.. "SELECTED CODE:\n\n```"
|
|
||||||
.. entry.selected_code.filetype
|
|
||||||
.. "\n"
|
|
||||||
.. entry.selected_code.content
|
|
||||||
.. "\n```\n\n"
|
|
||||||
end
|
|
||||||
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
|
|
||||||
table.insert(history_messages, 1, { role = "user", content = user_content })
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
local history_messages = Utils.history.entries_to_llm_messages(entries)
|
||||||
|
|
||||||
|
---@type AvanteGeneratePromptsOptions
|
||||||
|
local prompts_opts = {
|
||||||
ask = opts.ask or true,
|
ask = opts.ask or true,
|
||||||
project_context = vim.json.encode(project_context),
|
project_context = vim.json.encode(project_context),
|
||||||
selected_files = selected_files_contents,
|
selected_files = selected_files_contents,
|
||||||
@@ -2400,6 +2396,20 @@ function Sidebar:create_input_container(opts)
|
|||||||
mode = Config.behaviour.enable_cursor_planning_mode and "cursor-planning" or "planning",
|
mode = Config.behaviour.enable_cursor_planning_mode and "cursor-planning" or "planning",
|
||||||
tools = tools,
|
tools = tools,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if chat_history.memory then prompts_opts.memory = chat_history.memory.content end
|
||||||
|
|
||||||
|
if not summarize_memory or #history_messages < 8 then
|
||||||
|
cb(prompts_opts)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
prompts_opts.history_messages = vim.list_slice(prompts_opts.history_messages, 5)
|
||||||
|
|
||||||
|
Llm.summarize_memory(self.code.bufnr, chat_history, function(memory)
|
||||||
|
if memory then prompts_opts.memory = memory.content end
|
||||||
|
cb(prompts_opts)
|
||||||
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param request string
|
---@param request string
|
||||||
@@ -2584,7 +2594,8 @@ function Sidebar:create_input_container(opts)
|
|||||||
end, 0)
|
end, 0)
|
||||||
|
|
||||||
-- Save chat history
|
-- Save chat history
|
||||||
table.insert(chat_history or {}, {
|
chat_history.entries = chat_history.entries or {}
|
||||||
|
table.insert(chat_history.entries, {
|
||||||
timestamp = timestamp,
|
timestamp = timestamp,
|
||||||
provider = Config.provider,
|
provider = Config.provider,
|
||||||
model = model,
|
model = model,
|
||||||
@@ -2597,17 +2608,18 @@ function Sidebar:create_input_container(opts)
|
|||||||
Path.history.save(self.code.bufnr, chat_history)
|
Path.history.save(self.code.bufnr, chat_history)
|
||||||
end
|
end
|
||||||
|
|
||||||
local generate_prompts_options = get_generate_prompts_options(request)
|
get_generate_prompts_options(request, true, function(generate_prompts_options)
|
||||||
---@type AvanteLLMStreamOptions
|
---@type AvanteLLMStreamOptions
|
||||||
---@diagnostic disable-next-line: assign-type-mismatch
|
---@diagnostic disable-next-line: assign-type-mismatch
|
||||||
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
|
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
|
||||||
on_start = on_start,
|
on_start = on_start,
|
||||||
on_chunk = on_chunk,
|
on_chunk = on_chunk,
|
||||||
on_stop = on_stop,
|
on_stop = on_stop,
|
||||||
on_tool_log = on_tool_log,
|
on_tool_log = on_tool_log,
|
||||||
})
|
})
|
||||||
|
|
||||||
Llm.stream(stream_options)
|
Llm.stream(stream_options)
|
||||||
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
local function get_position()
|
local function get_position()
|
||||||
@@ -2762,37 +2774,43 @@ function Sidebar:create_input_container(opts)
|
|||||||
local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
|
local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
|
||||||
.. ": submit"
|
.. ": submit"
|
||||||
|
|
||||||
if Config.behaviour.enable_token_counting then
|
local function show()
|
||||||
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")
|
local buf = api.nvim_create_buf(false, true)
|
||||||
local generate_prompts_options = get_generate_prompts_options(input_value)
|
api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text })
|
||||||
local tokens = Llm.calculate_tokens(generate_prompts_options)
|
api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1)
|
||||||
hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text
|
|
||||||
|
-- Get the current window size
|
||||||
|
local win_width = api.nvim_win_get_width(self.input_container.winid)
|
||||||
|
local width = #hint_text
|
||||||
|
|
||||||
|
-- Set the floating window options
|
||||||
|
local win_opts = {
|
||||||
|
relative = "win",
|
||||||
|
win = self.input_container.winid,
|
||||||
|
width = width,
|
||||||
|
height = 1,
|
||||||
|
row = get_float_window_row(),
|
||||||
|
col = math.max(win_width - width, 0), -- Display in the bottom right corner
|
||||||
|
style = "minimal",
|
||||||
|
border = "none",
|
||||||
|
focusable = false,
|
||||||
|
zindex = 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
-- Create the floating window
|
||||||
|
hint_window = api.nvim_open_win(buf, false, win_opts)
|
||||||
end
|
end
|
||||||
|
|
||||||
local buf = api.nvim_create_buf(false, true)
|
if Config.behaviour.enable_token_counting then
|
||||||
api.nvim_buf_set_lines(buf, 0, -1, false, { hint_text })
|
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")
|
||||||
api.nvim_buf_add_highlight(buf, 0, "AvantePopupHint", 0, 0, -1)
|
get_generate_prompts_options(input_value, false, function(generate_prompts_options)
|
||||||
|
local tokens = Llm.calculate_tokens(generate_prompts_options)
|
||||||
-- Get the current window size
|
hint_text = "Tokens: " .. tostring(tokens) .. "; " .. hint_text
|
||||||
local win_width = api.nvim_win_get_width(self.input_container.winid)
|
show()
|
||||||
local width = #hint_text
|
end)
|
||||||
|
else
|
||||||
-- Set the floating window options
|
show()
|
||||||
local win_opts = {
|
end
|
||||||
relative = "win",
|
|
||||||
win = self.input_container.winid,
|
|
||||||
width = width,
|
|
||||||
height = 1,
|
|
||||||
row = get_float_window_row(),
|
|
||||||
col = math.max(win_width - width, 0), -- Display in the bottom right corner
|
|
||||||
style = "minimal",
|
|
||||||
border = "none",
|
|
||||||
focusable = false,
|
|
||||||
zindex = 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Create the floating window
|
|
||||||
hint_window = api.nvim_open_win(buf, false, win_opts)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
api.nvim_create_autocmd({ "TextChanged", "TextChangedI", "VimResized" }, {
|
api.nvim_create_autocmd({ "TextChanged", "TextChangedI", "VimResized" }, {
|
||||||
|
|||||||
5
lua/avante/templates/_memory.avanterules
Normal file
5
lua/avante/templates/_memory.avanterules
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{%- if memory -%}
|
||||||
|
<memory>
|
||||||
|
{{memory}}
|
||||||
|
</memory>
|
||||||
|
{%- endif %}
|
||||||
@@ -190,13 +190,9 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
|
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
|
||||||
---
|
---
|
||||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
||||||
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
---
|
---
|
||||||
---@class AvanteResponseParserOptions
|
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteHandlerOptions): nil
|
||||||
---@field on_start AvanteLLMStartCallback
|
|
||||||
---@field on_chunk AvanteLLMChunkCallback
|
|
||||||
---@field on_stop AvanteLLMStopCallback
|
|
||||||
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil
|
|
||||||
---
|
---
|
||||||
---@class AvanteDefaultBaseProvider: table<string, any>
|
---@class AvanteDefaultBaseProvider: table<string, any>
|
||||||
---@field endpoint? string
|
---@field endpoint? string
|
||||||
@@ -305,6 +301,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field selected_files AvanteSelectedFiles[] | nil
|
---@field selected_files AvanteSelectedFiles[] | nil
|
||||||
---@field diagnostics string | nil
|
---@field diagnostics string | nil
|
||||||
---@field history_messages AvanteLLMMessage[] | nil
|
---@field history_messages AvanteLLMMessage[] | nil
|
||||||
|
---@field memory string | nil
|
||||||
---
|
---
|
||||||
---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions
|
---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions
|
||||||
---@field ask boolean
|
---@field ask boolean
|
||||||
@@ -358,3 +355,35 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field description string
|
---@field description string
|
||||||
---@field type 'string' | 'string[]' | 'boolean'
|
---@field type 'string' | 'string[]' | 'boolean'
|
||||||
---@field optional? boolean
|
---@field optional? boolean
|
||||||
|
---
|
||||||
|
---@class avante.ChatHistoryEntry
|
||||||
|
---@field timestamp string
|
||||||
|
---@field provider string
|
||||||
|
---@field model string
|
||||||
|
---@field request string
|
||||||
|
---@field response string
|
||||||
|
---@field original_response string
|
||||||
|
---@field selected_file {filepath: string}?
|
||||||
|
---@field selected_code {filetype: string, content: string}?
|
||||||
|
---@field reset_memory boolean?
|
||||||
|
---@field selected_filepaths string[] | nil
|
||||||
|
---
|
||||||
|
---@class avante.ChatHistory
|
||||||
|
---@field title string
|
||||||
|
---@field timestamp string
|
||||||
|
---@field entries avante.ChatHistoryEntry[]
|
||||||
|
---@field memory avante.ChatMemory | nil
|
||||||
|
---
|
||||||
|
---@class avante.ChatMemory
|
||||||
|
---@field content string
|
||||||
|
---@field last_summarized_timestamp string
|
||||||
|
|
||||||
|
---@class avante.Path
|
||||||
|
---@field history_path Path
|
||||||
|
---@field cache_path Path
|
||||||
|
---
|
||||||
|
---@class avante.CurlOpts
|
||||||
|
---@field provider AvanteProviderFunctor
|
||||||
|
---@field prompt_opts AvantePromptOptions
|
||||||
|
---@field handler_opts AvanteHandlerOptions
|
||||||
|
---
|
||||||
|
|||||||
52
lua/avante/utils/history.lua
Normal file
52
lua/avante/utils/history.lua
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
local Utils = require("avante.utils")
|
||||||
|
|
||||||
|
---@class avante.utils.history
|
||||||
|
local M = {}
|
||||||
|
|
||||||
|
---@param entries avante.ChatHistoryEntry[]
|
||||||
|
---@return avante.ChatHistoryEntry[]
|
||||||
|
function M.filter_active_entries(entries)
|
||||||
|
local entries_ = {}
|
||||||
|
|
||||||
|
for i = #entries, 1, -1 do
|
||||||
|
local entry = entries[i]
|
||||||
|
if entry.reset_memory then break end
|
||||||
|
if
|
||||||
|
entry.request == nil
|
||||||
|
or entry.original_response == nil
|
||||||
|
or entry.request == ""
|
||||||
|
or entry.original_response == ""
|
||||||
|
then
|
||||||
|
break
|
||||||
|
end
|
||||||
|
table.insert(entries_, 1, entry)
|
||||||
|
end
|
||||||
|
|
||||||
|
return entries_
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param entries avante.ChatHistoryEntry[]
|
||||||
|
---@return AvanteLLMMessage[]
|
||||||
|
function M.entries_to_llm_messages(entries)
|
||||||
|
local messages = {}
|
||||||
|
for _, entry in ipairs(entries) do
|
||||||
|
local user_content = ""
|
||||||
|
if entry.selected_file ~= nil then
|
||||||
|
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
|
||||||
|
end
|
||||||
|
if entry.selected_code ~= nil then
|
||||||
|
user_content = user_content
|
||||||
|
.. "SELECTED CODE:\n\n```"
|
||||||
|
.. entry.selected_code.filetype
|
||||||
|
.. "\n"
|
||||||
|
.. entry.selected_code.content
|
||||||
|
.. "\n```\n\n"
|
||||||
|
end
|
||||||
|
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
|
||||||
|
table.insert(messages, { role = "user", content = user_content })
|
||||||
|
table.insert(messages, { role = "assistant", content = Utils.trim_think_content(entry.original_response) })
|
||||||
|
end
|
||||||
|
return messages
|
||||||
|
end
|
||||||
|
|
||||||
|
return M
|
||||||
@@ -8,6 +8,7 @@ local lsp = vim.lsp
|
|||||||
---@field tokens avante.utils.tokens
|
---@field tokens avante.utils.tokens
|
||||||
---@field root avante.utils.root
|
---@field root avante.utils.root
|
||||||
---@field file avante.utils.file
|
---@field file avante.utils.file
|
||||||
|
---@field history avante.utils.history
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
setmetatable(M, {
|
setmetatable(M, {
|
||||||
|
|||||||
Reference in New Issue
Block a user