refactor: chat history based on project (#867)
This commit is contained in:
@@ -1,45 +1,52 @@
|
||||
local fn, api = vim.fn, vim.api
|
||||
local fn = vim.fn
|
||||
local Utils = require("avante.utils")
|
||||
local LRUCache = require("avante.utils.lru_cache")
|
||||
local Path = require("plenary.path")
|
||||
local Scan = require("plenary.scandir")
|
||||
local Config = require("avante.config")
|
||||
|
||||
---@class avante.ChatHistoryEntry
|
||||
---@field timestamp string
|
||||
---@field provider string
|
||||
---@field model string
|
||||
---@field request string
|
||||
---@field response string
|
||||
---@field original_response string
|
||||
---@field selected_file {filepath: string}?
|
||||
---@field selected_code {filetype: string, content: string}?
|
||||
---@field reset_memory boolean?
|
||||
|
||||
---@class avante.Path
|
||||
---@field history_path Path
|
||||
---@field cache_path Path
|
||||
local P = {}
|
||||
|
||||
-- Helpers
|
||||
local H = {}
|
||||
|
||||
-- Get a chat history file name given a buffer
|
||||
---@param bufnr integer
|
||||
---@return string
|
||||
H.filename = function(bufnr)
|
||||
local code_buf_name = api.nvim_buf_get_name(bufnr)
|
||||
-- Replace path separators with double underscores
|
||||
local path_with_separators = fn.substitute(code_buf_name, "/", "__", "g")
|
||||
-- Replace other non-alphanumeric characters with single underscores
|
||||
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json"
|
||||
end
|
||||
|
||||
-- Given a mode, return the file name for the custom prompt.
|
||||
---@param mode LlmMode
|
||||
H.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end
|
||||
|
||||
local history_file_cache = LRUCache:new(12)
|
||||
|
||||
-- History path
|
||||
local History = {}
|
||||
|
||||
-- Get a chat history file name given a buffer
|
||||
---@param bufnr integer
|
||||
---@return string
|
||||
History.filename = function(bufnr)
|
||||
local project_root = Utils.root.get({
|
||||
buf = bufnr,
|
||||
})
|
||||
-- Replace path separators with double underscores
|
||||
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
|
||||
-- Replace other non-alphanumeric characters with single underscores
|
||||
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json"
|
||||
end
|
||||
|
||||
-- Returns the Path to the chat history file for the given buffer.
|
||||
---@param bufnr integer
|
||||
---@return Path
|
||||
History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end
|
||||
History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(History.filename(bufnr)) end
|
||||
|
||||
-- Loads the chat history for the given buffer.
|
||||
---@param bufnr integer
|
||||
---@return avante.ChatHistoryEntry[]
|
||||
History.load = function(bufnr)
|
||||
local history_file = History.get(bufnr)
|
||||
local cached_key = tostring(history_file:absolute())
|
||||
@@ -56,7 +63,7 @@ end
|
||||
|
||||
-- Saves the chat history for the given buffer.
|
||||
---@param bufnr integer
|
||||
---@param history table
|
||||
---@param history avante.ChatHistoryEntry[]
|
||||
History.save = vim.schedule_wrap(function(bufnr, history)
|
||||
local history_file = History.get(bufnr)
|
||||
local cached_key = tostring(history_file:absolute())
|
||||
@@ -69,6 +76,10 @@ P.history = History
|
||||
-- Prompt path
|
||||
local Prompt = {}
|
||||
|
||||
-- Given a mode, return the file name for the custom prompt.
|
||||
---@param mode LlmMode
|
||||
Prompt.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end
|
||||
|
||||
---@class AvanteTemplates
|
||||
---@field initialize fun(directory: string): nil
|
||||
---@field render fun(template: string, context: TemplateOptions): string
|
||||
@@ -110,7 +121,7 @@ Prompt.get = function(bufnr)
|
||||
:copy({ destination = cache_prompt_dir, recursive = true })
|
||||
|
||||
vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
|
||||
local f = cache_prompt_dir:joinpath(H.get_mode_file(k))
|
||||
local f = cache_prompt_dir:joinpath(Prompt.get_mode_file(k))
|
||||
f:write(v, "w")
|
||||
end)
|
||||
|
||||
@@ -119,7 +130,7 @@ end
|
||||
|
||||
---@param mode LlmMode
|
||||
Prompt.get_file = function(mode)
|
||||
if Prompt.templates[mode] ~= nil then return H.get_mode_file(mode) end
|
||||
if Prompt.templates[mode] ~= nil then return Prompt.get_mode_file(mode) end
|
||||
return string.format("%s.avanterules", mode)
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user