From 8c71e1f62456baa006fea7697a19654796b1c248 Mon Sep 17 00:00:00 2001 From: yetone Date: Fri, 30 Aug 2024 15:01:23 +0800 Subject: [PATCH] feat: support apply current code snippet (#391) --- lua/avante/sidebar.lua | 51 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index c2bf51c..2bc9dcd 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -183,15 +183,26 @@ local function realign_line_numbers(code_lines, snippet) return snippet end +---@class AvanteCodeSnippet +---@field range integer[] +---@field content string +---@field lang string +---@field explanation string +---@field start_line_in_response_buf integer +---@field end_line_in_response_buf integer + +---@param code_content string +---@param response_content string +---@return AvanteCodeSnippet[] local function extract_code_snippets(code_content, response_content) local code_lines = vim.split(code_content, "\n") local snippets = {} local current_snippet = {} local in_code_block = false - local lang, start_line, end_line + local lang, start_line, end_line, start_line_in_response_buf local explanation = "" - for _, line in ipairs(vim.split(response_content, "\n")) do + for idx, line in ipairs(vim.split(response_content, "\n")) do local start_line_str, end_line_str = line:match("^Replace lines: (%d+)-(%d+)") if start_line_str ~= nil and end_line_str ~= nil then start_line = tonumber(start_line_str) @@ -205,6 +216,8 @@ local function extract_code_snippets(code_content, response_content) content = table.concat(current_snippet, "\n"), lang = lang, explanation = explanation, + start_line_in_response_buf = start_line_in_response_buf, + end_line_in_response_buf = idx, } snippet = realign_line_numbers(code_lines, snippet) table.insert(snippets, snippet) @@ -219,6 +232,7 @@ local function extract_code_snippets(code_content, response_content) lang = "text" end in_code_block = true + start_line_in_response_buf = idx end elseif in_code_block then table.insert(current_snippet, line) @@ -318,10 +332,26 @@ local function parse_codeblocks(buf) return codeblocks end -function Sidebar:apply() +---@param current_cursor boolean +function Sidebar:apply(current_cursor) local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n") - local response = self:get_content_between_separators() + local response, response_start_line = self:get_content_between_separators() local snippets = extract_code_snippets(content, response) + if current_cursor then + if self.result and self.result.winid then + local cursor_line = Utils.get_cursor_pos(self.result.winid) + for _, snippet in ipairs(snippets) do + if + cursor_line >= snippet.start_line_in_response_buf + response_start_line + and cursor_line <= snippet.end_line_in_response_buf + response_start_line + then + snippets = { snippet } + break + end + end + end + end + local conflict_content = get_conflict_content(content, snippets) vim.defer_fn(function() @@ -501,7 +531,7 @@ function Sidebar:on_mount() current_apply_extmark_id = api.nvim_buf_set_extmark(self.result.bufnr, CODEBLOCK_KEYBINDING_NAMESPACE, block.start_line, -1, { - virt_text = { { " [: apply patch] ", "Keyword" } }, + virt_text = { { " [: apply this, : apply all] ", "Keyword" } }, virt_text_pos = "right_align", hl_group = "Keyword", priority = PRIORITY, @@ -509,8 +539,11 @@ function Sidebar:on_mount() end local function bind_apply_key() + vim.keymap.set("n", "a", function() + self:apply(true) + end, { buffer = self.result.bufnr, noremap = true, silent = true }) vim.keymap.set("n", "A", function() - self:apply() + self:apply(false) end, { buffer = self.result.bufnr, noremap = true, silent = true }) end @@ -878,7 +911,7 @@ function Sidebar:update_content_with_history(history) self:update_content(content) end ----@return string +---@return string, integer function Sidebar:get_content_between_separators() local separator = "---" local cursor_line, _ = Utils.get_cursor_pos() @@ -910,7 +943,7 @@ function Sidebar:get_content_between_separators() end local content = table.concat(vim.list_slice(lines, start_line, end_line), "\n") - return content + return content, start_line end ---@alias AvanteSlashCommands "clear" | "help" | "lines" @@ -1150,7 +1183,7 @@ function Sidebar:create_input() ) if Config.behaviour.auto_apply_diff_after_generation then - self:apply() + self:apply(false) end end