From f8cbc88424ece7cfa6d8be251c36a55d0346ca97 Mon Sep 17 00:00:00 2001 From: yetone Date: Sat, 17 Aug 2024 15:14:30 +0800 Subject: [PATCH] refactor(ui): bounding popover (#13) (#29) * refactor(ui): bounding popover (#13) * refactor(ui): bounding popover Signed-off-by: Aaron Pham * chore: update readme instructions on setting up render-markdown.nvim Signed-off-by: Aaron Pham * chore: align code style * fix: incorrect type annotation * fix: make it work with mouse movement Signed-off-by: Aaron Pham * fix: focus correct on render Signed-off-by: Aaron Pham * fix: make sure to close the view Signed-off-by: Aaron Pham * chore: cleanup cursor position Signed-off-by: Aaron Pham * docs: add notes on rc Signed-off-by: Aaron Pham * fix: make sure to apply if has diff Signed-off-by: Aaron Pham * fix: do not simulate user input --------- Signed-off-by: Aaron Pham Co-authored-by: yetone * fix(autocmd): make sure to load tiktoken on correct events (closes #16) (#24) Signed-off-by: Aaron Pham * feat(type): better hinting on nui components (#27) Signed-off-by: Aaron Pham * feat: scrollview and tracking config and lazy load and perf (#33) * feat: scrollview and tracking config and lazy load and perf Signed-off-by: Aaron Pham * fix: add back options Signed-off-by: Aaron Pham * revert: remove unused autocmd Signed-off-by: Aaron Pham * fix: get code content * fix: keybinding hint virtual text position --------- Signed-off-by: Aaron Pham Co-authored-by: yetone --------- Signed-off-by: Aaron Pham Co-authored-by: Aaron Pham --- .luarc.json | 9 - README.md | 156 +++++++----- lua/avante/ai_bot.lua | 59 +++-- lua/avante/config.lua | 37 ++- lua/avante/diff.lua | 2 +- lua/avante/init.lua | 140 ++++++++++- lua/avante/sidebar.lua | 547 ++++++++++++++++++++--------------------- lua/avante/types.lua | 36 +++ lua/avante/utils.lua | 4 + lua/avante/view.lua | 66 +++++ 10 files changed, 661 insertions(+), 395 deletions(-) delete mode 100644 .luarc.json create mode 100644 lua/avante/types.lua create mode 100644 lua/avante/view.lua diff --git a/.luarc.json b/.luarc.json deleted file mode 100644 index 8b2003a..0000000 --- a/.luarc.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "workspace.library": [ - "/Users/yetone/.local/share/nvim/lazy/neodev.nvim/types/stable", - "/opt/homebrew/Cellar/neovim/0.9.0/share/nvim/runtime/lua", - "/Users/yetone/.local/share/nvim/lazy/nvim-dap-ui/lua", - "/Users/yetone/.config/nvim/lua", - "${3rd}/luv/library" - ] -} \ No newline at end of file diff --git a/README.md b/README.md index cef70b4..2e30b48 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,8 @@ https://github.com/user-attachments/assets/510e6270-b6cf-459d-9a2f-15b397d1fe53 - - https://github.com/user-attachments/assets/86140bfd-08b4-483d-a887-1b701d9e37dd - - ## Features - **AI-Powered Code Assistance**: Interact with AI to ask questions about your current code file and receive intelligent suggestions for improvement or modification. @@ -22,74 +18,89 @@ https://github.com/user-attachments/assets/86140bfd-08b4-483d-a887-1b701d9e37dd Install `avante.nvim` using [lazy.nvim](https://github.com/folke/lazy.nvim): - ```lua - { - "yetone/avante.nvim", - event = "VeryLazy", - opts = {}, - build = "make", - dependencies = { - "nvim-tree/nvim-web-devicons", - { - "grapp-dev/nui-components.nvim", - dependencies = { - "MunifTanjim/nui.nvim" - } - }, - "nvim-lua/plenary.nvim", - "MeanderingProgrammer/render-markdown.nvim", +```lua +{ + "yetone/avante.nvim", + event = "VeryLazy", + opts = {}, + build = "make", + dependencies = { + "nvim-tree/nvim-web-devicons", + { + "grapp-dev/nui-components.nvim", + dependencies = { + "MunifTanjim/nui.nvim" + } }, - } - ``` + "nvim-lua/plenary.nvim", + "MeanderingProgrammer/render-markdown.nvim", -- optional + }, +} +``` > [!IMPORTANT] > > If your neovim doesn't use LuaJIT, then change `build` to `make lua51`. By default running make will install luajit. > For ARM-based setup, make sure to also install cargo as we will have to build the tiktoken_core from source. +> [!note] `render-markdown.nvim` +> +> `render-markdown.nvim` is an optional dependency that is used to render the markdown content of the chat history. Make sure to also include `Avante` as a filetype +> to its setup: +> +> ```lua +> { +> "MeanderingProgrammer/markdown.nvim", +> opts = { +> file_types = { "markdown", "Avante" }, +> }, +> ft = { "markdown", "Avante" }, +> } +> ``` + Default setup configuration: - ```lua - { - provider = "claude", -- "claude" or "openai" or "azure" - openai = { - endpoint = "https://api.openai.com", - model = "gpt-4o", - temperature = 0, - max_tokens = 4096, +```lua +{ + provider = "claude", -- "claude" or "openai" or "azure" + openai = { + endpoint = "https://api.openai.com", + model = "gpt-4o", + temperature = 0, + max_tokens = 4096, + }, + azure = { + endpoint = "", -- Example: "https://.openai.azure.com" + deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") + api_version = "2024-06-01", + temperature = 0, + max_tokens = 4096, + }, + claude = { + endpoint = "https://api.anthropic.com", + model = "claude-3-5-sonnet-20240620", + temperature = 0, + max_tokens = 4096, + }, + highlights = { + diff = { + current = "DiffText", -- need have background color + incoming = "DiffAdd", -- need have background color }, - azure = { - endpoint = "", -- Example: "https://.openai.azure.com" - deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") - api_version = "2024-06-01", - temperature = 0, - max_tokens = 4096, + }, + mappings = { + show_sidebar = "aa", + diff = { + ours = "co", + theirs = "ct", + none = "c0", + both = "cb", + next = "]x", + prev = "[x", }, - claude = { - endpoint = "https://api.anthropic.com", - model = "claude-3-5-sonnet-20240620", - temperature = 0, - max_tokens = 4096, - }, - highlights = { - diff = { - current = "DiffText", -- need have background color - incoming = "DiffAdd", -- need have background color - }, - }, - mappings = { - show_sidebar = "aa", - diff = { - ours = "co", - theirs = "ct", - none = "c0", - both = "cb", - next = "]x", - prev = "[x", - }, - }, - } - ``` + }, +} +``` ## Usage @@ -165,6 +176,29 @@ To set up the development environment: pre-commit install --install-hooks ``` +For setting up lua_ls you can use the following for `nvim-lspconfig`: + +```lua +lua_ls = { + settings = { + Lua = { + runtime = { + version = "LuaJIT", + special = { reload = "require" }, + }, + workspace = { + library = { + vim.fn.expand "$VIMRUNTIME/lua", + vim.fn.expand "$VIMRUNTIME/lua/vim/lsp", + vim.fn.stdpath "data" .. "/lazy/lazy.nvim/lua/lazy", + "${3rd}/luv/library", + }, + }, + }, + }, +}, +``` + ## License avante.nvim is licensed under the Apache License. For more details, please refer to the [LICENSE](./LICENSE) file. diff --git a/lua/avante/ai_bot.lua b/lua/avante/ai_bot.lua index 7d4f473..fd3241a 100644 --- a/lua/avante/ai_bot.lua +++ b/lua/avante/ai_bot.lua @@ -1,11 +1,13 @@ -local M = {} +local fn = vim.fn local curl = require("plenary.curl") -local utils = require("avante.utils") -local config = require("avante.config") -local tiktoken = require("avante.tiktoken") -local fn = vim.fn +local Utils = require("avante.utils") +local Config = require("avante.config") +local Tiktoken = require("avante.tiktoken") + +---@class avante.AiBot +local M = {} local system_prompt = [[ You are an excellent programming expert. @@ -64,7 +66,7 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun local user_prompt = base_user_prompt - local tokens = config.get().claude.max_tokens + local tokens = Config.claude.max_tokens local headers = { ["Content-Type"] = "application/json", ["x-api-key"] = api_key, @@ -82,16 +84,16 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun text = user_prompt, } - if tiktoken.count(code_prompt_obj.text) > 1024 then + if Tiktoken.count(code_prompt_obj.text) > 1024 then code_prompt_obj.cache_control = { type = "ephemeral" } end - if tiktoken.count(user_prompt_obj.text) > 1024 then + if Tiktoken.count(user_prompt_obj.text) > 1024 then user_prompt_obj.cache_control = { type = "ephemeral" } end local body = { - model = config.get().claude.model, + model = Config.claude.model, system = system_prompt, messages = { { @@ -107,13 +109,13 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun }, }, stream = true, - temperature = config.get().claude.temperature, + temperature = Config.claude.temperature, max_tokens = tokens, } - local url = utils.trim_suffix(config.get().claude.endpoint, "/") .. "/v1/messages" + local url = Utils.trim_suffix(Config.claude.endpoint, "/") .. "/v1/messages" - print("Sending request to Claude API...") + -- print("Sending request to Claude API...") curl.post(url, { ---@diagnostic disable-next-line: unused-local @@ -154,7 +156,7 @@ end local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) local api_key = os.getenv("OPENAI_API_KEY") - if not api_key and config.get().provider == "openai" then + if not api_key and Config.provider == "openai" then error("OPENAI_API_KEY environment variable is not set") end @@ -169,16 +171,16 @@ local function call_openai_api_stream(question, code_lang, code_content, on_chun .. "\n```" local url, headers, body - if config.get().provider == "azure" then + if Config.provider == "azure" then api_key = os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY") if not api_key then error("Azure OpenAI API key is not set. Please set AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable.") end - url = config.get().azure.endpoint + url = Config.azure.endpoint .. "/openai/deployments/" - .. config.get().azure.deployment + .. Config.azure.deployment .. "/chat/completions?api-version=" - .. config.get().azure.api_version + .. Config.azure.api_version headers = { ["Content-Type"] = "application/json", ["api-key"] = api_key, @@ -188,29 +190,29 @@ local function call_openai_api_stream(question, code_lang, code_content, on_chun { role = "system", content = system_prompt }, { role = "user", content = user_prompt }, }, - temperature = config.get().azure.temperature, - max_tokens = config.get().azure.max_tokens, + temperature = Config.azure.temperature, + max_tokens = Config.azure.max_tokens, stream = true, } else - url = utils.trim_suffix(config.get().openai.endpoint, "/") .. "/v1/chat/completions" + url = Utils.trim_suffix(Config.openai.endpoint, "/") .. "/v1/chat/completions" headers = { ["Content-Type"] = "application/json", ["Authorization"] = "Bearer " .. api_key, } body = { - model = config.get().openai.model, + model = Config.openai.model, messages = { { role = "system", content = system_prompt }, { role = "user", content = user_prompt }, }, - temperature = config.get().openai.temperature, - max_tokens = config.get().openai.max_tokens, + temperature = Config.openai.temperature, + max_tokens = Config.openai.max_tokens, stream = true, } end - print("Sending request to " .. (config.get().provider == "azure" and "Azure OpenAI" or "OpenAI") .. " API...") + -- print("Sending request to " .. (config.get().provider == "azure" and "Azure OpenAI" or "OpenAI") .. " API...") curl.post(url, { ---@diagnostic disable-next-line: unused-local @@ -253,10 +255,15 @@ local function call_openai_api_stream(question, code_lang, code_content, on_chun }) end +---@param question string +---@param code_lang string +---@param code_content string +---@param on_chunk fun(chunk: string): any +---@param on_complete fun(err: string|nil): any function M.call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete) - if config.get().provider == "openai" or config.get().provider == "azure" then + if Config.provider == "openai" or Config.provider == "azure" then call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) - elseif config.get().provider == "claude" then + elseif Config.provider == "claude" then call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) end end diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 8d55518..40e44c8 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -1,6 +1,10 @@ +---NOTE: user will be merged with defaults and +---we add a default var_accessor for this table to config values. +---@class avante.CoreConfig: avante.Config local M = {} -local config = { +---@class avante.Config +M.defaults = { provider = "claude", -- "claude" or "openai" or "azure" openai = { endpoint = "https://api.openai.com", @@ -38,14 +42,37 @@ local config = { prev = "[x", }, }, + windows = { + width = 30, -- default % based on available width + }, } -function M.update(opts) - config = vim.tbl_deep_extend("force", config, opts or {}) +---@type avante.Config +M.options = {} + +---@param opts? avante.Config +function M.setup(opts) + M.options = vim.tbl_deep_extend("force", M.defaults, opts or {}) end -function M.get() - return config +M = setmetatable(M, { + __index = function(_, k) + if M.options[k] then + return M.options[k] + end + end, +}) + +function M.get_window_width() + return math.ceil(vim.o.columns * (M.windows.width / 100)) +end + +---@return {width: integer, height: integer, position: integer} +function M.get_renderer_layout_options() + local width = M.get_window_width() + local height = vim.o.lines + local position = vim.o.columns - width + return { width = width, height = height, position = position } end return M diff --git a/lua/avante/diff.lua b/lua/avante/diff.lua index 63cfbc2..f4c713b 100644 --- a/lua/avante/diff.lua +++ b/lua/avante/diff.lua @@ -392,7 +392,7 @@ local function register_cursor_move_events(bufnr) show_keybinding_hint_extmark_id = api.nvim_buf_set_extmark(bufnr, KEYBINDING_NAMESPACE, lnum - 1, -1, { hl_group = "Keyword", virt_text = { { hint, "Keyword" } }, - virt_text_win_col = col, + virt_text_pos = "right_align", priority = PRIORITY, }) end diff --git a/lua/avante/init.lua b/lua/avante/init.lua index 41156a4..fe5cdf3 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -1,30 +1,148 @@ -local M = {} -local sidebar = require("avante.sidebar") -local config = require("avante.config") +local api = vim.api -function M.setup(opts) +local Tiktoken = require("avante.tiktoken") +local Sidebar = require("avante.sidebar") +local Config = require("avante.config") +local Diff = require("avante.diff") + +---@class Avante +local M = { + ---@type avante.Sidebar[] we use this to track chat command across tabs + sidebars = {}, + ---@type avante.Sidebar + current = nil, + _once = false, +} + +local H = {} + +H.commands = function() + local cmd = function(n, c, o) + o = vim.tbl_extend("force", { nargs = 0 }, o or {}) + api.nvim_create_user_command("Avante" .. n, c, o) + end + + cmd("Ask", function() + M.toggle() + end, { desc = "avante: ask AI for code suggestions" }) + cmd("Close", function() + local sidebar = M._get() + if not sidebar then + return + end + sidebar:close() + end, { desc = "avante: close chat window" }) +end + +H.keymaps = function() + vim.keymap.set({ "n" }, Config.mappings.show_sidebar, M.toggle, { noremap = true }) +end + +H.autocmds = function() local ok, LazyConfig = pcall(require, "lazy.core.config") + if ok then local name = "avante.nvim" + local load_path = function() + require("tiktoken_lib").load() + Tiktoken.setup("gpt-4o") + end + if LazyConfig.plugins[name] and LazyConfig.plugins[name]._.loaded then - vim.schedule(function() - require("tiktoken_lib").load() - end) + vim.schedule(load_path) else - vim.api.nvim_create_autocmd("User", { + api.nvim_create_autocmd("User", { pattern = "LazyLoad", callback = function(event) if event.data == name then - require("tiktoken_lib").load() + load_path() return true end end, }) end + + api.nvim_create_autocmd("User", { + pattern = "VeryLazy", + callback = load_path, + }) end - config.update(opts) - sidebar.setup() + api.nvim_create_autocmd("TabClosed", { + pattern = "*", + callback = function(ev) + local tab = tonumber(ev.file) + local s = M.sidebars[tab] + if s then + s:destroy() + end + M.sidebars[tab] = nil + end, + }) + + -- automatically setup Avante filetype to markdown + vim.treesitter.language.register("markdown", "Avante") +end + +---@param current boolean? false to disable setting current, otherwise use this to track across tabs. +---@return avante.Sidebar +function M._get(current) + local tab = api.nvim_get_current_tabpage() + local sidebar = M.sidebars[tab] + if current ~= false then + M.current = sidebar + end + return sidebar +end + +M.open = function() + local tab = api.nvim_get_current_tabpage() + local sidebar = M.sidebars[tab] + + if not sidebar then + sidebar = Sidebar:new(tab) + M.sidebars[tab] = sidebar + end + + M.current = sidebar + + return sidebar:open() +end + +M.toggle = function() + local sidebar = M._get() + if not sidebar then + M.open() + return true + end + + return sidebar:toggle() +end + +---@param opts? avante.Config +function M.setup(opts) + ---PERF: we can still allow running require("avante").setup() multiple times to override config if users wish to + ---but most of the other functionality will only be called once from lazy.nvim + Config.setup(opts) + + if M._once then + return + end + Diff.setup({ + debug = false, -- log output to console + default_mappings = Config.mappings.diff, -- disable buffer local mapping created by this plugin + default_commands = true, -- disable commands created by this plugin + disable_diagnostics = true, -- This will disable the diagnostics in a buffer whilst it is conflicted + list_opener = "copen", + highlights = Config.highlights.diff, + }) + + -- setup helpers + H.autocmds() + H.commands() + H.keymaps() + + M._once = true end return M diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 9188a34..1c5c084 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1,26 +1,199 @@ -local M = {} -local Path = require("plenary.path") -local n = require("nui-components") -local diff = require("avante.diff") -local tiktoken = require("avante.tiktoken") -local config = require("avante.config") -local ai_bot = require("avante.ai_bot") local api = vim.api local fn = vim.fn -local RESULT_BUF_NAME = "AVANTE_RESULT" -local CONFLICT_BUF_NAME = "AVANTE_CONFLICT" +local Path = require("plenary.path") +local N = require("nui-components") -local CODEBLOCK_KEYBINDING_NAMESPACE = vim.api.nvim_create_namespace("AVANTE_CODEBLOCK_KEYBINDING") +local Config = require("avante.config") +local View = require("avante.view") +local Diff = require("avante.diff") +local AiBot = require("avante.ai_bot") +local Utils = require("avante.utils") + +local CODEBLOCK_KEYBINDING_NAMESPACE = api.nvim_create_namespace("AVANTE_CODEBLOCK_KEYBINDING") local PRIORITY = vim.highlight.priorities.user +---@class avante.Sidebar +local Sidebar = {} + +---@class avante.SidebarState +---@field win integer +---@field buf integer + +---@class avante.Sidebar +---@field id integer +---@field view avante.View +---@field augroup integer +---@field code avante.SidebarState +---@field renderer NuiRenderer +---@field winid {result: integer, input: integer} + +---@param id integer the tabpage id retrieved from vim.api.nvim_get_current_tabpage() +function Sidebar:new(id) + return setmetatable({ + id = id, + code = { buf = 0, win = 0 }, + winid = { result = 0, input = 0 }, + view = View:new(), + renderer = nil, + }, { __index = Sidebar }) +end + +--- This function should only be used on TabClosed, nothing else. +function Sidebar:destroy() + self:delete_autocmds() + self.view = nil + self.code = nil + self.winid = nil + self.renderer = nil +end + +function Sidebar:delete_autocmds() + if self.augroup then + vim.api.nvim_del_augroup_by_id(self.augroup) + end + self.augroup = nil +end + +function Sidebar:reset() + self.code = { buf = 0, win = 0 } + self.winid = { result = 0, input = 0 } + self:delete_autocmds() +end + +function Sidebar:open() + if not self.view:is_open() then + self:intialize() + self:render() + else + self:focus() + end + return self +end + +function Sidebar:close() + self.renderer:close() + fn.win_gotoid(self.code.win) +end + +---@return boolean +function Sidebar:focus() + if self.view:is_open() then + fn.win_gotoid(self.view.win) + return true + end + return false +end + +function Sidebar:toggle() + if self.view:is_open() then + self:close() + return false + else + self:open() + return true + end +end + +function Sidebar:has_code_win() + return self.code.win + and self.code.buf + and self.code.win ~= 0 + and self.code.buf ~= 0 + and api.nvim_win_is_valid(self.code.win) + and api.nvim_buf_is_valid(self.code.buf) +end + +function Sidebar:intialize() + self.code.win = api.nvim_get_current_win() + self.code.buf = api.nvim_get_current_buf() + + local split_command = "botright vs" + local layout = Config.get_renderer_layout_options() + + self.view:setup(split_command, layout.width) + + --- setup coord + self.renderer = N.create_renderer({ + width = layout.width, + height = layout.height, + position = layout.position, + relative = { type = "win", winid = fn.bufwinid(self.view.buf) }, + }) + + self.renderer:on_mount(function() + local components = self.renderer:get_focusable_components() + -- current layout is a + -- [ chat ] + -- + -- [ input ] + self.winid.result = components[1].winid + self.winid.input = components[2].winid + self.augroup = api.nvim_create_augroup("avante_" .. self.id .. self.view.win, { clear = true }) + + api.nvim_create_autocmd("BufEnter", { + group = self.augroup, + buffer = self.view.buf, + callback = function() + self:focus() + vim.api.nvim_set_current_win(self.winid.input) + return true + end, + }) + + api.nvim_create_autocmd("VimResized", { + group = self.augroup, + callback = function() + local new_layout = Config.get_renderer_layout_options() + vim.api.nvim_win_set_width(self.view.win, new_layout.width) + vim.api.nvim_win_set_height(self.view.win, new_layout.height) + self.renderer:set_size({ width = new_layout.width, height = new_layout.height }) + end, + }) + end) + + self.renderer:on_unmount(function() + self.view:close() + end) + + -- reset states when buffer is closed + api.nvim_buf_attach(self.code.buf, false, { + on_detach = function(_, _) + self:reset() + end, + }) +end + +---@param content string concatenated content of the buffer +---@param focus? boolean whether to focus the result view +function Sidebar:update_content(content, focus) + focus = focus or false + vim.defer_fn(function() + api.nvim_set_option_value("modifiable", true, { buf = self.view.buf }) + api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n")) + api.nvim_set_option_value("modifiable", false, { buf = self.view.buf }) + api.nvim_set_option_value("filetype", "Avante", { buf = self.view.buf }) + if focus then + xpcall(function() + --- set cursor to bottom of result view + api.nvim_set_current_win(self.winid.result) + end, function(err) + -- XXX: omit error for now, but should fix me why it can't jump here. + return err + end) + api.nvim_win_set_cursor(self.winid.result, { api.nvim_buf_line_count(self.view.buf), 0 }) + end + end, 0) + return self +end + local function parse_codeblocks(buf) local codeblocks = {} local in_codeblock = false local start_line = nil local lang = nil - local lines = vim.api.nvim_buf_get_lines(buf, 0, -1, false) + local lines = api.nvim_buf_get_lines(buf, 0, -1, false) for i, line in ipairs(lines) do if line:match("^```") then -- parse language @@ -39,8 +212,9 @@ local function parse_codeblocks(buf) return codeblocks end +---@param codeblocks table local function is_cursor_in_codeblock(codeblocks) - local cursor_pos = vim.api.nvim_win_get_cursor(0) + local cursor_pos = api.nvim_win_get_cursor(0) local cursor_line = cursor_pos[1] - 1 -- 转换为 0-indexed 行号 for _, block in ipairs(codeblocks) do @@ -52,80 +226,6 @@ local function is_cursor_in_codeblock(codeblocks) return nil end -local function create_result_buf() - local buf = api.nvim_create_buf(false, true) - api.nvim_set_option_value("filetype", "markdown", { buf = buf }) - api.nvim_set_option_value("buftype", "nofile", { buf = buf }) - api.nvim_set_option_value("swapfile", false, { buf = buf }) - api.nvim_set_option_value("modifiable", false, { buf = buf }) - api.nvim_set_option_value("bufhidden", "wipe", { buf = buf }) - api.nvim_buf_set_name(buf, RESULT_BUF_NAME) - return buf -end - -local result_buf = create_result_buf() - -local function is_code_buf(buf) - local ignored_filetypes = { - "dashboard", - "alpha", - "neo-tree", - "NvimTree", - "TelescopePrompt", - "Prompt", - "qf", - "help", - } - - if api.nvim_buf_is_valid(buf) and api.nvim_get_option_value("buflisted", { buf = buf }) then - local buftype = api.nvim_get_option_value("buftype", { buf = buf }) - local filetype = api.nvim_get_option_value("filetype", { buf = buf }) - - if buftype == "" and filetype ~= "" and not vim.tbl_contains(ignored_filetypes, filetype) then - local bufname = api.nvim_buf_get_name(buf) - if bufname ~= "" and bufname ~= RESULT_BUF_NAME and bufname ~= CONFLICT_BUF_NAME then - return true - end - end - end - - return false -end - -local _cur_code_buf = nil - -local function get_cur_code_buf() - return _cur_code_buf -end - -local function get_cur_code_buf_name() - local code_buf = get_cur_code_buf() - if code_buf == nil then - print("Error: cannot get code buffer") - return - end - return api.nvim_buf_get_name(code_buf) -end - -local function get_cur_code_win() - local code_buf = get_cur_code_buf() - if code_buf == nil then - print("Error: cannot get code buffer") - return - end - return fn.bufwinid(code_buf) -end - -local function get_cur_code_buf_content() - local code_buf = get_cur_code_buf() - if code_buf == nil then - print("Error: cannot get code buffer") - return {} - end - local lines = api.nvim_buf_get_lines(code_buf, 0, -1, false) - return table.concat(lines, "\n") -end - local function prepend_line_number(content) local lines = vim.split(content, "\n") local result = {} @@ -179,31 +279,6 @@ local function extract_code_snippets(content) return snippets end -local function update_result_buf_content(content) - local current_win = api.nvim_get_current_win() - local result_win = fn.bufwinid(result_buf) - - vim.defer_fn(function() - api.nvim_set_option_value("modifiable", true, { buf = result_buf }) - api.nvim_buf_set_lines(result_buf, 0, -1, false, vim.split(content, "\n")) - api.nvim_set_option_value("modifiable", false, { buf = result_buf }) - api.nvim_set_option_value("filetype", "markdown", { buf = result_buf }) - if result_win ~= -1 then - -- Move to the bottom - api.nvim_win_set_cursor(result_win, { api.nvim_buf_line_count(result_buf), 0 }) - api.nvim_set_current_win(current_win) - end - end, 0) -end - --- Add a new function to display notifications -local function show_notification(message) - vim.notify(message, vim.log.levels.INFO, { - title = "AI Assistant", - timeout = 3000, - }) -end - -- Function to get the current project root directory local function get_project_root() local current_file = fn.expand("%:p") @@ -212,12 +287,9 @@ local function get_project_root() return git_root or current_dir end -local function get_chat_history_filename() - local code_buf_name = get_cur_code_buf_name() - if code_buf_name == nil then - print("Error: cannot get code buffer name") - return - end +---@param sidebar avante.Sidebar +local function get_chat_history_filename(sidebar) + local code_buf_name = api.nvim_buf_get_name(sidebar.code.buf) local relative_path = fn.fnamemodify(code_buf_name, ":~:.") -- Replace path separators with double underscores local path_with_separators = fn.substitute(relative_path, "/", "__", "g") @@ -226,9 +298,9 @@ local function get_chat_history_filename() end -- Function to get the chat history file path -local function get_chat_history_file() +local function get_chat_history_file(sidebar) local project_root = get_project_root() - local filename = get_chat_history_filename() + local filename = get_chat_history_filename(sidebar) local history_dir = Path:new(project_root, ".avante_chat_history") return history_dir:joinpath(filename .. ".json") end @@ -239,8 +311,8 @@ local function get_timestamp() end -- Function to load chat history -local function load_chat_history() - local history_file = get_chat_history_file() +local function load_chat_history(sidebar) + local history_file = get_chat_history_file(sidebar) if history_file:exists() then local content = history_file:read() return fn.json_decode(content) @@ -249,8 +321,8 @@ local function load_chat_history() end -- Function to save chat history -local function save_chat_history(history) - local history_file = get_chat_history_file() +local function save_chat_history(sidebar, history) + local history_file = get_chat_history_file(sidebar) local history_dir = history_file:parent() -- Create the directory if it doesn't exist @@ -261,7 +333,7 @@ local function save_chat_history(history) history_file:write(fn.json_encode(history), "w") end -local function update_result_buf_with_history(history) +function Sidebar:update_content_with_history(history) local content = "" for _, entry in ipairs(history) do content = content .. "## " .. entry.timestamp .. "\n\n" @@ -269,11 +341,7 @@ local function update_result_buf_with_history(history) content = content .. entry.response .. "\n\n" content = content .. "---\n\n" end - update_result_buf_content(content) -end - -local function trim_line_number_prefix(line) - return line:gsub("^L%d+: ", "") + self:update_content(content, true) end local function get_conflict_content(content, snippets) @@ -301,7 +369,7 @@ local function get_conflict_content(content, snippets) table.insert(result, "=======") for _, line in ipairs(vim.split(snippet.content, "\n")) do - line = trim_line_number_prefix(line) + line = Utils.trim_line_number_prefix(line) table.insert(result, line) end @@ -318,11 +386,17 @@ local function get_conflict_content(content, snippets) return result end -local function get_content_between_separators() +---@return string +function Sidebar:get_code_content() + local lines = api.nvim_buf_get_lines(self.code.buf, 0, -1, false) + return table.concat(lines, "\n") +end + +---@return string +function Sidebar:get_content_between_separators() local separator = "---" - local bufnr = vim.api.nvim_get_current_buf() - local cursor_line = vim.api.nvim_win_get_cursor(0)[1] - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local cursor_line = api.nvim_win_get_cursor(0)[1] + local lines = api.nvim_buf_get_lines(self.view.buf, 0, -1, false) local start_line, end_line for i = cursor_line, 1, -1 do @@ -353,29 +427,16 @@ local function get_content_between_separators() return content end -local get_renderer_size_and_position = function() - local renderer_width = math.ceil(vim.o.columns * 0.3) - local renderer_height = vim.o.lines - local renderer_position = vim.o.columns - renderer_width - return renderer_width, renderer_height, renderer_position -end - -function M.render_sidebar() - if result_buf ~= nil and api.nvim_buf_is_valid(result_buf) then - api.nvim_buf_delete(result_buf, { force = true }) - end - - result_buf = create_result_buf() - +function Sidebar:render() local current_apply_extmark_id = nil local function show_apply_button(block) if current_apply_extmark_id then - api.nvim_buf_del_extmark(result_buf, CODEBLOCK_KEYBINDING_NAMESPACE, current_apply_extmark_id) + api.nvim_buf_del_extmark(self.view.buf, CODEBLOCK_KEYBINDING_NAMESPACE, current_apply_extmark_id) end current_apply_extmark_id = - api.nvim_buf_set_extmark(result_buf, CODEBLOCK_KEYBINDING_NAMESPACE, block.start_line, -1, { + api.nvim_buf_set_extmark(self.view.buf, CODEBLOCK_KEYBINDING_NAMESPACE, block.start_line, -1, { virt_text = { { " [Press to Apply these patches] ", "Keyword" } }, virt_text_pos = "right_align", hl_group = "Keyword", @@ -384,132 +445,91 @@ function M.render_sidebar() end local function apply() - local code_buf = get_cur_code_buf() - if code_buf == nil then - error("Error: cannot get code buffer") - return - end - local content = get_cur_code_buf_content() - local response = get_content_between_separators() + local content = self:get_code_content() + local response = self:get_content_between_separators() local snippets = extract_code_snippets(response) local conflict_content = get_conflict_content(content, snippets) vim.defer_fn(function() - api.nvim_buf_set_lines(code_buf, 0, -1, false, conflict_content) - local code_win = get_cur_code_win() - if code_win == nil then - error("Error: cannot get code window") - return - end - api.nvim_set_current_win(code_win) + api.nvim_buf_set_lines(self.code.buf, 0, -1, false, conflict_content) + + api.nvim_set_current_win(self.code.win) api.nvim_feedkeys(api.nvim_replace_termcodes("", true, false, true), "n", true) - diff.add_visited_buffer(code_buf) - diff.process(code_buf) - api.nvim_feedkeys("gg", "n", false) + Diff.add_visited_buffer(self.code.buf) + Diff.process(self.code.buf) + api.nvim_win_set_cursor(self.code.win, { 1, 0 }) vim.defer_fn(function() vim.cmd("AvanteConflictNextConflict") - api.nvim_feedkeys("zz", "n", false) + vim.cmd("normal! zz") end, 1000) end, 10) end local function bind_apply_key() - vim.keymap.set("n", "A", apply, { buffer = result_buf, noremap = true, silent = true }) + vim.keymap.set("n", "A", apply, { buffer = self.view.buf, noremap = true, silent = true }) end local function unbind_apply_key() - pcall(vim.keymap.del, "n", "A", { buffer = result_buf }) + pcall(vim.keymap.del, "n", "A", { buffer = self.view.buf }) end local codeblocks = {} api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI" }, { - buffer = result_buf, - callback = function() + buffer = self.view.buf, + callback = function(ev) local block = is_cursor_in_codeblock(codeblocks) if block then show_apply_button(block) bind_apply_key() else - api.nvim_buf_clear_namespace(result_buf, CODEBLOCK_KEYBINDING_NAMESPACE, 0, -1) + api.nvim_buf_clear_namespace(ev.buf, CODEBLOCK_KEYBINDING_NAMESPACE, 0, -1) unbind_apply_key() end end, }) api.nvim_create_autocmd({ "BufEnter", "BufWritePost" }, { - buffer = result_buf, - callback = function() - codeblocks = parse_codeblocks(result_buf) + buffer = self.view.buf, + callback = function(ev) + codeblocks = parse_codeblocks(ev.buf) end, }) - local renderer_width, renderer_height, renderer_position = get_renderer_size_and_position() + local signal = N.create_signal({ is_loading = false, text = "" }) - local renderer = n.create_renderer({ - width = renderer_width, - height = renderer_height, - position = renderer_position, - relative = "editor", - }) - - local autocmd_id - renderer:on_mount(function() - autocmd_id = api.nvim_create_autocmd("VimResized", { - callback = function() - local width, height, _ = get_renderer_size_and_position() - renderer:set_size({ width = width, height = height }) - end, - }) - end) - - renderer:on_unmount(function() - if autocmd_id ~= nil then - api.nvim_del_autocmd(autocmd_id) - end - end) - - local signal = n.create_signal({ - is_loading = false, - text = "", - }) - - local chat_history = load_chat_history() - update_result_buf_with_history(chat_history) + local chat_history = load_chat_history(self) + self:update_content_with_history(chat_history) local function handle_submit() local state = signal:get_value() local user_input = state.text local timestamp = get_timestamp() - update_result_buf_content( + self:update_content( "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\nGenerating response from " - .. config.get().provider + .. Config.provider .. " ...\n" ) - local code_buf = get_cur_code_buf() - if code_buf == nil then - error("Error: cannot get code buffer") - return - end - local content = get_cur_code_buf_content() + local content = self:get_code_content() local content_with_line_numbers = prepend_line_number(content) local full_response = "" signal.is_loading = true - local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) + local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf }) - ai_bot.call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) + AiBot.call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) full_response = full_response .. chunk - update_result_buf_content( - "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response + self:update_content( + "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response, + true ) vim.schedule(function() vim.cmd("redraw") @@ -518,7 +538,7 @@ function M.render_sidebar() signal.is_loading = false if err ~= nil then - update_result_buf_content( + self:update_content( "## " .. timestamp .. "\n\n> " @@ -526,53 +546,52 @@ function M.render_sidebar() .. "\n\n" .. full_response .. "\n\n🚨 Error: " - .. vim.inspect(err) + .. vim.inspect(err), + true ) return end -- Execute when the stream request is actually completed - update_result_buf_content( + self:update_content( "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response - .. "\n\n**Generation complete!** Please review the code suggestions above.\n\n\n\n" + .. "\n\n**Generation complete!** Please review the code suggestions above.\n\n\n\n", + true ) + api.nvim_set_current_win(self.winid.result) + -- Display notification - show_notification("Content generation complete!") + -- show_notification("Content generation complete!") -- Save chat history table.insert(chat_history or {}, { timestamp = timestamp, requirement = user_input, response = full_response }) - save_chat_history(chat_history) + save_chat_history(self, chat_history) end) end local body = function() - local code_buf = get_cur_code_buf() - if code_buf == nil then - error("Error: cannot get code buffer") - return - end - local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) + local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf }) local icon = require("nvim-web-devicons").get_icon_by_filetype(filetype, {}) - local code_file_fullpath = api.nvim_buf_get_name(code_buf) + local code_file_fullpath = api.nvim_buf_get_name(self.code.buf) local code_filename = fn.fnamemodify(code_file_fullpath, ":t") - return n.rows( + return N.rows( { flex = 0 }, - n.box( + N.box( { direction = "column", size = vim.o.lines - 4, }, - n.buffer({ + N.buffer({ id = "response", flex = 1, - buf = result_buf, + buf = self.view.buf, autoscroll = true, border_label = { text = "💬 Avante Chat", @@ -586,14 +605,15 @@ function M.render_sidebar() }, }) ), - n.gap(1), - n.columns( + N.gap(1), + N.columns( { flex = 0 }, - n.text_input({ + N.text_input({ id = "text-input", border_label = { - text = string.format(" 🙋 Your question (with %s %s): ", icon, code_filename), + text = string.format(" 🙋 (with %s %s): ", icon, code_filename), }, + placeholder = "Enter your question", autofocus = true, wrap = true, flex = 1, @@ -610,55 +630,18 @@ function M.render_sidebar() end, padding = { left = 1, right = 1 }, }), - n.gap(1), - n.spinner({ + N.gap(1), + N.spinner({ is_loading = signal.is_loading, padding = { top = 1, right = 1 }, - ---@diagnostic disable-next-line: undefined-field hidden = signal.is_loading:negate(), }) ) ) end - renderer:render(body) + self.renderer:render(body) + return self end -function M.setup() - local bufnr = vim.api.nvim_get_current_buf() - if is_code_buf(bufnr) then - _cur_code_buf = bufnr - end - - tiktoken.setup("gpt-4o") - - diff.setup({ - debug = false, -- log output to console - default_mappings = config.get().mappings.diff, -- disable buffer local mapping created by this plugin - default_commands = true, -- disable commands created by this plugin - disable_diagnostics = true, -- This will disable the diagnostics in a buffer whilst it is conflicted - list_opener = "copen", - highlights = config.get().highlights.diff, - }) - - local function on_buf_enter() - bufnr = vim.api.nvim_get_current_buf() - if is_code_buf(bufnr) then - _cur_code_buf = bufnr - end - end - - api.nvim_create_autocmd("BufEnter", { - callback = on_buf_enter, - }) - - api.nvim_create_user_command("AvanteAsk", function() - M.render_sidebar() - end, { - nargs = 0, - }) - - api.nvim_set_keymap("n", config.get().mappings.show_sidebar, "AvanteAsk", { noremap = true, silent = true }) -end - -return M +return Sidebar diff --git a/lua/avante/types.lua b/lua/avante/types.lua new file mode 100644 index 0000000..070d7d9 --- /dev/null +++ b/lua/avante/types.lua @@ -0,0 +1,36 @@ +---@meta + +---@class NuiRenderer +_G.AvanteRenderer = require("nui-components.renderer") + +---@class NuiComponent +_G.AvanteComponent = require("nui-components.component") + +---@param opts table +---@return NuiRenderer +function AvanteRenderer.create(opts) end + +---@param body fun():NuiComponent +function AvanteRenderer:render(body) end + +---@return nil +function AvanteRenderer:focus() end + +---@return nil +function AvanteRenderer:close() end + +---@param callback fun():nil +---@return nil +function AvanteRenderer:on_mount(callback) end + +---@param callback fun():nil +---@return nil +function AvanteRenderer:on_unmount(callback) end + +---@class LayoutSize +---@field width integer? +---@field height integer? + +---@param size LayoutSize +---@return nil +function AvanteRenderer:set_size(size) end diff --git a/lua/avante/utils.lua b/lua/avante/utils.lua index 6993623..8bb3828 100644 --- a/lua/avante/utils.lua +++ b/lua/avante/utils.lua @@ -4,4 +4,8 @@ function M.trim_suffix(str, suffix) return string.gsub(str, suffix .. "$", "") end +function M.trim_line_number_prefix(line) + return line:gsub("^L%d+: ", "") +end + return M diff --git a/lua/avante/view.lua b/lua/avante/view.lua new file mode 100644 index 0000000..146d39d --- /dev/null +++ b/lua/avante/view.lua @@ -0,0 +1,66 @@ +local api = vim.api + +---@class avante.View +---@field buf integer +---@field win integer +---@field RESULT_BUF_NAME string +local View = {} + +local RESULT_BUF_NAME = "AVANTE_RESULT" + +function View:new() + return setmetatable({ buf = nil, win = nil }, { __index = View }) +end + +---setup view buffer +---@param split_command string A split command to position the side bar to +---@param size integer a given % to resize the chat window +---@return avante.View +function View:setup(split_command, size) + -- create a scratch unlisted buffer + self.buf = api.nvim_create_buf(false, true) + + -- set filetype + api.nvim_set_option_value("filetype", "Avante", { buf = self.buf }) + api.nvim_set_option_value("bufhidden", "wipe", { buf = self.buf }) + api.nvim_set_option_value("modifiable", false, { buf = self.buf }) + api.nvim_set_option_value("swapfile", false, { buf = self.buf }) + + -- create a split + vim.cmd(split_command) + + --get current window and attach the buffer to it + self.win = api.nvim_get_current_win() + api.nvim_win_set_buf(self.win, self.buf) + + vim.cmd("vertical resize " .. size) + + -- win stuff + api.nvim_set_option_value("spell", false, { win = self.win }) + api.nvim_set_option_value("signcolumn", "no", { win = self.win }) + api.nvim_set_option_value("foldcolumn", "0", { win = self.win }) + api.nvim_set_option_value("number", false, { win = self.win }) + api.nvim_set_option_value("relativenumber", false, { win = self.win }) + api.nvim_set_option_value("list", false, { win = self.win }) + api.nvim_set_option_value("wrap", false, { win = self.win }) + api.nvim_set_option_value("winhl", "", { win = self.win }) + + -- buffer stuff + api.nvim_buf_set_name(self.buf, RESULT_BUF_NAME) + + return self +end + +function View:close() + if self.win then + api.nvim_win_close(self.win, true) + self.win = nil + self.buf = nil + end +end + +function View:is_open() + return self.win and self.buf and api.nvim_buf_is_valid(self.buf) and api.nvim_win_is_valid(self.win) +end + +return View