1720 lines
54 KiB
Lua
1720 lines
54 KiB
Lua
local api = vim.api
|
||
local fn = vim.fn
|
||
local lsp = vim.lsp
|
||
|
||
local LRUCache = require("avante.utils.lru_cache")
|
||
local diff2search_replace = require("avante.utils.diff2search_replace")
|
||
|
||
---@class avante.utils: LazyUtilCore
|
||
---@field tokens avante.utils.tokens
|
||
---@field root avante.utils.root
|
||
---@field file avante.utils.file
|
||
---@field path avante.utils.path
|
||
---@field environment avante.utils.environment
|
||
---@field lsp avante.utils.lsp
|
||
---@field logger avante.utils.promptLogger
|
||
local M = {}
|
||
|
||
setmetatable(M, {
|
||
__index = function(t, k)
|
||
local ok, lazyutil = pcall(require, "lazy.core.util")
|
||
if ok and lazyutil[k] then return lazyutil[k] end
|
||
|
||
---@diagnostic disable-next-line: no-unknown
|
||
t[k] = require("avante.utils." .. k)
|
||
return t[k]
|
||
end,
|
||
})
|
||
|
||
---Check if a plugin is installed
|
||
---@param plugin string
|
||
---@return boolean
|
||
function M.has(plugin)
|
||
local ok, LazyConfig = pcall(require, "lazy.core.config")
|
||
if ok then return LazyConfig.plugins[plugin] ~= nil end
|
||
|
||
local res, _ = pcall(require, plugin)
|
||
return res
|
||
end
|
||
|
||
function M.is_win() return M.path.is_win() end
|
||
|
||
M.path_sep = M.path.SEP
|
||
|
||
---@return "linux" | "darwin" | "windows"
|
||
function M.get_os_name()
|
||
local os_name = vim.uv.os_uname().sysname
|
||
if os_name == "Linux" then
|
||
return "linux"
|
||
elseif os_name == "Darwin" then
|
||
return "darwin"
|
||
elseif os_name == "Windows_NT" then
|
||
return "windows"
|
||
else
|
||
error("Unsupported operating system: " .. os_name)
|
||
end
|
||
end
|
||
|
||
function M.get_system_info()
|
||
local os_name = vim.loop.os_uname().sysname
|
||
local os_version = vim.loop.os_uname().release
|
||
local os_machine = vim.loop.os_uname().machine
|
||
local lang = os.getenv("LANG")
|
||
local shell = os.getenv("SHELL")
|
||
|
||
local res = string.format(
|
||
"- Platform: %s-%s-%s\n- Shell: %s\n- Language: %s\n- Current date: %s",
|
||
os_name,
|
||
os_version,
|
||
os_machine,
|
||
shell,
|
||
lang,
|
||
os.date("%Y-%m-%d")
|
||
)
|
||
|
||
local project_root = M.root.get()
|
||
if project_root then res = res .. string.format("\n- Project root: %s", project_root) end
|
||
|
||
local is_git_repo = vim.fn.isdirectory(".git") == 1
|
||
if is_git_repo then res = res .. "\n- The user is operating inside a git repository" end
|
||
|
||
return res
|
||
end
|
||
|
||
---@param input_cmd string
|
||
---@param shell_cmd string?
|
||
local function get_cmd_for_shell(input_cmd, shell_cmd)
|
||
local shell = vim.o.shell:lower()
|
||
local cmd ---@type string
|
||
|
||
-- powershell then we can just run the cmd
|
||
if shell:match("powershell") or shell:match("pwsh") then
|
||
cmd = input_cmd
|
||
elseif fn.has("win32") > 0 then
|
||
cmd = 'powershell.exe -NoProfile -Command "' .. input_cmd:gsub('"', "'") .. '"'
|
||
else
|
||
-- linux and macos we will just do sh -c
|
||
shell_cmd = shell_cmd or "sh -c"
|
||
cmd = shell_cmd .. " " .. fn.shellescape(input_cmd)
|
||
end
|
||
|
||
return cmd
|
||
end
|
||
|
||
--- This function will run given shell command synchronously.
|
||
---@param input_cmd string
|
||
---@param shell_cmd string?
|
||
---@return vim.SystemCompleted
|
||
function M.shell_run(input_cmd, shell_cmd)
|
||
local cmd = get_cmd_for_shell(input_cmd, shell_cmd)
|
||
|
||
local output = fn.system(cmd)
|
||
local code = vim.v.shell_error
|
||
|
||
return { stdout = output, code = code }
|
||
end
|
||
|
||
---@param input_cmd string
|
||
---@param shell_cmd string?
|
||
---@param on_complete fun(output: string, code: integer)
|
||
---@param cwd? string
|
||
---@param timeout? integer Timeout in milliseconds
|
||
function M.shell_run_async(input_cmd, shell_cmd, on_complete, cwd, timeout)
|
||
local cmd = get_cmd_for_shell(input_cmd, shell_cmd)
|
||
---@type string[]
|
||
local output = {}
|
||
local timer = nil
|
||
local completed = false
|
||
|
||
-- Create a wrapper for on_complete to ensure it's only called once
|
||
local function complete_once(out, code)
|
||
if completed then return end
|
||
completed = true
|
||
|
||
-- Clean up timer if it exists
|
||
if timer then
|
||
timer:stop()
|
||
timer:close()
|
||
timer = nil
|
||
end
|
||
|
||
on_complete(out, code)
|
||
end
|
||
|
||
-- Start the job
|
||
local job_id = fn.jobstart(cmd, {
|
||
on_stdout = function(_, data)
|
||
if not data then return end
|
||
vim.list_extend(output, data)
|
||
end,
|
||
on_stderr = function(_, data)
|
||
if not data then return end
|
||
vim.list_extend(output, data)
|
||
end,
|
||
on_exit = function(_, exit_code) complete_once(table.concat(output, "\n"), exit_code) end,
|
||
cwd = cwd,
|
||
})
|
||
|
||
-- Set up timeout if specified
|
||
if timeout and timeout > 0 then
|
||
timer = vim.loop.new_timer()
|
||
if timer then
|
||
timer:start(timeout, 0, function()
|
||
vim.schedule(function()
|
||
if not completed and job_id then
|
||
-- Kill the job
|
||
fn.jobstop(job_id)
|
||
-- Complete with timeout error
|
||
complete_once("Command timed out after " .. timeout .. "ms", 124)
|
||
end
|
||
end)
|
||
end)
|
||
end
|
||
end
|
||
end
|
||
|
||
---@see https://github.com/LazyVim/LazyVim/blob/main/lua/lazyvim/util/toggle.lua
|
||
---
|
||
---@alias _ToggleSet fun(state: boolean): nil
|
||
---@alias _ToggleGet fun(): boolean
|
||
---
|
||
---@class ToggleBind
|
||
---@field name string
|
||
---@field set _ToggleSet
|
||
---@field get _ToggleGet
|
||
---
|
||
---@class ToggleBind.wrap: ToggleBind
|
||
---@operator call:boolean
|
||
|
||
---@param toggle ToggleBind
|
||
function M.toggle_wrap(toggle)
|
||
return setmetatable(toggle, {
|
||
__call = function()
|
||
toggle.set(not toggle.get())
|
||
local state = toggle.get()
|
||
if state then
|
||
M.info("enabled: " .. toggle.name)
|
||
else
|
||
M.warn("disabled: " .. toggle.name)
|
||
end
|
||
return state
|
||
end,
|
||
}) --[[@as ToggleBind.wrap]]
|
||
end
|
||
|
||
-- Wrapper around vim.keymap.set that will
|
||
-- not create a keymap if a lazy key handler exists.
|
||
-- It will also set `silent` to true by default.
|
||
--
|
||
---@param mode string|string[] Mode short-name, see |nvim_set_keymap()|.
|
||
--- Can also be list of modes to create mapping on multiple modes.
|
||
---@param lhs string Left-hand side |{lhs}| of the mapping.
|
||
---@param rhs string|function Right-hand side |{rhs}| of the mapping, can be a Lua function.
|
||
---
|
||
---@param opts? vim.keymap.set.Opts
|
||
function M.safe_keymap_set(mode, lhs, rhs, opts)
|
||
---@type boolean
|
||
local ok
|
||
---@module "lazy.core.handler"
|
||
local H
|
||
|
||
ok, H = pcall(require, "lazy.core.handler")
|
||
if not ok then
|
||
M.debug("lazy.nvim is not available. Avante will use vim.keymap.set")
|
||
vim.keymap.set(mode, lhs, rhs, opts)
|
||
return
|
||
end
|
||
|
||
local Keys = H.handlers.keys
|
||
---@cast Keys LazyKeysHandler
|
||
local modes = type(mode) == "string" and { mode } or mode
|
||
---@cast modes -string
|
||
|
||
---@param m string
|
||
---@diagnostic disable-next-line: undefined-field
|
||
modes = vim.tbl_filter(function(m) return not (Keys and Keys.have and Keys:have(lhs, m)) end, modes)
|
||
|
||
-- don't create keymap if a lazy keys handler exists
|
||
if #modes > 0 then
|
||
opts = opts or {}
|
||
opts.silent = opts.silent ~= false
|
||
if opts.remap and not vim.g.vscode then
|
||
---@diagnostic disable-next-line: no-unknown
|
||
opts.remap = nil
|
||
end
|
||
vim.keymap.set(mode, lhs, rhs, opts)
|
||
end
|
||
end
|
||
|
||
---@param str string
|
||
---@param opts? {suffix?: string, prefix?: string}
|
||
function M.trim(str, opts)
|
||
local res = str
|
||
if not opts then return res end
|
||
if opts.suffix then
|
||
res = res:sub(#res - #opts.suffix + 1) == opts.suffix and res:sub(1, #res - #opts.suffix) or res
|
||
end
|
||
if opts.prefix then res = res:sub(1, #opts.prefix) == opts.prefix and res:sub(#opts.prefix + 1) or res end
|
||
return res
|
||
end
|
||
|
||
function M.in_visual_mode()
|
||
local current_mode = fn.mode()
|
||
return current_mode == "v" or current_mode == "V" or current_mode == ""
|
||
end
|
||
|
||
---Get the selected content and range in Visual mode
|
||
---@return avante.SelectionResult | nil Selected content and range
|
||
function M.get_visual_selection_and_range()
|
||
if not M.in_visual_mode() then return nil end
|
||
|
||
local Range = require("avante.range")
|
||
local SelectionResult = require("avante.selection_result")
|
||
|
||
-- Get the start and end positions of Visual mode
|
||
local start_pos = fn.getpos("v")
|
||
local end_pos = fn.getpos(".")
|
||
|
||
-- Get the start and end line and column numbers
|
||
local start_line = start_pos[2]
|
||
local start_col = start_pos[3]
|
||
local end_line = end_pos[2]
|
||
local end_col = end_pos[3]
|
||
-- If the start point is after the end point, swap them
|
||
if start_line > end_line or (start_line == end_line and start_col > end_col) then
|
||
start_line, end_line = end_line, start_line
|
||
start_col, end_col = end_col, start_col
|
||
end
|
||
local content = "" -- luacheck: ignore
|
||
local range = Range:new({ lnum = start_line, col = start_col }, { lnum = end_line, col = end_col })
|
||
-- Check if it's a single-line selection
|
||
if start_line == end_line then
|
||
-- Get partial content of a single line
|
||
local line = fn.getline(start_line)
|
||
-- content = string.sub(line, start_col, end_col)
|
||
content = line
|
||
else
|
||
-- Multi-line selection: Get all lines in the selection
|
||
local lines = fn.getline(start_line, end_line)
|
||
-- Extract partial content of the first line
|
||
-- lines[1] = string.sub(lines[1], start_col)
|
||
-- Extract partial content of the last line
|
||
-- lines[#lines] = string.sub(lines[#lines], 1, end_col)
|
||
-- Concatenate all lines in the selection into a string
|
||
if type(lines) == "table" then
|
||
content = table.concat(lines, "\n")
|
||
else
|
||
content = lines
|
||
end
|
||
end
|
||
if not content then return nil end
|
||
local filepath = fn.expand("%:p")
|
||
local filetype = M.get_filetype(filepath)
|
||
-- Return the selected content and range
|
||
return SelectionResult:new(filepath, filetype, content, range)
|
||
end
|
||
|
||
---Wrapper around `api.nvim_buf_get_lines` which defaults to the current buffer
|
||
---@param start integer
|
||
---@param end_ integer
|
||
---@param buf integer?
|
||
---@return string[]
|
||
function M.get_buf_lines(start, end_, buf) return api.nvim_buf_get_lines(buf or 0, start, end_, false) end
|
||
|
||
---Get cursor row and column as (1, 0) based
|
||
---@param win_id integer?
|
||
---@return integer, integer
|
||
---@diagnostic disable-next-line: redundant-return-value
|
||
function M.get_cursor_pos(win_id) return unpack(api.nvim_win_get_cursor(win_id or 0)) end
|
||
|
||
---Check if the buffer is likely to have actionable conflict markers
|
||
---@param bufnr integer?
|
||
---@return boolean
|
||
function M.is_valid_buf(bufnr)
|
||
bufnr = bufnr or 0
|
||
return #vim.bo[bufnr].buftype == 0 and vim.bo[bufnr].modifiable
|
||
end
|
||
|
||
--- Check if a NUI container is valid:
|
||
--- 1. Container must exist
|
||
--- 2. Container must have a valid buffer number
|
||
--- 3. Container must have a valid window ID (optional, based on check_winid parameter)
|
||
--- Always returns a boolean value
|
||
---@param container NuiSplit | nil
|
||
---@param check_winid boolean? Whether to check window validity, defaults to false
|
||
---@return boolean
|
||
function M.is_valid_container(container, check_winid)
|
||
-- Default check_winid to false if not specified
|
||
if check_winid == nil then check_winid = false end
|
||
|
||
-- First check if container exists
|
||
if container == nil then return false end
|
||
|
||
-- Check buffer validity
|
||
if container.bufnr == nil or not api.nvim_buf_is_valid(container.bufnr) then return false end
|
||
|
||
-- Check window validity if requested
|
||
if check_winid then
|
||
if container.winid == nil or not api.nvim_win_is_valid(container.winid) then return false end
|
||
end
|
||
|
||
return true
|
||
end
|
||
|
||
---@param name string?
|
||
---@return table
|
||
function M.get_hl(name)
|
||
if not name then return {} end
|
||
return api.nvim_get_hl(0, { name = name, link = false })
|
||
end
|
||
|
||
--- vendor from lazy.nvim for early access and override
|
||
|
||
---@param path string
|
||
---@return string
|
||
function M.norm(path) return M.path.normalize(path) end
|
||
|
||
---@param msg string|string[]
|
||
---@param opts? LazyNotifyOpts
|
||
function M.notify(msg, opts)
|
||
if msg == nil then return end
|
||
if vim.in_fast_event() then
|
||
return vim.schedule(function() M.notify(msg, opts) end)
|
||
end
|
||
|
||
opts = opts or {}
|
||
if type(msg) == "table" then
|
||
---@diagnostic disable-next-line: no-unknown
|
||
msg = table.concat(vim.tbl_filter(function(line) return line or false end, msg), "\n")
|
||
end
|
||
---@diagnostic disable-next-line: undefined-field
|
||
if opts.stacktrace then
|
||
---@diagnostic disable-next-line: undefined-field
|
||
msg = msg .. M.pretty_trace({ level = opts.stacklevel or 2 })
|
||
end
|
||
local lang = opts.lang or "markdown"
|
||
---@diagnostic disable-next-line: undefined-field
|
||
local n = opts.once and vim.notify_once or vim.notify
|
||
n(msg, opts.level or vim.log.levels.INFO, {
|
||
on_open = function(win)
|
||
local ok = pcall(function() vim.treesitter.language.add("markdown") end)
|
||
vim.wo[win].conceallevel = 3
|
||
vim.wo[win].concealcursor = ""
|
||
vim.wo[win].spell = false
|
||
local buf = api.nvim_win_get_buf(win)
|
||
if not pcall(vim.treesitter.start, buf, lang) then
|
||
vim.bo[buf].filetype = lang
|
||
vim.bo[buf].syntax = lang
|
||
end
|
||
end,
|
||
title = opts.title or "Avante",
|
||
})
|
||
end
|
||
|
||
---@param msg string|string[]
|
||
---@param opts? LazyNotifyOpts
|
||
function M.error(msg, opts)
|
||
opts = opts or {}
|
||
opts.level = vim.log.levels.ERROR
|
||
M.notify(msg, opts)
|
||
end
|
||
|
||
---@param msg string|string[]
|
||
---@param opts? LazyNotifyOpts
|
||
function M.info(msg, opts)
|
||
opts = opts or {}
|
||
opts.level = vim.log.levels.INFO
|
||
M.notify(msg, opts)
|
||
end
|
||
|
||
---@param msg string|string[]
|
||
---@param opts? LazyNotifyOpts
|
||
function M.warn(msg, opts)
|
||
opts = opts or {}
|
||
opts.level = vim.log.levels.WARN
|
||
M.notify(msg, opts)
|
||
end
|
||
|
||
function M.debug(...)
|
||
if not require("avante.config").debug then return end
|
||
|
||
local args = { ... }
|
||
if #args == 0 then return end
|
||
|
||
-- Get caller information
|
||
local info = debug.getinfo(2, "Sl")
|
||
local caller_source = info.source:match("@(.+)$") or "unknown"
|
||
local caller_module = caller_source:gsub("^.*/lua/", ""):gsub("%.lua$", ""):gsub("/", ".")
|
||
|
||
local timestamp = M.get_timestamp()
|
||
local formated_args = {
|
||
"[" .. timestamp .. "] [AVANTE] [DEBUG] [" .. caller_module .. ":" .. info.currentline .. "]",
|
||
}
|
||
|
||
for _, arg in ipairs(args) do
|
||
if type(arg) == "string" then
|
||
table.insert(formated_args, arg)
|
||
else
|
||
table.insert(formated_args, vim.inspect(arg))
|
||
end
|
||
end
|
||
print(unpack(formated_args))
|
||
end
|
||
|
||
function M.tbl_indexof(tbl, value)
|
||
for i, v in ipairs(tbl) do
|
||
if v == value then return i end
|
||
end
|
||
return nil
|
||
end
|
||
|
||
function M.update_win_options(winid, opt_name, key, value)
|
||
local cur_opt_value = api.nvim_get_option_value(opt_name, { win = winid })
|
||
|
||
if cur_opt_value:find(key .. ":") then
|
||
cur_opt_value = cur_opt_value:gsub(key .. ":[^,]*", key .. ":" .. value)
|
||
else
|
||
if #cur_opt_value > 0 then cur_opt_value = cur_opt_value .. "," end
|
||
cur_opt_value = cur_opt_value .. key .. ":" .. value
|
||
end
|
||
|
||
api.nvim_set_option_value(opt_name, cur_opt_value, { win = winid })
|
||
end
|
||
|
||
function M.get_win_options(winid, opt_name, key)
|
||
local cur_opt_value = api.nvim_get_option_value(opt_name, { win = winid })
|
||
if not cur_opt_value then return end
|
||
local pieces = vim.split(cur_opt_value, ",")
|
||
for _, piece in ipairs(pieces) do
|
||
local kv_pair = vim.split(piece, ":")
|
||
if kv_pair[1] == key then return kv_pair[2] end
|
||
end
|
||
end
|
||
|
||
function M.get_winid(bufnr)
|
||
for _, winid in ipairs(api.nvim_list_wins()) do
|
||
if api.nvim_win_get_buf(winid) == bufnr then return winid end
|
||
end
|
||
end
|
||
|
||
function M.unlock_buf(bufnr)
|
||
vim.bo[bufnr].readonly = false
|
||
vim.bo[bufnr].modified = false
|
||
vim.bo[bufnr].modifiable = true
|
||
end
|
||
|
||
function M.lock_buf(bufnr)
|
||
if bufnr == api.nvim_get_current_buf() then vim.cmd("noautocmd stopinsert") end
|
||
vim.bo[bufnr].readonly = true
|
||
vim.bo[bufnr].modified = false
|
||
vim.bo[bufnr].modifiable = false
|
||
end
|
||
|
||
---@param winnr? number
|
||
---@return nil
|
||
function M.scroll_to_end(winnr)
|
||
winnr = winnr or 0
|
||
local bufnr = api.nvim_win_get_buf(winnr)
|
||
local lnum = api.nvim_buf_line_count(bufnr)
|
||
local last_line = api.nvim_buf_get_lines(bufnr, -2, -1, true)[1]
|
||
api.nvim_win_set_cursor(winnr, { lnum, api.nvim_strwidth(last_line) })
|
||
end
|
||
|
||
---@param bufnr nil|integer
|
||
---@return nil
|
||
function M.buf_scroll_to_end(bufnr)
|
||
for _, winnr in ipairs(M.buf_list_wins(bufnr or 0)) do
|
||
M.scroll_to_end(winnr)
|
||
end
|
||
end
|
||
|
||
---@param bufnr nil|integer
|
||
---@return integer[]
|
||
function M.buf_list_wins(bufnr)
|
||
local wins = {}
|
||
|
||
if not bufnr or bufnr == 0 then bufnr = api.nvim_get_current_buf() end
|
||
|
||
for _, winnr in ipairs(api.nvim_list_wins()) do
|
||
if api.nvim_win_is_valid(winnr) and api.nvim_win_get_buf(winnr) == bufnr then table.insert(wins, winnr) end
|
||
end
|
||
|
||
return wins
|
||
end
|
||
|
||
local sidebar_buffer_var_name = "is_avante_sidebar_buffer"
|
||
|
||
function M.mark_as_sidebar_buffer(bufnr) api.nvim_buf_set_var(bufnr, sidebar_buffer_var_name, true) end
|
||
|
||
function M.is_sidebar_buffer(bufnr)
|
||
local ok, v = pcall(api.nvim_buf_get_var, bufnr, sidebar_buffer_var_name)
|
||
if not ok then return false end
|
||
return v == true
|
||
end
|
||
|
||
function M.trim_spaces(s) return s:match("^%s*(.-)%s*$") end
|
||
|
||
---Remove trailing spaces from each line in a string
|
||
---@param content string The content to process
|
||
---@return string The content with trailing spaces removed from each line
|
||
function M.remove_trailing_spaces(content)
|
||
if not content then return content end
|
||
local lines = vim.split(content, "\n")
|
||
for i, line in ipairs(lines) do
|
||
lines[i] = line:gsub("%s+$", "")
|
||
end
|
||
return table.concat(lines, "\n")
|
||
end
|
||
|
||
function M.fallback(v, default_value) return type(v) == "nil" and default_value or v end
|
||
|
||
---Join URL parts together, handling slashes correctly
|
||
---@param ... string URL parts to join
|
||
---@return string Joined URL
|
||
function M.url_join(...)
|
||
local parts = { ... }
|
||
local result = parts[1] or ""
|
||
|
||
for i = 2, #parts do
|
||
local part = parts[i]
|
||
if not part or part == "" then goto continue end
|
||
|
||
-- Remove trailing slash from result if present
|
||
if result:sub(-1) == "/" then result = result:sub(1, -2) end
|
||
|
||
-- Remove leading slash from part if present
|
||
if part:sub(1, 1) == "/" then part = part:sub(2) end
|
||
|
||
-- Join with slash
|
||
result = result .. "/" .. part
|
||
|
||
::continue::
|
||
end
|
||
|
||
if result:sub(-1) == "/" then result = result:sub(1, -2) end
|
||
|
||
return result
|
||
end
|
||
|
||
-- luacheck: push no max comment line length
|
||
---@param type_name "'nil'" | "'number'" | "'string'" | "'boolean'" | "'table'" | "'function'" | "'thread'" | "'userdata'" | "'list'" | '"map"'
|
||
---@return boolean
|
||
function M.is_type(type_name, v)
|
||
---@diagnostic disable-next-line: deprecated
|
||
local islist = vim.islist or vim.tbl_islist
|
||
if type_name == "list" then return islist(v) end
|
||
|
||
if type_name == "map" then return type(v) == "table" and not islist(v) end
|
||
|
||
return type(v) == type_name
|
||
end
|
||
-- luacheck: pop
|
||
|
||
---@param text string
|
||
---@return string
|
||
function M.get_indentation(text)
|
||
if not text then return "" end
|
||
return text:match("^%s*") or ""
|
||
end
|
||
|
||
function M.trim_space(text)
|
||
if not text then return text end
|
||
return text:gsub("%s*", "")
|
||
end
|
||
|
||
function M.trim_escapes(text)
|
||
if not text then return text end
|
||
local res = text
|
||
:gsub("//n", "/n")
|
||
:gsub("//r", "/r")
|
||
:gsub("//t", "/t")
|
||
:gsub('/"', '"')
|
||
:gsub('\\"', '"')
|
||
:gsub("\\n", "\n")
|
||
:gsub("\\r", "\r")
|
||
:gsub("\\t", "\t")
|
||
return res
|
||
end
|
||
|
||
---@param original_lines string[]
|
||
---@param target_lines string[]
|
||
---@param compare_fn fun(line_a: string, line_b: string): boolean
|
||
---@return integer | nil start_line
|
||
---@return integer | nil end_line
|
||
function M.try_find_match(original_lines, target_lines, compare_fn)
|
||
local start_line, end_line
|
||
for i = 1, #original_lines - #target_lines + 1 do
|
||
local match = true
|
||
for j = 1, #target_lines do
|
||
if not compare_fn(original_lines[i + j - 1], target_lines[j]) then
|
||
match = false
|
||
break
|
||
end
|
||
end
|
||
if match then
|
||
start_line = i
|
||
end_line = i + #target_lines - 1
|
||
break
|
||
end
|
||
end
|
||
return start_line, end_line
|
||
end
|
||
|
||
---@param original_lines string[]
|
||
---@param target_lines string[]
|
||
---@return integer | nil start_line
|
||
---@return integer | nil end_line
|
||
function M.fuzzy_match(original_lines, target_lines)
|
||
local start_line, end_line
|
||
---exact match
|
||
start_line, end_line = M.try_find_match(
|
||
original_lines,
|
||
target_lines,
|
||
function(line_a, line_b) return line_a == line_b end
|
||
)
|
||
if start_line ~= nil and end_line ~= nil then return start_line, end_line end
|
||
---fuzzy match
|
||
start_line, end_line = M.try_find_match(
|
||
original_lines,
|
||
target_lines,
|
||
function(line_a, line_b) return M.trim(line_a, { suffix = " \t" }) == M.trim(line_b, { suffix = " \t" }) end
|
||
)
|
||
if start_line ~= nil and end_line ~= nil then return start_line, end_line end
|
||
---trim_space match
|
||
start_line, end_line = M.try_find_match(
|
||
original_lines,
|
||
target_lines,
|
||
function(line_a, line_b) return M.trim_space(line_a) == M.trim_space(line_b) end
|
||
)
|
||
if start_line ~= nil and end_line ~= nil then return start_line, end_line end
|
||
---trim slashes match
|
||
start_line, end_line = M.try_find_match(
|
||
original_lines,
|
||
target_lines,
|
||
function(line_a, line_b) return line_a == M.trim_escapes(line_b) end
|
||
)
|
||
if start_line ~= nil and end_line ~= nil then return start_line, end_line end
|
||
---trim slashes and trim_space match
|
||
start_line, end_line = M.try_find_match(
|
||
original_lines,
|
||
target_lines,
|
||
function(line_a, line_b) return M.trim_space(line_a) == M.trim_space(M.trim_escapes(line_b)) end
|
||
)
|
||
return start_line, end_line
|
||
end
|
||
|
||
function M.relative_path(absolute)
|
||
local project_root = M.get_project_root()
|
||
return M.make_relative_path(absolute, project_root)
|
||
end
|
||
|
||
function M.get_doc()
|
||
local absolute = api.nvim_buf_get_name(0)
|
||
local params = lsp.util.make_position_params(0, "utf-8")
|
||
|
||
local position = {
|
||
row = params.position.line + 1,
|
||
col = params.position.character,
|
||
}
|
||
|
||
local doc = {
|
||
uri = params.textDocument.uri,
|
||
version = api.nvim_buf_get_var(0, "changedtick"),
|
||
relativePath = M.relative_path(absolute),
|
||
insertSpaces = vim.o.expandtab,
|
||
tabSize = fn.shiftwidth(),
|
||
indentSize = fn.shiftwidth(),
|
||
position = position,
|
||
}
|
||
|
||
return doc
|
||
end
|
||
|
||
---Prepends line numbers to each line in a list of strings.
|
||
---@param lines string[] The lines of content to prepend line numbers to.
|
||
---@param start_line? integer The starting line number. Defaults to 1.
|
||
---@return string[] A new list of strings with line numbers prepended.
|
||
function M.prepend_line_numbers(lines, start_line)
|
||
start_line = start_line or 1
|
||
return vim
|
||
.iter(lines)
|
||
:enumerate()
|
||
:map(function(i, line) return string.format("L%d: %s", i + start_line, line) end)
|
||
:totable()
|
||
end
|
||
|
||
---Iterates through a list of strings and removes prefixes in form of "L<number>: " from them
|
||
---@param content string[]
|
||
---@return string[]
|
||
function M.trim_line_numbers(content)
|
||
return vim.iter(content):map(function(line) return (line:gsub("^L%d+: ", "")) end):totable()
|
||
end
|
||
|
||
---Debounce a function call
|
||
---@param func fun(...) function to debounce
|
||
---@param delay integer delay in milliseconds
|
||
---@return fun(...): uv.uv_timer_t debounced function
|
||
function M.debounce(func, delay)
|
||
local timer = nil
|
||
|
||
return function(...)
|
||
local args = { ... }
|
||
|
||
if timer and not timer:is_closing() then
|
||
timer:stop()
|
||
timer:close()
|
||
end
|
||
|
||
timer = vim.defer_fn(function()
|
||
func(unpack(args))
|
||
timer = nil
|
||
end, delay)
|
||
|
||
return timer
|
||
end
|
||
end
|
||
|
||
---Throttle a function call
|
||
---@param func fun(...) function to throttle
|
||
---@param delay integer delay in milliseconds
|
||
---@return fun(...): nil throttled function
|
||
function M.throttle(func, delay)
|
||
local timer = nil
|
||
|
||
return function(...)
|
||
if timer then return end
|
||
|
||
local args = { ... }
|
||
|
||
timer = vim.defer_fn(function()
|
||
func(unpack(args))
|
||
timer = nil
|
||
end, delay)
|
||
end
|
||
end
|
||
|
||
function M.winline(winid)
|
||
-- If the winid is not provided, then line number should be 1, so that it can land on the first line
|
||
if not vim.api.nvim_win_is_valid(winid) then return 1 end
|
||
|
||
local line = 1
|
||
vim.api.nvim_win_call(winid, function() line = fn.winline() end)
|
||
|
||
return line
|
||
end
|
||
|
||
function M.get_project_root() return M.root.get() end
|
||
|
||
function M.is_same_file_ext(target_ext, filepath)
|
||
local ext = fn.fnamemodify(filepath, ":e")
|
||
if (target_ext == "tsx" and ext == "ts") or (target_ext == "ts" and ext == "tsx") then return true end
|
||
if (target_ext == "jsx" and ext == "js") or (target_ext == "js" and ext == "jsx") then return true end
|
||
return ext == target_ext
|
||
end
|
||
|
||
-- Get recent filepaths in the same project and same file ext
|
||
---@param limit? integer
|
||
---@param filenames? string[]
|
||
---@param same_file_ext? boolean
|
||
---@return string[]
|
||
function M.get_recent_filepaths(limit, filenames, same_file_ext)
|
||
local project_root = M.get_project_root()
|
||
local current_ext = fn.expand("%:e")
|
||
local oldfiles = vim.v.oldfiles
|
||
local recent_files = {}
|
||
|
||
for _, file in ipairs(oldfiles) do
|
||
if vim.startswith(file, project_root) then
|
||
local has_ext = file:match("%.%w+$")
|
||
if not has_ext then goto continue end
|
||
if same_file_ext then
|
||
if not M.is_same_file_ext(current_ext, file) then goto continue end
|
||
end
|
||
if filenames and #filenames > 0 then
|
||
for _, filename in ipairs(filenames) do
|
||
if file:find(filename) then table.insert(recent_files, file) end
|
||
end
|
||
else
|
||
table.insert(recent_files, file)
|
||
end
|
||
if #recent_files >= (limit or 10) then break end
|
||
end
|
||
::continue::
|
||
end
|
||
|
||
return recent_files
|
||
end
|
||
|
||
local function pattern_to_lua(pattern)
|
||
local lua_pattern = pattern:gsub("[%(%)%.%%%+%-%*%?%[%]%^%$]", "%%%1")
|
||
lua_pattern = lua_pattern:gsub("%*%*/", ".-/")
|
||
lua_pattern = lua_pattern:gsub("%*", "[^/]*")
|
||
lua_pattern = lua_pattern:gsub("%?", ".")
|
||
if lua_pattern:sub(-1) == "/" then lua_pattern = lua_pattern .. ".*" end
|
||
return lua_pattern
|
||
end
|
||
|
||
function M.parse_gitignore(gitignore_path)
|
||
local ignore_patterns = {}
|
||
local negate_patterns = {}
|
||
local file = io.open(gitignore_path, "r")
|
||
if not file then return ignore_patterns, negate_patterns end
|
||
|
||
for line in file:lines() do
|
||
if line:match("%S") and not line:match("^#") then
|
||
local trimmed_line = line:match("^%s*(.-)%s*$")
|
||
if trimmed_line:sub(1, 1) == "!" then
|
||
table.insert(negate_patterns, pattern_to_lua(trimmed_line:sub(2)))
|
||
else
|
||
table.insert(ignore_patterns, pattern_to_lua(trimmed_line))
|
||
end
|
||
end
|
||
end
|
||
|
||
file:close()
|
||
ignore_patterns = vim.list_extend(ignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" })
|
||
return ignore_patterns, negate_patterns
|
||
end
|
||
|
||
-- @param file string
|
||
-- @param ignore_patterns string[]
|
||
-- @param negate_patterns string[]
|
||
-- @return boolean
|
||
function M.is_ignored(file, ignore_patterns, negate_patterns)
|
||
for _, pattern in ipairs(negate_patterns) do
|
||
if file:match(pattern) then return false end
|
||
end
|
||
for _, pattern in ipairs(ignore_patterns) do
|
||
if file:match(pattern .. "/") or file:match(pattern .. "$") then return true end
|
||
end
|
||
return false
|
||
end
|
||
|
||
---@param options { directory: string, add_dirs?: boolean, max_depth?: integer }
|
||
---@return string[]
|
||
function M.scan_directory(options)
|
||
local cmd_supports_max_depth = true
|
||
local cmd = (function()
|
||
if vim.fn.executable("rg") == 1 then
|
||
local cmd = { "rg", "--files", "--color", "never", "--no-require-git", "--no-ignore-parent" }
|
||
if options.max_depth ~= nil then vim.list_extend(cmd, { "--max-depth", options.max_depth }) end
|
||
table.insert(cmd, options.directory)
|
||
return cmd
|
||
end
|
||
if vim.fn.executable("fd") == 1 then
|
||
local cmd = { "fd", "--type", "f", "--color", "never", "--no-require-git" }
|
||
if options.max_depth ~= nil then vim.list_extend(cmd, { "--max-depth", options.max_depth }) end
|
||
vim.list_extend(cmd, { "--base-directory", options.directory })
|
||
return cmd
|
||
end
|
||
if vim.fn.executable("fdfind") == 1 then
|
||
local cmd = { "fdfind", "--type", "f", "--color", "never", "--no-require-git" }
|
||
if options.max_depth ~= nil then vim.list_extend(cmd, { "--max-depth", options.max_depth }) end
|
||
vim.list_extend(cmd, { "--base-directory", options.directory })
|
||
return cmd
|
||
end
|
||
end)()
|
||
|
||
if not cmd then
|
||
if M.path_exists(M.join_paths(options.directory, ".git")) and vim.fn.executable("git") == 1 then
|
||
if vim.fn.has("win32") == 1 then
|
||
cmd = {
|
||
"powershell",
|
||
"-NoProfile",
|
||
"-NonInteractive",
|
||
"-Command",
|
||
string.format(
|
||
"Push-Location '%s'; (git ls-files --exclude-standard), (git ls-files --exclude-standard --others)",
|
||
options.directory:gsub("/", "\\")
|
||
),
|
||
}
|
||
else
|
||
cmd = {
|
||
"bash",
|
||
"-c",
|
||
string.format("cd %s && git ls-files -co --exclude-standard", options.directory),
|
||
}
|
||
end
|
||
cmd_supports_max_depth = false
|
||
else
|
||
M.error("No search command found, please install fd or fdfind or rg")
|
||
return {}
|
||
end
|
||
end
|
||
|
||
local files = vim.fn.systemlist(cmd)
|
||
|
||
files = vim
|
||
.iter(files)
|
||
:map(function(file)
|
||
if not M.is_absolute_path(file) then return M.join_paths(options.directory, file) end
|
||
return file
|
||
end)
|
||
:totable()
|
||
|
||
if options.max_depth ~= nil and not cmd_supports_max_depth then
|
||
files = vim
|
||
.iter(files)
|
||
:filter(function(file)
|
||
local base_dir = options.directory
|
||
if base_dir:sub(-2) == "/." then base_dir = base_dir:sub(1, -3) end
|
||
local rel_path = M.make_relative_path(file, base_dir)
|
||
local pieces = vim.split(rel_path, "/")
|
||
return #pieces <= options.max_depth
|
||
end)
|
||
:totable()
|
||
end
|
||
|
||
if options.add_dirs then
|
||
local dirs = {}
|
||
local dirs_seen = {}
|
||
for _, file in ipairs(files) do
|
||
local dir = M.get_parent_path(file)
|
||
if not dirs_seen[dir] then
|
||
table.insert(dirs, dir)
|
||
dirs_seen[dir] = true
|
||
end
|
||
end
|
||
files = vim.list_extend(dirs, files)
|
||
end
|
||
|
||
return files
|
||
end
|
||
|
||
function M.get_parent_path(filepath)
|
||
if filepath == nil then error("filepath cannot be nil") end
|
||
if filepath == "" then return "" end
|
||
local is_abs = M.is_absolute_path(filepath)
|
||
if filepath:sub(-1) == M.path_sep then filepath = filepath:sub(1, -2) end
|
||
if filepath == "" then return "" end
|
||
local parts = vim.split(filepath, M.path_sep)
|
||
local parent_parts = vim.list_slice(parts, 1, #parts - 1)
|
||
local res = table.concat(parent_parts, M.path_sep)
|
||
if res == "" then
|
||
if is_abs then return M.path_sep end
|
||
return "."
|
||
end
|
||
return res
|
||
end
|
||
|
||
function M.make_relative_path(filepath, base_dir) return M.path.relative(base_dir, filepath, false) end
|
||
|
||
function M.is_absolute_path(path) return M.path.is_absolute(path) end
|
||
|
||
function M.to_absolute_path(path)
|
||
if not path or path == "" then return path end
|
||
if path:sub(1, 1) == "/" or path:sub(1, 7) == "term://" then return path end
|
||
return M.join_paths(M.get_project_root(), path)
|
||
end
|
||
|
||
function M.join_paths(...)
|
||
local paths = { ... }
|
||
local result = paths[1] or ""
|
||
for i = 2, #paths do
|
||
local path = paths[i]
|
||
if path == nil or path == "" then goto continue end
|
||
|
||
if M.is_absolute_path(path) then
|
||
result = path
|
||
goto continue
|
||
end
|
||
|
||
result = result == "" and path or M.path.join(result, path)
|
||
::continue::
|
||
end
|
||
return M.norm(result)
|
||
end
|
||
|
||
function M.path_exists(path) return M.path.is_exist(path) end
|
||
|
||
function M.is_first_letter_uppercase(str) return string.match(str, "^[A-Z]") ~= nil end
|
||
|
||
---@param content string
|
||
---@return { new_content: string, enable_project_context: boolean, enable_diagnostics: boolean }
|
||
function M.extract_mentions(content)
|
||
-- if content contains @codebase, enable project context and remove @codebase
|
||
local new_content = content
|
||
local enable_project_context = false
|
||
local enable_diagnostics = false
|
||
if content:match("@codebase") then
|
||
enable_project_context = true
|
||
new_content = content:gsub("@codebase", "")
|
||
end
|
||
if content:match("@diagnostics") then enable_diagnostics = true end
|
||
return {
|
||
new_content = new_content,
|
||
enable_project_context = enable_project_context,
|
||
enable_diagnostics = enable_diagnostics,
|
||
}
|
||
end
|
||
|
||
---@return AvanteMention[]
|
||
function M.get_mentions()
|
||
return {
|
||
{
|
||
description = "codebase",
|
||
command = "codebase",
|
||
details = "repo map",
|
||
},
|
||
{
|
||
description = "diagnostics",
|
||
command = "diagnostics",
|
||
details = "diagnostics",
|
||
},
|
||
}
|
||
end
|
||
|
||
---@return AvanteMention[]
|
||
function M.get_chat_mentions()
|
||
local mentions = M.get_mentions()
|
||
|
||
table.insert(mentions, {
|
||
description = "file",
|
||
command = "file",
|
||
details = "add files...",
|
||
callback = function(sidebar) sidebar.file_selector:open() end,
|
||
})
|
||
|
||
table.insert(mentions, {
|
||
description = "quickfix",
|
||
command = "quickfix",
|
||
details = "add files in quickfix list to chat context",
|
||
callback = function(sidebar) sidebar.file_selector:add_quickfix_files() end,
|
||
})
|
||
|
||
table.insert(mentions, {
|
||
description = "buffers",
|
||
command = "buffers",
|
||
details = "add open buffers to the chat context",
|
||
callback = function(sidebar) sidebar.file_selector:add_buffer_files() end,
|
||
})
|
||
|
||
return mentions
|
||
end
|
||
|
||
---@return AvanteShortcut[]
|
||
function M.get_shortcuts()
|
||
local Config = require("avante.config")
|
||
|
||
-- Built-in shortcuts
|
||
local builtin_shortcuts = {
|
||
{
|
||
name = "refactor",
|
||
description = "Refactor code with best practices",
|
||
details = "Automatically refactor code to improve readability, maintainability, and follow best practices while preserving functionality",
|
||
prompt = "Please refactor this code following best practices, improving readability and maintainability while preserving functionality.",
|
||
},
|
||
{
|
||
name = "test",
|
||
description = "Generate unit tests",
|
||
details = "Create comprehensive unit tests covering edge cases, error scenarios, and various input conditions",
|
||
prompt = "Please generate comprehensive unit tests for this code, covering edge cases and error scenarios.",
|
||
},
|
||
{
|
||
name = "document",
|
||
description = "Add documentation",
|
||
details = "Add clear and comprehensive documentation including function descriptions, parameter explanations, and usage examples",
|
||
prompt = "Please add clear and comprehensive documentation to this code, including function descriptions, parameter explanations, and usage examples.",
|
||
},
|
||
{
|
||
name = "debug",
|
||
description = "Add debugging information",
|
||
details = "Add comprehensive debugging information including logging statements, error handling, and debugging utilities",
|
||
prompt = "Please add comprehensive debugging information to this code, including logging statements, error handling, and debugging utilities.",
|
||
},
|
||
{
|
||
name = "optimize",
|
||
description = "Optimize performance",
|
||
details = "Analyze and optimize code for better performance considering time complexity, memory usage, and algorithmic improvements",
|
||
prompt = "Please analyze and optimize this code for better performance, considering time complexity, memory usage, and algorithmic improvements.",
|
||
},
|
||
{
|
||
name = "security",
|
||
description = "Security review",
|
||
details = "Perform a security review identifying potential vulnerabilities, security best practices, and recommendations for improvement",
|
||
prompt = "Please perform a security review of this code, identifying potential vulnerabilities, security best practices, and recommendations for improvement.",
|
||
},
|
||
}
|
||
|
||
local user_shortcuts = Config.shortcuts or {}
|
||
local result = {}
|
||
|
||
-- Create a map of builtin shortcuts by name for quick lookup
|
||
local builtin_map = {}
|
||
for _, shortcut in ipairs(builtin_shortcuts) do
|
||
builtin_map[shortcut.name] = shortcut
|
||
end
|
||
|
||
-- Process user shortcuts first (they take precedence)
|
||
for _, user_shortcut in ipairs(user_shortcuts) do
|
||
if builtin_map[user_shortcut.name] then
|
||
-- User has overridden a builtin shortcut
|
||
table.insert(result, user_shortcut)
|
||
builtin_map[user_shortcut.name] = nil -- Remove from builtin map
|
||
else
|
||
-- User has added a new shortcut
|
||
table.insert(result, user_shortcut)
|
||
end
|
||
end
|
||
|
||
-- Add remaining builtin shortcuts that weren't overridden
|
||
for _, builtin_shortcut in pairs(builtin_map) do
|
||
table.insert(result, builtin_shortcut)
|
||
end
|
||
|
||
return result
|
||
end
|
||
|
||
---@param content string
|
||
---@return string new_content
|
||
---@return boolean has_shortcuts
|
||
function M.extract_shortcuts(content)
|
||
local shortcuts = M.get_shortcuts()
|
||
local new_content = content
|
||
local has_shortcuts = false
|
||
|
||
for _, shortcut in ipairs(shortcuts) do
|
||
local pattern = "#" .. shortcut.name
|
||
if content:match(pattern) then
|
||
has_shortcuts = true
|
||
new_content = new_content:gsub(pattern, shortcut.prompt)
|
||
end
|
||
end
|
||
|
||
return new_content, has_shortcuts
|
||
end
|
||
|
||
---@param path string
|
||
---@param set_current_buf? boolean
|
||
---@return integer bufnr
|
||
function M.open_buffer(path, set_current_buf)
|
||
if set_current_buf == nil then set_current_buf = true end
|
||
|
||
local abs_path = M.join_paths(M.get_project_root(), path)
|
||
|
||
local bufnr ---@type integer
|
||
if set_current_buf then
|
||
bufnr = vim.fn.bufnr(abs_path)
|
||
if bufnr ~= -1 and vim.api.nvim_buf_is_loaded(bufnr) and vim.bo[bufnr].modified then
|
||
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
|
||
end
|
||
vim.cmd("noautocmd edit " .. abs_path)
|
||
bufnr = vim.api.nvim_get_current_buf()
|
||
else
|
||
bufnr = vim.fn.bufnr(abs_path, true)
|
||
pcall(vim.fn.bufload, bufnr)
|
||
end
|
||
|
||
vim.cmd("filetype detect")
|
||
|
||
return bufnr
|
||
end
|
||
|
||
---@param old_lines avante.ui.Line[]
|
||
---@param new_lines avante.ui.Line[]
|
||
---@return { start_line: integer, end_line: integer, content: avante.ui.Line[] }[]
|
||
local function get_lines_diff(old_lines, new_lines)
|
||
local remaining_lines = 100
|
||
local start_line = 0
|
||
if #new_lines >= #old_lines then
|
||
start_line = math.max(#old_lines - remaining_lines, 0)
|
||
old_lines = vim.list_slice(old_lines, start_line + 1)
|
||
new_lines = vim.list_slice(new_lines, start_line + 1)
|
||
end
|
||
local diffs = {}
|
||
local prev_diff_idx = nil
|
||
for i, line in ipairs(new_lines) do
|
||
if line ~= old_lines[i] then
|
||
if prev_diff_idx == nil then prev_diff_idx = i end
|
||
else
|
||
if prev_diff_idx ~= nil then
|
||
local content = vim.list_slice(new_lines, prev_diff_idx, i - 1)
|
||
table.insert(diffs, { start_line = start_line + prev_diff_idx, end_line = start_line + i, content = content })
|
||
prev_diff_idx = nil
|
||
end
|
||
end
|
||
end
|
||
if prev_diff_idx ~= nil then
|
||
table.insert(diffs, {
|
||
start_line = start_line + prev_diff_idx,
|
||
end_line = start_line + #new_lines + 1,
|
||
content = vim.list_slice(new_lines, prev_diff_idx),
|
||
})
|
||
end
|
||
if #new_lines < #old_lines then
|
||
table.insert(
|
||
diffs,
|
||
{ start_line = start_line + #new_lines + 1, end_line = start_line + #old_lines + 1, content = {} }
|
||
)
|
||
end
|
||
table.sort(diffs, function(a, b) return a.start_line > b.start_line end)
|
||
return diffs
|
||
end
|
||
|
||
---@param bufnr integer
|
||
---@param new_lines string[]
|
||
---@return { start_line: integer, end_line: integer, content: string[] }[]
|
||
local function get_buffer_content_diffs(bufnr, new_lines)
|
||
local old_lines = api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||
local diffs = {}
|
||
local prev_diff_idx = nil
|
||
for i, line in ipairs(new_lines) do
|
||
if line ~= old_lines[i] then
|
||
if prev_diff_idx == nil then prev_diff_idx = i end
|
||
else
|
||
if prev_diff_idx ~= nil then
|
||
local content = vim.list_slice(new_lines, prev_diff_idx, i - 1)
|
||
table.insert(diffs, { start_line = prev_diff_idx, end_line = i, content = content })
|
||
prev_diff_idx = nil
|
||
end
|
||
end
|
||
end
|
||
if prev_diff_idx ~= nil then
|
||
table.insert(
|
||
diffs,
|
||
{ start_line = prev_diff_idx, end_line = #new_lines + 1, content = vim.list_slice(new_lines, prev_diff_idx) }
|
||
)
|
||
end
|
||
if #new_lines < #old_lines then
|
||
table.insert(diffs, { start_line = #new_lines + 1, end_line = #old_lines + 1, content = {} })
|
||
end
|
||
table.sort(diffs, function(a, b) return a.start_line > b.start_line end)
|
||
return diffs
|
||
end
|
||
|
||
--- Update the buffer content more efficiently by only updating the changed lines
|
||
---@param bufnr integer
|
||
---@param new_lines string[]
|
||
function M.update_buffer_content(bufnr, new_lines)
|
||
local diffs = get_buffer_content_diffs(bufnr, new_lines)
|
||
if #diffs == 0 then return end
|
||
for _, diff in ipairs(diffs) do
|
||
api.nvim_buf_set_lines(bufnr, diff.start_line - 1, diff.end_line - 1, false, diff.content)
|
||
end
|
||
end
|
||
|
||
---@param ns_id number
|
||
---@param bufnr integer
|
||
---@param old_lines avante.ui.Line[]
|
||
---@param new_lines avante.ui.Line[]
|
||
---@param skip_line_count? integer
|
||
function M.update_buffer_lines(ns_id, bufnr, old_lines, new_lines, skip_line_count)
|
||
skip_line_count = skip_line_count or 0
|
||
local diffs = get_lines_diff(old_lines, new_lines)
|
||
if #diffs == 0 then return end
|
||
for _, diff in ipairs(diffs) do
|
||
local lines = diff.content
|
||
local text_lines = vim.tbl_map(function(line) return tostring(line) end, lines)
|
||
--- rmeove newlines from text_lines
|
||
local cleaned_lines = {}
|
||
for _, line in ipairs(text_lines) do
|
||
local lines_ = vim.split(line, "\n")
|
||
cleaned_lines = vim.list_extend(cleaned_lines, lines_)
|
||
end
|
||
vim.api.nvim_buf_set_lines(
|
||
bufnr,
|
||
skip_line_count + diff.start_line - 1,
|
||
skip_line_count + diff.end_line - 1,
|
||
false,
|
||
cleaned_lines
|
||
)
|
||
for i, line in ipairs(lines) do
|
||
line:set_highlights(ns_id, bufnr, skip_line_count + diff.start_line + i - 2)
|
||
end
|
||
vim.cmd("redraw")
|
||
end
|
||
end
|
||
|
||
function M.uniform_path(path)
|
||
if type(path) ~= "string" then path = tostring(path) end
|
||
if not M.file.is_in_project(path) then return path end
|
||
local project_root = M.get_project_root()
|
||
local abs_path = M.is_absolute_path(path) and path or M.join_paths(project_root, path)
|
||
local relative_path = M.make_relative_path(abs_path, project_root)
|
||
return relative_path
|
||
end
|
||
|
||
function M.is_same_file(filepath_a, filepath_b) return M.uniform_path(filepath_a) == M.uniform_path(filepath_b) end
|
||
|
||
---Removes <think> tags, returning only text between them
|
||
---@param content string
|
||
---@return string
|
||
function M.trim_think_content(content) return (content:gsub("^<think>.-</think>", "", 1)) end
|
||
|
||
local _filetype_lru_cache = LRUCache:new(60)
|
||
|
||
function M.get_filetype(filepath)
|
||
local cached_filetype = _filetype_lru_cache:get(filepath)
|
||
if cached_filetype then return cached_filetype end
|
||
-- Some files are sometimes not detected correctly when buffer is not included
|
||
-- https://github.com/neovim/neovim/issues/27265
|
||
|
||
local buf = vim.api.nvim_create_buf(false, true)
|
||
local filetype = vim.filetype.match({ filename = filepath, buf = buf }) or ""
|
||
vim.api.nvim_buf_delete(buf, { force = true })
|
||
-- Parse the first filetype from a multifiltype file
|
||
filetype = filetype:gsub("%..*$", "")
|
||
_filetype_lru_cache:set(filepath, filetype)
|
||
return filetype
|
||
end
|
||
|
||
---@param filepath string
|
||
---@return string[]|nil lines
|
||
---@return string|nil error
|
||
function M.read_file_from_buf_or_disk(filepath)
|
||
local abs_path = filepath:sub(1, 7) == "term://" and filepath or M.join_paths(M.get_project_root(), filepath)
|
||
--- Lookup if the file is loaded in a buffer
|
||
local ok, bufnr = pcall(vim.fn.bufnr, abs_path)
|
||
if ok then
|
||
if bufnr ~= -1 and vim.api.nvim_buf_is_loaded(bufnr) then
|
||
-- If buffer exists and is loaded, get buffer content
|
||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||
return lines, nil
|
||
end
|
||
end
|
||
|
||
local stat = vim.uv.fs_stat(abs_path)
|
||
if stat and stat.type == "directory" then return {}, "Cannot read a directory as file" .. filepath end
|
||
|
||
-- Fallback: read file from disk
|
||
local file, open_err = io.open(abs_path, "r")
|
||
if file then
|
||
local content = file:read("*all")
|
||
file:close()
|
||
content = content:gsub("\r\n", "\n")
|
||
return vim.split(content, "\n"), nil
|
||
else
|
||
return {}, open_err
|
||
end
|
||
end
|
||
|
||
---Check if an icon plugin is installed
|
||
---@return boolean
|
||
function M.icons_enabled() return M.has("nvim-web-devicons") or M.has("mini.icons") or M.has("mini.nvim") end
|
||
|
||
---Display an string with icon, if an icon plugin is available.
|
||
---Dev icons are an optional install for avante, this function prevents ugly chars
|
||
---being displayed by displaying fallback options or nothing at all.
|
||
---@param string_with_icon string
|
||
---@param utf8_fallback string|nil
|
||
---@return string
|
||
function M.icon(string_with_icon, utf8_fallback)
|
||
if M.icons_enabled() then
|
||
return string_with_icon
|
||
else
|
||
return utf8_fallback or ""
|
||
end
|
||
end
|
||
|
||
function M.deep_extend_with_metatable(behavior, ...)
|
||
local tables = { ... }
|
||
local base = tables[1]
|
||
if behavior == "keep" then base = tables[#tables] end
|
||
local mt = getmetatable(base)
|
||
|
||
local result = vim.tbl_deep_extend(behavior, ...)
|
||
|
||
if mt then setmetatable(result, mt) end
|
||
|
||
return result
|
||
end
|
||
|
||
function M.utc_now()
|
||
local utc_date = os.date("!*t")
|
||
---@diagnostic disable-next-line: param-type-mismatch
|
||
local utc_time = os.time(utc_date)
|
||
return os.date("%Y-%m-%d %H:%M:%S", utc_time)
|
||
end
|
||
|
||
---@param dt1 string
|
||
---@param dt2 string
|
||
---@return integer delta_seconds
|
||
function M.datetime_diff(dt1, dt2)
|
||
local pattern = "(%d+)-(%d+)-(%d+) (%d+):(%d+):(%d+)"
|
||
local y1, m1, d1, h1, min1, s1 = dt1:match(pattern)
|
||
local y2, m2, d2, h2, min2, s2 = dt2:match(pattern)
|
||
|
||
local time1 = os.time({ year = y1, month = m1, day = d1, hour = h1, min = min1, sec = s1 })
|
||
local time2 = os.time({ year = y2, month = m2, day = d2, hour = h2, min = min2, sec = s2 })
|
||
|
||
local delta_seconds = os.difftime(time2, time1)
|
||
return delta_seconds
|
||
end
|
||
|
||
---@param iso_str string
|
||
---@return string|nil
|
||
---@return string|nil error
|
||
function M.parse_iso8601_date(iso_str)
|
||
local year, month, day, hour, min, sec = iso_str:match("(%d+)-(%d+)-(%d+)T(%d+):(%d+):(%d+)Z")
|
||
if not year then return nil, "Invalid ISO 8601 format" end
|
||
|
||
local time_table = {
|
||
year = tonumber(year),
|
||
month = tonumber(month),
|
||
day = tonumber(day),
|
||
hour = tonumber(hour),
|
||
min = tonumber(min),
|
||
sec = tonumber(sec),
|
||
isdst = false,
|
||
}
|
||
|
||
local timestamp = os.time(time_table)
|
||
|
||
return tostring(os.date("%Y-%m-%d %H:%M:%S", timestamp)), nil
|
||
end
|
||
|
||
function M.random_string(length)
|
||
local charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||
local result = {}
|
||
for _ = 1, length do
|
||
local rand = math.random(1, #charset)
|
||
table.insert(result, charset:sub(rand, rand))
|
||
end
|
||
return table.concat(result)
|
||
end
|
||
|
||
function M.is_left_adjacent(win_a, win_b)
|
||
if not vim.api.nvim_win_is_valid(win_a) or not vim.api.nvim_win_is_valid(win_b) then return false end
|
||
|
||
local _, col_a = unpack(vim.fn.win_screenpos(win_a))
|
||
local _, col_b = unpack(vim.fn.win_screenpos(win_b))
|
||
local width_a = vim.api.nvim_win_get_width(win_a)
|
||
|
||
local right_edge_a = col_a + width_a
|
||
|
||
return right_edge_a + 1 == col_b
|
||
end
|
||
|
||
function M.is_top_adjacent(win_a, win_b)
|
||
local row_a, _ = unpack(vim.fn.win_screenpos(win_a))
|
||
local row_b, _ = unpack(vim.fn.win_screenpos(win_b))
|
||
local height_a = vim.api.nvim_win_get_height(win_a)
|
||
return row_a + height_a + 1 == row_b
|
||
end
|
||
|
||
function M.should_hidden_border(win_a, win_b)
|
||
return M.is_left_adjacent(win_a, win_b) or M.is_top_adjacent(win_a, win_b)
|
||
end
|
||
|
||
---@param fields AvanteLLMToolParamField[]
|
||
---@return table[] properties
|
||
---@return string[] required
|
||
function M.llm_tool_param_fields_to_json_schema(fields)
|
||
local properties = {}
|
||
local required = {}
|
||
for _, field in ipairs(fields) do
|
||
if field.type == "object" and field.fields then
|
||
local properties_, required_ = M.llm_tool_param_fields_to_json_schema(field.fields)
|
||
properties[field.name] = {
|
||
type = field.type,
|
||
description = field.get_description and field.get_description() or field.description,
|
||
properties = properties_,
|
||
required = required_,
|
||
}
|
||
elseif field.type == "array" and field.items then
|
||
local properties_ = M.llm_tool_param_fields_to_json_schema({ field.items })
|
||
local _, obj = next(properties_)
|
||
properties[field.name] = {
|
||
type = field.type,
|
||
description = field.get_description and field.get_description() or field.description,
|
||
items = obj,
|
||
}
|
||
else
|
||
properties[field.name] = {
|
||
type = field.type,
|
||
description = field.get_description and field.get_description() or field.description,
|
||
}
|
||
if field.choices then properties[field.name].enum = field.choices end
|
||
end
|
||
if not field.optional then table.insert(required, field.name) end
|
||
end
|
||
return properties, required
|
||
end
|
||
|
||
---@return AvanteSlashCommand[]
|
||
function M.get_commands()
|
||
local Config = require("avante.config")
|
||
|
||
---@param items_ {name: string, description: string, shorthelp?: string}[]
|
||
---@return string
|
||
local function get_help_text(items_)
|
||
local help_text = ""
|
||
for _, item in ipairs(items_) do
|
||
help_text = help_text .. "- " .. item.name .. ": " .. (item.shorthelp or item.description) .. "\n"
|
||
end
|
||
return help_text
|
||
end
|
||
|
||
local builtin_items = {
|
||
{ description = "Show help message", name = "help" },
|
||
{ description = "Init AGENTS.md based on the current project", name = "init" },
|
||
{ description = "Clear chat history", name = "clear" },
|
||
{ description = "New chat", name = "new" },
|
||
{ description = "Compact history messages to save tokens", name = "compact" },
|
||
{
|
||
shorthelp = "Ask a question about specific lines",
|
||
description = "/lines <start>-<end> <question>",
|
||
name = "lines",
|
||
},
|
||
{ description = "Commit the changes", name = "commit" },
|
||
}
|
||
|
||
---@type {[AvanteSlashCommandBuiltInName]: AvanteSlashCommandCallback}
|
||
local builtin_cbs = {
|
||
help = function(sidebar, args, cb)
|
||
local help_text = get_help_text(builtin_items)
|
||
sidebar:update_content(help_text, { focus = false, scroll = false })
|
||
if cb then cb(args) end
|
||
end,
|
||
clear = function(sidebar, args, cb) sidebar:clear_history(args, cb) end,
|
||
new = function(sidebar, args, cb) sidebar:new_chat(args, cb) end,
|
||
compact = function(sidebar, args, cb) sidebar:compact_history_messages(args, cb) end,
|
||
init = function(sidebar, args, cb) sidebar:init_current_project(args, cb) end,
|
||
lines = function(_, args, cb)
|
||
if cb then cb(args) end
|
||
end,
|
||
commit = function(_, _, cb)
|
||
local question = "Please commit the changes"
|
||
if cb then cb(question) end
|
||
end,
|
||
}
|
||
|
||
local builtin_commands = vim
|
||
.iter(builtin_items)
|
||
:map(
|
||
---@param item AvanteSlashCommand
|
||
function(item)
|
||
return {
|
||
name = item.name,
|
||
description = item.description,
|
||
callback = builtin_cbs[item.name],
|
||
details = item.shorthelp and table.concat({ item.shorthelp, item.description }, "\n") or item.description,
|
||
}
|
||
end
|
||
)
|
||
:totable()
|
||
|
||
return vim.list_extend(builtin_commands, Config.slash_commands)
|
||
end
|
||
|
||
function M.get_timestamp() return tostring(os.date("%Y-%m-%d %H:%M:%S")) end
|
||
|
||
function M.uuid()
|
||
local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx"
|
||
return string.gsub(template, "[xy]", function(c)
|
||
local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb)
|
||
return string.format("%x", v)
|
||
end)
|
||
end
|
||
|
||
---@param tool_use AvanteLLMToolUse
|
||
function M.tool_use_to_xml(tool_use)
|
||
local tool_use_json = vim.json.encode({
|
||
name = tool_use.name,
|
||
input = tool_use.input,
|
||
})
|
||
local xml = string.format("<tool_use>%s</tool_use>", tool_use_json)
|
||
return xml
|
||
end
|
||
|
||
---@param tool_use AvanteLLMToolUse
|
||
function M.is_edit_tool_use(tool_use)
|
||
return tool_use.name == "replace_in_file"
|
||
or tool_use.name == "edit_file"
|
||
or (tool_use.name == "str_replace_editor" and tool_use.input.command == "str_replace")
|
||
or (tool_use.name == "str_replace_based_edit_tool" and tool_use.input.command == "str_replace")
|
||
end
|
||
|
||
---Counts number of strings in text, accounting for possibility of a trailing newline
|
||
---@param str string | nil
|
||
---@return integer
|
||
function M.count_lines(str)
|
||
if not str or str == "" then return 0 end
|
||
|
||
local _, count = str:gsub("\n", "\n")
|
||
-- Number of lines is one more than number of newlines unless we have a trailing newline
|
||
return str:sub(-1) ~= "\n" and count + 1 or count
|
||
end
|
||
|
||
function M.tbl_override(value, override)
|
||
override = override or {}
|
||
if type(override) == "function" then return override(value) or value end
|
||
return vim.tbl_extend("force", value, override)
|
||
end
|
||
|
||
function M.call_once(func)
|
||
local called = false
|
||
return function(...)
|
||
if called then return end
|
||
called = true
|
||
return func(...)
|
||
end
|
||
end
|
||
|
||
--- Some models (e.g., gpt-4o) cannot correctly return diff content and often miss the SEARCH line, so this needs to be manually fixed in such cases.
|
||
---@param diff string
|
||
---@return string
|
||
function M.fix_diff(diff)
|
||
diff = diff2search_replace(diff)
|
||
-- Normalize block headers to the expected ones (fix for some LLMs output)
|
||
diff = diff:gsub("<<<<<<<%s*SEARCH", "------- SEARCH")
|
||
diff = diff:gsub(">>>>>>>%s*REPLACE", "+++++++ REPLACE")
|
||
diff = diff:gsub("-------%s*REPLACE", "+++++++ REPLACE")
|
||
diff = diff:gsub("------- ", "------- SEARCH\n")
|
||
diff = diff:gsub("======= ", "=======\n")
|
||
|
||
local fixed_diff_lines = {}
|
||
local lines = vim.split(diff, "\n")
|
||
local first_line = lines[1]
|
||
if first_line and first_line:match("^%s*```") then
|
||
table.insert(fixed_diff_lines, first_line)
|
||
table.insert(fixed_diff_lines, "------- SEARCH")
|
||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
||
else
|
||
table.insert(fixed_diff_lines, "------- SEARCH")
|
||
if first_line:match("------- SEARCH") then
|
||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 2)
|
||
else
|
||
fixed_diff_lines = vim.list_extend(fixed_diff_lines, lines, 1)
|
||
end
|
||
end
|
||
local the_final_diff_lines = {}
|
||
local has_split_line = false
|
||
local replace_block_closed = false
|
||
local should_delete_following_lines = false
|
||
for _, line in ipairs(fixed_diff_lines) do
|
||
if should_delete_following_lines then goto continue end
|
||
if line:match("^-------%s*SEARCH") then has_split_line = false end
|
||
if line:match("^=======") then
|
||
if has_split_line then
|
||
should_delete_following_lines = true
|
||
goto continue
|
||
end
|
||
has_split_line = true
|
||
end
|
||
if line:match("^+++++++%s*REPLACE") then
|
||
if not has_split_line then
|
||
table.insert(the_final_diff_lines, "=======")
|
||
has_split_line = true
|
||
goto continue
|
||
else
|
||
replace_block_closed = true
|
||
end
|
||
end
|
||
table.insert(the_final_diff_lines, line)
|
||
::continue::
|
||
end
|
||
if not replace_block_closed then table.insert(the_final_diff_lines, "+++++++ REPLACE") end
|
||
return table.concat(the_final_diff_lines, "\n")
|
||
end
|
||
|
||
function M.get_unified_diff(text1, text2, opts)
|
||
opts = opts or {}
|
||
opts.result_type = "unified"
|
||
opts.ctxlen = opts.ctxlen or 3
|
||
|
||
return vim.diff(text1, text2, opts)
|
||
end
|
||
|
||
return M
|