feat(api): enable customizable calls functions (#457)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham
2024-09-02 12:22:48 -04:00
committed by GitHub
parent d520f09333
commit 7266661413
9 changed files with 291 additions and 267 deletions

View File

@@ -1,15 +1,12 @@
-- This file COPY and MODIFIED based on: https://github.com/akinsho/git-conflict.nvim/blob/main/lua/git-conflict.lua
local M = {}
local api = vim.api
local Config = require("avante.config")
local Utils = require("avante.utils")
local Highlights = require("avante.highlights")
local fn = vim.fn
local api = vim.api
local fmt = string.format
local map = vim.keymap.set
local H = {}
local M = {}
-----------------------------------------------------------------------------//
-- REFERENCES:
-----------------------------------------------------------------------------//
@@ -31,7 +28,6 @@ local map = vim.keymap.set
--- @class AvanteConflictHighlights
--- @field current string
--- @field incoming string
--- @field ancestor string?
---@class RangeMark
---@field label integer
@@ -40,7 +36,6 @@ local map = vim.keymap.set
--- @class PositionMarks
--- @field current RangeMark
--- @field incoming RangeMark
--- @field ancestor RangeMark
--- @class Range
--- @field range_start integer
@@ -69,7 +64,6 @@ local SIDES = {
THEIRS = "theirs",
ALL_THEIRS = "all_theirs",
BOTH = "both",
BASE = "base",
NONE = "none",
CURSOR = "cursor",
}
@@ -78,7 +72,6 @@ local SIDES = {
local name_map = {
ours = "current",
theirs = "incoming",
base = "ancestor",
both = "both",
none = "none",
cursor = "cursor",
@@ -86,10 +79,8 @@ local name_map = {
local CURRENT_HL = "AvanteConflictCurrent"
local INCOMING_HL = "AvanteConflictIncoming"
local ANCESTOR_HL = "AvanteConflictAncestor"
local CURRENT_LABEL_HL = "AvanteConflictCurrentLabel"
local INCOMING_LABEL_HL = "AvanteConflictIncomingLabel"
local ANCESTOR_LABEL_HL = "AvanteConflictAncestorLabel"
local PRIORITY = vim.highlight.priorities.user
local NAMESPACE = api.nvim_create_namespace("avante-conflict")
local KEYBINDING_NAMESPACE = api.nvim_create_namespace("avante-conflict-keybinding")
@@ -98,7 +89,6 @@ local AUGROUP_NAME = "avante_conflicts"
local conflict_start = "^<<<<<<<"
local conflict_middle = "^======="
local conflict_end = "^>>>>>>>"
local conflict_ancestor = "^|||||||"
-----------------------------------------------------------------------------//
@@ -208,16 +198,7 @@ local function highlight_conflicts(positions, lines)
position.marks = {
current = { label = curr_label_id, content = curr_id },
incoming = { label = inc_label_id, content = inc_id },
ancestor = {},
}
if not vim.tbl_isempty(position.ancestor) then
local ancestor_start = position.ancestor.range_start
local ancestor_end = position.ancestor.range_end
local ancestor_label = lines[ancestor_start + 1] .. " (Base changes)"
local id = hl_range(bufnr, ANCESTOR_HL, ancestor_start + 1, ancestor_end + 1)
local label_id = draw_section_label(bufnr, ANCESTOR_LABEL_HL, ancestor_label, ancestor_start)
position.marks.ancestor = { label = label_id, content = id }
end
end
end
@@ -228,7 +209,7 @@ end
---@return ConflictPosition[]
local function detect_conflicts(lines)
local positions = {}
local position, has_middle, has_ancestor = nil, false, false
local position, has_middle = nil, false
for index, line in ipairs(lines) do
local lnum = index - 1
if line:match(conflict_start) then
@@ -236,25 +217,12 @@ local function detect_conflicts(lines)
current = { range_start = lnum, content_start = lnum + 1 },
middle = {},
incoming = {},
ancestor = {},
}
end
if position ~= nil and line:match(conflict_ancestor) then
has_ancestor = true
position.ancestor.range_start = lnum
position.ancestor.content_start = lnum + 1
position.current.range_end = lnum - 1
position.current.content_end = lnum - 1
end
if position ~= nil and line:match(conflict_middle) then
has_middle = true
if has_ancestor then
position.ancestor.content_end = lnum - 1
position.ancestor.range_end = lnum - 1
else
position.current.range_end = lnum - 1
position.current.content_end = lnum - 1
end
position.current.range_end = lnum - 1
position.current.content_end = lnum - 1
position.middle.range_start = lnum
position.middle.range_end = lnum + 1
position.incoming.range_start = lnum + 1
@@ -265,7 +233,7 @@ local function detect_conflicts(lines)
position.incoming.content_end = lnum - 1
positions[#positions + 1] = position
position, has_middle, has_ancestor = nil, false, false
position, has_middle = nil, false
end
end
return #positions > 0, positions
@@ -376,7 +344,7 @@ local function parse_buffer(bufnr, range_start, range_end)
else
M.clear(bufnr)
end
if prev_conflicts ~= has_conflict or not vim.b[bufnr].conflict_mappings_set then
if prev_conflicts ~= has_conflict or not vim.b[bufnr].avante_conflict_mappings_set then
local pattern = has_conflict and "AvanteConflictDetected" or "AvanteConflictResolved"
api.nvim_exec_autocmds("User", { pattern = pattern })
end
@@ -384,6 +352,8 @@ end
---Process a buffer if the changed tick has changed
---@param bufnr integer?
---@param range_start integer?
---@param range_end integer?
function M.process(bufnr, range_start, range_end)
bufnr = bufnr or api.nvim_get_current_buf()
if visited_buffers[bufnr] and visited_buffers[bufnr].tick == vim.b[bufnr].changedtick then
@@ -392,129 +362,62 @@ function M.process(bufnr, range_start, range_end)
parse_buffer(bufnr, range_start, range_end)
end
-----------------------------------------------------------------------------//
-- Commands
-----------------------------------------------------------------------------//
local function set_commands()
local command = api.nvim_create_user_command
command("AvanteConflictListQf", function()
M.conflicts_to_qf_items(function(items)
if #items > 0 then
fn.setqflist(items, "r")
if type(Config.diff.list_opener) == "function" then
Config.diff.list_opener()
else
vim.cmd(Config.diff.list_opener)
end
end
end)
end, { nargs = 0 })
command("AvanteConflictChooseOurs", function()
M.choose("ours")
end, { nargs = 0 })
command("AvanteConflictChooseTheirs", function()
M.choose("theirs")
end, { nargs = 0 })
command("AvanteConflictChooseAllTheirs", function()
M.choose("all_theirs")
end, { nargs = 0 })
command("AvanteConflictChooseBoth", function()
M.choose("both")
end, { nargs = 0 })
command("AvanteConflictChooseCursor", function()
M.choose("cursor")
end, { nargs = 0 })
command("AvanteConflictChooseBase", function()
M.choose("base")
end, { nargs = 0 })
command("AvanteConflictChooseNone", function()
M.choose("none")
end, { nargs = 0 })
command("AvanteConflictNextConflict", function()
M.find_next("ours")
end, { nargs = 0 })
command("AvanteConflictPrevConflict", function()
M.find_prev("ours")
end, { nargs = 0 })
end
-----------------------------------------------------------------------------//
-- Mappings
-----------------------------------------------------------------------------//
local function set_plug_mappings()
local function opts(desc)
return { silent = true, desc = "Git Conflict: " .. desc }
end
map({ "n", "v" }, "<Plug>(git-conflict-ours)", "<Cmd>AvanteConflictChooseOurs<CR>", opts("Choose Ours"))
map({ "n", "v" }, "<Plug>(git-conflict-both)", "<Cmd>AvanteConflictChooseBoth<CR>", opts("Choose Both"))
map({ "n", "v" }, "<Plug>(git-conflict-none)", "<Cmd>AvanteConflictChooseNone<CR>", opts("Choose None"))
map({ "n", "v" }, "<Plug>(git-conflict-theirs)", "<Cmd>AvanteConflictChooseTheirs<CR>", opts("Choose Theirs"))
map(
{ "n", "v" },
"<Plug>(git-conflict-all-theirs)",
"<Cmd>AvanteConflictChooseAllTheirs<CR>",
opts("Choose All Theirs")
)
map("n", "<Plug>(git-conflict-cursor)", "<Cmd>AvanteConflictChooseCursor<CR>", opts("Choose Cursor"))
map("n", "<Plug>(git-conflict-next-conflict)", "<Cmd>AvanteConflictNextConflict<CR>", opts("Next Conflict"))
map("n", "<Plug>(git-conflict-prev-conflict)", "<Cmd>AvanteConflictPrevConflict<CR>", opts("Previous Conflict"))
end
---@param bufnr integer given buffer id
local function setup_buffer_mappings(bufnr)
H.setup_buffer_mappings = function(bufnr)
---@param desc string
local function opts(desc)
return { silent = true, buffer = bufnr, desc = "Git Conflict: " .. desc }
return { silent = true, buffer = bufnr, desc = "avante(conflict): " .. desc }
end
map({ "n", "v" }, Config.diff.mappings.ours, "<Plug>(git-conflict-ours)", opts("Choose Ours"))
map({ "n", "v" }, Config.diff.mappings.both, "<Plug>(git-conflict-both)", opts("Choose Both"))
map({ "n", "v" }, Config.diff.mappings.none, "<Plug>(git-conflict-none)", opts("Choose None"))
map({ "n", "v" }, Config.diff.mappings.theirs, "<Plug>(git-conflict-theirs)", opts("Choose Theirs"))
map({ "n", "v" }, Config.diff.mappings.all_theirs, "<Plug>(git-conflict-all-theirs)", opts("Choose All Theirs"))
map({ "v", "v" }, Config.diff.mappings.ours, "<Plug>(git-conflict-ours)", opts("Choose Ours"))
map("n", Config.diff.mappings.cursor, "<Plug>(git-conflict-cursor)", opts("Choose Cursor"))
-- map('V', Config.diff.mappings.ours, '<Plug>(git-conflict-ours)', opts('Choose Ours'))
map("n", Config.diff.mappings.prev, "<Plug>(git-conflict-prev-conflict)", opts("Previous Conflict"))
map("n", Config.diff.mappings.next, "<Plug>(git-conflict-next-conflict)", opts("Next Conflict"))
vim.b[bufnr].conflict_mappings_set = true
vim.keymap.set({ "n", "v" }, Config.diff.mappings.ours, function()
M.choose("ours")
end, opts("choose ours"))
vim.keymap.set({ "n", "v" }, Config.diff.mappings.both, function()
M.choose("both")
end, opts("choose both"))
vim.keymap.set({ "n", "v" }, Config.diff.mappings.theirs, function()
M.choose("theirs")
end, opts("choose theirs"))
vim.keymap.set({ "n", "v" }, Config.diff.mappings.all_theirs, function()
M.choose("all_theirs")
end, opts("choose all theirs"))
vim.keymap.set("n", Config.diff.mappings.cursor, function()
M.choose("cursor")
end, opts("choose under cursor"))
vim.keymap.set("n", Config.diff.mappings.prev, function()
M.find_prev("ours")
end, opts("previous conflict"))
vim.keymap.set("n", Config.diff.mappings.next, function()
M.find_next("ours")
end, opts("next conflict"))
vim.b[bufnr].avante_conflict_mappings_set = true
end
---@param key string
---@param mode "'n'|'v'|'o'|'nv'|'nvo'"?
---@return boolean
local function is_mapped(key, mode)
return fn.hasmapto(key, mode or "n") > 0
end
local function clear_buffer_mappings(bufnr)
if not bufnr or not vim.b[bufnr].conflict_mappings_set then
---@param bufnr integer
H.clear_buffer_mappings = function(bufnr)
if not bufnr or not vim.b[bufnr].avante_conflict_mappings_set then
return
end
for _, mapping in pairs(Config.diff.mappings) do
if is_mapped(mapping) then
if vim.fn.hasmapto(mapping, "n") > 0 then
api.nvim_buf_del_keymap(bufnr, "n", mapping)
end
end
vim.b[bufnr].conflict_mappings_set = false
vim.b[bufnr].avante_conflict_mappings_set = false
end
M.augroup = api.nvim_create_augroup(AUGROUP_NAME, { clear = true })
function M.setup()
Highlights.conflict_highlights()
set_commands()
set_plug_mappings()
local augroup = api.nvim_create_augroup(AUGROUP_NAME, { clear = true })
local previous_inlay_enabled = nil
api.nvim_create_autocmd("User", {
group = augroup,
group = M.augroup,
pattern = "AvanteConflictDetected",
callback = function(ev)
vim.diagnostic.enable(false, { bufnr = ev.buf })
@@ -522,12 +425,12 @@ function M.setup()
previous_inlay_enabled = vim.lsp.inlay_hint.is_enabled({ bufnr = ev.buf })
vim.lsp.inlay_hint.enable(false, { bufnr = ev.buf })
end
setup_buffer_mappings(ev.buf)
H.setup_buffer_mappings(ev.buf)
end,
})
api.nvim_create_autocmd("User", {
group = AUGROUP_NAME,
group = M.augroup,
pattern = "AvanteConflictResolved",
callback = function(ev)
vim.diagnostic.enable(true, { bufnr = ev.buf })
@@ -535,7 +438,7 @@ function M.setup()
vim.lsp.inlay_hint.enable(previous_inlay_enabled, { bufnr = ev.buf })
previous_inlay_enabled = nil
end
clear_buffer_mappings(ev.buf)
H.clear_buffer_mappings(ev.buf)
end,
})
@@ -565,7 +468,7 @@ local function quickfix_items_from_positions(item, items, visited_buf)
if vim.tbl_contains({ name_map.ours, name_map.theirs, name_map.base }, key) and not vim.tbl_isempty(value) then
local lnum = value.range_start + 1
local next_item = vim.deepcopy(item)
next_item.text = fmt("%s change", key, lnum)
next_item.text = string.format("%s change", key, lnum)
next_item.lnum = lnum
next_item.col = 0
table.insert(items, next_item)
@@ -711,9 +614,6 @@ function M.process_position(bufnr, side, position, enable_autojump)
api.nvim_buf_set_lines(0, pos_start, pos_end, false, lines)
api.nvim_buf_del_extmark(0, NAMESPACE, position.marks.incoming.label)
api.nvim_buf_del_extmark(0, NAMESPACE, position.marks.current.label)
if position.marks.ancestor.label then
api.nvim_buf_del_extmark(0, NAMESPACE, position.marks.ancestor.label)
end
parse_buffer(bufnr)
if enable_autojump and Config.diff.autojump then
M.find_next(side)