diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 1d0cd4e..7447c32 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -867,7 +867,7 @@ end function Sidebar:retry_user_request() local block = self:get_current_user_request_block() if not block then return end - self.handle_submit(block.content) + self:handle_submit(block.content) end function Sidebar:handle_expand_message(message_uuid, expanded) @@ -2623,6 +2623,245 @@ function Sidebar:get_generate_prompts_options(request, cb) if cb then cb(prompts_opts) end end +---@param request string +function Sidebar:handle_submit(request) + if Config.prompt_logger.enabled then PromptLogger.log_prompt(request) end + + if self.is_generating then + self:add_history_messages({ History.Message:new("user", request) }) + return + end + + if request:match("@codebase") and not vim.fn.expand("%:e") then + self:update_content("Please open a file first before using @codebase", { focus = false, scroll = false }) + return + end + + if request:sub(1, 1) == "/" then + local command, args = request:match("^/(%S+)%s*(.*)") + if command == nil then + self:update_content("Invalid command", { focus = false, scroll = false }) + return + end + local cmds = Utils.get_commands() + ---@type AvanteSlashCommand + local cmd = vim.iter(cmds):filter(function(cmd) return cmd.name == command end):totable()[1] + if cmd then + if command == "lines" then + cmd.callback(self, args, function(args_) + local _, _, question = args_:match("(%d+)-(%d+)%s+(.*)") + request = question + end) + elseif command == "commit" then + cmd.callback(self, args, function(question) request = question end) + else + cmd.callback(self, args) + return + end + else + self:update_content("Unknown command: " .. command, { focus = false, scroll = false }) + return + end + end + + -- Process shortcut replacements + local new_content, has_shortcuts = Utils.extract_shortcuts(request) + if has_shortcuts then request = new_content end + + local selected_filepaths = self.file_selector:get_selected_filepaths() + + ---@type AvanteSelectedCode | nil + local selected_code = self.code.selection + and { + path = self.code.selection.filepath, + file_type = self.code.selection.filetype, + content = self.code.selection.content, + } + + --- HACK: we need to set focus to true and scroll to false to + --- prevent the cursor from jumping to the bottom of the + --- buffer at the beginning + self:update_content("", { focus = true, scroll = false }) + + ---stop scroll when user presses j/k keys + local function on_j() + self.scroll = false + ---perform scroll + vim.cmd("normal! j") + end + + local function on_k() + self.scroll = false + ---perform scroll + vim.cmd("normal! k") + end + + local function on_G() + self.scroll = true + ---perform scroll + vim.cmd("normal! G") + end + + vim.keymap.set("n", "j", on_j, { buffer = self.containers.result.bufnr }) + vim.keymap.set("n", "k", on_k, { buffer = self.containers.result.bufnr }) + vim.keymap.set("n", "G", on_G, { buffer = self.containers.result.bufnr }) + + ---@type AvanteLLMStartCallback + local function on_start(_) end + + ---@param messages avante.HistoryMessage[] + local function on_messages_add(messages) self:add_history_messages(messages) end + + ---@param state avante.GenerateState + local function on_state_change(state) + self:clear_state() + self.current_state = state + self:render_state() + end + + ---@param tool_id string + ---@param tool_name string + ---@param log string + ---@param state AvanteLLMToolUseState + local function on_tool_log(tool_id, tool_name, log, state) + if state == "generating" then on_state_change("tool calling") end + local tool_use_message = History.Helpers.get_tool_use_message(tool_id, self.chat_history.messages) + if not tool_use_message then + -- Utils.debug("tool_use message not found", tool_id, tool_name) + return + end + + local tool_use_logs = tool_use_message.tool_use_logs or {} + local content = string.format("[%s]: %s", tool_name, log) + table.insert(tool_use_logs, content) + tool_use_message.tool_use_logs = tool_use_logs + + local orig_is_calling = tool_use_message.is_calling + tool_use_message.is_calling = true + self:update_content("") + tool_use_message.is_calling = orig_is_calling + self:save_history() + end + + local function set_tool_use_store(tool_id, key, value) + local tool_use_message = History.Helpers.get_tool_use_message(tool_id, self.chat_history.messages) + if tool_use_message then + local tool_use_store = tool_use_message.tool_use_store or {} + tool_use_store[key] = value + tool_use_message.tool_use_store = tool_use_store + self:save_history() + end + end + + ---@type AvanteLLMStopCallback + local function on_stop(stop_opts) + self.is_generating = false + + pcall(function() + ---remove keymaps + vim.keymap.del("n", "j", { buffer = self.containers.result.bufnr }) + vim.keymap.del("n", "k", { buffer = self.containers.result.bufnr }) + vim.keymap.del("n", "G", { buffer = self.containers.result.bufnr }) + end) + + if stop_opts.error ~= nil and stop_opts.error ~= vim.NIL then + local msg_content = stop_opts.error + if type(msg_content) ~= "string" then msg_content = vim.inspect(msg_content) end + self:add_history_messages({ + History.Message:new("assistant", "\n\nError: " .. msg_content, { + just_for_display = true, + }), + }) + on_state_change("failed") + return + end + + on_state_change("succeeded") + + self:update_content("", { + callback = function() api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN }) end, + }) + + vim.defer_fn(function() + if Utils.is_valid_container(self.containers.result, true) and Config.behaviour.jump_result_buffer_on_finish then + api.nvim_set_current_win(self.containers.result.winid) + end + if Config.behaviour.auto_apply_diff_after_generation then self:apply(false) end + end, 0) + + Path.history.save(self.code.bufnr, self.chat_history) + end + + if request and request ~= "" then + self:add_history_messages({ + History.Message:new("user", request, { + is_user_submission = true, + selected_filepaths = selected_filepaths, + selected_code = selected_code, + }), + }) + end + + self:get_generate_prompts_options(request, function(generate_prompts_options) + ---@type AvanteLLMStreamOptions + ---@diagnostic disable-next-line: assign-type-mismatch + local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { + on_start = on_start, + on_stop = on_stop, + on_tool_log = on_tool_log, + on_messages_add = on_messages_add, + on_state_change = on_state_change, + acp_client = self.acp_client, + on_save_acp_client = function(client) self.acp_client = client end, + acp_session_id = self.chat_history.acp_session_id, + on_save_acp_session_id = function(session_id) + self.chat_history.acp_session_id = session_id + Path.history.save(self.code.bufnr, self.chat_history) + end, + set_tool_use_store = set_tool_use_store, + get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end, + get_todos = function() + local history = Path.history.load(self.code.bufnr) + return history.todos + end, + update_todos = function(todos) self:update_todos(todos) end, + session_ctx = {}, + ---@param usage avante.LLMTokenUsage + update_tokens_usage = function(usage) + if not usage then return end + if usage.completion_tokens == nil then return end + if usage.prompt_tokens == nil then return end + self.chat_history.tokens_usage = usage + self:save_history() + end, + get_tokens_usage = function() return self.chat_history.tokens_usage end, + }) + + ---@param pending_compaction_history_messages avante.HistoryMessage[] + local function on_memory_summarize(pending_compaction_history_messages) + local history_memory = self.chat_history.memory + Llm.summarize_memory( + history_memory and history_memory.content, + pending_compaction_history_messages, + function(memory) + if memory then + self.chat_history.memory = memory + Path.history.save(self.code.bufnr, self.chat_history) + stream_options.memory = memory.content + end + stream_options.history_messages = self:get_history_messages_for_api() + Llm.stream(stream_options) + end + ) + end + + stream_options.on_memory_summarize = on_memory_summarize + + on_state_change("generating") + Llm.stream(stream_options) + end) +end + function Sidebar:initialize_token_count() if Config.behaviour.enable_token_counting then self:get_generate_prompts_options("") end end @@ -2634,250 +2873,6 @@ function Sidebar:create_input_container() if self.chat_history == nil then self:reload_chat_history() end - ---@param request string - local function handle_submit(request) - if Config.prompt_logger.enabled then PromptLogger.log_prompt(request) end - - if self.is_generating then - self:add_history_messages({ - History.Message:new("user", request), - }) - return - end - if request:match("@codebase") and not vim.fn.expand("%:e") then - self:update_content("Please open a file first before using @codebase", { focus = false, scroll = false }) - return - end - - if request:sub(1, 1) == "/" then - local command, args = request:match("^/(%S+)%s*(.*)") - if command == nil then - self:update_content("Invalid command", { focus = false, scroll = false }) - return - end - local cmds = Utils.get_commands() - ---@type AvanteSlashCommand - local cmd = vim.iter(cmds):filter(function(cmd) return cmd.name == command end):totable()[1] - if cmd then - if command == "lines" then - cmd.callback(self, args, function(args_) - local _, _, question = args_:match("(%d+)-(%d+)%s+(.*)") - request = question - end) - elseif command == "commit" then - cmd.callback(self, args, function(question) request = question end) - else - cmd.callback(self, args) - return - end - else - self:update_content("Unknown command: " .. command, { focus = false, scroll = false }) - return - end - end - - -- Process shortcut replacements - local new_content, has_shortcuts = Utils.extract_shortcuts(request) - if has_shortcuts then request = new_content end - - local selected_filepaths = self.file_selector:get_selected_filepaths() - - ---@type AvanteSelectedCode | nil - local selected_code = nil - if self.code.selection ~= nil then - selected_code = { - path = self.code.selection.filepath, - file_type = self.code.selection.filetype, - content = self.code.selection.content, - } - end - - --- HACK: we need to set focus to true and scroll to false to - --- prevent the cursor from jumping to the bottom of the - --- buffer at the beginning - self:update_content("", { focus = true, scroll = false }) - - ---stop scroll when user presses j/k keys - local function on_j() - self.scroll = false - ---perform scroll - vim.cmd("normal! j") - end - - local function on_k() - self.scroll = false - ---perform scroll - vim.cmd("normal! k") - end - - local function on_G() - self.scroll = true - ---perform scroll - vim.cmd("normal! G") - end - - vim.keymap.set("n", "j", on_j, { buffer = self.containers.result.bufnr }) - vim.keymap.set("n", "k", on_k, { buffer = self.containers.result.bufnr }) - vim.keymap.set("n", "G", on_G, { buffer = self.containers.result.bufnr }) - - ---@type AvanteLLMStartCallback - local function on_start(_) end - - ---@param messages avante.HistoryMessage[] - local function on_messages_add(messages) self:add_history_messages(messages) end - - ---@param state avante.GenerateState - local function on_state_change(state) - self:clear_state() - self.current_state = state - self:render_state() - end - - ---@param tool_id string - ---@param tool_name string - ---@param log string - ---@param state AvanteLLMToolUseState - local function on_tool_log(tool_id, tool_name, log, state) - if state == "generating" then on_state_change("tool calling") end - local tool_use_message = History.Helpers.get_tool_use_message(tool_id, self.chat_history.messages) - if not tool_use_message then - -- Utils.debug("tool_use message not found", tool_id, tool_name) - return - end - - local tool_use_logs = tool_use_message.tool_use_logs or {} - local content = string.format("[%s]: %s", tool_name, log) - table.insert(tool_use_logs, content) - tool_use_message.tool_use_logs = tool_use_logs - - local orig_is_calling = tool_use_message.is_calling - tool_use_message.is_calling = true - self:update_content("") - tool_use_message.is_calling = orig_is_calling - self:save_history() - end - - local function set_tool_use_store(tool_id, key, value) - local tool_use_message = History.Helpers.get_tool_use_message(tool_id, self.chat_history.messages) - if tool_use_message then - local tool_use_store = tool_use_message.tool_use_store or {} - tool_use_store[key] = value - tool_use_message.tool_use_store = tool_use_store - self:save_history() - end - end - - ---@type AvanteLLMStopCallback - local function on_stop(stop_opts) - self.is_generating = false - - pcall(function() - ---remove keymaps - vim.keymap.del("n", "j", { buffer = self.containers.result.bufnr }) - vim.keymap.del("n", "k", { buffer = self.containers.result.bufnr }) - vim.keymap.del("n", "G", { buffer = self.containers.result.bufnr }) - end) - - if stop_opts.error ~= nil and stop_opts.error ~= vim.NIL then - local msg_content = stop_opts.error - if type(msg_content) ~= "string" then msg_content = vim.inspect(msg_content) end - self:add_history_messages({ - History.Message:new("assistant", "\n\nError: " .. msg_content, { - just_for_display = true, - }), - }) - on_state_change("failed") - return - end - - on_state_change("succeeded") - - self:update_content("", { - callback = function() api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN }) end, - }) - - vim.defer_fn(function() - if - Utils.is_valid_container(self.containers.result, true) and Config.behaviour.jump_result_buffer_on_finish - then - api.nvim_set_current_win(self.containers.result.winid) - end - if Config.behaviour.auto_apply_diff_after_generation then self:apply(false) end - end, 0) - - Path.history.save(self.code.bufnr, self.chat_history) - end - - if request and request ~= "" then - self:add_history_messages({ - History.Message:new("user", request, { - is_user_submission = true, - selected_filepaths = selected_filepaths, - selected_code = selected_code, - }), - }) - end - - self:get_generate_prompts_options(request, function(generate_prompts_options) - ---@type AvanteLLMStreamOptions - ---@diagnostic disable-next-line: assign-type-mismatch - local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { - on_start = on_start, - on_stop = on_stop, - on_tool_log = on_tool_log, - on_messages_add = on_messages_add, - on_state_change = on_state_change, - acp_client = self.acp_client, - on_save_acp_client = function(client) self.acp_client = client end, - acp_session_id = self.chat_history.acp_session_id, - on_save_acp_session_id = function(session_id) - self.chat_history.acp_session_id = session_id - Path.history.save(self.code.bufnr, self.chat_history) - end, - set_tool_use_store = set_tool_use_store, - get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end, - get_todos = function() - local history = Path.history.load(self.code.bufnr) - return history.todos - end, - update_todos = function(todos) self:update_todos(todos) end, - session_ctx = {}, - ---@param usage avante.LLMTokenUsage - update_tokens_usage = function(usage) - if not usage then return end - if usage.completion_tokens == nil then return end - if usage.prompt_tokens == nil then return end - self.chat_history.tokens_usage = usage - self:save_history() - end, - get_tokens_usage = function() return self.chat_history.tokens_usage end, - }) - - ---@param pending_compaction_history_messages avante.HistoryMessage[] - local function on_memory_summarize(pending_compaction_history_messages) - local history_memory = self.chat_history.memory - Llm.summarize_memory( - history_memory and history_memory.content, - pending_compaction_history_messages, - function(memory) - if memory then - self.chat_history.memory = memory - Path.history.save(self.code.bufnr, self.chat_history) - stream_options.memory = memory.content - end - stream_options.history_messages = self:get_history_messages_for_api() - Llm.stream(stream_options) - end - ) - end - - stream_options.on_memory_summarize = on_memory_summarize - - on_state_change("generating") - Llm.stream(stream_options) - end) - end - local function get_position() if self:get_layout() == "vertical" then return "bottom" end return "right" @@ -2922,11 +2917,9 @@ function Sidebar:create_input_container() if request == "" then return end api.nvim_buf_set_lines(self.containers.input.bufnr, 0, -1, false, {}) api.nvim_win_set_cursor(self.containers.input.winid, { 1, 0 }) - handle_submit(request) + self:handle_submit(request) end - self.handle_submit = handle_submit - self.containers.input:mount() PromptLogger.init() @@ -3058,14 +3051,6 @@ function Sidebar:create_input_container() end end, }) - - api.nvim_create_autocmd("User", { - group = self.augroup, - pattern = "AvanteInputSubmitted", - callback = function(ev) - if ev.data and ev.data.request then handle_submit(ev.data.request) end - end, - }) end -- FIXME: this is used by external plugin users @@ -3227,6 +3212,14 @@ function Sidebar:render(opts) self:update_content_with_history() end + api.nvim_create_autocmd("User", { + group = self.augroup, + pattern = "AvanteInputSubmitted", + callback = function(ev) + if ev.data and ev.data.request then self:handle_submit(ev.data.request) end + end, + }) + return self end