diff --git a/lua/avante/llm_tools/bash.lua b/lua/avante/llm_tools/bash.lua index 7e75f03..c64a172 100644 --- a/lua/avante/llm_tools/bash.lua +++ b/lua/avante/llm_tools/bash.lua @@ -219,7 +219,7 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ rel_path: string, command: string }> -function M.func(opts, on_log, on_complete) +function M.func(opts, on_log, on_complete, session_ctx) local abs_path = Helpers.get_abs_path(opts.rel_path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end @@ -248,7 +248,9 @@ function M.func(opts, on_log, on_complete) local result, err = handle_result(output, exit_code) on_complete(result, err) end, abs_path) - end + end, + { focus = true }, + session_ctx ) end diff --git a/lua/avante/llm_tools/create.lua b/lua/avante/llm_tools/create.lua index 8e4e0e5..81300d0 100644 --- a/lua/avante/llm_tools/create.lua +++ b/lua/avante/llm_tools/create.lua @@ -45,7 +45,7 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, file_text: string }> -function M.func(opts, on_log, on_complete) +function M.func(opts, on_log, on_complete, session_ctx) if not on_complete then return false, "on_complete not provided" end if on_log then on_log("path: " .. opts.path) end if Helpers.already_in_context(opts.path) then @@ -60,11 +60,11 @@ function M.func(opts, on_log, on_complete) local bufnr, err = Helpers.get_bufnr(abs_path) if err then return false, err end vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) - Helpers.confirm("Are you sure you want to create this file?", function(ok) + Helpers.confirm("Are you sure you want to create this file?", function(ok, reason) if not ok then -- close the buffer vim.api.nvim_buf_delete(bufnr, { force = true }) - on_complete(false, "User canceled") + on_complete(false, "User declined, reason: " .. (reason or "unknown")) return end -- save the file @@ -75,7 +75,7 @@ function M.func(opts, on_log, on_complete) vim.cmd("write") vim.api.nvim_set_current_win(current_winid) on_complete(true, nil) - end) + end, { focus = true }, session_ctx) end return M diff --git a/lua/avante/llm_tools/helpers.lua b/lua/avante/llm_tools/helpers.lua index 788e4eb..3fe523b 100644 --- a/lua/avante/llm_tools/helpers.lua +++ b/lua/avante/llm_tools/helpers.lua @@ -21,10 +21,15 @@ function M.get_abs_path(rel_path) end ---@param message string ----@param callback fun(yes: boolean) ----@param opts? { focus?: boolean } +---@param callback fun(yes: boolean, reason?: string) +---@param confirm_opts? { focus?: boolean } +---@param session_ctx? table ---@return avante.ui.Confirm | nil -function M.confirm(message, callback, opts) +function M.confirm(message, callback, confirm_opts, session_ctx) + if session_ctx and session_ctx.always_yes then + callback(true) + return + end local Confirm = require("avante.ui.confirm") local sidebar = require("avante").get() if not sidebar or not sidebar.input_container or not sidebar.input_container.winid then @@ -32,8 +37,22 @@ function M.confirm(message, callback, opts) callback(false) return end - local confirm_opts = vim.tbl_deep_extend("force", { container_winid = sidebar.input_container.winid }, opts or {}) - M.confirm_popup = Confirm:new(message, callback, confirm_opts) + confirm_opts = vim.tbl_deep_extend("force", { container_winid = sidebar.input_container.winid }, confirm_opts or {}) + M.confirm_popup = Confirm:new(message, function(type, reason) + if type == "yes" then + callback(true) + return + end + if type == "all" then + if session_ctx then session_ctx.always_yes = true end + callback(true) + return + end + if type == "no" then + callback(false, reason) + return + end + end, confirm_opts) M.confirm_popup:open() return M.confirm_popup end diff --git a/lua/avante/llm_tools/init.lua b/lua/avante/llm_tools/init.lua index 927b57f..3f809b2 100644 --- a/lua/avante/llm_tools/init.lua +++ b/lua/avante/llm_tools/init.lua @@ -42,11 +42,17 @@ function M.str_replace_editor(opts, on_log, on_complete, session_ctx) return view(opts_, on_log, on_complete, session_ctx) end if opts.command == "str_replace" then - return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete) + return require("avante.llm_tools.str_replace").func(opts, on_log, on_complete, session_ctx) + end + if opts.command == "create" then + return require("avante.llm_tools.create").func(opts, on_log, on_complete, session_ctx) + end + if opts.command == "insert" then + return require("avante.llm_tools.insert").func(opts, on_log, on_complete, session_ctx) + end + if opts.command == "undo_edit" then + return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete, session_ctx) end - if opts.command == "create" then return require("avante.llm_tools.create").func(opts, on_log, on_complete) end - if opts.command == "insert" then return require("avante.llm_tools.insert").func(opts, on_log, on_complete) end - if opts.command == "undo_edit" then return require("avante.llm_tools.undo_edit").func(opts, on_log, on_complete) end return false, "Unknown command: " .. opts.command end diff --git a/lua/avante/llm_tools/insert.lua b/lua/avante/llm_tools/insert.lua index 45af507..32abcc8 100644 --- a/lua/avante/llm_tools/insert.lua +++ b/lua/avante/llm_tools/insert.lua @@ -50,7 +50,7 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, insert_line: integer, new_str: string }> -function M.func(opts, on_log, on_complete) +function M.func(opts, on_log, on_complete, session_ctx) if on_log then on_log("path: " .. opts.path) end local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -77,15 +77,15 @@ function M.func(opts, on_log, on_complete) hl_eol = true, hl_mode = "combine", }) - Helpers.confirm("Are you sure you want to insert these lines?", function(ok) + Helpers.confirm("Are you sure you want to insert these lines?", function(ok, reason) clear_highlights() if not ok then - on_complete(false, "User canceled") + on_complete(false, "User declined, reason: " .. (reason or "unknown")) return end vim.api.nvim_buf_set_lines(bufnr, opts.insert_line, opts.insert_line, false, new_lines) on_complete(true, nil) - end) + end, { focus = true }, session_ctx) end return M diff --git a/lua/avante/llm_tools/str_replace.lua b/lua/avante/llm_tools/str_replace.lua index f029b71..3ce4fd9 100644 --- a/lua/avante/llm_tools/str_replace.lua +++ b/lua/avante/llm_tools/str_replace.lua @@ -52,7 +52,7 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string, old_str: string, new_str: string }> -function M.func(opts, on_log, on_complete) +function M.func(opts, on_log, on_complete, session_ctx) if on_log then on_log("path: " .. opts.path) end local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -135,21 +135,8 @@ function M.func(opts, on_log, on_complete) vim.cmd("normal! zz") vim.api.nvim_set_current_win(current_winid) local augroup = vim.api.nvim_create_augroup("avante_str_replace_editor", { clear = true }) - local confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok) - pcall(vim.api.nvim_del_augroup_by_id, augroup) - vim.api.nvim_set_current_win(sidebar.code.winid) - vim.cmd("noautocmd stopinsert") - vim.cmd("noautocmd undo") - if not ok then - vim.api.nvim_set_current_win(current_winid) - on_complete(false, "User canceled") - return - end - vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, new_lines) - vim.api.nvim_set_current_win(current_winid) - on_complete(true, nil) - end, { focus = false }) vim.api.nvim_set_current_win(sidebar.code.winid) + local confirm vim.api.nvim_create_autocmd({ "TextChangedI", "TextChanged" }, { group = augroup, buffer = bufnr, @@ -167,6 +154,20 @@ function M.func(opts, on_log, on_complete) on_complete(true, nil) end, }) + confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok, reason) + pcall(vim.api.nvim_del_augroup_by_id, augroup) + vim.api.nvim_set_current_win(sidebar.code.winid) + vim.cmd("noautocmd stopinsert") + vim.cmd("noautocmd undo") + if not ok then + vim.api.nvim_set_current_win(current_winid) + on_complete(false, "User declined, reason: " .. (reason or "unknown")) + return + end + vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, new_lines) + vim.api.nvim_set_current_win(current_winid) + on_complete(true, nil) + end, { focus = false }, session_ctx) end return M diff --git a/lua/avante/llm_tools/undo_edit.lua b/lua/avante/llm_tools/undo_edit.lua index 1502ebb..aaa20e0 100644 --- a/lua/avante/llm_tools/undo_edit.lua +++ b/lua/avante/llm_tools/undo_edit.lua @@ -40,7 +40,7 @@ M.returns = { } ---@type AvanteLLMToolFunc<{ path: string }> -function M.func(opts, on_log, on_complete) +function M.func(opts, on_log, on_complete, session_ctx) if on_log then on_log("path: " .. opts.path) end local abs_path = Helpers.get_abs_path(opts.path) if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end @@ -52,9 +52,9 @@ function M.func(opts, on_log, on_complete) local winid = Utils.get_winid(bufnr) vim.api.nvim_set_current_win(winid) vim.api.nvim_set_current_win(current_winid) - Helpers.confirm("Are you sure you want to undo edit this file?", function(ok) + Helpers.confirm("Are you sure you want to undo edit this file?", function(ok, reason) if not ok then - on_complete(false, "User canceled") + on_complete(false, "User declined, reason: " .. (reason or "unknown")) return end vim.api.nvim_set_current_win(winid) @@ -62,7 +62,7 @@ function M.func(opts, on_log, on_complete) vim.cmd("undo") vim.api.nvim_set_current_win(current_winid) on_complete(true, nil) - end) + end, { focus = true }, session_ctx) end return M diff --git a/lua/avante/selection.lua b/lua/avante/selection.lua index e02c813..c901343 100644 --- a/lua/avante/selection.lua +++ b/lua/avante/selection.lua @@ -251,7 +251,7 @@ function Selection:create_editing_input() cancel_callback = function() self:close_editing_input() end, win_opts = { border = Config.windows.edit.border, - title = { { "Edit selected block", "FloatTitle" } }, + title = { { "Avante edit selected block", "FloatTitle" } }, }, start_insert = Config.windows.edit.start_insert, }) diff --git a/lua/avante/ui/confirm.lua b/lua/avante/ui/confirm.lua index 4d256c7..11b06f4 100644 --- a/lua/avante/ui/confirm.lua +++ b/lua/avante/ui/confirm.lua @@ -3,11 +3,13 @@ local NuiText = require("nui.text") local Highlights = require("avante.highlights") local Utils = require("avante.utils") local Line = require("avante.ui.line") +local PromptInput = require("avante.ui.prompt_input") +local Config = require("avante.config") ---@class avante.ui.Confirm ---@field message string ----@field callback fun(yes: boolean) ----@field _container_winid number | nil +---@field callback fun(type: "yes" | "all" | "no", reason?: string) +---@field _container_winid number ---@field _focus boolean | nil ---@field _group number | nil ---@field _popup NuiPopup | nil @@ -16,7 +18,7 @@ local M = {} M.__index = M ---@param message string ----@param callback fun(yes: boolean) +---@param callback fun(type: "yes" | "all" | "no", reason?: string) ---@param opts { container_winid: number, focus?: boolean } ---@return avante.ui.Confirm function M:new(message, callback, opts) @@ -35,7 +37,7 @@ function M:open() local win_width = 60 - local focus_index = 2 -- 1 = Yes, 2 = No + local focus_index = 3 -- 1 = Yes, 2 = All Yes, 3 = No local BUTTON_NORMAL = Highlights.BUTTON_DEFAULT local BUTTON_FOCUS = Highlights.BUTTON_DEFAULT_HOVER @@ -57,11 +59,19 @@ function M:open() { " - input ", commentfg }, { " " }, }) - local buttons_content = " Yes No " + local buttons_line = Line:new({ + { " [Y]es ", function() return focus_index == 1 and BUTTON_FOCUS or BUTTON_NORMAL end }, + { " " }, + { " [A]ll yes ", function() return focus_index == 2 and BUTTON_FOCUS or BUTTON_NORMAL end }, + { " " }, + { " [N]o ", function() return focus_index == 3 and BUTTON_FOCUS or BUTTON_NORMAL end }, + }) + local buttons_content = tostring(buttons_line) local buttons_start_col = math.floor((win_width - #buttons_content) / 2) - local yes_button_pos = { buttons_start_col, buttons_start_col + 5 } - local no_button_pos = { buttons_start_col + 10, buttons_start_col + 14 } - local buttons_line = string.rep(" ", buttons_start_col) .. buttons_content + local yes_button_pos = buttons_line:get_section_pos(1, buttons_start_col) + local all_button_pos = buttons_line:get_section_pos(3, buttons_start_col) + local no_button_pos = buttons_line:get_section_pos(5, buttons_start_col) + local buttons_line_content = string.rep(" ", buttons_start_col) .. buttons_content local keybindings_line_num = 5 + #vim.split(message, "\n") local buttons_line_num = 2 + #vim.split(message, "\n") local content = vim @@ -69,7 +79,7 @@ function M:open() "", vim.tbl_map(function(line) return " " .. line end, vim.split(message, "\n")), "", - buttons_line, + buttons_line_content, "", "", tostring(keybindings_line), @@ -85,7 +95,7 @@ function M:open() local button_row = buttons_line_num + 1 - local container_winid = self._container_winid or vim.api.nvim_get_current_win() + local container_winid = self._container_winid local container_width = vim.api.nvim_win_get_width(container_winid) local popup = Popup({ @@ -119,32 +129,49 @@ function M:open() }, }) - local function focus_button(row) - row = row or button_row + local function focus_button() if focus_index == 1 then - vim.api.nvim_win_set_cursor(popup.winid, { row, yes_button_pos[1] }) + vim.api.nvim_win_set_cursor(popup.winid, { button_row, yes_button_pos[1] }) + elseif focus_index == 2 then + vim.api.nvim_win_set_cursor(popup.winid, { button_row, all_button_pos[1] }) else - vim.api.nvim_win_set_cursor(popup.winid, { row, no_button_pos[1] }) + vim.api.nvim_win_set_cursor(popup.winid, { button_row, no_button_pos[1] }) end end local function render_content() - local yes_style = (focus_index == 1) and BUTTON_FOCUS or BUTTON_NORMAL - local no_style = (focus_index == 2) and BUTTON_FOCUS or BUTTON_NORMAL - Utils.unlock_buf(popup.bufnr) vim.api.nvim_buf_set_lines(popup.bufnr, 0, -1, false, content) Utils.lock_buf(popup.bufnr) + buttons_line:set_highlights(0, popup.bufnr, buttons_line_num, buttons_start_col) keybindings_line:set_highlights(0, popup.bufnr, keybindings_line_num) - vim.api.nvim_buf_add_highlight(popup.bufnr, 0, yes_style, buttons_line_num, yes_button_pos[1], yes_button_pos[2]) - vim.api.nvim_buf_add_highlight(popup.bufnr, 0, no_style, buttons_line_num, no_button_pos[1], no_button_pos[2]) - focus_button(buttons_line_num + 1) + focus_button() end - local function select_button() + local function click_button() self:close() - callback(focus_index == 1) + if focus_index == 1 then + callback("yes") + return + end + if focus_index == 2 then + Utils.notify("Accept all") + callback("all") + return + end + local prompt_input = PromptInput:new({ + submit_callback = function(input) callback("no", input ~= "" and input or nil) end, + close_on_submit = true, + win_opts = { + relative = "win", + win = self._container_winid, + border = Config.windows.ask.border, + title = { { "Reject reason", "FloatTitle" } }, + }, + start_insert = Config.windows.ask.start_insert, + }) + prompt_input:open() end vim.keymap.set("n", "c", function() @@ -174,22 +201,48 @@ function M:open() vim.keymap.set("n", "y", function() focus_index = 1 render_content() - select_button() + click_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "Y", function() + focus_index = 1 + render_content() + click_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "a", function() + focus_index = 2 + render_content() + click_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "A", function() + focus_index = 2 + render_content() + click_button() end, { buffer = popup.bufnr }) vim.keymap.set("n", "n", function() - focus_index = 2 + focus_index = 3 render_content() - select_button() + click_button() + end, { buffer = popup.bufnr }) + + vim.keymap.set("n", "N", function() + focus_index = 3 + render_content() + click_button() end, { buffer = popup.bufnr }) vim.keymap.set("n", "", function() - focus_index = 1 + focus_index = focus_index - 1 + if focus_index < 1 then focus_index = 3 end focus_button() end, { buffer = popup.bufnr }) vim.keymap.set("n", "", function() - focus_index = 2 + focus_index = focus_index + 1 + if focus_index > 3 then focus_index = 1 end focus_button() end, { buffer = popup.bufnr }) @@ -204,16 +257,18 @@ function M:open() end, { buffer = popup.bufnr }) vim.keymap.set("n", "", function() - focus_index = (focus_index == 1) and 2 or 1 + focus_index = focus_index + 1 + if focus_index > 3 then focus_index = 1 end focus_button() end, { buffer = popup.bufnr }) vim.keymap.set("n", "", function() - focus_index = (focus_index == 1) and 2 or 1 + focus_index = focus_index - 1 + if focus_index < 1 then focus_index = 3 end focus_button() end, { buffer = popup.bufnr }) - vim.keymap.set("n", "", function() select_button() end, { buffer = popup.bufnr }) + vim.keymap.set("n", "", function() click_button() end, { buffer = popup.bufnr }) vim.api.nvim_buf_set_keymap(popup.bufnr, "n", "", "", { callback = function() @@ -222,13 +277,13 @@ function M:open() if row == button_row then if col >= yes_button_pos[1] and col <= yes_button_pos[2] then focus_index = 1 - render_content() - select_button() - elseif col >= no_button_pos[1] and col <= no_button_pos[2] then + elseif col >= all_button_pos[1] and col <= all_button_pos[2] then focus_index = 2 - render_content() - select_button() + elseif col >= no_button_pos[1] and col <= no_button_pos[2] then + focus_index = 3 end + render_content() + click_button() end end, noremap = true, @@ -243,9 +298,12 @@ function M:open() if col >= yes_button_pos[1] and col <= yes_button_pos[2] then focus_index = 1 render_content() - elseif col >= no_button_pos[1] and col <= no_button_pos[2] then + elseif col >= all_button_pos[1] and col <= all_button_pos[2] then focus_index = 2 render_content() + elseif col >= no_button_pos[1] and col <= no_button_pos[2] then + focus_index = 3 + render_content() end end, }) @@ -286,7 +344,7 @@ end function M:unbind_window_focus_keymaps() pcall(vim.keymap.del, { "n", "i" }, "f") end function M:cancel() - self.callback(false) + self.callback("no", "cancel") return self:close() end diff --git a/lua/avante/ui/line.lua b/lua/avante/ui/line.lua index d049552..023f8fc 100644 --- a/lua/avante/ui/line.lua +++ b/lua/avante/ui/line.lua @@ -1,4 +1,4 @@ ----@alias avante.ui.LineSection string[] +---@alias avante.ui.LineSection table --- ---@class avante.ui.Line ---@field sections avante.ui.LineSection[] @@ -15,17 +15,34 @@ end ---@param ns_id number ---@param bufnr number ---@param line number -function M:set_highlights(ns_id, bufnr, line) +---@param offset number | nil +function M:set_highlights(ns_id, bufnr, line, offset) if not vim.api.nvim_buf_is_valid(bufnr) then return end - local col_start = 0 + local col_start = offset or 0 for _, section in ipairs(self.sections) do local text = section[1] local highlight = section[2] + if type(highlight) == "function" then highlight = highlight() end if highlight then vim.api.nvim_buf_add_highlight(bufnr, ns_id, highlight, line, col_start, col_start + #text) end col_start = col_start + #text end end +---@param section_index number +---@param offset number | nil +---@return number[] +function M:get_section_pos(section_index, offset) + offset = offset or 0 + local col_start = 0 + for i = 1, section_index - 1 do + if i == section_index then break end + local section = self.sections[i] + col_start = col_start + #section + end + + return { offset + col_start, offset + col_start + #self.sections[section_index] } +end + function M:__tostring() local content = {} for _, section in ipairs(self.sections) do