diff --git a/.github/workflows/lua.yaml b/.github/workflows/lua.yaml index 70d690d..dd4bbb1 100644 --- a/.github/workflows/lua.yaml +++ b/.github/workflows/lua.yaml @@ -6,15 +6,41 @@ on: paths: - "lua/**/*.lua" - "plugin/**/*.lua" + - ".stylua.toml" + - ".luacheckrc" - .github/workflows/lua.yaml pull_request: branches: [master] paths: - "lua/**/*.lua" - "plugin/**/*.lua" + - ".stylua.toml" + - ".luacheckrc" - .github/workflows/lua.yaml jobs: + format: + name: StyLua auto-format + runs-on: ubuntu-latest + if: github.event_name == 'push' + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + + - name: Run StyLua + uses: JohnnyMorganz/stylua-action@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + version: latest + args: lua/ plugin/ + + - name: Commit formatting changes + uses: stefanzweifel/git-auto-commit-action@v5 + with: + commit_message: "style: auto-format with stylua" + file_pattern: "lua/**/*.lua plugin/**/*.lua" + lint: name: Luacheck runs-on: ubuntu-latest @@ -31,19 +57,7 @@ jobs: run: luarocks install luacheck - name: Run luacheck - run: luacheck lua/ plugin/ --globals vim describe it before_each after_each assert --no-max-line-length - - stylua: - name: StyLua format check - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: JohnnyMorganz/stylua-action@v4 - with: - token: ${{ secrets.GITHUB_TOKEN }} - version: latest - args: --check lua/ plugin/ + run: luacheck lua/ plugin/ health: name: Plugin load check diff --git a/.luacheckrc b/.luacheckrc new file mode 100644 index 0000000..2b4f736 --- /dev/null +++ b/.luacheckrc @@ -0,0 +1,49 @@ +std = "luajit" + +globals = { + "vim", + "_", +} + +read_globals = { + "describe", + "it", + "before_each", + "after_each", + "assert", +} + +max_line_length = false + +ignore = { + "211", -- unused function + "212", -- unused argument + "213", -- unused loop variable + "311", -- value assigned is unused + "312", -- value of argument is unused + "314", -- value of field is overwritten before use + "411", -- variable redefines + "421", -- shadowing local variable + "431", -- shadowing upvalue + "432", -- shadowing upvalue argument + "511", -- unreachable code + "542", -- empty if branch + "631", -- max_line_length +} + +files["lua/codetyper/adapters/nvim/autocmds.lua"] = { + ignore = { "111", "113", "131", "231", "241" }, -- TODO: fix undefined refs and dead stores +} + +files["lua/codetyper/adapters/nvim/ui/context_modal.lua"] = { + ignore = { "113" }, -- TODO: fix undefined run_project_inspect +} + +files["lua/codetyper/core/scheduler/loop.lua"] = { + ignore = { "241" }, -- mutated but never accessed +} + +exclude_files = { + ".luarocks", + ".luacache", +} diff --git a/.stylua.toml b/.stylua.toml new file mode 100644 index 0000000..6090f42 --- /dev/null +++ b/.stylua.toml @@ -0,0 +1,6 @@ +column_width = 120 +line_endings = "Unix" +indent_type = "Spaces" +indent_width = 2 +quote_style = "AutoPreferDouble" +call_parentheses = "Always" diff --git a/lua/codetyper/adapters/nvim/autocmds.lua b/lua/codetyper/adapters/nvim/autocmds.lua index cd884c2..e5ea37c 100644 --- a/lua/codetyper/adapters/nvim/autocmds.lua +++ b/lua/codetyper/adapters/nvim/autocmds.lua @@ -33,241 +33,240 @@ local PROMPT_PROCESS_DEBOUNCE_MS = 200 -- Wait 200ms after mode change before pr ---@param prompt table Prompt object ---@return string Unique key local function get_prompt_key(bufnr, prompt) - return string.format("%d:%d:%d:%s", bufnr, prompt.start_line, prompt.end_line, prompt.content:sub(1, 50)) + return string.format("%d:%d:%d:%s", bufnr, prompt.start_line, prompt.end_line, prompt.content:sub(1, 50)) end --- Schedule tree update with debounce local function schedule_tree_update() - if tree_update_timer then - tree_update_timer:stop() - end + if tree_update_timer then + tree_update_timer:stop() + end - tree_update_timer = vim.defer_fn(function() - local tree = require("codetyper.support.tree") - tree.update_tree_log() - tree_update_timer = nil - end, TREE_UPDATE_DEBOUNCE_MS) + tree_update_timer = vim.defer_fn(function() + local tree = require("codetyper.support.tree") + tree.update_tree_log() + tree_update_timer = nil + end, TREE_UPDATE_DEBOUNCE_MS) end --- Setup autocommands function M.setup() - local group = vim.api.nvim_create_augroup(AUGROUP, { clear = true }) + local group = vim.api.nvim_create_augroup(AUGROUP, { clear = true }) - -- Auto-check for closed prompts when leaving insert mode (works on ALL files) - vim.api.nvim_create_autocmd("InsertLeave", { - group = group, - pattern = "*", - callback = function() - -- Skip special buffers - local buftype = vim.bo.buftype - if buftype ~= "" then - return - end - -- Auto-save coder files only - local filepath = vim.fn.expand("%:p") - if utils.is_coder_file(filepath) and vim.bo.modified then - vim.cmd("silent! write") - end - -- Check for closed prompts and auto-process (respects preferences) - M.check_for_closed_prompt_with_preference() - end, - desc = "Check for closed prompt tags on InsertLeave", - }) + -- Auto-check for closed prompts when leaving insert mode (works on ALL files) + vim.api.nvim_create_autocmd("InsertLeave", { + group = group, + pattern = "*", + callback = function() + -- Skip special buffers + local buftype = vim.bo.buftype + if buftype ~= "" then + return + end + -- Auto-save coder files only + local filepath = vim.fn.expand("%:p") + if utils.is_coder_file(filepath) and vim.bo.modified then + vim.cmd("silent! write") + end + -- Check for closed prompts and auto-process (respects preferences) + M.check_for_closed_prompt_with_preference() + end, + desc = "Check for closed prompt tags on InsertLeave", + }) - -- Track mode changes for visual mode detection - vim.api.nvim_create_autocmd("ModeChanged", { - group = group, - pattern = "*", - callback = function(ev) - -- Extract old mode from pattern (format: "old_mode:new_mode") - local old_mode = ev.match:match("^(.-):") - if old_mode then - previous_mode = old_mode - end - end, - desc = "Track previous mode for visual mode detection", - }) + -- Track mode changes for visual mode detection + vim.api.nvim_create_autocmd("ModeChanged", { + group = group, + pattern = "*", + callback = function(ev) + -- Extract old mode from pattern (format: "old_mode:new_mode") + local old_mode = ev.match:match("^(.-):") + if old_mode then + previous_mode = old_mode + end + end, + desc = "Track previous mode for visual mode detection", + }) - -- Auto-process prompts when entering normal mode (works on ALL files) - vim.api.nvim_create_autocmd("ModeChanged", { - group = group, - pattern = "*:n", - callback = function() - -- Skip special buffers - local buftype = vim.bo.buftype - if buftype ~= "" then - return - end + -- Auto-process prompts when entering normal mode (works on ALL files) + vim.api.nvim_create_autocmd("ModeChanged", { + group = group, + pattern = "*:n", + callback = function() + -- Skip special buffers + local buftype = vim.bo.buftype + if buftype ~= "" then + return + end - -- Skip if currently processing (avoid concurrent processing) - if is_processing then - return - end + -- Skip if currently processing (avoid concurrent processing) + if is_processing then + return + end - -- Skip if coming from visual mode (v, V, CTRL-V) - user is still editing - if previous_mode == "v" or previous_mode == "V" or previous_mode == "\22" then - return - end + -- Skip if coming from visual mode (v, V, CTRL-V) - user is still editing + if previous_mode == "v" or previous_mode == "V" or previous_mode == "\22" then + return + end - -- Cancel any pending processing timer - if prompt_process_timer then - prompt_process_timer:stop() - prompt_process_timer = nil - end + -- Cancel any pending processing timer + if prompt_process_timer then + prompt_process_timer:stop() + prompt_process_timer = nil + end - -- Debounced processing - wait for user to truly be idle - prompt_process_timer = vim.defer_fn(function() - prompt_process_timer = nil - -- Double-check we're still in normal mode - local mode = vim.api.nvim_get_mode().mode - if mode ~= "n" then - return - end - M.check_all_prompts_with_preference() - end, PROMPT_PROCESS_DEBOUNCE_MS) - end, - desc = "Auto-process closed prompts when entering normal mode", - }) + -- Debounced processing - wait for user to truly be idle + prompt_process_timer = vim.defer_fn(function() + prompt_process_timer = nil + -- Double-check we're still in normal mode + local mode = vim.api.nvim_get_mode().mode + if mode ~= "n" then + return + end + M.check_all_prompts_with_preference() + end, PROMPT_PROCESS_DEBOUNCE_MS) + end, + desc = "Auto-process closed prompts when entering normal mode", + }) - -- Also check on CursorHold as backup (works on ALL files) - vim.api.nvim_create_autocmd("CursorHold", { - group = group, - pattern = "*", - callback = function() - -- Skip special buffers - local buftype = vim.bo.buftype - if buftype ~= "" then - return - end - -- Skip if currently processing - if is_processing then - return - end - local mode = vim.api.nvim_get_mode().mode - if mode == "n" then - M.check_all_prompts_with_preference() - end - end, - desc = "Auto-process closed prompts when idle in normal mode", - }) + -- Also check on CursorHold as backup (works on ALL files) + vim.api.nvim_create_autocmd("CursorHold", { + group = group, + pattern = "*", + callback = function() + -- Skip special buffers + local buftype = vim.bo.buftype + if buftype ~= "" then + return + end + -- Skip if currently processing + if is_processing then + return + end + local mode = vim.api.nvim_get_mode().mode + if mode == "n" then + M.check_all_prompts_with_preference() + end + end, + desc = "Auto-process closed prompts when idle in normal mode", + }) - -- Auto-set filetype for coder files based on extension - vim.api.nvim_create_autocmd({ "BufRead", "BufNewFile" }, { - group = group, - pattern = "*.codetyper/*", - callback = function() - M.set_coder_filetype() - end, - desc = "Set filetype for coder files", - }) + -- Auto-set filetype for coder files based on extension + vim.api.nvim_create_autocmd({ "BufRead", "BufNewFile" }, { + group = group, + pattern = "*.codetyper/*", + callback = function() + M.set_coder_filetype() + end, + desc = "Set filetype for coder files", + }) + -- Cleanup on buffer close + vim.api.nvim_create_autocmd("BufWipeout", { + group = group, + pattern = "*.codetyper/*", + callback = function(ev) + -- Clear processed prompts for this buffer + local bufnr = ev.buf + for key, _ in pairs(processed_prompts) do + if key:match("^" .. bufnr .. ":") then + processed_prompts[key] = nil + end + end + -- Clear auto-opened tracking + M.clear_auto_opened(bufnr) + end, + desc = "Cleanup on coder buffer close", + }) - -- Cleanup on buffer close - vim.api.nvim_create_autocmd("BufWipeout", { - group = group, - pattern = "*.codetyper/*", - callback = function(ev) - -- Clear processed prompts for this buffer - local bufnr = ev.buf - for key, _ in pairs(processed_prompts) do - if key:match("^" .. bufnr .. ":") then - processed_prompts[key] = nil - end - end - -- Clear auto-opened tracking - M.clear_auto_opened(bufnr) - end, - desc = "Cleanup on coder buffer close", - }) + -- Update tree.log when files are created/written + vim.api.nvim_create_autocmd({ "BufWritePost", "BufNewFile" }, { + group = group, + pattern = "*", + callback = function(ev) + -- Skip coder files and tree.log itself + local filepath = ev.file or vim.fn.expand("%:p") + if filepath:match("%.codetyper%.") or filepath:match("tree%.log$") then + return + end + -- Skip non-project files + if filepath:match("node_modules") or filepath:match("%.git/") or filepath:match("%.codetyper/") then + return + end + -- Schedule tree update with debounce + schedule_tree_update() - -- Update tree.log when files are created/written - vim.api.nvim_create_autocmd({ "BufWritePost", "BufNewFile" }, { - group = group, - pattern = "*", - callback = function(ev) - -- Skip coder files and tree.log itself - local filepath = ev.file or vim.fn.expand("%:p") - if filepath:match("%.codetyper%.") or filepath:match("tree%.log$") then - return - end - -- Skip non-project files - if filepath:match("node_modules") or filepath:match("%.git/") or filepath:match("%.codetyper/") then - return - end - -- Schedule tree update with debounce - schedule_tree_update() + -- Trigger incremental indexing if enabled + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if ok_indexer then + indexer.schedule_index_file(filepath) + end - -- Trigger incremental indexing if enabled - local ok_indexer, indexer = pcall(require, "codetyper.indexer") - if ok_indexer then - indexer.schedule_index_file(filepath) - end + -- Update brain with file patterns + local ok_brain, brain = pcall(require, "codetyper.brain") + if ok_brain and brain.is_initialized and brain.is_initialized() then + vim.defer_fn(function() + M.update_brain_from_file(filepath) + end, 500) -- Debounce brain updates + end + end, + desc = "Update tree.log, index, and brain on file creation/save", + }) - -- Update brain with file patterns - local ok_brain, brain = pcall(require, "codetyper.brain") - if ok_brain and brain.is_initialized and brain.is_initialized() then - vim.defer_fn(function() - M.update_brain_from_file(filepath) - end, 500) -- Debounce brain updates - end - end, - desc = "Update tree.log, index, and brain on file creation/save", - }) + -- Update tree.log when files are deleted (via netrw or file explorer) + vim.api.nvim_create_autocmd("BufDelete", { + group = group, + pattern = "*", + callback = function(ev) + local filepath = ev.file or "" + -- Skip special buffers and coder files + if filepath == "" or filepath:match("%.codetyper%.") or filepath:match("tree%.log$") then + return + end + schedule_tree_update() + end, + desc = "Update tree.log on file deletion", + }) - -- Update tree.log when files are deleted (via netrw or file explorer) - vim.api.nvim_create_autocmd("BufDelete", { - group = group, - pattern = "*", - callback = function(ev) - local filepath = ev.file or "" - -- Skip special buffers and coder files - if filepath == "" or filepath:match("%.codetyper%.") or filepath:match("tree%.log$") then - return - end - schedule_tree_update() - end, - desc = "Update tree.log on file deletion", - }) + -- Update tree on directory change + vim.api.nvim_create_autocmd("DirChanged", { + group = group, + pattern = "*", + callback = function() + schedule_tree_update() + end, + desc = "Update tree.log on directory change", + }) - -- Update tree on directory change - vim.api.nvim_create_autocmd("DirChanged", { - group = group, - pattern = "*", - callback = function() - schedule_tree_update() - end, - desc = "Update tree.log on directory change", - }) + -- Shutdown brain on Vim exit + vim.api.nvim_create_autocmd("VimLeavePre", { + group = group, + pattern = "*", + callback = function() + local ok, brain = pcall(require, "codetyper.brain") + if ok and brain.is_initialized and brain.is_initialized() then + brain.shutdown() + end + end, + desc = "Shutdown brain and flush pending changes", + }) - -- Shutdown brain on Vim exit - vim.api.nvim_create_autocmd("VimLeavePre", { - group = group, - pattern = "*", - callback = function() - local ok, brain = pcall(require, "codetyper.brain") - if ok and brain.is_initialized and brain.is_initialized() then - brain.shutdown() - end - end, - desc = "Shutdown brain and flush pending changes", - }) + -- Auto-index: Create/open coder companion file when opening source files + vim.api.nvim_create_autocmd("BufEnter", { + group = group, + pattern = "*", + callback = function(ev) + -- Delay to ensure buffer is fully loaded + vim.defer_fn(function() + M.auto_index_file(ev.buf) + end, 100) + end, + desc = "Auto-index source files with coder companion", + }) - -- Auto-index: Create/open coder companion file when opening source files - vim.api.nvim_create_autocmd("BufEnter", { - group = group, - pattern = "*", - callback = function(ev) - -- Delay to ensure buffer is fully loaded - vim.defer_fn(function() - M.auto_index_file(ev.buf) - end, 100) - end, - desc = "Auto-index source files with coder companion", - }) - - -- Thinking indicator (throbber) cleanup on exit - local thinking = require("codetyper.adapters.nvim.ui.thinking") - thinking.setup() + -- Thinking indicator (throbber) cleanup on exit + local thinking = require("codetyper.adapters.nvim.ui.thinking") + thinking.setup() end --- Create extmarks for injection range so position survives user edits (99-style). @@ -275,30 +274,30 @@ end ---@param range { start_line: number, end_line: number } Range to mark (1-based) ---@return table|nil injection_marks { start_mark, end_mark } or nil if buffer invalid local function create_injection_marks(target_bufnr, range) - if not range or target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then - return nil - end - local line_count = vim.api.nvim_buf_line_count(target_bufnr) - if line_count == 0 then - return nil - end - -- Clamp to valid 1-based line range (event range may refer to source buffer, target can be different) - local start_line = math.max(1, math.min(range.start_line, line_count)) - local end_line = math.max(1, math.min(range.end_line, line_count)) - if start_line > end_line then - end_line = start_line - end - local marks = require("codetyper.core.marks") - local end_line_content = vim.api.nvim_buf_get_lines(target_bufnr, end_line - 1, end_line, false) - local end_col_0 = 0 - if end_line_content and end_line_content[1] then - end_col_0 = #end_line_content[1] - end - local start_mark, end_mark = marks.mark_range(target_bufnr, start_line, end_line, end_col_0) - if not start_mark.id or not end_mark.id then - return nil - end - return { start_mark = start_mark, end_mark = end_mark } + if not range or target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then + return nil + end + local line_count = vim.api.nvim_buf_line_count(target_bufnr) + if line_count == 0 then + return nil + end + -- Clamp to valid 1-based line range (event range may refer to source buffer, target can be different) + local start_line = math.max(1, math.min(range.start_line, line_count)) + local end_line = math.max(1, math.min(range.end_line, line_count)) + if start_line > end_line then + end_line = start_line + end + local marks = require("codetyper.core.marks") + local end_line_content = vim.api.nvim_buf_get_lines(target_bufnr, end_line - 1, end_line, false) + local end_col_0 = 0 + if end_line_content and end_line_content[1] then + end_col_0 = #end_line_content[1] + end + local start_mark, end_mark = marks.mark_range(target_bufnr, start_line, end_line, end_col_0) + if not start_mark.id or not end_mark.id then + return nil + end + return { start_mark = start_mark, end_mark = end_mark } end --- Read attached files from prompt content @@ -306,250 +305,248 @@ end ---@param base_path string Base path to resolve relative file paths ---@return table[] attached_files List of {path, content} tables local function read_attached_files(prompt_content, base_path) - local parser = require("codetyper.parser") - local file_refs = parser.extract_file_references(prompt_content) - local attached = {} - local cwd = vim.fn.getcwd() - local base_dir = vim.fn.fnamemodify(base_path, ":h") + local parser = require("codetyper.parser") + local file_refs = parser.extract_file_references(prompt_content) + local attached = {} + local cwd = vim.fn.getcwd() + local base_dir = vim.fn.fnamemodify(base_path, ":h") - for _, ref in ipairs(file_refs) do - local file_path = nil + for _, ref in ipairs(file_refs) do + local file_path = nil - -- Try resolving relative to cwd first - local cwd_path = cwd .. "/" .. ref - if utils.file_exists(cwd_path) then - file_path = cwd_path - else - -- Try resolving relative to base file directory - local rel_path = base_dir .. "/" .. ref - if utils.file_exists(rel_path) then - file_path = rel_path - end - end + -- Try resolving relative to cwd first + local cwd_path = cwd .. "/" .. ref + if utils.file_exists(cwd_path) then + file_path = cwd_path + else + -- Try resolving relative to base file directory + local rel_path = base_dir .. "/" .. ref + if utils.file_exists(rel_path) then + file_path = rel_path + end + end - if file_path then - local content = utils.read_file(file_path) - if content then - table.insert(attached, { - path = ref, - full_path = file_path, - content = content, - }) - end - end - end + if file_path then + local content = utils.read_file(file_path) + if content then + table.insert(attached, { + path = ref, + full_path = file_path, + content = content, + }) + end + end + end - return attached + return attached end --- Check if the buffer has a newly closed prompt and auto-process (works on ANY file) function M.check_for_closed_prompt() - -- Skip if already processing - if is_processing then - return - end - is_processing = true + -- Skip if already processing + if is_processing then + return + end + is_processing = true - local parser = require("codetyper.parser") + local parser = require("codetyper.parser") - local bufnr = vim.api.nvim_get_current_buf() - local current_file = vim.fn.expand("%:p") + local bufnr = vim.api.nvim_get_current_buf() + local current_file = vim.fn.expand("%:p") - -- Skip if no file - if current_file == "" then - is_processing = false - return - end + -- Skip if no file + if current_file == "" then + is_processing = false + return + end - -- Get current line - local cursor = vim.api.nvim_win_get_cursor(0) - local line = cursor[1] - local lines = vim.api.nvim_buf_get_lines(bufnr, line - 1, line, false) + -- Get current line + local cursor = vim.api.nvim_win_get_cursor(0) + local line = cursor[1] + local lines = vim.api.nvim_buf_get_lines(bufnr, line - 1, line, false) - if #lines == 0 then - is_processing = false - return - end + if #lines == 0 then + is_processing = false + return + end - local current_line = lines[1] + local current_line = lines[1] - -- Check if line contains closing tag - if parser.has_closing_tag(current_line, config.patterns.close_tag) then - -- Find the complete prompt - local prompt = parser.get_last_prompt(bufnr) - if prompt and prompt.content and prompt.content ~= "" then - -- Generate unique key for this prompt - local prompt_key = get_prompt_key(bufnr, prompt) + -- Check if line contains closing tag + if parser.has_closing_tag(current_line, config.patterns.close_tag) then + -- Find the complete prompt + local prompt = parser.get_last_prompt(bufnr) + if prompt and prompt.content and prompt.content ~= "" then + -- Generate unique key for this prompt + local prompt_key = get_prompt_key(bufnr, prompt) - -- Check if already processed - if processed_prompts[prompt_key] then - is_processing = false - return - end + -- Check if already processed + if processed_prompts[prompt_key] then + is_processing = false + return + end - -- Mark as processed - processed_prompts[prompt_key] = true + -- Mark as processed + processed_prompts[prompt_key] = true - -- Check if scheduler is enabled - local codetyper = require("codetyper") - local ct_config = codetyper.get_config() - local scheduler_enabled = ct_config and ct_config.scheduler and ct_config.scheduler.enabled + -- Check if scheduler is enabled + local codetyper = require("codetyper") + local ct_config = codetyper.get_config() + local scheduler_enabled = ct_config and ct_config.scheduler and ct_config.scheduler.enabled - if scheduler_enabled then - -- Event-driven: emit to queue - vim.schedule(function() - local queue = require("codetyper.core.events.queue") - local patch_mod = require("codetyper.core.diff.patch") - local intent_mod = require("codetyper.core.intent") - local scope_mod = require("codetyper.core.scope") - -- In-buffer placeholder "@thinking .... end thinking" is inserted when worker starts (scheduler) + if scheduler_enabled then + -- Event-driven: emit to queue + vim.schedule(function() + local queue = require("codetyper.core.events.queue") + local patch_mod = require("codetyper.core.diff.patch") + local intent_mod = require("codetyper.core.intent") + local scope_mod = require("codetyper.core.scope") + -- In-buffer placeholder "@thinking .... end thinking" is inserted when worker starts (scheduler) - -- Take buffer snapshot - local snapshot = patch_mod.snapshot_buffer(bufnr, { - start_line = prompt.start_line, - end_line = prompt.end_line, - }) + -- Take buffer snapshot + local snapshot = patch_mod.snapshot_buffer(bufnr, { + start_line = prompt.start_line, + end_line = prompt.end_line, + }) - -- Get target path - for coder files, get the target; for regular files, use self - local target_path - if utils.is_coder_file(current_file) then - target_path = utils.get_target_path(current_file) - else - target_path = current_file - end + -- Get target path - for coder files, get the target; for regular files, use self + local target_path + if utils.is_coder_file(current_file) then + target_path = utils.get_target_path(current_file) + else + target_path = current_file + end - -- Read attached files before cleaning - local attached_files = read_attached_files(prompt.content, current_file) + -- Read attached files before cleaning + local attached_files = read_attached_files(prompt.content, current_file) - -- Clean prompt content (strip file references) - local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content)) + -- Clean prompt content (strip file references) + local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content)) - -- Check if we're working from a coder file - local is_from_coder_file = utils.is_coder_file(current_file) + -- Check if we're working from a coder file + local is_from_coder_file = utils.is_coder_file(current_file) - -- Resolve scope in target file FIRST (need it to adjust intent) - -- Only resolve scope if NOT from coder file (line numbers don't apply) - local target_bufnr = vim.fn.bufnr(target_path) - local scope = nil - local scope_text = nil - local scope_range = nil + -- Resolve scope in target file FIRST (need it to adjust intent) + -- Only resolve scope if NOT from coder file (line numbers don't apply) + local target_bufnr = vim.fn.bufnr(target_path) + local scope = nil + local scope_text = nil + local scope_range = nil - if not is_from_coder_file then - -- Prompt is in the actual source file, use line position for scope - if target_bufnr == -1 then - target_bufnr = bufnr - end - scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1) - if scope and scope.type ~= "file" then - scope_text = scope.text - scope_range = { - start_line = scope.range.start_row, - end_line = scope.range.end_row, - } - end - else - -- Prompt is in coder file - load target if needed, but don't use scope - -- Code from coder files should append to target by default - if target_bufnr == -1 then - target_bufnr = vim.fn.bufadd(target_path) - if target_bufnr ~= 0 then - vim.fn.bufload(target_bufnr) - end - end - end + if not is_from_coder_file then + -- Prompt is in the actual source file, use line position for scope + if target_bufnr == -1 then + target_bufnr = bufnr + end + scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1) + if scope and scope.type ~= "file" then + scope_text = scope.text + scope_range = { + start_line = scope.range.start_row, + end_line = scope.range.end_row, + } + end + else + -- Prompt is in coder file - load target if needed, but don't use scope + -- Code from coder files should append to target by default + if target_bufnr == -1 then + target_bufnr = vim.fn.bufadd(target_path) + if target_bufnr ~= 0 then + vim.fn.bufload(target_bufnr) + end + end + end - -- Detect intent from prompt - local intent = intent_mod.detect(cleaned) + -- Detect intent from prompt + local intent = intent_mod.detect(cleaned) - -- IMPORTANT: If prompt is inside a function/method and intent is "add", - -- override to "complete" since we're completing the function body - -- But NOT for coder files - they should use "add/append" by default - if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then - if intent.type == "add" or intent.action == "insert" or intent.action == "append" then - -- Override to complete the function instead of adding new code - intent = { - type = "complete", - scope_hint = "function", - confidence = intent.confidence, - action = "replace", - keywords = intent.keywords, - } - end - end + -- IMPORTANT: If prompt is inside a function/method and intent is "add", + -- override to "complete" since we're completing the function body + -- But NOT for coder files - they should use "add/append" by default + if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then + if intent.type == "add" or intent.action == "insert" or intent.action == "append" then + -- Override to complete the function instead of adding new code + intent = { + type = "complete", + scope_hint = "function", + confidence = intent.confidence, + action = "replace", + keywords = intent.keywords, + } + end + end - -- For coder files, default to "add" with "append" action - if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then - intent = { - type = intent.type == "complete" and "add" or intent.type, - confidence = intent.confidence, - action = "append", - keywords = intent.keywords, - } - end + -- For coder files, default to "add" with "append" action + if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then + intent = { + type = intent.type == "complete" and "add" or intent.type, + confidence = intent.confidence, + action = "append", + keywords = intent.keywords, + } + end - -- Determine priority based on intent - local priority = 2 -- Normal - if intent.type == "fix" or intent.type == "complete" then - priority = 1 -- High priority for fixes and completions - elseif intent.type == "test" or intent.type == "document" then - priority = 3 -- Lower priority for tests and docs - end + -- Determine priority based on intent + local priority = 2 -- Normal + if intent.type == "fix" or intent.type == "complete" then + priority = 1 -- High priority for fixes and completions + elseif intent.type == "test" or intent.type == "document" then + priority = 3 -- Lower priority for tests and docs + end - -- Use captured injection range when provided, else prompt.start_line/end_line - local raw_start = (prompt.injection_range and prompt.injection_range.start_line) - or prompt.start_line - or 1 - local raw_end = (prompt.injection_range and prompt.injection_range.end_line) or prompt.end_line or 1 - local tc = vim.api.nvim_buf_line_count(target_bufnr) - tc = math.max(1, tc) - local rs = math.max(1, math.min(raw_start, tc)) - local re = math.max(1, math.min(raw_end, tc)) - if re < rs then - re = rs - end - local event_range = { start_line = rs, end_line = re } + -- Use captured injection range when provided, else prompt.start_line/end_line + local raw_start = (prompt.injection_range and prompt.injection_range.start_line) or prompt.start_line or 1 + local raw_end = (prompt.injection_range and prompt.injection_range.end_line) or prompt.end_line or 1 + local tc = vim.api.nvim_buf_line_count(target_bufnr) + tc = math.max(1, tc) + local rs = math.max(1, math.min(raw_start, tc)) + local re = math.max(1, math.min(raw_end, tc)) + if re < rs then + re = rs + end + local event_range = { start_line = rs, end_line = re } - -- Extmarks for injection range (99-style: position survives user typing) - local range_for_marks = scope_range or event_range - local injection_marks = create_injection_marks(target_bufnr, range_for_marks) + -- Extmarks for injection range (99-style: position survives user typing) + local range_for_marks = scope_range or event_range + local injection_marks = create_injection_marks(target_bufnr, range_for_marks) - -- Enqueue the event (event.range = where to apply the generated code) - queue.enqueue({ - id = queue.generate_id(), - bufnr = bufnr, - range = event_range, - timestamp = os.clock(), - changedtick = snapshot.changedtick, - content_hash = snapshot.content_hash, - prompt_content = cleaned, - target_path = target_path, - priority = priority, - status = "pending", - attempt_count = 0, - intent = intent, - scope = scope, - scope_text = scope_text, - scope_range = scope_range, - attached_files = attached_files, - injection_marks = injection_marks, - }) + -- Enqueue the event (event.range = where to apply the generated code) + queue.enqueue({ + id = queue.generate_id(), + bufnr = bufnr, + range = event_range, + timestamp = os.clock(), + changedtick = snapshot.changedtick, + content_hash = snapshot.content_hash, + prompt_content = cleaned, + target_path = target_path, + priority = priority, + status = "pending", + attempt_count = 0, + intent = intent, + scope = scope, + scope_text = scope_text, + scope_range = scope_range, + attached_files = attached_files, + injection_marks = injection_marks, + }) - local scope_info = scope - and scope.type ~= "file" - and string.format(" [%s: %s]", scope.type, scope.name or "anonymous") - or "" - utils.notify(string.format("Prompt queued: %s%s", intent.type, scope_info), vim.log.levels.INFO) - end) - else - -- Legacy: direct processing - utils.notify("Processing prompt...", vim.log.levels.INFO) - vim.schedule(function() - vim.cmd("CoderProcess") - end) - end - end - end - is_processing = false + local scope_info = scope + and scope.type ~= "file" + and string.format(" [%s: %s]", scope.type, scope.name or "anonymous") + or "" + utils.notify(string.format("Prompt queued: %s%s", intent.type, scope_info), vim.log.levels.INFO) + end) + else + -- Legacy: direct processing + utils.notify("Processing prompt...", vim.log.levels.INFO) + vim.schedule(function() + vim.cmd("CoderProcess") + end) + end + end + end + is_processing = false end --- Process a single prompt through the scheduler @@ -559,289 +556,289 @@ end ---@param current_file string Current file path ---@param skip_processed_check? boolean Skip the processed check (for manual mode) function M.process_single_prompt(bufnr, prompt, current_file, skip_processed_check) - local parser = require("codetyper.parser") - local scheduler = require("codetyper.core.scheduler.scheduler") + local parser = require("codetyper.parser") + local scheduler = require("codetyper.core.scheduler.scheduler") - if not prompt.content or prompt.content == "" then - return - end + if not prompt.content or prompt.content == "" then + return + end - -- Ensure scheduler is running - if not scheduler.status().running then - scheduler.start() - end + -- Ensure scheduler is running + if not scheduler.status().running then + scheduler.start() + end - -- Generate unique key for this prompt - local prompt_key = get_prompt_key(bufnr, prompt) + -- Generate unique key for this prompt + local prompt_key = get_prompt_key(bufnr, prompt) - -- Skip if already processed (unless overridden for manual mode) - if not skip_processed_check and processed_prompts[prompt_key] then - return - end + -- Skip if already processed (unless overridden for manual mode) + if not skip_processed_check and processed_prompts[prompt_key] then + return + end - -- Mark as processed - processed_prompts[prompt_key] = true + -- Mark as processed + processed_prompts[prompt_key] = true - -- Process this prompt - vim.schedule(function() - local queue = require("codetyper.core.events.queue") - local patch_mod = require("codetyper.core.diff.patch") - local intent_mod = require("codetyper.core.intent") - local scope_mod = require("codetyper.core.scope") - -- In-buffer placeholder "@thinking .... end thinking" is inserted when worker starts (scheduler) + -- Process this prompt + vim.schedule(function() + local queue = require("codetyper.core.events.queue") + local patch_mod = require("codetyper.core.diff.patch") + local intent_mod = require("codetyper.core.intent") + local scope_mod = require("codetyper.core.scope") + -- In-buffer placeholder "@thinking .... end thinking" is inserted when worker starts (scheduler) - -- Take buffer snapshot - local snapshot = patch_mod.snapshot_buffer(bufnr, { - start_line = prompt.start_line, - end_line = prompt.end_line, - }) + -- Take buffer snapshot + local snapshot = patch_mod.snapshot_buffer(bufnr, { + start_line = prompt.start_line, + end_line = prompt.end_line, + }) - -- Get target path - for coder files, get the target; for regular files, use self - local target_path - local is_from_coder_file = utils.is_coder_file(current_file) - if is_from_coder_file then - target_path = utils.get_target_path(current_file) - else - target_path = current_file - end + -- Get target path - for coder files, get the target; for regular files, use self + local target_path + local is_from_coder_file = utils.is_coder_file(current_file) + if is_from_coder_file then + target_path = utils.get_target_path(current_file) + else + target_path = current_file + end - -- Read attached files before cleaning - local attached_files = read_attached_files(prompt.content, current_file) + -- Read attached files before cleaning + local attached_files = read_attached_files(prompt.content, current_file) - -- Clean prompt content (strip file references) - local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content)) + -- Clean prompt content (strip file references) + local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content)) - -- Resolve scope in target file FIRST (need it to adjust intent) - -- Only resolve scope if NOT from coder file (line numbers don't apply) - local target_bufnr = vim.fn.bufnr(target_path) - local scope = nil - local scope_text = nil - local scope_range = nil + -- Resolve scope in target file FIRST (need it to adjust intent) + -- Only resolve scope if NOT from coder file (line numbers don't apply) + local target_bufnr = vim.fn.bufnr(target_path) + local scope = nil + local scope_text = nil + local scope_range = nil - if not is_from_coder_file then - -- Prompt is in the actual source file, use line position for scope - if target_bufnr == -1 then - target_bufnr = bufnr - end - scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1) - if scope and scope.type ~= "file" then - scope_text = scope.text - scope_range = { - start_line = scope.range.start_row, - end_line = scope.range.end_row, - } - end - else - -- Prompt is in coder file - load target if needed - if target_bufnr == -1 then - target_bufnr = vim.fn.bufadd(target_path) - if target_bufnr ~= 0 then - vim.fn.bufload(target_bufnr) - end - end - end + if not is_from_coder_file then + -- Prompt is in the actual source file, use line position for scope + if target_bufnr == -1 then + target_bufnr = bufnr + end + scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1) + if scope and scope.type ~= "file" then + scope_text = scope.text + scope_range = { + start_line = scope.range.start_row, + end_line = scope.range.end_row, + } + end + else + -- Prompt is in coder file - load target if needed + if target_bufnr == -1 then + target_bufnr = vim.fn.bufadd(target_path) + if target_bufnr ~= 0 then + vim.fn.bufload(target_bufnr) + end + end + end - -- Detect intent from prompt (honor explicit override from transform-selection) - local intent = intent_mod.detect(cleaned) + -- Detect intent from prompt (honor explicit override from transform-selection) + local intent = intent_mod.detect(cleaned) - if prompt.intent_override then - intent.action = prompt.intent_override.action or intent.action - if prompt.intent_override.type then - intent.type = prompt.intent_override.type - end - elseif not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then - if intent.type == "add" or intent.action == "insert" or intent.action == "append" then - intent = { - type = "complete", - scope_hint = "function", - confidence = intent.confidence, - action = "replace", - keywords = intent.keywords, - } - end - end + if prompt.intent_override then + intent.action = prompt.intent_override.action or intent.action + if prompt.intent_override.type then + intent.type = prompt.intent_override.type + end + elseif not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then + if intent.type == "add" or intent.action == "insert" or intent.action == "append" then + intent = { + type = "complete", + scope_hint = "function", + confidence = intent.confidence, + action = "replace", + keywords = intent.keywords, + } + end + end - -- For coder files, default to "add" with "append" action - if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then - intent = { - type = intent.type == "complete" and "add" or intent.type, - confidence = intent.confidence, - action = "append", - keywords = intent.keywords, - } - end + -- For coder files, default to "add" with "append" action + if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then + intent = { + type = intent.type == "complete" and "add" or intent.type, + confidence = intent.confidence, + action = "append", + keywords = intent.keywords, + } + end - -- For whole-file selections, gather project tree context - local project_context = nil - if prompt.is_whole_file then - pcall(function() - local tree = require("codetyper.support.tree") - local tree_log = tree.get_tree_log_path() - if tree_log and vim.fn.filereadable(tree_log) == 1 then - local tree_lines = vim.fn.readfile(tree_log) - if tree_lines and #tree_lines > 0 then - local tree_content = table.concat(tree_lines, "\n") - project_context = tree_content:sub(1, 4000) - end - end - end) - end + -- For whole-file selections, gather project tree context + local project_context = nil + if prompt.is_whole_file then + pcall(function() + local tree = require("codetyper.support.tree") + local tree_log = tree.get_tree_log_path() + if tree_log and vim.fn.filereadable(tree_log) == 1 then + local tree_lines = vim.fn.readfile(tree_log) + if tree_lines and #tree_lines > 0 then + local tree_content = table.concat(tree_lines, "\n") + project_context = tree_content:sub(1, 4000) + end + end + end) + end - -- Determine priority based on intent - local priority = 2 - if intent.type == "fix" or intent.type == "complete" then - priority = 1 - elseif intent.type == "test" or intent.type == "document" then - priority = 3 - end + -- Determine priority based on intent + local priority = 2 + if intent.type == "fix" or intent.type == "complete" then + priority = 1 + elseif intent.type == "test" or intent.type == "document" then + priority = 3 + end - -- Use captured injection range when provided (from transform-selection), else prompt.start_line/end_line - local raw_start = (prompt.injection_range and prompt.injection_range.start_line) or prompt.start_line or 1 - local raw_end = (prompt.injection_range and prompt.injection_range.end_line) or prompt.end_line or 1 - -- Clamp to target buffer (1-based, valid lines) - local tc = vim.api.nvim_buf_line_count(target_bufnr) - tc = math.max(1, tc) - local rs = math.max(1, math.min(raw_start, tc)) - local re = math.max(1, math.min(raw_end, tc)) - if re < rs then - re = rs - end - local event_range = { start_line = rs, end_line = re } + -- Use captured injection range when provided (from transform-selection), else prompt.start_line/end_line + local raw_start = (prompt.injection_range and prompt.injection_range.start_line) or prompt.start_line or 1 + local raw_end = (prompt.injection_range and prompt.injection_range.end_line) or prompt.end_line or 1 + -- Clamp to target buffer (1-based, valid lines) + local tc = vim.api.nvim_buf_line_count(target_bufnr) + tc = math.max(1, tc) + local rs = math.max(1, math.min(raw_start, tc)) + local re = math.max(1, math.min(raw_end, tc)) + if re < rs then + re = rs + end + local event_range = { start_line = rs, end_line = re } - -- Extmarks for injection range (99-style: position survives user typing) - local range_for_marks = scope_range or event_range - local injection_marks = create_injection_marks(target_bufnr, range_for_marks) + -- Extmarks for injection range (99-style: position survives user typing) + local range_for_marks = scope_range or event_range + local injection_marks = create_injection_marks(target_bufnr, range_for_marks) - -- Enqueue the event (event.range = where to apply the generated code) - queue.enqueue({ - id = queue.generate_id(), - bufnr = bufnr, - range = event_range, - timestamp = os.clock(), - changedtick = snapshot.changedtick, - content_hash = snapshot.content_hash, - prompt_content = cleaned, - target_path = target_path, - priority = priority, - status = "pending", - attempt_count = 0, - intent = intent, - intent_override = prompt.intent_override, - scope = scope, - scope_text = scope_text, - scope_range = scope_range, - attached_files = attached_files, - injection_marks = injection_marks, - injection_range = prompt.injection_range, - is_whole_file = prompt.is_whole_file, - project_context = project_context, - }) + -- Enqueue the event (event.range = where to apply the generated code) + queue.enqueue({ + id = queue.generate_id(), + bufnr = bufnr, + range = event_range, + timestamp = os.clock(), + changedtick = snapshot.changedtick, + content_hash = snapshot.content_hash, + prompt_content = cleaned, + target_path = target_path, + priority = priority, + status = "pending", + attempt_count = 0, + intent = intent, + intent_override = prompt.intent_override, + scope = scope, + scope_text = scope_text, + scope_range = scope_range, + attached_files = attached_files, + injection_marks = injection_marks, + injection_range = prompt.injection_range, + is_whole_file = prompt.is_whole_file, + project_context = project_context, + }) - local scope_info = scope - and scope.type ~= "file" - and string.format(" [%s: %s]", scope.type, scope.name or "anonymous") - or "" - utils.notify(string.format("Prompt queued: %s%s", intent.type, scope_info), vim.log.levels.INFO) - end) + local scope_info = scope + and scope.type ~= "file" + and string.format(" [%s: %s]", scope.type, scope.name or "anonymous") + or "" + utils.notify(string.format("Prompt queued: %s%s", intent.type, scope_info), vim.log.levels.INFO) + end) end --- Check and process all closed prompts in the buffer (works on ANY file) function M.check_all_prompts() - local parser = require("codetyper.parser") - local bufnr = vim.api.nvim_get_current_buf() - local current_file = vim.fn.expand("%:p") + local parser = require("codetyper.parser") + local bufnr = vim.api.nvim_get_current_buf() + local current_file = vim.fn.expand("%:p") - -- Skip if no file - if current_file == "" then - return - end + -- Skip if no file + if current_file == "" then + return + end - -- Find all prompts in buffer - local prompts = parser.find_prompts_in_buffer(bufnr) + -- Find all prompts in buffer + local prompts = parser.find_prompts_in_buffer(bufnr) - if #prompts == 0 then - return - end + if #prompts == 0 then + return + end - -- Check if scheduler is enabled - local codetyper = require("codetyper") - local ct_config = codetyper.get_config() - local scheduler_enabled = ct_config and ct_config.scheduler and ct_config.scheduler.enabled + -- Check if scheduler is enabled + local codetyper = require("codetyper") + local ct_config = codetyper.get_config() + local scheduler_enabled = ct_config and ct_config.scheduler and ct_config.scheduler.enabled - if not scheduler_enabled then - return - end + if not scheduler_enabled then + return + end - for _, prompt in ipairs(prompts) do - M.process_single_prompt(bufnr, prompt, current_file) - end + for _, prompt in ipairs(prompts) do + M.process_single_prompt(bufnr, prompt, current_file) + end end --- Check for closed prompt with preference check --- If user hasn't chosen auto/manual mode, ask them first function M.check_for_closed_prompt_with_preference() - local parser = require("codetyper.parser") + local parser = require("codetyper.parser") - -- First check if there are any prompts to process - local bufnr = vim.api.nvim_get_current_buf() - local prompts = parser.find_prompts_in_buffer(bufnr) - if #prompts == 0 then - return - end + -- First check if there are any prompts to process + local bufnr = vim.api.nvim_get_current_buf() + local prompts = parser.find_prompts_in_buffer(bufnr) + if #prompts == 0 then + return + end - if auto_process then - -- Automatic mode - process prompts - M.check_for_closed_prompt() - end - -- Manual mode - do nothing, user will run :CoderProcess + if auto_process then + -- Automatic mode - process prompts + M.check_for_closed_prompt() + end + -- Manual mode - do nothing, user will run :CoderProcess end --- Check all prompts with preference check function M.check_all_prompts_with_preference() - local preferences = require("codetyper.config.preferences") - local parser = require("codetyper.parser") + local preferences = require("codetyper.config.preferences") + local parser = require("codetyper.parser") - -- First check if there are any prompts to process - local bufnr = vim.api.nvim_get_current_buf() - local prompts = parser.find_prompts_in_buffer(bufnr) - if #prompts == 0 then - return - end + -- First check if there are any prompts to process + local bufnr = vim.api.nvim_get_current_buf() + local prompts = parser.find_prompts_in_buffer(bufnr) + if #prompts == 0 then + return + end - -- Check if any prompts are unprocessed - local has_unprocessed = false - for _, prompt in ipairs(prompts) do - local prompt_key = get_prompt_key(bufnr, prompt) - if not processed_prompts[prompt_key] then - has_unprocessed = true - break - end - end + -- Check if any prompts are unprocessed + local has_unprocessed = false + for _, prompt in ipairs(prompts) do + local prompt_key = get_prompt_key(bufnr, prompt) + if not processed_prompts[prompt_key] then + has_unprocessed = true + break + end + end - if not has_unprocessed then - return - end + if not has_unprocessed then + return + end - if auto_process then - -- Automatic mode - process prompts - M.check_all_prompts() - end - -- Manual mode - do nothing, user will run :CoderProcess + if auto_process then + -- Automatic mode - process prompts + M.check_all_prompts() + end + -- Manual mode - do nothing, user will run :CoderProcess end --- Reset processed prompts for a buffer (useful for re-processing) ---@param bufnr? number Buffer number (default: current) ---@param silent? boolean Suppress notification (default: false) function M.reset_processed(bufnr, silent) - bufnr = bufnr or vim.api.nvim_get_current_buf() - for key, _ in pairs(processed_prompts) do - if key:match("^" .. bufnr .. ":") then - processed_prompts[key] = nil - end - end - if not silent then - utils.notify("Prompt history cleared - prompts can be re-processed") - end + bufnr = bufnr or vim.api.nvim_get_current_buf() + for key, _ in pairs(processed_prompts) do + if key:match("^" .. bufnr .. ":") then + processed_prompts[key] = nil + end + end + if not silent then + utils.notify("Prompt history cleared - prompts can be re-processed") + end end --- Track if we already opened the split for this buffer @@ -851,51 +848,51 @@ local auto_opened_buffers = {} --- Clear auto-opened tracking for a buffer ---@param bufnr number Buffer number function M.clear_auto_opened(bufnr) - auto_opened_buffers[bufnr] = nil + auto_opened_buffers[bufnr] = nil end --- Set appropriate filetype for coder files function M.set_coder_filetype() - local filepath = vim.fn.expand("%:p") + local filepath = vim.fn.expand("%:p") - -- Extract the actual extension (e.g., index.codetyper/ts -> ts) - local ext = filepath:match("%.codetyper%.(%w+)$") + -- Extract the actual extension (e.g., index.codetyper/ts -> ts) + local ext = filepath:match("%.codetyper%.(%w+)$") - if ext then - -- Map extension to filetype - local ft_map = { - ts = "typescript", - tsx = "typescriptreact", - js = "javascript", - jsx = "javascriptreact", - py = "python", - lua = "lua", - go = "go", - rs = "rust", - rb = "ruby", - java = "java", - c = "c", - cpp = "cpp", - cs = "cs", - json = "json", - yaml = "yaml", - yml = "yaml", - md = "markdown", - html = "html", - css = "css", - scss = "scss", - vue = "vue", - svelte = "svelte", - } + if ext then + -- Map extension to filetype + local ft_map = { + ts = "typescript", + tsx = "typescriptreact", + js = "javascript", + jsx = "javascriptreact", + py = "python", + lua = "lua", + go = "go", + rs = "rust", + rb = "ruby", + java = "java", + c = "c", + cpp = "cpp", + cs = "cs", + json = "json", + yaml = "yaml", + yml = "yaml", + md = "markdown", + html = "html", + css = "css", + scss = "scss", + vue = "vue", + svelte = "svelte", + } - local filetype = ft_map[ext] or ext - vim.bo.filetype = filetype - end + local filetype = ft_map[ext] or ext + vim.bo.filetype = filetype + end end --- Clear all autocommands function M.clear() - vim.api.nvim_del_augroup_by_name(AUGROUP) + vim.api.nvim_del_augroup_by_name(AUGROUP) end --- Debounce timers for brain updates per file @@ -905,97 +902,97 @@ local brain_update_timers = {} --- Update brain with patterns from a file ---@param filepath string function M.update_brain_from_file(filepath) - local ok_brain, brain = pcall(require, "codetyper.brain") - if not ok_brain or not brain.is_initialized() then - return - end + local ok_brain, brain = pcall(require, "codetyper.brain") + if not ok_brain or not brain.is_initialized() then + return + end - -- Read file content - local content = utils.read_file(filepath) - if not content or content == "" then - return - end + -- Read file content + local content = utils.read_file(filepath) + if not content or content == "" then + return + end - local ext = vim.fn.fnamemodify(filepath, ":e") - local lines = vim.split(content, "\n") + local ext = vim.fn.fnamemodify(filepath, ":e") + local lines = vim.split(content, "\n") - -- Extract key patterns from the file - local functions = {} - local classes = {} - local imports = {} + -- Extract key patterns from the file + local functions = {} + local classes = {} + local imports = {} - for i, line in ipairs(lines) do - -- Functions - local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(") - or line:match("^%s*local%s+function%s+([%w_]+)%s*%(") - or line:match("^%s*def%s+([%w_]+)%s*%(") - or line:match("^%s*func%s+([%w_]+)%s*%(") - or line:match("^%s*async%s+function%s+([%w_]+)%s*%(") - or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(") - or line:match("^%s*private%s+.*%s+([%w_]+)%s*%(") - if func then - table.insert(functions, { name = func, line = i }) - end + for i, line in ipairs(lines) do + -- Functions + local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(") + or line:match("^%s*local%s+function%s+([%w_]+)%s*%(") + or line:match("^%s*def%s+([%w_]+)%s*%(") + or line:match("^%s*func%s+([%w_]+)%s*%(") + or line:match("^%s*async%s+function%s+([%w_]+)%s*%(") + or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(") + or line:match("^%s*private%s+.*%s+([%w_]+)%s*%(") + if func then + table.insert(functions, { name = func, line = i }) + end - -- Classes - local class = line:match("^%s*class%s+([%w_]+)") - or line:match("^%s*public%s+class%s+([%w_]+)") - or line:match("^%s*interface%s+([%w_]+)") - or line:match("^%s*struct%s+([%w_]+)") - if class then - table.insert(classes, { name = class, line = i }) - end + -- Classes + local class = line:match("^%s*class%s+([%w_]+)") + or line:match("^%s*public%s+class%s+([%w_]+)") + or line:match("^%s*interface%s+([%w_]+)") + or line:match("^%s*struct%s+([%w_]+)") + if class then + table.insert(classes, { name = class, line = i }) + end - -- Imports - local imp = line:match("import%s+.*%s+from%s+[\"']([^\"']+)[\"']") - or line:match("require%([\"']([^\"']+)[\"']%)") - or line:match("from%s+([%w_.]+)%s+import") - if imp then - table.insert(imports, imp) - end - end + -- Imports + local imp = line:match("import%s+.*%s+from%s+[\"']([^\"']+)[\"']") + or line:match("require%([\"']([^\"']+)[\"']%)") + or line:match("from%s+([%w_.]+)%s+import") + if imp then + table.insert(imports, imp) + end + end - -- Only store if file has meaningful content - if #functions == 0 and #classes == 0 then - return - end + -- Only store if file has meaningful content + if #functions == 0 and #classes == 0 then + return + end - -- Build summary - local parts = {} - if #functions > 0 then - local func_names = {} - for i, f in ipairs(functions) do - if i <= 5 then - table.insert(func_names, f.name) - end - end - table.insert(parts, "functions: " .. table.concat(func_names, ", ")) - end - if #classes > 0 then - local class_names = {} - for _, c in ipairs(classes) do - table.insert(class_names, c.name) - end - table.insert(parts, "classes: " .. table.concat(class_names, ", ")) - end + -- Build summary + local parts = {} + if #functions > 0 then + local func_names = {} + for i, f in ipairs(functions) do + if i <= 5 then + table.insert(func_names, f.name) + end + end + table.insert(parts, "functions: " .. table.concat(func_names, ", ")) + end + if #classes > 0 then + local class_names = {} + for _, c in ipairs(classes) do + table.insert(class_names, c.name) + end + table.insert(parts, "classes: " .. table.concat(class_names, ", ")) + end - local summary = vim.fn.fnamemodify(filepath, ":t") .. " - " .. table.concat(parts, "; ") + local summary = vim.fn.fnamemodify(filepath, ":t") .. " - " .. table.concat(parts, "; ") - -- Learn this pattern - use "pattern_detected" type to match the pattern learner - brain.learn({ - type = "pattern_detected", - file = filepath, - timestamp = os.time(), - data = { - name = summary, - description = #functions .. " functions, " .. #classes .. " classes", - language = ext, - symbols = vim.tbl_map(function(f) - return f.name - end, functions), - example = nil, - }, - }) + -- Learn this pattern - use "pattern_detected" type to match the pattern learner + brain.learn({ + type = "pattern_detected", + file = filepath, + timestamp = os.time(), + data = { + name = summary, + description = #functions .. " functions, " .. #classes .. " classes", + language = ext, + symbols = vim.tbl_map(function(f) + return f.name + end, functions), + example = nil, + }, + }) end --- Track buffers that have been auto-indexed @@ -1004,443 +1001,440 @@ local auto_indexed_buffers = {} --- Supported file extensions for auto-indexing local supported_extensions = { - "ts", - "tsx", - "js", - "jsx", - "py", - "lua", - "go", - "rs", - "rb", - "java", - "c", - "cpp", - "cs", - "json", - "yaml", - "yml", - "md", - "html", - "css", - "scss", - "vue", - "svelte", - "php", - "sh", - "zsh", + "ts", + "tsx", + "js", + "jsx", + "py", + "lua", + "go", + "rs", + "rb", + "java", + "c", + "cpp", + "cs", + "json", + "yaml", + "yml", + "md", + "html", + "css", + "scss", + "vue", + "svelte", + "php", + "sh", + "zsh", } --- Check if extension is supported ---@param ext string File extension ---@return boolean local function is_supported_extension(ext) - for _, supported in ipairs(supported_extensions) do - if ext == supported then - return true - end - end - return false + for _, supported in ipairs(supported_extensions) do + if ext == supported then + return true + end + end + return false end --- Auto-index a file by creating/opening its coder companion ---@param bufnr number Buffer number --- Directories to ignore for coder file creation local ignored_directories = { - ".git", - ".codetyper", - ".claude", - ".vscode", - ".idea", - "node_modules", - "vendor", - "dist", - "build", - "target", - "__pycache__", - ".cache", - ".npm", - ".yarn", - "coverage", - ".next", - ".nuxt", - ".svelte-kit", - "out", - "bin", - "obj", + ".git", + ".codetyper", + ".claude", + ".vscode", + ".idea", + "node_modules", + "vendor", + "dist", + "build", + "target", + "__pycache__", + ".cache", + ".npm", + ".yarn", + "coverage", + ".next", + ".nuxt", + ".svelte-kit", + "out", + "bin", + "obj", } --- Files to ignore for coder file creation (exact names or patterns) local ignored_files = { - -- Git files - ".gitignore", - ".gitattributes", - ".gitmodules", - -- Lock files - "package-lock.json", - "yarn.lock", - "pnpm-lock.yaml", - "Cargo.lock", - "Gemfile.lock", - "poetry.lock", - "composer.lock", - -- Config files that don't need coder companions - ".env", - ".env.local", - ".env.development", - ".env.production", - ".eslintrc", - ".eslintrc.json", - ".prettierrc", - ".prettierrc.json", - ".editorconfig", - ".dockerignore", - "Dockerfile", - "docker-compose.yml", - "docker-compose.yaml", - ".npmrc", - ".yarnrc", - ".nvmrc", - "tsconfig.json", - "jsconfig.json", - "babel.config.js", - "webpack.config.js", - "vite.config.js", - "rollup.config.js", - "jest.config.js", - "vitest.config.js", - ".stylelintrc", - "tailwind.config.js", - "postcss.config.js", - -- Other non-code files - "README.md", - "CHANGELOG.md", - "LICENSE", - "LICENSE.md", - "CONTRIBUTING.md", - "Makefile", - "CMakeLists.txt", + -- Git files + ".gitignore", + ".gitattributes", + ".gitmodules", + -- Lock files + "package-lock.json", + "yarn.lock", + "pnpm-lock.yaml", + "Cargo.lock", + "Gemfile.lock", + "poetry.lock", + "composer.lock", + -- Config files that don't need coder companions + ".env", + ".env.local", + ".env.development", + ".env.production", + ".eslintrc", + ".eslintrc.json", + ".prettierrc", + ".prettierrc.json", + ".editorconfig", + ".dockerignore", + "Dockerfile", + "docker-compose.yml", + "docker-compose.yaml", + ".npmrc", + ".yarnrc", + ".nvmrc", + "tsconfig.json", + "jsconfig.json", + "babel.config.js", + "webpack.config.js", + "vite.config.js", + "rollup.config.js", + "jest.config.js", + "vitest.config.js", + ".stylelintrc", + "tailwind.config.js", + "postcss.config.js", + -- Other non-code files + "README.md", + "CHANGELOG.md", + "LICENSE", + "LICENSE.md", + "CONTRIBUTING.md", + "Makefile", + "CMakeLists.txt", } --- Check if a file path contains an ignored directory ---@param filepath string Full file path ---@return boolean local function is_in_ignored_directory(filepath) - for _, dir in ipairs(ignored_directories) do - -- Check for /dirname/ or /dirname at end - if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then - return true - end - -- Also check for dirname/ at start (relative paths) - if filepath:match("^" .. dir .. "/") then - return true - end - end - return false + for _, dir in ipairs(ignored_directories) do + -- Check for /dirname/ or /dirname at end + if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then + return true + end + -- Also check for dirname/ at start (relative paths) + if filepath:match("^" .. dir .. "/") then + return true + end + end + return false end --- Check if a file should be ignored for coder companion creation ---@param filepath string Full file path ---@return boolean local function should_ignore_for_coder(filepath) - local filename = vim.fn.fnamemodify(filepath, ":t") + local filename = vim.fn.fnamemodify(filepath, ":t") - -- Check exact filename matches - for _, ignored in ipairs(ignored_files) do - if filename == ignored then - return true - end - end + -- Check exact filename matches + for _, ignored in ipairs(ignored_files) do + if filename == ignored then + return true + end + end - -- Check if file starts with dot (hidden/config files) - if filename:match("^%.") then - return true - end + -- Check if file starts with dot (hidden/config files) + if filename:match("^%.") then + return true + end - -- Check if in ignored directory - if is_in_ignored_directory(filepath) then - return true - end + -- Check if in ignored directory + if is_in_ignored_directory(filepath) then + return true + end - return false + return false end function M.auto_index_file(bufnr) - -- Skip if buffer is invalid - if not vim.api.nvim_buf_is_valid(bufnr) then - return - end + -- Skip if buffer is invalid + if not vim.api.nvim_buf_is_valid(bufnr) then + return + end - -- Skip if already indexed - if auto_indexed_buffers[bufnr] then - return - end + -- Skip if already indexed + if auto_indexed_buffers[bufnr] then + return + end - -- Get file path - local filepath = vim.api.nvim_buf_get_name(bufnr) - if not filepath or filepath == "" then - return - end + -- Get file path + local filepath = vim.api.nvim_buf_get_name(bufnr) + if not filepath or filepath == "" then + return + end - -- Skip coder files - if utils.is_coder_file(filepath) then - return - end + -- Skip coder files + if utils.is_coder_file(filepath) then + return + end - -- Skip special buffers - local buftype = vim.bo[bufnr].buftype - if buftype ~= "" then - return - end + -- Skip special buffers + local buftype = vim.bo[bufnr].buftype + if buftype ~= "" then + return + end - -- Skip unsupported file types - local ext = vim.fn.fnamemodify(filepath, ":e") - if ext == "" or not is_supported_extension(ext) then - return - end + -- Skip unsupported file types + local ext = vim.fn.fnamemodify(filepath, ":e") + if ext == "" or not is_supported_extension(ext) then + return + end - -- Skip ignored directories and files (node_modules, .git, config files, etc.) - if should_ignore_for_coder(filepath) then - return - end + -- Skip ignored directories and files (node_modules, .git, config files, etc.) + if should_ignore_for_coder(filepath) then + return + end - -- Skip if auto_index is disabled in config - local codetyper = require("codetyper") - local config = codetyper.get_config() - if config and config.auto_index == false then - return - end + -- Skip if auto_index is disabled in config + local codetyper = require("codetyper") + local config = codetyper.get_config() + if config and config.auto_index == false then + return + end - -- Mark as indexed - auto_indexed_buffers[bufnr] = true + -- Mark as indexed + auto_indexed_buffers[bufnr] = true - -- Get coder companion path - local coder_path = utils.get_coder_path(filepath) + -- Get coder companion path + local coder_path = utils.get_coder_path(filepath) - -- Check if coder file already exists - local coder_exists = utils.file_exists(coder_path) + -- Check if coder file already exists + local coder_exists = utils.file_exists(coder_path) - -- Create coder file with pseudo-code context if it doesn't exist - if not coder_exists then - local filename = vim.fn.fnamemodify(filepath, ":t") - local ext = vim.fn.fnamemodify(filepath, ":e") + -- Create coder file with pseudo-code context if it doesn't exist + if not coder_exists then + local filename = vim.fn.fnamemodify(filepath, ":t") + local ext = vim.fn.fnamemodify(filepath, ":e") - -- Determine comment style based on extension - local comment_prefix = "--" - local comment_block_start = "--[[" - local comment_block_end = "]]" - if - ext == "ts" - or ext == "tsx" - or ext == "js" - or ext == "jsx" - or ext == "java" - or ext == "c" - or ext == "cpp" - or ext == "cs" - or ext == "go" - or ext == "rs" - then - comment_prefix = "//" - comment_block_start = "/*" - comment_block_end = "*/" - elseif ext == "py" or ext == "rb" or ext == "yaml" or ext == "yml" then - comment_prefix = "#" - comment_block_start = '"""' - comment_block_end = '"""' - end + -- Determine comment style based on extension + local comment_prefix = "--" + local comment_block_start = "--[[" + local comment_block_end = "]]" + if + ext == "ts" + or ext == "tsx" + or ext == "js" + or ext == "jsx" + or ext == "java" + or ext == "c" + or ext == "cpp" + or ext == "cs" + or ext == "go" + or ext == "rs" + then + comment_prefix = "//" + comment_block_start = "/*" + comment_block_end = "*/" + elseif ext == "py" or ext == "rb" or ext == "yaml" or ext == "yml" then + comment_prefix = "#" + comment_block_start = '"""' + comment_block_end = '"""' + end - -- Read target file to analyze its structure - local content = "" - pcall(function() - local lines = vim.fn.readfile(filepath) - if lines then - content = table.concat(lines, "\n") - end - end) + -- Read target file to analyze its structure + local content = "" + pcall(function() + local lines = vim.fn.readfile(filepath) + if lines then + content = table.concat(lines, "\n") + end + end) - -- Extract structure from the file - local functions = extract_functions(content, ext) - local classes = extract_classes(content, ext) - local imports = extract_imports(content, ext) + -- Extract structure from the file + local functions = extract_functions(content, ext) + local classes = extract_classes(content, ext) + local imports = extract_imports(content, ext) - -- Build pseudo-code context - local pseudo_code = {} + -- Build pseudo-code context + local pseudo_code = {} - -- Header - table.insert( - pseudo_code, - comment_prefix - .. " ═══════════════════════════════════════════════════════════" - ) - table.insert(pseudo_code, comment_prefix .. " CODER COMPANION: " .. filename) - table.insert( - pseudo_code, - comment_prefix - .. " ═══════════════════════════════════════════════════════════" - ) - table.insert( - pseudo_code, - comment_prefix .. " This file describes the business logic and behavior of " .. filename - ) - table.insert(pseudo_code, comment_prefix .. " Edit this pseudo-code to guide code generation.") - table.insert(pseudo_code, comment_prefix .. "") + -- Header + table.insert( + pseudo_code, + comment_prefix + .. " ═══════════════════════════════════════════════════════════" + ) + table.insert(pseudo_code, comment_prefix .. " CODER COMPANION: " .. filename) + table.insert( + pseudo_code, + comment_prefix + .. " ═══════════════════════════════════════════════════════════" + ) + table.insert(pseudo_code, comment_prefix .. " This file describes the business logic and behavior of " .. filename) + table.insert(pseudo_code, comment_prefix .. " Edit this pseudo-code to guide code generation.") + table.insert(pseudo_code, comment_prefix .. "") - -- Module purpose - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " MODULE PURPOSE:") - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " TODO: Describe what this module/file is responsible for") - table.insert(pseudo_code, comment_prefix .. ' Example: "Handles user authentication and session management"') - table.insert(pseudo_code, comment_prefix .. "") + -- Module purpose + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " MODULE PURPOSE:") + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " TODO: Describe what this module/file is responsible for") + table.insert(pseudo_code, comment_prefix .. ' Example: "Handles user authentication and session management"') + table.insert(pseudo_code, comment_prefix .. "") - -- Dependencies section - if #imports > 0 then - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " DEPENDENCIES:") - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - for _, imp in ipairs(imports) do - table.insert(pseudo_code, comment_prefix .. " • " .. imp) - end - table.insert(pseudo_code, comment_prefix .. "") - end + -- Dependencies section + if #imports > 0 then + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " DEPENDENCIES:") + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + for _, imp in ipairs(imports) do + table.insert(pseudo_code, comment_prefix .. " • " .. imp) + end + table.insert(pseudo_code, comment_prefix .. "") + end - -- Classes section - if #classes > 0 then - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " CLASSES:") - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - for _, class in ipairs(classes) do - table.insert(pseudo_code, comment_prefix .. "") - table.insert(pseudo_code, comment_prefix .. " class " .. class.name .. ":") - table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - describe what this class represents") - table.insert(pseudo_code, comment_prefix .. " RESPONSIBILITIES:") - table.insert(pseudo_code, comment_prefix .. " - TODO: list main responsibilities") - end - table.insert(pseudo_code, comment_prefix .. "") - end + -- Classes section + if #classes > 0 then + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " CLASSES:") + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + for _, class in ipairs(classes) do + table.insert(pseudo_code, comment_prefix .. "") + table.insert(pseudo_code, comment_prefix .. " class " .. class.name .. ":") + table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - describe what this class represents") + table.insert(pseudo_code, comment_prefix .. " RESPONSIBILITIES:") + table.insert(pseudo_code, comment_prefix .. " - TODO: list main responsibilities") + end + table.insert(pseudo_code, comment_prefix .. "") + end - -- Functions section - if #functions > 0 then - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " FUNCTIONS:") - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - for _, func in ipairs(functions) do - table.insert(pseudo_code, comment_prefix .. "") - table.insert(pseudo_code, comment_prefix .. " " .. func.name .. "():") - table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - what does this function do?") - table.insert(pseudo_code, comment_prefix .. " INPUTS: TODO - describe parameters") - table.insert(pseudo_code, comment_prefix .. " OUTPUTS: TODO - describe return value") - table.insert(pseudo_code, comment_prefix .. " BEHAVIOR:") - table.insert(pseudo_code, comment_prefix .. " - TODO: describe step-by-step logic") - end - table.insert(pseudo_code, comment_prefix .. "") - end + -- Functions section + if #functions > 0 then + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " FUNCTIONS:") + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + for _, func in ipairs(functions) do + table.insert(pseudo_code, comment_prefix .. "") + table.insert(pseudo_code, comment_prefix .. " " .. func.name .. "():") + table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - what does this function do?") + table.insert(pseudo_code, comment_prefix .. " INPUTS: TODO - describe parameters") + table.insert(pseudo_code, comment_prefix .. " OUTPUTS: TODO - describe return value") + table.insert(pseudo_code, comment_prefix .. " BEHAVIOR:") + table.insert(pseudo_code, comment_prefix .. " - TODO: describe step-by-step logic") + end + table.insert(pseudo_code, comment_prefix .. "") + end - -- If empty file, provide starter template - if #functions == 0 and #classes == 0 then - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " PLANNED STRUCTURE:") - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " TODO: Describe what you want to build in this file") - table.insert(pseudo_code, comment_prefix .. "") - table.insert(pseudo_code, comment_prefix .. " Example pseudo-code:") + -- If empty file, provide starter template + if #functions == 0 and #classes == 0 then + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " PLANNED STRUCTURE:") + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " TODO: Describe what you want to build in this file") + table.insert(pseudo_code, comment_prefix .. "") + table.insert(pseudo_code, comment_prefix .. " Example pseudo-code:") - table.insert(pseudo_code, comment_prefix .. " Create a module that:") - table.insert(pseudo_code, comment_prefix .. " 1. Exports a main function") - table.insert(pseudo_code, comment_prefix .. " 2. Handles errors gracefully") - table.insert(pseudo_code, comment_prefix .. " 3. Returns structured data") - table.insert(pseudo_code, comment_prefix .. "") - end + table.insert(pseudo_code, comment_prefix .. " Create a module that:") + table.insert(pseudo_code, comment_prefix .. " 1. Exports a main function") + table.insert(pseudo_code, comment_prefix .. " 2. Handles errors gracefully") + table.insert(pseudo_code, comment_prefix .. " 3. Returns structured data") + table.insert(pseudo_code, comment_prefix .. "") + end - -- Business rules section - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " BUSINESS RULES:") - table.insert( - pseudo_code, - comment_prefix - .. " ─────────────────────────────────────────────────────────────" - ) - table.insert(pseudo_code, comment_prefix .. " TODO: Document any business rules, constraints, or requirements") - table.insert(pseudo_code, comment_prefix .. " Example:") - table.insert(pseudo_code, comment_prefix .. " - Users must be authenticated before accessing this feature") - table.insert(pseudo_code, comment_prefix .. " - Data must be validated before saving") - table.insert(pseudo_code, comment_prefix .. " - Errors should be logged but not exposed to users") - table.insert(pseudo_code, comment_prefix .. "") + -- Business rules section + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " BUSINESS RULES:") + table.insert( + pseudo_code, + comment_prefix + .. " ─────────────────────────────────────────────────────────────" + ) + table.insert(pseudo_code, comment_prefix .. " TODO: Document any business rules, constraints, or requirements") + table.insert(pseudo_code, comment_prefix .. " Example:") + table.insert(pseudo_code, comment_prefix .. " - Users must be authenticated before accessing this feature") + table.insert(pseudo_code, comment_prefix .. " - Data must be validated before saving") + table.insert(pseudo_code, comment_prefix .. " - Errors should be logged but not exposed to users") + table.insert(pseudo_code, comment_prefix .. "") - -- Footer with generation tags example - table.insert( - pseudo_code, - comment_prefix - .. " ═══════════════════════════════════════════════════════════" - ) - table.insert( - pseudo_code, - comment_prefix - .. " ═══════════════════════════════════════════════════════════" - ) - table.insert(pseudo_code, "") + -- Footer with generation tags example + table.insert( + pseudo_code, + comment_prefix + .. " ═══════════════════════════════════════════════════════════" + ) + table.insert( + pseudo_code, + comment_prefix + .. " ═══════════════════════════════════════════════════════════" + ) + table.insert(pseudo_code, "") - utils.write_file(coder_path, table.concat(pseudo_code, "\n")) - end + utils.write_file(coder_path, table.concat(pseudo_code, "\n")) + end - -- Notify user about the coder companion - local coder_filename = vim.fn.fnamemodify(coder_path, ":t") - if coder_exists then - utils.notify("Coder companion available: " .. coder_filename, vim.log.levels.DEBUG) - else - utils.notify("Created coder companion: " .. coder_filename, vim.log.levels.INFO) - end + -- Notify user about the coder companion + local coder_filename = vim.fn.fnamemodify(coder_path, ":t") + if coder_exists then + utils.notify("Coder companion available: " .. coder_filename, vim.log.levels.DEBUG) + else + utils.notify("Created coder companion: " .. coder_filename, vim.log.levels.INFO) + end end --- Clear auto-indexed tracking for a buffer ---@param bufnr number Buffer number function M.clear_auto_indexed(bufnr) - auto_indexed_buffers[bufnr] = nil + auto_indexed_buffers[bufnr] = nil end return M diff --git a/lua/codetyper/adapters/nvim/cmp/init.lua b/lua/codetyper/adapters/nvim/cmp/init.lua index fc9b321..1733b62 100644 --- a/lua/codetyper/adapters/nvim/cmp/init.lua +++ b/lua/codetyper/adapters/nvim/cmp/init.lua @@ -11,136 +11,136 @@ local source = {} --- Check if cmp is available ---@return boolean local function has_cmp() - return pcall(require, "cmp") + return pcall(require, "cmp") end --- Get completion items from brain context ---@param prefix string Current word prefix ---@return table[] items local function get_brain_completions(prefix) - local items = {} + local items = {} - local ok_brain, brain = pcall(require, "codetyper.brain") - if not ok_brain then - return items - end + local ok_brain, brain = pcall(require, "codetyper.brain") + if not ok_brain then + return items + end - -- Check if brain is initialized safely - local is_init = false - if brain.is_initialized then - local ok, result = pcall(brain.is_initialized) - is_init = ok and result - end + -- Check if brain is initialized safely + local is_init = false + if brain.is_initialized then + local ok, result = pcall(brain.is_initialized) + is_init = ok and result + end - if not is_init then - return items - end + if not is_init then + return items + end - -- Query brain for relevant patterns - local ok_query, result = pcall(brain.query, { - query = prefix, - max_results = 10, - types = { "pattern" }, - }) + -- Query brain for relevant patterns + local ok_query, result = pcall(brain.query, { + query = prefix, + max_results = 10, + types = { "pattern" }, + }) - if ok_query and result and result.nodes then - for _, node in ipairs(result.nodes) do - if node.c and node.c.s then - -- Extract function/class names from summary - local summary = node.c.s - for name in summary:gmatch("functions:%s*([^;]+)") do - for func in name:gmatch("([%w_]+)") do - if func:lower():find(prefix:lower(), 1, true) then - table.insert(items, { - label = func, - kind = 3, -- Function - detail = "[brain]", - documentation = summary, - }) - end - end - end - for name in summary:gmatch("classes:%s*([^;]+)") do - for class in name:gmatch("([%w_]+)") do - if class:lower():find(prefix:lower(), 1, true) then - table.insert(items, { - label = class, - kind = 7, -- Class - detail = "[brain]", - documentation = summary, - }) - end - end - end - end - end - end + if ok_query and result and result.nodes then + for _, node in ipairs(result.nodes) do + if node.c and node.c.s then + -- Extract function/class names from summary + local summary = node.c.s + for name in summary:gmatch("functions:%s*([^;]+)") do + for func in name:gmatch("([%w_]+)") do + if func:lower():find(prefix:lower(), 1, true) then + table.insert(items, { + label = func, + kind = 3, -- Function + detail = "[brain]", + documentation = summary, + }) + end + end + end + for name in summary:gmatch("classes:%s*([^;]+)") do + for class in name:gmatch("([%w_]+)") do + if class:lower():find(prefix:lower(), 1, true) then + table.insert(items, { + label = class, + kind = 7, -- Class + detail = "[brain]", + documentation = summary, + }) + end + end + end + end + end + end - return items + return items end --- Get completion items from indexer symbols ---@param prefix string Current word prefix ---@return table[] items local function get_indexer_completions(prefix) - local items = {} + local items = {} - local ok_indexer, indexer = pcall(require, "codetyper.indexer") - if not ok_indexer then - return items - end + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if not ok_indexer then + return items + end - local ok_load, index = pcall(indexer.load_index) - if not ok_load or not index then - return items - end + local ok_load, index = pcall(indexer.load_index) + if not ok_load or not index then + return items + end - -- Search symbols - if index.symbols then - for symbol, files in pairs(index.symbols) do - if symbol:lower():find(prefix:lower(), 1, true) then - local files_str = type(files) == "table" and table.concat(files, ", ") or tostring(files) - table.insert(items, { - label = symbol, - kind = 6, -- Variable (generic) - detail = "[index] " .. files_str:sub(1, 30), - documentation = "Symbol found in: " .. files_str, - }) - end - end - end + -- Search symbols + if index.symbols then + for symbol, files in pairs(index.symbols) do + if symbol:lower():find(prefix:lower(), 1, true) then + local files_str = type(files) == "table" and table.concat(files, ", ") or tostring(files) + table.insert(items, { + label = symbol, + kind = 6, -- Variable (generic) + detail = "[index] " .. files_str:sub(1, 30), + documentation = "Symbol found in: " .. files_str, + }) + end + end + end - -- Search functions in files - if index.files then - for filepath, file_index in pairs(index.files) do - if file_index and file_index.functions then - for _, func in ipairs(file_index.functions) do - if func.name and func.name:lower():find(prefix:lower(), 1, true) then - table.insert(items, { - label = func.name, - kind = 3, -- Function - detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"), - documentation = func.docstring or ("Function at line " .. (func.line or "?")), - }) - end - end - end - if file_index and file_index.classes then - for _, class in ipairs(file_index.classes) do - if class.name and class.name:lower():find(prefix:lower(), 1, true) then - table.insert(items, { - label = class.name, - kind = 7, -- Class - detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"), - documentation = class.docstring or ("Class at line " .. (class.line or "?")), - }) - end - end - end - end - end + -- Search functions in files + if index.files then + for filepath, file_index in pairs(index.files) do + if file_index and file_index.functions then + for _, func in ipairs(file_index.functions) do + if func.name and func.name:lower():find(prefix:lower(), 1, true) then + table.insert(items, { + label = func.name, + kind = 3, -- Function + detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"), + documentation = func.docstring or ("Function at line " .. (func.line or "?")), + }) + end + end + end + if file_index and file_index.classes then + for _, class in ipairs(file_index.classes) do + if class.name and class.name:lower():find(prefix:lower(), 1, true) then + table.insert(items, { + label = class.name, + kind = 7, -- Class + detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"), + documentation = class.docstring or ("Class at line " .. (class.line or "?")), + }) + end + end + end + end + end - return items + return items end --- Get completion items from current buffer (fallback) @@ -148,210 +148,209 @@ end ---@param bufnr number Buffer number ---@return table[] items local function get_buffer_completions(prefix, bufnr) - local items = {} - local seen = {} + local items = {} + local seen = {} - -- Get all lines in buffer - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local prefix_lower = prefix:lower() + -- Get all lines in buffer + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local prefix_lower = prefix:lower() - for _, line in ipairs(lines) do - -- Extract words that could be identifiers - for word in line:gmatch("[%a_][%w_]*") do - if #word >= 3 and word:lower():find(prefix_lower, 1, true) and not seen[word] and word ~= prefix then - seen[word] = true - table.insert(items, { - label = word, - kind = 1, -- Text - detail = "[buffer]", - }) - end - end - end + for _, line in ipairs(lines) do + -- Extract words that could be identifiers + for word in line:gmatch("[%a_][%w_]*") do + if #word >= 3 and word:lower():find(prefix_lower, 1, true) and not seen[word] and word ~= prefix then + seen[word] = true + table.insert(items, { + label = word, + kind = 1, -- Text + detail = "[buffer]", + }) + end + end + end - return items + return items end --- Try to get Copilot suggestion if plugin is installed ---@param prefix string ---@return string|nil suggestion local function get_copilot_suggestion(prefix) - -- Try copilot.lua suggestion API first - local ok, copilot_suggestion = pcall(require, "copilot.suggestion") - if ok and copilot_suggestion and type(copilot_suggestion.get_suggestion) == "function" then - local ok2, suggestion = pcall(copilot_suggestion.get_suggestion) - if ok2 and suggestion and suggestion ~= "" then - -- Only return if suggestion seems to start with prefix (best-effort) - if prefix == "" or suggestion:lower():match(prefix:lower(), 1) then - return suggestion - else - return suggestion - end - end - end + -- Try copilot.lua suggestion API first + local ok, copilot_suggestion = pcall(require, "copilot.suggestion") + if ok and copilot_suggestion and type(copilot_suggestion.get_suggestion) == "function" then + local ok2, suggestion = pcall(copilot_suggestion.get_suggestion) + if ok2 and suggestion and suggestion ~= "" then + -- Only return if suggestion seems to start with prefix (best-effort) + if prefix == "" or suggestion:lower():match(prefix:lower(), 1) then + return suggestion + else + return suggestion + end + end + end - -- Fallback: try older copilot module if present - local ok3, copilot = pcall(require, "copilot") - if ok3 and copilot and type(copilot.get_suggestion) == "function" then - local ok4, suggestion = pcall(copilot.get_suggestion) - if ok4 and suggestion and suggestion ~= "" then - return suggestion - end - end + -- Fallback: try older copilot module if present + local ok3, copilot = pcall(require, "copilot") + if ok3 and copilot and type(copilot.get_suggestion) == "function" then + local ok4, suggestion = pcall(copilot.get_suggestion) + if ok4 and suggestion and suggestion ~= "" then + return suggestion + end + end - return nil + return nil end --- Create new cmp source instance function source.new() - return setmetatable({}, { __index = source }) + return setmetatable({}, { __index = source }) end --- Get source name function source:get_keyword_pattern() - return [[\k\+]] + return [[\k\+]] end --- Check if source is available function source:is_available() - return true + return true end --- Get debug name function source:get_debug_name() - return "codetyper" + return "codetyper" end --- Get trigger characters function source:get_trigger_characters() - return { ".", ":", "_" } + return { ".", ":", "_" } end --- Complete ---@param params table ---@param callback fun(response: table|nil) function source:complete(params, callback) - local prefix = params.context.cursor_before_line:match("[%w_]+$") or "" + local prefix = params.context.cursor_before_line:match("[%w_]+$") or "" - if #prefix < 2 then - callback({ items = {}, isIncomplete = true }) - return - end + if #prefix < 2 then + callback({ items = {}, isIncomplete = true }) + return + end - -- Collect completions from brain, indexer, and buffer - local items = {} - local seen = {} + -- Collect completions from brain, indexer, and buffer + local items = {} + local seen = {} - -- Get brain completions (highest priority) - local ok1, brain_items = pcall(get_brain_completions, prefix) - if ok1 and brain_items then - for _, item in ipairs(brain_items) do - if not seen[item.label] then - seen[item.label] = true - item.sortText = "1" .. item.label - table.insert(items, item) - end - end - end + -- Get brain completions (highest priority) + local ok1, brain_items = pcall(get_brain_completions, prefix) + if ok1 and brain_items then + for _, item in ipairs(brain_items) do + if not seen[item.label] then + seen[item.label] = true + item.sortText = "1" .. item.label + table.insert(items, item) + end + end + end - -- Get indexer completions - local ok2, indexer_items = pcall(get_indexer_completions, prefix) - if ok2 and indexer_items then - for _, item in ipairs(indexer_items) do - if not seen[item.label] then - seen[item.label] = true - item.sortText = "2" .. item.label - table.insert(items, item) - end - end - end + -- Get indexer completions + local ok2, indexer_items = pcall(get_indexer_completions, prefix) + if ok2 and indexer_items then + for _, item in ipairs(indexer_items) do + if not seen[item.label] then + seen[item.label] = true + item.sortText = "2" .. item.label + table.insert(items, item) + end + end + end - -- Get buffer completions as fallback (lower priority) - local bufnr = params.context.bufnr - if bufnr then - local ok3, buffer_items = pcall(get_buffer_completions, prefix, bufnr) - if ok3 and buffer_items then - for _, item in ipairs(buffer_items) do - if not seen[item.label] then - seen[item.label] = true - item.sortText = "3" .. item.label - table.insert(items, item) - end - end - end - end + -- Get buffer completions as fallback (lower priority) + local bufnr = params.context.bufnr + if bufnr then + local ok3, buffer_items = pcall(get_buffer_completions, prefix, bufnr) + if ok3 and buffer_items then + for _, item in ipairs(buffer_items) do + if not seen[item.label] then + seen[item.label] = true + item.sortText = "3" .. item.label + table.insert(items, item) + end + end + end + end - -- If Copilot is installed, prefer its suggestion as a top-priority completion - local ok_cp, _ = pcall(require, "copilot") - if ok_cp then - local suggestion = nil - local ok_sug, res = pcall(get_copilot_suggestion, prefix) - if ok_sug then - suggestion = res - end - if suggestion and suggestion ~= "" then - -- Truncate suggestion to first line for label display - local first_line = suggestion:match("([^ -]+)") or suggestion - -- Avoid duplicates - if not seen[first_line] then - seen[first_line] = true - table.insert(items, 1, { - label = first_line, - kind = 1, - detail = "[copilot]", - documentation = suggestion, - sortText = "0" .. first_line, - }) - end - end - end + -- If Copilot is installed, prefer its suggestion as a top-priority completion + local ok_cp, _ = pcall(require, "copilot") + if ok_cp then + local suggestion = nil + local ok_sug, res = pcall(get_copilot_suggestion, prefix) + if ok_sug then + suggestion = res + end + if suggestion and suggestion ~= "" then + -- Truncate suggestion to first line for label display + local first_line = suggestion:match("([^\n]+)") or suggestion + -- Avoid duplicates + if not seen[first_line] then + seen[first_line] = true + table.insert(items, 1, { + label = first_line, + kind = 1, + detail = "[copilot]", + documentation = suggestion, + sortText = "0" .. first_line, + }) + end + end + end - callback({ - items = items, - isIncomplete = #items >= 50, - }) + callback({ + items = items, + isIncomplete = #items >= 50, + }) end --- Setup the completion source function M.setup() - if not has_cmp() then - return false - end + if not has_cmp() then + return false + end - local cmp = require("cmp") - local new_source = source.new() + local cmp = require("cmp") + local new_source = source.new() - -- Register the source - cmp.register_source("codetyper", new_source) + -- Register the source + cmp.register_source("codetyper", new_source) - return true + return true end --- Check if source is registered ---@return boolean function M.is_registered() - local ok, cmp = pcall(require, "cmp") - if not ok then - return false - end + local ok, cmp = pcall(require, "cmp") + if not ok then + return false + end - -- Try to get registered sources - local config = cmp.get_config() - if config and config.sources then - for _, src in ipairs(config.sources) do - if src.name == "codetyper" then - return true - end - end - end + -- Try to get registered sources + local config = cmp.get_config() + if config and config.sources then + for _, src in ipairs(config.sources) do + if src.name == "codetyper" then + return true + end + end + end - return false + return false end --- Get source for manual registration function M.get_source() - return source + return source end return M diff --git a/lua/codetyper/adapters/nvim/commands.lua b/lua/codetyper/adapters/nvim/commands.lua index 2d54427..4d070e0 100644 --- a/lua/codetyper/adapters/nvim/commands.lua +++ b/lua/codetyper/adapters/nvim/commands.lua @@ -7,413 +7,413 @@ local utils = require("codetyper.support.utils") --- Refresh tree.log manually local function cmd_tree() - local tree = require("codetyper.support.tree") - if tree.update_tree_log() then - utils.notify("Tree log updated: " .. tree.get_tree_log_path()) - else - utils.notify("Failed to update tree log", vim.log.levels.ERROR) - end + local tree = require("codetyper.support.tree") + if tree.update_tree_log() then + utils.notify("Tree log updated: " .. tree.get_tree_log_path()) + else + utils.notify("Failed to update tree log", vim.log.levels.ERROR) + end end --- Open tree.log file local function cmd_tree_view() - local tree = require("codetyper.support.tree") - local tree_log_path = tree.get_tree_log_path() + local tree = require("codetyper.support.tree") + local tree_log_path = tree.get_tree_log_path() - if not tree_log_path then - utils.notify("Could not find tree.log", vim.log.levels.WARN) - return - end + if not tree_log_path then + utils.notify("Could not find tree.log", vim.log.levels.WARN) + return + end - -- Ensure tree is up to date - tree.update_tree_log() + -- Ensure tree is up to date + tree.update_tree_log() - -- Open in a new split - vim.cmd("vsplit " .. vim.fn.fnameescape(tree_log_path)) - vim.bo.readonly = true - vim.bo.modifiable = false + -- Open in a new split + vim.cmd("vsplit " .. vim.fn.fnameescape(tree_log_path)) + vim.bo.readonly = true + vim.bo.modifiable = false end --- Reset processed prompts to allow re-processing local function cmd_reset() - local autocmds = require("codetyper.adapters.nvim.autocmds") - autocmds.reset_processed() + local autocmds = require("codetyper.adapters.nvim.autocmds") + autocmds.reset_processed() end --- Force update gitignore local function cmd_gitignore() - local gitignore = require("codetyper.support.gitignore") - gitignore.force_update() + local gitignore = require("codetyper.support.gitignore") + gitignore.force_update() end --- Index the entire project local function cmd_index_project() - local indexer = require("codetyper.features.indexer") + local indexer = require("codetyper.features.indexer") - utils.notify("Indexing project...", vim.log.levels.INFO) + utils.notify("Indexing project...", vim.log.levels.INFO) - indexer.index_project(function(index) - if index then - local msg = string.format( - "Indexed: %d files, %d functions, %d classes, %d exports", - index.stats.files, - index.stats.functions, - index.stats.classes, - index.stats.exports - ) - utils.notify(msg, vim.log.levels.INFO) - else - utils.notify("Failed to index project", vim.log.levels.ERROR) - end - end) + indexer.index_project(function(index) + if index then + local msg = string.format( + "Indexed: %d files, %d functions, %d classes, %d exports", + index.stats.files, + index.stats.functions, + index.stats.classes, + index.stats.exports + ) + utils.notify(msg, vim.log.levels.INFO) + else + utils.notify("Failed to index project", vim.log.levels.ERROR) + end + end) end --- Show index status local function cmd_index_status() - local indexer = require("codetyper.features.indexer") - local memory = require("codetyper.features.indexer.memory") + local indexer = require("codetyper.features.indexer") + local memory = require("codetyper.features.indexer.memory") - local status = indexer.get_status() - local mem_stats = memory.get_stats() + local status = indexer.get_status() + local mem_stats = memory.get_stats() - local lines = { - "Project Index Status", - "====================", - "", - } + local lines = { + "Project Index Status", + "====================", + "", + } - if status.indexed then - table.insert(lines, "Status: Indexed") - table.insert(lines, "Project Type: " .. (status.project_type or "unknown")) - table.insert(lines, "Last Indexed: " .. os.date("%Y-%m-%d %H:%M:%S", status.last_indexed)) - table.insert(lines, "") - table.insert(lines, "Stats:") - table.insert(lines, " Files: " .. (status.stats.files or 0)) - table.insert(lines, " Functions: " .. (status.stats.functions or 0)) - table.insert(lines, " Classes: " .. (status.stats.classes or 0)) - table.insert(lines, " Exports: " .. (status.stats.exports or 0)) - else - table.insert(lines, "Status: Not indexed") - table.insert(lines, "Run :CoderIndexProject to index") - end + if status.indexed then + table.insert(lines, "Status: Indexed") + table.insert(lines, "Project Type: " .. (status.project_type or "unknown")) + table.insert(lines, "Last Indexed: " .. os.date("%Y-%m-%d %H:%M:%S", status.last_indexed)) + table.insert(lines, "") + table.insert(lines, "Stats:") + table.insert(lines, " Files: " .. (status.stats.files or 0)) + table.insert(lines, " Functions: " .. (status.stats.functions or 0)) + table.insert(lines, " Classes: " .. (status.stats.classes or 0)) + table.insert(lines, " Exports: " .. (status.stats.exports or 0)) + else + table.insert(lines, "Status: Not indexed") + table.insert(lines, "Run :CoderIndexProject to index") + end - table.insert(lines, "") - table.insert(lines, "Memories:") - table.insert(lines, " Patterns: " .. mem_stats.patterns) - table.insert(lines, " Conventions: " .. mem_stats.conventions) - table.insert(lines, " Symbols: " .. mem_stats.symbols) + table.insert(lines, "") + table.insert(lines, "Memories:") + table.insert(lines, " Patterns: " .. mem_stats.patterns) + table.insert(lines, " Conventions: " .. mem_stats.conventions) + table.insert(lines, " Symbols: " .. mem_stats.symbols) - utils.notify(table.concat(lines, "\n")) + utils.notify(table.concat(lines, "\n")) end --- Show learned memories local function cmd_memories() - local memory = require("codetyper.features.indexer.memory") + local memory = require("codetyper.features.indexer.memory") - local all = memory.get_all() - local lines = { - "Learned Memories", - "================", - "", - "Patterns:", - } + local all = memory.get_all() + local lines = { + "Learned Memories", + "================", + "", + "Patterns:", + } - local pattern_count = 0 - for _, mem in pairs(all.patterns) do - pattern_count = pattern_count + 1 - if pattern_count <= 10 then - table.insert(lines, " - " .. (mem.content or ""):sub(1, 60)) - end - end - if pattern_count > 10 then - table.insert(lines, " ... and " .. (pattern_count - 10) .. " more") - elseif pattern_count == 0 then - table.insert(lines, " (none)") - end + local pattern_count = 0 + for _, mem in pairs(all.patterns) do + pattern_count = pattern_count + 1 + if pattern_count <= 10 then + table.insert(lines, " - " .. (mem.content or ""):sub(1, 60)) + end + end + if pattern_count > 10 then + table.insert(lines, " ... and " .. (pattern_count - 10) .. " more") + elseif pattern_count == 0 then + table.insert(lines, " (none)") + end - table.insert(lines, "") - table.insert(lines, "Conventions:") + table.insert(lines, "") + table.insert(lines, "Conventions:") - local conv_count = 0 - for _, mem in pairs(all.conventions) do - conv_count = conv_count + 1 - if conv_count <= 10 then - table.insert(lines, " - " .. (mem.content or ""):sub(1, 60)) - end - end - if conv_count > 10 then - table.insert(lines, " ... and " .. (conv_count - 10) .. " more") - elseif conv_count == 0 then - table.insert(lines, " (none)") - end + local conv_count = 0 + for _, mem in pairs(all.conventions) do + conv_count = conv_count + 1 + if conv_count <= 10 then + table.insert(lines, " - " .. (mem.content or ""):sub(1, 60)) + end + end + if conv_count > 10 then + table.insert(lines, " ... and " .. (conv_count - 10) .. " more") + elseif conv_count == 0 then + table.insert(lines, " (none)") + end - utils.notify(table.concat(lines, "\n")) + utils.notify(table.concat(lines, "\n")) end --- Clear memories ---@param pattern string|nil Optional pattern to match local function cmd_forget(pattern) - local memory = require("codetyper.features.indexer.memory") + local memory = require("codetyper.features.indexer.memory") - if not pattern or pattern == "" then - -- Confirm before clearing all - vim.ui.select({ "Yes", "No" }, { - prompt = "Clear all memories?", - }, function(choice) - if choice == "Yes" then - memory.clear() - utils.notify("All memories cleared", vim.log.levels.INFO) - end - end) - else - memory.clear(pattern) - utils.notify("Cleared memories matching: " .. pattern, vim.log.levels.INFO) - end + if not pattern or pattern == "" then + -- Confirm before clearing all + vim.ui.select({ "Yes", "No" }, { + prompt = "Clear all memories?", + }, function(choice) + if choice == "Yes" then + memory.clear() + utils.notify("All memories cleared", vim.log.levels.INFO) + end + end) + else + memory.clear(pattern) + utils.notify("Cleared memories matching: " .. pattern, vim.log.levels.INFO) + end end --- Main command dispatcher ---@param args table Command arguments --- Show LLM accuracy statistics local function cmd_llm_stats() - local llm = require("codetyper.core.llm") - local stats = llm.get_accuracy_stats() + local llm = require("codetyper.core.llm") + local stats = llm.get_accuracy_stats() - local lines = { - "LLM Provider Accuracy Statistics", - "================================", - "", - string.format("Ollama:"), - string.format(" Total requests: %d", stats.ollama.total), - string.format(" Correct: %d", stats.ollama.correct), - string.format(" Accuracy: %.1f%%", stats.ollama.accuracy * 100), - "", - string.format("Copilot:"), - string.format(" Total requests: %d", stats.copilot.total), - string.format(" Correct: %d", stats.copilot.correct), - string.format(" Accuracy: %.1f%%", stats.copilot.accuracy * 100), - "", - "Note: Smart selection prefers Ollama when brain memories", - "provide enough context. Accuracy improves over time via", - "pondering (verification with other LLMs).", - } + local lines = { + "LLM Provider Accuracy Statistics", + "================================", + "", + string.format("Ollama:"), + string.format(" Total requests: %d", stats.ollama.total), + string.format(" Correct: %d", stats.ollama.correct), + string.format(" Accuracy: %.1f%%", stats.ollama.accuracy * 100), + "", + string.format("Copilot:"), + string.format(" Total requests: %d", stats.copilot.total), + string.format(" Correct: %d", stats.copilot.correct), + string.format(" Accuracy: %.1f%%", stats.copilot.accuracy * 100), + "", + "Note: Smart selection prefers Ollama when brain memories", + "provide enough context. Accuracy improves over time via", + "pondering (verification with other LLMs).", + } - vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO) + vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO) end --- Report feedback on last LLM response ---@param was_good boolean Whether the response was good local function cmd_llm_feedback(was_good) - local llm = require("codetyper.core.llm") - -- Default to ollama for feedback - local provider = "ollama" + local llm = require("codetyper.core.llm") + -- Default to ollama for feedback + local provider = "ollama" - llm.report_feedback(provider, was_good) - local feedback_type = was_good and "positive" or "negative" - utils.notify(string.format("Reported %s feedback for %s", feedback_type, provider), vim.log.levels.INFO) + llm.report_feedback(provider, was_good) + local feedback_type = was_good and "positive" or "negative" + utils.notify(string.format("Reported %s feedback for %s", feedback_type, provider), vim.log.levels.INFO) end --- Reset LLM accuracy statistics local function cmd_llm_reset_stats() - local selector = require("codetyper.core.llm.selector") - selector.reset_accuracy_stats() - utils.notify("LLM accuracy statistics reset", vim.log.levels.INFO) + local selector = require("codetyper.core.llm.selector") + selector.reset_accuracy_stats() + utils.notify("LLM accuracy statistics reset", vim.log.levels.INFO) end local function coder_cmd(args) - local subcommand = args.fargs[1] or "version" + local subcommand = args.fargs[1] or "version" - local commands = { - ["version"] = function() - local codetyper = require("codetyper") - utils.notify("Codetyper.nvim " .. codetyper.version, vim.log.levels.INFO) - end, - tree = cmd_tree, - ["tree-view"] = cmd_tree_view, - reset = cmd_reset, - gitignore = cmd_gitignore, - ["transform-selection"] = transform.cmd_transform_selection, - ["index-project"] = cmd_index_project, - ["index-status"] = cmd_index_status, - ["llm-stats"] = cmd_llm_stats, - ["llm-reset-stats"] = cmd_llm_reset_stats, - ["cost"] = function() - local cost = require("codetyper.core.cost") - cost.toggle() - end, - ["cost-clear"] = function() - local cost = require("codetyper.core.cost") - cost.clear() - end, - ["credentials"] = function() - local credentials = require("codetyper.config.credentials") - credentials.show_status() - end, - ["switch-provider"] = function() - local credentials = require("codetyper.config.credentials") - credentials.interactive_switch_provider() - end, - ["model"] = function(args) - local credentials = require("codetyper.config.credentials") - local codetyper = require("codetyper") - local config = codetyper.get_config() - local provider = config.llm.provider + local commands = { + ["version"] = function() + local codetyper = require("codetyper") + utils.notify("Codetyper.nvim " .. codetyper.version, vim.log.levels.INFO) + end, + tree = cmd_tree, + ["tree-view"] = cmd_tree_view, + reset = cmd_reset, + gitignore = cmd_gitignore, + ["transform-selection"] = transform.cmd_transform_selection, + ["index-project"] = cmd_index_project, + ["index-status"] = cmd_index_status, + ["llm-stats"] = cmd_llm_stats, + ["llm-reset-stats"] = cmd_llm_reset_stats, + ["cost"] = function() + local cost = require("codetyper.core.cost") + cost.toggle() + end, + ["cost-clear"] = function() + local cost = require("codetyper.core.cost") + cost.clear() + end, + ["credentials"] = function() + local credentials = require("codetyper.config.credentials") + credentials.show_status() + end, + ["switch-provider"] = function() + local credentials = require("codetyper.config.credentials") + credentials.interactive_switch_provider() + end, + ["model"] = function(args) + local credentials = require("codetyper.config.credentials") + local codetyper = require("codetyper") + local config = codetyper.get_config() + local provider = config.llm.provider - if provider ~= "copilot" then - utils.notify( - "CoderModel is only available when using Copilot provider. Current: " .. provider:upper(), - vim.log.levels.WARN - ) - return - end + if provider ~= "copilot" then + utils.notify( + "CoderModel is only available when using Copilot provider. Current: " .. provider:upper(), + vim.log.levels.WARN + ) + return + end - local model_arg = args.fargs[2] - if model_arg and model_arg ~= "" then - local cost = credentials.get_copilot_model_cost(model_arg) or "custom" - credentials.set_credentials("copilot", { model = model_arg, configured = true }) - utils.notify("Copilot model set to: " .. model_arg .. " — " .. cost, vim.log.levels.INFO) - else - credentials.interactive_copilot_config(true) - end - end, - } + local model_arg = args.fargs[2] + if model_arg and model_arg ~= "" then + local cost = credentials.get_copilot_model_cost(model_arg) or "custom" + credentials.set_credentials("copilot", { model = model_arg, configured = true }) + utils.notify("Copilot model set to: " .. model_arg .. " — " .. cost, vim.log.levels.INFO) + else + credentials.interactive_copilot_config(true) + end + end, + } - local cmd_fn = commands[subcommand] - if cmd_fn then - cmd_fn(args) - else - utils.notify("Unknown subcommand: " .. subcommand, vim.log.levels.ERROR) - end + local cmd_fn = commands[subcommand] + if cmd_fn then + cmd_fn(args) + else + utils.notify("Unknown subcommand: " .. subcommand, vim.log.levels.ERROR) + end end --- Setup all commands function M.setup() - vim.api.nvim_create_user_command("Coder", coder_cmd, { - nargs = "?", - complete = function() - return { - "version", - "tree", - "tree-view", - "reset", - "gitignore", - "transform-selection", - "index-project", - "index-status", - "llm-stats", - "llm-reset-stats", - "cost", - "cost-clear", - "credentials", - "switch-provider", - "model", - } - end, - desc = "Codetyper.nvim commands", - }) + vim.api.nvim_create_user_command("Coder", coder_cmd, { + nargs = "?", + complete = function() + return { + "version", + "tree", + "tree-view", + "reset", + "gitignore", + "transform-selection", + "index-project", + "index-status", + "llm-stats", + "llm-reset-stats", + "cost", + "cost-clear", + "credentials", + "switch-provider", + "model", + } + end, + desc = "Codetyper.nvim commands", + }) - vim.api.nvim_create_user_command("CoderTree", function() - cmd_tree() - end, { desc = "Refresh tree.log" }) + vim.api.nvim_create_user_command("CoderTree", function() + cmd_tree() + end, { desc = "Refresh tree.log" }) - vim.api.nvim_create_user_command("CoderTreeView", function() - cmd_tree_view() - end, { desc = "View tree.log" }) + vim.api.nvim_create_user_command("CoderTreeView", function() + cmd_tree_view() + end, { desc = "View tree.log" }) - vim.api.nvim_create_user_command("CoderTransformSelection", function() - transform.cmd_transform_selection() - end, { desc = "Transform visual selection with custom prompt input" }) + vim.api.nvim_create_user_command("CoderTransformSelection", function() + transform.cmd_transform_selection() + end, { desc = "Transform visual selection with custom prompt input" }) - -- Project indexer commands - vim.api.nvim_create_user_command("CoderIndexProject", function() - cmd_index_project() - end, { desc = "Index the entire project" }) + -- Project indexer commands + vim.api.nvim_create_user_command("CoderIndexProject", function() + cmd_index_project() + end, { desc = "Index the entire project" }) - vim.api.nvim_create_user_command("CoderIndexStatus", function() - cmd_index_status() - end, { desc = "Show project index status" }) + vim.api.nvim_create_user_command("CoderIndexStatus", function() + cmd_index_status() + end, { desc = "Show project index status" }) - -- TODO: re-enable CoderMemories, CoderForget when memory UI is reworked - -- TODO: re-enable CoderFeedback when feedback loop is reworked - -- TODO: re-enable CoderBrain when brain management UI is reworked + -- TODO: re-enable CoderMemories, CoderForget when memory UI is reworked + -- TODO: re-enable CoderFeedback when feedback loop is reworked + -- TODO: re-enable CoderBrain when brain management UI is reworked - -- Cost estimation command - vim.api.nvim_create_user_command("CoderCost", function() - local cost = require("codetyper.core.cost") - cost.toggle() - end, { desc = "Show LLM cost estimation window" }) + -- Cost estimation command + vim.api.nvim_create_user_command("CoderCost", function() + local cost = require("codetyper.core.cost") + cost.toggle() + end, { desc = "Show LLM cost estimation window" }) - -- TODO: re-enable CoderAddApiKey when multi-provider support returns + -- TODO: re-enable CoderAddApiKey when multi-provider support returns - vim.api.nvim_create_user_command("CoderCredentials", function() - local credentials = require("codetyper.config.credentials") - credentials.show_status() - end, { desc = "Show credentials status" }) + vim.api.nvim_create_user_command("CoderCredentials", function() + local credentials = require("codetyper.config.credentials") + credentials.show_status() + end, { desc = "Show credentials status" }) - vim.api.nvim_create_user_command("CoderSwitchProvider", function() - local credentials = require("codetyper.config.credentials") - credentials.interactive_switch_provider() - end, { desc = "Switch active LLM provider" }) + vim.api.nvim_create_user_command("CoderSwitchProvider", function() + local credentials = require("codetyper.config.credentials") + credentials.interactive_switch_provider() + end, { desc = "Switch active LLM provider" }) - -- Quick model switcher command (Copilot only) - vim.api.nvim_create_user_command("CoderModel", function(opts) - local credentials = require("codetyper.config.credentials") - local codetyper = require("codetyper") - local config = codetyper.get_config() - local provider = config.llm.provider + -- Quick model switcher command (Copilot only) + vim.api.nvim_create_user_command("CoderModel", function(opts) + local credentials = require("codetyper.config.credentials") + local codetyper = require("codetyper") + local config = codetyper.get_config() + local provider = config.llm.provider - -- Only available for Copilot provider - if provider ~= "copilot" then - utils.notify( - "CoderModel is only available when using Copilot provider. Current: " .. provider:upper(), - vim.log.levels.WARN - ) - return - end + -- Only available for Copilot provider + if provider ~= "copilot" then + utils.notify( + "CoderModel is only available when using Copilot provider. Current: " .. provider:upper(), + vim.log.levels.WARN + ) + return + end - -- If an argument is provided, set the model directly - if opts.args and opts.args ~= "" then - local cost = credentials.get_copilot_model_cost(opts.args) or "custom" - credentials.set_credentials("copilot", { model = opts.args, configured = true }) - utils.notify("Copilot model set to: " .. opts.args .. " — " .. cost, vim.log.levels.INFO) - return - end + -- If an argument is provided, set the model directly + if opts.args and opts.args ~= "" then + local cost = credentials.get_copilot_model_cost(opts.args) or "custom" + credentials.set_credentials("copilot", { model = opts.args, configured = true }) + utils.notify("Copilot model set to: " .. opts.args .. " — " .. cost, vim.log.levels.INFO) + return + end - -- Show interactive selector with costs (silent mode - no OAuth message) - credentials.interactive_copilot_config(true) - end, { - nargs = "?", - desc = "Quick switch Copilot model (only available with Copilot provider)", - complete = function() - local codetyper = require("codetyper") - local credentials = require("codetyper.config.credentials") - local config = codetyper.get_config() - if config.llm.provider == "copilot" then - return credentials.get_copilot_model_names() - end - return {} - end, - }) + -- Show interactive selector with costs (silent mode - no OAuth message) + credentials.interactive_copilot_config(true) + end, { + nargs = "?", + desc = "Quick switch Copilot model (only available with Copilot provider)", + complete = function() + local codetyper = require("codetyper") + local credentials = require("codetyper.config.credentials") + local config = codetyper.get_config() + if config.llm.provider == "copilot" then + return credentials.get_copilot_model_names() + end + return {} + end, + }) - -- Setup default keymaps - M.setup_keymaps() + -- Setup default keymaps + M.setup_keymaps() end --- Setup default keymaps for transform commands function M.setup_keymaps() - -- Visual mode: transform selection with custom prompt input - vim.keymap.set("v", "ctt", function() - transform.cmd_transform_selection() - end, { - silent = true, - desc = "Coder: Transform selection with prompt", - }) - -- Normal mode: prompt only (no selection); request is entered in the prompt - vim.keymap.set("n", "ctt", function() - transform.cmd_transform_selection() - end, { - silent = true, - desc = "Coder: Open prompt window", - }) + -- Visual mode: transform selection with custom prompt input + vim.keymap.set("v", "ctt", function() + transform.cmd_transform_selection() + end, { + silent = true, + desc = "Coder: Transform selection with prompt", + }) + -- Normal mode: prompt only (no selection); request is entered in the prompt + vim.keymap.set("n", "ctt", function() + transform.cmd_transform_selection() + end, { + silent = true, + desc = "Coder: Open prompt window", + }) end return M diff --git a/lua/codetyper/adapters/nvim/ui/context_modal.lua b/lua/codetyper/adapters/nvim/ui/context_modal.lua index 0aa98e3..5dfd444 100644 --- a/lua/codetyper/adapters/nvim/ui/context_modal.lua +++ b/lua/codetyper/adapters/nvim/ui/context_modal.lua @@ -14,138 +14,139 @@ local M = {} ---@field llm_response string|nil LLM's response asking for context local state = { - buf = nil, - win = nil, - original_event = nil, - callback = nil, - llm_response = nil, - attached_files = nil, + buf = nil, + win = nil, + original_event = nil, + callback = nil, + llm_response = nil, + attached_files = nil, } --- Close the context modal function M.close() - if state.win and vim.api.nvim_win_is_valid(state.win) then - vim.api.nvim_win_close(state.win, true) - end - if state.buf and vim.api.nvim_buf_is_valid(state.buf) then - vim.api.nvim_buf_delete(state.buf, { force = true }) - end - state.win = nil - state.buf = nil - state.original_event = nil - state.callback = nil - state.llm_response = nil + if state.win and vim.api.nvim_win_is_valid(state.win) then + vim.api.nvim_win_close(state.win, true) + end + if state.buf and vim.api.nvim_buf_is_valid(state.buf) then + vim.api.nvim_buf_delete(state.buf, { force = true }) + end + state.win = nil + state.buf = nil + state.original_event = nil + state.callback = nil + state.llm_response = nil end --- Submit the additional context local function submit() - if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then - return - end + if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then + return + end - local lines = vim.api.nvim_buf_get_lines(state.buf, 0, -1, false) - local additional_context = table.concat(lines, "\n") + local lines = vim.api.nvim_buf_get_lines(state.buf, 0, -1, false) + local additional_context = table.concat(lines, "\n") - -- Trim whitespace - additional_context = additional_context:match("^%s*(.-)%s*$") or additional_context + -- Trim whitespace + additional_context = additional_context:match("^%s*(.-)%s*$") or additional_context - if additional_context == "" then - M.close() - return - end + if additional_context == "" then + M.close() + return + end - local original_event = state.original_event - local callback = state.callback + local original_event = state.original_event + local callback = state.callback - M.close() + M.close() - if callback and original_event then - -- Pass attached_files as third optional parameter - callback(original_event, additional_context, state.attached_files) - end + if callback and original_event then + -- Pass attached_files as third optional parameter + callback(original_event, additional_context, state.attached_files) + end end - --- Parse requested file paths from LLM response and resolve to full paths local function parse_requested_files(response) - if not response or response == "" then - return {} - end + if not response or response == "" then + return {} + end - local cwd = vim.fn.getcwd() - local candidates = {} - local seen = {} + local cwd = vim.fn.getcwd() + local candidates = {} + local seen = {} - for path in response:gmatch("`([%w%._%-%/]+%.[%w_]+)`") do - if not seen[path] then - table.insert(candidates, path) - seen[path] = true - end - end - for path in response:gmatch("([%w%._%-%/]+%.[%w_]+)") do - if not seen[path] then - table.insert(candidates, path) - seen[path] = true - end - end + for path in response:gmatch("`([%w%._%-%/]+%.[%w_]+)`") do + if not seen[path] then + table.insert(candidates, path) + seen[path] = true + end + end + for path in response:gmatch("([%w%._%-%/]+%.[%w_]+)") do + if not seen[path] then + table.insert(candidates, path) + seen[path] = true + end + end - -- Resolve to full paths using cwd and glob - local resolved = {} - for _, p in ipairs(candidates) do - local full = nil - if p:sub(1,1) == "/" and vim.fn.filereadable(p) == 1 then - full = p - else - local try1 = cwd .. "/" .. p - if vim.fn.filereadable(try1) == 1 then - full = try1 - else - local tail = p:match("[^/]+$") or p - local matches = vim.fn.globpath(cwd, "**/" .. tail, false, true) - if matches and #matches > 0 then - full = matches[1] - end - end - end - if full and vim.fn.filereadable(full) == 1 then - table.insert(resolved, full) - end - end - return resolved + -- Resolve to full paths using cwd and glob + local resolved = {} + for _, p in ipairs(candidates) do + local full = nil + if p:sub(1, 1) == "/" and vim.fn.filereadable(p) == 1 then + full = p + else + local try1 = cwd .. "/" .. p + if vim.fn.filereadable(try1) == 1 then + full = try1 + else + local tail = p:match("[^/]+$") or p + local matches = vim.fn.globpath(cwd, "**/" .. tail, false, true) + if matches and #matches > 0 then + full = matches[1] + end + end + end + if full and vim.fn.filereadable(full) == 1 then + table.insert(resolved, full) + end + end + return resolved end - --- Attach parsed files into the modal buffer and remember them for submission local function attach_requested_files() - if not state.llm_response or state.llm_response == "" then - return - end - local files = parse_requested_files(state.llm_response) - if #files == 0 then - local ui_prompts = require("codetyper.prompts.agents.modal").ui - vim.api.nvim_buf_set_lines(state.buf, vim.api.nvim_buf_line_count(state.buf), -1, false, ui_prompts.files_header) - return - end + if not state.llm_response or state.llm_response == "" then + return + end + local files = parse_requested_files(state.llm_response) + if #files == 0 then + local ui_prompts = require("codetyper.prompts.agents.modal").ui + vim.api.nvim_buf_set_lines(state.buf, vim.api.nvim_buf_line_count(state.buf), -1, false, ui_prompts.files_header) + return + end - state.attached_files = state.attached_files or {} + state.attached_files = state.attached_files or {} - for _, full in ipairs(files) do - local ok, lines = pcall(vim.fn.readfile, full) - if ok and lines and #lines > 0 then - table.insert(state.attached_files, { path = vim.fn.fnamemodify(full, ":~:." ) , full_path = full, content = table.concat(lines, "\n") }) - local insert_at = vim.api.nvim_buf_line_count(state.buf) - vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Attached: " .. full .. " --" }) - for i, l in ipairs(lines) do - vim.api.nvim_buf_set_lines(state.buf, insert_at + 1 + i, insert_at + 1 + i, false, { l }) - end - else - local insert_at = vim.api.nvim_buf_line_count(state.buf) - vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Failed to read: " .. full .. " --" }) - end - end - -- Move cursor to end and enter insert mode - vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 }) - vim.cmd("startinsert") + for _, full in ipairs(files) do + local ok, lines = pcall(vim.fn.readfile, full) + if ok and lines and #lines > 0 then + table.insert( + state.attached_files, + { path = vim.fn.fnamemodify(full, ":~:."), full_path = full, content = table.concat(lines, "\n") } + ) + local insert_at = vim.api.nvim_buf_line_count(state.buf) + vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Attached: " .. full .. " --" }) + for i, l in ipairs(lines) do + vim.api.nvim_buf_set_lines(state.buf, insert_at + 1 + i, insert_at + 1 + i, false, { l }) + end + else + local insert_at = vim.api.nvim_buf_line_count(state.buf) + vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Failed to read: " .. full .. " --" }) + end + end + -- Move cursor to end and enter insert mode + vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 }) + vim.cmd("startinsert") end --- Open the context modal @@ -154,228 +155,246 @@ end ---@param callback function(event: table, additional_context: string, attached_files?: table) ---@param suggested_commands table[]|nil Optional list of {label,cmd} suggested shell commands function M.open(original_event, llm_response, callback, suggested_commands) - -- Close any existing modal - M.close() + -- Close any existing modal + M.close() - state.original_event = original_event - state.llm_response = llm_response - state.callback = callback + state.original_event = original_event + state.llm_response = llm_response + state.callback = callback - -- Calculate window size - local width = math.min(80, vim.o.columns - 10) - local height = 10 + -- Calculate window size + local width = math.min(80, vim.o.columns - 10) + local height = 10 - -- Create buffer - state.buf = vim.api.nvim_create_buf(false, true) - vim.bo[state.buf].buftype = "nofile" - vim.bo[state.buf].bufhidden = "wipe" - vim.bo[state.buf].filetype = "markdown" + -- Create buffer + state.buf = vim.api.nvim_create_buf(false, true) + vim.bo[state.buf].buftype = "nofile" + vim.bo[state.buf].bufhidden = "wipe" + vim.bo[state.buf].filetype = "markdown" - -- Create window - local row = math.floor((vim.o.lines - height) / 2) - local col = math.floor((vim.o.columns - width) / 2) + -- Create window + local row = math.floor((vim.o.lines - height) / 2) + local col = math.floor((vim.o.columns - width) / 2) - state.win = vim.api.nvim_open_win(state.buf, true, { - relative = "editor", - row = row, - col = col, - width = width, - height = height, - style = "minimal", - border = "rounded", - title = " Additional Context Needed ", - title_pos = "center", - }) + state.win = vim.api.nvim_open_win(state.buf, true, { + relative = "editor", + row = row, + col = col, + width = width, + height = height, + style = "minimal", + border = "rounded", + title = " Additional Context Needed ", + title_pos = "center", + }) - -- Set window options - vim.wo[state.win].wrap = true - vim.wo[state.win].cursorline = true + -- Set window options + vim.wo[state.win].wrap = true + vim.wo[state.win].cursorline = true - local ui_prompts = require("codetyper.prompts.agents.modal").ui + local ui_prompts = require("codetyper.prompts.agents.modal").ui - -- Add header showing what the LLM said - local header_lines = { - ui_prompts.llm_response_header, - } + -- Add header showing what the LLM said + local header_lines = { + ui_prompts.llm_response_header, + } - -- Truncate LLM response for display - local response_preview = llm_response or "" - if #response_preview > 200 then - response_preview = response_preview:sub(1, 200) .. "..." - end - for line in response_preview:gmatch("[^\n]+") do - table.insert(header_lines, "-- " .. line) - end + -- Truncate LLM response for display + local response_preview = llm_response or "" + if #response_preview > 200 then + response_preview = response_preview:sub(1, 200) .. "..." + end + for line in response_preview:gmatch("[^\n]+") do + table.insert(header_lines, "-- " .. line) + end - -- If suggested commands were provided, show them in the header - if suggested_commands and #suggested_commands > 0 then - table.insert(header_lines, "") - table.insert(header_lines, ui_prompts.suggested_commands_header) - for i, s in ipairs(suggested_commands) do - local label = s.label or s.cmd - table.insert(header_lines, string.format("[%d] %s: %s", i, label, s.cmd)) - end - table.insert(header_lines, ui_prompts.commands_hint) - end + -- If suggested commands were provided, show them in the header + if suggested_commands and #suggested_commands > 0 then + table.insert(header_lines, "") + table.insert(header_lines, ui_prompts.suggested_commands_header) + for i, s in ipairs(suggested_commands) do + local label = s.label or s.cmd + table.insert(header_lines, string.format("[%d] %s: %s", i, label, s.cmd)) + end + table.insert(header_lines, ui_prompts.commands_hint) + end - table.insert(header_lines, "") - table.insert(header_lines, ui_prompts.input_header) - table.insert(header_lines, "") + table.insert(header_lines, "") + table.insert(header_lines, ui_prompts.input_header) + table.insert(header_lines, "") - vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, header_lines) + vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, header_lines) - -- Move cursor to the end - vim.api.nvim_win_set_cursor(state.win, { #header_lines, 0 }) + -- Move cursor to the end + vim.api.nvim_win_set_cursor(state.win, { #header_lines, 0 }) - -- Set up keymaps - local opts = { buffer = state.buf, noremap = true, silent = true } + -- Set up keymaps + local opts = { buffer = state.buf, noremap = true, silent = true } - -- Submit with Ctrl+Enter or s - vim.keymap.set("n", "", submit, opts) - vim.keymap.set("i", "", submit, opts) - vim.keymap.set("n", "s", submit, opts) - vim.keymap.set("n", "", submit, opts) + -- Submit with Ctrl+Enter or s + vim.keymap.set("n", "", submit, opts) + vim.keymap.set("i", "", submit, opts) + vim.keymap.set("n", "s", submit, opts) + vim.keymap.set("n", "", submit, opts) - -- Attach parsed files (from LLM response) - vim.keymap.set("n", "a", function() - attach_requested_files() - end, opts) + -- Attach parsed files (from LLM response) + vim.keymap.set("n", "a", function() + attach_requested_files() + end, opts) - -- Confirm and submit with 'c' (convenient when doing question round) - vim.keymap.set("n", "c", submit, opts) + -- Confirm and submit with 'c' (convenient when doing question round) + vim.keymap.set("n", "c", submit, opts) - -- Quick run of project inspection from modal with r / in insert mode - vim.keymap.set("n", "r", run_project_inspect, opts) - vim.keymap.set("i", "", function() - vim.schedule(run_project_inspect) - end, { buffer = state.buf, noremap = true, silent = true }) + -- Quick run of project inspection from modal with r / in insert mode + vim.keymap.set("n", "r", run_project_inspect, opts) + vim.keymap.set("i", "", function() + vim.schedule(run_project_inspect) + end, { buffer = state.buf, noremap = true, silent = true }) - -- If suggested commands provided, create per-command keymaps 1..n to run them - state.suggested_commands = suggested_commands - if suggested_commands and #suggested_commands > 0 then - for i, s in ipairs(suggested_commands) do - local key = "" .. tostring(i) - vim.keymap.set("n", key, function() - -- run this single command and append output - if not s or not s.cmd then - return - end - local ok, out = pcall(vim.fn.systemlist, s.cmd) - local insert_at = vim.api.nvim_buf_line_count(state.buf) - vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Output: " .. s.cmd .. " --" }) - if ok and out and #out > 0 then - for j, line in ipairs(out) do - vim.api.nvim_buf_set_lines(state.buf, insert_at + j, insert_at + j, false, { line }) - end - else - vim.api.nvim_buf_set_lines(state.buf, insert_at + 1, insert_at + 1, false, { "(no output or command failed)" }) - end - vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 }) - vim.cmd("startinsert") - end, opts) - end - -- Also map 0 to run all suggested commands - vim.keymap.set("n", "0", function() - for _, s in ipairs(suggested_commands) do - pcall(function() - local ok, out = pcall(vim.fn.systemlist, s.cmd) - local insert_at = vim.api.nvim_buf_line_count(state.buf) - vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Output: " .. s.cmd .. " --" }) - if ok and out and #out > 0 then - for j, line in ipairs(out) do - vim.api.nvim_buf_set_lines(state.buf, insert_at + j, insert_at + j, false, { line }) - end - else - vim.api.nvim_buf_set_lines(state.buf, insert_at + 1, insert_at + 1, false, { "(no output or command failed)" }) - end - end) - end - vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 }) - vim.cmd("startinsert") - end, opts) - end + -- If suggested commands provided, create per-command keymaps 1..n to run them + state.suggested_commands = suggested_commands + if suggested_commands and #suggested_commands > 0 then + for i, s in ipairs(suggested_commands) do + local key = "" .. tostring(i) + vim.keymap.set("n", key, function() + -- run this single command and append output + if not s or not s.cmd then + return + end + local ok, out = pcall(vim.fn.systemlist, s.cmd) + local insert_at = vim.api.nvim_buf_line_count(state.buf) + vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Output: " .. s.cmd .. " --" }) + if ok and out and #out > 0 then + for j, line in ipairs(out) do + vim.api.nvim_buf_set_lines(state.buf, insert_at + j, insert_at + j, false, { line }) + end + else + vim.api.nvim_buf_set_lines( + state.buf, + insert_at + 1, + insert_at + 1, + false, + { "(no output or command failed)" } + ) + end + vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 }) + vim.cmd("startinsert") + end, opts) + end + -- Also map 0 to run all suggested commands + vim.keymap.set("n", "0", function() + for _, s in ipairs(suggested_commands) do + pcall(function() + local ok, out = pcall(vim.fn.systemlist, s.cmd) + local insert_at = vim.api.nvim_buf_line_count(state.buf) + vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Output: " .. s.cmd .. " --" }) + if ok and out and #out > 0 then + for j, line in ipairs(out) do + vim.api.nvim_buf_set_lines(state.buf, insert_at + j, insert_at + j, false, { line }) + end + else + vim.api.nvim_buf_set_lines( + state.buf, + insert_at + 1, + insert_at + 1, + false, + { "(no output or command failed)" } + ) + end + end) + end + vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 }) + vim.cmd("startinsert") + end, opts) + end - -- Close with Esc or q - vim.keymap.set("n", "", M.close, opts) - vim.keymap.set("n", "q", M.close, opts) + -- Close with Esc or q + vim.keymap.set("n", "", M.close, opts) + vim.keymap.set("n", "q", M.close, opts) - -- Start in insert mode - vim.cmd("startinsert") + -- Start in insert mode + vim.cmd("startinsert") - -- Log - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = "Context modal opened - waiting for user input", - }) - end) + -- Log + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = "Context modal opened - waiting for user input", + }) + end) end --- Run a small set of safe project inspection commands and insert outputs into the modal buffer local function run_project_inspect() - if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then - return - end + if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then + return + end - local cmds = { - { label = "List files (ls -la)", cmd = "ls -la" }, - { label = "Git status (git status --porcelain)", cmd = "git status --porcelain" }, - { label = "Git top (git rev-parse --show-toplevel)", cmd = "git rev-parse --show-toplevel" }, - { label = "Show repo files (git ls-files)", cmd = "git ls-files" }, - } + local cmds = { + { label = "List files (ls -la)", cmd = "ls -la" }, + { label = "Git status (git status --porcelain)", cmd = "git status --porcelain" }, + { label = "Git top (git rev-parse --show-toplevel)", cmd = "git rev-parse --show-toplevel" }, + { label = "Show repo files (git ls-files)", cmd = "git ls-files" }, + } - local ui_prompts = require("codetyper.prompts.agents.modal").ui - local insert_pos = vim.api.nvim_buf_line_count(state.buf) - vim.api.nvim_buf_set_lines(state.buf, insert_pos, insert_pos, false, ui_prompts.project_inspect_header) + local ui_prompts = require("codetyper.prompts.agents.modal").ui + local insert_pos = vim.api.nvim_buf_line_count(state.buf) + vim.api.nvim_buf_set_lines(state.buf, insert_pos, insert_pos, false, ui_prompts.project_inspect_header) - for _, c in ipairs(cmds) do - local ok, out = pcall(vim.fn.systemlist, c.cmd) - if ok and out and #out > 0 then - vim.api.nvim_buf_set_lines(state.buf, insert_pos + 2, insert_pos + 2, false, { "-- " .. c.label .. " --" }) - for i, line in ipairs(out) do - vim.api.nvim_buf_set_lines(state.buf, insert_pos + 2 + i, insert_pos + 2 + i, false, { line }) - end - insert_pos = vim.api.nvim_buf_line_count(state.buf) - else - vim.api.nvim_buf_set_lines(state.buf, insert_pos + 2, insert_pos + 2, false, { "-- " .. c.label .. " --", "(no output or command failed)" }) - insert_pos = vim.api.nvim_buf_line_count(state.buf) - end - end + for _, c in ipairs(cmds) do + local ok, out = pcall(vim.fn.systemlist, c.cmd) + if ok and out and #out > 0 then + vim.api.nvim_buf_set_lines(state.buf, insert_pos + 2, insert_pos + 2, false, { "-- " .. c.label .. " --" }) + for i, line in ipairs(out) do + vim.api.nvim_buf_set_lines(state.buf, insert_pos + 2 + i, insert_pos + 2 + i, false, { line }) + end + insert_pos = vim.api.nvim_buf_line_count(state.buf) + else + vim.api.nvim_buf_set_lines( + state.buf, + insert_pos + 2, + insert_pos + 2, + false, + { "-- " .. c.label .. " --", "(no output or command failed)" } + ) + insert_pos = vim.api.nvim_buf_line_count(state.buf) + end + end - -- Move cursor to end - vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 }) - vim.cmd("startinsert") + -- Move cursor to end + vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 }) + vim.cmd("startinsert") end -- Provide a keybinding in the modal to run project inspection commands pcall(function() - if state.buf and vim.api.nvim_buf_is_valid(state.buf) then - vim.keymap.set("n", "r", run_project_inspect, { buffer = state.buf, noremap = true, silent = true }) - vim.keymap.set("i", "", function() - vim.schedule(run_project_inspect) - end, { buffer = state.buf, noremap = true, silent = true }) - end + if state.buf and vim.api.nvim_buf_is_valid(state.buf) then + vim.keymap.set("n", "r", run_project_inspect, { buffer = state.buf, noremap = true, silent = true }) + vim.keymap.set("i", "", function() + vim.schedule(run_project_inspect) + end, { buffer = state.buf, noremap = true, silent = true }) + end end) --- Check if modal is open ---@return boolean function M.is_open() - return state.win ~= nil and vim.api.nvim_win_is_valid(state.win) + return state.win ~= nil and vim.api.nvim_win_is_valid(state.win) end --- Setup autocmds for the context modal function M.setup() - local group = vim.api.nvim_create_augroup("CodetypeContextModal", { clear = true }) + local group = vim.api.nvim_create_augroup("CodetypeContextModal", { clear = true }) - -- Close context modal when exiting Neovim - vim.api.nvim_create_autocmd("VimLeavePre", { - group = group, - callback = function() - M.close() - end, - desc = "Close context modal before exiting Neovim", - }) + -- Close context modal when exiting Neovim + vim.api.nvim_create_autocmd("VimLeavePre", { + group = group, + callback = function() + M.close() + end, + desc = "Close context modal before exiting Neovim", + }) end return M diff --git a/lua/codetyper/adapters/nvim/ui/diff_review.lua b/lua/codetyper/adapters/nvim/ui/diff_review.lua index d124360..8fcd872 100644 --- a/lua/codetyper/adapters/nvim/ui/diff_review.lua +++ b/lua/codetyper/adapters/nvim/ui/diff_review.lua @@ -8,7 +8,6 @@ local M = {} local utils = require("codetyper.support.utils") local prompts = require("codetyper.prompts.agents.diff") - ---@class DiffEntry ---@field path string File path ---@field operation string "create"|"edit"|"delete" @@ -96,9 +95,16 @@ local function generate_diff_lines(original, modified, filepath) if orig ~= mod then if not in_change then - table.insert(lines, string.format("@@ -%d,%d +%d,%d @@", - math.max(1, i - 2), math.min(5, #orig_lines - i + 3), - math.max(1, i - 2), math.min(5, #mod_lines - i + 3))) + table.insert( + lines, + string.format( + "@@ -%d,%d +%d,%d @@", + math.max(1, i - 2), + math.min(5, #orig_lines - i + 3), + math.max(1, i - 2), + math.min(5, #mod_lines - i + 3) + ) + ) in_change = true end if orig ~= "" then @@ -140,10 +146,12 @@ local function update_diff_view() local status_icon = entry.applied and " " or (entry.approved and " " or " ") local op_icon = entry.operation == "create" and "+" or (entry.operation == "delete" and "-" or "~") local current_status = entry.applied and ui_prompts.status.applied - or (entry.approved and ui_prompts.status.approved or ui_prompts.status.pending) + or (entry.approved and ui_prompts.status.approved or ui_prompts.status.pending) - table.insert(lines, string.format(ui_prompts.diff_header.top, - status_icon, op_icon, vim.fn.fnamemodify(entry.path, ":t"))) + table.insert( + lines, + string.format(ui_prompts.diff_header.top, status_icon, op_icon, vim.fn.fnamemodify(entry.path, ":t")) + ) table.insert(lines, string.format(ui_prompts.diff_header.path, entry.path)) table.insert(lines, string.format(ui_prompts.diff_header.op, entry.operation)) table.insert(lines, string.format(ui_prompts.diff_header.status, current_status)) @@ -332,7 +340,9 @@ function M.open() vim.keymap.set("n", "k", M.prev, list_opts) vim.keymap.set("n", "", M.next, list_opts) vim.keymap.set("n", "", M.prev, list_opts) - vim.keymap.set("n", "", function() vim.api.nvim_set_current_win(state.diff_win) end, list_opts) + vim.keymap.set("n", "", function() + vim.api.nvim_set_current_win(state.diff_win) + end, list_opts) vim.keymap.set("n", "a", M.approve_current, list_opts) vim.keymap.set("n", "r", M.reject_current, list_opts) vim.keymap.set("n", "A", M.approve_all, list_opts) @@ -343,7 +353,9 @@ function M.open() local diff_opts = { buffer = state.diff_buf, noremap = true, silent = true } vim.keymap.set("n", "j", M.next, diff_opts) vim.keymap.set("n", "k", M.prev, diff_opts) - vim.keymap.set("n", "", function() vim.api.nvim_set_current_win(state.list_win) end, diff_opts) + vim.keymap.set("n", "", function() + vim.api.nvim_set_current_win(state.list_win) + end, diff_opts) vim.keymap.set("n", "a", M.approve_current, diff_opts) vim.keymap.set("n", "r", M.reject_current, diff_opts) vim.keymap.set("n", "A", M.approve_all, diff_opts) diff --git a/lua/codetyper/adapters/nvim/ui/logs.lua b/lua/codetyper/adapters/nvim/ui/logs.lua index ce904ee..8e3bb78 100644 --- a/lua/codetyper/adapters/nvim/ui/logs.lua +++ b/lua/codetyper/adapters/nvim/ui/logs.lua @@ -6,7 +6,6 @@ local M = {} local params = require("codetyper.params.agents.logs") - ---@class LogEntry ---@field timestamp string ISO timestamp ---@field level string "info" | "debug" | "request" | "response" | "tool" | "error" @@ -195,7 +194,10 @@ end ---@param tokens number Tokens used ---@param duration number Duration in seconds function M.explore_done(tool_uses, tokens, duration) - M.log("result", string.format(" ⎿ Done (%d tool uses · %.1fk tokens · %.1fs)", tool_uses, tokens / 1000, duration)) + M.log( + "result", + string.format(" ⎿ Done (%d tool uses · %.1fk tokens · %.1fs)", tool_uses, tokens / 1000, duration) + ) end --- Log update/edit operation diff --git a/lua/codetyper/adapters/nvim/ui/logs_panel.lua b/lua/codetyper/adapters/nvim/ui/logs_panel.lua index 5218bab..4678b0e 100644 --- a/lua/codetyper/adapters/nvim/ui/logs_panel.lua +++ b/lua/codetyper/adapters/nvim/ui/logs_panel.lua @@ -15,13 +15,13 @@ local queue = require("codetyper.core.events.queue") ---@field queue_listener_id number|nil Listener ID for queue local state = { - buf = nil, - win = nil, - queue_buf = nil, - queue_win = nil, - is_open = false, - listener_id = nil, - queue_listener_id = nil, + buf = nil, + win = nil, + queue_buf = nil, + queue_win = nil, + is_open = false, + listener_id = nil, + queue_listener_id = nil, } --- Namespace for highlights @@ -35,346 +35,346 @@ local QUEUE_HEIGHT = 8 --- Add a log entry to the buffer ---@param entry table Log entry local function add_log_entry(entry) - if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then - return - end + if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then + return + end - vim.schedule(function() - if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then - return - end + vim.schedule(function() + if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then + return + end - -- Handle clear event - if entry.level == "clear" then - vim.bo[state.buf].modifiable = true - vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, { - "Generation Logs", - string.rep("─", LOGS_WIDTH - 2), - "", - }) - vim.bo[state.buf].modifiable = false - return - end + -- Handle clear event + if entry.level == "clear" then + vim.bo[state.buf].modifiable = true + vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, { + "Generation Logs", + string.rep("─", LOGS_WIDTH - 2), + "", + }) + vim.bo[state.buf].modifiable = false + return + end - vim.bo[state.buf].modifiable = true + vim.bo[state.buf].modifiable = true - local formatted = logs.format_entry(entry) - local formatted_lines = vim.split(formatted, "\n", { plain = true }) - local line_count = vim.api.nvim_buf_line_count(state.buf) + local formatted = logs.format_entry(entry) + local formatted_lines = vim.split(formatted, "\n", { plain = true }) + local line_count = vim.api.nvim_buf_line_count(state.buf) - vim.api.nvim_buf_set_lines(state.buf, -1, -1, false, formatted_lines) + vim.api.nvim_buf_set_lines(state.buf, -1, -1, false, formatted_lines) - -- Apply highlighting based on level - local hl_map = { - info = "DiagnosticInfo", - debug = "Comment", - request = "DiagnosticWarn", - response = "DiagnosticOk", - tool = "DiagnosticHint", - error = "DiagnosticError", - } + -- Apply highlighting based on level + local hl_map = { + info = "DiagnosticInfo", + debug = "Comment", + request = "DiagnosticWarn", + response = "DiagnosticOk", + tool = "DiagnosticHint", + error = "DiagnosticError", + } - local hl = hl_map[entry.level] or "Normal" - for i = 0, #formatted_lines - 1 do - vim.api.nvim_buf_add_highlight(state.buf, ns_logs, hl, line_count + i, 0, -1) - end + local hl = hl_map[entry.level] or "Normal" + for i = 0, #formatted_lines - 1 do + vim.api.nvim_buf_add_highlight(state.buf, ns_logs, hl, line_count + i, 0, -1) + end - vim.bo[state.buf].modifiable = false + vim.bo[state.buf].modifiable = false - -- Auto-scroll logs - if state.win and vim.api.nvim_win_is_valid(state.win) then - local new_count = vim.api.nvim_buf_line_count(state.buf) - pcall(vim.api.nvim_win_set_cursor, state.win, { new_count, 0 }) - end - end) + -- Auto-scroll logs + if state.win and vim.api.nvim_win_is_valid(state.win) then + local new_count = vim.api.nvim_buf_line_count(state.buf) + pcall(vim.api.nvim_win_set_cursor, state.win, { new_count, 0 }) + end + end) end --- Update the title with token counts local function update_title() - if not state.win or not vim.api.nvim_win_is_valid(state.win) then - return - end + if not state.win or not vim.api.nvim_win_is_valid(state.win) then + return + end - local prompt_tokens, response_tokens = logs.get_token_totals() - local provider, model = logs.get_provider_info() + local prompt_tokens, response_tokens = logs.get_token_totals() + local provider, model = logs.get_provider_info() - if provider and state.buf and vim.api.nvim_buf_is_valid(state.buf) then - vim.bo[state.buf].modifiable = true - local title = string.format("%s | %d/%d tokens", (provider or ""):upper(), prompt_tokens, response_tokens) - vim.api.nvim_buf_set_lines(state.buf, 0, 1, false, { title }) - vim.bo[state.buf].modifiable = false - end + if provider and state.buf and vim.api.nvim_buf_is_valid(state.buf) then + vim.bo[state.buf].modifiable = true + local title = string.format("%s | %d/%d tokens", (provider or ""):upper(), prompt_tokens, response_tokens) + vim.api.nvim_buf_set_lines(state.buf, 0, 1, false, { title }) + vim.bo[state.buf].modifiable = false + end end --- Update the queue display local function update_queue_display() - if not state.queue_buf or not vim.api.nvim_buf_is_valid(state.queue_buf) then - return - end + if not state.queue_buf or not vim.api.nvim_buf_is_valid(state.queue_buf) then + return + end - vim.schedule(function() - if not state.queue_buf or not vim.api.nvim_buf_is_valid(state.queue_buf) then - return - end + vim.schedule(function() + if not state.queue_buf or not vim.api.nvim_buf_is_valid(state.queue_buf) then + return + end - vim.bo[state.queue_buf].modifiable = true + vim.bo[state.queue_buf].modifiable = true - local lines = { - "Queue", - string.rep("─", LOGS_WIDTH - 2), - } + local lines = { + "Queue", + string.rep("─", LOGS_WIDTH - 2), + } - -- Get all events (pending and processing) - local pending = queue.get_pending() - local processing = queue.get_processing() + -- Get all events (pending and processing) + local pending = queue.get_pending() + local processing = queue.get_processing() - -- Add processing events first - for _, event in ipairs(processing) do - local filename = vim.fn.fnamemodify(event.target_path or "", ":t") - local line_num = event.range and event.range.start_line or 0 - local prompt_preview = (event.prompt_content or ""):sub(1, 25):gsub("\n", " ") - if #(event.prompt_content or "") > 25 then - prompt_preview = prompt_preview .. "..." - end - table.insert(lines, string.format("▶ %s:%d %s", filename, line_num, prompt_preview)) - end + -- Add processing events first + for _, event in ipairs(processing) do + local filename = vim.fn.fnamemodify(event.target_path or "", ":t") + local line_num = event.range and event.range.start_line or 0 + local prompt_preview = (event.prompt_content or ""):sub(1, 25):gsub("\n", " ") + if #(event.prompt_content or "") > 25 then + prompt_preview = prompt_preview .. "..." + end + table.insert(lines, string.format("▶ %s:%d %s", filename, line_num, prompt_preview)) + end - -- Add pending events - for _, event in ipairs(pending) do - local filename = vim.fn.fnamemodify(event.target_path or "", ":t") - local line_num = event.range and event.range.start_line or 0 - local prompt_preview = (event.prompt_content or ""):sub(1, 25):gsub("\n", " ") - if #(event.prompt_content or "") > 25 then - prompt_preview = prompt_preview .. "..." - end - table.insert(lines, string.format("○ %s:%d %s", filename, line_num, prompt_preview)) - end + -- Add pending events + for _, event in ipairs(pending) do + local filename = vim.fn.fnamemodify(event.target_path or "", ":t") + local line_num = event.range and event.range.start_line or 0 + local prompt_preview = (event.prompt_content or ""):sub(1, 25):gsub("\n", " ") + if #(event.prompt_content or "") > 25 then + prompt_preview = prompt_preview .. "..." + end + table.insert(lines, string.format("○ %s:%d %s", filename, line_num, prompt_preview)) + end - if #pending == 0 and #processing == 0 then - table.insert(lines, " (empty)") - end + if #pending == 0 and #processing == 0 then + table.insert(lines, " (empty)") + end - vim.api.nvim_buf_set_lines(state.queue_buf, 0, -1, false, lines) + vim.api.nvim_buf_set_lines(state.queue_buf, 0, -1, false, lines) - -- Apply highlights - vim.api.nvim_buf_clear_namespace(state.queue_buf, ns_queue, 0, -1) - vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "Title", 0, 0, -1) - vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "Comment", 1, 0, -1) + -- Apply highlights + vim.api.nvim_buf_clear_namespace(state.queue_buf, ns_queue, 0, -1) + vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "Title", 0, 0, -1) + vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "Comment", 1, 0, -1) - local line_idx = 2 - for _ = 1, #processing do - vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "DiagnosticWarn", line_idx, 0, 1) - vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "String", line_idx, 2, -1) - line_idx = line_idx + 1 - end - for _ = 1, #pending do - vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "Comment", line_idx, 0, 1) - vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "Normal", line_idx, 2, -1) - line_idx = line_idx + 1 - end + local line_idx = 2 + for _ = 1, #processing do + vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "DiagnosticWarn", line_idx, 0, 1) + vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "String", line_idx, 2, -1) + line_idx = line_idx + 1 + end + for _ = 1, #pending do + vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "Comment", line_idx, 0, 1) + vim.api.nvim_buf_add_highlight(state.queue_buf, ns_queue, "Normal", line_idx, 2, -1) + line_idx = line_idx + 1 + end - vim.bo[state.queue_buf].modifiable = false - end) + vim.bo[state.queue_buf].modifiable = false + end) end --- Open the logs panel function M.open() - if state.is_open then - return - end + if state.is_open then + return + end - -- Clear previous logs - logs.clear() + -- Clear previous logs + logs.clear() - -- Create logs buffer - state.buf = vim.api.nvim_create_buf(false, true) - vim.bo[state.buf].buftype = "nofile" - vim.bo[state.buf].bufhidden = "hide" - vim.bo[state.buf].swapfile = false + -- Create logs buffer + state.buf = vim.api.nvim_create_buf(false, true) + vim.bo[state.buf].buftype = "nofile" + vim.bo[state.buf].bufhidden = "hide" + vim.bo[state.buf].swapfile = false - -- Create window on the right - vim.cmd("botright vsplit") - state.win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(state.win, state.buf) - vim.api.nvim_win_set_width(state.win, LOGS_WIDTH) + -- Create window on the right + vim.cmd("botright vsplit") + state.win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(state.win, state.buf) + vim.api.nvim_win_set_width(state.win, LOGS_WIDTH) - -- Window options for logs - vim.wo[state.win].number = false - vim.wo[state.win].relativenumber = false - vim.wo[state.win].signcolumn = "no" - vim.wo[state.win].wrap = true - vim.wo[state.win].linebreak = true - vim.wo[state.win].winfixwidth = true - vim.wo[state.win].cursorline = false + -- Window options for logs + vim.wo[state.win].number = false + vim.wo[state.win].relativenumber = false + vim.wo[state.win].signcolumn = "no" + vim.wo[state.win].wrap = true + vim.wo[state.win].linebreak = true + vim.wo[state.win].winfixwidth = true + vim.wo[state.win].cursorline = false - -- Set initial content for logs - vim.bo[state.buf].modifiable = true - vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, { - "Generation Logs", - string.rep("─", LOGS_WIDTH - 2), - "", - }) - vim.bo[state.buf].modifiable = false + -- Set initial content for logs + vim.bo[state.buf].modifiable = true + vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, { + "Generation Logs", + string.rep("─", LOGS_WIDTH - 2), + "", + }) + vim.bo[state.buf].modifiable = false - -- Create queue buffer - state.queue_buf = vim.api.nvim_create_buf(false, true) - vim.bo[state.queue_buf].buftype = "nofile" - vim.bo[state.queue_buf].bufhidden = "hide" - vim.bo[state.queue_buf].swapfile = false + -- Create queue buffer + state.queue_buf = vim.api.nvim_create_buf(false, true) + vim.bo[state.queue_buf].buftype = "nofile" + vim.bo[state.queue_buf].bufhidden = "hide" + vim.bo[state.queue_buf].swapfile = false - -- Create queue window as horizontal split at bottom of logs window - vim.cmd("belowright split") - state.queue_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(state.queue_win, state.queue_buf) - vim.api.nvim_win_set_height(state.queue_win, QUEUE_HEIGHT) + -- Create queue window as horizontal split at bottom of logs window + vim.cmd("belowright split") + state.queue_win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(state.queue_win, state.queue_buf) + vim.api.nvim_win_set_height(state.queue_win, QUEUE_HEIGHT) - -- Window options for queue - vim.wo[state.queue_win].number = false - vim.wo[state.queue_win].relativenumber = false - vim.wo[state.queue_win].signcolumn = "no" - vim.wo[state.queue_win].wrap = true - vim.wo[state.queue_win].linebreak = true - vim.wo[state.queue_win].winfixheight = true - vim.wo[state.queue_win].cursorline = false + -- Window options for queue + vim.wo[state.queue_win].number = false + vim.wo[state.queue_win].relativenumber = false + vim.wo[state.queue_win].signcolumn = "no" + vim.wo[state.queue_win].wrap = true + vim.wo[state.queue_win].linebreak = true + vim.wo[state.queue_win].winfixheight = true + vim.wo[state.queue_win].cursorline = false - -- Setup keymaps for logs buffer - local opts = { buffer = state.buf, noremap = true, silent = true } - vim.keymap.set("n", "q", M.close, opts) - vim.keymap.set("n", "", M.close, opts) + -- Setup keymaps for logs buffer + local opts = { buffer = state.buf, noremap = true, silent = true } + vim.keymap.set("n", "q", M.close, opts) + vim.keymap.set("n", "", M.close, opts) - -- Setup keymaps for queue buffer - local queue_opts = { buffer = state.queue_buf, noremap = true, silent = true } - vim.keymap.set("n", "q", M.close, queue_opts) - vim.keymap.set("n", "", M.close, queue_opts) + -- Setup keymaps for queue buffer + local queue_opts = { buffer = state.queue_buf, noremap = true, silent = true } + vim.keymap.set("n", "q", M.close, queue_opts) + vim.keymap.set("n", "", M.close, queue_opts) - -- Register log listener - state.listener_id = logs.add_listener(function(entry) - add_log_entry(entry) - if entry.level == "response" then - vim.schedule(update_title) - end - end) + -- Register log listener + state.listener_id = logs.add_listener(function(entry) + add_log_entry(entry) + if entry.level == "response" then + vim.schedule(update_title) + end + end) - -- Register queue listener - state.queue_listener_id = queue.add_listener(function() - update_queue_display() - end) + -- Register queue listener + state.queue_listener_id = queue.add_listener(function() + update_queue_display() + end) - -- Initial queue display - update_queue_display() + -- Initial queue display + update_queue_display() - state.is_open = true + state.is_open = true - -- Return focus to previous window - vim.cmd("wincmd p") + -- Return focus to previous window + vim.cmd("wincmd p") - logs.info("Logs panel opened") + logs.info("Logs panel opened") end --- Close the logs panel ---@param force? boolean Force close even if not marked as open function M.close(force) - if not state.is_open and not force then - return - end + if not state.is_open and not force then + return + end - -- Remove log listener - if state.listener_id then - pcall(logs.remove_listener, state.listener_id) - state.listener_id = nil - end + -- Remove log listener + if state.listener_id then + pcall(logs.remove_listener, state.listener_id) + state.listener_id = nil + end - -- Remove queue listener - if state.queue_listener_id then - pcall(queue.remove_listener, state.queue_listener_id) - state.queue_listener_id = nil - end + -- Remove queue listener + if state.queue_listener_id then + pcall(queue.remove_listener, state.queue_listener_id) + state.queue_listener_id = nil + end - -- Close queue window first - if state.queue_win then - pcall(vim.api.nvim_win_close, state.queue_win, true) - state.queue_win = nil - end + -- Close queue window first + if state.queue_win then + pcall(vim.api.nvim_win_close, state.queue_win, true) + state.queue_win = nil + end - -- Close logs window - if state.win then - pcall(vim.api.nvim_win_close, state.win, true) - state.win = nil - end + -- Close logs window + if state.win then + pcall(vim.api.nvim_win_close, state.win, true) + state.win = nil + end - -- Delete queue buffer - if state.queue_buf then - pcall(vim.api.nvim_buf_delete, state.queue_buf, { force = true }) - state.queue_buf = nil - end + -- Delete queue buffer + if state.queue_buf then + pcall(vim.api.nvim_buf_delete, state.queue_buf, { force = true }) + state.queue_buf = nil + end - -- Delete logs buffer - if state.buf then - pcall(vim.api.nvim_buf_delete, state.buf, { force = true }) - state.buf = nil - end + -- Delete logs buffer + if state.buf then + pcall(vim.api.nvim_buf_delete, state.buf, { force = true }) + state.buf = nil + end - state.is_open = false + state.is_open = false end --- Toggle the logs panel function M.toggle() - if state.is_open then - M.close() - else - M.open() - end + if state.is_open then + M.close() + else + M.open() + end end --- Check if panel is open ---@return boolean function M.is_open() - return state.is_open + return state.is_open end --- Ensure panel is open (call before starting generation) function M.ensure_open() - if not state.is_open then - M.open() - end + if not state.is_open then + M.open() + end end --- Setup autocmds for the logs panel function M.setup() - local group = vim.api.nvim_create_augroup("CodetypeLogsPanel", { clear = true }) + local group = vim.api.nvim_create_augroup("CodetypeLogsPanel", { clear = true }) - -- Close logs panel when exiting Neovim - vim.api.nvim_create_autocmd("VimLeavePre", { - group = group, - callback = function() - -- Force close to ensure cleanup even in edge cases - M.close(true) - end, - desc = "Close logs panel before exiting Neovim", - }) + -- Close logs panel when exiting Neovim + vim.api.nvim_create_autocmd("VimLeavePre", { + group = group, + callback = function() + -- Force close to ensure cleanup even in edge cases + M.close(true) + end, + desc = "Close logs panel before exiting Neovim", + }) - -- Also clean up when QuitPre fires (handles :qa, :wqa, etc.) - vim.api.nvim_create_autocmd("QuitPre", { - group = group, - callback = function() - -- Check if this is the last window (about to quit Neovim) - local wins = vim.api.nvim_list_wins() - local real_wins = 0 - for _, win in ipairs(wins) do - local buf = vim.api.nvim_win_get_buf(win) - local buftype = vim.bo[buf].buftype - -- Count non-special windows - if buftype == "" or buftype == "help" then - real_wins = real_wins + 1 - end - end - -- If only logs/queue windows remain, close them - if real_wins <= 1 then - M.close(true) - end - end, - desc = "Close logs panel on quit", - }) + -- Also clean up when QuitPre fires (handles :qa, :wqa, etc.) + vim.api.nvim_create_autocmd("QuitPre", { + group = group, + callback = function() + -- Check if this is the last window (about to quit Neovim) + local wins = vim.api.nvim_list_wins() + local real_wins = 0 + for _, win in ipairs(wins) do + local buf = vim.api.nvim_win_get_buf(win) + local buftype = vim.bo[buf].buftype + -- Count non-special windows + if buftype == "" or buftype == "help" then + real_wins = real_wins + 1 + end + end + -- If only logs/queue windows remain, close them + if real_wins <= 1 then + M.close(true) + end + end, + desc = "Close logs panel on quit", + }) end return M diff --git a/lua/codetyper/adapters/nvim/ui/thinking.lua b/lua/codetyper/adapters/nvim/ui/thinking.lua index 53a8d67..1bb4aa8 100644 --- a/lua/codetyper/adapters/nvim/ui/thinking.lua +++ b/lua/codetyper/adapters/nvim/ui/thinking.lua @@ -17,164 +17,163 @@ local queue = require("codetyper.core.events.queue") ---@field timer number|nil Defer timer for polling local state = { - win_id = nil, - buf_id = nil, - throbber = nil, - queue_listener_id = nil, - timer = nil, - stage_text = "Thinking...", + win_id = nil, + buf_id = nil, + throbber = nil, + queue_listener_id = nil, + timer = nil, + stage_text = "Thinking...", } local function get_ui_dimensions() - local ui = vim.api.nvim_list_uis()[1] - if ui then - return ui.width, ui.height - end - return vim.o.columns, vim.o.lines + local ui = vim.api.nvim_list_uis()[1] + if ui then + return ui.width, ui.height + end + return vim.o.columns, vim.o.lines end --- Top-right status window config (like 99) local function status_window_config() - local width, _ = get_ui_dimensions() - local win_width = math.min(40, math.floor(width / 3)) - return { - relative = "editor", - row = 0, - col = width, - width = win_width, - height = 2, - anchor = "NE", - style = "minimal", - border = nil, - zindex = 100, - } + local width, _ = get_ui_dimensions() + local win_width = math.min(40, math.floor(width / 3)) + return { + relative = "editor", + row = 0, + col = width, + width = win_width, + height = 2, + anchor = "NE", + style = "minimal", + border = nil, + zindex = 100, + } end local function active_count() - return queue.pending_count() + queue.processing_count() + return queue.pending_count() + queue.processing_count() end local function close_window() - if state.timer then - pcall(vim.fn.timer_stop, state.timer) - state.timer = nil - end - if state.throbber then - state.throbber:stop() - state.throbber = nil - end - if state.queue_listener_id then - queue.remove_listener(state.queue_listener_id) - state.queue_listener_id = nil - end - if state.win_id and vim.api.nvim_win_is_valid(state.win_id) then - vim.api.nvim_win_close(state.win_id, true) - end - if state.buf_id and vim.api.nvim_buf_is_valid(state.buf_id) then - vim.api.nvim_buf_delete(state.buf_id, { force = true }) - end - state.win_id = nil - state.buf_id = nil + if state.timer then + pcall(vim.fn.timer_stop, state.timer) + state.timer = nil + end + if state.throbber then + state.throbber:stop() + state.throbber = nil + end + if state.queue_listener_id then + queue.remove_listener(state.queue_listener_id) + state.queue_listener_id = nil + end + if state.win_id and vim.api.nvim_win_is_valid(state.win_id) then + vim.api.nvim_win_close(state.win_id, true) + end + if state.buf_id and vim.api.nvim_buf_is_valid(state.buf_id) then + vim.api.nvim_buf_delete(state.buf_id, { force = true }) + end + state.win_id = nil + state.buf_id = nil end local function update_display(icon, force) - if not state.buf_id or not vim.api.nvim_buf_is_valid(state.buf_id) then - return - end - local count = active_count() - if count <= 0 and not force then - return - end - local text = state.stage_text or "Thinking..." - local line = (count <= 1) - and (icon .. " " .. text) - or (icon .. " " .. text .. " (" .. tostring(count) .. " requests)") - vim.schedule(function() - if state.buf_id and vim.api.nvim_buf_is_valid(state.buf_id) then - vim.bo[state.buf_id].modifiable = true - vim.api.nvim_buf_set_lines(state.buf_id, 0, -1, false, { line }) - vim.bo[state.buf_id].modifiable = false - end - end) + if not state.buf_id or not vim.api.nvim_buf_is_valid(state.buf_id) then + return + end + local count = active_count() + if count <= 0 and not force then + return + end + local text = state.stage_text or "Thinking..." + local line = (count <= 1) and (icon .. " " .. text) + or (icon .. " " .. text .. " (" .. tostring(count) .. " requests)") + vim.schedule(function() + if state.buf_id and vim.api.nvim_buf_is_valid(state.buf_id) then + vim.bo[state.buf_id].modifiable = true + vim.api.nvim_buf_set_lines(state.buf_id, 0, -1, false, { line }) + vim.bo[state.buf_id].modifiable = false + end + end) end local function check_and_hide() - if active_count() > 0 then - return - end - close_window() + if active_count() > 0 then + return + end + close_window() end --- Ensure the thinking status window is shown and throbber is running. --- Call when starting prompt processing (instead of logs_panel.ensure_open). function M.ensure_shown() - if state.win_id and vim.api.nvim_win_is_valid(state.win_id) then - -- Already shown; throbber keeps running - return - end + if state.win_id and vim.api.nvim_win_is_valid(state.win_id) then + -- Already shown; throbber keeps running + return + end - state.buf_id = vim.api.nvim_create_buf(false, true) - vim.bo[state.buf_id].buftype = "nofile" - vim.bo[state.buf_id].bufhidden = "wipe" - vim.bo[state.buf_id].swapfile = false + state.buf_id = vim.api.nvim_create_buf(false, true) + vim.bo[state.buf_id].buftype = "nofile" + vim.bo[state.buf_id].bufhidden = "wipe" + vim.bo[state.buf_id].swapfile = false - local config = status_window_config() - state.win_id = vim.api.nvim_open_win(state.buf_id, false, config) - vim.wo[state.win_id].wrap = true - vim.wo[state.win_id].number = false - vim.wo[state.win_id].relativenumber = false + local config = status_window_config() + state.win_id = vim.api.nvim_open_win(state.buf_id, false, config) + vim.wo[state.win_id].wrap = true + vim.wo[state.win_id].number = false + vim.wo[state.win_id].relativenumber = false - state.throbber = throbber.new(function(icon) - update_display(icon) - -- When active count drops to 0, hide after a short delay - if active_count() <= 0 then - vim.defer_fn(check_and_hide, 300) - end - end) - state.throbber:start() + state.throbber = throbber.new(function(icon) + update_display(icon) + -- When active count drops to 0, hide after a short delay + if active_count() <= 0 then + vim.defer_fn(check_and_hide, 300) + end + end) + state.throbber:start() - -- Queue listener: when queue updates, check if we should hide - state.queue_listener_id = queue.add_listener(function(_, _, _) - vim.schedule(function() - if active_count() <= 0 then - vim.defer_fn(check_and_hide, 400) - end - end) - end) + -- Queue listener: when queue updates, check if we should hide + state.queue_listener_id = queue.add_listener(function(_, _, _) + vim.schedule(function() + if active_count() <= 0 then + vim.defer_fn(check_and_hide, 400) + end + end) + end) - -- Initial line (force show before enqueue so window is not empty) - local icon = (state.throbber and state.throbber.icon_set and state.throbber.icon_set[1]) or "⠋" - update_display(icon, true) + -- Initial line (force show before enqueue so window is not empty) + local icon = (state.throbber and state.throbber.icon_set and state.throbber.icon_set[1]) or "⠋" + update_display(icon, true) end --- Update the displayed stage text (e.g. "Reading context...", "Sending to LLM..."). ---@param text string function M.update_stage(text) - state.stage_text = text + state.stage_text = text end --- Force close the thinking window (e.g. on VimLeavePre). function M.close() - state.stage_text = "Thinking..." - close_window() + state.stage_text = "Thinking..." + close_window() end --- Check if thinking window is currently visible. ---@return boolean function M.is_shown() - return state.win_id ~= nil and vim.api.nvim_win_is_valid(state.win_id) + return state.win_id ~= nil and vim.api.nvim_win_is_valid(state.win_id) end --- Register autocmds for cleanup on exit. function M.setup() - local group = vim.api.nvim_create_augroup("CodetyperThinking", { clear = true }) - vim.api.nvim_create_autocmd("VimLeavePre", { - group = group, - callback = function() - M.close() - end, - desc = "Close thinking window before exiting Neovim", - }) + local group = vim.api.nvim_create_augroup("CodetyperThinking", { clear = true }) + vim.api.nvim_create_autocmd("VimLeavePre", { + group = group, + callback = function() + M.close() + end, + desc = "Close thinking window before exiting Neovim", + }) end return M diff --git a/lua/codetyper/adapters/nvim/ui/throbber.lua b/lua/codetyper/adapters/nvim/ui/throbber.lua index 929c447..7b53e5d 100644 --- a/lua/codetyper/adapters/nvim/ui/throbber.lua +++ b/lua/codetyper/adapters/nvim/ui/throbber.lua @@ -6,11 +6,11 @@ local M = {} local throb_icons = { - { "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏" }, - { "◐", "◓", "◑", "◒" }, - { "⣾", "⣽", "⣻", "⢿", "⡿", "⣟", "⣯", "⣷" }, - { "◰", "◳", "◲", "◱" }, - { "◜", "◠", "◝", "◞", "◡", "◟" }, + { "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏" }, + { "◐", "◓", "◑", "◒" }, + { "⣾", "⣽", "⣻", "⢿", "⡿", "⣟", "⣯", "⣷" }, + { "◰", "◳", "◲", "◱" }, + { "◜", "◠", "◝", "◞", "◡", "◟" }, } local throb_time = 1200 @@ -18,7 +18,7 @@ local cooldown_time = 100 local tick_time = 100 local function now() - return vim.uv and vim.uv.now() or (os.clock() * 1000) + return vim.uv and vim.uv.now() or (os.clock() * 1000) end ---@class Throbber @@ -37,51 +37,51 @@ Throbber.__index = Throbber ---@param opts? { throb_time?: number, cooldown_time?: number } ---@return Throbber function M.new(cb, opts) - opts = opts or {} - local throb_time_ms = opts.throb_time or throb_time - local cooldown_ms = opts.cooldown_time or cooldown_time - local icon_set = throb_icons[math.random(#throb_icons)] - return setmetatable({ - state = "init", - start_time = 0, - section_time = throb_time_ms, - opts = { throb_time = throb_time_ms, cooldown_time = cooldown_ms }, - cb = cb, - icon_set = icon_set, - }, Throbber) + opts = opts or {} + local throb_time_ms = opts.throb_time or throb_time + local cooldown_ms = opts.cooldown_time or cooldown_time + local icon_set = throb_icons[math.random(#throb_icons)] + return setmetatable({ + state = "init", + start_time = 0, + section_time = throb_time_ms, + opts = { throb_time = throb_time_ms, cooldown_time = cooldown_ms }, + cb = cb, + icon_set = icon_set, + }, Throbber) end function Throbber:_run() - if self.state ~= "throbbing" and self.state ~= "cooldown" then - return - end - local elapsed = now() - self.start_time - local percent = math.min(1, elapsed / self.section_time) - local idx = math.floor(percent * #self.icon_set) + 1 - idx = math.min(idx, #self.icon_set) - local icon = self.icon_set[idx] + if self.state ~= "throbbing" and self.state ~= "cooldown" then + return + end + local elapsed = now() - self.start_time + local percent = math.min(1, elapsed / self.section_time) + local idx = math.floor(percent * #self.icon_set) + 1 + idx = math.min(idx, #self.icon_set) + local icon = self.icon_set[idx] - if percent >= 1 then - self.state = self.state == "cooldown" and "throbbing" or "cooldown" - self.start_time = now() - self.section_time = (self.state == "cooldown") and self.opts.cooldown_time or self.opts.throb_time - end + if percent >= 1 then + self.state = self.state == "cooldown" and "throbbing" or "cooldown" + self.start_time = now() + self.section_time = (self.state == "cooldown") and self.opts.cooldown_time or self.opts.throb_time + end - self.cb(icon) - vim.defer_fn(function() - self:_run() - end, tick_time) + self.cb(icon) + vim.defer_fn(function() + self:_run() + end, tick_time) end function Throbber:start() - self.start_time = now() - self.section_time = self.opts.throb_time - self.state = "throbbing" - self:_run() + self.start_time = now() + self.section_time = self.opts.throb_time + self.state = "throbbing" + self:_run() end function Throbber:stop() - self.state = "stopped" + self.state = "stopped" end return M diff --git a/lua/codetyper/config/credentials.lua b/lua/codetyper/config/credentials.lua index 07b9848..10e22c2 100644 --- a/lua/codetyper/config/credentials.lua +++ b/lua/codetyper/config/credentials.lua @@ -11,112 +11,112 @@ local utils = require("codetyper.support.utils") --- Get the credentials file path ---@return string Path to credentials file local function get_credentials_path() - local data_dir = vim.fn.stdpath("data") - return data_dir .. "/codetyper/configuration.json" + local data_dir = vim.fn.stdpath("data") + return data_dir .. "/codetyper/configuration.json" end --- Ensure the credentials directory exists ---@return boolean Success local function ensure_dir() - local data_dir = vim.fn.stdpath("data") - local codetyper_dir = data_dir .. "/codetyper" - return utils.ensure_dir(codetyper_dir) + local data_dir = vim.fn.stdpath("data") + local codetyper_dir = data_dir .. "/codetyper" + return utils.ensure_dir(codetyper_dir) end --- Load credentials from file ---@return table Credentials data function M.load() - local path = get_credentials_path() - local content = utils.read_file(path) + local path = get_credentials_path() + local content = utils.read_file(path) - if not content or content == "" then - return { - version = 1, - providers = {}, - } - end + if not content or content == "" then + return { + version = 1, + providers = {}, + } + end - local ok, data = pcall(vim.json.decode, content) - if not ok or not data then - return { - version = 1, - providers = {}, - } - end + local ok, data = pcall(vim.json.decode, content) + if not ok or not data then + return { + version = 1, + providers = {}, + } + end - return data + return data end --- Save credentials to file ---@param data table Credentials data ---@return boolean Success function M.save(data) - if not ensure_dir() then - return false - end + if not ensure_dir() then + return false + end - local path = get_credentials_path() - local ok, json = pcall(vim.json.encode, data) - if not ok then - return false - end + local path = get_credentials_path() + local ok, json = pcall(vim.json.encode, data) + if not ok then + return false + end - return utils.write_file(path, json) + return utils.write_file(path, json) end --- Get API key for a provider ---@param provider string Provider name (copilot, ollama) ---@return string|nil API key or nil if not found function M.get_api_key(provider) - local data = M.load() - local provider_data = data.providers and data.providers[provider] + local data = M.load() + local provider_data = data.providers and data.providers[provider] - if provider_data and provider_data.api_key then - return provider_data.api_key - end + if provider_data and provider_data.api_key then + return provider_data.api_key + end - return nil + return nil end --- Get model for a provider ---@param provider string Provider name ---@return string|nil Model name or nil if not found function M.get_model(provider) - local data = M.load() - local provider_data = data.providers and data.providers[provider] + local data = M.load() + local provider_data = data.providers and data.providers[provider] - if provider_data and provider_data.model then - return provider_data.model - end + if provider_data and provider_data.model then + return provider_data.model + end - return nil + return nil end --- Get endpoint for a provider (for custom OpenAI-compatible endpoints) ---@param provider string Provider name ---@return string|nil Endpoint URL or nil if not found function M.get_endpoint(provider) - local data = M.load() - local provider_data = data.providers and data.providers[provider] + local data = M.load() + local provider_data = data.providers and data.providers[provider] - if provider_data and provider_data.endpoint then - return provider_data.endpoint - end + if provider_data and provider_data.endpoint then + return provider_data.endpoint + end - return nil + return nil end --- Get host for Ollama ---@return string|nil Host URL or nil if not found function M.get_ollama_host() - local data = M.load() - local provider_data = data.providers and data.providers.ollama + local data = M.load() + local provider_data = data.providers and data.providers.ollama - if provider_data and provider_data.host then - return provider_data.host - end + if provider_data and provider_data.host then + return provider_data.host + end - return nil + return nil end --- Set credentials for a provider @@ -124,452 +124,452 @@ end ---@param credentials table Credentials (api_key, model, endpoint, host) ---@return boolean Success function M.set_credentials(provider, credentials) - local data = M.load() + local data = M.load() - if not data.providers then - data.providers = {} - end + if not data.providers then + data.providers = {} + end - if not data.providers[provider] then - data.providers[provider] = {} - end + if not data.providers[provider] then + data.providers[provider] = {} + end - -- Merge credentials - for key, value in pairs(credentials) do - if value and value ~= "" then - data.providers[provider][key] = value - end - end + -- Merge credentials + for key, value in pairs(credentials) do + if value and value ~= "" then + data.providers[provider][key] = value + end + end - data.updated = os.time() + data.updated = os.time() - return M.save(data) + return M.save(data) end --- Remove credentials for a provider ---@param provider string Provider name ---@return boolean Success function M.remove_credentials(provider) - local data = M.load() + local data = M.load() - if data.providers and data.providers[provider] then - data.providers[provider] = nil - data.updated = os.time() - return M.save(data) - end + if data.providers and data.providers[provider] then + data.providers[provider] = nil + data.updated = os.time() + return M.save(data) + end - return true + return true end --- List all configured providers (checks both stored credentials AND config) ---@return table List of provider names with their config status function M.list_providers() - local data = M.load() - local result = {} + local data = M.load() + local result = {} - local all_providers = { "copilot", "ollama" } + local all_providers = { "copilot", "ollama" } - for _, provider in ipairs(all_providers) do - local provider_data = data.providers and data.providers[provider] - local has_stored_key = provider_data and provider_data.api_key and provider_data.api_key ~= "" - local has_model = provider_data and provider_data.model and provider_data.model ~= "" + for _, provider in ipairs(all_providers) do + local provider_data = data.providers and data.providers[provider] + local has_stored_key = provider_data and provider_data.api_key and provider_data.api_key ~= "" + local has_model = provider_data and provider_data.model and provider_data.model ~= "" - local configured_from_config = false - local config_model = nil - local ok, codetyper = pcall(require, "codetyper") - if ok then - local config = codetyper.get_config() - if config and config.llm and config.llm[provider] then - local pc = config.llm[provider] - config_model = pc.model + local configured_from_config = false + local config_model = nil + local ok, codetyper = pcall(require, "codetyper") + if ok then + local config = codetyper.get_config() + if config and config.llm and config.llm[provider] then + local pc = config.llm[provider] + config_model = pc.model - if provider == "copilot" then - configured_from_config = true - elseif provider == "ollama" then - configured_from_config = pc.host ~= nil - end - end - end + if provider == "copilot" then + configured_from_config = true + elseif provider == "ollama" then + configured_from_config = pc.host ~= nil + end + end + end - local is_configured = has_stored_key - or (provider == "ollama" and provider_data ~= nil) - or (provider == "copilot" and (provider_data ~= nil or configured_from_config)) - or configured_from_config + local is_configured = has_stored_key + or (provider == "ollama" and provider_data ~= nil) + or (provider == "copilot" and (provider_data ~= nil or configured_from_config)) + or configured_from_config - table.insert(result, { - name = provider, - configured = is_configured, - has_api_key = has_stored_key, - has_model = has_model or config_model ~= nil, - model = (provider_data and provider_data.model) or config_model, - source = has_stored_key and "stored" or (configured_from_config and "config" or nil), - }) - end + table.insert(result, { + name = provider, + configured = is_configured, + has_api_key = has_stored_key, + has_model = has_model or config_model ~= nil, + model = (provider_data and provider_data.model) or config_model, + source = has_stored_key and "stored" or (configured_from_config and "config" or nil), + }) + end - return result + return result end --- Default models for each provider M.default_models = { - copilot = "claude-sonnet-4", - ollama = "deepseek-coder:6.7b", + copilot = "claude-sonnet-4", + ollama = "deepseek-coder:6.7b", } --- Available models for Copilot (GitHub Copilot Chat API) --- Models with cost multipliers: 0x = free, 0.33x = discount, 1x = standard, 3x = premium M.copilot_models = { - -- Free tier (0x) - { name = "gpt-4.1", cost = "0x" }, - { name = "gpt-4o", cost = "0x" }, - { name = "gpt-5-mini", cost = "0x" }, - { name = "grok-code-fast-1", cost = "0x" }, - { name = "raptor-mini", cost = "0x" }, - -- Discount tier (0.33x) - { name = "claude-haiku-4.5", cost = "0.33x" }, - { name = "gemini-3-flash", cost = "0.33x" }, - { name = "gpt-5.1-codex-mini", cost = "0.33x" }, - -- Standard tier (1x) - { name = "claude-sonnet-4", cost = "1x" }, - { name = "claude-sonnet-4.5", cost = "1x" }, - { name = "gemini-2.5-pro", cost = "1x" }, - { name = "gemini-3-pro", cost = "1x" }, - { name = "gpt-5", cost = "1x" }, - { name = "gpt-5-codex", cost = "1x" }, - { name = "gpt-5.1", cost = "1x" }, - { name = "gpt-5.1-codex", cost = "1x" }, - { name = "gpt-5.1-codex-max", cost = "1x" }, - { name = "gpt-5.2", cost = "1x" }, - { name = "gpt-5.2-codex", cost = "1x" }, - -- Premium tier (3x) - { name = "claude-opus-4.5", cost = "3x" }, + -- Free tier (0x) + { name = "gpt-4.1", cost = "0x" }, + { name = "gpt-4o", cost = "0x" }, + { name = "gpt-5-mini", cost = "0x" }, + { name = "grok-code-fast-1", cost = "0x" }, + { name = "raptor-mini", cost = "0x" }, + -- Discount tier (0.33x) + { name = "claude-haiku-4.5", cost = "0.33x" }, + { name = "gemini-3-flash", cost = "0.33x" }, + { name = "gpt-5.1-codex-mini", cost = "0.33x" }, + -- Standard tier (1x) + { name = "claude-sonnet-4", cost = "1x" }, + { name = "claude-sonnet-4.5", cost = "1x" }, + { name = "gemini-2.5-pro", cost = "1x" }, + { name = "gemini-3-pro", cost = "1x" }, + { name = "gpt-5", cost = "1x" }, + { name = "gpt-5-codex", cost = "1x" }, + { name = "gpt-5.1", cost = "1x" }, + { name = "gpt-5.1-codex", cost = "1x" }, + { name = "gpt-5.1-codex-max", cost = "1x" }, + { name = "gpt-5.2", cost = "1x" }, + { name = "gpt-5.2-codex", cost = "1x" }, + -- Premium tier (3x) + { name = "claude-opus-4.5", cost = "3x" }, } --- Get list of copilot model names (for completion) ---@return string[] function M.get_copilot_model_names() - local names = {} - for _, model in ipairs(M.copilot_models) do - table.insert(names, model.name) - end - return names + local names = {} + for _, model in ipairs(M.copilot_models) do + table.insert(names, model.name) + end + return names end --- Get cost for a copilot model ---@param model_name string ---@return string|nil function M.get_copilot_model_cost(model_name) - for _, model in ipairs(M.copilot_models) do - if model.name == model_name then - return model.cost - end - end - return nil + for _, model in ipairs(M.copilot_models) do + if model.name == model_name then + return model.cost + end + end + return nil end --- Interactive command to add/update configuration function M.interactive_add() - local providers = { "copilot", "ollama" } + local providers = { "copilot", "ollama" } - vim.ui.select(providers, { - prompt = "Select LLM provider:", - format_item = function(item) - local display = item:sub(1, 1):upper() .. item:sub(2) - local creds = M.load() - local configured = creds.providers and creds.providers[item] - if configured and (configured.configured or item == "ollama") then - return display .. " [configured]" - end - return display - end, - }, function(provider) - if not provider then - return - end + vim.ui.select(providers, { + prompt = "Select LLM provider:", + format_item = function(item) + local display = item:sub(1, 1):upper() .. item:sub(2) + local creds = M.load() + local configured = creds.providers and creds.providers[item] + if configured and (configured.configured or item == "ollama") then + return display .. " [configured]" + end + return display + end, + }, function(provider) + if not provider then + return + end - if provider == "ollama" then - M.interactive_ollama_config() - elseif provider == "copilot" then - M.interactive_copilot_config() - end - end) + if provider == "ollama" then + M.interactive_ollama_config() + elseif provider == "copilot" then + M.interactive_copilot_config() + end + end) end --- Interactive Copilot configuration (no API key, uses OAuth) ---@param silent? boolean If true, don't show the OAuth info message function M.interactive_copilot_config(silent) - if not silent then - utils.notify("Copilot uses OAuth from copilot.lua/copilot.vim - no API key needed", vim.log.levels.INFO) - end + if not silent then + utils.notify("Copilot uses OAuth from copilot.lua/copilot.vim - no API key needed", vim.log.levels.INFO) + end - -- Get current model if configured - local current_model = M.get_model("copilot") or M.default_models.copilot - local current_cost = M.get_copilot_model_cost(current_model) or "?" + -- Get current model if configured + local current_model = M.get_model("copilot") or M.default_models.copilot + local current_cost = M.get_copilot_model_cost(current_model) or "?" - -- Build model options with "Custom..." option - local model_options = vim.deepcopy(M.copilot_models) - table.insert(model_options, { name = "Custom...", cost = "" }) + -- Build model options with "Custom..." option + local model_options = vim.deepcopy(M.copilot_models) + table.insert(model_options, { name = "Custom...", cost = "" }) - vim.ui.select(model_options, { - prompt = "Select Copilot model (current: " .. current_model .. " — " .. current_cost .. "):", - format_item = function(item) - local display = item.name - if item.cost and item.cost ~= "" then - display = display .. " — " .. item.cost - end - if item.name == current_model then - display = display .. " [current]" - end - return display - end, - }, function(choice) - if choice == nil then - return -- Cancelled - end + vim.ui.select(model_options, { + prompt = "Select Copilot model (current: " .. current_model .. " — " .. current_cost .. "):", + format_item = function(item) + local display = item.name + if item.cost and item.cost ~= "" then + display = display .. " — " .. item.cost + end + if item.name == current_model then + display = display .. " [current]" + end + return display + end, + }, function(choice) + if choice == nil then + return -- Cancelled + end - if choice.name == "Custom..." then - -- Allow custom model input - vim.ui.input({ - prompt = "Enter custom model name: ", - default = current_model, - }, function(custom_model) - if custom_model and custom_model ~= "" then - M.save_and_notify("copilot", { - model = custom_model, - configured = true, - }) - end - end) - else - M.save_and_notify("copilot", { - model = choice.name, - configured = true, - }) - end - end) + if choice.name == "Custom..." then + -- Allow custom model input + vim.ui.input({ + prompt = "Enter custom model name: ", + default = current_model, + }, function(custom_model) + if custom_model and custom_model ~= "" then + M.save_and_notify("copilot", { + model = custom_model, + configured = true, + }) + end + end) + else + M.save_and_notify("copilot", { + model = choice.name, + configured = true, + }) + end + end) end --- Interactive Ollama configuration function M.interactive_ollama_config() - vim.ui.input({ - prompt = "Ollama host (default: http://localhost:11434): ", - default = "http://localhost:11434", - }, function(host) - if host == nil then - return -- Cancelled - end + vim.ui.input({ + prompt = "Ollama host (default: http://localhost:11434): ", + default = "http://localhost:11434", + }, function(host) + if host == nil then + return -- Cancelled + end - if host == "" then - host = "http://localhost:11434" - end + if host == "" then + host = "http://localhost:11434" + end - -- Get model - local default_model = M.default_models.ollama - vim.ui.input({ - prompt = string.format("Ollama model (default: %s): ", default_model), - default = default_model, - }, function(model) - if model == nil then - return -- Cancelled - end + -- Get model + local default_model = M.default_models.ollama + vim.ui.input({ + prompt = string.format("Ollama model (default: %s): ", default_model), + default = default_model, + }, function(model) + if model == nil then + return -- Cancelled + end - if model == "" then - model = default_model - end + if model == "" then + model = default_model + end - M.save_and_notify("ollama", { - host = host, - model = model, - }) - end) - end) + M.save_and_notify("ollama", { + host = host, + model = model, + }) + end) + end) end --- Save credentials and notify user ---@param provider string Provider name ---@param credentials table Credentials to save function M.save_and_notify(provider, credentials) - if M.set_credentials(provider, credentials) then - local msg = string.format("Saved %s configuration", provider:upper()) - if credentials.model then - msg = msg .. " (model: " .. credentials.model .. ")" - end - utils.notify(msg, vim.log.levels.INFO) - else - utils.notify("Failed to save credentials", vim.log.levels.ERROR) - end + if M.set_credentials(provider, credentials) then + local msg = string.format("Saved %s configuration", provider:upper()) + if credentials.model then + msg = msg .. " (model: " .. credentials.model .. ")" + end + utils.notify(msg, vim.log.levels.INFO) + else + utils.notify("Failed to save credentials", vim.log.levels.ERROR) + end end --- Show current credentials status function M.show_status() - local providers = M.list_providers() + local providers = M.list_providers() - -- Get current active provider - local codetyper = require("codetyper") - local current = codetyper.get_config().llm.provider + -- Get current active provider + local codetyper = require("codetyper") + local current = codetyper.get_config().llm.provider - local lines = { - "Codetyper Credentials Status", - "============================", - "", - "Storage: " .. get_credentials_path(), - "Active: " .. current:upper(), - "", - } + local lines = { + "Codetyper Credentials Status", + "============================", + "", + "Storage: " .. get_credentials_path(), + "Active: " .. current:upper(), + "", + } - for _, p in ipairs(providers) do - local status_icon = p.configured and "✓" or "✗" - local active_marker = p.name == current and " [ACTIVE]" or "" - local source_info = "" - if p.configured then - source_info = p.source == "stored" and " (stored)" or " (config)" - end - local model_info = p.model and (" - " .. p.model) or "" + for _, p in ipairs(providers) do + local status_icon = p.configured and "✓" or "✗" + local active_marker = p.name == current and " [ACTIVE]" or "" + local source_info = "" + if p.configured then + source_info = p.source == "stored" and " (stored)" or " (config)" + end + local model_info = p.model and (" - " .. p.model) or "" - table.insert( - lines, - string.format(" %s %s%s%s%s", status_icon, p.name:upper(), active_marker, source_info, model_info) - ) - end + table.insert( + lines, + string.format(" %s %s%s%s%s", status_icon, p.name:upper(), active_marker, source_info, model_info) + ) + end - table.insert(lines, "") - table.insert(lines, "Commands:") - table.insert(lines, " :CoderAddApiKey - Add/update credentials") - table.insert(lines, " :CoderSwitchProvider - Switch active provider") - table.insert(lines, " :CoderRemoveApiKey - Remove stored credentials") + table.insert(lines, "") + table.insert(lines, "Commands:") + table.insert(lines, " :CoderAddApiKey - Add/update credentials") + table.insert(lines, " :CoderSwitchProvider - Switch active provider") + table.insert(lines, " :CoderRemoveApiKey - Remove stored credentials") - utils.notify(table.concat(lines, "\n")) + utils.notify(table.concat(lines, "\n")) end --- Interactive remove credentials function M.interactive_remove() - local data = M.load() - local configured = {} + local data = M.load() + local configured = {} - for provider, _ in pairs(data.providers or {}) do - table.insert(configured, provider) - end + for provider, _ in pairs(data.providers or {}) do + table.insert(configured, provider) + end - if #configured == 0 then - utils.notify("No credentials configured", vim.log.levels.INFO) - return - end + if #configured == 0 then + utils.notify("No credentials configured", vim.log.levels.INFO) + return + end - vim.ui.select(configured, { - prompt = "Select provider to remove:", - }, function(provider) - if not provider then - return - end + vim.ui.select(configured, { + prompt = "Select provider to remove:", + }, function(provider) + if not provider then + return + end - vim.ui.select({ "Yes", "No" }, { - prompt = "Remove " .. provider:upper() .. " credentials?", - }, function(choice) - if choice == "Yes" then - if M.remove_credentials(provider) then - utils.notify("Removed " .. provider:upper() .. " credentials", vim.log.levels.INFO) - else - utils.notify("Failed to remove credentials", vim.log.levels.ERROR) - end - end - end) - end) + vim.ui.select({ "Yes", "No" }, { + prompt = "Remove " .. provider:upper() .. " credentials?", + }, function(choice) + if choice == "Yes" then + if M.remove_credentials(provider) then + utils.notify("Removed " .. provider:upper() .. " credentials", vim.log.levels.INFO) + else + utils.notify("Failed to remove credentials", vim.log.levels.ERROR) + end + end + end) + end) end --- Set the active provider ---@param provider string Provider name function M.set_active_provider(provider) - local data = M.load() - data.active_provider = provider - data.updated = os.time() - M.save(data) + local data = M.load() + data.active_provider = provider + data.updated = os.time() + M.save(data) - -- Also update the runtime config - local codetyper = require("codetyper") - local config = codetyper.get_config() - config.llm.provider = provider + -- Also update the runtime config + local codetyper = require("codetyper") + local config = codetyper.get_config() + config.llm.provider = provider - utils.notify("Active provider set to: " .. provider:upper(), vim.log.levels.INFO) + utils.notify("Active provider set to: " .. provider:upper(), vim.log.levels.INFO) end --- Get the active provider from stored config ---@return string|nil Active provider function M.get_active_provider() - local data = M.load() - return data.active_provider + local data = M.load() + return data.active_provider end --- Check if a provider is configured (from stored credentials OR config) ---@param provider string Provider name ---@return boolean configured, string|nil source local function is_provider_configured(provider) - local data = M.load() - local stored = data.providers and data.providers[provider] - if stored then - if stored.configured or provider == "ollama" or provider == "copilot" then - return true, "stored" - end - end + local data = M.load() + local stored = data.providers and data.providers[provider] + if stored then + if stored.configured or provider == "ollama" or provider == "copilot" then + return true, "stored" + end + end - local ok, codetyper = pcall(require, "codetyper") - if not ok then - return false, nil - end + local ok, codetyper = pcall(require, "codetyper") + if not ok then + return false, nil + end - local config = codetyper.get_config() - if not config or not config.llm then - return false, nil - end + local config = codetyper.get_config() + if not config or not config.llm then + return false, nil + end - local provider_config = config.llm[provider] - if not provider_config then - return false, nil - end + local provider_config = config.llm[provider] + if not provider_config then + return false, nil + end - if provider == "copilot" then - return true, "config" - elseif provider == "ollama" then - if provider_config.host then - return true, "config" - end - end + if provider == "copilot" then + return true, "config" + elseif provider == "ollama" then + if provider_config.host then + return true, "config" + end + end - return false, nil + return false, nil end --- Interactive switch provider function M.interactive_switch_provider() - local all_providers = { "copilot", "ollama" } - local available = {} - local sources = {} + local all_providers = { "copilot", "ollama" } + local available = {} + local sources = {} - for _, provider in ipairs(all_providers) do - local configured, source = is_provider_configured(provider) - if configured then - table.insert(available, provider) - sources[provider] = source - end - end + for _, provider in ipairs(all_providers) do + local configured, source = is_provider_configured(provider) + if configured then + table.insert(available, provider) + sources[provider] = source + end + end - if #available == 0 then - utils.notify("No providers configured. Use :CoderAddApiKey or add to your config.", vim.log.levels.WARN) - return - end + if #available == 0 then + utils.notify("No providers configured. Use :CoderAddApiKey or add to your config.", vim.log.levels.WARN) + return + end - local codetyper = require("codetyper") - local current = codetyper.get_config().llm.provider + local codetyper = require("codetyper") + local current = codetyper.get_config().llm.provider - vim.ui.select(available, { - prompt = "Select provider (current: " .. current .. "):", - format_item = function(item) - local marker = item == current and " [active]" or "" - local source_marker = sources[item] == "stored" and " (stored)" or " (config)" - return item:upper() .. marker .. source_marker - end, - }, function(provider) - if provider then - M.set_active_provider(provider) - end - end) + vim.ui.select(available, { + prompt = "Select provider (current: " .. current .. "):", + format_item = function(item) + local marker = item == current and " [active]" or "" + local source_marker = sources[item] == "stored" and " (stored)" or " (config)" + return item:upper() .. marker .. source_marker + end, + }, function(provider) + if provider then + M.set_active_provider(provider) + end + end) end return M diff --git a/lua/codetyper/config/defaults.lua b/lua/codetyper/config/defaults.lua index bd7ea40..70af306 100644 --- a/lua/codetyper/config/defaults.lua +++ b/lua/codetyper/config/defaults.lua @@ -4,48 +4,48 @@ local M = {} ---@type CoderConfig local defaults = { - llm = { - provider = "ollama", -- Options: "ollama", "copilot" - ollama = { - host = "http://localhost:11434", - model = "deepseek-coder:6.7b", - }, - copilot = { - model = "claude-sonnet-4", -- Uses GitHub Copilot authentication - }, - }, - auto_gitignore = true, - auto_index = false, -- Auto-create coder companion files on file open - indexer = { - enabled = true, -- Enable project indexing - auto_index = true, -- Index files on save - index_on_open = false, -- Index project when opening - max_file_size = 100000, -- Skip files larger than 100KB - excluded_dirs = { "node_modules", "dist", "build", ".git", ".codetyper", "__pycache__", "vendor", "target" }, - index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" }, - memory = { - enabled = true, -- Enable memory persistence - max_memories = 1000, -- Maximum stored memories - prune_threshold = 0.1, -- Remove low-weight memories - }, - }, - brain = { - enabled = true, -- Enable brain learning system - auto_learn = true, -- Auto-learn from events - auto_commit = true, -- Auto-commit after threshold - commit_threshold = 10, -- Changes before auto-commit - max_nodes = 5000, -- Maximum nodes before pruning - max_deltas = 500, -- Maximum delta history - prune = { - enabled = true, -- Enable auto-pruning - threshold = 0.1, -- Remove nodes below this weight - unused_days = 90, -- Remove unused nodes after N days - }, - output = { - max_tokens = 4000, -- Token budget for LLM context - format = "compact", -- "compact"|"json"|"natural" - }, - }, + llm = { + provider = "ollama", -- Options: "ollama", "copilot" + ollama = { + host = "http://localhost:11434", + model = "deepseek-coder:6.7b", + }, + copilot = { + model = "claude-sonnet-4", -- Uses GitHub Copilot authentication + }, + }, + auto_gitignore = true, + auto_index = false, -- Auto-create coder companion files on file open + indexer = { + enabled = true, -- Enable project indexing + auto_index = true, -- Index files on save + index_on_open = false, -- Index project when opening + max_file_size = 100000, -- Skip files larger than 100KB + excluded_dirs = { "node_modules", "dist", "build", ".git", ".codetyper", "__pycache__", "vendor", "target" }, + index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" }, + memory = { + enabled = true, -- Enable memory persistence + max_memories = 1000, -- Maximum stored memories + prune_threshold = 0.1, -- Remove low-weight memories + }, + }, + brain = { + enabled = true, -- Enable brain learning system + auto_learn = true, -- Auto-learn from events + auto_commit = true, -- Auto-commit after threshold + commit_threshold = 10, -- Changes before auto-commit + max_nodes = 5000, -- Maximum nodes before pruning + max_deltas = 500, -- Maximum delta history + prune = { + enabled = true, -- Enable auto-pruning + threshold = 0.1, -- Remove nodes below this weight + unused_days = 90, -- Remove unused nodes after N days + }, + output = { + max_tokens = 4000, -- Token budget for LLM context + format = "compact", -- "compact"|"json"|"natural" + }, + }, } --- Deep merge two tables @@ -53,53 +53,53 @@ local defaults = { ---@param t2 table Table to merge into base ---@return table Merged table local function deep_merge(t1, t2) - local result = vim.deepcopy(t1) - for k, v in pairs(t2) do - if type(v) == "table" and type(result[k]) == "table" then - result[k] = deep_merge(result[k], v) - else - result[k] = v - end - end - return result + local result = vim.deepcopy(t1) + for k, v in pairs(t2) do + if type(v) == "table" and type(result[k]) == "table" then + result[k] = deep_merge(result[k], v) + else + result[k] = v + end + end + return result end --- Setup configuration with user options ---@param opts? CoderConfig User configuration options ---@return CoderConfig Final configuration function M.setup(opts) - opts = opts or {} - return deep_merge(defaults, opts) + opts = opts or {} + return deep_merge(defaults, opts) end --- Get default configuration ---@return CoderConfig Default configuration function M.get_defaults() - return vim.deepcopy(defaults) + return vim.deepcopy(defaults) end --- Validate configuration ---@param config CoderConfig Configuration to validate ---@return boolean, string? Valid status and optional error message function M.validate(config) - if not config.llm then - return false, "Missing LLM configuration" - end + if not config.llm then + return false, "Missing LLM configuration" + end - local valid_providers = { "ollama", "copilot" } - local is_valid_provider = false - for _, p in ipairs(valid_providers) do - if config.llm.provider == p then - is_valid_provider = true - break - end - end + local valid_providers = { "ollama", "copilot" } + local is_valid_provider = false + for _, p in ipairs(valid_providers) do + if config.llm.provider == p then + is_valid_provider = true + break + end + end - if not is_valid_provider then - return false, "Invalid LLM provider. Must be one of: " .. table.concat(valid_providers, ", ") - end + if not is_valid_provider then + return false, "Invalid LLM provider. Must be one of: " .. table.concat(valid_providers, ", ") + end - return true + return true end return M diff --git a/lua/codetyper/config/preferences.lua b/lua/codetyper/config/preferences.lua index 022a830..fecef78 100644 --- a/lua/codetyper/config/preferences.lua +++ b/lua/codetyper/config/preferences.lua @@ -12,8 +12,8 @@ local utils = require("codetyper.support.utils") --- Default preferences local defaults = { - auto_process = nil, -- nil means "not yet decided" - asked_auto_process = false, + auto_process = nil, -- nil means "not yet decided" + asked_auto_process = false, } --- Cached preferences per project @@ -23,113 +23,113 @@ local cache = {} --- Get the preferences file path for current project ---@return string local function get_preferences_path() - local cwd = vim.fn.getcwd() - return cwd .. "/.codetyper/preferences.json" + local cwd = vim.fn.getcwd() + return cwd .. "/.codetyper/preferences.json" end --- Ensure .codetyper directory exists local function ensure_coder_dir() - local cwd = vim.fn.getcwd() - local coder_dir = cwd .. "/.codetyper" - if vim.fn.isdirectory(coder_dir) == 0 then - vim.fn.mkdir(coder_dir, "p") - end + local cwd = vim.fn.getcwd() + local coder_dir = cwd .. "/.codetyper" + if vim.fn.isdirectory(coder_dir) == 0 then + vim.fn.mkdir(coder_dir, "p") + end end --- Load preferences from file ---@return CoderPreferences function M.load() - local cwd = vim.fn.getcwd() + local cwd = vim.fn.getcwd() - -- Check cache first - if cache[cwd] then - return cache[cwd] - end + -- Check cache first + if cache[cwd] then + return cache[cwd] + end - local path = get_preferences_path() - local prefs = vim.deepcopy(defaults) + local path = get_preferences_path() + local prefs = vim.deepcopy(defaults) - if utils.file_exists(path) then - local content = utils.read_file(path) - if content then - local ok, decoded = pcall(vim.json.decode, content) - if ok and decoded then - -- Merge with defaults - for k, v in pairs(decoded) do - prefs[k] = v - end - end - end - end + if utils.file_exists(path) then + local content = utils.read_file(path) + if content then + local ok, decoded = pcall(vim.json.decode, content) + if ok and decoded then + -- Merge with defaults + for k, v in pairs(decoded) do + prefs[k] = v + end + end + end + end - -- Cache it - cache[cwd] = prefs - return prefs + -- Cache it + cache[cwd] = prefs + return prefs end --- Save preferences to file ---@param prefs CoderPreferences function M.save(prefs) - local cwd = vim.fn.getcwd() - ensure_coder_dir() + local cwd = vim.fn.getcwd() + ensure_coder_dir() - local path = get_preferences_path() - local ok, encoded = pcall(vim.json.encode, prefs) - if ok then - utils.write_file(path, encoded) - -- Update cache - cache[cwd] = prefs - end + local path = get_preferences_path() + local ok, encoded = pcall(vim.json.encode, prefs) + if ok then + utils.write_file(path, encoded) + -- Update cache + cache[cwd] = prefs + end end --- Get a specific preference ---@param key string ---@return any function M.get(key) - local prefs = M.load() - return prefs[key] + local prefs = M.load() + return prefs[key] end --- Set a specific preference ---@param key string ---@param value any function M.set(key, value) - local prefs = M.load() - prefs[key] = value - M.save(prefs) + local prefs = M.load() + prefs[key] = value + M.save(prefs) end --- Check if auto-process is enabled ---@return boolean|nil Returns true/false if set, nil if not yet decided function M.is_auto_process_enabled() - return M.get("auto_process") + return M.get("auto_process") end --- Set auto-process preference ---@param enabled boolean function M.set_auto_process(enabled) - M.set("auto_process", enabled) - M.set("asked_auto_process", true) + M.set("auto_process", enabled) + M.set("asked_auto_process", true) end --- Check if we've already asked the user about auto-process ---@return boolean function M.has_asked_auto_process() - return M.get("asked_auto_process") == true + return M.get("asked_auto_process") == true end --- Clear cached preferences (useful when changing projects) function M.clear_cache() - cache = {} + cache = {} end --- Toggle auto-process mode function M.toggle_auto_process() - local current = M.is_auto_process_enabled() - local new_value = not current - M.set_auto_process(new_value) - local mode = new_value and "automatic" or "manual" - vim.notify("Codetyper: Switched to " .. mode .. " mode", vim.log.levels.INFO) + local current = M.is_auto_process_enabled() + local new_value = not current + M.set_auto_process(new_value) + local mode = new_value and "automatic" or "manual" + vim.notify("Codetyper: Switched to " .. mode .. " mode", vim.log.levels.INFO) end return M diff --git a/lua/codetyper/core/cost/init.lua b/lua/codetyper/core/cost/init.lua index 96714cc..be6fbca 100644 --- a/lua/codetyper/core/cost/init.lua +++ b/lua/codetyper/core/cost/init.lua @@ -14,8 +14,8 @@ local COST_HISTORY_FILE = "cost_history.json" --- Get path to cost history file ---@return string File path local function get_history_path() - local root = utils.get_project_root() - return root .. "/.codetyper/" .. COST_HISTORY_FILE + local root = utils.get_project_root() + return root .. "/.codetyper/" .. COST_HISTORY_FILE end --- Default model for savings comparison (what you'd pay if not using Ollama) @@ -23,100 +23,100 @@ M.comparison_model = "gpt-4o" --- Models considered "free" (Ollama, local, Copilot subscription) M.free_models = { - ["ollama"] = true, - ["codellama"] = true, - ["llama2"] = true, - ["llama3"] = true, - ["mistral"] = true, - ["deepseek-coder"] = true, - ["copilot"] = true, + ["ollama"] = true, + ["codellama"] = true, + ["llama2"] = true, + ["llama3"] = true, + ["mistral"] = true, + ["deepseek-coder"] = true, + ["copilot"] = true, } --- Model pricing table (per 1M tokens in USD) ---@type table M.pricing = { - -- GPT-5.x series - ["gpt-5.2"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, - ["gpt-5.1"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, - ["gpt-5"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, - ["gpt-5-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 }, - ["gpt-5-nano"] = { input = 0.05, cached_input = 0.005, output = 0.40 }, - ["gpt-5.2-chat-latest"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, - ["gpt-5.1-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, - ["gpt-5-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, - ["gpt-5.2-codex"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, - ["gpt-5.1-codex-max"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, - ["gpt-5.1-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, - ["gpt-5-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, - ["gpt-5.2-pro"] = { input = 21.00, cached_input = nil, output = 168.00 }, - ["gpt-5-pro"] = { input = 15.00, cached_input = nil, output = 120.00 }, - ["gpt-5.1-codex-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 }, - ["gpt-5-search-api"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + -- GPT-5.x series + ["gpt-5.2"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, + ["gpt-5.1"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 }, + ["gpt-5-nano"] = { input = 0.05, cached_input = 0.005, output = 0.40 }, + ["gpt-5.2-chat-latest"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, + ["gpt-5.1-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5.2-codex"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, + ["gpt-5.1-codex-max"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5.1-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5.2-pro"] = { input = 21.00, cached_input = nil, output = 168.00 }, + ["gpt-5-pro"] = { input = 15.00, cached_input = nil, output = 120.00 }, + ["gpt-5.1-codex-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 }, + ["gpt-5-search-api"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, - -- GPT-4.x series - ["gpt-4.1"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, - ["gpt-4.1-mini"] = { input = 0.40, cached_input = 0.10, output = 1.60 }, - ["gpt-4.1-nano"] = { input = 0.10, cached_input = 0.025, output = 0.40 }, - ["gpt-4o"] = { input = 2.50, cached_input = 1.25, output = 10.00 }, - ["gpt-4o-2024-05-13"] = { input = 5.00, cached_input = nil, output = 15.00 }, - ["gpt-4o-mini"] = { input = 0.15, cached_input = 0.075, output = 0.60 }, + -- GPT-4.x series + ["gpt-4.1"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, + ["gpt-4.1-mini"] = { input = 0.40, cached_input = 0.10, output = 1.60 }, + ["gpt-4.1-nano"] = { input = 0.10, cached_input = 0.025, output = 0.40 }, + ["gpt-4o"] = { input = 2.50, cached_input = 1.25, output = 10.00 }, + ["gpt-4o-2024-05-13"] = { input = 5.00, cached_input = nil, output = 15.00 }, + ["gpt-4o-mini"] = { input = 0.15, cached_input = 0.075, output = 0.60 }, - -- Realtime models - ["gpt-realtime"] = { input = 4.00, cached_input = 0.40, output = 16.00 }, - ["gpt-realtime-mini"] = { input = 0.60, cached_input = 0.06, output = 2.40 }, - ["gpt-4o-realtime-preview"] = { input = 5.00, cached_input = 2.50, output = 20.00 }, - ["gpt-4o-mini-realtime-preview"] = { input = 0.60, cached_input = 0.30, output = 2.40 }, + -- Realtime models + ["gpt-realtime"] = { input = 4.00, cached_input = 0.40, output = 16.00 }, + ["gpt-realtime-mini"] = { input = 0.60, cached_input = 0.06, output = 2.40 }, + ["gpt-4o-realtime-preview"] = { input = 5.00, cached_input = 2.50, output = 20.00 }, + ["gpt-4o-mini-realtime-preview"] = { input = 0.60, cached_input = 0.30, output = 2.40 }, - -- Audio models - ["gpt-audio"] = { input = 2.50, cached_input = nil, output = 10.00 }, - ["gpt-audio-mini"] = { input = 0.60, cached_input = nil, output = 2.40 }, - ["gpt-4o-audio-preview"] = { input = 2.50, cached_input = nil, output = 10.00 }, - ["gpt-4o-mini-audio-preview"] = { input = 0.15, cached_input = nil, output = 0.60 }, + -- Audio models + ["gpt-audio"] = { input = 2.50, cached_input = nil, output = 10.00 }, + ["gpt-audio-mini"] = { input = 0.60, cached_input = nil, output = 2.40 }, + ["gpt-4o-audio-preview"] = { input = 2.50, cached_input = nil, output = 10.00 }, + ["gpt-4o-mini-audio-preview"] = { input = 0.15, cached_input = nil, output = 0.60 }, - -- O-series reasoning models - ["o1"] = { input = 15.00, cached_input = 7.50, output = 60.00 }, - ["o1-pro"] = { input = 150.00, cached_input = nil, output = 600.00 }, - ["o3-pro"] = { input = 20.00, cached_input = nil, output = 80.00 }, - ["o3"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, - ["o3-deep-research"] = { input = 10.00, cached_input = 2.50, output = 40.00 }, - ["o4-mini"] = { input = 1.10, cached_input = 0.275, output = 4.40 }, - ["o4-mini-deep-research"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, - ["o3-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 }, - ["o1-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 }, + -- O-series reasoning models + ["o1"] = { input = 15.00, cached_input = 7.50, output = 60.00 }, + ["o1-pro"] = { input = 150.00, cached_input = nil, output = 600.00 }, + ["o3-pro"] = { input = 20.00, cached_input = nil, output = 80.00 }, + ["o3"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, + ["o3-deep-research"] = { input = 10.00, cached_input = 2.50, output = 40.00 }, + ["o4-mini"] = { input = 1.10, cached_input = 0.275, output = 4.40 }, + ["o4-mini-deep-research"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, + ["o3-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 }, + ["o1-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 }, - -- Codex - ["codex-mini-latest"] = { input = 1.50, cached_input = 0.375, output = 6.00 }, + -- Codex + ["codex-mini-latest"] = { input = 1.50, cached_input = 0.375, output = 6.00 }, - -- Search models - ["gpt-4o-mini-search-preview"] = { input = 0.15, cached_input = nil, output = 0.60 }, - ["gpt-4o-search-preview"] = { input = 2.50, cached_input = nil, output = 10.00 }, + -- Search models + ["gpt-4o-mini-search-preview"] = { input = 0.15, cached_input = nil, output = 0.60 }, + ["gpt-4o-search-preview"] = { input = 2.50, cached_input = nil, output = 10.00 }, - -- Computer use - ["computer-use-preview"] = { input = 3.00, cached_input = nil, output = 12.00 }, + -- Computer use + ["computer-use-preview"] = { input = 3.00, cached_input = nil, output = 12.00 }, - -- Image models - ["gpt-image-1.5"] = { input = 5.00, cached_input = 1.25, output = 10.00 }, - ["chatgpt-image-latest"] = { input = 5.00, cached_input = 1.25, output = 10.00 }, - ["gpt-image-1"] = { input = 5.00, cached_input = 1.25, output = nil }, - ["gpt-image-1-mini"] = { input = 2.00, cached_input = 0.20, output = nil }, + -- Image models + ["gpt-image-1.5"] = { input = 5.00, cached_input = 1.25, output = 10.00 }, + ["chatgpt-image-latest"] = { input = 5.00, cached_input = 1.25, output = 10.00 }, + ["gpt-image-1"] = { input = 5.00, cached_input = 1.25, output = nil }, + ["gpt-image-1-mini"] = { input = 2.00, cached_input = 0.20, output = nil }, - -- Claude models - ["claude-3-opus"] = { input = 15.00, cached_input = 7.50, output = 75.00 }, - ["claude-3-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 }, - ["claude-3-haiku"] = { input = 0.25, cached_input = 0.125, output = 1.25 }, - ["claude-3.5-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 }, - ["claude-3.5-haiku"] = { input = 0.80, cached_input = 0.40, output = 4.00 }, + -- Claude models + ["claude-3-opus"] = { input = 15.00, cached_input = 7.50, output = 75.00 }, + ["claude-3-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 }, + ["claude-3-haiku"] = { input = 0.25, cached_input = 0.125, output = 1.25 }, + ["claude-3.5-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 }, + ["claude-3.5-haiku"] = { input = 0.80, cached_input = 0.40, output = 4.00 }, - -- Ollama/Local models (free) - ["ollama"] = { input = 0, cached_input = 0, output = 0 }, - ["codellama"] = { input = 0, cached_input = 0, output = 0 }, - ["llama2"] = { input = 0, cached_input = 0, output = 0 }, - ["llama3"] = { input = 0, cached_input = 0, output = 0 }, - ["mistral"] = { input = 0, cached_input = 0, output = 0 }, - ["deepseek-coder"] = { input = 0, cached_input = 0, output = 0 }, + -- Ollama/Local models (free) + ["ollama"] = { input = 0, cached_input = 0, output = 0 }, + ["codellama"] = { input = 0, cached_input = 0, output = 0 }, + ["llama2"] = { input = 0, cached_input = 0, output = 0 }, + ["llama3"] = { input = 0, cached_input = 0, output = 0 }, + ["mistral"] = { input = 0, cached_input = 0, output = 0 }, + ["deepseek-coder"] = { input = 0, cached_input = 0, output = 0 }, - -- Copilot (included in subscription, but tracking usage) - ["copilot"] = { input = 0, cached_input = 0, output = 0 }, + -- Copilot (included in subscription, but tracking usage) + ["copilot"] = { input = 0, cached_input = 0, output = 0 }, } ---@class CostUsage @@ -135,130 +135,134 @@ M.pricing = { ---@field buf number|nil Buffer handle ---@field loaded boolean Whether historical data has been loaded local state = { - usage = {}, - all_usage = {}, - session_start = os.time(), - win = nil, - buf = nil, - loaded = false, + usage = {}, + all_usage = {}, + session_start = os.time(), + win = nil, + buf = nil, + loaded = false, } --- Load historical usage from disk function M.load_from_history() - if state.loaded then - return - end + if state.loaded then + return + end - local history_path = get_history_path() - local content = utils.read_file(history_path) + local history_path = get_history_path() + local content = utils.read_file(history_path) - if content and content ~= "" then - local ok, data = pcall(vim.json.decode, content) - if ok and data and data.usage then - state.all_usage = data.usage - end - end + if content and content ~= "" then + local ok, data = pcall(vim.json.decode, content) + if ok and data and data.usage then + state.all_usage = data.usage + end + end - state.loaded = true + state.loaded = true end --- Save all usage to disk (debounced) local save_timer = nil local function save_to_disk() - -- Cancel existing timer - if save_timer then - save_timer:stop() - save_timer = nil - end + -- Cancel existing timer + if save_timer then + save_timer:stop() + save_timer = nil + end - -- Debounce writes (500ms) - save_timer = vim.loop.new_timer() - save_timer:start(500, 0, vim.schedule_wrap(function() - local history_path = get_history_path() + -- Debounce writes (500ms) + save_timer = vim.loop.new_timer() + save_timer:start( + 500, + 0, + vim.schedule_wrap(function() + local history_path = get_history_path() - -- Ensure directory exists - local dir = vim.fn.fnamemodify(history_path, ":h") - utils.ensure_dir(dir) + -- Ensure directory exists + local dir = vim.fn.fnamemodify(history_path, ":h") + utils.ensure_dir(dir) - -- Merge session and historical usage - local all_data = vim.deepcopy(state.all_usage) - for _, usage in ipairs(state.usage) do - table.insert(all_data, usage) - end + -- Merge session and historical usage + local all_data = vim.deepcopy(state.all_usage) + for _, usage in ipairs(state.usage) do + table.insert(all_data, usage) + end - -- Save to file - local data = { - version = 1, - updated = os.time(), - usage = all_data, - } + -- Save to file + local data = { + version = 1, + updated = os.time(), + usage = all_data, + } - local ok, json = pcall(vim.json.encode, data) - if ok then - utils.write_file(history_path, json) - end + local ok, json = pcall(vim.json.encode, data) + if ok then + utils.write_file(history_path, json) + end - save_timer = nil - end)) + save_timer = nil + end) + ) end --- Normalize model name for pricing lookup ---@param model string Model name from API ---@return string Normalized model name local function normalize_model(model) - if not model then - return "unknown" - end + if not model then + return "unknown" + end - -- Convert to lowercase - local normalized = model:lower() + -- Convert to lowercase + local normalized = model:lower() - -- Handle Copilot models - if normalized:match("copilot") then - return "copilot" - end + -- Handle Copilot models + if normalized:match("copilot") then + return "copilot" + end - -- Handle common prefixes - normalized = normalized:gsub("^copilot/", "") + -- Handle common prefixes + normalized = normalized:gsub("^copilot/", "") - -- Try exact match first - if M.pricing[normalized] then - return normalized - end + -- Try exact match first + if M.pricing[normalized] then + return normalized + end - -- Try partial matches - for price_model, _ in pairs(M.pricing) do - if normalized:match(price_model) or price_model:match(normalized) then - return price_model - end - end + -- Try partial matches + for price_model, _ in pairs(M.pricing) do + if normalized:match(price_model) or price_model:match(normalized) then + return price_model + end + end - return normalized + return normalized end --- Check if a model is considered "free" (local/Ollama/Copilot subscription) ---@param model string Model name ---@return boolean True if free function M.is_free_model(model) - local normalized = normalize_model(model) + local normalized = normalize_model(model) - -- Check direct match - if M.free_models[normalized] then - return true - end + -- Check direct match + if M.free_models[normalized] then + return true + end - -- Check if it's an Ollama model (any model with : in name like deepseek-coder:6.7b) - if model:match(":") then - return true - end + -- Check if it's an Ollama model (any model with : in name like deepseek-coder:6.7b) + if model:match(":") then + return true + end - -- Check pricing - if cost is 0, it's free - local pricing = M.pricing[normalized] - if pricing and pricing.input == 0 and pricing.output == 0 then - return true - end + -- Check pricing - if cost is 0, it's free + local pricing = M.pricing[normalized] + if pricing and pricing.input == 0 and pricing.output == 0 then + return true + end - return false + return false end --- Calculate cost for token usage @@ -268,23 +272,23 @@ end ---@param cached_tokens? number Cached input tokens ---@return number Cost in USD function M.calculate_cost(model, input_tokens, output_tokens, cached_tokens) - local normalized = normalize_model(model) - local pricing = M.pricing[normalized] + local normalized = normalize_model(model) + local pricing = M.pricing[normalized] - if not pricing then - -- Unknown model, return 0 - return 0 - end + if not pricing then + -- Unknown model, return 0 + return 0 + end - cached_tokens = cached_tokens or 0 - local regular_input = input_tokens - cached_tokens + cached_tokens = cached_tokens or 0 + local regular_input = input_tokens - cached_tokens - -- Calculate cost (prices are per 1M tokens) - local input_cost = (regular_input / 1000000) * (pricing.input or 0) - local cached_cost = (cached_tokens / 1000000) * (pricing.cached_input or pricing.input or 0) - local output_cost = (output_tokens / 1000000) * (pricing.output or 0) + -- Calculate cost (prices are per 1M tokens) + local input_cost = (regular_input / 1000000) * (pricing.input or 0) + local cached_cost = (cached_tokens / 1000000) * (pricing.cached_input or pricing.input or 0) + local output_cost = (output_tokens / 1000000) * (pricing.output or 0) - return input_cost + cached_cost + output_cost + return input_cost + cached_cost + output_cost end --- Calculate estimated savings (what would have been paid if using comparison model) @@ -293,8 +297,8 @@ end ---@param cached_tokens? number Cached input tokens ---@return number Estimated savings in USD function M.calculate_savings(input_tokens, output_tokens, cached_tokens) - -- Calculate what it would have cost with the comparison model - return M.calculate_cost(M.comparison_model, input_tokens, output_tokens, cached_tokens) + -- Calculate what it would have cost with the comparison model + return M.calculate_cost(M.comparison_model, input_tokens, output_tokens, cached_tokens) end --- Record token usage @@ -303,447 +307,469 @@ end ---@param output_tokens number Output tokens ---@param cached_tokens? number Cached input tokens function M.record_usage(model, input_tokens, output_tokens, cached_tokens) - cached_tokens = cached_tokens or 0 - local cost = M.calculate_cost(model, input_tokens, output_tokens, cached_tokens) + cached_tokens = cached_tokens or 0 + local cost = M.calculate_cost(model, input_tokens, output_tokens, cached_tokens) - -- Calculate savings if using a free model - local savings = 0 - if M.is_free_model(model) then - savings = M.calculate_savings(input_tokens, output_tokens, cached_tokens) - end + -- Calculate savings if using a free model + local savings = 0 + if M.is_free_model(model) then + savings = M.calculate_savings(input_tokens, output_tokens, cached_tokens) + end - table.insert(state.usage, { - model = model, - input_tokens = input_tokens, - output_tokens = output_tokens, - cached_tokens = cached_tokens, - timestamp = os.time(), - cost = cost, - savings = savings, - is_free = M.is_free_model(model), - }) + table.insert(state.usage, { + model = model, + input_tokens = input_tokens, + output_tokens = output_tokens, + cached_tokens = cached_tokens, + timestamp = os.time(), + cost = cost, + savings = savings, + is_free = M.is_free_model(model), + }) - -- Save to disk (debounced) - save_to_disk() + -- Save to disk (debounced) + save_to_disk() - -- Update window if open - if state.win and vim.api.nvim_win_is_valid(state.win) then - M.refresh_window() - end + -- Update window if open + if state.win and vim.api.nvim_win_is_valid(state.win) then + M.refresh_window() + end end --- Aggregate usage data into stats ---@param usage_list CostUsage[] List of usage records ---@return table Stats local function aggregate_usage(usage_list) - local stats = { - total_input = 0, - total_output = 0, - total_cached = 0, - total_cost = 0, - total_savings = 0, - free_requests = 0, - paid_requests = 0, - by_model = {}, - request_count = #usage_list, - } + local stats = { + total_input = 0, + total_output = 0, + total_cached = 0, + total_cost = 0, + total_savings = 0, + free_requests = 0, + paid_requests = 0, + by_model = {}, + request_count = #usage_list, + } - for _, usage in ipairs(usage_list) do - stats.total_input = stats.total_input + (usage.input_tokens or 0) - stats.total_output = stats.total_output + (usage.output_tokens or 0) - stats.total_cached = stats.total_cached + (usage.cached_tokens or 0) - stats.total_cost = stats.total_cost + (usage.cost or 0) + for _, usage in ipairs(usage_list) do + stats.total_input = stats.total_input + (usage.input_tokens or 0) + stats.total_output = stats.total_output + (usage.output_tokens or 0) + stats.total_cached = stats.total_cached + (usage.cached_tokens or 0) + stats.total_cost = stats.total_cost + (usage.cost or 0) - -- Track savings - local usage_savings = usage.savings or 0 - -- For historical data without savings field, calculate it - if usage_savings == 0 and usage.is_free == nil then - local model = usage.model or "unknown" - if M.is_free_model(model) then - usage_savings = M.calculate_savings( - usage.input_tokens or 0, - usage.output_tokens or 0, - usage.cached_tokens or 0 - ) - end - end - stats.total_savings = stats.total_savings + usage_savings + -- Track savings + local usage_savings = usage.savings or 0 + -- For historical data without savings field, calculate it + if usage_savings == 0 and usage.is_free == nil then + local model = usage.model or "unknown" + if M.is_free_model(model) then + usage_savings = M.calculate_savings(usage.input_tokens or 0, usage.output_tokens or 0, usage.cached_tokens or 0) + end + end + stats.total_savings = stats.total_savings + usage_savings - -- Track free vs paid - local is_free = usage.is_free - if is_free == nil then - is_free = M.is_free_model(usage.model or "unknown") - end - if is_free then - stats.free_requests = stats.free_requests + 1 - else - stats.paid_requests = stats.paid_requests + 1 - end + -- Track free vs paid + local is_free = usage.is_free + if is_free == nil then + is_free = M.is_free_model(usage.model or "unknown") + end + if is_free then + stats.free_requests = stats.free_requests + 1 + else + stats.paid_requests = stats.paid_requests + 1 + end - local model = usage.model or "unknown" - if not stats.by_model[model] then - stats.by_model[model] = { - input_tokens = 0, - output_tokens = 0, - cached_tokens = 0, - cost = 0, - savings = 0, - requests = 0, - is_free = is_free, - } - end + local model = usage.model or "unknown" + if not stats.by_model[model] then + stats.by_model[model] = { + input_tokens = 0, + output_tokens = 0, + cached_tokens = 0, + cost = 0, + savings = 0, + requests = 0, + is_free = is_free, + } + end - stats.by_model[model].input_tokens = stats.by_model[model].input_tokens + (usage.input_tokens or 0) - stats.by_model[model].output_tokens = stats.by_model[model].output_tokens + (usage.output_tokens or 0) - stats.by_model[model].cached_tokens = stats.by_model[model].cached_tokens + (usage.cached_tokens or 0) - stats.by_model[model].cost = stats.by_model[model].cost + (usage.cost or 0) - stats.by_model[model].savings = stats.by_model[model].savings + usage_savings - stats.by_model[model].requests = stats.by_model[model].requests + 1 - end + stats.by_model[model].input_tokens = stats.by_model[model].input_tokens + (usage.input_tokens or 0) + stats.by_model[model].output_tokens = stats.by_model[model].output_tokens + (usage.output_tokens or 0) + stats.by_model[model].cached_tokens = stats.by_model[model].cached_tokens + (usage.cached_tokens or 0) + stats.by_model[model].cost = stats.by_model[model].cost + (usage.cost or 0) + stats.by_model[model].savings = stats.by_model[model].savings + usage_savings + stats.by_model[model].requests = stats.by_model[model].requests + 1 + end - return stats + return stats end --- Get session statistics ---@return table Statistics function M.get_stats() - local stats = aggregate_usage(state.usage) - stats.session_duration = os.time() - state.session_start - return stats + local stats = aggregate_usage(state.usage) + stats.session_duration = os.time() - state.session_start + return stats end --- Get all-time statistics (session + historical) ---@return table Statistics function M.get_all_time_stats() - -- Load history if not loaded - M.load_from_history() + -- Load history if not loaded + M.load_from_history() - -- Combine session and historical usage - local all_usage = vim.deepcopy(state.all_usage) - for _, usage in ipairs(state.usage) do - table.insert(all_usage, usage) - end + -- Combine session and historical usage + local all_usage = vim.deepcopy(state.all_usage) + for _, usage in ipairs(state.usage) do + table.insert(all_usage, usage) + end - local stats = aggregate_usage(all_usage) + local stats = aggregate_usage(all_usage) - -- Calculate time span - if #all_usage > 0 then - local oldest = all_usage[1].timestamp or os.time() - for _, usage in ipairs(all_usage) do - if usage.timestamp and usage.timestamp < oldest then - oldest = usage.timestamp - end - end - stats.time_span = os.time() - oldest - else - stats.time_span = 0 - end + -- Calculate time span + if #all_usage > 0 then + local oldest = all_usage[1].timestamp or os.time() + for _, usage in ipairs(all_usage) do + if usage.timestamp and usage.timestamp < oldest then + oldest = usage.timestamp + end + end + stats.time_span = os.time() - oldest + else + stats.time_span = 0 + end - return stats + return stats end --- Format cost as string ---@param cost number Cost in USD ---@return string Formatted cost local function format_cost(cost) - if cost < 0.01 then - return string.format("$%.4f", cost) - elseif cost < 1 then - return string.format("$%.3f", cost) - else - return string.format("$%.2f", cost) - end + if cost < 0.01 then + return string.format("$%.4f", cost) + elseif cost < 1 then + return string.format("$%.3f", cost) + else + return string.format("$%.2f", cost) + end end --- Format token count ---@param tokens number Token count ---@return string Formatted count local function format_tokens(tokens) - if tokens >= 1000000 then - return string.format("%.2fM", tokens / 1000000) - elseif tokens >= 1000 then - return string.format("%.1fK", tokens / 1000) - else - return tostring(tokens) - end + if tokens >= 1000000 then + return string.format("%.2fM", tokens / 1000000) + elseif tokens >= 1000 then + return string.format("%.1fK", tokens / 1000) + else + return tostring(tokens) + end end --- Format duration ---@param seconds number Duration in seconds ---@return string Formatted duration local function format_duration(seconds) - if seconds < 60 then - return string.format("%ds", seconds) - elseif seconds < 3600 then - return string.format("%dm %ds", math.floor(seconds / 60), seconds % 60) - else - local hours = math.floor(seconds / 3600) - local mins = math.floor((seconds % 3600) / 60) - return string.format("%dh %dm", hours, mins) - end + if seconds < 60 then + return string.format("%ds", seconds) + elseif seconds < 3600 then + return string.format("%dm %ds", math.floor(seconds / 60), seconds % 60) + else + local hours = math.floor(seconds / 3600) + local mins = math.floor((seconds % 3600) / 60) + return string.format("%dh %dm", hours, mins) + end end --- Generate model breakdown section ---@param stats table Stats with by_model ---@return string[] Lines local function generate_model_breakdown(stats) - local lines = {} + local lines = {} - if next(stats.by_model) then - -- Sort models by cost (descending) - local models = {} - for model, data in pairs(stats.by_model) do - table.insert(models, { name = model, data = data }) - end - table.sort(models, function(a, b) - return a.data.cost > b.data.cost - end) + if next(stats.by_model) then + -- Sort models by cost (descending) + local models = {} + for model, data in pairs(stats.by_model) do + table.insert(models, { name = model, data = data }) + end + table.sort(models, function(a, b) + return a.data.cost > b.data.cost + end) - for _, item in ipairs(models) do - local model = item.name - local data = item.data - local pricing = M.pricing[normalize_model(model)] - local is_free = data.is_free or M.is_free_model(model) + for _, item in ipairs(models) do + local model = item.name + local data = item.data + local pricing = M.pricing[normalize_model(model)] + local is_free = data.is_free or M.is_free_model(model) - table.insert(lines, "") - local model_icon = is_free and "🆓" or "💳" - table.insert(lines, string.format(" %s %s", model_icon, model)) - table.insert(lines, string.format(" Requests: %d", data.requests)) - table.insert(lines, string.format(" Input: %s tokens", format_tokens(data.input_tokens))) - table.insert(lines, string.format(" Output: %s tokens", format_tokens(data.output_tokens))) + table.insert(lines, "") + local model_icon = is_free and "🆓" or "💳" + table.insert(lines, string.format(" %s %s", model_icon, model)) + table.insert(lines, string.format(" Requests: %d", data.requests)) + table.insert(lines, string.format(" Input: %s tokens", format_tokens(data.input_tokens))) + table.insert(lines, string.format(" Output: %s tokens", format_tokens(data.output_tokens))) - if is_free then - -- Show savings for free models - if data.savings and data.savings > 0 then - table.insert(lines, string.format(" Saved: %s", format_cost(data.savings))) - end - else - table.insert(lines, string.format(" Cost: %s", format_cost(data.cost))) - end + if is_free then + -- Show savings for free models + if data.savings and data.savings > 0 then + table.insert(lines, string.format(" Saved: %s", format_cost(data.savings))) + end + else + table.insert(lines, string.format(" Cost: %s", format_cost(data.cost))) + end - -- Show pricing info for paid models - if pricing and not is_free then - local price_info = string.format( - " Rate: $%.2f/1M in, $%.2f/1M out", - pricing.input or 0, - pricing.output or 0 - ) - table.insert(lines, price_info) - end - end - else - table.insert(lines, " No usage recorded.") - end + -- Show pricing info for paid models + if pricing and not is_free then + local price_info = + string.format(" Rate: $%.2f/1M in, $%.2f/1M out", pricing.input or 0, pricing.output or 0) + table.insert(lines, price_info) + end + end + else + table.insert(lines, " No usage recorded.") + end - return lines + return lines end --- Generate window content ---@return string[] Lines for the buffer local function generate_content() - local session_stats = M.get_stats() - local all_time_stats = M.get_all_time_stats() - local lines = {} + local session_stats = M.get_stats() + local all_time_stats = M.get_all_time_stats() + local lines = {} - -- Header - table.insert(lines, "╔══════════════════════════════════════════════════════╗") - table.insert(lines, "║ 💰 LLM Cost Estimation ║") - table.insert(lines, "╠══════════════════════════════════════════════════════╣") - table.insert(lines, "") + -- Header + table.insert( + lines, + "╔══════════════════════════════════════════════════════╗" + ) + table.insert(lines, "║ 💰 LLM Cost Estimation ║") + table.insert( + lines, + "╠══════════════════════════════════════════════════════╣" + ) + table.insert(lines, "") - -- All-time summary (prominent) - table.insert(lines, "🌐 All-Time Summary (Project)") - table.insert(lines, "───────────────────────────────────────────────────────") - if all_time_stats.time_span > 0 then - table.insert(lines, string.format(" Time span: %s", format_duration(all_time_stats.time_span))) - end - table.insert(lines, string.format(" Requests: %d total", all_time_stats.request_count)) - table.insert(lines, string.format(" Local/Free: %d requests", all_time_stats.free_requests or 0)) - table.insert(lines, string.format(" Paid API: %d requests", all_time_stats.paid_requests or 0)) - table.insert(lines, string.format(" Input tokens: %s", format_tokens(all_time_stats.total_input))) - table.insert(lines, string.format(" Output tokens: %s", format_tokens(all_time_stats.total_output))) - if all_time_stats.total_cached > 0 then - table.insert(lines, string.format(" Cached tokens: %s", format_tokens(all_time_stats.total_cached))) - end - table.insert(lines, "") - table.insert(lines, string.format(" 💵 Total Cost: %s", format_cost(all_time_stats.total_cost))) + -- All-time summary (prominent) + table.insert(lines, "🌐 All-Time Summary (Project)") + table.insert( + lines, + "───────────────────────────────────────────────────────" + ) + if all_time_stats.time_span > 0 then + table.insert(lines, string.format(" Time span: %s", format_duration(all_time_stats.time_span))) + end + table.insert(lines, string.format(" Requests: %d total", all_time_stats.request_count)) + table.insert(lines, string.format(" Local/Free: %d requests", all_time_stats.free_requests or 0)) + table.insert(lines, string.format(" Paid API: %d requests", all_time_stats.paid_requests or 0)) + table.insert(lines, string.format(" Input tokens: %s", format_tokens(all_time_stats.total_input))) + table.insert(lines, string.format(" Output tokens: %s", format_tokens(all_time_stats.total_output))) + if all_time_stats.total_cached > 0 then + table.insert(lines, string.format(" Cached tokens: %s", format_tokens(all_time_stats.total_cached))) + end + table.insert(lines, "") + table.insert(lines, string.format(" 💵 Total Cost: %s", format_cost(all_time_stats.total_cost))) - -- Show savings prominently if there are any - if all_time_stats.total_savings and all_time_stats.total_savings > 0 then - table.insert(lines, string.format(" 💚 Saved: %s (vs %s)", format_cost(all_time_stats.total_savings), M.comparison_model)) - end - table.insert(lines, "") + -- Show savings prominently if there are any + if all_time_stats.total_savings and all_time_stats.total_savings > 0 then + table.insert( + lines, + string.format(" 💚 Saved: %s (vs %s)", format_cost(all_time_stats.total_savings), M.comparison_model) + ) + end + table.insert(lines, "") - -- Session summary - table.insert(lines, "📊 Current Session") - table.insert(lines, "───────────────────────────────────────────────────────") - table.insert(lines, string.format(" Duration: %s", format_duration(session_stats.session_duration))) - table.insert(lines, string.format(" Requests: %d (%d free, %d paid)", - session_stats.request_count, - session_stats.free_requests or 0, - session_stats.paid_requests or 0)) - table.insert(lines, string.format(" Input tokens: %s", format_tokens(session_stats.total_input))) - table.insert(lines, string.format(" Output tokens: %s", format_tokens(session_stats.total_output))) - if session_stats.total_cached > 0 then - table.insert(lines, string.format(" Cached tokens: %s", format_tokens(session_stats.total_cached))) - end - table.insert(lines, string.format(" Session Cost: %s", format_cost(session_stats.total_cost))) - if session_stats.total_savings and session_stats.total_savings > 0 then - table.insert(lines, string.format(" Session Saved: %s", format_cost(session_stats.total_savings))) - end - table.insert(lines, "") + -- Session summary + table.insert(lines, "📊 Current Session") + table.insert( + lines, + "───────────────────────────────────────────────────────" + ) + table.insert(lines, string.format(" Duration: %s", format_duration(session_stats.session_duration))) + table.insert( + lines, + string.format( + " Requests: %d (%d free, %d paid)", + session_stats.request_count, + session_stats.free_requests or 0, + session_stats.paid_requests or 0 + ) + ) + table.insert(lines, string.format(" Input tokens: %s", format_tokens(session_stats.total_input))) + table.insert(lines, string.format(" Output tokens: %s", format_tokens(session_stats.total_output))) + if session_stats.total_cached > 0 then + table.insert(lines, string.format(" Cached tokens: %s", format_tokens(session_stats.total_cached))) + end + table.insert(lines, string.format(" Session Cost: %s", format_cost(session_stats.total_cost))) + if session_stats.total_savings and session_stats.total_savings > 0 then + table.insert(lines, string.format(" Session Saved: %s", format_cost(session_stats.total_savings))) + end + table.insert(lines, "") - -- Per-model breakdown (all-time) - table.insert(lines, "📈 Cost by Model (All-Time)") - table.insert(lines, "───────────────────────────────────────────────────────") - local model_lines = generate_model_breakdown(all_time_stats) - for _, line in ipairs(model_lines) do - table.insert(lines, line) - end + -- Per-model breakdown (all-time) + table.insert(lines, "📈 Cost by Model (All-Time)") + table.insert( + lines, + "───────────────────────────────────────────────────────" + ) + local model_lines = generate_model_breakdown(all_time_stats) + for _, line in ipairs(model_lines) do + table.insert(lines, line) + end - table.insert(lines, "") - table.insert(lines, "───────────────────────────────────────────────────────") - table.insert(lines, " 'q' close | 'r' refresh | 'c' clear session | 'C' all") - table.insert(lines, "╚══════════════════════════════════════════════════════╝") + table.insert(lines, "") + table.insert( + lines, + "───────────────────────────────────────────────────────" + ) + table.insert(lines, " 'q' close | 'r' refresh | 'c' clear session | 'C' all") + table.insert( + lines, + "╚══════════════════════════════════════════════════════╝" + ) - return lines + return lines end --- Refresh the cost window content function M.refresh_window() - if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then - return - end + if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then + return + end - local lines = generate_content() + local lines = generate_content() - vim.bo[state.buf].modifiable = true - vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, lines) - vim.bo[state.buf].modifiable = false + vim.bo[state.buf].modifiable = true + vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, lines) + vim.bo[state.buf].modifiable = false end --- Open the cost estimation window function M.open() - -- Load historical data if not loaded - M.load_from_history() + -- Load historical data if not loaded + M.load_from_history() - -- Close existing window if open - if state.win and vim.api.nvim_win_is_valid(state.win) then - vim.api.nvim_win_close(state.win, true) - end + -- Close existing window if open + if state.win and vim.api.nvim_win_is_valid(state.win) then + vim.api.nvim_win_close(state.win, true) + end - -- Create buffer - state.buf = vim.api.nvim_create_buf(false, true) - vim.bo[state.buf].buftype = "nofile" - vim.bo[state.buf].bufhidden = "wipe" - vim.bo[state.buf].swapfile = false - vim.bo[state.buf].filetype = "codetyper-cost" + -- Create buffer + state.buf = vim.api.nvim_create_buf(false, true) + vim.bo[state.buf].buftype = "nofile" + vim.bo[state.buf].bufhidden = "wipe" + vim.bo[state.buf].swapfile = false + vim.bo[state.buf].filetype = "codetyper-cost" - -- Calculate window size - local width = 58 - local height = 40 - local row = math.floor((vim.o.lines - height) / 2) - local col = math.floor((vim.o.columns - width) / 2) + -- Calculate window size + local width = 58 + local height = 40 + local row = math.floor((vim.o.lines - height) / 2) + local col = math.floor((vim.o.columns - width) / 2) - -- Create floating window - state.win = vim.api.nvim_open_win(state.buf, true, { - relative = "editor", - width = width, - height = height, - row = row, - col = col, - style = "minimal", - border = "rounded", - title = " Cost Estimation ", - title_pos = "center", - }) + -- Create floating window + state.win = vim.api.nvim_open_win(state.buf, true, { + relative = "editor", + width = width, + height = height, + row = row, + col = col, + style = "minimal", + border = "rounded", + title = " Cost Estimation ", + title_pos = "center", + }) - -- Set window options - vim.wo[state.win].wrap = false - vim.wo[state.win].cursorline = false + -- Set window options + vim.wo[state.win].wrap = false + vim.wo[state.win].cursorline = false - -- Populate content - M.refresh_window() + -- Populate content + M.refresh_window() - -- Set up keymaps - local opts = { buffer = state.buf, silent = true } - vim.keymap.set("n", "q", function() - M.close() - end, opts) - vim.keymap.set("n", "", function() - M.close() - end, opts) - vim.keymap.set("n", "r", function() - M.refresh_window() - end, opts) - vim.keymap.set("n", "c", function() - M.clear_session() - M.refresh_window() - end, opts) - vim.keymap.set("n", "C", function() - M.clear_all() - M.refresh_window() - end, opts) + -- Set up keymaps + local opts = { buffer = state.buf, silent = true } + vim.keymap.set("n", "q", function() + M.close() + end, opts) + vim.keymap.set("n", "", function() + M.close() + end, opts) + vim.keymap.set("n", "r", function() + M.refresh_window() + end, opts) + vim.keymap.set("n", "c", function() + M.clear_session() + M.refresh_window() + end, opts) + vim.keymap.set("n", "C", function() + M.clear_all() + M.refresh_window() + end, opts) - -- Set up highlights - vim.api.nvim_buf_call(state.buf, function() - vim.fn.matchadd("Title", "LLM Cost Estimation") - vim.fn.matchadd("Number", "\\$[0-9.]*") - vim.fn.matchadd("Keyword", "[0-9.]*[KM]\\? tokens") - vim.fn.matchadd("Special", "🤖\\|💰\\|📊\\|📈\\|💵") - end) + -- Set up highlights + vim.api.nvim_buf_call(state.buf, function() + vim.fn.matchadd("Title", "LLM Cost Estimation") + vim.fn.matchadd("Number", "\\$[0-9.]*") + vim.fn.matchadd("Keyword", "[0-9.]*[KM]\\? tokens") + vim.fn.matchadd("Special", "🤖\\|💰\\|📊\\|📈\\|💵") + end) end --- Close the cost window function M.close() - if state.win and vim.api.nvim_win_is_valid(state.win) then - vim.api.nvim_win_close(state.win, true) - end - state.win = nil - state.buf = nil + if state.win and vim.api.nvim_win_is_valid(state.win) then + vim.api.nvim_win_close(state.win, true) + end + state.win = nil + state.buf = nil end --- Toggle the cost window function M.toggle() - if state.win and vim.api.nvim_win_is_valid(state.win) then - M.close() - else - M.open() - end + if state.win and vim.api.nvim_win_is_valid(state.win) then + M.close() + else + M.open() + end end --- Clear session usage (not history) function M.clear_session() - state.usage = {} - state.session_start = os.time() - utils.notify("Session cost tracking cleared", vim.log.levels.INFO) + state.usage = {} + state.session_start = os.time() + utils.notify("Session cost tracking cleared", vim.log.levels.INFO) end --- Clear all history (session + saved) function M.clear_all() - state.usage = {} - state.all_usage = {} - state.session_start = os.time() - state.loaded = false + state.usage = {} + state.all_usage = {} + state.session_start = os.time() + state.loaded = false - -- Delete history file - local history_path = get_history_path() - local ok, err = os.remove(history_path) - if not ok and err and not err:match("No such file") then - utils.notify("Failed to delete history: " .. err, vim.log.levels.WARN) - end + -- Delete history file + local history_path = get_history_path() + local ok, err = os.remove(history_path) + if not ok and err and not err:match("No such file") then + utils.notify("Failed to delete history: " .. err, vim.log.levels.WARN) + end - utils.notify("All cost history cleared", vim.log.levels.INFO) + utils.notify("All cost history cleared", vim.log.levels.INFO) end --- Clear usage history (alias for clear_session) function M.clear() - M.clear_session() + M.clear_session() end --- Reset session function M.reset() - M.clear_session() + M.clear_session() end return M diff --git a/lua/codetyper/core/diff/conflict.lua b/lua/codetyper/core/diff/conflict.lua index 7465ad8..a85d873 100644 --- a/lua/codetyper/core/diff/conflict.lua +++ b/lua/codetyper/core/diff/conflict.lua @@ -19,7 +19,7 @@ local params = require("codetyper.params.agents.conflict") --- Lazy load linter module local function get_linter() - return require("codetyper.features.agents.linter") + return require("codetyper.features.agents.linter") end --- Configuration @@ -46,237 +46,237 @@ local conflict_buffers = {} ---@param end_line number End line of changed region ---@param accepted_type string Type of acceptance ("theirs", "both") local function validate_after_accept(bufnr, start_line, end_line, accepted_type) - if not config.lint_after_accept then - return - end + if not config.lint_after_accept then + return + end - -- Only validate when accepting AI suggestions - if accepted_type ~= "theirs" and accepted_type ~= "both" then - return - end + -- Only validate when accepting AI suggestions + if accepted_type ~= "theirs" and accepted_type ~= "both" then + return + end - local linter = get_linter() + local linter = get_linter() - -- Validate the changed region - linter.validate_after_injection(bufnr, start_line, end_line, function(result) - if not result then - return - end + -- Validate the changed region + linter.validate_after_injection(bufnr, start_line, end_line, function(result) + if not result then + return + end - -- If errors found and auto-fix is enabled, queue fix automatically - if result.has_errors and config.auto_fix_lint_errors then - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = "Auto-queuing fix for lint errors...", - }) - end) - linter.request_ai_fix(bufnr, result) - end - end) + -- If errors found and auto-fix is enabled, queue fix automatically + if result.has_errors and config.auto_fix_lint_errors then + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = "Auto-queuing fix for lint errors...", + }) + end) + linter.request_ai_fix(bufnr, result) + end + end) end --- Configure conflict behavior ---@param opts table Configuration options function M.configure(opts) - for k, v in pairs(opts) do - if config[k] ~= nil then - config[k] = v - end - end + for k, v in pairs(opts) do + if config[k] ~= nil then + config[k] = v + end + end end --- Get current configuration ---@return table function M.get_config() - return vim.deepcopy(config) + return vim.deepcopy(config) end --- Auto-show menu for next conflict if enabled and conflicts remain ---@param bufnr number Buffer number local function auto_show_next_conflict_menu(bufnr) - if not config.auto_show_next_menu then - return - end + if not config.auto_show_next_menu then + return + end - vim.schedule(function() - if not vim.api.nvim_buf_is_valid(bufnr) then - return - end + vim.schedule(function() + if not vim.api.nvim_buf_is_valid(bufnr) then + return + end - local conflicts = M.detect_conflicts(bufnr) - if #conflicts > 0 then - -- Jump to first remaining conflict and show menu - local conflict = conflicts[1] - local win = vim.api.nvim_get_current_win() - if vim.api.nvim_win_get_buf(win) == bufnr then - vim.api.nvim_win_set_cursor(win, { conflict.start_line, 0 }) - vim.cmd("normal! zz") - M.show_floating_menu(bufnr) - end - end - end) + local conflicts = M.detect_conflicts(bufnr) + if #conflicts > 0 then + -- Jump to first remaining conflict and show menu + local conflict = conflicts[1] + local win = vim.api.nvim_get_current_win() + if vim.api.nvim_win_get_buf(win) == bufnr then + vim.api.nvim_win_set_cursor(win, { conflict.start_line, 0 }) + vim.cmd("normal! zz") + M.show_floating_menu(bufnr) + end + end + end) end --- Setup highlight groups local function setup_highlights() - -- Current (original) code - green tint - vim.api.nvim_set_hl(0, HL_GROUPS.current, { - bg = "#2d4a3e", - default = true, - }) - vim.api.nvim_set_hl(0, HL_GROUPS.current_label, { - fg = "#98c379", - bg = "#2d4a3e", - bold = true, - default = true, - }) + -- Current (original) code - green tint + vim.api.nvim_set_hl(0, HL_GROUPS.current, { + bg = "#2d4a3e", + default = true, + }) + vim.api.nvim_set_hl(0, HL_GROUPS.current_label, { + fg = "#98c379", + bg = "#2d4a3e", + bold = true, + default = true, + }) - -- Incoming (AI suggestion) code - blue tint - vim.api.nvim_set_hl(0, HL_GROUPS.incoming, { - bg = "#2d3a4a", - default = true, - }) - vim.api.nvim_set_hl(0, HL_GROUPS.incoming_label, { - fg = "#61afef", - bg = "#2d3a4a", - bold = true, - default = true, - }) + -- Incoming (AI suggestion) code - blue tint + vim.api.nvim_set_hl(0, HL_GROUPS.incoming, { + bg = "#2d3a4a", + default = true, + }) + vim.api.nvim_set_hl(0, HL_GROUPS.incoming_label, { + fg = "#61afef", + bg = "#2d3a4a", + bold = true, + default = true, + }) - -- Separator line - vim.api.nvim_set_hl(0, HL_GROUPS.separator, { - fg = "#5c6370", - bg = "#3e4451", - bold = true, - default = true, - }) + -- Separator line + vim.api.nvim_set_hl(0, HL_GROUPS.separator, { + fg = "#5c6370", + bg = "#3e4451", + bold = true, + default = true, + }) - -- Keybinding hints - vim.api.nvim_set_hl(0, HL_GROUPS.hint, { - fg = "#5c6370", - italic = true, - default = true, - }) + -- Keybinding hints + vim.api.nvim_set_hl(0, HL_GROUPS.hint, { + fg = "#5c6370", + italic = true, + default = true, + }) end --- Parse a buffer and find all conflict regions ---@param bufnr number Buffer number ---@return table[] conflicts List of conflict positions function M.detect_conflicts(bufnr) - if not vim.api.nvim_buf_is_valid(bufnr) then - return {} - end + if not vim.api.nvim_buf_is_valid(bufnr) then + return {} + end - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local conflicts = {} - local current_conflict = nil + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local conflicts = {} + local current_conflict = nil - for i, line in ipairs(lines) do - if line:match("^<<<<<<<") then - current_conflict = { - start_line = i, - current_start = i, - current_end = nil, - separator = nil, - incoming_start = nil, - incoming_end = nil, - end_line = nil, - } - elseif line:match("^=======") and current_conflict then - current_conflict.current_end = i - 1 - current_conflict.separator = i - current_conflict.incoming_start = i + 1 - elseif line:match("^>>>>>>>") and current_conflict then - current_conflict.incoming_end = i - 1 - current_conflict.end_line = i - table.insert(conflicts, current_conflict) - current_conflict = nil - end - end + for i, line in ipairs(lines) do + if line:match("^<<<<<<<") then + current_conflict = { + start_line = i, + current_start = i, + current_end = nil, + separator = nil, + incoming_start = nil, + incoming_end = nil, + end_line = nil, + } + elseif line:match("^=======") and current_conflict then + current_conflict.current_end = i - 1 + current_conflict.separator = i + current_conflict.incoming_start = i + 1 + elseif line:match("^>>>>>>>") and current_conflict then + current_conflict.incoming_end = i - 1 + current_conflict.end_line = i + table.insert(conflicts, current_conflict) + current_conflict = nil + end + end - return conflicts + return conflicts end --- Highlight conflicts in buffer using extmarks ---@param bufnr number Buffer number ---@param conflicts table[] Conflict positions function M.highlight_conflicts(bufnr, conflicts) - if not vim.api.nvim_buf_is_valid(bufnr) then - return - end + if not vim.api.nvim_buf_is_valid(bufnr) then + return + end - -- Clear existing highlights - vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) - vim.api.nvim_buf_clear_namespace(bufnr, HINT_NAMESPACE, 0, -1) + -- Clear existing highlights + vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) + vim.api.nvim_buf_clear_namespace(bufnr, HINT_NAMESPACE, 0, -1) - for _, conflict in ipairs(conflicts) do - -- Highlight <<<<<<< CURRENT line - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, conflict.start_line - 1, 0, { - end_row = conflict.start_line - 1, - end_col = 0, - line_hl_group = HL_GROUPS.current_label, - priority = 100, - }) + for _, conflict in ipairs(conflicts) do + -- Highlight <<<<<<< CURRENT line + vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, conflict.start_line - 1, 0, { + end_row = conflict.start_line - 1, + end_col = 0, + line_hl_group = HL_GROUPS.current_label, + priority = 100, + }) - -- Highlight current (original) code section - if conflict.current_start and conflict.current_end then - for row = conflict.current_start, conflict.current_end do - if row <= conflict.current_end then - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, row - 1, 0, { - end_row = row - 1, - end_col = 0, - line_hl_group = HL_GROUPS.current, - priority = 90, - }) - end - end - end + -- Highlight current (original) code section + if conflict.current_start and conflict.current_end then + for row = conflict.current_start, conflict.current_end do + if row <= conflict.current_end then + vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, row - 1, 0, { + end_row = row - 1, + end_col = 0, + line_hl_group = HL_GROUPS.current, + priority = 90, + }) + end + end + end - -- Highlight ======= separator - if conflict.separator then - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, conflict.separator - 1, 0, { - end_row = conflict.separator - 1, - end_col = 0, - line_hl_group = HL_GROUPS.separator, - priority = 100, - }) - end + -- Highlight ======= separator + if conflict.separator then + vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, conflict.separator - 1, 0, { + end_row = conflict.separator - 1, + end_col = 0, + line_hl_group = HL_GROUPS.separator, + priority = 100, + }) + end - -- Highlight incoming (AI suggestion) code section - if conflict.incoming_start and conflict.incoming_end then - for row = conflict.incoming_start, conflict.incoming_end do - if row <= conflict.incoming_end then - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, row - 1, 0, { - end_row = row - 1, - end_col = 0, - line_hl_group = HL_GROUPS.incoming, - priority = 90, - }) - end - end - end + -- Highlight incoming (AI suggestion) code section + if conflict.incoming_start and conflict.incoming_end then + for row = conflict.incoming_start, conflict.incoming_end do + if row <= conflict.incoming_end then + vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, row - 1, 0, { + end_row = row - 1, + end_col = 0, + line_hl_group = HL_GROUPS.incoming, + priority = 90, + }) + end + end + end - -- Highlight >>>>>>> INCOMING line - if conflict.end_line then - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, conflict.end_line - 1, 0, { - end_row = conflict.end_line - 1, - end_col = 0, - line_hl_group = HL_GROUPS.incoming_label, - priority = 100, - }) - end + -- Highlight >>>>>>> INCOMING line + if conflict.end_line then + vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, conflict.end_line - 1, 0, { + end_row = conflict.end_line - 1, + end_col = 0, + line_hl_group = HL_GROUPS.incoming_label, + priority = 100, + }) + end - -- Add virtual text hint on the <<<<<<< line - vim.api.nvim_buf_set_extmark(bufnr, HINT_NAMESPACE, conflict.start_line - 1, 0, { - virt_text = { - { " [co]=ours [ct]=theirs [cb]=both [cn]=none [x/]x=nav", HL_GROUPS.hint }, - }, - virt_text_pos = "eol", - priority = 50, - }) - end + -- Add virtual text hint on the <<<<<<< line + vim.api.nvim_buf_set_extmark(bufnr, HINT_NAMESPACE, conflict.start_line - 1, 0, { + virt_text = { + { " [co]=ours [ct]=theirs [cb]=both [cn]=none [x/]x=nav", HL_GROUPS.hint }, + }, + virt_text_pos = "eol", + priority = 50, + }) + end end --- Get the conflict at the current cursor position @@ -284,565 +284,585 @@ end ---@param cursor_line number Current line (1-indexed) ---@return table|nil conflict The conflict at cursor, or nil function M.get_conflict_at_cursor(bufnr, cursor_line) - local conflicts = M.detect_conflicts(bufnr) + local conflicts = M.detect_conflicts(bufnr) - for _, conflict in ipairs(conflicts) do - if cursor_line >= conflict.start_line and cursor_line <= conflict.end_line then - return conflict - end - end + for _, conflict in ipairs(conflicts) do + if cursor_line >= conflict.start_line and cursor_line <= conflict.end_line then + return conflict + end + end - return nil + return nil end --- Accept "ours" - keep the original code ---@param bufnr number Buffer number function M.accept_ours(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) + local cursor = vim.api.nvim_win_get_cursor(0) + local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) - if not conflict then - vim.notify("No conflict at cursor position", vim.log.levels.WARN) - return - end + if not conflict then + vim.notify("No conflict at cursor position", vim.log.levels.WARN) + return + end - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - -- Extract the "current" (original) lines - local keep_lines = {} - if conflict.current_start and conflict.current_end then - for i = conflict.current_start + 1, conflict.current_end do - table.insert(keep_lines, lines[i]) - end - end + -- Extract the "current" (original) lines + local keep_lines = {} + if conflict.current_start and conflict.current_end then + for i = conflict.current_start + 1, conflict.current_end do + table.insert(keep_lines, lines[i]) + end + end - -- Replace the entire conflict region with the kept lines - vim.api.nvim_buf_set_lines(bufnr, conflict.start_line - 1, conflict.end_line, false, keep_lines) + -- Replace the entire conflict region with the kept lines + vim.api.nvim_buf_set_lines(bufnr, conflict.start_line - 1, conflict.end_line, false, keep_lines) - -- Re-process remaining conflicts - M.process(bufnr) + -- Re-process remaining conflicts + M.process(bufnr) - vim.notify("Accepted CURRENT (original) code", vim.log.levels.INFO) + vim.notify("Accepted CURRENT (original) code", vim.log.levels.INFO) - -- Auto-show menu for next conflict if any remain - auto_show_next_conflict_menu(bufnr) + -- Auto-show menu for next conflict if any remain + auto_show_next_conflict_menu(bufnr) end --- Accept "theirs" - use the AI suggestion ---@param bufnr number Buffer number function M.accept_theirs(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) + local cursor = vim.api.nvim_win_get_cursor(0) + local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) - if not conflict then - vim.notify("No conflict at cursor position", vim.log.levels.WARN) - return - end + if not conflict then + vim.notify("No conflict at cursor position", vim.log.levels.WARN) + return + end - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - -- Extract the "incoming" (AI suggestion) lines - local keep_lines = {} - if conflict.incoming_start and conflict.incoming_end then - for i = conflict.incoming_start, conflict.incoming_end do - table.insert(keep_lines, lines[i]) - end - end + -- Extract the "incoming" (AI suggestion) lines + local keep_lines = {} + if conflict.incoming_start and conflict.incoming_end then + for i = conflict.incoming_start, conflict.incoming_end do + table.insert(keep_lines, lines[i]) + end + end - -- Track where the code will be inserted - local insert_start = conflict.start_line - local insert_end = insert_start + #keep_lines - 1 + -- Track where the code will be inserted + local insert_start = conflict.start_line + local insert_end = insert_start + #keep_lines - 1 - -- Replace the entire conflict region with the kept lines - vim.api.nvim_buf_set_lines(bufnr, conflict.start_line - 1, conflict.end_line, false, keep_lines) + -- Replace the entire conflict region with the kept lines + vim.api.nvim_buf_set_lines(bufnr, conflict.start_line - 1, conflict.end_line, false, keep_lines) - -- Re-process remaining conflicts - M.process(bufnr) + -- Re-process remaining conflicts + M.process(bufnr) - vim.notify("Accepted INCOMING (AI suggestion) code", vim.log.levels.INFO) + vim.notify("Accepted INCOMING (AI suggestion) code", vim.log.levels.INFO) - -- Run linter validation on the accepted code - validate_after_accept(bufnr, insert_start, insert_end, "theirs") + -- Run linter validation on the accepted code + validate_after_accept(bufnr, insert_start, insert_end, "theirs") - -- Auto-show menu for next conflict if any remain - auto_show_next_conflict_menu(bufnr) + -- Auto-show menu for next conflict if any remain + auto_show_next_conflict_menu(bufnr) end --- Accept "both" - keep both versions ---@param bufnr number Buffer number function M.accept_both(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) + local cursor = vim.api.nvim_win_get_cursor(0) + local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) - if not conflict then - vim.notify("No conflict at cursor position", vim.log.levels.WARN) - return - end + if not conflict then + vim.notify("No conflict at cursor position", vim.log.levels.WARN) + return + end - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - -- Extract both "current" and "incoming" lines - local keep_lines = {} + -- Extract both "current" and "incoming" lines + local keep_lines = {} - -- Add current lines - if conflict.current_start and conflict.current_end then - for i = conflict.current_start + 1, conflict.current_end do - table.insert(keep_lines, lines[i]) - end - end + -- Add current lines + if conflict.current_start and conflict.current_end then + for i = conflict.current_start + 1, conflict.current_end do + table.insert(keep_lines, lines[i]) + end + end - -- Add incoming lines - if conflict.incoming_start and conflict.incoming_end then - for i = conflict.incoming_start, conflict.incoming_end do - table.insert(keep_lines, lines[i]) - end - end + -- Add incoming lines + if conflict.incoming_start and conflict.incoming_end then + for i = conflict.incoming_start, conflict.incoming_end do + table.insert(keep_lines, lines[i]) + end + end - -- Track where the code will be inserted - local insert_start = conflict.start_line - local insert_end = insert_start + #keep_lines - 1 + -- Track where the code will be inserted + local insert_start = conflict.start_line + local insert_end = insert_start + #keep_lines - 1 - -- Replace the entire conflict region with the kept lines - vim.api.nvim_buf_set_lines(bufnr, conflict.start_line - 1, conflict.end_line, false, keep_lines) + -- Replace the entire conflict region with the kept lines + vim.api.nvim_buf_set_lines(bufnr, conflict.start_line - 1, conflict.end_line, false, keep_lines) - -- Re-process remaining conflicts - M.process(bufnr) + -- Re-process remaining conflicts + M.process(bufnr) - vim.notify("Accepted BOTH (current + incoming) code", vim.log.levels.INFO) + vim.notify("Accepted BOTH (current + incoming) code", vim.log.levels.INFO) - -- Run linter validation on the accepted code - validate_after_accept(bufnr, insert_start, insert_end, "both") + -- Run linter validation on the accepted code + validate_after_accept(bufnr, insert_start, insert_end, "both") - -- Auto-show menu for next conflict if any remain - auto_show_next_conflict_menu(bufnr) + -- Auto-show menu for next conflict if any remain + auto_show_next_conflict_menu(bufnr) end --- Accept "none" - delete both versions ---@param bufnr number Buffer number function M.accept_none(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) + local cursor = vim.api.nvim_win_get_cursor(0) + local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) - if not conflict then - vim.notify("No conflict at cursor position", vim.log.levels.WARN) - return - end + if not conflict then + vim.notify("No conflict at cursor position", vim.log.levels.WARN) + return + end - -- Replace the entire conflict region with nothing - vim.api.nvim_buf_set_lines(bufnr, conflict.start_line - 1, conflict.end_line, false, {}) + -- Replace the entire conflict region with nothing + vim.api.nvim_buf_set_lines(bufnr, conflict.start_line - 1, conflict.end_line, false, {}) - -- Re-process remaining conflicts - M.process(bufnr) + -- Re-process remaining conflicts + M.process(bufnr) - vim.notify("Deleted conflict (accepted NONE)", vim.log.levels.INFO) + vim.notify("Deleted conflict (accepted NONE)", vim.log.levels.INFO) - -- Auto-show menu for next conflict if any remain - auto_show_next_conflict_menu(bufnr) + -- Auto-show menu for next conflict if any remain + auto_show_next_conflict_menu(bufnr) end --- Navigate to the next conflict ---@param bufnr number Buffer number ---@return boolean found Whether a conflict was found function M.goto_next(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local cursor_line = cursor[1] - local conflicts = M.detect_conflicts(bufnr) + local cursor = vim.api.nvim_win_get_cursor(0) + local cursor_line = cursor[1] + local conflicts = M.detect_conflicts(bufnr) - for _, conflict in ipairs(conflicts) do - if conflict.start_line > cursor_line then - vim.api.nvim_win_set_cursor(0, { conflict.start_line, 0 }) - vim.cmd("normal! zz") - return true - end - end + for _, conflict in ipairs(conflicts) do + if conflict.start_line > cursor_line then + vim.api.nvim_win_set_cursor(0, { conflict.start_line, 0 }) + vim.cmd("normal! zz") + return true + end + end - -- Wrap around to first conflict - if #conflicts > 0 then - vim.api.nvim_win_set_cursor(0, { conflicts[1].start_line, 0 }) - vim.cmd("normal! zz") - vim.notify("Wrapped to first conflict", vim.log.levels.INFO) - return true - end + -- Wrap around to first conflict + if #conflicts > 0 then + vim.api.nvim_win_set_cursor(0, { conflicts[1].start_line, 0 }) + vim.cmd("normal! zz") + vim.notify("Wrapped to first conflict", vim.log.levels.INFO) + return true + end - vim.notify("No more conflicts", vim.log.levels.INFO) - return false + vim.notify("No more conflicts", vim.log.levels.INFO) + return false end --- Navigate to the previous conflict ---@param bufnr number Buffer number ---@return boolean found Whether a conflict was found function M.goto_prev(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local cursor_line = cursor[1] - local conflicts = M.detect_conflicts(bufnr) + local cursor = vim.api.nvim_win_get_cursor(0) + local cursor_line = cursor[1] + local conflicts = M.detect_conflicts(bufnr) - for i = #conflicts, 1, -1 do - local conflict = conflicts[i] - if conflict.start_line < cursor_line then - vim.api.nvim_win_set_cursor(0, { conflict.start_line, 0 }) - vim.cmd("normal! zz") - return true - end - end + for i = #conflicts, 1, -1 do + local conflict = conflicts[i] + if conflict.start_line < cursor_line then + vim.api.nvim_win_set_cursor(0, { conflict.start_line, 0 }) + vim.cmd("normal! zz") + return true + end + end - -- Wrap around to last conflict - if #conflicts > 0 then - vim.api.nvim_win_set_cursor(0, { conflicts[#conflicts].start_line, 0 }) - vim.cmd("normal! zz") - vim.notify("Wrapped to last conflict", vim.log.levels.INFO) - return true - end + -- Wrap around to last conflict + if #conflicts > 0 then + vim.api.nvim_win_set_cursor(0, { conflicts[#conflicts].start_line, 0 }) + vim.cmd("normal! zz") + vim.notify("Wrapped to last conflict", vim.log.levels.INFO) + return true + end - vim.notify("No more conflicts", vim.log.levels.INFO) - return false + vim.notify("No more conflicts", vim.log.levels.INFO) + return false end --- Show conflict resolution menu modal ---@param bufnr number Buffer number function M.show_menu(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) + local cursor = vim.api.nvim_win_get_cursor(0) + local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) - if not conflict then - vim.notify("No conflict at cursor position", vim.log.levels.WARN) - return - end + if not conflict then + vim.notify("No conflict at cursor position", vim.log.levels.WARN) + return + end - -- Get preview of both versions - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + -- Get preview of both versions + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local current_preview = "" - if conflict.current_start and conflict.current_end then - local current_lines = {} - for i = conflict.current_start + 1, math.min(conflict.current_end, conflict.current_start + 3) do - if lines[i] then - table.insert(current_lines, " " .. lines[i]:sub(1, 50)) - end - end - if conflict.current_end - conflict.current_start > 3 then - table.insert(current_lines, " ...") - end - current_preview = table.concat(current_lines, "\n") - end + local current_preview = "" + if conflict.current_start and conflict.current_end then + local current_lines = {} + for i = conflict.current_start + 1, math.min(conflict.current_end, conflict.current_start + 3) do + if lines[i] then + table.insert(current_lines, " " .. lines[i]:sub(1, 50)) + end + end + if conflict.current_end - conflict.current_start > 3 then + table.insert(current_lines, " ...") + end + current_preview = table.concat(current_lines, "\n") + end - local incoming_preview = "" - if conflict.incoming_start and conflict.incoming_end then - local incoming_lines = {} - for i = conflict.incoming_start, math.min(conflict.incoming_end, conflict.incoming_start + 2) do - if lines[i] then - table.insert(incoming_lines, " " .. lines[i]:sub(1, 50)) - end - end - if conflict.incoming_end - conflict.incoming_start > 3 then - table.insert(incoming_lines, " ...") - end - incoming_preview = table.concat(incoming_lines, "\n") - end + local incoming_preview = "" + if conflict.incoming_start and conflict.incoming_end then + local incoming_lines = {} + for i = conflict.incoming_start, math.min(conflict.incoming_end, conflict.incoming_start + 2) do + if lines[i] then + table.insert(incoming_lines, " " .. lines[i]:sub(1, 50)) + end + end + if conflict.incoming_end - conflict.incoming_start > 3 then + table.insert(incoming_lines, " ...") + end + incoming_preview = table.concat(incoming_lines, "\n") + end - -- Count lines in each section - local current_count = conflict.current_end and conflict.current_start - and (conflict.current_end - conflict.current_start) or 0 - local incoming_count = conflict.incoming_end and conflict.incoming_start - and (conflict.incoming_end - conflict.incoming_start + 1) or 0 + -- Count lines in each section + local current_count = conflict.current_end + and conflict.current_start + and (conflict.current_end - conflict.current_start) + or 0 + local incoming_count = conflict.incoming_end + and conflict.incoming_start + and (conflict.incoming_end - conflict.incoming_start + 1) + or 0 - -- Build menu options - local options = { - { - label = string.format("Accept CURRENT (original) - %d lines", current_count), - key = "co", - action = function() M.accept_ours(bufnr) end, - preview = current_preview, - }, - { - label = string.format("Accept INCOMING (AI suggestion) - %d lines", incoming_count), - key = "ct", - action = function() M.accept_theirs(bufnr) end, - preview = incoming_preview, - }, - { - label = string.format("Accept BOTH versions - %d lines total", current_count + incoming_count), - key = "cb", - action = function() M.accept_both(bufnr) end, - }, - { - label = "Delete conflict (accept NONE)", - key = "cn", - action = function() M.accept_none(bufnr) end, - }, - { - label = "─────────────────────────", - key = "", - action = nil, - separator = true, - }, - { - label = "Next conflict", - key = "]x", - action = function() M.goto_next(bufnr) end, - }, - { - label = "Previous conflict", - key = "[x", - action = function() M.goto_prev(bufnr) end, - }, - } + -- Build menu options + local options = { + { + label = string.format("Accept CURRENT (original) - %d lines", current_count), + key = "co", + action = function() + M.accept_ours(bufnr) + end, + preview = current_preview, + }, + { + label = string.format("Accept INCOMING (AI suggestion) - %d lines", incoming_count), + key = "ct", + action = function() + M.accept_theirs(bufnr) + end, + preview = incoming_preview, + }, + { + label = string.format("Accept BOTH versions - %d lines total", current_count + incoming_count), + key = "cb", + action = function() + M.accept_both(bufnr) + end, + }, + { + label = "Delete conflict (accept NONE)", + key = "cn", + action = function() + M.accept_none(bufnr) + end, + }, + { + label = "─────────────────────────", + key = "", + action = nil, + separator = true, + }, + { + label = "Next conflict", + key = "]x", + action = function() + M.goto_next(bufnr) + end, + }, + { + label = "Previous conflict", + key = "[x", + action = function() + M.goto_prev(bufnr) + end, + }, + } - -- Build display labels - local labels = {} - for _, opt in ipairs(options) do - if opt.separator then - table.insert(labels, opt.label) - else - table.insert(labels, string.format("[%s] %s", opt.key, opt.label)) - end - end + -- Build display labels + local labels = {} + for _, opt in ipairs(options) do + if opt.separator then + table.insert(labels, opt.label) + else + table.insert(labels, string.format("[%s] %s", opt.key, opt.label)) + end + end - -- Show menu using vim.ui.select - vim.ui.select(labels, { - prompt = "Resolve Conflict:", - format_item = function(item) - return item - end, - }, function(choice, idx) - if not choice or not idx then - return - end + -- Show menu using vim.ui.select + vim.ui.select(labels, { + prompt = "Resolve Conflict:", + format_item = function(item) + return item + end, + }, function(choice, idx) + if not choice or not idx then + return + end - local selected = options[idx] - if selected and selected.action then - selected.action() - end - end) + local selected = options[idx] + if selected and selected.action then + selected.action() + end + end) end --- Show floating window menu for conflict resolution ---@param bufnr number Buffer number function M.show_floating_menu(bufnr) - local cursor = vim.api.nvim_win_get_cursor(0) - local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) + local cursor = vim.api.nvim_win_get_cursor(0) + local conflict = M.get_conflict_at_cursor(bufnr, cursor[1]) - if not conflict then - vim.notify("No conflict at cursor position", vim.log.levels.WARN) - return - end + if not conflict then + vim.notify("No conflict at cursor position", vim.log.levels.WARN) + return + end - -- Get lines for preview - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + -- Get lines for preview + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - -- Count lines - local current_count = conflict.current_end and conflict.current_start - and (conflict.current_end - conflict.current_start) or 0 - local incoming_count = conflict.incoming_end and conflict.incoming_start - and (conflict.incoming_end - conflict.incoming_start + 1) or 0 + -- Count lines + local current_count = conflict.current_end + and conflict.current_start + and (conflict.current_end - conflict.current_start) + or 0 + local incoming_count = conflict.incoming_end + and conflict.incoming_start + and (conflict.incoming_end - conflict.incoming_start + 1) + or 0 - -- Build menu content - local menu_lines = { - "╭─────────────────────────────────────────╮", - "│ Resolve Conflict │", - "├─────────────────────────────────────────┤", - string.format("│ [co] Accept CURRENT (original) %3d lines│", current_count), - string.format("│ [ct] Accept INCOMING (AI) %3d lines│", incoming_count), - string.format("│ [cb] Accept BOTH %3d lines│", current_count + incoming_count), - "│ [cn] Delete conflict (NONE) │", - "├─────────────────────────────────────────┤", - "│ []x] Next conflict │", - "│ [[x] Previous conflict │", - "│ [q] Close menu │", - "╰─────────────────────────────────────────╯", - } + -- Build menu content + local menu_lines = { + "╭─────────────────────────────────────────╮", + "│ Resolve Conflict │", + "├─────────────────────────────────────────┤", + string.format("│ [co] Accept CURRENT (original) %3d lines│", current_count), + string.format("│ [ct] Accept INCOMING (AI) %3d lines│", incoming_count), + string.format("│ [cb] Accept BOTH %3d lines│", current_count + incoming_count), + "│ [cn] Delete conflict (NONE) │", + "├─────────────────────────────────────────┤", + "│ []x] Next conflict │", + "│ [[x] Previous conflict │", + "│ [q] Close menu │", + "╰─────────────────────────────────────────╯", + } - -- Create floating window - local width = 43 - local height = #menu_lines + -- Create floating window + local width = 43 + local height = #menu_lines - local float_opts = { - relative = "cursor", - row = 1, - col = 0, - width = width, - height = height, - style = "minimal", - border = "none", - focusable = true, - } + local float_opts = { + relative = "cursor", + row = 1, + col = 0, + width = width, + height = height, + style = "minimal", + border = "none", + focusable = true, + } - -- Create buffer for menu - local menu_bufnr = vim.api.nvim_create_buf(false, true) - vim.api.nvim_buf_set_lines(menu_bufnr, 0, -1, false, menu_lines) - vim.bo[menu_bufnr].modifiable = false - vim.bo[menu_bufnr].bufhidden = "wipe" + -- Create buffer for menu + local menu_bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(menu_bufnr, 0, -1, false, menu_lines) + vim.bo[menu_bufnr].modifiable = false + vim.bo[menu_bufnr].bufhidden = "wipe" - -- Open floating window - local win = vim.api.nvim_open_win(menu_bufnr, true, float_opts) + -- Open floating window + local win = vim.api.nvim_open_win(menu_bufnr, true, float_opts) - -- Set highlights - vim.api.nvim_set_hl(0, "CoderConflictMenuBorder", { fg = "#61afef", default = true }) - vim.api.nvim_set_hl(0, "CoderConflictMenuTitle", { fg = "#e5c07b", bold = true, default = true }) - vim.api.nvim_set_hl(0, "CoderConflictMenuKey", { fg = "#98c379", bold = true, default = true }) + -- Set highlights + vim.api.nvim_set_hl(0, "CoderConflictMenuBorder", { fg = "#61afef", default = true }) + vim.api.nvim_set_hl(0, "CoderConflictMenuTitle", { fg = "#e5c07b", bold = true, default = true }) + vim.api.nvim_set_hl(0, "CoderConflictMenuKey", { fg = "#98c379", bold = true, default = true }) - vim.wo[win].winhl = "Normal:Normal,FloatBorder:CoderConflictMenuBorder" + vim.wo[win].winhl = "Normal:Normal,FloatBorder:CoderConflictMenuBorder" - -- Add syntax highlighting to menu buffer - vim.api.nvim_buf_add_highlight(menu_bufnr, -1, "CoderConflictMenuTitle", 1, 0, -1) - for i = 3, 9 do - -- Highlight the key in brackets - local line = menu_lines[i + 1] - if line then - local start_col = line:find("%[") - local end_col = line:find("%]") - if start_col and end_col then - vim.api.nvim_buf_add_highlight(menu_bufnr, -1, "CoderConflictMenuKey", i, start_col - 1, end_col) - end - end - end + -- Add syntax highlighting to menu buffer + vim.api.nvim_buf_add_highlight(menu_bufnr, -1, "CoderConflictMenuTitle", 1, 0, -1) + for i = 3, 9 do + -- Highlight the key in brackets + local line = menu_lines[i + 1] + if line then + local start_col = line:find("%[") + local end_col = line:find("%]") + if start_col and end_col then + vim.api.nvim_buf_add_highlight(menu_bufnr, -1, "CoderConflictMenuKey", i, start_col - 1, end_col) + end + end + end - -- Setup keymaps for the menu - local close_menu = function() - if vim.api.nvim_win_is_valid(win) then - vim.api.nvim_win_close(win, true) - end - end + -- Setup keymaps for the menu + local close_menu = function() + if vim.api.nvim_win_is_valid(win) then + vim.api.nvim_win_close(win, true) + end + end - -- Use nowait to prevent delay from built-in 'c' command - local menu_opts = { buffer = menu_bufnr, silent = true, noremap = true, nowait = true } + -- Use nowait to prevent delay from built-in 'c' command + local menu_opts = { buffer = menu_bufnr, silent = true, noremap = true, nowait = true } - vim.keymap.set("n", "q", close_menu, menu_opts) - vim.keymap.set("n", "", close_menu, menu_opts) + vim.keymap.set("n", "q", close_menu, menu_opts) + vim.keymap.set("n", "", close_menu, menu_opts) - vim.keymap.set("n", "co", function() - close_menu() - M.accept_ours(bufnr) - end, menu_opts) + vim.keymap.set("n", "co", function() + close_menu() + M.accept_ours(bufnr) + end, menu_opts) - vim.keymap.set("n", "ct", function() - close_menu() - M.accept_theirs(bufnr) - end, menu_opts) + vim.keymap.set("n", "ct", function() + close_menu() + M.accept_theirs(bufnr) + end, menu_opts) - vim.keymap.set("n", "cb", function() - close_menu() - M.accept_both(bufnr) - end, menu_opts) + vim.keymap.set("n", "cb", function() + close_menu() + M.accept_both(bufnr) + end, menu_opts) - vim.keymap.set("n", "cn", function() - close_menu() - M.accept_none(bufnr) - end, menu_opts) + vim.keymap.set("n", "cn", function() + close_menu() + M.accept_none(bufnr) + end, menu_opts) - vim.keymap.set("n", "]x", function() - close_menu() - M.goto_next(bufnr) - end, menu_opts) + vim.keymap.set("n", "]x", function() + close_menu() + M.goto_next(bufnr) + end, menu_opts) - vim.keymap.set("n", "[x", function() - close_menu() - M.goto_prev(bufnr) - end, menu_opts) + vim.keymap.set("n", "[x", function() + close_menu() + M.goto_prev(bufnr) + end, menu_opts) - -- Also support number keys for quick selection - vim.keymap.set("n", "1", function() - close_menu() - M.accept_ours(bufnr) - end, menu_opts) + -- Also support number keys for quick selection + vim.keymap.set("n", "1", function() + close_menu() + M.accept_ours(bufnr) + end, menu_opts) - vim.keymap.set("n", "2", function() - close_menu() - M.accept_theirs(bufnr) - end, menu_opts) + vim.keymap.set("n", "2", function() + close_menu() + M.accept_theirs(bufnr) + end, menu_opts) - vim.keymap.set("n", "3", function() - close_menu() - M.accept_both(bufnr) - end, menu_opts) + vim.keymap.set("n", "3", function() + close_menu() + M.accept_both(bufnr) + end, menu_opts) - vim.keymap.set("n", "4", function() - close_menu() - M.accept_none(bufnr) - end, menu_opts) + vim.keymap.set("n", "4", function() + close_menu() + M.accept_none(bufnr) + end, menu_opts) - -- Close on focus lost - vim.api.nvim_create_autocmd("WinLeave", { - buffer = menu_bufnr, - once = true, - callback = close_menu, - }) + -- Close on focus lost + vim.api.nvim_create_autocmd("WinLeave", { + buffer = menu_bufnr, + once = true, + callback = close_menu, + }) end --- Setup keybindings for conflict resolution in a buffer ---@param bufnr number Buffer number function M.setup_keymaps(bufnr) - -- Use nowait to prevent delay from built-in 'c' command - local opts = { buffer = bufnr, silent = true, noremap = true, nowait = true } + -- Use nowait to prevent delay from built-in 'c' command + local opts = { buffer = bufnr, silent = true, noremap = true, nowait = true } - -- Accept ours (original) - vim.keymap.set("n", "co", function() - M.accept_ours(bufnr) - end, vim.tbl_extend("force", opts, { desc = "Accept CURRENT (original) code" })) + -- Accept ours (original) + vim.keymap.set("n", "co", function() + M.accept_ours(bufnr) + end, vim.tbl_extend("force", opts, { desc = "Accept CURRENT (original) code" })) - -- Accept theirs (AI suggestion) - vim.keymap.set("n", "ct", function() - M.accept_theirs(bufnr) - end, vim.tbl_extend("force", opts, { desc = "Accept INCOMING (AI suggestion) code" })) + -- Accept theirs (AI suggestion) + vim.keymap.set("n", "ct", function() + M.accept_theirs(bufnr) + end, vim.tbl_extend("force", opts, { desc = "Accept INCOMING (AI suggestion) code" })) - -- Accept both - vim.keymap.set("n", "cb", function() - M.accept_both(bufnr) - end, vim.tbl_extend("force", opts, { desc = "Accept BOTH versions" })) + -- Accept both + vim.keymap.set("n", "cb", function() + M.accept_both(bufnr) + end, vim.tbl_extend("force", opts, { desc = "Accept BOTH versions" })) - -- Accept none - vim.keymap.set("n", "cn", function() - M.accept_none(bufnr) - end, vim.tbl_extend("force", opts, { desc = "Delete conflict (accept NONE)" })) + -- Accept none + vim.keymap.set("n", "cn", function() + M.accept_none(bufnr) + end, vim.tbl_extend("force", opts, { desc = "Delete conflict (accept NONE)" })) - -- Navigate to next conflict - vim.keymap.set("n", "]x", function() - M.goto_next(bufnr) - end, vim.tbl_extend("force", opts, { desc = "Go to next conflict" })) + -- Navigate to next conflict + vim.keymap.set("n", "]x", function() + M.goto_next(bufnr) + end, vim.tbl_extend("force", opts, { desc = "Go to next conflict" })) - -- Navigate to previous conflict - vim.keymap.set("n", "[x", function() - M.goto_prev(bufnr) - end, vim.tbl_extend("force", opts, { desc = "Go to previous conflict" })) + -- Navigate to previous conflict + vim.keymap.set("n", "[x", function() + M.goto_prev(bufnr) + end, vim.tbl_extend("force", opts, { desc = "Go to previous conflict" })) - -- Show menu modal - vim.keymap.set("n", "cm", function() - M.show_floating_menu(bufnr) - end, vim.tbl_extend("force", opts, { desc = "Show conflict resolution menu" })) + -- Show menu modal + vim.keymap.set("n", "cm", function() + M.show_floating_menu(bufnr) + end, vim.tbl_extend("force", opts, { desc = "Show conflict resolution menu" })) - -- Also map to show menu when on conflict - vim.keymap.set("n", "", function() - local cursor = vim.api.nvim_win_get_cursor(0) - if M.get_conflict_at_cursor(bufnr, cursor[1]) then - M.show_floating_menu(bufnr) - else - -- Default behavior - vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes("", true, false, true), "n", false) - end - end, vim.tbl_extend("force", opts, { desc = "Show conflict menu or default action" })) + -- Also map to show menu when on conflict + vim.keymap.set("n", "", function() + local cursor = vim.api.nvim_win_get_cursor(0) + if M.get_conflict_at_cursor(bufnr, cursor[1]) then + M.show_floating_menu(bufnr) + else + -- Default behavior + vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes("", true, false, true), "n", false) + end + end, vim.tbl_extend("force", opts, { desc = "Show conflict menu or default action" })) - -- Mark buffer as having conflict keymaps - conflict_buffers[bufnr] = { - keymaps_set = true, - } + -- Mark buffer as having conflict keymaps + conflict_buffers[bufnr] = { + keymaps_set = true, + } end --- Remove keybindings from a buffer ---@param bufnr number Buffer number function M.remove_keymaps(bufnr) - if not conflict_buffers[bufnr] then - return - end + if not conflict_buffers[bufnr] then + return + end - pcall(vim.keymap.del, "n", "co", { buffer = bufnr }) - pcall(vim.keymap.del, "n", "ct", { buffer = bufnr }) - pcall(vim.keymap.del, "n", "cb", { buffer = bufnr }) - pcall(vim.keymap.del, "n", "cn", { buffer = bufnr }) - pcall(vim.keymap.del, "n", "cm", { buffer = bufnr }) - pcall(vim.keymap.del, "n", "]x", { buffer = bufnr }) - pcall(vim.keymap.del, "n", "[x", { buffer = bufnr }) - pcall(vim.keymap.del, "n", "", { buffer = bufnr }) + pcall(vim.keymap.del, "n", "co", { buffer = bufnr }) + pcall(vim.keymap.del, "n", "ct", { buffer = bufnr }) + pcall(vim.keymap.del, "n", "cb", { buffer = bufnr }) + pcall(vim.keymap.del, "n", "cn", { buffer = bufnr }) + pcall(vim.keymap.del, "n", "cm", { buffer = bufnr }) + pcall(vim.keymap.del, "n", "]x", { buffer = bufnr }) + pcall(vim.keymap.del, "n", "[x", { buffer = bufnr }) + pcall(vim.keymap.del, "n", "", { buffer = bufnr }) - conflict_buffers[bufnr] = nil + conflict_buffers[bufnr] = nil end --- Insert conflict markers for a code change @@ -852,201 +872,201 @@ end ---@param new_lines string[] New lines to insert as "incoming" ---@param label? string Optional label for the incoming section function M.insert_conflict(bufnr, start_line, end_line, new_lines, label) - if not vim.api.nvim_buf_is_valid(bufnr) then - return - end + if not vim.api.nvim_buf_is_valid(bufnr) then + return + end - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - -- Clamp to valid range - local line_count = #lines - start_line = math.max(1, math.min(start_line, line_count + 1)) - end_line = math.max(start_line, math.min(end_line, line_count)) + -- Clamp to valid range + local line_count = #lines + start_line = math.max(1, math.min(start_line, line_count + 1)) + end_line = math.max(start_line, math.min(end_line, line_count)) - -- Extract current lines - local current_lines = {} - for i = start_line, end_line do - if lines[i] then - table.insert(current_lines, lines[i]) - end - end + -- Extract current lines + local current_lines = {} + for i = start_line, end_line do + if lines[i] then + table.insert(current_lines, lines[i]) + end + end - -- Build conflict block - local conflict_block = {} - table.insert(conflict_block, MARKERS.current_start) - for _, line in ipairs(current_lines) do - table.insert(conflict_block, line) - end - table.insert(conflict_block, MARKERS.separator) - for _, line in ipairs(new_lines) do - table.insert(conflict_block, line) - end - table.insert(conflict_block, label and (">>>>>>> " .. label) or MARKERS.incoming_end) + -- Build conflict block + local conflict_block = {} + table.insert(conflict_block, MARKERS.current_start) + for _, line in ipairs(current_lines) do + table.insert(conflict_block, line) + end + table.insert(conflict_block, MARKERS.separator) + for _, line in ipairs(new_lines) do + table.insert(conflict_block, line) + end + table.insert(conflict_block, label and (">>>>>>> " .. label) or MARKERS.incoming_end) - -- Replace the range with conflict block - vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, conflict_block) + -- Replace the range with conflict block + vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, conflict_block) end --- Process buffer and auto-show menu for first conflict --- Call this after inserting conflict(s) to set up highlights and show menu ---@param bufnr number Buffer number function M.process_and_show_menu(bufnr) - if not vim.api.nvim_buf_is_valid(bufnr) then - return - end + if not vim.api.nvim_buf_is_valid(bufnr) then + return + end - -- Process to set up highlights and keymaps - local conflict_count = M.process(bufnr) + -- Process to set up highlights and keymaps + local conflict_count = M.process(bufnr) - -- Auto-show menu if enabled and conflicts exist - if config.auto_show_menu and conflict_count > 0 then - vim.schedule(function() - if not vim.api.nvim_buf_is_valid(bufnr) then - return - end + -- Auto-show menu if enabled and conflicts exist + if config.auto_show_menu and conflict_count > 0 then + vim.schedule(function() + if not vim.api.nvim_buf_is_valid(bufnr) then + return + end - -- Find window showing this buffer and focus it - local win = nil - for _, w in ipairs(vim.api.nvim_list_wins()) do - if vim.api.nvim_win_get_buf(w) == bufnr then - win = w - break - end - end + -- Find window showing this buffer and focus it + local win = nil + for _, w in ipairs(vim.api.nvim_list_wins()) do + if vim.api.nvim_win_get_buf(w) == bufnr then + win = w + break + end + end - if win then - vim.api.nvim_set_current_win(win) - -- Jump to first conflict - local conflicts = M.detect_conflicts(bufnr) - if #conflicts > 0 then - vim.api.nvim_win_set_cursor(win, { conflicts[1].start_line, 0 }) - vim.cmd("normal! zz") - -- Show the menu - M.show_floating_menu(bufnr) - end - end - end) - end + if win then + vim.api.nvim_set_current_win(win) + -- Jump to first conflict + local conflicts = M.detect_conflicts(bufnr) + if #conflicts > 0 then + vim.api.nvim_win_set_cursor(win, { conflicts[1].start_line, 0 }) + vim.cmd("normal! zz") + -- Show the menu + M.show_floating_menu(bufnr) + end + end + end) + end end --- Process a buffer for conflicts - detect, highlight, and setup keymaps ---@param bufnr number Buffer number ---@return number conflict_count Number of conflicts found function M.process(bufnr) - if not vim.api.nvim_buf_is_valid(bufnr) then - return 0 - end + if not vim.api.nvim_buf_is_valid(bufnr) then + return 0 + end - -- Setup highlights if not done - setup_highlights() + -- Setup highlights if not done + setup_highlights() - -- Detect conflicts - local conflicts = M.detect_conflicts(bufnr) + -- Detect conflicts + local conflicts = M.detect_conflicts(bufnr) - if #conflicts > 0 then - -- Highlight conflicts - M.highlight_conflicts(bufnr, conflicts) + if #conflicts > 0 then + -- Highlight conflicts + M.highlight_conflicts(bufnr, conflicts) - -- Setup keymaps if not already done - if not conflict_buffers[bufnr] then - M.setup_keymaps(bufnr) - end + -- Setup keymaps if not already done + if not conflict_buffers[bufnr] then + M.setup_keymaps(bufnr) + end - -- Log - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.info(string.format("Found %d conflict(s) - use co/ct/cb/cn to resolve, [x/]x to navigate", #conflicts)) - end) - else - -- No conflicts - clean up - vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) - vim.api.nvim_buf_clear_namespace(bufnr, HINT_NAMESPACE, 0, -1) - M.remove_keymaps(bufnr) - end + -- Log + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.info(string.format("Found %d conflict(s) - use co/ct/cb/cn to resolve, [x/]x to navigate", #conflicts)) + end) + else + -- No conflicts - clean up + vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) + vim.api.nvim_buf_clear_namespace(bufnr, HINT_NAMESPACE, 0, -1) + M.remove_keymaps(bufnr) + end - return #conflicts + return #conflicts end --- Check if a buffer has conflicts ---@param bufnr number Buffer number ---@return boolean function M.has_conflicts(bufnr) - local conflicts = M.detect_conflicts(bufnr) - return #conflicts > 0 + local conflicts = M.detect_conflicts(bufnr) + return #conflicts > 0 end --- Get conflict count for a buffer ---@param bufnr number Buffer number ---@return number function M.count_conflicts(bufnr) - local conflicts = M.detect_conflicts(bufnr) - return #conflicts + local conflicts = M.detect_conflicts(bufnr) + return #conflicts end --- Clear all conflicts from a buffer (remove markers but keep current code) ---@param bufnr number Buffer number ---@param keep "ours"|"theirs"|"both"|"none" Which version to keep function M.resolve_all(bufnr, keep) - local conflicts = M.detect_conflicts(bufnr) + local conflicts = M.detect_conflicts(bufnr) - -- Process in reverse order to maintain line numbers - for i = #conflicts, 1, -1 do - -- Move cursor to conflict - vim.api.nvim_win_set_cursor(0, { conflicts[i].start_line, 0 }) + -- Process in reverse order to maintain line numbers + for i = #conflicts, 1, -1 do + -- Move cursor to conflict + vim.api.nvim_win_set_cursor(0, { conflicts[i].start_line, 0 }) - -- Accept based on preference - if keep == "ours" then - M.accept_ours(bufnr) - elseif keep == "theirs" then - M.accept_theirs(bufnr) - elseif keep == "both" then - M.accept_both(bufnr) - else - M.accept_none(bufnr) - end - end + -- Accept based on preference + if keep == "ours" then + M.accept_ours(bufnr) + elseif keep == "theirs" then + M.accept_theirs(bufnr) + elseif keep == "both" then + M.accept_both(bufnr) + else + M.accept_none(bufnr) + end + end end --- Add a buffer to conflict tracking (for auto-follow) ---@param bufnr number Buffer number function M.add_tracked_buffer(bufnr) - if not conflict_buffers[bufnr] then - conflict_buffers[bufnr] = {} - end + if not conflict_buffers[bufnr] then + conflict_buffers[bufnr] = {} + end end --- Get all tracked buffers with conflicts ---@return number[] buffers List of buffer numbers function M.get_tracked_buffers() - local buffers = {} - for bufnr, _ in pairs(conflict_buffers) do - if vim.api.nvim_buf_is_valid(bufnr) and M.has_conflicts(bufnr) then - table.insert(buffers, bufnr) - end - end - return buffers + local buffers = {} + for bufnr, _ in pairs(conflict_buffers) do + if vim.api.nvim_buf_is_valid(bufnr) and M.has_conflicts(bufnr) then + table.insert(buffers, bufnr) + end + end + return buffers end --- Clear tracking for a buffer ---@param bufnr number Buffer number function M.clear_buffer(bufnr) - vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) - vim.api.nvim_buf_clear_namespace(bufnr, HINT_NAMESPACE, 0, -1) - M.remove_keymaps(bufnr) - conflict_buffers[bufnr] = nil + vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) + vim.api.nvim_buf_clear_namespace(bufnr, HINT_NAMESPACE, 0, -1) + M.remove_keymaps(bufnr) + conflict_buffers[bufnr] = nil end --- Initialize the conflict module function M.setup() - setup_highlights() + setup_highlights() - -- Auto-clean up when buffers are deleted - vim.api.nvim_create_autocmd("BufDelete", { - group = vim.api.nvim_create_augroup("CoderConflict", { clear = true }), - callback = function(ev) - conflict_buffers[ev.buf] = nil - end, - }) + -- Auto-clean up when buffers are deleted + vim.api.nvim_create_autocmd("BufDelete", { + group = vim.api.nvim_create_augroup("CoderConflict", { clear = true }), + callback = function(ev) + conflict_buffers[ev.buf] = nil + end, + }) end return M diff --git a/lua/codetyper/core/diff/diff.lua b/lua/codetyper/core/diff/diff.lua index 53c0ea1..da10e91 100644 --- a/lua/codetyper/core/diff/diff.lua +++ b/lua/codetyper/core/diff/diff.lua @@ -163,9 +163,9 @@ function M.show_diff(diff_data, callback) local final_help = {} for _, item in ipairs(help_msg) do if item[1] == "{path}" then - table.insert(final_help, { diff_data.path, item[2] }) + table.insert(final_help, { diff_data.path, item[2] }) else - table.insert(final_help, item) + table.insert(final_help, item) end end @@ -210,7 +210,7 @@ function M.show_bash_approval(command, callback) table.insert(lines, approval_prompts.divider) table.insert(lines, "") for _, opt in ipairs(approval_prompts.options) do - table.insert(lines, opt) + table.insert(lines, opt) end table.insert(lines, "") table.insert(lines, approval_prompts.divider) diff --git a/lua/codetyper/core/diff/patch.lua b/lua/codetyper/core/diff/patch.lua index f9e7add..836a107 100644 --- a/lua/codetyper/core/diff/patch.lua +++ b/lua/codetyper/core/diff/patch.lua @@ -10,20 +10,19 @@ local M = {} local params = require("codetyper.params.agents.patch") local logger = require("codetyper.support.logger") - --- Lazy load inject module to avoid circular requires local function get_inject_module() - return require("codetyper.inject") + return require("codetyper.inject") end --- Lazy load search_replace module local function get_search_replace_module() - return require("codetyper.core.diff.search_replace") + return require("codetyper.core.diff.search_replace") end --- Lazy load conflict module local function get_conflict_module() - return require("codetyper.core.diff.conflict") + return require("codetyper.core.diff.conflict") end --- Configuration for patch behavior @@ -62,8 +61,8 @@ local patch_counter = 0 --- Generate unique patch ID ---@return string function M.generate_id() - patch_counter = patch_counter + 1 - return string.format("patch_%d_%d", os.time(), patch_counter) + patch_counter = patch_counter + 1 + return string.format("patch_%d_%d", os.time(), patch_counter) end --- Hash buffer content in range @@ -72,23 +71,23 @@ end ---@param end_line number|nil 1-indexed, nil for whole buffer ---@return string local function hash_buffer_range(bufnr, start_line, end_line) - if not vim.api.nvim_buf_is_valid(bufnr) then - return "" - end + if not vim.api.nvim_buf_is_valid(bufnr) then + return "" + end - local lines - if start_line and end_line then - lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false) - else - lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - end + local lines + if start_line and end_line then + lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false) + else + lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + end - local content = table.concat(lines, "\n") - local hash = 0 - for i = 1, #content do - hash = (hash * 31 + string.byte(content, i)) % 2147483647 - end - return string.format("%x", hash) + local content = table.concat(lines, "\n") + local hash = 0 + for i = 1, #content do + hash = (hash * 31 + string.byte(content, i)) % 2147483647 + end + return string.format("%x", hash) end --- Take a snapshot of buffer state @@ -96,24 +95,24 @@ end ---@param range {start_line: number, end_line: number}|nil Optional range ---@return BufferSnapshot function M.snapshot_buffer(bufnr, range) - local changedtick = 0 - if vim.api.nvim_buf_is_valid(bufnr) then - changedtick = vim.api.nvim_buf_get_var(bufnr, "changedtick") or vim.b[bufnr].changedtick or 0 - end + local changedtick = 0 + if vim.api.nvim_buf_is_valid(bufnr) then + changedtick = vim.api.nvim_buf_get_var(bufnr, "changedtick") or vim.b[bufnr].changedtick or 0 + end - local content_hash - if range then - content_hash = hash_buffer_range(bufnr, range.start_line, range.end_line) - else - content_hash = hash_buffer_range(bufnr, nil, nil) - end + local content_hash + if range then + content_hash = hash_buffer_range(bufnr, range.start_line, range.end_line) + else + content_hash = hash_buffer_range(bufnr, nil, nil) + end - return { - bufnr = bufnr, - changedtick = changedtick, - content_hash = content_hash, - range = range, - } + return { + bufnr = bufnr, + changedtick = changedtick, + content_hash = content_hash, + range = range, + } end --- Check if buffer changed since snapshot @@ -121,34 +120,29 @@ end ---@return boolean is_stale ---@return string|nil reason function M.is_snapshot_stale(snapshot) - if not vim.api.nvim_buf_is_valid(snapshot.bufnr) then - return true, "buffer_invalid" - end + if not vim.api.nvim_buf_is_valid(snapshot.bufnr) then + return true, "buffer_invalid" + end - -- Check changedtick first (fast path) - local current_tick = vim.api.nvim_buf_get_var(snapshot.bufnr, "changedtick") - or vim.b[snapshot.bufnr].changedtick or 0 + -- Check changedtick first (fast path) + local current_tick = vim.api.nvim_buf_get_var(snapshot.bufnr, "changedtick") or vim.b[snapshot.bufnr].changedtick or 0 - if current_tick ~= snapshot.changedtick then - -- Changedtick differs, but might be just cursor movement - -- Verify with content hash - local current_hash - if snapshot.range then - current_hash = hash_buffer_range( - snapshot.bufnr, - snapshot.range.start_line, - snapshot.range.end_line - ) - else - current_hash = hash_buffer_range(snapshot.bufnr, nil, nil) - end + if current_tick ~= snapshot.changedtick then + -- Changedtick differs, but might be just cursor movement + -- Verify with content hash + local current_hash + if snapshot.range then + current_hash = hash_buffer_range(snapshot.bufnr, snapshot.range.start_line, snapshot.range.end_line) + else + current_hash = hash_buffer_range(snapshot.bufnr, nil, nil) + end - if current_hash ~= snapshot.content_hash then - return true, "content_changed" - end - end + if current_hash ~= snapshot.content_hash then + return true, "content_changed" + end + end - return false, nil + return false, nil end --- Check if a patch is stale @@ -156,38 +150,35 @@ end ---@return boolean ---@return string|nil reason function M.is_stale(patch) - return M.is_snapshot_stale(patch.original_snapshot) + return M.is_snapshot_stale(patch.original_snapshot) end --- Queue a patch for deferred application ---@param patch PatchCandidate ---@return PatchCandidate function M.queue_patch(patch) - patch.id = patch.id or M.generate_id() - patch.status = patch.status or "pending" - patch.created_at = patch.created_at or os.time() + patch.id = patch.id or M.generate_id() + patch.status = patch.status or "pending" + patch.created_at = patch.created_at or os.time() - table.insert(patches, patch) + table.insert(patches, patch) - -- Log patch creation - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "patch", - message = string.format( - "Patch queued: %s (confidence: %.2f)", - patch.id, patch.confidence or 0 - ), - data = { - patch_id = patch.id, - event_id = patch.event_id, - target_path = patch.target_path, - code_preview = patch.generated_code:sub(1, 50), - }, - }) - end) + -- Log patch creation + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "patch", + message = string.format("Patch queued: %s (confidence: %.2f)", patch.id, patch.confidence or 0), + data = { + patch_id = patch.id, + event_id = patch.event_id, + target_path = patch.target_path, + code_preview = patch.generated_code:sub(1, 50), + }, + }) + end) - return patch + return patch end --- Create patch from event and response @@ -197,231 +188,240 @@ end ---@param strategy string|nil Injection strategy (overrides intent-based) ---@return PatchCandidate function M.create_from_event(event, generated_code, confidence, strategy) - -- Source buffer is where the prompt tags are (could be coder file) - local source_bufnr = event.bufnr + -- Source buffer is where the prompt tags are (could be coder file) + local source_bufnr = event.bufnr - -- Get target buffer (where code should be injected - the real file) - local target_bufnr = vim.fn.bufnr(event.target_path) - if target_bufnr == -1 then - -- Try to find by filename - for _, buf in ipairs(vim.api.nvim_list_bufs()) do - local name = vim.api.nvim_buf_get_name(buf) - if name == event.target_path then - target_bufnr = buf - break - end - end - end + -- Get target buffer (where code should be injected - the real file) + local target_bufnr = vim.fn.bufnr(event.target_path) + if target_bufnr == -1 then + -- Try to find by filename + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + local name = vim.api.nvim_buf_get_name(buf) + if name == event.target_path then + target_bufnr = buf + break + end + end + end - -- Detect if this is an inline prompt (source == target, not a .codetyper/ file) - local is_inline = (source_bufnr == target_bufnr) or - (event.target_path and not event.target_path:match("%.codetyper%.")) + -- Detect if this is an inline prompt (source == target, not a .codetyper/ file) + local is_inline = (source_bufnr == target_bufnr) + or (event.target_path and not event.target_path:match("%.codetyper%.")) - -- Take snapshot of the scope range in target buffer (for staleness detection) - local snapshot_range = event.scope_range or event.range - local snapshot = M.snapshot_buffer( - target_bufnr ~= -1 and target_bufnr or event.bufnr, - snapshot_range - ) + -- Take snapshot of the scope range in target buffer (for staleness detection) + local snapshot_range = event.scope_range or event.range + local snapshot = M.snapshot_buffer(target_bufnr ~= -1 and target_bufnr or event.bufnr, snapshot_range) - -- Check if the response contains SEARCH/REPLACE blocks - local search_replace = get_search_replace_module() - local sr_blocks = search_replace.parse_blocks(generated_code) - local use_search_replace = #sr_blocks > 0 + -- Check if the response contains SEARCH/REPLACE blocks + local search_replace = get_search_replace_module() + local sr_blocks = search_replace.parse_blocks(generated_code) + local use_search_replace = #sr_blocks > 0 - -- Determine injection strategy and range based on intent - local injection_strategy = strategy - local injection_range = nil + -- Determine injection strategy and range based on intent + local injection_strategy = strategy + local injection_range = nil - -- Handle intent_override from transform-selection (e.g., cursor insert mode) - if event.intent_override and event.intent_override.action then - injection_strategy = event.intent_override.action - -- Use injection_range from transform-selection, not event.range - injection_range = event.injection_range or (event.range and { - start_line = event.range.start_line, - end_line = event.range.end_line, - }) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Using override strategy: %s (%s)", injection_strategy, - injection_range and (injection_range.start_line .. "-" .. injection_range.end_line) or "nil"), - }) - end) - -- If we have SEARCH/REPLACE blocks, use that strategy - elseif use_search_replace then - injection_strategy = "search_replace" - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Using SEARCH/REPLACE mode with %d block(s)", #sr_blocks), - }) - end) - elseif is_inline and event.range then - -- Inline prompts: always replace the selection (we asked LLM for "code that replaces lines X-Y") - injection_strategy = "replace" - local start_line = math.max(1, event.range.start_line or 1) - local end_line = math.max(1, event.range.end_line or 1) - if end_line < start_line then - end_line = start_line - end - -- Prefer scope_range if event.range is invalid (0-0) and we have scope - if (event.range.start_line == 0 or event.range.end_line == 0) and event.scope_range then - start_line = math.max(1, event.scope_range.start_line or 1) - end_line = math.max(1, event.scope_range.end_line or 1) - if end_line < start_line then - end_line = start_line - end - end - injection_range = { start_line = start_line, end_line = end_line } - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Inline: replace lines %d-%d", start_line, end_line), - }) - end) - elseif not injection_strategy and event.intent then - local intent_mod = require("codetyper.core.intent") - if intent_mod.is_replacement(event.intent) then - injection_strategy = "replace" + -- Handle intent_override from transform-selection (e.g., cursor insert mode) + if event.intent_override and event.intent_override.action then + injection_strategy = event.intent_override.action + -- Use injection_range from transform-selection, not event.range + injection_range = event.injection_range + or ( + event.range + and { + start_line = event.range.start_line, + end_line = event.range.end_line, + } + ) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format( + "Using override strategy: %s (%s)", + injection_strategy, + injection_range and (injection_range.start_line .. "-" .. injection_range.end_line) or "nil" + ), + }) + end) + -- If we have SEARCH/REPLACE blocks, use that strategy + elseif use_search_replace then + injection_strategy = "search_replace" + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Using SEARCH/REPLACE mode with %d block(s)", #sr_blocks), + }) + end) + elseif is_inline and event.range then + -- Inline prompts: always replace the selection (we asked LLM for "code that replaces lines X-Y") + injection_strategy = "replace" + local start_line = math.max(1, event.range.start_line or 1) + local end_line = math.max(1, event.range.end_line or 1) + if end_line < start_line then + end_line = start_line + end + -- Prefer scope_range if event.range is invalid (0-0) and we have scope + if (event.range.start_line == 0 or event.range.end_line == 0) and event.scope_range then + start_line = math.max(1, event.scope_range.start_line or 1) + end_line = math.max(1, event.scope_range.end_line or 1) + if end_line < start_line then + end_line = start_line + end + end + injection_range = { start_line = start_line, end_line = end_line } + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Inline: replace lines %d-%d", start_line, end_line), + }) + end) + elseif not injection_strategy and event.intent then + local intent_mod = require("codetyper.core.intent") + if intent_mod.is_replacement(event.intent) then + injection_strategy = "replace" - -- INLINE PROMPTS: Always use tag range - -- The LLM is told specifically to replace the tagged region - if is_inline and event.range then - injection_range = { - start_line = event.range.start_line, - end_line = event.range.end_line, - } - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Inline prompt: will replace tag region (lines %d-%d)", - event.range.start_line, event.range.end_line), - }) - end) - -- CODER FILES: Use scope range for replacement - elseif event.scope_range then - injection_range = event.scope_range - else - -- Fallback: no scope found (treesitter didn't find function) - -- Use tag range - the generated code will replace the tag region - injection_range = { - start_line = event.range.start_line, - end_line = event.range.end_line, - } - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "warning", - message = "No scope found, using tag range as fallback", - }) - end) - end - elseif event.intent.action == "insert" then - injection_strategy = "insert" - -- Insert at prompt location (use full tag range) - injection_range = { start_line = event.range.start_line, end_line = event.range.end_line } - elseif event.intent.action == "append" then - injection_strategy = "append" - -- Will append to end of file - else - injection_strategy = "append" - end - end + -- INLINE PROMPTS: Always use tag range + -- The LLM is told specifically to replace the tagged region + if is_inline and event.range then + injection_range = { + start_line = event.range.start_line, + end_line = event.range.end_line, + } + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format( + "Inline prompt: will replace tag region (lines %d-%d)", + event.range.start_line, + event.range.end_line + ), + }) + end) + -- CODER FILES: Use scope range for replacement + elseif event.scope_range then + injection_range = event.scope_range + else + -- Fallback: no scope found (treesitter didn't find function) + -- Use tag range - the generated code will replace the tag region + injection_range = { + start_line = event.range.start_line, + end_line = event.range.end_line, + } + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "warning", + message = "No scope found, using tag range as fallback", + }) + end) + end + elseif event.intent.action == "insert" then + injection_strategy = "insert" + -- Insert at prompt location (use full tag range) + injection_range = { start_line = event.range.start_line, end_line = event.range.end_line } + elseif event.intent.action == "append" then + injection_strategy = "append" + -- Will append to end of file + else + injection_strategy = "append" + end + end - injection_strategy = injection_strategy or "append" + injection_strategy = injection_strategy or "append" - local range_str = injection_range - and string.format("%d-%d", injection_range.start_line, injection_range.end_line) - or "nil" - logger.info("patch", string.format( - "create: is_inline=%s strategy=%s range=%s use_sr=%s intent_action=%s", - tostring(is_inline), - injection_strategy, - range_str, - tostring(use_search_replace), - event.intent and event.intent.action or "nil" - )) + local range_str = injection_range and string.format("%d-%d", injection_range.start_line, injection_range.end_line) + or "nil" + logger.info( + "patch", + string.format( + "create: is_inline=%s strategy=%s range=%s use_sr=%s intent_action=%s", + tostring(is_inline), + injection_strategy, + range_str, + tostring(use_search_replace), + event.intent and event.intent.action or "nil" + ) + ) - return { - id = M.generate_id(), - event_id = event.id, - source_bufnr = source_bufnr, -- Where prompt tags are (coder file) - target_bufnr = target_bufnr, -- Where code goes (real file) - target_path = event.target_path, - original_snapshot = snapshot, - generated_code = generated_code, - injection_range = injection_range, - injection_strategy = injection_strategy, - confidence = confidence, - status = "pending", - created_at = os.time(), - intent = event.intent, - scope = event.scope, - -- Store the prompt tag range so we can delete it after applying - prompt_tag_range = event.range, - -- Mark if this is an inline prompt (tags in source file, not coder file) - is_inline_prompt = is_inline, - -- SEARCH/REPLACE support - use_search_replace = use_search_replace, - search_replace_blocks = use_search_replace and sr_blocks or nil, - -- Extmarks for injection range (99-style: apply at current position after user edits) - injection_marks = event.injection_marks, - } + return { + id = M.generate_id(), + event_id = event.id, + source_bufnr = source_bufnr, -- Where prompt tags are (coder file) + target_bufnr = target_bufnr, -- Where code goes (real file) + target_path = event.target_path, + original_snapshot = snapshot, + generated_code = generated_code, + injection_range = injection_range, + injection_strategy = injection_strategy, + confidence = confidence, + status = "pending", + created_at = os.time(), + intent = event.intent, + scope = event.scope, + -- Store the prompt tag range so we can delete it after applying + prompt_tag_range = event.range, + -- Mark if this is an inline prompt (tags in source file, not coder file) + is_inline_prompt = is_inline, + -- SEARCH/REPLACE support + use_search_replace = use_search_replace, + search_replace_blocks = use_search_replace and sr_blocks or nil, + -- Extmarks for injection range (99-style: apply at current position after user edits) + injection_marks = event.injection_marks, + } end --- Get all pending patches ---@return PatchCandidate[] function M.get_pending() - local pending = {} - for _, patch in ipairs(patches) do - if patch.status == "pending" then - table.insert(pending, patch) - end - end - return pending + local pending = {} + for _, patch in ipairs(patches) do + if patch.status == "pending" then + table.insert(pending, patch) + end + end + return pending end --- Get patch by ID ---@param id string ---@return PatchCandidate|nil function M.get(id) - for _, patch in ipairs(patches) do - if patch.id == id then - return patch - end - end - return nil + for _, patch in ipairs(patches) do + if patch.id == id then + return patch + end + end + return nil end --- Get patches for event ---@param event_id string ---@return PatchCandidate[] function M.get_for_event(event_id) - local result = {} - for _, patch in ipairs(patches) do - if patch.event_id == event_id then - table.insert(result, patch) - end - end - return result + local result = {} + for _, patch in ipairs(patches) do + if patch.event_id == event_id then + table.insert(result, patch) + end + end + return result end --- Mark patch as applied ---@param id string ---@return boolean function M.mark_applied(id) - local patch = M.get(id) - if patch then - patch.status = "applied" - patch.applied_at = os.time() - return true - end - return false + local patch = M.get(id) + if patch then + patch.status = "applied" + patch.applied_at = os.time() + return true + end + return false end --- Mark patch as stale @@ -429,13 +429,13 @@ end ---@param reason string|nil ---@return boolean function M.mark_stale(id, reason) - local patch = M.get(id) - if patch then - patch.status = "stale" - patch.stale_reason = reason - return true - end - return false + local patch = M.get(id) + if patch then + patch.status = "stale" + patch.stale_reason = reason + return true + end + return false end --- Mark patch as rejected @@ -443,31 +443,31 @@ end ---@param reason string|nil ---@return boolean function M.mark_rejected(id, reason) - local patch = M.get(id) - if patch then - patch.status = "rejected" - patch.reject_reason = reason - return true - end - return false + local patch = M.get(id) + if patch then + patch.status = "rejected" + patch.reject_reason = reason + return true + end + return false end --- Check if it's safe to modify the buffer (not in insert or visual mode) ---@return boolean local function is_safe_to_modify() - local mode = vim.fn.mode() - -- Don't modify if in insert mode, visual mode, or completion is visible - if mode == "i" or mode == "ic" or mode == "ix" then - return false - end - -- Visual modes: v (char), V (line), \22 (block) - if mode == "v" or mode == "V" or mode == "\22" then - return false - end - if vim.fn.pumvisible() == 1 then - return false - end - return true + local mode = vim.fn.mode() + -- Don't modify if in insert mode, visual mode, or completion is visible + if mode == "i" or mode == "ic" or mode == "ix" then + return false + end + -- Visual modes: v (char), V (line), \22 (block) + if mode == "v" or mode == "V" or mode == "\22" then + return false + end + if vim.fn.pumvisible() == 1 then + return false + end + return true end --- Apply a patch to the target buffer @@ -475,374 +475,402 @@ end ---@return boolean success ---@return string|nil error function M.apply(patch) - logger.info("patch", string.format("apply() entered: id=%s strategy=%s has_range=%s", patch.id, tostring(patch.injection_strategy), patch.injection_range and "yes" or "no")) + logger.info( + "patch", + string.format( + "apply() entered: id=%s strategy=%s has_range=%s", + patch.id, + tostring(patch.injection_strategy), + patch.injection_range and "yes" or "no" + ) + ) - -- Check if safe to modify (not in insert or visual mode) - if not is_safe_to_modify() then - logger.info("patch", "apply aborted: not safe (insert/visual mode or pum visible)") - return false, "user_typing" - end + -- Check if safe to modify (not in insert or visual mode) + if not is_safe_to_modify() then + logger.info("patch", "apply aborted: not safe (insert/visual mode or pum visible)") + return false, "user_typing" + end - -- Check staleness (skip when we have valid extmarks - 99-style: position tracked across edits) - local is_stale, stale_reason = true, nil - if patch.injection_marks and patch.injection_marks.start_mark and patch.injection_marks.end_mark then - local marks_mod = require("codetyper.core.marks") - if marks_mod.is_valid(patch.injection_marks.start_mark) and marks_mod.is_valid(patch.injection_marks.end_mark) then - is_stale = false - end - end - if is_stale then - is_stale, stale_reason = M.is_stale(patch) - end - if is_stale then - M.mark_stale(patch.id, stale_reason) - logger.warn("patch", string.format("Patch %s is stale: %s", patch.id, stale_reason or "unknown")) - return false, "patch_stale: " .. (stale_reason or "unknown") - end + -- Check staleness (skip when we have valid extmarks - 99-style: position tracked across edits) + local is_stale, stale_reason = true, nil + if patch.injection_marks and patch.injection_marks.start_mark and patch.injection_marks.end_mark then + local marks_mod = require("codetyper.core.marks") + if marks_mod.is_valid(patch.injection_marks.start_mark) and marks_mod.is_valid(patch.injection_marks.end_mark) then + is_stale = false + end + end + if is_stale then + is_stale, stale_reason = M.is_stale(patch) + end + if is_stale then + M.mark_stale(patch.id, stale_reason) + logger.warn("patch", string.format("Patch %s is stale: %s", patch.id, stale_reason or "unknown")) + return false, "patch_stale: " .. (stale_reason or "unknown") + end - -- Ensure target buffer is valid - local target_bufnr = patch.target_bufnr - if target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then - -- Try to load buffer from path - target_bufnr = vim.fn.bufadd(patch.target_path) - if target_bufnr == 0 then - M.mark_rejected(patch.id, "buffer_not_found") - return false, "target buffer not found" - end - vim.fn.bufload(target_bufnr) - patch.target_bufnr = target_bufnr - end + -- Ensure target buffer is valid + local target_bufnr = patch.target_bufnr + if target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then + -- Try to load buffer from path + target_bufnr = vim.fn.bufadd(patch.target_path) + if target_bufnr == 0 then + M.mark_rejected(patch.id, "buffer_not_found") + return false, "target buffer not found" + end + vim.fn.bufload(target_bufnr) + patch.target_bufnr = target_bufnr + end - -- Prepare code to inject (may be overwritten when SEARCH/REPLACE fails and we use REPLACE parts only) - local code_to_inject = patch.generated_code - local code_lines = vim.split(patch.generated_code, "\n", { plain = true }) + -- Prepare code to inject (may be overwritten when SEARCH/REPLACE fails and we use REPLACE parts only) + local code_to_inject = patch.generated_code + local code_lines = vim.split(patch.generated_code, "\n", { plain = true }) - -- Replace in-buffer thinking placeholder with actual code (if we inserted one when worker started). - -- Skip when patch uses SEARCH/REPLACE: that path needs the original buffer content and parses blocks itself. - local thinking_placeholder = require("codetyper.core.thinking_placeholder") - local ph = thinking_placeholder.get(patch.event_id) - if ph and ph.bufnr and vim.api.nvim_buf_is_valid(ph.bufnr) - and not (patch.use_search_replace and patch.search_replace_blocks and #patch.search_replace_blocks > 0) then - local marks_mod = require("codetyper.core.marks") - if marks_mod.is_valid(ph.start_mark) and marks_mod.is_valid(ph.end_mark) then - local sr, sc, er, ec = marks_mod.range_to_vim(ph.start_mark, ph.end_mark) - if sr ~= nil then - vim.api.nvim_buf_set_text(ph.bufnr, sr, sc, er, ec, code_lines) - thinking_placeholder.clear(patch.event_id) - M.mark_applied(patch.id) - return true - end - end - thinking_placeholder.clear(patch.event_id) - end + -- Replace in-buffer thinking placeholder with actual code (if we inserted one when worker started). + -- Skip when patch uses SEARCH/REPLACE: that path needs the original buffer content and parses blocks itself. + local thinking_placeholder = require("codetyper.core.thinking_placeholder") + local ph = thinking_placeholder.get(patch.event_id) + if + ph + and ph.bufnr + and vim.api.nvim_buf_is_valid(ph.bufnr) + and not (patch.use_search_replace and patch.search_replace_blocks and #patch.search_replace_blocks > 0) + then + local marks_mod = require("codetyper.core.marks") + if marks_mod.is_valid(ph.start_mark) and marks_mod.is_valid(ph.end_mark) then + local sr, sc, er, ec = marks_mod.range_to_vim(ph.start_mark, ph.end_mark) + if sr ~= nil then + vim.api.nvim_buf_set_text(ph.bufnr, sr, sc, er, ec, code_lines) + thinking_placeholder.clear(patch.event_id) + M.mark_applied(patch.id) + return true + end + end + thinking_placeholder.clear(patch.event_id) + end - -- Use the stored inline prompt flag (computed during patch creation) - -- For inline prompts, we replace the tag region directly instead of separate remove + inject - local source_bufnr = patch.source_bufnr - local is_inline_prompt = patch.is_inline_prompt or (source_bufnr == target_bufnr) - local tags_removed = 0 + -- Use the stored inline prompt flag (computed during patch creation) + -- For inline prompts, we replace the tag region directly instead of separate remove + inject + local source_bufnr = patch.source_bufnr + local is_inline_prompt = patch.is_inline_prompt or (source_bufnr == target_bufnr) + local tags_removed = 0 - -- Get filetype for smart injection - local filetype = vim.fn.fnamemodify(patch.target_path or "", ":e") + -- Get filetype for smart injection + local filetype = vim.fn.fnamemodify(patch.target_path or "", ":e") - if patch.use_search_replace and patch.search_replace_blocks and #patch.search_replace_blocks > 0 then - -- Apply SEARCH/REPLACE blocks - local search_replace = get_search_replace_module() - local success, err = search_replace.apply_to_buffer(target_bufnr, patch.search_replace_blocks) + if patch.use_search_replace and patch.search_replace_blocks and #patch.search_replace_blocks > 0 then + -- Apply SEARCH/REPLACE blocks + local search_replace = get_search_replace_module() + local success, err = search_replace.apply_to_buffer(target_bufnr, patch.search_replace_blocks) - if success then - M.mark_applied(patch.id) + if success then + M.mark_applied(patch.id) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "success", - message = string.format("Patch %s applied via SEARCH/REPLACE (%d block(s))", - patch.id, #patch.search_replace_blocks), - data = { - target_path = patch.target_path, - blocks_applied = #patch.search_replace_blocks, - }, - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "success", + message = string.format( + "Patch %s applied via SEARCH/REPLACE (%d block(s))", + patch.id, + #patch.search_replace_blocks + ), + data = { + target_path = patch.target_path, + blocks_applied = #patch.search_replace_blocks, + }, + }) + end) - -- Learn from successful code generation - pcall(function() - local brain = require("codetyper.core.memory") - if brain.is_initialized() then - local intent_type = patch.intent and patch.intent.type or "unknown" - brain.learn({ - type = "code_completion", - file = patch.target_path, - timestamp = os.time(), - data = { - intent = intent_type, - method = "search_replace", - language = filetype, - confidence = patch.confidence or 0.5, - }, - }) - end - end) + -- Learn from successful code generation + pcall(function() + local brain = require("codetyper.core.memory") + if brain.is_initialized() then + local intent_type = patch.intent and patch.intent.type or "unknown" + brain.learn({ + type = "code_completion", + file = patch.target_path, + timestamp = os.time(), + data = { + intent = intent_type, + method = "search_replace", + language = filetype, + confidence = patch.confidence or 0.5, + }, + }) + end + end) - return true, nil - else - -- SEARCH/REPLACE failed: use only REPLACE parts for fallback (never inject raw markers) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "warning", - message = string.format("SEARCH/REPLACE failed: %s. Using REPLACE content only for injection.", err or "unknown"), - }) - end) - local replace_only = {} - for _, block in ipairs(patch.search_replace_blocks) do - if block.replace and block.replace ~= "" then - for _, line in ipairs(vim.split(block.replace, "\n", { plain = true })) do - table.insert(replace_only, line) - end - end - end - if #replace_only > 0 then - code_lines = replace_only - code_to_inject = table.concat(replace_only, "\n") - end - -- Fall through to line-based injection - end - end + return true, nil + else + -- SEARCH/REPLACE failed: use only REPLACE parts for fallback (never inject raw markers) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "warning", + message = string.format( + "SEARCH/REPLACE failed: %s. Using REPLACE content only for injection.", + err or "unknown" + ), + }) + end) + local replace_only = {} + for _, block in ipairs(patch.search_replace_blocks) do + if block.replace and block.replace ~= "" then + for _, line in ipairs(vim.split(block.replace, "\n", { plain = true })) do + table.insert(replace_only, line) + end + end + end + if #replace_only > 0 then + code_lines = replace_only + code_to_inject = table.concat(replace_only, "\n") + end + -- Fall through to line-based injection + end + end - -- Use smart injection module for intelligent import handling - local inject = get_inject_module() - local inject_result = nil + -- Use smart injection module for intelligent import handling + local inject = get_inject_module() + local inject_result = nil - local has_range = patch.injection_range ~= nil - local apply_msg = string.format("apply: id=%s strategy=%s has_range=%s is_inline=%s target_bufnr=%s", - patch.id, - patch.injection_strategy or "nil", - tostring(has_range), - tostring(is_inline_prompt), - tostring(target_bufnr)) - logger.info("patch", apply_msg) + local has_range = patch.injection_range ~= nil + local apply_msg = string.format( + "apply: id=%s strategy=%s has_range=%s is_inline=%s target_bufnr=%s", + patch.id, + patch.injection_strategy or "nil", + tostring(has_range), + tostring(is_inline_prompt), + tostring(target_bufnr) + ) + logger.info("patch", apply_msg) - -- Apply based on strategy using smart injection - local ok, err = pcall(function() - -- Prepare injection options - local inject_opts = { - strategy = patch.injection_strategy, - filetype = filetype, - sort_imports = true, - } + -- Apply based on strategy using smart injection + local ok, err = pcall(function() + -- Prepare injection options + local inject_opts = { + strategy = patch.injection_strategy, + filetype = filetype, + sort_imports = true, + } - if patch.injection_strategy == "replace" and patch.injection_range then - -- Replace the scope range with the new code - local start_line = patch.injection_range.start_line - local end_line = patch.injection_range.end_line + if patch.injection_strategy == "replace" and patch.injection_range then + -- Replace the scope range with the new code + local start_line = patch.injection_range.start_line + local end_line = patch.injection_range.end_line - -- 99-style: use extmarks so we apply at current position (survives user typing) - local marks = require("codetyper.core.marks") - if patch.injection_marks and patch.injection_marks.start_mark and patch.injection_marks.end_mark then - local sm, em = patch.injection_marks.start_mark, patch.injection_marks.end_mark - if marks.is_valid(sm) and marks.is_valid(em) then - local sr, sc, er, ec = marks.range_to_vim(sm, em) - if sr ~= nil then - start_line = sr + 1 - end_line = er + 1 - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Applying at extmark range (lines %d-%d)", start_line, end_line), - }) - end) - marks.delete(sm) - marks.delete(em) - end - end - end + -- 99-style: use extmarks so we apply at current position (survives user typing) + local marks = require("codetyper.core.marks") + if patch.injection_marks and patch.injection_marks.start_mark and patch.injection_marks.end_mark then + local sm, em = patch.injection_marks.start_mark, patch.injection_marks.end_mark + if marks.is_valid(sm) and marks.is_valid(em) then + local sr, sc, er, ec = marks.range_to_vim(sm, em) + if sr ~= nil then + start_line = sr + 1 + end_line = er + 1 + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Applying at extmark range (lines %d-%d)", start_line, end_line), + }) + end) + marks.delete(sm) + marks.delete(em) + end + end + end - -- For inline prompts, use scope range directly (tags are inside scope) - -- No adjustment needed since we didn't remove tags yet - if not is_inline_prompt and patch.scope and patch.scope.type then - -- For coder files, tags were already removed, so we may need to find the scope again - local found_range = nil - pcall(function() - local parsers = require("nvim-treesitter.parsers") - local parser = parsers.get_parser(target_bufnr) - if parser then - local tree = parser:parse()[1] - if tree then - local root = tree:root() - -- Find the function/method node that contains our original position - local function find_scope_node(node) - local node_type = node:type() - local is_scope = node_type:match("function") - or node_type:match("method") - or node_type:match("class") - or node_type:match("declaration") + -- For inline prompts, use scope range directly (tags are inside scope) + -- No adjustment needed since we didn't remove tags yet + if not is_inline_prompt and patch.scope and patch.scope.type then + -- For coder files, tags were already removed, so we may need to find the scope again + local found_range = nil + pcall(function() + local parsers = require("nvim-treesitter.parsers") + local parser = parsers.get_parser(target_bufnr) + if parser then + local tree = parser:parse()[1] + if tree then + local root = tree:root() + -- Find the function/method node that contains our original position + local function find_scope_node(node) + local node_type = node:type() + local is_scope = node_type:match("function") + or node_type:match("method") + or node_type:match("class") + or node_type:match("declaration") - if is_scope then - local s_row, _, e_row, _ = node:range() - -- Check if this scope roughly matches our expected range - if math.abs(s_row - (start_line - 1)) <= 5 then - found_range = { start_line = s_row + 1, end_line = e_row + 1 } - return true - end - end + if is_scope then + local s_row, _, e_row, _ = node:range() + -- Check if this scope roughly matches our expected range + if math.abs(s_row - (start_line - 1)) <= 5 then + found_range = { start_line = s_row + 1, end_line = e_row + 1 } + return true + end + end - for child in node:iter_children() do - if find_scope_node(child) then - return true - end - end - return false - end - find_scope_node(root) - end - end - end) + for child in node:iter_children() do + if find_scope_node(child) then + return true + end + end + return false + end + find_scope_node(root) + end + end + end) - if found_range then - start_line = found_range.start_line - end_line = found_range.end_line - end - end + if found_range then + start_line = found_range.start_line + end_line = found_range.end_line + end + end - -- Clamp to valid range - local line_count = vim.api.nvim_buf_line_count(target_bufnr) - start_line = math.max(1, start_line) - end_line = math.min(line_count, end_line) + -- Clamp to valid range + local line_count = vim.api.nvim_buf_line_count(target_bufnr) + start_line = math.max(1, start_line) + end_line = math.min(line_count, end_line) - inject_opts.range = { start_line = start_line, end_line = end_line } - elseif patch.injection_strategy == "insert" and patch.injection_range then - -- For inline prompts with "insert" strategy, replace the TAG RANGE - -- (the tag itself gets replaced with the new code) - if is_inline_prompt and patch.prompt_tag_range then - inject_opts.range = { - start_line = patch.prompt_tag_range.start_line, - end_line = patch.prompt_tag_range.end_line - } - -- Switch to replace strategy for the tag range - inject_opts.strategy = "replace" - else - inject_opts.range = { start_line = patch.injection_range.start_line } - end - end + inject_opts.range = { start_line = start_line, end_line = end_line } + elseif patch.injection_strategy == "insert" and patch.injection_range then + -- For inline prompts with "insert" strategy, replace the TAG RANGE + -- (the tag itself gets replaced with the new code) + if is_inline_prompt and patch.prompt_tag_range then + inject_opts.range = { + start_line = patch.prompt_tag_range.start_line, + end_line = patch.prompt_tag_range.end_line, + } + -- Switch to replace strategy for the tag range + inject_opts.strategy = "replace" + else + inject_opts.range = { start_line = patch.injection_range.start_line } + end + end - -- Log inline prompt handling - if is_inline_prompt then - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Inline prompt: replacing lines %d-%d", - inject_opts.range and inject_opts.range.start_line or 0, - inject_opts.range and inject_opts.range.end_line or 0), - }) - end) - end + -- Log inline prompt handling + if is_inline_prompt then + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format( + "Inline prompt: replacing lines %d-%d", + inject_opts.range and inject_opts.range.start_line or 0, + inject_opts.range and inject_opts.range.end_line or 0 + ), + }) + end) + end - -- Diagnostic: log inject_opts before calling inject (why injection might not run) - local range_str = inject_opts.range - and string.format("%d-%d", inject_opts.range.start_line, inject_opts.range.end_line) - or "nil" - logger.info("patch", string.format( - "inject_opts: strategy=%s range=%s code_len=%d", - inject_opts.strategy or "nil", - range_str, - code_to_inject and #code_to_inject or 0 - )) + -- Diagnostic: log inject_opts before calling inject (why injection might not run) + local range_str = inject_opts.range + and string.format("%d-%d", inject_opts.range.start_line, inject_opts.range.end_line) + or "nil" + logger.info( + "patch", + string.format( + "inject_opts: strategy=%s range=%s code_len=%d", + inject_opts.strategy or "nil", + range_str, + code_to_inject and #code_to_inject or 0 + ) + ) - if not inject_opts.range then - logger.warn("patch", string.format( - "inject has no range (strategy=%s) - inject may append or skip", - tostring(patch.injection_strategy) - )) - end + if not inject_opts.range then + logger.warn( + "patch", + string.format( + "inject has no range (strategy=%s) - inject may append or skip", + tostring(patch.injection_strategy) + ) + ) + end - -- Use smart injection - handles imports automatically - inject_result = inject.inject(target_bufnr, code_to_inject, inject_opts) + -- Use smart injection - handles imports automatically + inject_result = inject.inject(target_bufnr, code_to_inject, inject_opts) - -- Log injection details - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - if inject_result.imports_added > 0 then - logs.add({ - type = "info", - message = string.format( - "%s %d import(s), injected %d body line(s)", - inject_result.imports_merged and "Merged" or "Added", - inject_result.imports_added, - inject_result.body_lines - ), - }) - else - logs.add({ - type = "info", - message = string.format("Injected %d line(s) of code", inject_result.body_lines), - }) - end - end) - end) + -- Log injection details + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + if inject_result.imports_added > 0 then + logs.add({ + type = "info", + message = string.format( + "%s %d import(s), injected %d body line(s)", + inject_result.imports_merged and "Merged" or "Added", + inject_result.imports_added, + inject_result.body_lines + ), + }) + else + logs.add({ + type = "info", + message = string.format("Injected %d line(s) of code", inject_result.body_lines), + }) + end + end) + end) - if not ok then - logger.error("patch", string.format("inject failed: %s", tostring(err))) - M.mark_rejected(patch.id, err) - return false, err - end + if not ok then + logger.error("patch", string.format("inject failed: %s", tostring(err))) + M.mark_rejected(patch.id, err) + return false, err + end - local body_lines = inject_result and inject_result.body_lines or "nil" - logger.info("patch", string.format("inject done: body_lines=%s", tostring(body_lines))) + local body_lines = inject_result and inject_result.body_lines or "nil" + logger.info("patch", string.format("inject done: body_lines=%s", tostring(body_lines))) - M.mark_applied(patch.id) + M.mark_applied(patch.id) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "success", - message = string.format("Patch %s applied successfully", patch.id), - data = { - target_path = patch.target_path, - lines_added = #code_lines, - }, - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "success", + message = string.format("Patch %s applied successfully", patch.id), + data = { + target_path = patch.target_path, + lines_added = #code_lines, + }, + }) + end) - -- Learn from successful code generation - this builds neural pathways - -- The more code is successfully applied, the better the brain becomes - pcall(function() - local brain = require("codetyper.core.memory") - if brain.is_initialized() then - -- Learn the successful pattern - local intent_type = patch.intent and patch.intent.type or "unknown" - local scope_type = patch.scope and patch.scope.type or "file" - local scope_name = patch.scope and patch.scope.name or "" + -- Learn from successful code generation - this builds neural pathways + -- The more code is successfully applied, the better the brain becomes + pcall(function() + local brain = require("codetyper.core.memory") + if brain.is_initialized() then + -- Learn the successful pattern + local intent_type = patch.intent and patch.intent.type or "unknown" + local scope_type = patch.scope and patch.scope.type or "file" + local scope_name = patch.scope and patch.scope.name or "" - -- Create a meaningful summary for this learning - local summary = string.format( - "Generated %s: %s %s in %s", - intent_type, - scope_type, - scope_name ~= "" and scope_name or "", - vim.fn.fnamemodify(patch.target_path or "", ":t") - ) + -- Create a meaningful summary for this learning + local summary = string.format( + "Generated %s: %s %s in %s", + intent_type, + scope_type, + scope_name ~= "" and scope_name or "", + vim.fn.fnamemodify(patch.target_path or "", ":t") + ) - brain.learn({ - type = "code_completion", - file = patch.target_path, - timestamp = os.time(), - data = { - intent = intent_type, - code = patch.generated_code:sub(1, 500), -- Store first 500 chars - language = vim.fn.fnamemodify(patch.target_path or "", ":e"), - function_name = scope_name, - prompt = patch.prompt_content, - confidence = patch.confidence or 0.5, - }, - }) - end - end) + brain.learn({ + type = "code_completion", + file = patch.target_path, + timestamp = os.time(), + data = { + intent = intent_type, + code = patch.generated_code:sub(1, 500), -- Store first 500 chars + language = vim.fn.fnamemodify(patch.target_path or "", ":e"), + function_name = scope_name, + prompt = patch.prompt_content, + confidence = patch.confidence or 0.5, + }, + }) + end + end) - return true, nil + return true, nil end --- Flush all pending patches that are safe to apply @@ -850,81 +878,80 @@ end ---@return number stale_count ---@return number deferred_count function M.flush_pending() - local applied = 0 - local stale = 0 - local deferred = 0 + local applied = 0 + local stale = 0 + local deferred = 0 - for _, p in ipairs(patches) do - if p.status == "pending" then - local success, err = M.apply(p) - if success then - applied = applied + 1 - elseif err == "user_typing" then - -- Keep pending, will retry later - deferred = deferred + 1 - else - stale = stale + 1 - end - end - end + for _, p in ipairs(patches) do + if p.status == "pending" then + local success, err = M.apply(p) + if success then + applied = applied + 1 + elseif err == "user_typing" then + -- Keep pending, will retry later + deferred = deferred + 1 + else + stale = stale + 1 + end + end + end - return applied, stale, deferred + return applied, stale, deferred end --- Cancel all pending patches for a buffer ---@param bufnr number ---@return number cancelled_count function M.cancel_for_buffer(bufnr) - local cancelled = 0 - for _, patch in ipairs(patches) do - if patch.status == "pending" and - (patch.target_bufnr == bufnr or patch.original_snapshot.bufnr == bufnr) then - patch.status = "cancelled" - cancelled = cancelled + 1 - end - end - return cancelled + local cancelled = 0 + for _, patch in ipairs(patches) do + if patch.status == "pending" and (patch.target_bufnr == bufnr or patch.original_snapshot.bufnr == bufnr) then + patch.status = "cancelled" + cancelled = cancelled + 1 + end + end + return cancelled end --- Cleanup old patches ---@param max_age number Max age in seconds (default: 300) function M.cleanup(max_age) - max_age = max_age or 300 - local now = os.time() - local i = 1 - while i <= #patches do - local patch = patches[i] - if patch.status ~= "pending" and (now - patch.created_at) > max_age then - table.remove(patches, i) - else - i = i + 1 - end - end + max_age = max_age or 300 + local now = os.time() + local i = 1 + while i <= #patches do + local patch = patches[i] + if patch.status ~= "pending" and (now - patch.created_at) > max_age then + table.remove(patches, i) + else + i = i + 1 + end + end end --- Get statistics ---@return table function M.stats() - local stats = { - total = #patches, - pending = 0, - applied = 0, - stale = 0, - rejected = 0, - cancelled = 0, - } - for _, patch in ipairs(patches) do - local s = patch.status - if stats[s] then - stats[s] = stats[s] + 1 - end - end - return stats + local stats = { + total = #patches, + pending = 0, + applied = 0, + stale = 0, + rejected = 0, + cancelled = 0, + } + for _, patch in ipairs(patches) do + local s = patch.status + if stats[s] then + stats[s] = stats[s] + 1 + end + end + return stats end --- Clear all patches function M.clear() - patches = {} + patches = {} end --- Configure patch behavior @@ -932,24 +959,24 @@ end --- - use_conflict_mode: boolean Use conflict markers instead of direct apply --- - auto_jump_to_conflict: boolean Auto-jump to first conflict after applying function M.configure(opts) - if opts.use_conflict_mode ~= nil then - config.use_conflict_mode = opts.use_conflict_mode - end - if opts.auto_jump_to_conflict ~= nil then - config.auto_jump_to_conflict = opts.auto_jump_to_conflict - end + if opts.use_conflict_mode ~= nil then + config.use_conflict_mode = opts.use_conflict_mode + end + if opts.auto_jump_to_conflict ~= nil then + config.auto_jump_to_conflict = opts.auto_jump_to_conflict + end end --- Get current configuration ---@return table function M.get_config() - return vim.deepcopy(config) + return vim.deepcopy(config) end --- Check if conflict mode is enabled ---@return boolean function M.is_conflict_mode() - return config.use_conflict_mode + return config.use_conflict_mode end --- Apply a patch using conflict markers for interactive review @@ -958,120 +985,110 @@ end ---@return boolean success ---@return string|nil error function M.apply_with_conflict(patch) - -- Check if safe to modify (not in insert mode) - if not is_safe_to_modify() then - return false, "user_typing" - end + -- Check if safe to modify (not in insert mode) + if not is_safe_to_modify() then + return false, "user_typing" + end - -- Check staleness first - local is_stale, stale_reason = M.is_stale(patch) - if is_stale then - M.mark_stale(patch.id, stale_reason) - return false, "patch_stale: " .. (stale_reason or "unknown") - end + -- Check staleness first + local is_stale, stale_reason = M.is_stale(patch) + if is_stale then + M.mark_stale(patch.id, stale_reason) + return false, "patch_stale: " .. (stale_reason or "unknown") + end - -- Ensure target buffer is valid - local target_bufnr = patch.target_bufnr - if target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then - target_bufnr = vim.fn.bufadd(patch.target_path) - if target_bufnr == 0 then - M.mark_rejected(patch.id, "buffer_not_found") - return false, "target buffer not found" - end - vim.fn.bufload(target_bufnr) - patch.target_bufnr = target_bufnr - end + -- Ensure target buffer is valid + local target_bufnr = patch.target_bufnr + if target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then + target_bufnr = vim.fn.bufadd(patch.target_path) + if target_bufnr == 0 then + M.mark_rejected(patch.id, "buffer_not_found") + return false, "target buffer not found" + end + vim.fn.bufload(target_bufnr) + patch.target_bufnr = target_bufnr + end - local conflict = get_conflict_module() - local source_bufnr = patch.source_bufnr - local is_inline_prompt = patch.is_inline_prompt or (source_bufnr == target_bufnr) + local conflict = get_conflict_module() + local source_bufnr = patch.source_bufnr + local is_inline_prompt = patch.is_inline_prompt or (source_bufnr == target_bufnr) - - -- For SEARCH/REPLACE blocks, convert each block to a conflict - if patch.use_search_replace and patch.search_replace_blocks and #patch.search_replace_blocks > 0 then - local search_replace = get_search_replace_module() - local content = table.concat(vim.api.nvim_buf_get_lines(target_bufnr, 0, -1, false), "\n") - local applied_count = 0 + -- For SEARCH/REPLACE blocks, convert each block to a conflict + if patch.use_search_replace and patch.search_replace_blocks and #patch.search_replace_blocks > 0 then + local search_replace = get_search_replace_module() + local content = table.concat(vim.api.nvim_buf_get_lines(target_bufnr, 0, -1, false), "\n") + local applied_count = 0 - -- Sort blocks by position (bottom to top) to maintain line numbers - local sorted_blocks = {} - for _, block in ipairs(patch.search_replace_blocks) do - local match = search_replace.find_match(content, block.search) - if match then - block._match = match - table.insert(sorted_blocks, block) - end - end - table.sort(sorted_blocks, function(a, b) - return (a._match and a._match.start_line or 0) > (b._match and b._match.start_line or 0) - end) + -- Sort blocks by position (bottom to top) to maintain line numbers + local sorted_blocks = {} + for _, block in ipairs(patch.search_replace_blocks) do + local match = search_replace.find_match(content, block.search) + if match then + block._match = match + table.insert(sorted_blocks, block) + end + end + table.sort(sorted_blocks, function(a, b) + return (a._match and a._match.start_line or 0) > (b._match and b._match.start_line or 0) + end) - -- Apply each block as a conflict - for _, block in ipairs(sorted_blocks) do - local match = block._match - if match then - local new_lines = vim.split(block.replace, "\n", { plain = true }) - conflict.insert_conflict( - target_bufnr, - match.start_line, - match.end_line, - new_lines, - "AI SUGGESTION" - ) - applied_count = applied_count + 1 - -- Re-read content for next match (line numbers changed) - content = table.concat(vim.api.nvim_buf_get_lines(target_bufnr, 0, -1, false), "\n") - end - end + -- Apply each block as a conflict + for _, block in ipairs(sorted_blocks) do + local match = block._match + if match then + local new_lines = vim.split(block.replace, "\n", { plain = true }) + conflict.insert_conflict(target_bufnr, match.start_line, match.end_line, new_lines, "AI SUGGESTION") + applied_count = applied_count + 1 + -- Re-read content for next match (line numbers changed) + content = table.concat(vim.api.nvim_buf_get_lines(target_bufnr, 0, -1, false), "\n") + end + end - if applied_count > 0 then - -- Process conflicts (highlight, keymaps) and show menu - conflict.process_and_show_menu(target_bufnr) + if applied_count > 0 then + -- Process conflicts (highlight, keymaps) and show menu + conflict.process_and_show_menu(target_bufnr) - M.mark_applied(patch.id) + M.mark_applied(patch.id) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "success", - message = string.format( - "Created %d conflict(s) for review - use co/ct/cb/cn to resolve", - applied_count - ), - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "success", + message = string.format("Created %d conflict(s) for review - use co/ct/cb/cn to resolve", applied_count), + }) + end) - return true, nil - end - end + return true, nil + end + end - -- Fallback: Use injection range if available - if patch.injection_range then - local start_line = patch.injection_range.start_line - local end_line = patch.injection_range.end_line - local new_lines = vim.split(patch.generated_code, "\n", { plain = true }) + -- Fallback: Use injection range if available + if patch.injection_range then + local start_line = patch.injection_range.start_line + local end_line = patch.injection_range.end_line + local new_lines = vim.split(patch.generated_code, "\n", { plain = true }) - -- Insert conflict markers - conflict.insert_conflict(target_bufnr, start_line, end_line, new_lines, "AI SUGGESTION") + -- Insert conflict markers + conflict.insert_conflict(target_bufnr, start_line, end_line, new_lines, "AI SUGGESTION") - -- Process conflicts (highlight, keymaps) and show menu - conflict.process_and_show_menu(target_bufnr) + -- Process conflicts (highlight, keymaps) and show menu + conflict.process_and_show_menu(target_bufnr) - M.mark_applied(patch.id) + M.mark_applied(patch.id) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "success", - message = "Created conflict for review - use co/ct/cb/cn to resolve", - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "success", + message = "Created conflict for review - use co/ct/cb/cn to resolve", + }) + end) - return true, nil - end + return true, nil + end - -- No suitable range found, fall back to direct apply - return M.apply(patch) + -- No suitable range found, fall back to direct apply + return M.apply(patch) end --- Smart apply - uses conflict mode if enabled, otherwise direct apply @@ -1079,11 +1096,11 @@ end ---@return boolean success ---@return string|nil error function M.smart_apply(patch) - if config.use_conflict_mode then - return M.apply_with_conflict(patch) - else - return M.apply(patch) - end + if config.use_conflict_mode then + return M.apply_with_conflict(patch) + else + return M.apply(patch) + end end --- Flush all pending patches using smart apply @@ -1091,28 +1108,28 @@ end ---@return number stale_count ---@return number deferred_count function M.flush_pending_smart() - local applied = 0 - local stale = 0 - local deferred = 0 + local applied = 0 + local stale = 0 + local deferred = 0 - for _, p in ipairs(patches) do - if p.status == "pending" then - logger.info("patch", string.format("flush trying: id=%s", p.id)) - local success, err = M.smart_apply(p) - if success then - applied = applied + 1 - logger.info("patch", string.format("flush result: id=%s success", p.id)) - elseif err == "user_typing" then - deferred = deferred + 1 - logger.info("patch", string.format("flush result: id=%s deferred (user_typing)", p.id)) - else - stale = stale + 1 - logger.info("patch", string.format("flush result: id=%s stale (%s)", p.id, tostring(err))) - end - end - end + for _, p in ipairs(patches) do + if p.status == "pending" then + logger.info("patch", string.format("flush trying: id=%s", p.id)) + local success, err = M.smart_apply(p) + if success then + applied = applied + 1 + logger.info("patch", string.format("flush result: id=%s success", p.id)) + elseif err == "user_typing" then + deferred = deferred + 1 + logger.info("patch", string.format("flush result: id=%s deferred (user_typing)", p.id)) + else + stale = stale + 1 + logger.info("patch", string.format("flush result: id=%s stale (%s)", p.id, tostring(err))) + end + end + end - return applied, stale, deferred + return applied, stale, deferred end return M diff --git a/lua/codetyper/core/diff/search_replace.lua b/lua/codetyper/core/diff/search_replace.lua index 0fc6eb9..a830bc9 100644 --- a/lua/codetyper/core/diff/search_replace.lua +++ b/lua/codetyper/core/diff/search_replace.lua @@ -47,92 +47,92 @@ local params = require("codetyper.params.agents.search_replace").patterns ---@param response string LLM response text ---@return SearchReplaceBlock[] function M.parse_blocks(response) - local blocks = {} + local blocks = {} - -- Try dash-style format: ------- SEARCH ... ======= ... +++++++ REPLACE - for search, replace in response:gmatch(params.dash_style) do - table.insert(blocks, { search = search, replace = replace }) - end + -- Try dash-style format: ------- SEARCH ... ======= ... +++++++ REPLACE + for search, replace in response:gmatch(params.dash_style) do + table.insert(blocks, { search = search, replace = replace }) + end - if #blocks > 0 then - return blocks - end + if #blocks > 0 then + return blocks + end - -- Try claude-style format: <<<<<<< SEARCH ... ======= ... >>>>>>> REPLACE - for search, replace in response:gmatch(params.claude_style) do - table.insert(blocks, { search = search, replace = replace }) - end + -- Try claude-style format: <<<<<<< SEARCH ... ======= ... >>>>>>> REPLACE + for search, replace in response:gmatch(params.claude_style) do + table.insert(blocks, { search = search, replace = replace }) + end - if #blocks > 0 then - return blocks - end + if #blocks > 0 then + return blocks + end - -- Try simple format: [SEARCH] ... [REPLACE] ... [END] - for search, replace in response:gmatch(params.simple_style) do - table.insert(blocks, { search = search, replace = replace }) - end + -- Try simple format: [SEARCH] ... [REPLACE] ... [END] + for search, replace in response:gmatch(params.simple_style) do + table.insert(blocks, { search = search, replace = replace }) + end - if #blocks > 0 then - return blocks - end + if #blocks > 0 then + return blocks + end - -- Try markdown diff format: ```diff ... ``` - local diff_block = response:match(params.diff_block) - if diff_block then - local old_lines = {} - local new_lines = {} - for line in diff_block:gmatch("[^\n]+") do - if line:match("^%-[^%-]") then - -- Removed line (starts with single -) - table.insert(old_lines, line:sub(2)) - elseif line:match("^%+[^%+]") then - -- Added line (starts with single +) - table.insert(new_lines, line:sub(2)) - elseif line:match("^%s") or line:match("^[^%-%+@]") then - -- Context line - table.insert(old_lines, line:match("^%s?(.*)")) - table.insert(new_lines, line:match("^%s?(.*)")) - end - end - if #old_lines > 0 or #new_lines > 0 then - table.insert(blocks, { - search = table.concat(old_lines, "\n"), - replace = table.concat(new_lines, "\n"), - }) - end - end + -- Try markdown diff format: ```diff ... ``` + local diff_block = response:match(params.diff_block) + if diff_block then + local old_lines = {} + local new_lines = {} + for line in diff_block:gmatch("[^\n]+") do + if line:match("^%-[^%-]") then + -- Removed line (starts with single -) + table.insert(old_lines, line:sub(2)) + elseif line:match("^%+[^%+]") then + -- Added line (starts with single +) + table.insert(new_lines, line:sub(2)) + elseif line:match("^%s") or line:match("^[^%-%+@]") then + -- Context line + table.insert(old_lines, line:match("^%s?(.*)")) + table.insert(new_lines, line:match("^%s?(.*)")) + end + end + if #old_lines > 0 or #new_lines > 0 then + table.insert(blocks, { + search = table.concat(old_lines, "\n"), + replace = table.concat(new_lines, "\n"), + }) + end + end - return blocks + return blocks end --- Get indentation of a line ---@param line string ---@return string local function get_indentation(line) - if not line then - return "" - end - return line:match("^(%s*)") or "" + if not line then + return "" + end + return line:match("^(%s*)") or "" end --- Normalize whitespace in a string (collapse multiple spaces to one) ---@param str string ---@return string local function normalize_whitespace(str) - -- Wrap in parentheses to only return first value (gsub returns string + count) - return (str:gsub("%s+", " "):gsub("^%s*", ""):gsub("%s*$", "")) + -- Wrap in parentheses to only return first value (gsub returns string + count) + return (str:gsub("%s+", " "):gsub("^%s*", ""):gsub("%s*$", "")) end --- Trim trailing whitespace from each line ---@param str string ---@return string local function trim_lines(str) - local lines = vim.split(str, "\n", { plain = true }) - for i, line in ipairs(lines) do - -- Wrap in parentheses to only get string, not count - lines[i] = (line:gsub("%s+$", "")) - end - return table.concat(lines, "\n") + local lines = vim.split(str, "\n", { plain = true }) + for i, line in ipairs(lines) do + -- Wrap in parentheses to only get string, not count + lines[i] = (line:gsub("%s+$", "")) + end + return table.concat(lines, "\n") end --- Calculate Levenshtein distance between two strings @@ -140,34 +140,30 @@ end ---@param s2 string ---@return number local function levenshtein(s1, s2) - local len1, len2 = #s1, #s2 - if len1 == 0 then - return len2 - end - if len2 == 0 then - return len1 - end + local len1, len2 = #s1, #s2 + if len1 == 0 then + return len2 + end + if len2 == 0 then + return len1 + end - local matrix = {} - for i = 0, len1 do - matrix[i] = { [0] = i } - end - for j = 0, len2 do - matrix[0][j] = j - end + local matrix = {} + for i = 0, len1 do + matrix[i] = { [0] = i } + end + for j = 0, len2 do + matrix[0][j] = j + end - for i = 1, len1 do - for j = 1, len2 do - local cost = (s1:sub(i, i) == s2:sub(j, j)) and 0 or 1 - matrix[i][j] = math.min( - matrix[i - 1][j] + 1, - matrix[i][j - 1] + 1, - matrix[i - 1][j - 1] + cost - ) - end - end + for i = 1, len1 do + for j = 1, len2 do + local cost = (s1:sub(i, i) == s2:sub(j, j)) and 0 or 1 + matrix[i][j] = math.min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost) + end + end - return matrix[len1][len2] + return matrix[len1][len2] end --- Calculate similarity ratio (0.0-1.0) between two strings @@ -175,15 +171,15 @@ end ---@param s2 string ---@return number local function similarity(s1, s2) - if s1 == s2 then - return 1.0 - end - local max_len = math.max(#s1, #s2) - if max_len == 0 then - return 1.0 - end - local distance = levenshtein(s1, s2) - return 1.0 - (distance / max_len) + if s1 == s2 then + return 1.0 + end + local max_len = math.max(#s1, #s2) + if max_len == 0 then + return 1.0 + end + local distance = levenshtein(s1, s2) + return 1.0 - (distance / max_len) end --- Strategy 1: Exact match @@ -191,31 +187,31 @@ end ---@param search_lines string[] ---@return MatchResult|nil local function exact_match(content_lines, search_lines) - if #search_lines == 0 then - return nil - end + if #search_lines == 0 then + return nil + end - for i = 1, #content_lines - #search_lines + 1 do - local match = true - for j = 1, #search_lines do - if content_lines[i + j - 1] ~= search_lines[j] then - match = false - break - end - end - if match then - return { - start_line = i, - end_line = i + #search_lines - 1, - start_col = 1, - end_col = #content_lines[i + #search_lines - 1], - strategy = "exact", - confidence = 1.0, - } - end - end + for i = 1, #content_lines - #search_lines + 1 do + local match = true + for j = 1, #search_lines do + if content_lines[i + j - 1] ~= search_lines[j] then + match = false + break + end + end + if match then + return { + start_line = i, + end_line = i + #search_lines - 1, + start_col = 1, + end_col = #content_lines[i + #search_lines - 1], + strategy = "exact", + confidence = 1.0, + } + end + end - return nil + return nil end --- Strategy 2: Line-trimmed match (ignore trailing whitespace) @@ -223,37 +219,37 @@ end ---@param search_lines string[] ---@return MatchResult|nil local function line_trimmed_match(content_lines, search_lines) - if #search_lines == 0 then - return nil - end + if #search_lines == 0 then + return nil + end - local trimmed_search = {} - for _, line in ipairs(search_lines) do - table.insert(trimmed_search, (line:gsub("%s+$", ""))) - end + local trimmed_search = {} + for _, line in ipairs(search_lines) do + table.insert(trimmed_search, (line:gsub("%s+$", ""))) + end - for i = 1, #content_lines - #search_lines + 1 do - local match = true - for j = 1, #search_lines do - local trimmed_content = content_lines[i + j - 1]:gsub("%s+$", "") - if trimmed_content ~= trimmed_search[j] then - match = false - break - end - end - if match then - return { - start_line = i, - end_line = i + #search_lines - 1, - start_col = 1, - end_col = #content_lines[i + #search_lines - 1], - strategy = "line_trimmed", - confidence = 0.95, - } - end - end + for i = 1, #content_lines - #search_lines + 1 do + local match = true + for j = 1, #search_lines do + local trimmed_content = content_lines[i + j - 1]:gsub("%s+$", "") + if trimmed_content ~= trimmed_search[j] then + match = false + break + end + end + if match then + return { + start_line = i, + end_line = i + #search_lines - 1, + start_col = 1, + end_col = #content_lines[i + #search_lines - 1], + strategy = "line_trimmed", + confidence = 0.95, + } + end + end - return nil + return nil end --- Strategy 3: Indentation-flexible match (normalize indentation) @@ -261,65 +257,65 @@ end ---@param search_lines string[] ---@return MatchResult|nil local function indentation_flexible_match(content_lines, search_lines) - if #search_lines == 0 then - return nil - end + if #search_lines == 0 then + return nil + end - -- Get base indentation from search (first non-empty line) - local search_indent = "" - for _, line in ipairs(search_lines) do - if line:match("%S") then - search_indent = get_indentation(line) - break - end - end + -- Get base indentation from search (first non-empty line) + local search_indent = "" + for _, line in ipairs(search_lines) do + if line:match("%S") then + search_indent = get_indentation(line) + break + end + end - -- Strip common indentation from search - local stripped_search = {} - for _, line in ipairs(search_lines) do - if line:match("^" .. vim.pesc(search_indent)) then - table.insert(stripped_search, line:sub(#search_indent + 1)) - else - table.insert(stripped_search, line) - end - end + -- Strip common indentation from search + local stripped_search = {} + for _, line in ipairs(search_lines) do + if line:match("^" .. vim.pesc(search_indent)) then + table.insert(stripped_search, line:sub(#search_indent + 1)) + else + table.insert(stripped_search, line) + end + end - for i = 1, #content_lines - #search_lines + 1 do - -- Get content indentation at this position - local content_indent = "" - for j = 0, #search_lines - 1 do - local line = content_lines[i + j] - if line:match("%S") then - content_indent = get_indentation(line) - break - end - end + for i = 1, #content_lines - #search_lines + 1 do + -- Get content indentation at this position + local content_indent = "" + for j = 0, #search_lines - 1 do + local line = content_lines[i + j] + if line:match("%S") then + content_indent = get_indentation(line) + break + end + end - local match = true - for j = 1, #search_lines do - local content_line = content_lines[i + j - 1] - local expected = content_indent .. stripped_search[j] + local match = true + for j = 1, #search_lines do + local content_line = content_lines[i + j - 1] + local expected = content_indent .. stripped_search[j] - -- Compare with normalized indentation - if content_line:gsub("%s+$", "") ~= expected:gsub("%s+$", "") then - match = false - break - end - end + -- Compare with normalized indentation + if content_line:gsub("%s+$", "") ~= expected:gsub("%s+$", "") then + match = false + break + end + end - if match then - return { - start_line = i, - end_line = i + #search_lines - 1, - start_col = 1, - end_col = #content_lines[i + #search_lines - 1], - strategy = "indentation_flexible", - confidence = 0.9, - } - end - end + if match then + return { + start_line = i, + end_line = i + #search_lines - 1, + start_col = 1, + end_col = #content_lines[i + #search_lines - 1], + strategy = "indentation_flexible", + confidence = 0.9, + } + end + end - return nil + return nil end --- Strategy 4: Block anchor match (match first/last lines, fuzzy middle) @@ -327,56 +323,56 @@ end ---@param search_lines string[] ---@return MatchResult|nil local function block_anchor_match(content_lines, search_lines) - if #search_lines < 2 then - return nil - end + if #search_lines < 2 then + return nil + end - local first_search = search_lines[1]:gsub("%s+$", "") - local last_search = search_lines[#search_lines]:gsub("%s+$", "") + local first_search = search_lines[1]:gsub("%s+$", "") + local last_search = search_lines[#search_lines]:gsub("%s+$", "") - -- Find potential start positions - local candidates = {} - for i = 1, #content_lines - #search_lines + 1 do - local first_content = content_lines[i]:gsub("%s+$", "") - if similarity(first_content, first_search) > 0.8 then - -- Check if last line also matches - local last_idx = i + #search_lines - 1 - if last_idx <= #content_lines then - local last_content = content_lines[last_idx]:gsub("%s+$", "") - if similarity(last_content, last_search) > 0.8 then - -- Calculate overall similarity - local total_sim = 0 - for j = 1, #search_lines do - local c = content_lines[i + j - 1]:gsub("%s+$", "") - local s = search_lines[j]:gsub("%s+$", "") - total_sim = total_sim + similarity(c, s) - end - local avg_sim = total_sim / #search_lines - if avg_sim > 0.7 then - table.insert(candidates, { start = i, similarity = avg_sim }) - end - end - end - end - end + -- Find potential start positions + local candidates = {} + for i = 1, #content_lines - #search_lines + 1 do + local first_content = content_lines[i]:gsub("%s+$", "") + if similarity(first_content, first_search) > 0.8 then + -- Check if last line also matches + local last_idx = i + #search_lines - 1 + if last_idx <= #content_lines then + local last_content = content_lines[last_idx]:gsub("%s+$", "") + if similarity(last_content, last_search) > 0.8 then + -- Calculate overall similarity + local total_sim = 0 + for j = 1, #search_lines do + local c = content_lines[i + j - 1]:gsub("%s+$", "") + local s = search_lines[j]:gsub("%s+$", "") + total_sim = total_sim + similarity(c, s) + end + local avg_sim = total_sim / #search_lines + if avg_sim > 0.7 then + table.insert(candidates, { start = i, similarity = avg_sim }) + end + end + end + end + end - -- Return best match - if #candidates > 0 then - table.sort(candidates, function(a, b) - return a.similarity > b.similarity - end) - local best = candidates[1] - return { - start_line = best.start, - end_line = best.start + #search_lines - 1, - start_col = 1, - end_col = #content_lines[best.start + #search_lines - 1], - strategy = "block_anchor", - confidence = best.similarity * 0.85, - } - end + -- Return best match + if #candidates > 0 then + table.sort(candidates, function(a, b) + return a.similarity > b.similarity + end) + local best = candidates[1] + return { + start_line = best.start, + end_line = best.start + #search_lines - 1, + start_col = 1, + end_col = #content_lines[best.start + #search_lines - 1], + strategy = "block_anchor", + confidence = best.similarity * 0.85, + } + end - return nil + return nil end --- Strategy 5: Whitespace-normalized match @@ -384,38 +380,38 @@ end ---@param search_lines string[] ---@return MatchResult|nil local function whitespace_normalized_match(content_lines, search_lines) - if #search_lines == 0 then - return nil - end + if #search_lines == 0 then + return nil + end - -- Normalize search lines - local norm_search = {} - for _, line in ipairs(search_lines) do - table.insert(norm_search, normalize_whitespace(line)) - end + -- Normalize search lines + local norm_search = {} + for _, line in ipairs(search_lines) do + table.insert(norm_search, normalize_whitespace(line)) + end - for i = 1, #content_lines - #search_lines + 1 do - local match = true - for j = 1, #search_lines do - local norm_content = normalize_whitespace(content_lines[i + j - 1]) - if norm_content ~= norm_search[j] then - match = false - break - end - end - if match then - return { - start_line = i, - end_line = i + #search_lines - 1, - start_col = 1, - end_col = #content_lines[i + #search_lines - 1], - strategy = "whitespace_normalized", - confidence = 0.8, - } - end - end + for i = 1, #content_lines - #search_lines + 1 do + local match = true + for j = 1, #search_lines do + local norm_content = normalize_whitespace(content_lines[i + j - 1]) + if norm_content ~= norm_search[j] then + match = false + break + end + end + if match then + return { + start_line = i, + end_line = i + #search_lines - 1, + start_col = 1, + end_col = #content_lines[i + #search_lines - 1], + strategy = "whitespace_normalized", + confidence = 0.8, + } + end + end - return nil + return nil end --- Find the best match for search text in content @@ -423,35 +419,35 @@ end ---@param search string Text to search for ---@return MatchResult|nil function M.find_match(content, search) - local content_lines = vim.split(content, "\n", { plain = true }) - local search_lines = vim.split(search, "\n", { plain = true }) + local content_lines = vim.split(content, "\n", { plain = true }) + local search_lines = vim.split(search, "\n", { plain = true }) - -- Remove trailing empty lines from search - while #search_lines > 0 and search_lines[#search_lines]:match("^%s*$") do - table.remove(search_lines) - end + -- Remove trailing empty lines from search + while #search_lines > 0 and search_lines[#search_lines]:match("^%s*$") do + table.remove(search_lines) + end - if #search_lines == 0 then - return nil - end + if #search_lines == 0 then + return nil + end - -- Try strategies in order of strictness - local strategies = { - exact_match, - line_trimmed_match, - indentation_flexible_match, - block_anchor_match, - whitespace_normalized_match, - } + -- Try strategies in order of strictness + local strategies = { + exact_match, + line_trimmed_match, + indentation_flexible_match, + block_anchor_match, + whitespace_normalized_match, + } - for _, strategy in ipairs(strategies) do - local result = strategy(content_lines, search_lines) - if result then - return result - end - end + for _, strategy in ipairs(strategies) do + local result = strategy(content_lines, search_lines) + if result then + return result + end + end - return nil + return nil end --- Apply a single SEARCH/REPLACE block to content @@ -461,49 +457,49 @@ end ---@return MatchResult|nil match_info ---@return string|nil error function M.apply_block(content, block) - local match = M.find_match(content, block.search) - if not match then - return nil, nil, "Could not find search text in file" - end + local match = M.find_match(content, block.search) + if not match then + return nil, nil, "Could not find search text in file" + end - local content_lines = vim.split(content, "\n", { plain = true }) - local replace_lines = vim.split(block.replace, "\n", { plain = true }) + local content_lines = vim.split(content, "\n", { plain = true }) + local replace_lines = vim.split(block.replace, "\n", { plain = true }) - -- Adjust indentation of replacement to match original - local original_indent = get_indentation(content_lines[match.start_line]) - local replace_indent = "" - for _, line in ipairs(replace_lines) do - if line:match("%S") then - replace_indent = get_indentation(line) - break - end - end + -- Adjust indentation of replacement to match original + local original_indent = get_indentation(content_lines[match.start_line]) + local replace_indent = "" + for _, line in ipairs(replace_lines) do + if line:match("%S") then + replace_indent = get_indentation(line) + break + end + end - -- Apply indentation adjustment - local adjusted_replace = {} - for _, line in ipairs(replace_lines) do - if line:match("^" .. vim.pesc(replace_indent)) then - table.insert(adjusted_replace, original_indent .. line:sub(#replace_indent + 1)) - elseif line:match("^%s*$") then - table.insert(adjusted_replace, "") - else - table.insert(adjusted_replace, original_indent .. line) - end - end + -- Apply indentation adjustment + local adjusted_replace = {} + for _, line in ipairs(replace_lines) do + if line:match("^" .. vim.pesc(replace_indent)) then + table.insert(adjusted_replace, original_indent .. line:sub(#replace_indent + 1)) + elseif line:match("^%s*$") then + table.insert(adjusted_replace, "") + else + table.insert(adjusted_replace, original_indent .. line) + end + end - -- Build new content - local new_lines = {} - for i = 1, match.start_line - 1 do - table.insert(new_lines, content_lines[i]) - end - for _, line in ipairs(adjusted_replace) do - table.insert(new_lines, line) - end - for i = match.end_line + 1, #content_lines do - table.insert(new_lines, content_lines[i]) - end + -- Build new content + local new_lines = {} + for i = 1, match.start_line - 1 do + table.insert(new_lines, content_lines[i]) + end + for _, line in ipairs(adjusted_replace) do + table.insert(new_lines, line) + end + for i = match.end_line + 1, #content_lines do + table.insert(new_lines, content_lines[i]) + end - return table.concat(new_lines, "\n"), match, nil + return table.concat(new_lines, "\n"), match, nil end --- Apply multiple SEARCH/REPLACE blocks to content @@ -512,20 +508,20 @@ end ---@return string new_content ---@return table results Array of {success: boolean, match: MatchResult|nil, error: string|nil} function M.apply_blocks(content, blocks) - local current_content = content - local results = {} + local current_content = content + local results = {} - for _, block in ipairs(blocks) do - local new_content, match, err = M.apply_block(current_content, block) - if new_content then - current_content = new_content - table.insert(results, { success = true, match = match }) - else - table.insert(results, { success = false, error = err }) - end - end + for _, block in ipairs(blocks) do + local new_content, match, err = M.apply_block(current_content, block) + if new_content then + current_content = new_content + table.insert(results, { success = true, match = match }) + else + table.insert(results, { success = false, error = err }) + end + end - return current_content, results + return current_content, results end --- Apply SEARCH/REPLACE blocks to a buffer @@ -534,39 +530,39 @@ end ---@return boolean success ---@return string|nil error function M.apply_to_buffer(bufnr, blocks) - if not vim.api.nvim_buf_is_valid(bufnr) then - return false, "Invalid buffer" - end + if not vim.api.nvim_buf_is_valid(bufnr) then + return false, "Invalid buffer" + end - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local content = table.concat(lines, "\n") + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local content = table.concat(lines, "\n") - local new_content, results = M.apply_blocks(content, blocks) + local new_content, results = M.apply_blocks(content, blocks) - -- Check for any failures - local failures = {} - for i, result in ipairs(results) do - if not result.success then - table.insert(failures, string.format("Block %d: %s", i, result.error or "unknown error")) - end - end + -- Check for any failures + local failures = {} + for i, result in ipairs(results) do + if not result.success then + table.insert(failures, string.format("Block %d: %s", i, result.error or "unknown error")) + end + end - if #failures > 0 then - return false, table.concat(failures, "; ") - end + if #failures > 0 then + return false, table.concat(failures, "; ") + end - -- Apply to buffer - local new_lines = vim.split(new_content, "\n", { plain = true }) - vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, new_lines) + -- Apply to buffer + local new_lines = vim.split(new_content, "\n", { plain = true }) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, new_lines) - return true, nil + return true, nil end --- Check if response contains SEARCH/REPLACE blocks ---@param response string ---@return boolean function M.has_blocks(response) - return #M.parse_blocks(response) > 0 + return #M.parse_blocks(response) > 0 end return M diff --git a/lua/codetyper/core/events/queue.lua b/lua/codetyper/core/events/queue.lua index ae0c149..f1ae55d 100644 --- a/lua/codetyper/core/events/queue.lua +++ b/lua/codetyper/core/events/queue.lua @@ -45,44 +45,44 @@ local event_counter = 0 --- Generate unique event ID ---@return string function M.generate_id() - event_counter = event_counter + 1 - return string.format("evt_%d_%d", os.time(), event_counter) + event_counter = event_counter + 1 + return string.format("evt_%d_%d", os.time(), event_counter) end --- Simple hash function for content ---@param content string ---@return string function M.hash_content(content) - local hash = 0 - for i = 1, #content do - hash = (hash * 31 + string.byte(content, i)) % 2147483647 - end - return string.format("%x", hash) + local hash = 0 + for i = 1, #content do + hash = (hash * 31 + string.byte(content, i)) % 2147483647 + end + return string.format("%x", hash) end --- Notify all listeners of queue change ---@param event_type string "enqueue"|"dequeue"|"update"|"cancel" ---@param event PromptEvent|nil The affected event local function notify_listeners(event_type, event) - for _, listener in ipairs(listeners) do - pcall(listener, event_type, event, #queue) - end + for _, listener in ipairs(listeners) do + pcall(listener, event_type, event, #queue) + end end --- Add event listener ---@param callback function(event_type: string, event: PromptEvent|nil, queue_size: number) ---@return number Listener ID for removal function M.add_listener(callback) - table.insert(listeners, callback) - return #listeners + table.insert(listeners, callback) + return #listeners end --- Remove event listener ---@param listener_id number function M.remove_listener(listener_id) - if listener_id > 0 and listener_id <= #listeners then - table.remove(listeners, listener_id) - end + if listener_id > 0 and listener_id <= #listeners then + table.remove(listeners, listener_id) + end end --- Compare events for priority sorting @@ -90,12 +90,12 @@ end ---@param b PromptEvent ---@return boolean local function compare_priority(a, b) - -- Lower priority number = higher priority - if a.priority ~= b.priority then - return a.priority < b.priority - end - -- Same priority: older events first (FIFO) - return a.timestamp < b.timestamp + -- Lower priority number = higher priority + if a.priority ~= b.priority then + return a.priority < b.priority + end + -- Same priority: older events first (FIFO) + return a.timestamp < b.timestamp end --- Check if two events are in the same scope @@ -103,40 +103,39 @@ end ---@param b PromptEvent ---@return boolean local function same_scope(a, b) - -- Same buffer - if a.target_path ~= b.target_path then - return false - end + -- Same buffer + if a.target_path ~= b.target_path then + return false + end - -- Both have scope ranges - if a.scope_range and b.scope_range then - -- Check if ranges overlap - return a.scope_range.start_line <= b.scope_range.end_line - and b.scope_range.start_line <= a.scope_range.end_line - end + -- Both have scope ranges + if a.scope_range and b.scope_range then + -- Check if ranges overlap + return a.scope_range.start_line <= b.scope_range.end_line and b.scope_range.start_line <= a.scope_range.end_line + end - -- Fallback: check if prompt ranges are close (within 10 lines) - if a.range and b.range then - local distance = math.abs(a.range.start_line - b.range.start_line) - return distance < 10 - end + -- Fallback: check if prompt ranges are close (within 10 lines) + if a.range and b.range then + local distance = math.abs(a.range.start_line - b.range.start_line) + return distance < 10 + end - return false + return false end --- Find conflicting events in the same scope ---@param event PromptEvent ---@return PromptEvent[] Conflicting pending events function M.find_conflicts(event) - local conflicts = {} - for _, existing in ipairs(queue) do - if existing.status == "pending" and existing.id ~= event.id then - if same_scope(event, existing) then - table.insert(conflicts, existing) - end - end - end - return conflicts + local conflicts = {} + for _, existing in ipairs(queue) do + if existing.status == "pending" and existing.id ~= event.id then + if same_scope(event, existing) then + table.insert(conflicts, existing) + end + end + end + return conflicts end --- Check if an event should be skipped due to conflicts (first tag wins) @@ -144,105 +143,102 @@ end ---@return boolean should_skip ---@return string|nil reason function M.check_precedence(event) - local conflicts = M.find_conflicts(event) + local conflicts = M.find_conflicts(event) - for _, conflict in ipairs(conflicts) do - -- First (older) tag wins - if conflict.timestamp < event.timestamp then - return true, string.format( - "Skipped: earlier tag in same scope (event %s)", - conflict.id - ) - end - end + for _, conflict in ipairs(conflicts) do + -- First (older) tag wins + if conflict.timestamp < event.timestamp then + return true, string.format("Skipped: earlier tag in same scope (event %s)", conflict.id) + end + end - return false, nil + return false, nil end --- Insert event maintaining priority order ---@param event PromptEvent local function insert_sorted(event) - local pos = #queue + 1 - for i, existing in ipairs(queue) do - if compare_priority(event, existing) then - pos = i - break - end - end - table.insert(queue, pos, event) + local pos = #queue + 1 + for i, existing in ipairs(queue) do + if compare_priority(event, existing) then + pos = i + break + end + end + table.insert(queue, pos, event) end --- Enqueue a new event ---@param event PromptEvent ---@return PromptEvent The enqueued event with generated ID if missing function M.enqueue(event) - -- Ensure required fields - event.id = event.id or M.generate_id() - event.timestamp = event.timestamp or os.clock() - event.created_at = event.created_at or os.time() - event.status = event.status or "pending" - event.priority = event.priority or 2 - event.attempt_count = event.attempt_count or 0 + -- Ensure required fields + event.id = event.id or M.generate_id() + event.timestamp = event.timestamp or os.clock() + event.created_at = event.created_at or os.time() + event.status = event.status or "pending" + event.priority = event.priority or 2 + event.attempt_count = event.attempt_count or 0 - -- Generate content hash if not provided - if not event.content_hash and event.prompt_content then - event.content_hash = M.hash_content(event.prompt_content) - end + -- Generate content hash if not provided + if not event.content_hash and event.prompt_content then + event.content_hash = M.hash_content(event.prompt_content) + end - insert_sorted(event) - notify_listeners("enqueue", event) + insert_sorted(event) + notify_listeners("enqueue", event) - -- Log to agent logs if available - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "queue", - message = string.format("Event queued: %s (priority: %d)", event.id, event.priority), - data = { - event_id = event.id, - bufnr = event.bufnr, - prompt_preview = event.prompt_content:sub(1, 50), - }, - }) - end) + -- Log to agent logs if available + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "queue", + message = string.format("Event queued: %s (priority: %d)", event.id, event.priority), + data = { + event_id = event.id, + bufnr = event.bufnr, + prompt_preview = event.prompt_content:sub(1, 50), + }, + }) + end) - return event + return event end --- Dequeue highest priority pending event ---@return PromptEvent|nil function M.dequeue() - for i, event in ipairs(queue) do - if event.status == "pending" then - event.status = "processing" - notify_listeners("dequeue", event) - return event - end - end - return nil + for i, event in ipairs(queue) do + if event.status == "pending" then + event.status = "processing" + notify_listeners("dequeue", event) + return event + end + end + return nil end --- Peek at next pending event without removing ---@return PromptEvent|nil function M.peek() - for _, event in ipairs(queue) do - if event.status == "pending" then - return event - end - end - return nil + for _, event in ipairs(queue) do + if event.status == "pending" then + return event + end + end + return nil end --- Get event by ID ---@param id string ---@return PromptEvent|nil function M.get(id) - for _, event in ipairs(queue) do - if event.id == id then - return event - end - end - return nil + for _, event in ipairs(queue) do + if event.id == id then + return event + end + end + return nil end --- Update event status @@ -251,201 +247,207 @@ end ---@param extra table|nil Additional fields to update ---@return boolean Success function M.update_status(id, status, extra) - for _, event in ipairs(queue) do - if event.id == id then - event.status = status - if extra then - for k, v in pairs(extra) do - event[k] = v - end - end - notify_listeners("update", event) - return true - end - end - return false + for _, event in ipairs(queue) do + if event.id == id then + event.status = status + if extra then + for k, v in pairs(extra) do + event[k] = v + end + end + notify_listeners("update", event) + return true + end + end + return false end --- Mark event as completed ---@param id string ---@return boolean function M.complete(id) - return M.update_status(id, "completed") + return M.update_status(id, "completed") end --- Mark event as escalated (needs remote LLM) ---@param id string ---@return boolean function M.escalate(id) - local event = M.get(id) - if event then - event.status = "escalated" - event.attempt_count = event.attempt_count + 1 - -- Re-queue as pending with same priority - event.status = "pending" - notify_listeners("update", event) - return true - end - return false + local event = M.get(id) + if event then + event.status = "escalated" + event.attempt_count = event.attempt_count + 1 + -- Re-queue as pending with same priority + event.status = "pending" + notify_listeners("update", event) + return true + end + return false end --- Cancel all events for a buffer ---@param bufnr number ---@return number Number of cancelled events function M.cancel_for_buffer(bufnr) - local cancelled = 0 - for _, event in ipairs(queue) do - if event.bufnr == bufnr and event.status == "pending" then - event.status = "cancelled" - cancelled = cancelled + 1 - notify_listeners("cancel", event) - end - end - return cancelled + local cancelled = 0 + for _, event in ipairs(queue) do + if event.bufnr == bufnr and event.status == "pending" then + event.status = "cancelled" + cancelled = cancelled + 1 + notify_listeners("cancel", event) + end + end + return cancelled end --- Cancel event by ID ---@param id string ---@return boolean function M.cancel(id) - return M.update_status(id, "cancelled") + return M.update_status(id, "cancelled") end --- Get all pending events ---@return PromptEvent[] function M.get_pending() - local pending = {} - for _, event in ipairs(queue) do - if event.status == "pending" then - table.insert(pending, event) - end - end - return pending + local pending = {} + for _, event in ipairs(queue) do + if event.status == "pending" then + table.insert(pending, event) + end + end + return pending end --- Get all processing events ---@return PromptEvent[] function M.get_processing() - local processing = {} - for _, event in ipairs(queue) do - if event.status == "processing" then - table.insert(processing, event) - end - end - return processing + local processing = {} + for _, event in ipairs(queue) do + if event.status == "processing" then + table.insert(processing, event) + end + end + return processing end --- Get queue size (all events) ---@return number function M.size() - return #queue + return #queue end --- Get count of pending events ---@return number function M.pending_count() - local count = 0 - for _, event in ipairs(queue) do - if event.status == "pending" then - count = count + 1 - end - end - return count + local count = 0 + for _, event in ipairs(queue) do + if event.status == "pending" then + count = count + 1 + end + end + return count end --- Get count of processing events ---@return number function M.processing_count() - local count = 0 - for _, event in ipairs(queue) do - if event.status == "processing" then - count = count + 1 - end - end - return count + local count = 0 + for _, event in ipairs(queue) do + if event.status == "processing" then + count = count + 1 + end + end + return count end --- Check if queue is empty (no pending events) ---@return boolean function M.is_empty() - return M.pending_count() == 0 + return M.pending_count() == 0 end --- Clear all events (optionally filter by status) ---@param status string|nil Status to clear, or nil for all function M.clear(status) - if status then - local i = 1 - while i <= #queue do - if queue[i].status == status then - table.remove(queue, i) - else - i = i + 1 - end - end - else - queue = {} - end - notify_listeners("update", nil) + if status then + local i = 1 + while i <= #queue do + if queue[i].status == status then + table.remove(queue, i) + else + i = i + 1 + end + end + else + queue = {} + end + notify_listeners("update", nil) end --- Cleanup completed/cancelled/failed events older than max_age seconds ---@param max_age number Maximum age in seconds (default: 300) function M.cleanup(max_age) - max_age = max_age or 300 - local now = os.time() - local terminal_statuses = { - completed = true, - cancelled = true, - failed = true, - needs_context = true, - } - local i = 1 - while i <= #queue do - local event = queue[i] - if terminal_statuses[event.status] and (now - event.created_at) > max_age then - table.remove(queue, i) - else - i = i + 1 - end - end + max_age = max_age or 300 + local now = os.time() + local terminal_statuses = { + completed = true, + cancelled = true, + failed = true, + needs_context = true, + } + local i = 1 + while i <= #queue do + local event = queue[i] + if terminal_statuses[event.status] and (now - event.created_at) > max_age then + table.remove(queue, i) + else + i = i + 1 + end + end end --- Get queue statistics ---@return table function M.stats() - local stats = { - total = #queue, - pending = 0, - processing = 0, - completed = 0, - cancelled = 0, - escalated = 0, - failed = 0, - needs_context = 0, - } - for _, event in ipairs(queue) do - local s = event.status - if stats[s] then - stats[s] = stats[s] + 1 - end - end - return stats + local stats = { + total = #queue, + pending = 0, + processing = 0, + completed = 0, + cancelled = 0, + escalated = 0, + failed = 0, + needs_context = 0, + } + for _, event in ipairs(queue) do + local s = event.status + if stats[s] then + stats[s] = stats[s] + 1 + end + end + return stats end --- Debug: dump queue contents ---@return string function M.dump() - local lines = { "Queue contents:" } - for i, event in ipairs(queue) do - table.insert(lines, string.format( - " %d. [%s] %s (p:%d, status:%s, attempts:%d)", - i, event.id, - event.prompt_content:sub(1, 30):gsub("\n", " "), - event.priority, event.status, event.attempt_count - )) - end - return table.concat(lines, "\n") + local lines = { "Queue contents:" } + for i, event in ipairs(queue) do + table.insert( + lines, + string.format( + " %d. [%s] %s (p:%d, status:%s, attempts:%d)", + i, + event.id, + event.prompt_content:sub(1, 30):gsub("\n", " "), + event.priority, + event.status, + event.attempt_count + ) + ) + end + return table.concat(lines, "\n") end return M diff --git a/lua/codetyper/core/intent/init.lua b/lua/codetyper/core/intent/init.lua index 169e5c9..8c3b650 100644 --- a/lua/codetyper/core/intent/init.lua +++ b/lua/codetyper/core/intent/init.lua @@ -22,96 +22,96 @@ local prompts = require("codetyper.prompts.agents.intent") ---@param prompt string The prompt content ---@return Intent function M.detect(prompt) - local lower = prompt:lower() - local best_match = nil - local best_priority = 999 - local matched_keywords = {} + local lower = prompt:lower() + local best_match = nil + local best_priority = 999 + local matched_keywords = {} - -- Check each intent type - for intent_type, config in pairs(intent_patterns) do - for _, pattern in ipairs(config.patterns) do - if lower:find(pattern, 1, true) then - if config.priority < best_priority then - best_match = intent_type - best_priority = config.priority - matched_keywords = { pattern } - elseif config.priority == best_priority and best_match == intent_type then - table.insert(matched_keywords, pattern) - end - end - end - end + -- Check each intent type + for intent_type, config in pairs(intent_patterns) do + for _, pattern in ipairs(config.patterns) do + if lower:find(pattern, 1, true) then + if config.priority < best_priority then + best_match = intent_type + best_priority = config.priority + matched_keywords = { pattern } + elseif config.priority == best_priority and best_match == intent_type then + table.insert(matched_keywords, pattern) + end + end + end + end - -- Default to "add" if no clear intent - if not best_match then - best_match = "add" - matched_keywords = {} - end + -- Default to "add" if no clear intent + if not best_match then + best_match = "add" + matched_keywords = {} + end - local config = intent_patterns[best_match] + local config = intent_patterns[best_match] - -- Detect scope hint from prompt - local scope_hint = config.scope_hint - for pattern, hint in pairs(scope_patterns) do - if lower:find(pattern, 1, true) then - scope_hint = hint or scope_hint - break - end - end + -- Detect scope hint from prompt + local scope_hint = config.scope_hint + for pattern, hint in pairs(scope_patterns) do + if lower:find(pattern, 1, true) then + scope_hint = hint or scope_hint + break + end + end - -- Calculate confidence based on keyword matches - local confidence = 0.5 + (#matched_keywords * 0.15) - confidence = math.min(confidence, 1.0) + -- Calculate confidence based on keyword matches + local confidence = 0.5 + (#matched_keywords * 0.15) + confidence = math.min(confidence, 1.0) - return { - type = best_match, - scope_hint = scope_hint, - confidence = confidence, - action = config.action, - keywords = matched_keywords, - } + return { + type = best_match, + scope_hint = scope_hint, + confidence = confidence, + action = config.action, + keywords = matched_keywords, + } end --- Check if intent requires code modification ---@param intent Intent ---@return boolean function M.modifies_code(intent) - return intent.action ~= "none" + return intent.action ~= "none" end --- Check if intent should replace existing code ---@param intent Intent ---@return boolean function M.is_replacement(intent) - return intent.action == "replace" + return intent.action == "replace" end --- Check if intent adds new code ---@param intent Intent ---@return boolean function M.is_insertion(intent) - return intent.action == "insert" or intent.action == "append" + return intent.action == "insert" or intent.action == "append" end --- Get system prompt modifier based on intent ---@param intent Intent ---@return string function M.get_prompt_modifier(intent) - local modifiers = prompts.modifiers - return modifiers[intent.type] or modifiers.add + local modifiers = prompts.modifiers + return modifiers[intent.type] or modifiers.add end --- Format intent for logging ---@param intent Intent ---@return string function M.format(intent) - return string.format( - "%s (scope: %s, action: %s, confidence: %.2f)", - intent.type, - intent.scope_hint or "auto", - intent.action, - intent.confidence - ) + return string.format( + "%s (scope: %s, action: %s, confidence: %.2f)", + intent.type, + intent.scope_hint or "auto", + intent.action, + intent.confidence + ) end return M diff --git a/lua/codetyper/core/llm/confidence.lua b/lua/codetyper/core/llm/confidence.lua index 5dc5870..678d811 100644 --- a/lua/codetyper/core/llm/confidence.lua +++ b/lua/codetyper/core/llm/confidence.lua @@ -19,168 +19,168 @@ local uncertainty_phrases = params.uncertainty_phrases ---@param prompt string ---@return number 0.0-1.0 local function score_length(response, prompt) - local response_len = #response - local prompt_len = #prompt + local response_len = #response + local prompt_len = #prompt - -- Very short response to long prompt is suspicious - if prompt_len > 50 and response_len < 20 then - return 0.2 - end + -- Very short response to long prompt is suspicious + if prompt_len > 50 and response_len < 20 then + return 0.2 + end - -- Response should generally be longer than prompt for code generation - local ratio = response_len / math.max(prompt_len, 1) + -- Response should generally be longer than prompt for code generation + local ratio = response_len / math.max(prompt_len, 1) - if ratio < 0.5 then - return 0.3 - elseif ratio < 1.0 then - return 0.6 - elseif ratio < 2.0 then - return 0.8 - else - return 1.0 - end + if ratio < 0.5 then + return 0.3 + elseif ratio < 1.0 then + return 0.6 + elseif ratio < 2.0 then + return 0.8 + else + return 1.0 + end end --- Score based on uncertainty phrases ---@param response string ---@return number 0.0-1.0 local function score_uncertainty(response) - local lower = response:lower() - local found = 0 + local lower = response:lower() + local found = 0 - for _, phrase in ipairs(uncertainty_phrases) do - if lower:find(phrase:lower(), 1, true) then - found = found + 1 - end - end + for _, phrase in ipairs(uncertainty_phrases) do + if lower:find(phrase:lower(), 1, true) then + found = found + 1 + end + end - -- More uncertainty phrases = lower score - if found == 0 then - return 1.0 - elseif found == 1 then - return 0.7 - elseif found == 2 then - return 0.5 - else - return 0.2 - end + -- More uncertainty phrases = lower score + if found == 0 then + return 1.0 + elseif found == 1 then + return 0.7 + elseif found == 2 then + return 0.5 + else + return 0.2 + end end --- Score based on syntax completeness ---@param response string ---@return number 0.0-1.0 local function score_syntax(response) - local score = 1.0 + local score = 1.0 - -- Check bracket balance - if not require("codetyper.support.utils").check_brackets(response) then - score = score - 0.4 - end + -- Check bracket balance + if not require("codetyper.support.utils").check_brackets(response) then + score = score - 0.4 + end - -- Check for common incomplete patterns + -- Check for common incomplete patterns - -- Lua: unbalanced end/function - local function_count = select(2, response:gsub("function%s*%(", "")) - + select(2, response:gsub("function%s+%w+%(", "")) - local end_count = select(2, response:gsub("%f[%w]end%f[%W]", "")) - if function_count > end_count + 2 then - score = score - 0.2 - end + -- Lua: unbalanced end/function + local function_count = select(2, response:gsub("function%s*%(", "")) + + select(2, response:gsub("function%s+%w+%(", "")) + local end_count = select(2, response:gsub("%f[%w]end%f[%W]", "")) + if function_count > end_count + 2 then + score = score - 0.2 + end - -- JavaScript/TypeScript: unclosed template literals - local backtick_count = select(2, response:gsub("`", "")) - if backtick_count % 2 ~= 0 then - score = score - 0.2 - end + -- JavaScript/TypeScript: unclosed template literals + local backtick_count = select(2, response:gsub("`", "")) + if backtick_count % 2 ~= 0 then + score = score - 0.2 + end - -- String quotes balance - local double_quotes = select(2, response:gsub('"', "")) - local single_quotes = select(2, response:gsub("'", "")) - -- Allow for escaped quotes by being lenient - if double_quotes % 2 ~= 0 and not response:find('\\"') then - score = score - 0.1 - end - if single_quotes % 2 ~= 0 and not response:find("\\'") then - score = score - 0.1 - end + -- String quotes balance + local double_quotes = select(2, response:gsub('"', "")) + local single_quotes = select(2, response:gsub("'", "")) + -- Allow for escaped quotes by being lenient + if double_quotes % 2 ~= 0 and not response:find('\\"') then + score = score - 0.1 + end + if single_quotes % 2 ~= 0 and not response:find("\\'") then + score = score - 0.1 + end - return math.max(0, score) + return math.max(0, score) end --- Score based on line repetition ---@param response string ---@return number 0.0-1.0 local function score_repetition(response) - local lines = vim.split(response, "\n", { plain = true }) - if #lines < 3 then - return 1.0 - end + local lines = vim.split(response, "\n", { plain = true }) + if #lines < 3 then + return 1.0 + end - -- Count duplicate non-empty lines - local seen = {} - local duplicates = 0 + -- Count duplicate non-empty lines + local seen = {} + local duplicates = 0 - for _, line in ipairs(lines) do - local trimmed = vim.trim(line) - if #trimmed > 10 then -- Only check substantial lines - if seen[trimmed] then - duplicates = duplicates + 1 - end - seen[trimmed] = true - end - end + for _, line in ipairs(lines) do + local trimmed = vim.trim(line) + if #trimmed > 10 then -- Only check substantial lines + if seen[trimmed] then + duplicates = duplicates + 1 + end + seen[trimmed] = true + end + end - local dup_ratio = duplicates / #lines + local dup_ratio = duplicates / #lines - if dup_ratio < 0.1 then - return 1.0 - elseif dup_ratio < 0.2 then - return 0.8 - elseif dup_ratio < 0.3 then - return 0.5 - else - return 0.2 -- High repetition = degraded output - end + if dup_ratio < 0.1 then + return 1.0 + elseif dup_ratio < 0.2 then + return 0.8 + elseif dup_ratio < 0.3 then + return 0.5 + else + return 0.2 -- High repetition = degraded output + end end --- Score based on truncation indicators ---@param response string ---@return number 0.0-1.0 local function score_truncation(response) - local score = 1.0 + local score = 1.0 - -- Ends with ellipsis - if response:match("%.%.%.$") then - score = score - 0.5 - end + -- Ends with ellipsis + if response:match("%.%.%.$") then + score = score - 0.5 + end - -- Ends with incomplete comment - if response:match("/%*[^*/]*$") then -- Unclosed /* comment - score = score - 0.4 - end - if response:match("]*$") then -- Unclosed HTML comment - score = score - 0.4 - end + -- Ends with incomplete comment + if response:match("/%*[^*/]*$") then -- Unclosed /* comment + score = score - 0.4 + end + if response:match("]*$") then -- Unclosed HTML comment + score = score - 0.4 + end - -- Ends mid-statement (common patterns) - local trimmed = vim.trim(response) - local last_char = trimmed:sub(-1) + -- Ends mid-statement (common patterns) + local trimmed = vim.trim(response) + local last_char = trimmed:sub(-1) - -- Suspicious endings - if last_char == "=" or last_char == "," or last_char == "(" then - score = score - 0.3 - end + -- Suspicious endings + if last_char == "=" or last_char == "," or last_char == "(" then + score = score - 0.3 + end - -- Very short last line after long response - local lines = vim.split(response, "\n", { plain = true }) - if #lines > 5 then - local last_line = vim.trim(lines[#lines]) - if #last_line < 5 and not last_line:match("^[%}%]%)%;end]") then - score = score - 0.2 - end - end + -- Very short last line after long response + local lines = vim.split(response, "\n", { plain = true }) + if #lines > 5 then + local last_line = vim.trim(lines[#lines]) + if #last_line < 5 and not last_line:match("^[%}%]%)%;end]") then + score = score - 0.2 + end + end - return math.max(0, score) + return math.max(0, score) end ---@class ConfidenceBreakdown @@ -198,37 +198,37 @@ end ---@return number confidence 0.0-1.0 ---@return ConfidenceBreakdown breakdown Individual scores function M.score(response, prompt, context) - _ = context -- Reserved for future use + _ = context -- Reserved for future use - if not response or #response == 0 then - return 0, - { - length = 0, - uncertainty = 0, - syntax = 0, - repetition = 0, - truncation = 0, - weighted_total = 0, - } - end + if not response or #response == 0 then + return 0, + { + length = 0, + uncertainty = 0, + syntax = 0, + repetition = 0, + truncation = 0, + weighted_total = 0, + } + end - local scores = { - length = score_length(response, prompt or ""), - uncertainty = score_uncertainty(response), - syntax = score_syntax(response), - repetition = score_repetition(response), - truncation = score_truncation(response), - } + local scores = { + length = score_length(response, prompt or ""), + uncertainty = score_uncertainty(response), + syntax = score_syntax(response), + repetition = score_repetition(response), + truncation = score_truncation(response), + } - -- Calculate weighted total - local weighted = 0 - for key, weight in pairs(M.weights) do - weighted = weighted + (scores[key] * weight) - end + -- Calculate weighted total + local weighted = 0 + for key, weight in pairs(M.weights) do + weighted = weighted + (scores[key] * weight) + end - scores.weighted_total = weighted + scores.weighted_total = weighted - return weighted, scores + return weighted, scores end --- Check if response needs escalation @@ -236,40 +236,40 @@ end ---@param threshold number|nil Default: 0.7 ---@return boolean needs_escalation function M.needs_escalation(confidence, threshold) - threshold = threshold or 0.7 - return confidence < threshold + threshold = threshold or 0.7 + return confidence < threshold end --- Get human-readable confidence level ---@param confidence number ---@return string function M.level_name(confidence) - if confidence >= 0.9 then - return "excellent" - elseif confidence >= 0.8 then - return "good" - elseif confidence >= 0.7 then - return "acceptable" - elseif confidence >= 0.5 then - return "uncertain" - else - return "poor" - end + if confidence >= 0.9 then + return "excellent" + elseif confidence >= 0.8 then + return "good" + elseif confidence >= 0.7 then + return "acceptable" + elseif confidence >= 0.5 then + return "uncertain" + else + return "poor" + end end --- Format breakdown for logging ---@param breakdown ConfidenceBreakdown ---@return string function M.format_breakdown(breakdown) - return string.format( - "len:%.2f unc:%.2f syn:%.2f rep:%.2f tru:%.2f = %.2f", - breakdown.length, - breakdown.uncertainty, - breakdown.syntax, - breakdown.repetition, - breakdown.truncation, - breakdown.weighted_total - ) + return string.format( + "len:%.2f unc:%.2f syn:%.2f rep:%.2f tru:%.2f = %.2f", + breakdown.length, + breakdown.uncertainty, + breakdown.syntax, + breakdown.repetition, + breakdown.truncation, + breakdown.weighted_total + ) end return M diff --git a/lua/codetyper/core/llm/copilot.lua b/lua/codetyper/core/llm/copilot.lua index dc4199b..9448d12 100644 --- a/lua/codetyper/core/llm/copilot.lua +++ b/lua/codetyper/core/llm/copilot.lua @@ -20,184 +20,184 @@ local ollama_fallback_suggested = false --- Suggest switching to Ollama when rate limits are hit ---@param error_msg string The error message that triggered this function M.suggest_ollama_fallback(error_msg) - if ollama_fallback_suggested then - return - end + if ollama_fallback_suggested then + return + end - -- Check if Ollama is available - local ollama_available = false - vim.fn.jobstart({ "curl", "-s", "http://localhost:11434/api/tags" }, { - on_exit = function(_, code) - if code == 0 then - ollama_available = true - end + -- Check if Ollama is available + local ollama_available = false + vim.fn.jobstart({ "curl", "-s", "http://localhost:11434/api/tags" }, { + on_exit = function(_, code) + if code == 0 then + ollama_available = true + end - vim.schedule(function() - if ollama_available then - -- Switch to Ollama automatically - local codetyper = require("codetyper") - local config = codetyper.get_config() - config.llm.provider = "ollama" + vim.schedule(function() + if ollama_available then + -- Switch to Ollama automatically + local codetyper = require("codetyper") + local config = codetyper.get_config() + config.llm.provider = "ollama" - ollama_fallback_suggested = true - utils.notify( - "⚠️ Copilot rate limit reached. Switched to Ollama automatically.\n" - .. "Original error: " - .. error_msg:sub(1, 100), - vim.log.levels.WARN - ) - else - utils.notify( - "⚠️ Copilot rate limit reached. Ollama not available.\n" - .. "Start Ollama with: ollama serve\n" - .. "Or wait for Copilot limits to reset.", - vim.log.levels.WARN - ) - end - end) - end, - }) + ollama_fallback_suggested = true + utils.notify( + "⚠️ Copilot rate limit reached. Switched to Ollama automatically.\n" + .. "Original error: " + .. error_msg:sub(1, 100), + vim.log.levels.WARN + ) + else + utils.notify( + "⚠️ Copilot rate limit reached. Ollama not available.\n" + .. "Start Ollama with: ollama serve\n" + .. "Or wait for Copilot limits to reset.", + vim.log.levels.WARN + ) + end + end) + end, + }) end --- Get OAuth token from copilot.lua or copilot.vim config ---@return string|nil OAuth token local function get_oauth_token() - local xdg_config = vim.fn.expand("$XDG_CONFIG_HOME") - local os_name = vim.loop.os_uname().sysname:lower() + local xdg_config = vim.fn.expand("$XDG_CONFIG_HOME") + local os_name = vim.loop.os_uname().sysname:lower() - local config_dir - if xdg_config and vim.fn.isdirectory(xdg_config) > 0 then - config_dir = xdg_config - elseif os_name:match("linux") or os_name:match("darwin") then - config_dir = vim.fn.expand("~/.config") - else - config_dir = vim.fn.expand("~/AppData/Local") - end + local config_dir + if xdg_config and vim.fn.isdirectory(xdg_config) > 0 then + config_dir = xdg_config + elseif os_name:match("linux") or os_name:match("darwin") then + config_dir = vim.fn.expand("~/.config") + else + config_dir = vim.fn.expand("~/AppData/Local") + end - -- Try hosts.json (copilot.lua) and apps.json (copilot.vim) - local paths = { "hosts.json", "apps.json" } - for _, filename in ipairs(paths) do - local path = config_dir .. "/github-copilot/" .. filename - if vim.fn.filereadable(path) == 1 then - local content = vim.fn.readfile(path) - if content and #content > 0 then - local ok, data = pcall(vim.json.decode, table.concat(content, "\n")) - if ok and data then - for key, value in pairs(data) do - if key:match("github.com") and value.oauth_token then - return value.oauth_token - end - end - end - end - end - end + -- Try hosts.json (copilot.lua) and apps.json (copilot.vim) + local paths = { "hosts.json", "apps.json" } + for _, filename in ipairs(paths) do + local path = config_dir .. "/github-copilot/" .. filename + if vim.fn.filereadable(path) == 1 then + local content = vim.fn.readfile(path) + if content and #content > 0 then + local ok, data = pcall(vim.json.decode, table.concat(content, "\n")) + if ok and data then + for key, value in pairs(data) do + if key:match("github.com") and value.oauth_token then + return value.oauth_token + end + end + end + end + end + end - return nil + return nil end --- Get model from stored credentials or config ---@return string Model name local function get_model() - -- Priority: stored credentials > config - local credentials = require("codetyper.config.credentials") - local stored_model = credentials.get_model("copilot") - if stored_model then - return stored_model - end + -- Priority: stored credentials > config + local credentials = require("codetyper.config.credentials") + local stored_model = credentials.get_model("copilot") + if stored_model then + return stored_model + end - local codetyper = require("codetyper") - local config = codetyper.get_config() - return config.llm.copilot.model + local codetyper = require("codetyper") + local config = codetyper.get_config() + return config.llm.copilot.model end --- Refresh GitHub token using OAuth token ---@param callback fun(token: table|nil, error: string|nil) local function refresh_token(callback) - if not M.state or not M.state.oauth_token then - callback(nil, "No OAuth token available") - return - end + if not M.state or not M.state.oauth_token then + callback(nil, "No OAuth token available") + return + end - -- Check if current token is still valid - if M.state.github_token and M.state.github_token.expires_at then - if M.state.github_token.expires_at > os.time() then - callback(M.state.github_token, nil) - return - end - end + -- Check if current token is still valid + if M.state.github_token and M.state.github_token.expires_at then + if M.state.github_token.expires_at > os.time() then + callback(M.state.github_token, nil) + return + end + end - local cmd = { - "curl", - "-s", - "-X", - "GET", - AUTH_URL, - "-H", - "Authorization: token " .. M.state.oauth_token, - "-H", - "Accept: application/json", - } + local cmd = { + "curl", + "-s", + "-X", + "GET", + AUTH_URL, + "-H", + "Authorization: token " .. M.state.oauth_token, + "-H", + "Accept: application/json", + } - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if not data or #data == 0 or (data[1] == "" and #data == 1) then - return - end + vim.fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data) + if not data or #data == 0 or (data[1] == "" and #data == 1) then + return + end - local response_text = table.concat(data, "\n") - local ok, token = pcall(vim.json.decode, response_text) + local response_text = table.concat(data, "\n") + local ok, token = pcall(vim.json.decode, response_text) - if not ok then - vim.schedule(function() - callback(nil, "Failed to parse token response") - end) - return - end + if not ok then + vim.schedule(function() + callback(nil, "Failed to parse token response") + end) + return + end - if token.error then - vim.schedule(function() - callback(nil, token.error_description or "Token refresh failed") - end) - return - end + if token.error then + vim.schedule(function() + callback(nil, token.error_description or "Token refresh failed") + end) + return + end - M.state.github_token = token - vim.schedule(function() - callback(token, nil) - end) - end, - on_stderr = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - callback(nil, "Token refresh failed: " .. table.concat(data, "\n")) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - callback(nil, "Token refresh failed with code: " .. code) - end) - end - end, - }) + M.state.github_token = token + vim.schedule(function() + callback(token, nil) + end) + end, + on_stderr = function(_, data) + if data and #data > 0 and data[1] ~= "" then + vim.schedule(function() + callback(nil, "Token refresh failed: " .. table.concat(data, "\n")) + end) + end + end, + on_exit = function(_, code) + if code ~= 0 then + vim.schedule(function() + callback(nil, "Token refresh failed with code: " .. code) + end) + end + end, + }) end --- Build request headers ---@param token table GitHub token ---@return table Headers local function build_headers(token) - return { - "Authorization: Bearer " .. token.token, - "Content-Type: application/json", - "User-Agent: GitHubCopilotChat/0.26.7", - "Editor-Version: vscode/1.105.1", - "Editor-Plugin-Version: copilot-chat/0.26.7", - "Copilot-Integration-Id: vscode-chat", - "Openai-Intent: conversation-edits", - } + return { + "Authorization: Bearer " .. token.token, + "Content-Type: application/json", + "User-Agent: GitHubCopilotChat/0.26.7", + "Editor-Version: vscode/1.105.1", + "Editor-Plugin-Version: copilot-chat/0.26.7", + "Copilot-Integration-Id: vscode-chat", + "Openai-Intent: conversation-edits", + } end --- Build request body for Copilot API @@ -205,18 +205,18 @@ end ---@param context table Context information ---@return table Request body local function build_request_body(prompt, context) - local system_prompt = llm.build_system_prompt(context) + local system_prompt = llm.build_system_prompt(context) - return { - model = get_model(), - messages = { - { role = "system", content = system_prompt }, - { role = "user", content = prompt }, - }, - max_tokens = 4096, - temperature = 0.2, - stream = false, - } + return { + model = get_model(), + messages = { + { role = "system", content = system_prompt }, + { role = "user", content = prompt }, + }, + max_tokens = 4096, + temperature = 0.2, + stream = false, + } end --- Make HTTP request to Copilot API @@ -224,125 +224,122 @@ end ---@param body table Request body ---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) local function make_request(token, body, callback) - local endpoint = (token.endpoints and token.endpoints.api or "https://api.githubcopilot.com") .. "/chat/completions" - local json_body = vim.json.encode(body) + local endpoint = (token.endpoints and token.endpoints.api or "https://api.githubcopilot.com") .. "/chat/completions" + local json_body = vim.json.encode(body) - local headers = build_headers(token) - local cmd = { - "curl", - "-s", - "-X", - "POST", - endpoint, - } + local headers = build_headers(token) + local cmd = { + "curl", + "-s", + "-X", + "POST", + endpoint, + } - for _, header in ipairs(headers) do - table.insert(cmd, "-H") - table.insert(cmd, header) - end + for _, header in ipairs(headers) do + table.insert(cmd, "-H") + table.insert(cmd, header) + end - table.insert(cmd, "-d") - table.insert(cmd, json_body) + table.insert(cmd, "-d") + table.insert(cmd, json_body) - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if not data or #data == 0 or (data[1] == "" and #data == 1) then - return - end + vim.fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data) + if not data or #data == 0 or (data[1] == "" and #data == 1) then + return + end - local response_text = table.concat(data, "\n") - local ok, response = pcall(vim.json.decode, response_text) + local response_text = table.concat(data, "\n") + local ok, response = pcall(vim.json.decode, response_text) - if not ok then - -- Show the actual response text as the error (truncated if too long) - local error_msg = response_text - if #error_msg > 200 then - error_msg = error_msg:sub(1, 200) .. "..." - end + if not ok then + -- Show the actual response text as the error (truncated if too long) + local error_msg = response_text + if #error_msg > 200 then + error_msg = error_msg:sub(1, 200) .. "..." + end - -- Clean up common patterns - if response_text:match(" 0 and data[1] ~= "" then - vim.schedule(function() - callback(nil, "Copilot API request failed: " .. table.concat(data, "\n"), nil) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - callback(nil, "Copilot API request failed with code: " .. code, nil) - end) - end - end, - }) + if response.choices and response.choices[1] and response.choices[1].message then + local code = llm.extract_code(response.choices[1].message.content) + vim.schedule(function() + callback(code, nil, usage) + end) + else + vim.schedule(function() + callback(nil, "No content in Copilot response", nil) + end) + end + end, + on_stderr = function(_, data) + if data and #data > 0 and data[1] ~= "" then + vim.schedule(function() + callback(nil, "Copilot API request failed: " .. table.concat(data, "\n"), nil) + end) + end + end, + on_exit = function(_, code) + if code ~= 0 then + vim.schedule(function() + callback(nil, "Copilot API request failed with code: " .. code, nil) + end) + end + end, + }) end --- Initialize Copilot state local function ensure_initialized() - if not M.state then - M.state = { - oauth_token = get_oauth_token(), - github_token = nil, - } - end + if not M.state then + M.state = { + oauth_token = get_oauth_token(), + github_token = nil, + } + end end --- Generate code using Copilot API @@ -350,44 +347,44 @@ end ---@param context table Context information ---@param callback fun(response: string|nil, error: string|nil) function M.generate(prompt, context, callback) - ensure_initialized() + ensure_initialized() - if not M.state.oauth_token then - local err = "Copilot not authenticated. Please set up copilot.lua or copilot.vim first." - callback(nil, err) - return - end + if not M.state.oauth_token then + local err = "Copilot not authenticated. Please set up copilot.lua or copilot.vim first." + callback(nil, err) + return + end - refresh_token(function(token, err) - if err then - utils.notify(err, vim.log.levels.ERROR) - callback(nil, err) - return - end + refresh_token(function(token, err) + if err then + utils.notify(err, vim.log.levels.ERROR) + callback(nil, err) + return + end - local body = build_request_body(prompt, context) - utils.notify("Sending request to Copilot...", vim.log.levels.INFO) + local body = build_request_body(prompt, context) + utils.notify("Sending request to Copilot...", vim.log.levels.INFO) - make_request(token, body, function(response, request_err, usage) - if request_err then - utils.notify(request_err, vim.log.levels.ERROR) - callback(nil, request_err) - else - utils.notify("Code generated successfully", vim.log.levels.INFO) - callback(response, nil) - end - end) - end) + make_request(token, body, function(response, request_err, usage) + if request_err then + utils.notify(request_err, vim.log.levels.ERROR) + callback(nil, request_err) + else + utils.notify("Code generated successfully", vim.log.levels.INFO) + callback(response, nil) + end + end) + end) end --- Check if Copilot is properly configured ---@return boolean, string? Valid status and optional error message function M.validate() - ensure_initialized() - if not M.state.oauth_token then - return false, "Copilot not authenticated. Set up copilot.lua or copilot.vim first." - end - return true + ensure_initialized() + if not M.state.oauth_token then + return false, "Copilot not authenticated. Set up copilot.lua or copilot.vim first." + end + return true end return M diff --git a/lua/codetyper/core/llm/init.lua b/lua/codetyper/core/llm/init.lua index 1736c3a..410b3a8 100644 --- a/lua/codetyper/core/llm/init.lua +++ b/lua/codetyper/core/llm/init.lua @@ -7,16 +7,16 @@ local utils = require("codetyper.support.utils") --- Get the appropriate LLM client based on configuration ---@return table LLM client module function M.get_client() - local codetyper = require("codetyper") - local config = codetyper.get_config() + local codetyper = require("codetyper") + local config = codetyper.get_config() - if config.llm.provider == "ollama" then - return require("codetyper.core.llm.ollama") - elseif config.llm.provider == "copilot" then - return require("codetyper.core.llm.copilot") - else - error("Unknown LLM provider: " .. config.llm.provider .. ". Supported: ollama, copilot") - end + if config.llm.provider == "ollama" then + return require("codetyper.core.llm.ollama") + elseif config.llm.provider == "copilot" then + return require("codetyper.core.llm.copilot") + else + error("Unknown LLM provider: " .. config.llm.provider .. ". Supported: ollama, copilot") + end end --- Generate code from a prompt @@ -24,8 +24,8 @@ end ---@param context table Context information (file content, language, etc.) ---@param callback fun(response: string|nil, error: string|nil) Callback function function M.generate(prompt, context, callback) - local client = M.get_client() - client.generate(prompt, context, callback) + local client = M.get_client() + client.generate(prompt, context, callback) end --- Smart generate with automatic provider selection based on brain memories @@ -35,97 +35,97 @@ end ---@param context table Context information ---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback function M.smart_generate(prompt, context, callback) - local selector = require("codetyper.core.llm.selector") - selector.smart_generate(prompt, context, callback) + local selector = require("codetyper.core.llm.selector") + selector.smart_generate(prompt, context, callback) end --- Get accuracy statistics for providers ---@return table Statistics for each provider function M.get_accuracy_stats() - local selector = require("codetyper.core.llm.selector") - return selector.get_accuracy_stats() + local selector = require("codetyper.core.llm.selector") + return selector.get_accuracy_stats() end --- Report user feedback on response quality (for reinforcement learning) ---@param provider string Which provider generated the response ---@param was_correct boolean Whether the response was good function M.report_feedback(provider, was_correct) - local selector = require("codetyper.core.llm.selector") - selector.report_feedback(provider, was_correct) + local selector = require("codetyper.core.llm.selector") + selector.report_feedback(provider, was_correct) end --- Build the system prompt for code generation ---@param context table Context information ---@return string System prompt function M.build_system_prompt(context) - local prompts = require("codetyper.prompts") + local prompts = require("codetyper.prompts") - -- Select appropriate system prompt based on context - local prompt_type = context.prompt_type or "code_generation" - local system_prompts = prompts.system + -- Select appropriate system prompt based on context + local prompt_type = context.prompt_type or "code_generation" + local system_prompts = prompts.system - local system = system_prompts[prompt_type] or system_prompts.code_generation + local system = system_prompts[prompt_type] or system_prompts.code_generation - -- Substitute variables - system = system:gsub("{{language}}", context.language or "unknown") - system = system:gsub("{{filepath}}", context.file_path or "unknown") + -- Substitute variables + system = system:gsub("{{language}}", context.language or "unknown") + system = system:gsub("{{filepath}}", context.file_path or "unknown") - -- For agent mode, include project context - if prompt_type == "agent" then - local project_info = "\n\n## PROJECT CONTEXT\n" + -- For agent mode, include project context + if prompt_type == "agent" then + local project_info = "\n\n## PROJECT CONTEXT\n" - if context.project_root then - project_info = project_info .. "- Project root: " .. context.project_root .. "\n" - end - if context.cwd then - project_info = project_info .. "- Working directory: " .. context.cwd .. "\n" - end - if context.project_type then - project_info = project_info .. "- Project type: " .. context.project_type .. "\n" - end - if context.project_stats then - project_info = project_info - .. string.format( - "- Stats: %d files, %d functions, %d classes\n", - context.project_stats.files or 0, - context.project_stats.functions or 0, - context.project_stats.classes or 0 - ) - end - if context.file_path then - project_info = project_info .. "- Current file: " .. context.file_path .. "\n" - end + if context.project_root then + project_info = project_info .. "- Project root: " .. context.project_root .. "\n" + end + if context.cwd then + project_info = project_info .. "- Working directory: " .. context.cwd .. "\n" + end + if context.project_type then + project_info = project_info .. "- Project type: " .. context.project_type .. "\n" + end + if context.project_stats then + project_info = project_info + .. string.format( + "- Stats: %d files, %d functions, %d classes\n", + context.project_stats.files or 0, + context.project_stats.functions or 0, + context.project_stats.classes or 0 + ) + end + if context.file_path then + project_info = project_info .. "- Current file: " .. context.file_path .. "\n" + end - system = system .. project_info - return system - end + system = system .. project_info + return system + end - -- For "ask" or "explain" mode, don't add code generation instructions - if prompt_type == "ask" or prompt_type == "explain" then - -- Just add context about the file if available - if context.file_path then - system = system .. "\n\nContext: The user is working with " .. context.file_path - if context.language then - system = system .. " (" .. context.language .. ")" - end - end - return system - end + -- For "ask" or "explain" mode, don't add code generation instructions + if prompt_type == "ask" or prompt_type == "explain" then + -- Just add context about the file if available + if context.file_path then + system = system .. "\n\nContext: The user is working with " .. context.file_path + if context.language then + system = system .. " (" .. context.language .. ")" + end + end + return system + end - -- Add file content with analysis hints (for code generation modes) - if context.file_content and context.file_content ~= "" then - system = system .. "\n\n===== EXISTING FILE CONTENT (analyze and match this style) =====\n" - system = system .. context.file_content - system = system .. "\n===== END OF EXISTING FILE =====\n" - system = system .. "\nYour generated code MUST follow the exact patterns shown above." - else - system = system - .. "\n\nThis is a new/empty file. Generate clean, idiomatic " - .. (context.language or "code") - .. " following best practices." - end + -- Add file content with analysis hints (for code generation modes) + if context.file_content and context.file_content ~= "" then + system = system .. "\n\n===== EXISTING FILE CONTENT (analyze and match this style) =====\n" + system = system .. context.file_content + system = system .. "\n===== END OF EXISTING FILE =====\n" + system = system .. "\nYour generated code MUST follow the exact patterns shown above." + else + system = system + .. "\n\nThis is a new/empty file. Generate clean, idiomatic " + .. (context.language or "code") + .. " following best practices." + end - return system + return system end --- Build context for LLM request @@ -133,70 +133,70 @@ end ---@param prompt_type string Type of prompt ---@return table Context object function M.build_context(target_path, prompt_type) - local content = utils.read_file(target_path) - local ext = vim.fn.fnamemodify(target_path, ":e") + local content = utils.read_file(target_path) + local ext = vim.fn.fnamemodify(target_path, ":e") - local context = { - file_content = content, - language = lang_map[ext] or ext, - extension = ext, - prompt_type = prompt_type, - file_path = target_path, - } + local context = { + file_content = content, + language = lang_map[ext] or ext, + extension = ext, + prompt_type = prompt_type, + file_path = target_path, + } - -- For agent mode, include additional project context - if prompt_type == "agent" then - local project_root = utils.get_project_root() - context.project_root = project_root + -- For agent mode, include additional project context + if prompt_type == "agent" then + local project_root = utils.get_project_root() + context.project_root = project_root - -- Try to get project info from indexer - local ok_indexer, indexer = pcall(require, "codetyper.indexer") - if ok_indexer then - local status = indexer.get_status() - if status.indexed then - context.project_type = status.project_type - context.project_stats = status.stats - end - end + -- Try to get project info from indexer + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if ok_indexer then + local status = indexer.get_status() + if status.indexed then + context.project_type = status.project_type + context.project_stats = status.stats + end + end - -- Include working directory - context.cwd = vim.fn.getcwd() - end + -- Include working directory + context.cwd = vim.fn.getcwd() + end - return context + return context end --- Parse LLM response and extract code ---@param response string Raw LLM response ---@return string Extracted code function M.extract_code(response) - local code = response + local code = response - -- Remove markdown code blocks with language tags (```typescript, ```javascript, etc.) - code = code:gsub("```%w+%s*\n", "") - code = code:gsub("```%w+%s*$", "") - code = code:gsub("^```%w*\n?", "") - code = code:gsub("\n?```%s*$", "") - code = code:gsub("\n```\n", "\n") - code = code:gsub("```", "") + -- Remove markdown code blocks with language tags (```typescript, ```javascript, etc.) + code = code:gsub("```%w+%s*\n", "") + code = code:gsub("```%w+%s*$", "") + code = code:gsub("^```%w*\n?", "") + code = code:gsub("\n?```%s*$", "") + code = code:gsub("\n```\n", "\n") + code = code:gsub("```", "") - -- Remove common explanation prefixes that LLMs sometimes add - code = code:gsub("^Here.-:\n", "") - code = code:gsub("^Here's.-:\n", "") - code = code:gsub("^This.-:\n", "") - code = code:gsub("^The following.-:\n", "") - code = code:gsub("^Below.-:\n", "") + -- Remove common explanation prefixes that LLMs sometimes add + code = code:gsub("^Here.-:\n", "") + code = code:gsub("^Here's.-:\n", "") + code = code:gsub("^This.-:\n", "") + code = code:gsub("^The following.-:\n", "") + code = code:gsub("^Below.-:\n", "") - -- Remove common explanation suffixes - code = code:gsub("\n\nThis code.-$", "") - code = code:gsub("\n\nThe above.-$", "") - code = code:gsub("\n\nNote:.-$", "") - code = code:gsub("\n\nExplanation:.-$", "") + -- Remove common explanation suffixes + code = code:gsub("\n\nThis code.-$", "") + code = code:gsub("\n\nThe above.-$", "") + code = code:gsub("\n\nNote:.-$", "") + code = code:gsub("\n\nExplanation:.-$", "") - -- Trim leading/trailing whitespace but preserve internal formatting - code = code:match("^%s*(.-)%s*$") or code + -- Trim leading/trailing whitespace but preserve internal formatting + code = code:match("^%s*(.-)%s*$") or code - return code + return code end return M diff --git a/lua/codetyper/core/llm/ollama.lua b/lua/codetyper/core/llm/ollama.lua index 4b27672..3a2a8dc 100644 --- a/lua/codetyper/core/llm/ollama.lua +++ b/lua/codetyper/core/llm/ollama.lua @@ -8,31 +8,31 @@ local llm = require("codetyper.core.llm") --- Get Ollama host from stored credentials or config ---@return string Host URL local function get_host() - -- Priority: stored credentials > config - local credentials = require("codetyper.config.credentials") - local stored_host = credentials.get_ollama_host() - if stored_host then - return stored_host - end + -- Priority: stored credentials > config + local credentials = require("codetyper.config.credentials") + local stored_host = credentials.get_ollama_host() + if stored_host then + return stored_host + end - local codetyper = require("codetyper") - local config = codetyper.get_config() - return config.llm.ollama.host + local codetyper = require("codetyper") + local config = codetyper.get_config() + return config.llm.ollama.host end --- Get model from stored credentials or config ---@return string Model name local function get_model() - -- Priority: stored credentials > config - local credentials = require("codetyper.config.credentials") - local stored_model = credentials.get_model("ollama") - if stored_model then - return stored_model - end + -- Priority: stored credentials > config + local credentials = require("codetyper.config.credentials") + local stored_model = credentials.get_model("ollama") + if stored_model then + return stored_model + end - local codetyper = require("codetyper") - local config = codetyper.get_config() - return config.llm.ollama.model + local codetyper = require("codetyper") + local config = codetyper.get_config() + return config.llm.ollama.model end --- Build request body for Ollama API @@ -40,96 +40,96 @@ end ---@param context table Context information ---@return table Request body local function build_request_body(prompt, context) - local system_prompt = llm.build_system_prompt(context) + local system_prompt = llm.build_system_prompt(context) - return { - model = get_model(), - system = system_prompt, - prompt = prompt, - stream = false, - options = { - temperature = 0.2, - num_predict = 4096, - }, - } + return { + model = get_model(), + system = system_prompt, + prompt = prompt, + stream = false, + options = { + temperature = 0.2, + num_predict = 4096, + }, + } end --- Make HTTP request to Ollama API ---@param body table Request body ---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function local function make_request(body, callback) - local host = get_host() - local url = host .. "/api/generate" - local json_body = vim.json.encode(body) + local host = get_host() + local url = host .. "/api/generate" + local json_body = vim.json.encode(body) - local cmd = { - "curl", - "-s", - "-X", - "POST", - url, - "-H", - "Content-Type: application/json", - "-d", - json_body, - } + local cmd = { + "curl", + "-s", + "-X", + "POST", + url, + "-H", + "Content-Type: application/json", + "-d", + json_body, + } - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if not data or #data == 0 or (data[1] == "" and #data == 1) then - return - end + vim.fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data) + if not data or #data == 0 or (data[1] == "" and #data == 1) then + return + end - local response_text = table.concat(data, "\n") - local ok, response = pcall(vim.json.decode, response_text) + local response_text = table.concat(data, "\n") + local ok, response = pcall(vim.json.decode, response_text) - if not ok then - vim.schedule(function() - callback(nil, "Failed to parse Ollama response", nil) - end) - return - end + if not ok then + vim.schedule(function() + callback(nil, "Failed to parse Ollama response", nil) + end) + return + end - if response.error then - vim.schedule(function() - callback(nil, response.error or "Ollama API error", nil) - end) - return - end + if response.error then + vim.schedule(function() + callback(nil, response.error or "Ollama API error", nil) + end) + return + end - -- Extract usage info - local usage = { - prompt_tokens = response.prompt_eval_count or 0, - response_tokens = response.eval_count or 0, - } + -- Extract usage info + local usage = { + prompt_tokens = response.prompt_eval_count or 0, + response_tokens = response.eval_count or 0, + } - if response.response then - local code = llm.extract_code(response.response) - vim.schedule(function() - callback(code, nil, usage) - end) - else - vim.schedule(function() - callback(nil, "No response from Ollama", nil) - end) - end - end, - on_stderr = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - callback(nil, "Ollama API request failed with code: " .. code, nil) - end) - end - end, - }) + if response.response then + local code = llm.extract_code(response.response) + vim.schedule(function() + callback(code, nil, usage) + end) + else + vim.schedule(function() + callback(nil, "No response from Ollama", nil) + end) + end + end, + on_stderr = function(_, data) + if data and #data > 0 and data[1] ~= "" then + vim.schedule(function() + callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil) + end) + end + end, + on_exit = function(_, code) + if code ~= 0 then + vim.schedule(function() + callback(nil, "Ollama API request failed with code: " .. code, nil) + end) + end + end, + }) end --- Generate code using Ollama API @@ -137,60 +137,60 @@ end ---@param context table Context information ---@param callback fun(response: string|nil, error: string|nil) Callback function function M.generate(prompt, context, callback) - local model = get_model() + local model = get_model() - local body = build_request_body(prompt, context) - utils.notify("Sending request to Ollama...", vim.log.levels.INFO) + local body = build_request_body(prompt, context) + utils.notify("Sending request to Ollama...", vim.log.levels.INFO) - make_request(body, function(response, err, usage) - if err then - utils.notify(err, vim.log.levels.ERROR) - callback(nil, err) - else - utils.notify("Code generated successfully", vim.log.levels.INFO) - callback(response, nil) - end - end) + make_request(body, function(response, err, usage) + if err then + utils.notify(err, vim.log.levels.ERROR) + callback(nil, err) + else + utils.notify("Code generated successfully", vim.log.levels.INFO) + callback(response, nil) + end + end) end --- Check if Ollama is reachable ---@param callback fun(ok: boolean, error: string|nil) Callback function function M.health_check(callback) - local host = get_host() + local host = get_host() - local cmd = { "curl", "-s", host .. "/api/tags" } + local cmd = { "curl", "-s", host .. "/api/tags" } - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - callback(true, nil) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - callback(false, "Cannot connect to Ollama at " .. host) - end) - end - end, - }) + vim.fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data) + if data and #data > 0 and data[1] ~= "" then + vim.schedule(function() + callback(true, nil) + end) + end + end, + on_exit = function(_, code) + if code ~= 0 then + vim.schedule(function() + callback(false, "Cannot connect to Ollama at " .. host) + end) + end + end, + }) end --- Check if Ollama is properly configured ---@return boolean, string? Valid status and optional error message function M.validate() - local host = get_host() - if not host or host == "" then - return false, "Ollama host not configured" - end - local model = get_model() - if not model or model == "" then - return false, "Ollama model not configured" - end - return true + local host = get_host() + if not host or host == "" then + return false, "Ollama host not configured" + end + local model = get_model() + if not model or model == "" then + return false, "Ollama model not configured" + end + return true end return M diff --git a/lua/codetyper/core/llm/selector.lua b/lua/codetyper/core/llm/selector.lua index 429b146..7251dcf 100644 --- a/lua/codetyper/core/llm/selector.lua +++ b/lua/codetyper/core/llm/selector.lua @@ -34,76 +34,76 @@ local PONDER_SAMPLE_RATE = 0.2 --- Provider accuracy tracking (persisted in brain) local accuracy_cache = { - ollama = { correct = 0, total = 0 }, - copilot = { correct = 0, total = 0 }, + ollama = { correct = 0, total = 0 }, + copilot = { correct = 0, total = 0 }, } --- Get the brain module safely ---@return table|nil local function get_brain() - local ok, brain = pcall(require, "codetyper.brain") - if ok and brain.is_initialized and brain.is_initialized() then - return brain - end - return nil + local ok, brain = pcall(require, "codetyper.brain") + if ok and brain.is_initialized and brain.is_initialized() then + return brain + end + return nil end --- Load accuracy stats from brain local function load_accuracy_stats() - local brain = get_brain() - if not brain then - return - end + local brain = get_brain() + if not brain then + return + end - -- Query for accuracy tracking nodes - pcall(function() - local result = brain.query({ - query = "provider_accuracy_stats", - types = { "metric" }, - limit = 1, - }) + -- Query for accuracy tracking nodes + pcall(function() + local result = brain.query({ + query = "provider_accuracy_stats", + types = { "metric" }, + limit = 1, + }) - if result and result.nodes and #result.nodes > 0 then - local node = result.nodes[1] - if node.c and node.c.d then - local ok, stats = pcall(vim.json.decode, node.c.d) - if ok and stats then - accuracy_cache = stats - end - end - end - end) + if result and result.nodes and #result.nodes > 0 then + local node = result.nodes[1] + if node.c and node.c.d then + local ok, stats = pcall(vim.json.decode, node.c.d) + if ok and stats then + accuracy_cache = stats + end + end + end + end) end --- Save accuracy stats to brain local function save_accuracy_stats() - local brain = get_brain() - if not brain then - return - end + local brain = get_brain() + if not brain then + return + end - pcall(function() - brain.learn({ - type = "metric", - summary = "provider_accuracy_stats", - detail = vim.json.encode(accuracy_cache), - weight = 1.0, - }) - end) + pcall(function() + brain.learn({ + type = "metric", + summary = "provider_accuracy_stats", + detail = vim.json.encode(accuracy_cache), + weight = 1.0, + }) + end) end --- Calculate Ollama confidence based on historical accuracy ---@return number confidence (0-1) local function get_ollama_historical_confidence() - local stats = accuracy_cache.ollama - if stats.total < 5 then - -- Not enough data, return neutral confidence - return 0.5 - end + local stats = accuracy_cache.ollama + if stats.total < 5 then + -- Not enough data, return neutral confidence + return 0.5 + end - local accuracy = stats.correct / stats.total - -- Boost confidence if accuracy is high - return math.min(1.0, accuracy * 1.2) + local accuracy = stats.correct / stats.total + -- Boost confidence if accuracy is high + return math.min(1.0, accuracy * 1.2) end --- Query brain for relevant context @@ -111,46 +111,46 @@ end ---@param file_path string|nil Current file path ---@return table result {memories: table[], relevance: number, count: number} local function query_brain_context(prompt, file_path) - local result = { - memories = {}, - relevance = 0, - count = 0, - } + local result = { + memories = {}, + relevance = 0, + count = 0, + } - local brain = get_brain() - if not brain then - return result - end + local brain = get_brain() + if not brain then + return result + end - -- Query brain with multiple dimensions - local ok, query_result = pcall(function() - return brain.query({ - query = prompt, - file = file_path, - limit = 10, - types = { "pattern", "correction", "convention", "fact" }, - }) - end) + -- Query brain with multiple dimensions + local ok, query_result = pcall(function() + return brain.query({ + query = prompt, + file = file_path, + limit = 10, + types = { "pattern", "correction", "convention", "fact" }, + }) + end) - if not ok or not query_result then - return result - end + if not ok or not query_result then + return result + end - result.memories = query_result.nodes or {} - result.count = #result.memories + result.memories = query_result.nodes or {} + result.count = #result.memories - -- Calculate average relevance - if result.count > 0 then - local total_relevance = 0 - for _, node in ipairs(result.memories) do - -- Use node weight and success rate as relevance indicators - local node_relevance = (node.sc and node.sc.w or 0.5) * (node.sc and node.sc.sr or 0.5) - total_relevance = total_relevance + node_relevance - end - result.relevance = total_relevance / result.count - end + -- Calculate average relevance + if result.count > 0 then + local total_relevance = 0 + for _, node in ipairs(result.memories) do + -- Use node weight and success rate as relevance indicators + local node_relevance = (node.sc and node.sc.w or 0.5) * (node.sc and node.sc.sr or 0.5) + total_relevance = total_relevance + node_relevance + end + result.relevance = total_relevance / result.count + end - return result + return result end --- Select the best LLM provider based on context @@ -158,80 +158,77 @@ end ---@param context table LLM context ---@return SelectionResult function M.select_provider(prompt, context) - -- Load accuracy stats on first call - if accuracy_cache.ollama.total == 0 then - load_accuracy_stats() - end + -- Load accuracy stats on first call + if accuracy_cache.ollama.total == 0 then + load_accuracy_stats() + end - local file_path = context.file_path + local file_path = context.file_path - -- Query brain for relevant memories - local brain_context = query_brain_context(prompt, file_path) + -- Query brain for relevant memories + local brain_context = query_brain_context(prompt, file_path) - -- Calculate base confidence from memories - local memory_confidence = 0 - if brain_context.count >= MIN_MEMORIES_FOR_LOCAL then - memory_confidence = math.min(1.0, brain_context.count / 10) * brain_context.relevance - end + -- Calculate base confidence from memories + local memory_confidence = 0 + if brain_context.count >= MIN_MEMORIES_FOR_LOCAL then + memory_confidence = math.min(1.0, brain_context.count / 10) * brain_context.relevance + end - -- Factor in historical Ollama accuracy - local historical_confidence = get_ollama_historical_confidence() + -- Factor in historical Ollama accuracy + local historical_confidence = get_ollama_historical_confidence() - -- Combined confidence score - local combined_confidence = (memory_confidence * 0.6) + (historical_confidence * 0.4) + -- Combined confidence score + local combined_confidence = (memory_confidence * 0.6) + (historical_confidence * 0.4) - -- Decision logic - local provider = "copilot" -- Default to more capable - local reason = "" + -- Decision logic + local provider = "copilot" -- Default to more capable + local reason = "" - if brain_context.count >= MIN_MEMORIES_FOR_LOCAL and combined_confidence >= MIN_RELEVANCE_FOR_LOCAL then - provider = "ollama" - reason = string.format( - "Rich context: %d memories (%.1f%% relevance), historical accuracy: %.1f%%", - brain_context.count, - brain_context.relevance * 100, - historical_confidence * 100 - ) - elseif brain_context.count > 0 and combined_confidence >= 0.4 then - -- Medium confidence - use Ollama but with pondering - provider = "ollama" - reason = string.format( - "Moderate context: %d memories, will verify with pondering", - brain_context.count - ) - else - reason = string.format( - "Insufficient context: %d memories (need %d), using capable provider", - brain_context.count, - MIN_MEMORIES_FOR_LOCAL - ) - end + if brain_context.count >= MIN_MEMORIES_FOR_LOCAL and combined_confidence >= MIN_RELEVANCE_FOR_LOCAL then + provider = "ollama" + reason = string.format( + "Rich context: %d memories (%.1f%% relevance), historical accuracy: %.1f%%", + brain_context.count, + brain_context.relevance * 100, + historical_confidence * 100 + ) + elseif brain_context.count > 0 and combined_confidence >= 0.4 then + -- Medium confidence - use Ollama but with pondering + provider = "ollama" + reason = string.format("Moderate context: %d memories, will verify with pondering", brain_context.count) + else + reason = string.format( + "Insufficient context: %d memories (need %d), using capable provider", + brain_context.count, + MIN_MEMORIES_FOR_LOCAL + ) + end - return { - provider = provider, - confidence = combined_confidence, - memory_count = brain_context.count, - reason = reason, - memories = brain_context.memories, - } + return { + provider = provider, + confidence = combined_confidence, + memory_count = brain_context.count, + reason = reason, + memories = brain_context.memories, + } end --- Check if we should ponder (verify) this Ollama response ---@param confidence number Current confidence level ---@return boolean function M.should_ponder(confidence) - -- Always ponder when confidence is medium - if confidence >= 0.4 and confidence < 0.7 then - return true - end + -- Always ponder when confidence is medium + if confidence >= 0.4 and confidence < 0.7 then + return true + end - -- Random sampling for high confidence to keep learning - if confidence >= 0.7 then - return math.random() < PONDER_SAMPLE_RATE - end + -- Random sampling for high confidence to keep learning + if confidence >= 0.7 then + return math.random() < PONDER_SAMPLE_RATE + end - -- Low confidence shouldn't reach Ollama anyway - return false + -- Low confidence shouldn't reach Ollama anyway + return false end --- Calculate agreement score between two responses @@ -239,54 +236,54 @@ end ---@param response2 string Second response ---@return number Agreement score (0-1) local function calculate_agreement(response1, response2) - -- Normalize responses - local norm1 = response1:lower():gsub("%s+", " "):gsub("[^%w%s]", "") - local norm2 = response2:lower():gsub("%s+", " "):gsub("[^%w%s]", "") + -- Normalize responses + local norm1 = response1:lower():gsub("%s+", " "):gsub("[^%w%s]", "") + local norm2 = response2:lower():gsub("%s+", " "):gsub("[^%w%s]", "") - -- Extract words - local words1 = {} - for word in norm1:gmatch("%w+") do - words1[word] = (words1[word] or 0) + 1 - end + -- Extract words + local words1 = {} + for word in norm1:gmatch("%w+") do + words1[word] = (words1[word] or 0) + 1 + end - local words2 = {} - for word in norm2:gmatch("%w+") do - words2[word] = (words2[word] or 0) + 1 - end + local words2 = {} + for word in norm2:gmatch("%w+") do + words2[word] = (words2[word] or 0) + 1 + end - -- Calculate Jaccard similarity - local intersection = 0 - local union = 0 + -- Calculate Jaccard similarity + local intersection = 0 + local union = 0 - for word, count1 in pairs(words1) do - local count2 = words2[word] or 0 - intersection = intersection + math.min(count1, count2) - union = union + math.max(count1, count2) - end + for word, count1 in pairs(words1) do + local count2 = words2[word] or 0 + intersection = intersection + math.min(count1, count2) + union = union + math.max(count1, count2) + end - for word, count2 in pairs(words2) do - if not words1[word] then - union = union + count2 - end - end + for word, count2 in pairs(words2) do + if not words1[word] then + union = union + count2 + end + end - if union == 0 then - return 1.0 -- Both empty - end + if union == 0 then + return 1.0 -- Both empty + end - -- Also check structural similarity (code structure) - local struct_score = 0 - local function_count1 = select(2, response1:gsub("function", "")) - local function_count2 = select(2, response2:gsub("function", "")) - if function_count1 > 0 or function_count2 > 0 then - struct_score = 1 - math.abs(function_count1 - function_count2) / math.max(function_count1, function_count2, 1) - else - struct_score = 1.0 - end + -- Also check structural similarity (code structure) + local struct_score = 0 + local function_count1 = select(2, response1:gsub("function", "")) + local function_count2 = select(2, response2:gsub("function", "")) + if function_count1 > 0 or function_count2 > 0 then + struct_score = 1 - math.abs(function_count1 - function_count2) / math.max(function_count1, function_count2, 1) + else + struct_score = 1.0 + end - -- Combined score - local jaccard = intersection / union - return (jaccard * 0.7) + (struct_score * 0.3) + -- Combined score + local jaccard = intersection / union + return (jaccard * 0.7) + (struct_score * 0.3) end --- Ponder (verify) Ollama's response with another LLM @@ -295,92 +292,85 @@ end ---@param ollama_response string Ollama's response ---@param callback fun(result: PonderResult) Callback with pondering result function M.ponder(prompt, context, ollama_response, callback) - -- Use Copilot as verifier - local copilot = require("codetyper.core.llm.copilot") + -- Use Copilot as verifier + local copilot = require("codetyper.core.llm.copilot") - -- Build verification prompt - local verify_prompt = prompt + -- Build verification prompt + local verify_prompt = prompt - copilot.generate(verify_prompt, context, function(verifier_response, error) - if error or not verifier_response then - -- Verification failed, assume Ollama is correct - callback({ - ollama_response = ollama_response, - verifier_response = "", - agreement_score = 1.0, - ollama_correct = true, - feedback = "Verification unavailable, trusting Ollama", - }) - return - end + copilot.generate(verify_prompt, context, function(verifier_response, error) + if error or not verifier_response then + -- Verification failed, assume Ollama is correct + callback({ + ollama_response = ollama_response, + verifier_response = "", + agreement_score = 1.0, + ollama_correct = true, + feedback = "Verification unavailable, trusting Ollama", + }) + return + end - -- Calculate agreement - local agreement = calculate_agreement(ollama_response, verifier_response) + -- Calculate agreement + local agreement = calculate_agreement(ollama_response, verifier_response) - -- Determine if Ollama was correct - local ollama_correct = agreement >= AGREEMENT_THRESHOLD + -- Determine if Ollama was correct + local ollama_correct = agreement >= AGREEMENT_THRESHOLD - -- Generate feedback - local feedback - if ollama_correct then - feedback = string.format("Agreement: %.1f%% - Ollama response validated", agreement * 100) - else - feedback = string.format( - "Disagreement: %.1f%% - Ollama may need correction", - (1 - agreement) * 100 - ) - end + -- Generate feedback + local feedback + if ollama_correct then + feedback = string.format("Agreement: %.1f%% - Ollama response validated", agreement * 100) + else + feedback = string.format("Disagreement: %.1f%% - Ollama may need correction", (1 - agreement) * 100) + end - -- Update accuracy tracking - accuracy_cache.ollama.total = accuracy_cache.ollama.total + 1 - if ollama_correct then - accuracy_cache.ollama.correct = accuracy_cache.ollama.correct + 1 - end - save_accuracy_stats() + -- Update accuracy tracking + accuracy_cache.ollama.total = accuracy_cache.ollama.total + 1 + if ollama_correct then + accuracy_cache.ollama.correct = accuracy_cache.ollama.correct + 1 + end + save_accuracy_stats() - -- Learn from this verification - local brain = get_brain() - if brain then - pcall(function() - if ollama_correct then - -- Reinforce the pattern - brain.learn({ - type = "correction", - summary = "Ollama verified correct", - detail = string.format( - "Prompt: %s\nAgreement: %.1f%%", - prompt:sub(1, 100), - agreement * 100 - ), - weight = 0.8, - file = context.file_path, - }) - else - -- Learn the correction - brain.learn({ - type = "correction", - summary = "Ollama needed correction", - detail = string.format( - "Prompt: %s\nOllama: %s\nCorrect: %s", - prompt:sub(1, 100), - ollama_response:sub(1, 200), - verifier_response:sub(1, 200) - ), - weight = 0.9, - file = context.file_path, - }) - end - end) - end + -- Learn from this verification + local brain = get_brain() + if brain then + pcall(function() + if ollama_correct then + -- Reinforce the pattern + brain.learn({ + type = "correction", + summary = "Ollama verified correct", + detail = string.format("Prompt: %s\nAgreement: %.1f%%", prompt:sub(1, 100), agreement * 100), + weight = 0.8, + file = context.file_path, + }) + else + -- Learn the correction + brain.learn({ + type = "correction", + summary = "Ollama needed correction", + detail = string.format( + "Prompt: %s\nOllama: %s\nCorrect: %s", + prompt:sub(1, 100), + ollama_response:sub(1, 200), + verifier_response:sub(1, 200) + ), + weight = 0.9, + file = context.file_path, + }) + end + end) + end - callback({ - ollama_response = ollama_response, - verifier_response = verifier_response, - agreement_score = agreement, - ollama_correct = ollama_correct, - feedback = feedback, - }) - end) + callback({ + ollama_response = ollama_response, + verifier_response = verifier_response, + agreement_score = agreement, + ollama_correct = ollama_correct, + feedback = feedback, + }) + end) end --- Smart generate with automatic provider selection and pondering @@ -388,127 +378,124 @@ end ---@param context table LLM context ---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback function M.smart_generate(prompt, context, callback) - -- Select provider - local selection = M.select_provider(prompt, context) + -- Select provider + local selection = M.select_provider(prompt, context) - -- Log selection - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format( - "LLM: %s (confidence: %.1f%%, %s)", - selection.provider, - selection.confidence * 100, - selection.reason - ), - }) - end) + -- Log selection + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format( + "LLM: %s (confidence: %.1f%%, %s)", + selection.provider, + selection.confidence * 100, + selection.reason + ), + }) + end) - -- Get the selected client - local client - if selection.provider == "ollama" then - client = require("codetyper.core.llm.ollama") - else - client = require("codetyper.core.llm.copilot") - end + -- Get the selected client + local client + if selection.provider == "ollama" then + client = require("codetyper.core.llm.ollama") + else + client = require("codetyper.core.llm.copilot") + end - -- Generate response - client.generate(prompt, context, function(response, error) - if error then - -- Fallback on error - if selection.provider == "ollama" then - -- Try Copilot as fallback - local copilot = require("codetyper.core.llm.copilot") - copilot.generate(prompt, context, function(fallback_response, fallback_error) - callback(fallback_response, fallback_error, { - provider = "copilot", - fallback = true, - original_provider = "ollama", - original_error = error, - }) - end) - return - end - callback(nil, error, { provider = selection.provider }) - return - end + -- Generate response + client.generate(prompt, context, function(response, error) + if error then + -- Fallback on error + if selection.provider == "ollama" then + -- Try Copilot as fallback + local copilot = require("codetyper.core.llm.copilot") + copilot.generate(prompt, context, function(fallback_response, fallback_error) + callback(fallback_response, fallback_error, { + provider = "copilot", + fallback = true, + original_provider = "ollama", + original_error = error, + }) + end) + return + end + callback(nil, error, { provider = selection.provider }) + return + end - -- Check if we should ponder - if selection.provider == "ollama" and M.should_ponder(selection.confidence) then - M.ponder(prompt, context, response, function(ponder_result) - if ponder_result.ollama_correct then - -- Ollama was correct, use its response - callback(response, nil, { - provider = "ollama", - pondered = true, - agreement = ponder_result.agreement_score, - confidence = selection.confidence, - }) - else - -- Use verifier's response instead - callback(ponder_result.verifier_response, nil, { - provider = "copilot", - pondered = true, - agreement = ponder_result.agreement_score, - original_provider = "ollama", - corrected = true, - }) - end - end) - else - -- No pondering needed - callback(response, nil, { - provider = selection.provider, - pondered = false, - confidence = selection.confidence, - }) - end - end) + -- Check if we should ponder + if selection.provider == "ollama" and M.should_ponder(selection.confidence) then + M.ponder(prompt, context, response, function(ponder_result) + if ponder_result.ollama_correct then + -- Ollama was correct, use its response + callback(response, nil, { + provider = "ollama", + pondered = true, + agreement = ponder_result.agreement_score, + confidence = selection.confidence, + }) + else + -- Use verifier's response instead + callback(ponder_result.verifier_response, nil, { + provider = "copilot", + pondered = true, + agreement = ponder_result.agreement_score, + original_provider = "ollama", + corrected = true, + }) + end + end) + else + -- No pondering needed + callback(response, nil, { + provider = selection.provider, + pondered = false, + confidence = selection.confidence, + }) + end + end) end --- Get current accuracy statistics ---@return table {ollama: {correct, total, accuracy}, copilot: {correct, total, accuracy}} function M.get_accuracy_stats() - local stats = { - ollama = { - correct = accuracy_cache.ollama.correct, - total = accuracy_cache.ollama.total, - accuracy = accuracy_cache.ollama.total > 0 - and (accuracy_cache.ollama.correct / accuracy_cache.ollama.total) - or 0, - }, - copilot = { - correct = accuracy_cache.copilot.correct, - total = accuracy_cache.copilot.total, - accuracy = accuracy_cache.copilot.total > 0 - and (accuracy_cache.copilot.correct / accuracy_cache.copilot.total) - or 0, - }, - } - return stats + local stats = { + ollama = { + correct = accuracy_cache.ollama.correct, + total = accuracy_cache.ollama.total, + accuracy = accuracy_cache.ollama.total > 0 and (accuracy_cache.ollama.correct / accuracy_cache.ollama.total) or 0, + }, + copilot = { + correct = accuracy_cache.copilot.correct, + total = accuracy_cache.copilot.total, + accuracy = accuracy_cache.copilot.total > 0 and (accuracy_cache.copilot.correct / accuracy_cache.copilot.total) + or 0, + }, + } + return stats end --- Reset accuracy statistics function M.reset_accuracy_stats() - accuracy_cache = { - ollama = { correct = 0, total = 0 }, - copilot = { correct = 0, total = 0 }, - } - save_accuracy_stats() + accuracy_cache = { + ollama = { correct = 0, total = 0 }, + copilot = { correct = 0, total = 0 }, + } + save_accuracy_stats() end --- Report user feedback on response quality ---@param provider string Which provider generated the response ---@param was_correct boolean Whether the response was good function M.report_feedback(provider, was_correct) - if accuracy_cache[provider] then - accuracy_cache[provider].total = accuracy_cache[provider].total + 1 - if was_correct then - accuracy_cache[provider].correct = accuracy_cache[provider].correct + 1 - end - save_accuracy_stats() - end + if accuracy_cache[provider] then + accuracy_cache[provider].total = accuracy_cache[provider].total + 1 + if was_correct then + accuracy_cache[provider].correct = accuracy_cache[provider].correct + 1 + end + save_accuracy_stats() + end end return M diff --git a/lua/codetyper/core/marks.lua b/lua/codetyper/core/marks.lua index 6cf9666..0bfb556 100644 --- a/lua/codetyper/core/marks.lua +++ b/lua/codetyper/core/marks.lua @@ -19,19 +19,19 @@ local nsid = vim.api.nvim_create_namespace("codetyper.marks") ---@param col_0 number 0-based column ---@return Mark function M.mark_point(buffer, row_0, col_0) - if not vim.api.nvim_buf_is_valid(buffer) then - return { id = nil, buffer = buffer, nsid = nsid } - end - local line_count = vim.api.nvim_buf_line_count(buffer) - if line_count == 0 or row_0 < 0 or row_0 >= line_count then - return { id = nil, buffer = buffer, nsid = nsid } - end - local id = vim.api.nvim_buf_set_extmark(buffer, nsid, row_0, col_0, {}) - return { - id = id, - buffer = buffer, - nsid = nsid, - } + if not vim.api.nvim_buf_is_valid(buffer) then + return { id = nil, buffer = buffer, nsid = nsid } + end + local line_count = vim.api.nvim_buf_line_count(buffer) + if line_count == 0 or row_0 < 0 or row_0 >= line_count then + return { id = nil, buffer = buffer, nsid = nsid } + end + local id = vim.api.nvim_buf_set_extmark(buffer, nsid, row_0, col_0, {}) + return { + id = id, + buffer = buffer, + nsid = nsid, + } end --- Create marks for a range. start/end are 1-based line numbers; end_col_0 is 0-based column on end line. @@ -42,10 +42,10 @@ end ---@return Mark start_mark ---@return Mark end_mark function M.mark_range(buffer, start_line, end_line, end_col_0) - end_col_0 = end_col_0 or 0 - local start_mark = M.mark_point(buffer, start_line - 1, 0) - local end_mark = M.mark_point(buffer, end_line - 1, end_col_0) - return start_mark, end_mark + end_col_0 = end_col_0 or 0 + local start_mark = M.mark_point(buffer, start_line - 1, 0) + local end_mark = M.mark_point(buffer, end_line - 1, end_col_0) + return start_mark, end_mark end --- Get current 0-based (row, col) of a mark. Returns nil if mark invalid. @@ -53,25 +53,25 @@ end ---@return number|nil row_0 ---@return number|nil col_0 function M.get_position(mark) - if not mark or not mark.id or not vim.api.nvim_buf_is_valid(mark.buffer) then - return nil, nil - end - local pos = vim.api.nvim_buf_get_extmark_by_id(mark.buffer, mark.nsid, mark.id, {}) - if not pos or #pos < 2 then - return nil, nil - end - return pos[1], pos[2] + if not mark or not mark.id or not vim.api.nvim_buf_is_valid(mark.buffer) then + return nil, nil + end + local pos = vim.api.nvim_buf_get_extmark_by_id(mark.buffer, mark.nsid, mark.id, {}) + if not pos or #pos < 2 then + return nil, nil + end + return pos[1], pos[2] end --- Check if mark still exists and buffer valid. ---@param mark Mark ---@return boolean function M.is_valid(mark) - if not mark or not mark.id then - return false - end - local row, col = M.get_position(mark) - return row ~= nil and col ~= nil + if not mark or not mark.id then + return false + end + local row, col = M.get_position(mark) + return row ~= nil and col ~= nil end --- Get current range as 0-based (start_row, start_col, end_row, end_col) for nvim_buf_set_text. Returns nil if any mark invalid. @@ -79,12 +79,12 @@ end ---@param end_mark Mark ---@return number|nil, number|nil, number|nil, number|nil function M.range_to_vim(start_mark, end_mark) - local sr, sc = M.get_position(start_mark) - local er, ec = M.get_position(end_mark) - if sr == nil or er == nil then - return nil, nil, nil, nil - end - return sr, sc, er, ec + local sr, sc = M.get_position(start_mark) + local er, ec = M.get_position(end_mark) + if sr == nil or er == nil then + return nil, nil, nil, nil + end + return sr, sc, er, ec end --- Replace text between two marks with lines (like 99 Range:replace_text). Uses current positions from extmarks. @@ -94,24 +94,24 @@ end ---@param lines string[] ---@return boolean success function M.replace_text(buffer, start_mark, end_mark, lines) - local sr, sc, er, ec = M.range_to_vim(start_mark, end_mark) - if sr == nil then - return false - end - if not vim.api.nvim_buf_is_valid(buffer) then - return false - end - vim.api.nvim_buf_set_text(buffer, sr, sc, er, ec, lines) - return true + local sr, sc, er, ec = M.range_to_vim(start_mark, end_mark) + if sr == nil then + return false + end + if not vim.api.nvim_buf_is_valid(buffer) then + return false + end + vim.api.nvim_buf_set_text(buffer, sr, sc, er, ec, lines) + return true end --- Delete extmark (cleanup). ---@param mark Mark function M.delete(mark) - if not mark or not mark.id or not vim.api.nvim_buf_is_valid(mark.buffer) then - return - end - pcall(vim.api.nvim_buf_del_extmark, mark.buffer, mark.nsid, mark.id) + if not mark or not mark.id or not vim.api.nvim_buf_is_valid(mark.buffer) then + return + end + pcall(vim.api.nvim_buf_del_extmark, mark.buffer, mark.nsid, mark.id) end return M diff --git a/lua/codetyper/core/memory/delta/commit.lua b/lua/codetyper/core/memory/delta/commit.lua index d84dad5..a9cc2d4 100644 --- a/lua/codetyper/core/memory/delta/commit.lua +++ b/lua/codetyper/core/memory/delta/commit.lua @@ -282,7 +282,12 @@ function M.format(delta) "", " " .. (delta.m.msg or "No message"), "", - string.format(" %d additions, %d modifications, %d deletions", summary.stats.adds, summary.stats.modifies, summary.stats.deletes), + string.format( + " %d additions, %d modifications, %d deletions", + summary.stats.adds, + summary.stats.modifies, + summary.stats.deletes + ), } return lines diff --git a/lua/codetyper/core/memory/graph/query.lua b/lua/codetyper/core/memory/graph/query.lua index 1761c80..12c2d8b 100644 --- a/lua/codetyper/core/memory/graph/query.lua +++ b/lua/codetyper/core/memory/graph/query.lua @@ -323,8 +323,8 @@ function M.execute(opts) -- Nodes connected to multiple relevant seeds get higher activation local final_activations = spreading_activation( seed_activations, - opts.spread_iterations or 3, -- How far activation spreads - opts.spread_decay or 0.5, -- How much activation decays per hop + opts.spread_iterations or 3, -- How far activation spreads + opts.spread_decay or 0.5, -- How much activation decays per hop opts.spread_threshold or 0.05 -- Minimum activation to continue spreading ) diff --git a/lua/codetyper/core/memory/output/formatter.lua b/lua/codetyper/core/memory/output/formatter.lua index 0919dbe..18f6346 100644 --- a/lua/codetyper/core/memory/output/formatter.lua +++ b/lua/codetyper/core/memory/output/formatter.lua @@ -80,13 +80,7 @@ function M.to_compact(result, opts) break end - local conn_line = string.format( - " %s --%s(%.2f)--> %s", - edge.s:sub(-8), - edge.ty, - edge.p.w or 0.5, - edge.t:sub(-8) - ) + local conn_line = string.format(" %s --%s(%.2f)--> %s", edge.s:sub(-8), edge.ty, edge.p.w or 0.5, edge.t:sub(-8)) table.insert(lines, conn_line) current_tokens = current_tokens + M.estimate_tokens(conn_line) end @@ -235,7 +229,10 @@ function M.format_chain(chain) for i, item in ipairs(chain) do if item.node then local prefix = i == 1 and "" or " -> " - table.insert(lines, string.format("%s[%s] %s (w:%.2f)", prefix, item.node.t:upper(), item.node.c.s:sub(1, 50), item.node.sc.w)) + table.insert( + lines, + string.format("%s[%s] %s (w:%.2f)", prefix, item.node.t:upper(), item.node.c.s:sub(1, 50), item.node.sc.w) + ) end if item.edge then table.insert(lines, string.format(" via %s (w:%.2f)", item.edge.ty, item.edge.p.w)) diff --git a/lua/codetyper/core/scheduler/executor.lua b/lua/codetyper/core/scheduler/executor.lua index 607334b..257aacf 100644 --- a/lua/codetyper/core/scheduler/executor.lua +++ b/lua/codetyper/core/scheduler/executor.lua @@ -247,7 +247,10 @@ function M.handle_bash(params, callback) local command = params.command -- Log the bash operation - logs.add({ type = "action", message = string.format("Bash(%s)", command:sub(1, 50) .. (#command > 50 and "..." or "")) }) + logs.add({ + type = "action", + message = string.format("Bash(%s)", command:sub(1, 50) .. (#command > 50 and "..." or "")), + }) logs.add({ type = "result", message = " ⎿ Pending approval" }) -- Requires user approval first @@ -374,7 +377,8 @@ function M.handle_search_files(params, callback) if content_search then -- Search by content using grep local grep_results = {} - local grep_cmd = string.format("grep -rl '%s' '%s' 2>/dev/null | head -20", content_search:gsub("'", "\\'"), search_path) + local grep_cmd = + string.format("grep -rl '%s' '%s' 2>/dev/null | head -20", content_search:gsub("'", "\\'"), search_path) local handle = io.popen(grep_cmd) if handle then diff --git a/lua/codetyper/core/scheduler/loop.lua b/lua/codetyper/core/scheduler/loop.lua index 2f73222..d89f1c0 100644 --- a/lua/codetyper/core/scheduler/loop.lua +++ b/lua/codetyper/core/scheduler/loop.lua @@ -33,326 +33,326 @@ local prompts = require("codetyper.prompts.agents.loop") ---@param tools CoderTool[] ---@return table[] local function format_tools_for_api(tools) - local formatted = {} - for _, tool in ipairs(tools) do - local properties = {} - local required = {} + local formatted = {} + for _, tool in ipairs(tools) do + local properties = {} + local required = {} - for _, param in ipairs(tool.params or {}) do - properties[param.name] = { - type = param.type == "integer" and "number" or param.type, - description = param.description, - } - if not param.optional then - table.insert(required, param.name) - end - end + for _, param in ipairs(tool.params or {}) do + properties[param.name] = { + type = param.type == "integer" and "number" or param.type, + description = param.description, + } + if not param.optional then + table.insert(required, param.name) + end + end - table.insert(formatted, { - type = "function", - ["function"] = { - name = tool.name, - description = type(tool.description) == "function" and tool.description() or tool.description, - parameters = { - type = "object", - properties = properties, - required = required, - }, - }, - }) - end - return formatted + table.insert(formatted, { + type = "function", + ["function"] = { + name = tool.name, + description = type(tool.description) == "function" and tool.description() or tool.description, + parameters = { + type = "object", + properties = properties, + required = required, + }, + }, + }) + end + return formatted end --- Parse tool calls from LLM response ---@param response table LLM response ---@return table[] tool_calls local function parse_tool_calls(response) - local tool_calls = {} + local tool_calls = {} - -- Handle different response formats - if response.tool_calls then - -- OpenAI format - for _, call in ipairs(response.tool_calls) do - local args = call["function"].arguments - if type(args) == "string" then - local ok, parsed = pcall(vim.json.decode, args) - if ok then - args = parsed - end - end - table.insert(tool_calls, { - id = call.id, - name = call["function"].name, - input = args, - }) - end - elseif response.content and type(response.content) == "table" then - -- Claude format (content blocks) - for _, block in ipairs(response.content) do - if block.type == "tool_use" then - table.insert(tool_calls, { - id = block.id, - name = block.name, - input = block.input, - }) - end - end - end + -- Handle different response formats + if response.tool_calls then + -- OpenAI format + for _, call in ipairs(response.tool_calls) do + local args = call["function"].arguments + if type(args) == "string" then + local ok, parsed = pcall(vim.json.decode, args) + if ok then + args = parsed + end + end + table.insert(tool_calls, { + id = call.id, + name = call["function"].name, + input = args, + }) + end + elseif response.content and type(response.content) == "table" then + -- Claude format (content blocks) + for _, block in ipairs(response.content) do + if block.type == "tool_use" then + table.insert(tool_calls, { + id = block.id, + name = block.name, + input = block.input, + }) + end + end + end - return tool_calls + return tool_calls end --- Build messages for LLM request ---@param history AgentMessage[] ---@return table[] local function build_messages(history) - local messages = {} + local messages = {} - for _, msg in ipairs(history) do - if msg.role == "system" then - table.insert(messages, { - role = "system", - content = msg.content, - }) - elseif msg.role == "user" then - table.insert(messages, { - role = "user", - content = msg.content, - }) - elseif msg.role == "assistant" then - local message = { - role = "assistant", - content = msg.content, - } - if msg.tool_calls then - message.tool_calls = msg.tool_calls - end - table.insert(messages, message) - elseif msg.role == "tool" then - table.insert(messages, { - role = "tool", - tool_call_id = msg.tool_call_id, - content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content), - }) - end - end + for _, msg in ipairs(history) do + if msg.role == "system" then + table.insert(messages, { + role = "system", + content = msg.content, + }) + elseif msg.role == "user" then + table.insert(messages, { + role = "user", + content = msg.content, + }) + elseif msg.role == "assistant" then + local message = { + role = "assistant", + content = msg.content, + } + if msg.tool_calls then + message.tool_calls = msg.tool_calls + end + table.insert(messages, message) + elseif msg.role == "tool" then + table.insert(messages, { + role = "tool", + tool_call_id = msg.tool_call_id, + content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content), + }) + end + end - return messages + return messages end --- Execute the agent loop ---@param opts AgentLoopOpts function M.run(opts) - local tools_mod = require("codetyper.core.tools") - local llm = require("codetyper.core.llm") + local tools_mod = require("codetyper.core.tools") + local llm = require("codetyper.core.llm") - -- Get tools - local tools = opts.tools or tools_mod.list() - local tool_map = {} - for _, tool in ipairs(tools) do - tool_map[tool.name] = tool - end + -- Get tools + local tools = opts.tools or tools_mod.list() + local tool_map = {} + for _, tool in ipairs(tools) do + tool_map[tool.name] = tool + end - -- Initialize conversation history - ---@type AgentMessage[] - local history = { - { role = "system", content = opts.system_prompt }, - { role = "user", content = opts.user_input }, - } + -- Initialize conversation history + ---@type AgentMessage[] + local history = { + { role = "system", content = opts.system_prompt }, + { role = "user", content = opts.user_input }, + } - local session_ctx = opts.session_ctx or {} - local max_iterations = opts.max_iterations or 10 - local iteration = 0 + local session_ctx = opts.session_ctx or {} + local max_iterations = opts.max_iterations or 10 + local iteration = 0 - -- Callback wrappers - local function on_message(msg) - if opts.on_message then - opts.on_message(msg) - end - end + -- Callback wrappers + local function on_message(msg) + if opts.on_message then + opts.on_message(msg) + end + end - -- Notify of initial messages - for _, msg in ipairs(history) do - on_message(msg) - end + -- Notify of initial messages + for _, msg in ipairs(history) do + on_message(msg) + end - -- Start notification - if opts.on_start then - opts.on_start() - end + -- Start notification + if opts.on_start then + opts.on_start() + end - --- Process one iteration of the loop - local function process_iteration() - iteration = iteration + 1 + --- Process one iteration of the loop + local function process_iteration() + iteration = iteration + 1 - if iteration > max_iterations then - if opts.on_complete then - opts.on_complete(nil, "Max iterations reached") - end - return - end + if iteration > max_iterations then + if opts.on_complete then + opts.on_complete(nil, "Max iterations reached") + end + return + end - -- Build request - local messages = build_messages(history) - local formatted_tools = format_tools_for_api(tools) + -- Build request + local messages = build_messages(history) + local formatted_tools = format_tools_for_api(tools) - -- Build context for LLM - local context = { - file_content = "", - language = "lua", - extension = "lua", - prompt_type = "agent", - tools = formatted_tools, - } + -- Build context for LLM + local context = { + file_content = "", + language = "lua", + extension = "lua", + prompt_type = "agent", + tools = formatted_tools, + } - -- Get LLM response - local client = llm.get_client() - if not client then - if opts.on_complete then - opts.on_complete(nil, "No LLM client available") - end - return - end + -- Get LLM response + local client = llm.get_client() + if not client then + if opts.on_complete then + opts.on_complete(nil, "No LLM client available") + end + return + end - -- Build prompt from messages - local prompt_parts = {} - for _, msg in ipairs(messages) do - if msg.role ~= "system" then - table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or "")) - end - end - local prompt = table.concat(prompt_parts, "\n\n") + -- Build prompt from messages + local prompt_parts = {} + for _, msg in ipairs(messages) do + if msg.role ~= "system" then + table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or "")) + end + end + local prompt = table.concat(prompt_parts, "\n\n") - client.generate(prompt, context, function(response, error) - if error then - if opts.on_complete then - opts.on_complete(nil, error) - end - return - end + client.generate(prompt, context, function(response, error) + if error then + if opts.on_complete then + opts.on_complete(nil, error) + end + return + end - -- Chunk callback - if opts.on_chunk then - opts.on_chunk(response) - end + -- Chunk callback + if opts.on_chunk then + opts.on_chunk(response) + end - -- Parse response for tool calls - -- For now, we'll use a simple heuristic to detect tool calls in the response - -- In a full implementation, the LLM would return structured tool calls - local tool_calls = {} + -- Parse response for tool calls + -- For now, we'll use a simple heuristic to detect tool calls in the response + -- In a full implementation, the LLM would return structured tool calls + local tool_calls = {} - -- Try to parse JSON tool calls from response - local json_match = response:match("```json%s*(%b{})%s*```") - if json_match then - local ok, parsed = pcall(vim.json.decode, json_match) - if ok and parsed.tool_calls then - tool_calls = parsed.tool_calls - end - end + -- Try to parse JSON tool calls from response + local json_match = response:match("```json%s*(%b{})%s*```") + if json_match then + local ok, parsed = pcall(vim.json.decode, json_match) + if ok and parsed.tool_calls then + tool_calls = parsed.tool_calls + end + end - -- Add assistant message - local assistant_msg = { - role = "assistant", - content = response, - tool_calls = #tool_calls > 0 and tool_calls or nil, - } - table.insert(history, assistant_msg) - on_message(assistant_msg) + -- Add assistant message + local assistant_msg = { + role = "assistant", + content = response, + tool_calls = #tool_calls > 0 and tool_calls or nil, + } + table.insert(history, assistant_msg) + on_message(assistant_msg) - -- Process tool calls - if #tool_calls > 0 then - local pending = #tool_calls - local results = {} + -- Process tool calls + if #tool_calls > 0 then + local pending = #tool_calls + local results = {} - for i, call in ipairs(tool_calls) do - local tool = tool_map[call.name] - if not tool then - results[i] = { error = "Unknown tool: " .. call.name } - pending = pending - 1 - else - -- Notify of tool call - if opts.on_tool_call then - opts.on_tool_call(call.name, call.input) - end + for i, call in ipairs(tool_calls) do + local tool = tool_map[call.name] + if not tool then + results[i] = { error = "Unknown tool: " .. call.name } + pending = pending - 1 + else + -- Notify of tool call + if opts.on_tool_call then + opts.on_tool_call(call.name, call.input) + end - -- Execute tool - local tool_opts = { - on_log = function(msg) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ type = "tool", message = msg }) - end) - end, - on_complete = function(result, err) - results[i] = { result = result, error = err } - pending = pending - 1 + -- Execute tool + local tool_opts = { + on_log = function(msg) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ type = "tool", message = msg }) + end) + end, + on_complete = function(result, err) + results[i] = { result = result, error = err } + pending = pending - 1 - -- Notify of tool result - if opts.on_tool_result then - opts.on_tool_result(call.name, result, err) - end + -- Notify of tool result + if opts.on_tool_result then + opts.on_tool_result(call.name, result, err) + end - -- Add tool response to history - local tool_msg = { - role = "tool", - tool_call_id = call.id or tostring(i), - name = call.name, - content = err or result, - } - table.insert(history, tool_msg) - on_message(tool_msg) + -- Add tool response to history + local tool_msg = { + role = "tool", + tool_call_id = call.id or tostring(i), + name = call.name, + content = err or result, + } + table.insert(history, tool_msg) + on_message(tool_msg) - -- Continue loop when all tools complete - if pending == 0 then - vim.schedule(process_iteration) - end - end, - session_ctx = session_ctx, - } + -- Continue loop when all tools complete + if pending == 0 then + vim.schedule(process_iteration) + end + end, + session_ctx = session_ctx, + } - -- Validate and execute - local valid, validation_err = true, nil - if tool.validate_input then - valid, validation_err = tool:validate_input(call.input) - end + -- Validate and execute + local valid, validation_err = true, nil + if tool.validate_input then + valid, validation_err = tool:validate_input(call.input) + end - if not valid then - tool_opts.on_complete(nil, validation_err) - else - local result, err = tool.func(call.input, tool_opts) - -- If sync result, call on_complete - if result ~= nil or err ~= nil then - tool_opts.on_complete(result, err) - end - end - end - end - else - -- No tool calls - loop complete - if opts.on_complete then - opts.on_complete(response, nil) - end - end - end) - end + if not valid then + tool_opts.on_complete(nil, validation_err) + else + local result, err = tool.func(call.input, tool_opts) + -- If sync result, call on_complete + if result ~= nil or err ~= nil then + tool_opts.on_complete(result, err) + end + end + end + end + else + -- No tool calls - loop complete + if opts.on_complete then + opts.on_complete(response, nil) + end + end + end) + end - -- Start the loop - process_iteration() + -- Start the loop + process_iteration() end --- Create an agent with default settings ---@param task string Task description ---@param opts? AgentLoopOpts Additional options function M.create(task, opts) - opts = opts or {} + opts = opts or {} - local system_prompt = opts.system_prompt or prompts.default_system_prompt + local system_prompt = opts.system_prompt or prompts.default_system_prompt - M.run(vim.tbl_extend("force", opts, { - system_prompt = system_prompt, - user_input = task, - })) + M.run(vim.tbl_extend("force", opts, { + system_prompt = system_prompt, + user_input = task, + })) end --- Simple dispatch agent for sub-tasks @@ -360,22 +360,22 @@ end ---@param on_complete fun(result: string|nil, error: string|nil) Completion callback ---@param opts? table Additional options function M.dispatch(prompt, on_complete, opts) - opts = opts or {} + opts = opts or {} - -- Sub-agents get limited tools by default - local tools_mod = require("codetyper.core.tools") - local safe_tools = tools_mod.list(function(tool) - return tool.name == "view" or tool.name == "grep" or tool.name == "glob" - end) + -- Sub-agents get limited tools by default + local tools_mod = require("codetyper.core.tools") + local safe_tools = tools_mod.list(function(tool) + return tool.name == "view" or tool.name == "grep" or tool.name == "glob" + end) - M.run({ - system_prompt = prompts.dispatch_prompt, - user_input = prompt, - tools = opts.tools or safe_tools, - max_iterations = opts.max_iterations or 5, - on_complete = on_complete, - session_ctx = opts.session_ctx, - }) + M.run({ + system_prompt = prompts.dispatch_prompt, + user_input = prompt, + tools = opts.tools or safe_tools, + max_iterations = opts.max_iterations or 5, + on_complete = on_complete, + session_ctx = opts.session_ctx, + }) end return M diff --git a/lua/codetyper/core/scheduler/scheduler.lua b/lua/codetyper/core/scheduler/scheduler.lua index e0789af..a97afb6 100644 --- a/lua/codetyper/core/scheduler/scheduler.lua +++ b/lua/codetyper/core/scheduler/scheduler.lua @@ -19,11 +19,11 @@ context_modal.setup() --- Scheduler state local state = { - running = false, - timer = nil, - poll_interval = 100, -- ms - paused = false, - config = params.config, + running = false, + timer = nil, + poll_interval = 100, -- ms + paused = false, + config = params.config, } --- Autocommand group for injection timing @@ -32,465 +32,464 @@ local augroup = nil --- Check if completion popup is visible ---@return boolean function M.is_completion_visible() - -- Check native popup menu - if vim.fn.pumvisible() == 1 then - return true - end + -- Check native popup menu + if vim.fn.pumvisible() == 1 then + return true + end - -- Check nvim-cmp - local ok, cmp = pcall(require, "cmp") - if ok and cmp.visible and cmp.visible() then - return true - end + -- Check nvim-cmp + local ok, cmp = pcall(require, "cmp") + if ok and cmp.visible and cmp.visible() then + return true + end - -- Check coq_nvim - local coq_ok, coq = pcall(require, "coq") - if coq_ok and coq and type(coq.visible) == "function" and coq.visible() then - return true - end + -- Check coq_nvim + local coq_ok, coq = pcall(require, "coq") + if coq_ok and coq and type(coq.visible) == "function" and coq.visible() then + return true + end - return false + return false end --- Check if we're in insert mode ---@return boolean function M.is_insert_mode() - local mode = vim.fn.mode() - return mode == "i" or mode == "ic" or mode == "ix" + local mode = vim.fn.mode() + return mode == "i" or mode == "ic" or mode == "ix" end --- Check if we're in visual mode ---@return boolean function M.is_visual_mode() - local mode = vim.fn.mode() - return mode == "v" or mode == "V" or mode == "\22" + local mode = vim.fn.mode() + return mode == "v" or mode == "V" or mode == "\22" end --- Check if it's safe to inject code ---@return boolean ---@return string|nil reason if not safe function M.is_safe_to_inject() - if M.is_completion_visible() then - return false, "completion_visible" - end + if M.is_completion_visible() then + return false, "completion_visible" + end - if M.is_insert_mode() then - return false, "insert_mode" - end + if M.is_insert_mode() then + return false, "insert_mode" + end - if M.is_visual_mode() then - return false, "visual_mode" - end + if M.is_visual_mode() then + return false, "visual_mode" + end - return true, nil + return true, nil end --- Get the provider for escalation ---@return string local function get_remote_provider() - local ok, codetyper = pcall(require, "codetyper") - if ok then - local config = codetyper.get_config() - if config and config.llm and config.llm.provider then - if config.llm.provider == "ollama" then - return "copilot" - end - return config.llm.provider - end - end - return "copilot" + local ok, codetyper = pcall(require, "codetyper") + if ok then + local config = codetyper.get_config() + if config and config.llm and config.llm.provider then + if config.llm.provider == "ollama" then + return "copilot" + end + return config.llm.provider + end + end + return "copilot" end --- Get the primary provider (ollama if scout enabled, else configured) ---@return string local function get_primary_provider() - if state.config.ollama_scout then - return "ollama" - end + if state.config.ollama_scout then + return "ollama" + end - local ok, codetyper = pcall(require, "codetyper") - if ok then - local config = codetyper.get_config() - if config and config.llm and config.llm.provider then - return config.llm.provider - end - end - return "ollama" + local ok, codetyper = pcall(require, "codetyper") + if ok then + local config = codetyper.get_config() + if config and config.llm and config.llm.provider then + return config.llm.provider + end + end + return "ollama" end --- Retry event with additional context ---@param original_event table Original prompt event ---@param additional_context string Additional context from user local function retry_with_context(original_event, additional_context, attached_files) - -- Create new prompt content combining original + additional - local combined_prompt = string.format( - "%s\n\nAdditional context:\n%s", - original_event.prompt_content, - additional_context - ) + -- Create new prompt content combining original + additional + local combined_prompt = + string.format("%s\n\nAdditional context:\n%s", original_event.prompt_content, additional_context) - -- Create a new event with the combined prompt - local new_event = vim.deepcopy(original_event) - new_event.id = nil -- Will be assigned a new ID - new_event.prompt_content = combined_prompt - new_event.attempt_count = 0 - new_event.status = nil - -- Preserve any attached files provided by the context modal - if attached_files and #attached_files > 0 then - new_event.attached_files = attached_files - end + -- Create a new event with the combined prompt + local new_event = vim.deepcopy(original_event) + new_event.id = nil -- Will be assigned a new ID + new_event.prompt_content = combined_prompt + new_event.attempt_count = 0 + new_event.status = nil + -- Preserve any attached files provided by the context modal + if attached_files and #attached_files > 0 then + new_event.attached_files = attached_files + end - -- Log the retry - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Retrying with additional context (original: %s)", original_event.id), - }) - end) + -- Log the retry + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Retrying with additional context (original: %s)", original_event.id), + }) + end) - -- Queue the new event - queue.enqueue(new_event) + -- Queue the new event + queue.enqueue(new_event) end --- Try to parse requested file paths from an LLM response asking for more context ---@param response string ---@return string[] list of resolved full paths local function parse_requested_files(response) - if not response or response == "" then - return {} - end + if not response or response == "" then + return {} + end - local cwd = vim.fn.getcwd() - local results = {} - local seen = {} + local cwd = vim.fn.getcwd() + local results = {} + local seen = {} - -- Heuristics: capture backticked paths, lines starting with - or *, or raw paths with slashes and extension - for path in response:gmatch("`([%w%._%-%/]+%.[%w_]+)`") do - if not seen[path] then - table.insert(results, path) - seen[path] = true - end - end + -- Heuristics: capture backticked paths, lines starting with - or *, or raw paths with slashes and extension + for path in response:gmatch("`([%w%._%-%/]+%.[%w_]+)`") do + if not seen[path] then + table.insert(results, path) + seen[path] = true + end + end - for path in response:gmatch("([%w%._%-%/]+%.[%w_]+)") do - if not seen[path] then - -- Filter out common English words that match the pattern - if not path:match("^[Ii]$") and not path:match("^[Tt]his$") then - table.insert(results, path) - seen[path] = true - end - end - end + for path in response:gmatch("([%w%._%-%/]+%.[%w_]+)") do + if not seen[path] then + -- Filter out common English words that match the pattern + if not path:match("^[Ii]$") and not path:match("^[Tt]his$") then + table.insert(results, path) + seen[path] = true + end + end + end - -- Also capture list items like '- src/foo.lua' - for line in response:gmatch("[^\\n]+") do - local m = line:match("^%s*[-*]%s*([%w%._%-%/]+%.[%w_]+)%s*$") - if m and not seen[m] then - table.insert(results, m) - seen[m] = true - end - end + -- Also capture list items like '- src/foo.lua' + for line in response:gmatch("[^\\n]+") do + local m = line:match("^%s*[-*]%s*([%w%._%-%/]+%.[%w_]+)%s*$") + if m and not seen[m] then + table.insert(results, m) + seen[m] = true + end + end - -- Resolve each candidate to a full path by checking cwd and globbing - local resolved = {} - for _, p in ipairs(results) do - local candidate = p - local full = nil + -- Resolve each candidate to a full path by checking cwd and globbing + local resolved = {} + for _, p in ipairs(results) do + local candidate = p + local full = nil - -- If absolute or already rooted - if candidate:sub(1,1) == "/" and vim.fn.filereadable(candidate) == 1 then - full = candidate - else - -- Try relative to cwd - local try1 = cwd .. "/" .. candidate - if vim.fn.filereadable(try1) == 1 then - full = try1 - else - -- Try globbing for filename anywhere in project - local basename = candidate - -- If candidate contains slashes, try the tail - local tail = candidate:match("[^/]+$") or candidate - local matches = vim.fn.globpath(cwd, "**/" .. tail, false, true) - if matches and #matches > 0 then - full = matches[1] - end - end - end + -- If absolute or already rooted + if candidate:sub(1, 1) == "/" and vim.fn.filereadable(candidate) == 1 then + full = candidate + else + -- Try relative to cwd + local try1 = cwd .. "/" .. candidate + if vim.fn.filereadable(try1) == 1 then + full = try1 + else + -- Try globbing for filename anywhere in project + local basename = candidate + -- If candidate contains slashes, try the tail + local tail = candidate:match("[^/]+$") or candidate + local matches = vim.fn.globpath(cwd, "**/" .. tail, false, true) + if matches and #matches > 0 then + full = matches[1] + end + end + end - if full and vim.fn.filereadable(full) == 1 then - table.insert(resolved, full) - end - end + if full and vim.fn.filereadable(full) == 1 then + table.insert(resolved, full) + end + end - return resolved + return resolved end --- Process worker result and decide next action ---@param event table PromptEvent ---@param result table WorkerResult local function handle_worker_result(event, result) - -- Clear 99-style inline "Thinking..." virtual text when worker finishes (any outcome) - require("codetyper.core.thinking_placeholder").clear_inline(event.id) + -- Clear 99-style inline "Thinking..." virtual text when worker finishes (any outcome) + require("codetyper.core.thinking_placeholder").clear_inline(event.id) - -- Check if LLM needs more context - if result.needs_context then - require("codetyper.core.thinking_placeholder").remove_on_failure(event.id) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Event %s: LLM needs more context, opening modal", event.id), - }) - end) + -- Check if LLM needs more context + if result.needs_context then + require("codetyper.core.thinking_placeholder").remove_on_failure(event.id) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Event %s: LLM needs more context, opening modal", event.id), + }) + end) - -- Try to auto-attach any files the LLM specifically requested in its response - local requested = parse_requested_files(result.response or "") + -- Try to auto-attach any files the LLM specifically requested in its response + local requested = parse_requested_files(result.response or "") - -- Detect suggested shell commands the LLM may want executed (e.g., "run ls -la", "please run git status") - local function detect_suggested_commands(response) - if not response then - return {} - end - local cmds = {} - -- capture backticked commands: `ls -la` - for c in response:gmatch("`([^`]+)`") do - if #c > 1 and not c:match("%-%-help") then - table.insert(cmds, { label = c, cmd = c }) - end - end - -- capture phrases like: run ls -la or run `ls -la` - for m in response:gmatch("[Rr]un%s+([%w%p%s%-_/]+)") do - local cand = m:gsub("^%s+",""):gsub("%s+$","") - if cand and #cand > 1 then - -- ignore long sentences; keep first line or command-like substring - local line = cand:match("[^\n]+") or cand - line = line:gsub("and then.*","") - line = line:gsub("please.*","") - if not line:match("%a+%s+files") then - table.insert(cmds, { label = line, cmd = line }) - end - end - end - -- dedupe - local seen = {} - local out = {} - for _, v in ipairs(cmds) do - if v.cmd and not seen[v.cmd] then - seen[v.cmd] = true - table.insert(out, v) - end - end - return out - end + -- Detect suggested shell commands the LLM may want executed (e.g., "run ls -la", "please run git status") + local function detect_suggested_commands(response) + if not response then + return {} + end + local cmds = {} + -- capture backticked commands: `ls -la` + for c in response:gmatch("`([^`]+)`") do + if #c > 1 and not c:match("%-%-help") then + table.insert(cmds, { label = c, cmd = c }) + end + end + -- capture phrases like: run ls -la or run `ls -la` + for m in response:gmatch("[Rr]un%s+([%w%p%s%-_/]+)") do + local cand = m:gsub("^%s+", ""):gsub("%s+$", "") + if cand and #cand > 1 then + -- ignore long sentences; keep first line or command-like substring + local line = cand:match("[^\n]+") or cand + line = line:gsub("and then.*", "") + line = line:gsub("please.*", "") + if not line:match("%a+%s+files") then + table.insert(cmds, { label = line, cmd = line }) + end + end + end + -- dedupe + local seen = {} + local out = {} + for _, v in ipairs(cmds) do + if v.cmd and not seen[v.cmd] then + seen[v.cmd] = true + table.insert(out, v) + end + end + return out + end - local suggested_cmds = detect_suggested_commands(result.response or "") - if suggested_cmds and #suggested_cmds > 0 then - -- Open modal and show suggested commands for user approval - context_modal.open(result.original_event or event, result.response or "", retry_with_context, suggested_cmds) - queue.update_status(event.id, "needs_context", { response = result.response }) - return - end - if requested and #requested > 0 then - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ type = "info", message = string.format("Auto-attaching %d requested file(s)", #requested) }) - end) + local suggested_cmds = detect_suggested_commands(result.response or "") + if suggested_cmds and #suggested_cmds > 0 then + -- Open modal and show suggested commands for user approval + context_modal.open(result.original_event or event, result.response or "", retry_with_context, suggested_cmds) + queue.update_status(event.id, "needs_context", { response = result.response }) + return + end + if requested and #requested > 0 then + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ type = "info", message = string.format("Auto-attaching %d requested file(s)", #requested) }) + end) - -- Build attached_files entries - local attached = event.attached_files or {} - for _, full in ipairs(requested) do - local ok, content = pcall(function() - return table.concat(vim.fn.readfile(full), "\n") - end) - if ok and content then - table.insert(attached, { - path = vim.fn.fnamemodify(full, ":~:."), - full_path = full, - content = content, - }) - end - end + -- Build attached_files entries + local attached = event.attached_files or {} + for _, full in ipairs(requested) do + local ok, content = pcall(function() + return table.concat(vim.fn.readfile(full), "\n") + end) + if ok and content then + table.insert(attached, { + path = vim.fn.fnamemodify(full, ":~:."), + full_path = full, + content = content, + }) + end + end - -- Retry automatically with same prompt but attached files - local new_event = vim.deepcopy(result.original_event or event) - new_event.id = nil - new_event.attached_files = attached - new_event.attempt_count = 0 - new_event.status = nil - queue.enqueue(new_event) + -- Retry automatically with same prompt but attached files + local new_event = vim.deepcopy(result.original_event or event) + new_event.id = nil + new_event.attached_files = attached + new_event.attempt_count = 0 + new_event.status = nil + queue.enqueue(new_event) - queue.update_status(event.id, "needs_context", { response = result.response }) - return - end + queue.update_status(event.id, "needs_context", { response = result.response }) + return + end - -- If no files parsed, open modal for manual context entry - context_modal.open(result.original_event or event, result.response or "", retry_with_context) + -- If no files parsed, open modal for manual context entry + context_modal.open(result.original_event or event, result.response or "", retry_with_context) - -- Mark original event as needing context (not failed) - queue.update_status(event.id, "needs_context", { response = result.response }) - return - end + -- Mark original event as needing context (not failed) + queue.update_status(event.id, "needs_context", { response = result.response }) + return + end - if not result.success then - -- Remove in-buffer placeholder on failure (will be re-inserted if we escalate/retry) - require("codetyper.core.thinking_placeholder").remove_on_failure(event.id) - -- Failed - try escalation if this was ollama - if result.worker_type == "ollama" and event.attempt_count < 2 then - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format( - "Escalating event %s to remote provider (ollama failed)", - event.id - ), - }) - end) + if not result.success then + -- Remove in-buffer placeholder on failure (will be re-inserted if we escalate/retry) + require("codetyper.core.thinking_placeholder").remove_on_failure(event.id) + -- Failed - try escalation if this was ollama + if result.worker_type == "ollama" and event.attempt_count < 2 then + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Escalating event %s to remote provider (ollama failed)", event.id), + }) + end) - event.attempt_count = event.attempt_count + 1 - event.status = "pending" - event.worker_type = get_remote_provider() - return - end + event.attempt_count = event.attempt_count + 1 + event.status = "pending" + event.worker_type = get_remote_provider() + return + end - -- Mark as failed - queue.update_status(event.id, "failed", { error = result.error }) - return - end + -- Mark as failed + queue.update_status(event.id, "failed", { error = result.error }) + return + end - -- Success - check confidence - local needs_escalation = confidence_mod.needs_escalation( - result.confidence, - state.config.escalation_threshold - ) + -- Success - check confidence + local needs_escalation = confidence_mod.needs_escalation(result.confidence, state.config.escalation_threshold) - if needs_escalation and result.worker_type == "ollama" and event.attempt_count < 2 then - -- Low confidence from ollama - escalate to remote - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format( - "Escalating event %s to remote provider (confidence: %.2f < %.2f)", - event.id, result.confidence, state.config.escalation_threshold - ), - }) - end) + if needs_escalation and result.worker_type == "ollama" and event.attempt_count < 2 then + -- Low confidence from ollama - escalate to remote + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format( + "Escalating event %s to remote provider (confidence: %.2f < %.2f)", + event.id, + result.confidence, + state.config.escalation_threshold + ), + }) + end) - event.attempt_count = event.attempt_count + 1 - event.status = "pending" - event.worker_type = get_remote_provider() - return - end + event.attempt_count = event.attempt_count + 1 + event.status = "pending" + event.worker_type = get_remote_provider() + return + end - -- Good enough or final attempt - create patch - pcall(function() - local tp = require("codetyper.core.thinking_placeholder") - tp.update_inline_status(event.id, "Generating patch...") - local thinking = require("codetyper.adapters.nvim.ui.thinking") - thinking.update_stage("Generating patch...") - end) - vim.notify("Generating patch...", vim.log.levels.INFO) + -- Good enough or final attempt - create patch + pcall(function() + local tp = require("codetyper.core.thinking_placeholder") + tp.update_inline_status(event.id, "Generating patch...") + local thinking = require("codetyper.adapters.nvim.ui.thinking") + thinking.update_stage("Generating patch...") + end) + vim.notify("Generating patch...", vim.log.levels.INFO) - local p = patch.create_from_event(event, result.response, result.confidence) - patch.queue_patch(p) + local p = patch.create_from_event(event, result.response, result.confidence) + patch.queue_patch(p) - queue.complete(event.id) + queue.complete(event.id) - -- Schedule patch application after delay (gives user time to review/cancel) - local delay = state.config.apply_delay_ms or 5000 - pcall(function() - local tp = require("codetyper.core.thinking_placeholder") - tp.update_inline_status(event.id, "Applying code...") - local thinking = require("codetyper.adapters.nvim.ui.thinking") - thinking.update_stage("Applying code...") - end) - vim.notify("Applying code...", vim.log.levels.INFO) + -- Schedule patch application after delay (gives user time to review/cancel) + local delay = state.config.apply_delay_ms or 5000 + pcall(function() + local tp = require("codetyper.core.thinking_placeholder") + tp.update_inline_status(event.id, "Applying code...") + local thinking = require("codetyper.adapters.nvim.ui.thinking") + thinking.update_stage("Applying code...") + end) + vim.notify("Applying code...", vim.log.levels.INFO) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Code ready. Applying in %.1f seconds...", delay / 1000), - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Code ready. Applying in %.1f seconds...", delay / 1000), + }) + end) - vim.defer_fn(function() - M.schedule_patch_flush() - end, delay) + vim.defer_fn(function() + M.schedule_patch_flush() + end, delay) end --- Dispatch next event from queue local function dispatch_next() - if state.paused then - return - end + if state.paused then + return + end - -- Check concurrent limit - if worker.active_count() >= state.config.max_concurrent then - return - end + -- Check concurrent limit + if worker.active_count() >= state.config.max_concurrent then + return + end - -- Get next pending event - local event = queue.dequeue() - if not event then - return - end + -- Get next pending event + local event = queue.dequeue() + if not event then + return + end - -- Check for precedence conflicts (multiple tags in same scope) - local should_skip, skip_reason = queue.check_precedence(event) - if should_skip then - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "warning", - message = string.format("Event %s skipped: %s", event.id, skip_reason or "conflict"), - }) - end) - queue.cancel(event.id) - -- Try next event - return dispatch_next() - end + -- Check for precedence conflicts (multiple tags in same scope) + local should_skip, skip_reason = queue.check_precedence(event) + if should_skip then + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "warning", + message = string.format("Event %s skipped: %s", event.id, skip_reason or "conflict"), + }) + end) + queue.cancel(event.id) + -- Try next event + return dispatch_next() + end - -- Determine which provider to use - local provider = event.worker_type or get_primary_provider() + -- Determine which provider to use + local provider = event.worker_type or get_primary_provider() - -- Log dispatch with intent/scope info - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - local intent_info = event.intent and event.intent.type or "unknown" - local scope_info = event.scope and event.scope.type ~= "file" - and string.format("%s:%s", event.scope.type, event.scope.name or "anon") - or "file" - logs.add({ - type = "info", - message = string.format( - "Dispatching %s [intent: %s, scope: %s, provider: %s]", - event.id, intent_info, scope_info, provider - ), - }) - end) + -- Log dispatch with intent/scope info + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + local intent_info = event.intent and event.intent.type or "unknown" + local scope_info = event.scope + and event.scope.type ~= "file" + and string.format("%s:%s", event.scope.type, event.scope.name or "anon") + or "file" + logs.add({ + type = "info", + message = string.format( + "Dispatching %s [intent: %s, scope: %s, provider: %s]", + event.id, + intent_info, + scope_info, + provider + ), + }) + end) - -- Show thinking indicator: top-right window (always) + in-buffer or 99-style inline - local thinking = require("codetyper.adapters.nvim.ui.thinking") - thinking.ensure_shown() + -- Show thinking indicator: top-right window (always) + in-buffer or 99-style inline + local thinking = require("codetyper.adapters.nvim.ui.thinking") + thinking.ensure_shown() - local is_inline = event.target_path and not event.target_path:match("%.codetyper%.") and (event.bufnr == vim.fn.bufnr(event.target_path)) - local thinking_placeholder = require("codetyper.core.thinking_placeholder") - if is_inline then - -- 99-style: virtual text "⠋ Thinking..." at selection (no buffer change, SEARCH/REPLACE safe) - thinking_placeholder.start_inline(event) - else - thinking_placeholder.insert(event) - end + local is_inline = event.target_path + and not event.target_path:match("%.codetyper%.") + and (event.bufnr == vim.fn.bufnr(event.target_path)) + local thinking_placeholder = require("codetyper.core.thinking_placeholder") + if is_inline then + -- 99-style: virtual text "⠋ Thinking..." at selection (no buffer change, SEARCH/REPLACE safe) + thinking_placeholder.start_inline(event) + else + thinking_placeholder.insert(event) + end - -- Create worker - worker.create(event, provider, function(result) - vim.schedule(function() - handle_worker_result(event, result) - end) - end) + -- Create worker + worker.create(event, provider, function(result) + vim.schedule(function() + handle_worker_result(event, result) + end) + end) end --- Track if we're already waiting to flush (avoid spam logs) @@ -499,279 +498,279 @@ local waiting_to_flush = false --- Schedule patch flush after delay (completion safety) --- Will keep retrying until safe to inject or no pending patches function M.schedule_patch_flush() - vim.defer_fn(function() - -- Check if there are any pending patches - local pending = patch.get_pending() - logger.info("scheduler", string.format("schedule_patch_flush: %d pending", #pending)) - if #pending == 0 then - waiting_to_flush = false - return -- Nothing to apply - end + vim.defer_fn(function() + -- Check if there are any pending patches + local pending = patch.get_pending() + logger.info("scheduler", string.format("schedule_patch_flush: %d pending", #pending)) + if #pending == 0 then + waiting_to_flush = false + return -- Nothing to apply + end - local safe, reason = M.is_safe_to_inject() - logger.info("scheduler", string.format("is_safe_to_inject=%s (%s)", tostring(safe), tostring(reason or "ok"))) - if safe then - waiting_to_flush = false - local applied, stale = patch.flush_pending_smart() - if applied > 0 or stale > 0 then - logger.info("scheduler", string.format("Patches flushed: %d applied, %d stale", applied, stale)) - end - else - -- Not safe yet (user is typing), reschedule to try again - -- Only log once when we start waiting - if not waiting_to_flush then - waiting_to_flush = true - logger.info("scheduler", "Waiting for user to finish typing before applying code...") - -- Notify user about the wait - local utils = require("codetyper.support.utils") - if reason == "visual_mode" then - utils.notify("Queue waiting: exit Visual mode to inject code", vim.log.levels.INFO) - elseif reason == "insert_mode" then - utils.notify("Queue waiting: exit Insert mode to inject code", vim.log.levels.INFO) - end - end - -- Retry after a delay - keep waiting for user to finish typing - M.schedule_patch_flush() - end - end, state.config.completion_delay_ms) + local safe, reason = M.is_safe_to_inject() + logger.info("scheduler", string.format("is_safe_to_inject=%s (%s)", tostring(safe), tostring(reason or "ok"))) + if safe then + waiting_to_flush = false + local applied, stale = patch.flush_pending_smart() + if applied > 0 or stale > 0 then + logger.info("scheduler", string.format("Patches flushed: %d applied, %d stale", applied, stale)) + end + else + -- Not safe yet (user is typing), reschedule to try again + -- Only log once when we start waiting + if not waiting_to_flush then + waiting_to_flush = true + logger.info("scheduler", "Waiting for user to finish typing before applying code...") + -- Notify user about the wait + local utils = require("codetyper.support.utils") + if reason == "visual_mode" then + utils.notify("Queue waiting: exit Visual mode to inject code", vim.log.levels.INFO) + elseif reason == "insert_mode" then + utils.notify("Queue waiting: exit Insert mode to inject code", vim.log.levels.INFO) + end + end + -- Retry after a delay - keep waiting for user to finish typing + M.schedule_patch_flush() + end + end, state.config.completion_delay_ms) end --- Main scheduler loop local function scheduler_loop() - if not state.running then - return - end + if not state.running then + return + end - dispatch_next() + dispatch_next() - -- Cleanup old items periodically - if math.random() < 0.01 then -- ~1% chance each tick - queue.cleanup(300) - patch.cleanup(300) - end + -- Cleanup old items periodically + if math.random() < 0.01 then -- ~1% chance each tick + queue.cleanup(300) + patch.cleanup(300) + end - -- Schedule next tick - state.timer = vim.defer_fn(scheduler_loop, state.poll_interval) + -- Schedule next tick + state.timer = vim.defer_fn(scheduler_loop, state.poll_interval) end --- Setup autocommands for injection timing local function setup_autocmds() - if augroup then - pcall(vim.api.nvim_del_augroup_by_id, augroup) - end + if augroup then + pcall(vim.api.nvim_del_augroup_by_id, augroup) + end - augroup = vim.api.nvim_create_augroup("CodetypeScheduler", { clear = true }) + augroup = vim.api.nvim_create_augroup("CodetypeScheduler", { clear = true }) - -- Flush patches when leaving insert mode - vim.api.nvim_create_autocmd("InsertLeave", { - group = augroup, - callback = function() - vim.defer_fn(function() - if not M.is_completion_visible() then - patch.flush_pending_smart() - end - end, state.config.completion_delay_ms) - end, - desc = "Flush pending patches on InsertLeave", - }) + -- Flush patches when leaving insert mode + vim.api.nvim_create_autocmd("InsertLeave", { + group = augroup, + callback = function() + vim.defer_fn(function() + if not M.is_completion_visible() then + patch.flush_pending_smart() + end + end, state.config.completion_delay_ms) + end, + desc = "Flush pending patches on InsertLeave", + }) - -- Flush patches when leaving visual mode - vim.api.nvim_create_autocmd("ModeChanged", { - group = augroup, - pattern = "[vV\x16]*:*", -- visual mode to any other mode - callback = function() - vim.defer_fn(function() - if not M.is_insert_mode() and not M.is_completion_visible() then - patch.flush_pending_smart() - end - end, state.config.completion_delay_ms) - end, - desc = "Flush pending patches on VisualLeave", - }) + -- Flush patches when leaving visual mode + vim.api.nvim_create_autocmd("ModeChanged", { + group = augroup, + pattern = "[vV\x16]*:*", -- visual mode to any other mode + callback = function() + vim.defer_fn(function() + if not M.is_insert_mode() and not M.is_completion_visible() then + patch.flush_pending_smart() + end + end, state.config.completion_delay_ms) + end, + desc = "Flush pending patches on VisualLeave", + }) - -- Flush patches on cursor hold - vim.api.nvim_create_autocmd("CursorHold", { - group = augroup, - callback = function() - if not M.is_insert_mode() and not M.is_completion_visible() then - patch.flush_pending_smart() - end - end, - desc = "Flush pending patches on CursorHold", - }) + -- Flush patches on cursor hold + vim.api.nvim_create_autocmd("CursorHold", { + group = augroup, + callback = function() + if not M.is_insert_mode() and not M.is_completion_visible() then + patch.flush_pending_smart() + end + end, + desc = "Flush pending patches on CursorHold", + }) - -- Cancel patches when buffer changes significantly - vim.api.nvim_create_autocmd("BufWritePre", { - group = augroup, - callback = function(ev) - -- Mark relevant patches as potentially stale - -- They'll be checked on next flush attempt - end, - desc = "Check patch staleness on save", - }) + -- Cancel patches when buffer changes significantly + vim.api.nvim_create_autocmd("BufWritePre", { + group = augroup, + callback = function(ev) + -- Mark relevant patches as potentially stale + -- They'll be checked on next flush attempt + end, + desc = "Check patch staleness on save", + }) - -- Cleanup when buffer is deleted - vim.api.nvim_create_autocmd("BufDelete", { - group = augroup, - callback = function(ev) - queue.cancel_for_buffer(ev.buf) - patch.cancel_for_buffer(ev.buf) - worker.cancel_for_event(ev.buf) - end, - desc = "Cleanup on buffer delete", - }) + -- Cleanup when buffer is deleted + vim.api.nvim_create_autocmd("BufDelete", { + group = augroup, + callback = function(ev) + queue.cancel_for_buffer(ev.buf) + patch.cancel_for_buffer(ev.buf) + worker.cancel_for_event(ev.buf) + end, + desc = "Cleanup on buffer delete", + }) - -- Stop scheduler when exiting Neovim - vim.api.nvim_create_autocmd("VimLeavePre", { - group = augroup, - callback = function() - M.stop() - end, - desc = "Stop scheduler before exiting Neovim", - }) + -- Stop scheduler when exiting Neovim + vim.api.nvim_create_autocmd("VimLeavePre", { + group = augroup, + callback = function() + M.stop() + end, + desc = "Stop scheduler before exiting Neovim", + }) end --- Start the scheduler ---@param config table|nil Configuration overrides function M.start(config) - if state.running then - return - end + if state.running then + return + end - -- Merge config - if config then - for k, v in pairs(config) do - state.config[k] = v - end - end + -- Merge config + if config then + for k, v in pairs(config) do + state.config[k] = v + end + end - -- Load config from codetyper if available - pcall(function() - local codetyper = require("codetyper") - local ct_config = codetyper.get_config() - if ct_config and ct_config.scheduler then - for k, v in pairs(ct_config.scheduler) do - state.config[k] = v - end - end - end) + -- Load config from codetyper if available + pcall(function() + local codetyper = require("codetyper") + local ct_config = codetyper.get_config() + if ct_config and ct_config.scheduler then + for k, v in pairs(ct_config.scheduler) do + state.config[k] = v + end + end + end) - if not state.config.enabled then - return - end + if not state.config.enabled then + return + end - state.running = true - state.paused = false + state.running = true + state.paused = false - -- Setup autocmds - setup_autocmds() + -- Setup autocmds + setup_autocmds() - -- Add queue listener - queue.add_listener(function(event_type, event, queue_size) - if event_type == "enqueue" and not state.paused then - -- New event - try to dispatch immediately - vim.schedule(dispatch_next) - end - end) + -- Add queue listener + queue.add_listener(function(event_type, event, queue_size) + if event_type == "enqueue" and not state.paused then + -- New event - try to dispatch immediately + vim.schedule(dispatch_next) + end + end) - -- Start main loop - scheduler_loop() + -- Start main loop + scheduler_loop() - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = "Scheduler started", - data = { - ollama_scout = state.config.ollama_scout, - escalation_threshold = state.config.escalation_threshold, - max_concurrent = state.config.max_concurrent, - }, - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = "Scheduler started", + data = { + ollama_scout = state.config.ollama_scout, + escalation_threshold = state.config.escalation_threshold, + max_concurrent = state.config.max_concurrent, + }, + }) + end) end --- Stop the scheduler function M.stop() - state.running = false + state.running = false - if state.timer then - pcall(function() - if type(state.timer) == "userdata" and state.timer.stop then - state.timer:stop() - end - end) - state.timer = nil - end + if state.timer then + pcall(function() + if type(state.timer) == "userdata" and state.timer.stop then + state.timer:stop() + end + end) + state.timer = nil + end - if augroup then - pcall(vim.api.nvim_del_augroup_by_id, augroup) - augroup = nil - end + if augroup then + pcall(vim.api.nvim_del_augroup_by_id, augroup) + augroup = nil + end - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = "Scheduler stopped", - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = "Scheduler stopped", + }) + end) end --- Pause the scheduler (don't process new events) function M.pause() - state.paused = true + state.paused = true end --- Resume the scheduler function M.resume() - state.paused = false - vim.schedule(dispatch_next) + state.paused = false + vim.schedule(dispatch_next) end --- Check if scheduler is running ---@return boolean function M.is_running() - return state.running + return state.running end --- Check if scheduler is paused ---@return boolean function M.is_paused() - return state.paused + return state.paused end --- Get scheduler status ---@return table function M.status() - return { - running = state.running, - paused = state.paused, - queue_stats = queue.stats(), - patch_stats = patch.stats(), - active_workers = worker.active_count(), - config = vim.deepcopy(state.config), - } + return { + running = state.running, + paused = state.paused, + queue_stats = queue.stats(), + patch_stats = patch.stats(), + active_workers = worker.active_count(), + config = vim.deepcopy(state.config), + } end --- Manually trigger dispatch function M.dispatch() - if state.running and not state.paused then - dispatch_next() - end + if state.running and not state.paused then + dispatch_next() + end end --- Force flush all pending patches (ignores completion check) function M.force_flush() - return patch.flush_pending_smart() + return patch.flush_pending_smart() end --- Update configuration ---@param config table function M.configure(config) - for k, v in pairs(config) do - state.config[k] = v - end + for k, v in pairs(config) do + state.config[k] = v + end end --- Queue a prompt for processing @@ -786,21 +785,21 @@ end --- - priority: number|nil Priority (1=high, 2=normal, 3=low) default 2 ---@return table The enqueued event function M.queue_prompt(opts) - -- Build the PromptEvent structure - local event = { - bufnr = opts.bufnr, - filepath = opts.filepath, - target_path = opts.target_path or opts.filepath, - prompt_content = opts.prompt_content, - range = opts.range, - priority = opts.priority or 2, - source = opts.source or "manual", - -- Capture buffer state for staleness detection - changedtick = vim.api.nvim_buf_get_changedtick(opts.bufnr), - } + -- Build the PromptEvent structure + local event = { + bufnr = opts.bufnr, + filepath = opts.filepath, + target_path = opts.target_path or opts.filepath, + prompt_content = opts.prompt_content, + range = opts.range, + priority = opts.priority or 2, + source = opts.source or "manual", + -- Capture buffer state for staleness detection + changedtick = vim.api.nvim_buf_get_changedtick(opts.bufnr), + } - -- Enqueue through the queue module - return queue.enqueue(event) + -- Enqueue through the queue module + return queue.enqueue(event) end return M diff --git a/lua/codetyper/core/scheduler/worker.lua b/lua/codetyper/core/scheduler/worker.lua index a6230f3..a7f2a18 100644 --- a/lua/codetyper/core/scheduler/worker.lua +++ b/lua/codetyper/core/scheduler/worker.lua @@ -34,17 +34,17 @@ local worker_counter = 0 ---@param event_id string|nil ---@param text string Status text local function notify_stage(event_id, text) - pcall(function() - local tp = require("codetyper.core.thinking_placeholder") - if event_id then - tp.update_inline_status(event_id, text) - end - end) - pcall(function() - local thinking = require("codetyper.adapters.nvim.ui.thinking") - thinking.update_stage(text) - end) - vim.notify(text, vim.log.levels.INFO) + pcall(function() + local tp = require("codetyper.core.thinking_placeholder") + if event_id then + tp.update_inline_status(event_id, text) + end + end) + pcall(function() + local thinking = require("codetyper.adapters.nvim.ui.thinking") + thinking.update_stage(text) + end) + vim.notify(text, vim.log.levels.INFO) end --- Patterns that indicate LLM needs more context (must be near start of response) @@ -55,48 +55,56 @@ local context_needed_patterns = params.context_needed_patterns ---@param response string ---@return boolean local function needs_more_context(response) - if not response then - return false - end + if not response then + return false + end - -- If response has substantial code (more than 5 lines with code-like content), don't ask for context - local lines = vim.split(response, "\n") - local code_lines = 0 - for _, line in ipairs(lines) do - -- Count lines that look like code (have programming constructs) - if line:match("[{}();=]") or line:match("function") or line:match("def ") - or line:match("class ") or line:match("return ") or line:match("import ") - or line:match("public ") or line:match("private ") or line:match("local ") then - code_lines = code_lines + 1 - end - end + -- If response has substantial code (more than 5 lines with code-like content), don't ask for context + local lines = vim.split(response, "\n") + local code_lines = 0 + for _, line in ipairs(lines) do + -- Count lines that look like code (have programming constructs) + if + line:match("[{}();=]") + or line:match("function") + or line:match("def ") + or line:match("class ") + or line:match("return ") + or line:match("import ") + or line:match("public ") + or line:match("private ") + or line:match("local ") + then + code_lines = code_lines + 1 + end + end - -- If there's substantial code, don't trigger context request - if code_lines >= 3 then - return false - end + -- If there's substantial code, don't trigger context request + if code_lines >= 3 then + return false + end - -- Check if the response STARTS with a context-needed phrase - local lower = response:lower() - for _, pattern in ipairs(context_needed_patterns) do - if lower:match(pattern) then - return true - end - end - return false + -- Check if the response STARTS with a context-needed phrase + local lower = response:lower() + for _, pattern in ipairs(context_needed_patterns) do + if lower:match(pattern) then + return true + end + end + return false end --- Check if response contains SEARCH/REPLACE blocks ---@param response string ---@return boolean local function has_search_replace_blocks(response) - if not response then - return false - end - -- Check for any of the supported SEARCH/REPLACE formats - return response:match("<<<<<<<%s*SEARCH") ~= nil - or response:match("%-%-%-%-%-%-%-?%s*SEARCH") ~= nil - or response:match("%[SEARCH%]") ~= nil + if not response then + return false + end + -- Check for any of the supported SEARCH/REPLACE formats + return response:match("<<<<<<<%s*SEARCH") ~= nil + or response:match("%-%-%-%-%-%-%-?%s*SEARCH") ~= nil + or response:match("%[SEARCH%]") ~= nil end --- Clean LLM response to extract only code @@ -107,15 +115,15 @@ end ---@param text string Raw response that may start with @thinking ... end thinking ---@return string Text with thinking block removed (or original if no block) local function strip_thinking_block(text) - if not text or text == "" then - return text or "" - end - -- Match from start: @thinking, any content, then line "end thinking"; capture everything after that - local after = text:match("^%s*@thinking[%s%S]*\nend thinking%s*\n(.*)") - if after then - return after:match("^%s*(.-)%s*$") or after - end - return text + if not text or text == "" then + return text or "" + end + -- Match from start: @thinking, any content, then line "end thinking"; capture everything after that + local after = text:match("^%s*@thinking[%s%S]*\nend thinking%s*\n(.*)") + if after then + return after:match("^%s*(.-)%s*$") or after + end + return text end --- Clean LLM response to extract only code @@ -123,109 +131,108 @@ end ---@param filetype string|nil File type for language detection ---@return string Cleaned code local function clean_response(response, filetype) - if not response then - return "" - end + if not response then + return "" + end - local cleaned = response + local cleaned = response - -- Remove @thinking ... end thinking block first (we show thinking in placeholder; inject only code) - cleaned = strip_thinking_block(cleaned) + -- Remove @thinking ... end thinking block first (we show thinking in placeholder; inject only code) + cleaned = strip_thinking_block(cleaned) - -- Remove LLM special tokens (deepseek, llama, etc.) - cleaned = cleaned:gsub("<|begin▁of▁sentence|>", "") - cleaned = cleaned:gsub("<|end▁of▁sentence|>", "") - cleaned = cleaned:gsub("<|im_start|>", "") - cleaned = cleaned:gsub("<|im_end|>", "") - cleaned = cleaned:gsub("", "") - cleaned = cleaned:gsub("", "") - cleaned = cleaned:gsub("<|endoftext|>", "") + -- Remove LLM special tokens (deepseek, llama, etc.) + cleaned = cleaned:gsub("<|begin▁of▁sentence|>", "") + cleaned = cleaned:gsub("<|end▁of▁sentence|>", "") + cleaned = cleaned:gsub("<|im_start|>", "") + cleaned = cleaned:gsub("<|im_end|>", "") + cleaned = cleaned:gsub("", "") + cleaned = cleaned:gsub("", "") + cleaned = cleaned:gsub("<|endoftext|>", "") - -- Remove the original prompt tags /@ ... @/ if they appear in output - -- Use [%s%S] to match any character including newlines (Lua's . doesn't match newlines) - cleaned = cleaned:gsub("/@[%s%S]-@/", "") + -- Remove the original prompt tags /@ ... @/ if they appear in output + -- Use [%s%S] to match any character including newlines (Lua's . doesn't match newlines) + cleaned = cleaned:gsub("/@[%s%S]-@/", "") - -- IMPORTANT: If response contains SEARCH/REPLACE blocks, preserve them! - -- Don't extract from markdown or remove "explanations" that are actually part of the format - if has_search_replace_blocks(cleaned) then - -- Just trim whitespace and return - the blocks will be parsed by search_replace module - return cleaned:match("^%s*(.-)%s*$") or cleaned - end + -- IMPORTANT: If response contains SEARCH/REPLACE blocks, preserve them! + -- Don't extract from markdown or remove "explanations" that are actually part of the format + if has_search_replace_blocks(cleaned) then + -- Just trim whitespace and return - the blocks will be parsed by search_replace module + return cleaned:match("^%s*(.-)%s*$") or cleaned + end - -- Try to extract code from markdown code blocks - -- Match ```language\n...\n``` or just ```\n...\n``` - local code_block = cleaned:match("```[%w]*\n(.-)\n```") - if not code_block then - -- Try without newline after language - code_block = cleaned:match("```[%w]*(.-)\n```") - end - if not code_block then - -- Try single line code block - code_block = cleaned:match("```(.-)```") - end + -- Try to extract code from markdown code blocks + -- Match ```language\n...\n``` or just ```\n...\n``` + local code_block = cleaned:match("```[%w]*\n(.-)\n```") + if not code_block then + -- Try without newline after language + code_block = cleaned:match("```[%w]*(.-)\n```") + end + if not code_block then + -- Try single line code block + code_block = cleaned:match("```(.-)```") + end - if code_block then - cleaned = code_block - else - -- No code block found, try to remove common prefixes/suffixes - -- Remove common apology/explanation phrases at the start - local explanation_starts = { - "^[Ii]'m sorry.-\n", - "^[Ii] apologize.-\n", - "^[Hh]ere is.-:\n", - "^[Hh]ere's.-:\n", - "^[Tt]his is.-:\n", - "^[Bb]ased on.-:\n", - "^[Ss]ure.-:\n", - "^[Oo][Kk].-:\n", - "^[Cc]ertainly.-:\n", - } - for _, pattern in ipairs(explanation_starts) do - cleaned = cleaned:gsub(pattern, "") - end + if code_block then + cleaned = code_block + else + -- No code block found, try to remove common prefixes/suffixes + -- Remove common apology/explanation phrases at the start + local explanation_starts = { + "^[Ii]'m sorry.-\n", + "^[Ii] apologize.-\n", + "^[Hh]ere is.-:\n", + "^[Hh]ere's.-:\n", + "^[Tt]his is.-:\n", + "^[Bb]ased on.-:\n", + "^[Ss]ure.-:\n", + "^[Oo][Kk].-:\n", + "^[Cc]ertainly.-:\n", + } + for _, pattern in ipairs(explanation_starts) do + cleaned = cleaned:gsub(pattern, "") + end - -- Remove trailing explanations - local explanation_ends = { - "\n[Tt]his code.-$", - "\n[Tt]his function.-$", - "\n[Tt]his is a.-$", - "\n[Ii] hope.-$", - "\n[Ll]et me know.-$", - "\n[Ff]eel free.-$", - "\n[Nn]ote:.-$", - "\n[Pp]lease replace.-$", - "\n[Pp]lease note.-$", - "\n[Yy]ou might want.-$", - "\n[Yy]ou may want.-$", - "\n[Mm]ake sure.-$", - "\n[Aa]lso,.-$", - "\n[Rr]emember.-$", - } - for _, pattern in ipairs(explanation_ends) do - cleaned = cleaned:gsub(pattern, "") - end - end + -- Remove trailing explanations + local explanation_ends = { + "\n[Tt]his code.-$", + "\n[Tt]his function.-$", + "\n[Tt]his is a.-$", + "\n[Ii] hope.-$", + "\n[Ll]et me know.-$", + "\n[Ff]eel free.-$", + "\n[Nn]ote:.-$", + "\n[Pp]lease replace.-$", + "\n[Pp]lease note.-$", + "\n[Yy]ou might want.-$", + "\n[Yy]ou may want.-$", + "\n[Mm]ake sure.-$", + "\n[Aa]lso,.-$", + "\n[Rr]emember.-$", + } + for _, pattern in ipairs(explanation_ends) do + cleaned = cleaned:gsub(pattern, "") + end + end - -- Remove any remaining markdown artifacts - cleaned = cleaned:gsub("^```[%w]*\n?", "") - cleaned = cleaned:gsub("\n?```$", "") + -- Remove any remaining markdown artifacts + cleaned = cleaned:gsub("^```[%w]*\n?", "") + cleaned = cleaned:gsub("\n?```$", "") - -- Trim whitespace - cleaned = cleaned:match("^%s*(.-)%s*$") or cleaned + -- Trim whitespace + cleaned = cleaned:match("^%s*(.-)%s*$") or cleaned - return cleaned + return cleaned end --- Active workers ---@type table local active_workers = {} - --- Generate worker ID ---@return string local function generate_id() - worker_counter = worker_counter + 1 - return string.format("worker_%d_%d", os.time(), worker_counter) + worker_counter = worker_counter + 1 + return string.format("worker_%d_%d", os.time(), worker_counter) end --- Get LLM client by type @@ -233,178 +240,183 @@ end ---@return table|nil client ---@return string|nil error local function get_client(worker_type) - local ok, client = pcall(require, "codetyper.llm." .. worker_type) - if ok and client then - return client, nil - end - return nil, "Unknown provider: " .. worker_type + local ok, client = pcall(require, "codetyper.llm." .. worker_type) + if ok and client then + return client, nil + end + return nil, "Unknown provider: " .. worker_type end --- Format attached files for inclusion in prompt ---@param attached_files table[]|nil ---@return string local function format_attached_files(attached_files) - if not attached_files or #attached_files == 0 then - return "" - end + if not attached_files or #attached_files == 0 then + return "" + end - local parts = { "\n\n--- Referenced Files ---" } - for _, file in ipairs(attached_files) do - local ext = vim.fn.fnamemodify(file.path, ":e") - table.insert(parts, string.format( - "\n\nFile: %s\n```%s\n%s\n```", - file.path, - ext, - file.content:sub(1, 3000) -- Limit each file to 3000 chars - )) - end + local parts = { "\n\n--- Referenced Files ---" } + for _, file in ipairs(attached_files) do + local ext = vim.fn.fnamemodify(file.path, ":e") + table.insert( + parts, + string.format( + "\n\nFile: %s\n```%s\n%s\n```", + file.path, + ext, + file.content:sub(1, 3000) -- Limit each file to 3000 chars + ) + ) + end - return table.concat(parts, "") + return table.concat(parts, "") end --- Get coder companion file path for a target file ---@param target_path string Target file path ---@return string|nil Coder file path if exists local function get_coder_companion_path(target_path) - if not target_path or target_path == "" then - return nil - end + if not target_path or target_path == "" then + return nil + end - -- Skip if target is already a coder file - if target_path:match("%.codetyper%.") then - return nil - end + -- Skip if target is already a coder file + if target_path:match("%.codetyper%.") then + return nil + end - local dir = vim.fn.fnamemodify(target_path, ":h") - local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension - local ext = vim.fn.fnamemodify(target_path, ":e") + local dir = vim.fn.fnamemodify(target_path, ":h") + local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension + local ext = vim.fn.fnamemodify(target_path, ":e") - local coder_path = dir .. "/" .. name .. ".codetyper/" .. ext - if vim.fn.filereadable(coder_path) == 1 then - return coder_path - end + local coder_path = dir .. "/" .. name .. ".codetyper/" .. ext + if vim.fn.filereadable(coder_path) == 1 then + return coder_path + end - return nil + return nil end --- Read and format coder companion context (business logic, pseudo-code) ---@param target_path string Target file path ---@return string Formatted coder context local function get_coder_context(target_path) - local coder_path = get_coder_companion_path(target_path) - if not coder_path then - return "" - end + local coder_path = get_coder_companion_path(target_path) + if not coder_path then + return "" + end - local ok, lines = pcall(function() - return vim.fn.readfile(coder_path) - end) + local ok, lines = pcall(function() + return vim.fn.readfile(coder_path) + end) - if not ok or not lines or #lines == 0 then - return "" - end + if not ok or not lines or #lines == 0 then + return "" + end - local content = table.concat(lines, "\n") + local content = table.concat(lines, "\n") - -- Skip if only template comments (no actual content) - local stripped = content:gsub("^%s*", ""):gsub("%s*$", "") - if stripped == "" then - return "" - end + -- Skip if only template comments (no actual content) + local stripped = content:gsub("^%s*", ""):gsub("%s*$", "") + if stripped == "" then + return "" + end - -- Check if there's meaningful content (not just template) - local has_content = false - for _, line in ipairs(lines) do - -- Skip comment lines that are part of the template - local trimmed = line:gsub("^%s*", "") - if not trimmed:match("^[%-#/]+%s*Coder companion") - and not trimmed:match("^[%-#/]+%s*Use /@ @/") - and not trimmed:match("^[%-#/]+%s*Example:") - and not trimmed:match("^ 0 then - table.insert(symbol_list, symbol .. " (in " .. files[1] .. ")") - end - end - if #symbol_list > 0 then - table.insert(parts, "Relevant symbols: " .. table.concat(symbol_list, ", ")) - end - end + -- Relevant symbols + if indexed_context.relevant_symbols then + local symbol_list = {} + for symbol, files in pairs(indexed_context.relevant_symbols) do + if #files > 0 then + table.insert(symbol_list, symbol .. " (in " .. files[1] .. ")") + end + end + if #symbol_list > 0 then + table.insert(parts, "Relevant symbols: " .. table.concat(symbol_list, ", ")) + end + end - -- Learned patterns - if indexed_context.patterns and #indexed_context.patterns > 0 then - local pattern_list = {} - for i, p in ipairs(indexed_context.patterns) do - if i <= 3 then - table.insert(pattern_list, p.content or "") - end - end - if #pattern_list > 0 then - table.insert(parts, "Project conventions: " .. table.concat(pattern_list, "; ")) - end - end + -- Learned patterns + if indexed_context.patterns and #indexed_context.patterns > 0 then + local pattern_list = {} + for i, p in ipairs(indexed_context.patterns) do + if i <= 3 then + table.insert(pattern_list, p.content or "") + end + end + if #pattern_list > 0 then + table.insert(parts, "Project conventions: " .. table.concat(pattern_list, "; ")) + end + end - if #parts == 0 then - return "" - end + if #parts == 0 then + return "" + end - return "\n\n--- Project Context ---\n" .. table.concat(parts, "\n") + return "\n\n--- Project Context ---\n" .. table.concat(parts, "\n") end --- Check if this is an inline prompt (tags in target file, not a coder file) ---@param event table ---@return boolean local function is_inline_prompt(event) - -- Inline prompts have a range with start_line/end_line from tag detection - -- and the source file is the same as target (not a .codetyper/ file) - if not event.range or not event.range.start_line then - return false - end - -- Check if source path (if any) equals target, or if target has no .codetyper/ in it - local target = event.target_path or "" - if target:match("%.codetyper%.") then - return false - end - return true + -- Inline prompts have a range with start_line/end_line from tag detection + -- and the source file is the same as target (not a .codetyper/ file) + if not event.range or not event.range.start_line then + return false + end + -- Check if source path (if any) equals target, or if target has no .codetyper/ in it + local target = event.target_path or "" + if target:match("%.codetyper%.") then + return false + end + return true end --- Build file content with marked region for inline prompts @@ -414,19 +426,19 @@ end ---@param prompt_content string The prompt inside the tags ---@return string local function build_marked_file_content(lines, start_line, end_line, prompt_content) - local result = {} - for i, line in ipairs(lines) do - if i == start_line then - -- Mark the start of the region to be replaced - table.insert(result, ">>> REPLACE THIS REGION (lines " .. start_line .. "-" .. end_line .. ") <<<") - table.insert(result, "--- User request: " .. prompt_content:gsub("\n", " "):sub(1, 100) .. " ---") - end - table.insert(result, line) - if i == end_line then - table.insert(result, ">>> END OF REGION TO REPLACE <<<") - end - end - return table.concat(result, "\n") + local result = {} + for i, line in ipairs(lines) do + if i == start_line then + -- Mark the start of the region to be replaced + table.insert(result, ">>> REPLACE THIS REGION (lines " .. start_line .. "-" .. end_line .. ") <<<") + table.insert(result, "--- User request: " .. prompt_content:gsub("\n", " "):sub(1, 100) .. " ---") + end + table.insert(result, line) + if i == end_line then + table.insert(result, ">>> END OF REGION TO REPLACE <<<") + end + end + return table.concat(result, "\n") end --- Build prompt for code generation @@ -434,119 +446,120 @@ end ---@return string prompt ---@return table context local function build_prompt(event) - local intent_mod = require("codetyper.core.intent") - local eid = event and event.id + local intent_mod = require("codetyper.core.intent") + local eid = event and event.id - notify_stage(eid, "Reading file...") + notify_stage(eid, "Reading file...") - local target_content = "" - local target_lines = {} - if event.target_path then - local ok, lines = pcall(function() - return vim.fn.readfile(event.target_path) - end) - if ok and lines then - target_lines = lines - target_content = table.concat(lines, "\n") - end - end + local target_content = "" + local target_lines = {} + if event.target_path then + local ok, lines = pcall(function() + return vim.fn.readfile(event.target_path) + end) + if ok and lines then + target_lines = lines + target_content = table.concat(lines, "\n") + end + end - local filetype = vim.fn.fnamemodify(event.target_path or "", ":e") + local filetype = vim.fn.fnamemodify(event.target_path or "", ":e") - notify_stage(eid, "Searching index...") + notify_stage(eid, "Searching index...") - local indexed_context = nil - local indexed_content = "" - pcall(function() - local indexer = require("codetyper.features.indexer") - indexed_context = indexer.get_context_for({ - file = event.target_path, - intent = event.intent, - prompt = event.prompt_content, - scope = event.scope_text, - }) - indexed_content = format_indexed_context(indexed_context) - end) + local indexed_context = nil + local indexed_content = "" + pcall(function() + local indexer = require("codetyper.features.indexer") + indexed_context = indexer.get_context_for({ + file = event.target_path, + intent = event.intent, + prompt = event.prompt_content, + scope = event.scope_text, + }) + indexed_content = format_indexed_context(indexed_context) + end) - local attached_content = format_attached_files(event.attached_files) + local attached_content = format_attached_files(event.attached_files) - notify_stage(eid, "Gathering context...") + notify_stage(eid, "Gathering context...") - local coder_context = get_coder_context(event.target_path) + local coder_context = get_coder_context(event.target_path) - notify_stage(eid, "Recalling patterns...") + notify_stage(eid, "Recalling patterns...") - local brain_context = "" - pcall(function() - local brain = require("codetyper.core.memory") - if brain.is_initialized() then - local query_text = event.prompt_content or "" - if event.scope and event.scope.name then - query_text = event.scope.name .. " " .. query_text - end + local brain_context = "" + pcall(function() + local brain = require("codetyper.core.memory") + if brain.is_initialized() then + local query_text = event.prompt_content or "" + if event.scope and event.scope.name then + query_text = event.scope.name .. " " .. query_text + end - local result = brain.query({ - query = query_text, - file = event.target_path, - max_results = 5, - types = { "pattern", "correction", "convention" }, - }) + local result = brain.query({ + query = query_text, + file = event.target_path, + max_results = 5, + types = { "pattern", "correction", "convention" }, + }) - if result and result.nodes and #result.nodes > 0 then - local memories = { "\n\n--- Learned Patterns & Conventions ---" } - for _, node in ipairs(result.nodes) do - if node.c then - local summary = node.c.s or "" - local detail = node.c.d or "" - if summary ~= "" then - table.insert(memories, "• " .. summary) - if detail ~= "" and #detail < 200 then - table.insert(memories, " " .. detail) - end - end - end - end - if #memories > 1 then - brain_context = table.concat(memories, "\n") - end - end - end - end) + if result and result.nodes and #result.nodes > 0 then + local memories = { "\n\n--- Learned Patterns & Conventions ---" } + for _, node in ipairs(result.nodes) do + if node.c then + local summary = node.c.s or "" + local detail = node.c.d or "" + if summary ~= "" then + table.insert(memories, "• " .. summary) + if detail ~= "" and #detail < 200 then + table.insert(memories, " " .. detail) + end + end + end + end + if #memories > 1 then + brain_context = table.concat(memories, "\n") + end + end + end + end) - notify_stage(eid, "Building prompt...") + notify_stage(eid, "Building prompt...") - -- Include project tree context for whole-file selections - local project_context = "" - if event.is_whole_file and event.project_context then - project_context = "\n\n--- Project Structure ---\n" .. event.project_context - end + -- Include project tree context for whole-file selections + local project_context = "" + if event.is_whole_file and event.project_context then + project_context = "\n\n--- Project Structure ---\n" .. event.project_context + end - -- Combine all context sources: brain memories first, then coder context, attached files, indexed, project - local extra_context = brain_context .. coder_context .. attached_content .. indexed_content .. project_context + -- Combine all context sources: brain memories first, then coder context, attached files, indexed, project + local extra_context = brain_context .. coder_context .. attached_content .. indexed_content .. project_context - -- Build context with scope information - local context = { - target_path = event.target_path, - target_content = target_content, - filetype = filetype, - scope = event.scope, - scope_text = event.scope_text, - scope_range = event.scope_range, - intent = event.intent, - attached_files = event.attached_files, - indexed_context = indexed_context, - } + -- Build context with scope information + local context = { + target_path = event.target_path, + target_content = target_content, + filetype = filetype, + scope = event.scope, + scope_text = event.scope_text, + scope_range = event.scope_range, + intent = event.intent, + attached_files = event.attached_files, + indexed_context = indexed_context, + } - -- Build the actual prompt based on intent and scope - local system_prompt = "" - local user_prompt = event.prompt_content + -- Build the actual prompt based on intent and scope + local system_prompt = "" + local user_prompt = event.prompt_content - if event.intent then - system_prompt = intent_mod.get_prompt_modifier(event.intent) - end + if event.intent then + system_prompt = intent_mod.get_prompt_modifier(event.intent) + end - -- Ask the LLM to show its thinking (so we can display it in the buffer) - system_prompt = system_prompt .. [[ + -- Ask the LLM to show its thinking (so we can display it in the buffer) + system_prompt = system_prompt + .. [[ OUTPUT FORMAT - Show your reasoning first: 1. Start with exactly this line: @thinking @@ -561,17 +574,17 @@ end thinking ]] - -- SPECIAL HANDLING: Inline prompts with /@ ... @/ tags - -- Output only the code that replaces the tagged region (no SEARCH/REPLACE markers) - if is_inline_prompt(event) and event.range and event.range.start_line then - local start_line = event.range.start_line - local end_line = event.range.end_line or start_line + -- SPECIAL HANDLING: Inline prompts with /@ ... @/ tags + -- Output only the code that replaces the tagged region (no SEARCH/REPLACE markers) + if is_inline_prompt(event) and event.range and event.range.start_line then + local start_line = event.range.start_line + local end_line = event.range.end_line or start_line - -- Full file content for context - local file_content = table.concat(target_lines, "\n"):sub(1, 12000) + -- Full file content for context + local file_content = table.concat(target_lines, "\n"):sub(1, 12000) - user_prompt = string.format( - [[You are editing a %s file: %s + user_prompt = string.format( + [[You are editing a %s file: %s TASK: %s %s @@ -582,32 +595,32 @@ FULL FILE: The user has selected lines %d-%d. Your output will REPLACE those lines exactly. Output ONLY the new code for that region (no markers, no explanations, no code fences). Your response replaces the selection. Preserve indentation.]], - filetype, - vim.fn.fnamemodify(event.target_path or "", ":t"), - event.prompt_content, - extra_context, - filetype, - file_content, - start_line, - end_line - ) + filetype, + vim.fn.fnamemodify(event.target_path or "", ":t"), + event.prompt_content, + extra_context, + filetype, + file_content, + start_line, + end_line + ) - context.system_prompt = system_prompt - context.formatted_prompt = user_prompt - context.is_inline_prompt = true + context.system_prompt = system_prompt + context.formatted_prompt = user_prompt + context.is_inline_prompt = true - return user_prompt, context - end + return user_prompt, context + end - -- If we have a scope (function/method), include it in the prompt - if event.scope_text and event.scope and event.scope.type ~= "file" then - local scope_type = event.scope.type - local scope_name = event.scope.name or "anonymous" + -- If we have a scope (function/method), include it in the prompt + if event.scope_text and event.scope and event.scope.type ~= "file" then + local scope_type = event.scope.type + local scope_name = event.scope.name or "anonymous" - -- Special handling for "complete" intent - fill in the function body - if event.intent and event.intent.type == "complete" then - user_prompt = string.format( - [[Complete this %s. Fill in the implementation based on the description. + -- Special handling for "complete" intent - fill in the function body + if event.intent and event.intent.type == "complete" then + user_prompt = string.format( + [[Complete this %s. Fill in the implementation based on the description. IMPORTANT: - Keep the EXACT same function signature (name, parameters, return type) @@ -623,23 +636,24 @@ Current %s (incomplete): What it should do: %s Return ONLY the complete %s with implementation. No explanations, no duplicates.]], - scope_type, - scope_type, - filetype, - event.scope_text, - extra_context, - event.prompt_content, - scope_type - ) - -- Remind the LLM not to repeat the original file content; ask for only the new/updated code or a unified diff - user_prompt = user_prompt .. [[ + scope_type, + scope_type, + filetype, + event.scope_text, + extra_context, + event.prompt_content, + scope_type + ) + -- Remind the LLM not to repeat the original file content; ask for only the new/updated code or a unified diff + user_prompt = user_prompt + .. [[ IMPORTANT: Do NOT repeat the existing code provided above. Return ONLY the new or modified code (the updated function body). If you modify the file, prefer outputting a unified diff patch using standard diff headers (--- a/ / +++ b/ and @@ hunks). No explanations, no markdown, no code fences. ]] - -- For other replacement intents, provide the full scope to transform - elseif event.intent and intent_mod.is_replacement(event.intent) then - user_prompt = string.format( - [[Here is a %s named "%s" in a %s file: + -- For other replacement intents, provide the full scope to transform + elseif event.intent and intent_mod.is_replacement(event.intent) then + user_prompt = string.format( + [[Here is a %s named "%s" in a %s file: ```%s %s @@ -648,19 +662,19 @@ IMPORTANT: Do NOT repeat the existing code provided above. Return ONLY the new o User request: %s Return the complete transformed %s. Output only code, no explanations.]], - scope_type, - scope_name, - filetype, - filetype, - event.scope_text, - extra_context, - event.prompt_content, - scope_type - ) - else - -- For insertion intents, provide context - user_prompt = string.format( - [[Context - this code is inside a %s named "%s": + scope_type, + scope_name, + filetype, + filetype, + event.scope_text, + extra_context, + event.prompt_content, + scope_type + ) + else + -- For insertion intents, provide context + user_prompt = string.format( + [[Context - this code is inside a %s named "%s": ```%s %s @@ -669,30 +683,32 @@ Return the complete transformed %s. Output only code, no explanations.]], User request: %s Output only the code to insert, no explanations.]], - scope_type, - scope_name, - filetype, - event.scope_text, - extra_context, - event.prompt_content - ) + scope_type, + scope_name, + filetype, + event.scope_text, + extra_context, + event.prompt_content + ) - -- Remind the LLM not to repeat the full file content; ask for only the new/modified code or unified diff - user_prompt = user_prompt .. [[ + -- Remind the LLM not to repeat the full file content; ask for only the new/modified code or unified diff + user_prompt = user_prompt + .. [[ IMPORTANT: Do NOT repeat the full file content shown above. Return ONLY the new or modified code required to satisfy the request. If you modify the file, prefer outputting a unified diff patch using standard diff headers (--- a/ / +++ b/ and @@ hunks). No explanations, no markdown, no code fences. ]] - -- Remind the LLM not to repeat the original file content; ask for only the inserted code or a unified diff - user_prompt = user_prompt .. [[ + -- Remind the LLM not to repeat the original file content; ask for only the inserted code or a unified diff + user_prompt = user_prompt + .. [[ IMPORTANT: Do NOT repeat the surrounding code provided above. Return ONLY the code to insert (the new snippet). If you modify multiple parts of the file, prefer outputting a unified diff patch using standard diff headers (--- a/ / +++ b/ and @@ hunks). No explanations, no markdown, no code fences. ]] - end - else - -- No scope resolved, use full file context - user_prompt = string.format( - [[File: %s (%s) + end + else + -- No scope resolved, use full file context + user_prompt = string.format( + [[File: %s (%s) ```%s %s @@ -701,19 +717,19 @@ IMPORTANT: Do NOT repeat the surrounding code provided above. Return ONLY the co User request: %s Output only code, no explanations.]], - vim.fn.fnamemodify(event.target_path or "", ":t"), - filetype, - filetype, - target_content:sub(1, 4000), -- Limit context size - extra_context, - event.prompt_content - ) - end + vim.fn.fnamemodify(event.target_path or "", ":t"), + filetype, + filetype, + target_content:sub(1, 4000), -- Limit context size + extra_context, + event.prompt_content + ) + end - context.system_prompt = system_prompt - context.formatted_prompt = user_prompt + context.system_prompt = system_prompt + context.formatted_prompt = user_prompt - return user_prompt, context + return user_prompt, context end --- Create and start a worker @@ -722,101 +738,101 @@ end ---@param callback function(result: WorkerResult) ---@return Worker function M.create(event, worker_type, callback) - local worker = { - id = generate_id(), - event = event, - worker_type = worker_type, - status = "pending", - start_time = os.clock(), - callback = callback, - } + local worker = { + id = generate_id(), + event = event, + worker_type = worker_type, + status = "pending", + start_time = os.clock(), + callback = callback, + } - active_workers[worker.id] = worker + active_workers[worker.id] = worker - -- Log worker creation - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "worker", - message = string.format("Worker %s started (%s)", worker.id, worker_type), - data = { - worker_id = worker.id, - event_id = event.id, - provider = worker_type, - }, - }) - end) + -- Log worker creation + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "worker", + message = string.format("Worker %s started (%s)", worker.id, worker_type), + data = { + worker_id = worker.id, + event_id = event.id, + provider = worker_type, + }, + }) + end) - -- Start the work - M.start(worker) + -- Start the work + M.start(worker) - return worker + return worker end --- Start worker execution ---@param worker Worker function M.start(worker) - worker.status = "running" - local eid = worker.event and worker.event.id + worker.status = "running" + local eid = worker.event and worker.event.id - notify_stage(eid, "Reading context...") + notify_stage(eid, "Reading context...") - local prompt, context = build_prompt(worker.event) + local prompt, context = build_prompt(worker.event) - -- Check if smart selection is enabled (memory-based provider selection) - local use_smart_selection = false - pcall(function() - local codetyper = require("codetyper") - local config = codetyper.get_config() - use_smart_selection = config.llm.smart_selection ~= false -- Default to true - end) + -- Check if smart selection is enabled (memory-based provider selection) + local use_smart_selection = false + pcall(function() + local codetyper = require("codetyper") + local config = codetyper.get_config() + use_smart_selection = config.llm.smart_selection ~= false -- Default to true + end) - local provider_label = worker.worker_type or "LLM" - notify_stage(eid, "Sending to " .. provider_label .. "...") + local provider_label = worker.worker_type or "LLM" + notify_stage(eid, "Sending to " .. provider_label .. "...") - -- Define the response handler - local function handle_response(response, err, usage_or_metadata) - if worker.status ~= "running" then - return -- Already cancelled - end + -- Define the response handler + local function handle_response(response, err, usage_or_metadata) + if worker.status ~= "running" then + return -- Already cancelled + end - notify_stage(eid, "Processing response...") + notify_stage(eid, "Processing response...") - -- Extract usage from metadata if smart_generate was used - local usage = usage_or_metadata - if type(usage_or_metadata) == "table" and usage_or_metadata.provider then - usage = nil - worker.worker_type = usage_or_metadata.provider - if usage_or_metadata.pondered then - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format( - "Pondering: %s (agreement: %.0f%%)", - usage_or_metadata.corrected and "corrected" or "validated", - (usage_or_metadata.agreement or 1) * 100 - ), - }) - end) - end - end + -- Extract usage from metadata if smart_generate was used + local usage = usage_or_metadata + if type(usage_or_metadata) == "table" and usage_or_metadata.provider then + usage = nil + worker.worker_type = usage_or_metadata.provider + if usage_or_metadata.pondered then + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format( + "Pondering: %s (agreement: %.0f%%)", + usage_or_metadata.corrected and "corrected" or "validated", + (usage_or_metadata.agreement or 1) * 100 + ), + }) + end) + end + end - M.complete(worker, response, err, usage) - end + M.complete(worker, response, err, usage) + end - -- Use smart selection or direct client - if use_smart_selection then - local llm = require("codetyper.core.llm") - llm.smart_generate(prompt, context, handle_response) - else - local client, client_err = get_client(worker.worker_type) - if not client then - M.complete(worker, nil, client_err) - return - end - client.generate(prompt, context, handle_response) - end + -- Use smart selection or direct client + if use_smart_selection then + local llm = require("codetyper.core.llm") + llm.smart_generate(prompt, context, handle_response) + else + local client, client_err = get_client(worker.worker_type) + if not client then + M.complete(worker, nil, client_err) + return + end + client.generate(prompt, context, handle_response) + end end --- Complete worker execution @@ -825,174 +841,176 @@ end ---@param error string|nil ---@param usage table|nil function M.complete(worker, response, error, usage) - local duration = os.clock() - worker.start_time + local duration = os.clock() - worker.start_time - if error then - worker.status = "failed" - active_workers[worker.id] = nil + if error then + worker.status = "failed" + active_workers[worker.id] = nil - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "error", - message = string.format("Worker %s failed: %s", worker.id, error), - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "error", + message = string.format("Worker %s failed: %s", worker.id, error), + }) + end) - worker.callback({ - success = false, - response = nil, - error = error, - confidence = 0, - confidence_breakdown = {}, - duration = duration, - worker_type = worker.worker_type, - usage = usage, - }) - return - end + worker.callback({ + success = false, + response = nil, + error = error, + confidence = 0, + confidence_breakdown = {}, + duration = duration, + worker_type = worker.worker_type, + usage = usage, + }) + return + end - -- Check if LLM needs more context - if needs_more_context(response) then - worker.status = "needs_context" - active_workers[worker.id] = nil + -- Check if LLM needs more context + if needs_more_context(response) then + worker.status = "needs_context" + active_workers[worker.id] = nil - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Worker %s: LLM needs more context", worker.id), - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Worker %s: LLM needs more context", worker.id), + }) + end) - worker.callback({ - success = false, - response = response, - error = nil, - needs_context = true, - original_event = worker.event, - confidence = 0, - confidence_breakdown = {}, - duration = duration, - worker_type = worker.worker_type, - usage = usage, - }) - return - end + worker.callback({ + success = false, + response = response, + error = nil, + needs_context = true, + original_event = worker.event, + confidence = 0, + confidence_breakdown = {}, + duration = duration, + worker_type = worker.worker_type, + usage = usage, + }) + return + end - -- Log the full raw LLM response (for debugging) - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "response", - message = "--- LLM Response ---", - data = { - raw_response = response, - }, - }) - end) + -- Log the full raw LLM response (for debugging) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "response", + message = "--- LLM Response ---", + data = { + raw_response = response, + }, + }) + end) - -- Clean the response (remove markdown, explanations, etc.) - local filetype = vim.fn.fnamemodify(worker.event.target_path or "", ":e") - local cleaned_response = clean_response(response, filetype) + -- Clean the response (remove markdown, explanations, etc.) + local filetype = vim.fn.fnamemodify(worker.event.target_path or "", ":e") + local cleaned_response = clean_response(response, filetype) - -- Score confidence on cleaned response - local conf_score, breakdown = confidence.score(cleaned_response, worker.event.prompt_content) + -- Score confidence on cleaned response + local conf_score, breakdown = confidence.score(cleaned_response, worker.event.prompt_content) - worker.status = "completed" - active_workers[worker.id] = nil + worker.status = "completed" + active_workers[worker.id] = nil - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "success", - message = string.format( - "Worker %s completed (%.2fs, confidence: %.2f - %s)", - worker.id, duration, conf_score, confidence.level_name(conf_score) - ), - data = { - confidence_breakdown = confidence.format_breakdown(breakdown), - usage = usage, - }, - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "success", + message = string.format( + "Worker %s completed (%.2fs, confidence: %.2f - %s)", + worker.id, + duration, + conf_score, + confidence.level_name(conf_score) + ), + data = { + confidence_breakdown = confidence.format_breakdown(breakdown), + usage = usage, + }, + }) + end) - worker.callback({ - success = true, - response = cleaned_response, - error = nil, - confidence = conf_score, - confidence_breakdown = breakdown, - duration = duration, - worker_type = worker.worker_type, - usage = usage, - }) + worker.callback({ + success = true, + response = cleaned_response, + error = nil, + confidence = conf_score, + confidence_breakdown = breakdown, + duration = duration, + worker_type = worker.worker_type, + usage = usage, + }) end --- Cancel a worker ---@param worker_id string ---@return boolean function M.cancel(worker_id) - local worker = active_workers[worker_id] - if not worker then - return false - end + local worker = active_workers[worker_id] + if not worker then + return false + end - worker.status = "cancelled" - active_workers[worker_id] = nil + worker.status = "cancelled" + active_workers[worker_id] = nil - pcall(function() - local logs = require("codetyper.adapters.nvim.ui.logs") - logs.add({ - type = "info", - message = string.format("Worker %s cancelled", worker_id), - }) - end) + pcall(function() + local logs = require("codetyper.adapters.nvim.ui.logs") + logs.add({ + type = "info", + message = string.format("Worker %s cancelled", worker_id), + }) + end) - return true + return true end --- Get active worker count ---@return number function M.active_count() - local count = 0 - for _ in pairs(active_workers) do - count = count + 1 - end - return count + local count = 0 + for _ in pairs(active_workers) do + count = count + 1 + end + return count end --- Get all active workers ---@return Worker[] function M.get_active() - local workers = {} - for _, worker in pairs(active_workers) do - table.insert(workers, worker) - end - return workers + local workers = {} + for _, worker in pairs(active_workers) do + table.insert(workers, worker) + end + return workers end --- Check if worker exists and is running ---@param worker_id string ---@return boolean function M.is_running(worker_id) - local worker = active_workers[worker_id] - return worker ~= nil and worker.status == "running" + local worker = active_workers[worker_id] + return worker ~= nil and worker.status == "running" end --- Cancel all workers for an event ---@param event_id string ---@return number cancelled_count function M.cancel_for_event(event_id) - local cancelled = 0 - for id, worker in pairs(active_workers) do - if worker.event.id == event_id then - M.cancel(id) - cancelled = cancelled + 1 - end - end - return cancelled + local cancelled = 0 + for id, worker in pairs(active_workers) do + if worker.event.id == event_id then + M.cancel(id) + cancelled = cancelled + 1 + end + end + return cancelled end - return M diff --git a/lua/codetyper/core/scope/init.lua b/lua/codetyper/core/scope/init.lua index 58291e2..60a3819 100644 --- a/lua/codetyper/core/scope/init.lua +++ b/lua/codetyper/core/scope/init.lua @@ -23,43 +23,43 @@ local block_nodes = params.block_nodes ---@param bufnr number ---@return boolean function M.has_treesitter(bufnr) - -- Try to get the language for this buffer - local lang = nil + -- Try to get the language for this buffer + local lang = nil - -- Method 1: Use vim.treesitter (Neovim 0.9+) - if vim.treesitter and vim.treesitter.language then - local ft = vim.bo[bufnr].filetype - if vim.treesitter.language.get_lang then - lang = vim.treesitter.language.get_lang(ft) - else - lang = ft - end - end + -- Method 1: Use vim.treesitter (Neovim 0.9+) + if vim.treesitter and vim.treesitter.language then + local ft = vim.bo[bufnr].filetype + if vim.treesitter.language.get_lang then + lang = vim.treesitter.language.get_lang(ft) + else + lang = ft + end + end - -- Method 2: Try nvim-treesitter parsers module - if not lang then - local ok, parsers = pcall(require, "nvim-treesitter.parsers") - if ok and parsers then - if parsers.get_buf_lang then - lang = parsers.get_buf_lang(bufnr) - elseif parsers.ft_to_lang then - lang = parsers.ft_to_lang(vim.bo[bufnr].filetype) - end - end - end + -- Method 2: Try nvim-treesitter parsers module + if not lang then + local ok, parsers = pcall(require, "nvim-treesitter.parsers") + if ok and parsers then + if parsers.get_buf_lang then + lang = parsers.get_buf_lang(bufnr) + elseif parsers.ft_to_lang then + lang = parsers.ft_to_lang(vim.bo[bufnr].filetype) + end + end + end - -- Fallback to filetype - if not lang then - lang = vim.bo[bufnr].filetype - end + -- Fallback to filetype + if not lang then + lang = vim.bo[bufnr].filetype + end - if not lang or lang == "" then - return false - end + if not lang or lang == "" then + return false + end - -- Check if parser is available - local has_parser = pcall(vim.treesitter.get_parser, bufnr, lang) - return has_parser + -- Check if parser is available + local has_parser = pcall(vim.treesitter.get_parser, bufnr, lang) + return has_parser end --- Get Tree-sitter node at position @@ -68,30 +68,30 @@ end ---@param col number 0-indexed ---@return TSNode|nil local function get_node_at_pos(bufnr, row, col) - local ok, ts_utils = pcall(require, "nvim-treesitter.ts_utils") - if not ok then - return nil - end + local ok, ts_utils = pcall(require, "nvim-treesitter.ts_utils") + if not ok then + return nil + end - -- Try to get the node at the cursor position - local node = ts_utils.get_node_at_cursor() - if node then - return node - end + -- Try to get the node at the cursor position + local node = ts_utils.get_node_at_cursor() + if node then + return node + end - -- Fallback: get root and find node - local parser = vim.treesitter.get_parser(bufnr) - if not parser then - return nil - end + -- Fallback: get root and find node + local parser = vim.treesitter.get_parser(bufnr) + if not parser then + return nil + end - local tree = parser:parse()[1] - if not tree then - return nil - end + local tree = parser:parse()[1] + if not tree then + return nil + end - local root = tree:root() - return root:named_descendant_for_range(row, col, row, col) + local root = tree:root() + return root:named_descendant_for_range(row, col, row, col) end --- Find enclosing scope node of specific types @@ -99,15 +99,15 @@ end ---@param node_types table ---@return TSNode|nil, string|nil scope_type local function find_enclosing_scope(node, node_types) - local current = node - while current do - local node_type = current:type() - if node_types[node_type] then - return current, node_types[node_type] - end - current = current:parent() - end - return nil, nil + local current = node + while current do + local node_type = current:type() + if node_types[node_type] then + return current, node_types[node_type] + end + current = current:parent() + end + return nil, nil end --- Extract function/method name from node @@ -115,20 +115,20 @@ end ---@param bufnr number ---@return string|nil local function get_scope_name(node, bufnr) - -- Try to find name child node - local name_node = node:field("name")[1] - if name_node then - return vim.treesitter.get_node_text(name_node, bufnr) - end + -- Try to find name child node + local name_node = node:field("name")[1] + if name_node then + return vim.treesitter.get_node_text(name_node, bufnr) + end - -- Try identifier child - for child in node:iter_children() do - if child:type() == "identifier" or child:type() == "property_identifier" then - return vim.treesitter.get_node_text(child, bufnr) - end - end + -- Try identifier child + for child in node:iter_children() do + if child:type() == "identifier" or child:type() == "property_identifier" then + return vim.treesitter.get_node_text(child, bufnr) + end + end - return nil + return nil end --- Resolve scope at position using Tree-sitter @@ -137,74 +137,74 @@ end ---@param col number 1-indexed column number ---@return ScopeInfo function M.resolve_scope(bufnr, row, col) - -- Default to file scope - local default_scope = { - type = "file", - node_type = "file", - range = { - start_row = 1, - start_col = 0, - end_row = vim.api.nvim_buf_line_count(bufnr), - end_col = 0, - }, - text = table.concat(vim.api.nvim_buf_get_lines(bufnr, 0, -1, false), "\n"), - name = vim.fn.fnamemodify(vim.api.nvim_buf_get_name(bufnr), ":t"), - } + -- Default to file scope + local default_scope = { + type = "file", + node_type = "file", + range = { + start_row = 1, + start_col = 0, + end_row = vim.api.nvim_buf_line_count(bufnr), + end_col = 0, + }, + text = table.concat(vim.api.nvim_buf_get_lines(bufnr, 0, -1, false), "\n"), + name = vim.fn.fnamemodify(vim.api.nvim_buf_get_name(bufnr), ":t"), + } - -- Check if Tree-sitter is available - if not M.has_treesitter(bufnr) then - -- Fall back to heuristic-based scope resolution - return M.resolve_scope_heuristic(bufnr, row, col) or default_scope - end + -- Check if Tree-sitter is available + if not M.has_treesitter(bufnr) then + -- Fall back to heuristic-based scope resolution + return M.resolve_scope_heuristic(bufnr, row, col) or default_scope + end - -- Convert to 0-indexed for Tree-sitter - local ts_row = row - 1 - local ts_col = col - 1 + -- Convert to 0-indexed for Tree-sitter + local ts_row = row - 1 + local ts_col = col - 1 - -- Get node at position - local node = get_node_at_pos(bufnr, ts_row, ts_col) - if not node then - return default_scope - end + -- Get node at position + local node = get_node_at_pos(bufnr, ts_row, ts_col) + if not node then + return default_scope + end - -- Try to find function scope first - local scope_node, scope_type = find_enclosing_scope(node, function_nodes) + -- Try to find function scope first + local scope_node, scope_type = find_enclosing_scope(node, function_nodes) - -- If no function, try class - if not scope_node then - scope_node, scope_type = find_enclosing_scope(node, class_nodes) - end + -- If no function, try class + if not scope_node then + scope_node, scope_type = find_enclosing_scope(node, class_nodes) + end - -- If no class, try block - if not scope_node then - scope_node, scope_type = find_enclosing_scope(node, block_nodes) - end + -- If no class, try block + if not scope_node then + scope_node, scope_type = find_enclosing_scope(node, block_nodes) + end - if not scope_node then - return default_scope - end + if not scope_node then + return default_scope + end - -- Get range (convert back to 1-indexed) - local start_row, start_col, end_row, end_col = scope_node:range() + -- Get range (convert back to 1-indexed) + local start_row, start_col, end_row, end_col = scope_node:range() - -- Get text - local text = vim.treesitter.get_node_text(scope_node, bufnr) + -- Get text + local text = vim.treesitter.get_node_text(scope_node, bufnr) - -- Get name - local name = get_scope_name(scope_node, bufnr) + -- Get name + local name = get_scope_name(scope_node, bufnr) - return { - type = scope_type, - node_type = scope_node:type(), - range = { - start_row = start_row + 1, - start_col = start_col, - end_row = end_row + 1, - end_col = end_col, - }, - text = text, - name = name, - } + return { + type = scope_type, + node_type = scope_node:type(), + range = { + start_row = start_row + 1, + start_col = start_col, + end_row = end_row + 1, + end_col = end_col, + }, + text = text, + name = name, + } end --- Heuristic fallback for scope resolution (no Tree-sitter) @@ -213,151 +213,157 @@ end ---@param col number 1-indexed ---@return ScopeInfo|nil function M.resolve_scope_heuristic(bufnr, row, col) - _ = col -- unused in heuristic - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local filetype = vim.bo[bufnr].filetype + _ = col -- unused in heuristic + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local filetype = vim.bo[bufnr].filetype - -- Language-specific function patterns - local patterns = { - lua = { - start = "^%s*local%s+function%s+", - start_alt = "^%s*function%s+", - ending = "^%s*end%s*$", - }, - python = { - start = "^%s*def%s+", - start_alt = "^%s*async%s+def%s+", - ending = nil, -- Python uses indentation - }, - javascript = { - start = "^%s*export%s+function%s+", - start_alt = "^%s*function%s+", - start_alt2 = "^%s*export%s+const%s+%w+%s*=", - start_alt3 = "^%s*const%s+%w+%s*=%s*", - start_alt4 = "^%s*export%s+async%s+function%s+", - start_alt5 = "^%s*async%s+function%s+", - ending = "^%s*}%s*$", - }, - typescript = { - start = "^%s*export%s+function%s+", - start_alt = "^%s*function%s+", - start_alt2 = "^%s*export%s+const%s+%w+%s*=", - start_alt3 = "^%s*const%s+%w+%s*=%s*", - start_alt4 = "^%s*export%s+async%s+function%s+", - start_alt5 = "^%s*async%s+function%s+", - ending = "^%s*}%s*$", - }, - } + -- Language-specific function patterns + local patterns = { + lua = { + start = "^%s*local%s+function%s+", + start_alt = "^%s*function%s+", + ending = "^%s*end%s*$", + }, + python = { + start = "^%s*def%s+", + start_alt = "^%s*async%s+def%s+", + ending = nil, -- Python uses indentation + }, + javascript = { + start = "^%s*export%s+function%s+", + start_alt = "^%s*function%s+", + start_alt2 = "^%s*export%s+const%s+%w+%s*=", + start_alt3 = "^%s*const%s+%w+%s*=%s*", + start_alt4 = "^%s*export%s+async%s+function%s+", + start_alt5 = "^%s*async%s+function%s+", + ending = "^%s*}%s*$", + }, + typescript = { + start = "^%s*export%s+function%s+", + start_alt = "^%s*function%s+", + start_alt2 = "^%s*export%s+const%s+%w+%s*=", + start_alt3 = "^%s*const%s+%w+%s*=%s*", + start_alt4 = "^%s*export%s+async%s+function%s+", + start_alt5 = "^%s*async%s+function%s+", + ending = "^%s*}%s*$", + }, + } - local lang_patterns = patterns[filetype] - if not lang_patterns then - return nil - end + local lang_patterns = patterns[filetype] + if not lang_patterns then + return nil + end - -- Find function start (search backwards) - local start_line = nil - for i = row, 1, -1 do - local line = lines[i] - -- Check all start patterns - if line:match(lang_patterns.start) - or (lang_patterns.start_alt and line:match(lang_patterns.start_alt)) - or (lang_patterns.start_alt2 and line:match(lang_patterns.start_alt2)) - or (lang_patterns.start_alt3 and line:match(lang_patterns.start_alt3)) - or (lang_patterns.start_alt4 and line:match(lang_patterns.start_alt4)) - or (lang_patterns.start_alt5 and line:match(lang_patterns.start_alt5)) then - start_line = i - break - end - end + -- Find function start (search backwards) + local start_line = nil + for i = row, 1, -1 do + local line = lines[i] + -- Check all start patterns + if + line:match(lang_patterns.start) + or (lang_patterns.start_alt and line:match(lang_patterns.start_alt)) + or (lang_patterns.start_alt2 and line:match(lang_patterns.start_alt2)) + or (lang_patterns.start_alt3 and line:match(lang_patterns.start_alt3)) + or (lang_patterns.start_alt4 and line:match(lang_patterns.start_alt4)) + or (lang_patterns.start_alt5 and line:match(lang_patterns.start_alt5)) + then + start_line = i + break + end + end - if not start_line then - return nil - end + if not start_line then + return nil + end - -- Find function end - local end_line = nil - if lang_patterns.ending then - -- Brace/end based languages - local depth = 0 - for i = start_line, #lines do - local line = lines[i] - -- Count braces or end keywords - if filetype == "lua" then - if line:match("function") or line:match("if") or line:match("for") or line:match("while") then - depth = depth + 1 - end - if line:match("^%s*end") then - depth = depth - 1 - if depth <= 0 then - end_line = i - break - end - end - else - -- JavaScript/TypeScript brace counting - for _ in line:gmatch("{") do depth = depth + 1 end - for _ in line:gmatch("}") do depth = depth - 1 end - if depth <= 0 and i > start_line then - end_line = i - break - end - end - end - else - -- Python: use indentation - local base_indent = #(lines[start_line]:match("^%s*") or "") - for i = start_line + 1, #lines do - local line = lines[i] - if line:match("^%s*$") then - goto continue - end - local indent = #(line:match("^%s*") or "") - if indent <= base_indent then - end_line = i - 1 - break - end - ::continue:: - end - end_line = end_line or #lines - end + -- Find function end + local end_line = nil + if lang_patterns.ending then + -- Brace/end based languages + local depth = 0 + for i = start_line, #lines do + local line = lines[i] + -- Count braces or end keywords + if filetype == "lua" then + if line:match("function") or line:match("if") or line:match("for") or line:match("while") then + depth = depth + 1 + end + if line:match("^%s*end") then + depth = depth - 1 + if depth <= 0 then + end_line = i + break + end + end + else + -- JavaScript/TypeScript brace counting + for _ in line:gmatch("{") do + depth = depth + 1 + end + for _ in line:gmatch("}") do + depth = depth - 1 + end + if depth <= 0 and i > start_line then + end_line = i + break + end + end + end + else + -- Python: use indentation + local base_indent = #(lines[start_line]:match("^%s*") or "") + for i = start_line + 1, #lines do + local line = lines[i] + if line:match("^%s*$") then + goto continue + end + local indent = #(line:match("^%s*") or "") + if indent <= base_indent then + end_line = i - 1 + break + end + ::continue:: + end + end_line = end_line or #lines + end - if not end_line then - end_line = #lines - end + if not end_line then + end_line = #lines + end - -- Extract text - local scope_lines = {} - for i = start_line, end_line do - table.insert(scope_lines, lines[i]) - end + -- Extract text + local scope_lines = {} + for i = start_line, end_line do + table.insert(scope_lines, lines[i]) + end - -- Try to extract function name - local name = nil - local first_line = lines[start_line] - name = first_line:match("function%s+([%w_]+)") or - first_line:match("def%s+([%w_]+)") or - first_line:match("const%s+([%w_]+)") + -- Try to extract function name + local name = nil + local first_line = lines[start_line] + name = first_line:match("function%s+([%w_]+)") + or first_line:match("def%s+([%w_]+)") + or first_line:match("const%s+([%w_]+)") - return { - type = "function", - node_type = "heuristic", - range = { - start_row = start_line, - start_col = 0, - end_row = end_line, - end_col = #lines[end_line], - }, - text = table.concat(scope_lines, "\n"), - name = name, - } + return { + type = "function", + node_type = "heuristic", + range = { + start_row = start_line, + start_col = 0, + end_row = end_line, + end_col = #lines[end_line], + }, + text = table.concat(scope_lines, "\n"), + name = name, + } end --- Get scope for the current cursor position ---@return ScopeInfo function M.resolve_scope_at_cursor() - local bufnr = vim.api.nvim_get_current_buf() - local cursor = vim.api.nvim_win_get_cursor(0) - return M.resolve_scope(bufnr, cursor[1], cursor[2] + 1) + local bufnr = vim.api.nvim_get_current_buf() + local cursor = vim.api.nvim_win_get_cursor(0) + return M.resolve_scope(bufnr, cursor[1], cursor[2] + 1) end --- Check if position is inside a function/method @@ -366,66 +372,66 @@ end ---@param col number 1-indexed ---@return boolean function M.is_in_function(bufnr, row, col) - local scope = M.resolve_scope(bufnr, row, col) - return scope.type == "function" or scope.type == "method" + local scope = M.resolve_scope(bufnr, row, col) + return scope.type == "function" or scope.type == "method" end --- Get all functions in buffer ---@param bufnr number ---@return ScopeInfo[] function M.get_all_functions(bufnr) - local functions = {} + local functions = {} - if not M.has_treesitter(bufnr) then - return functions - end + if not M.has_treesitter(bufnr) then + return functions + end - local parser = vim.treesitter.get_parser(bufnr) - if not parser then - return functions - end + local parser = vim.treesitter.get_parser(bufnr) + if not parser then + return functions + end - local tree = parser:parse()[1] - if not tree then - return functions - end + local tree = parser:parse()[1] + if not tree then + return functions + end - local root = tree:root() + local root = tree:root() - -- Query for all function nodes - local lang = parser:lang() - local query_string = [[ + -- Query for all function nodes + local lang = parser:lang() + local query_string = [[ (function_declaration) @func (function_definition) @func (method_definition) @func (arrow_function) @func ]] - local ok, query = pcall(vim.treesitter.query.parse, lang, query_string) - if not ok then - return functions - end + local ok, query = pcall(vim.treesitter.query.parse, lang, query_string) + if not ok then + return functions + end - for _, node in query:iter_captures(root, bufnr, 0, -1) do - local start_row, start_col, end_row, end_col = node:range() - local text = vim.treesitter.get_node_text(node, bufnr) - local name = get_scope_name(node, bufnr) + for _, node in query:iter_captures(root, bufnr, 0, -1) do + local start_row, start_col, end_row, end_col = node:range() + local text = vim.treesitter.get_node_text(node, bufnr) + local name = get_scope_name(node, bufnr) - table.insert(functions, { - type = function_nodes[node:type()] or "function", - node_type = node:type(), - range = { - start_row = start_row + 1, - start_col = start_col, - end_row = end_row + 1, - end_col = end_col, - }, - text = text, - name = name, - }) - end + table.insert(functions, { + type = function_nodes[node:type()] or "function", + node_type = node:type(), + range = { + start_row = start_row + 1, + start_col = start_col, + end_row = end_row + 1, + end_col = end_col, + }, + text = text, + name = name, + }) + end - return functions + return functions end --- Resolve enclosing context for a selection range. @@ -436,133 +442,137 @@ end ---@param sel_end number 1-indexed end line of selection ---@return table context { type: string, scopes: ScopeInfo[], expanded_start: number, expanded_end: number } function M.resolve_selection_context(bufnr, sel_start, sel_end) - local all_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local total_lines = #all_lines + local all_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local total_lines = #all_lines - local scope_start = M.resolve_scope(bufnr, sel_start, 1) - local scope_end = M.resolve_scope(bufnr, sel_end, 1) + local scope_start = M.resolve_scope(bufnr, sel_start, 1) + local scope_end = M.resolve_scope(bufnr, sel_end, 1) - local selected_lines = sel_end - sel_start + 1 + local selected_lines = sel_end - sel_start + 1 - if selected_lines >= (total_lines * 0.8) then - return { - type = "file", - scopes = {}, - expanded_start = 1, - expanded_end = total_lines, - } - end + if selected_lines >= (total_lines * 0.8) then + return { + type = "file", + scopes = {}, + expanded_start = 1, + expanded_end = total_lines, + } + end - -- Both ends resolve to the same function/method - if scope_start.type ~= "file" and scope_end.type ~= "file" - and scope_start.name == scope_end.name - and scope_start.range.start_row == scope_end.range.start_row then + -- Both ends resolve to the same function/method + if + scope_start.type ~= "file" + and scope_end.type ~= "file" + and scope_start.name == scope_end.name + and scope_start.range.start_row == scope_end.range.start_row + then + local fn_start = scope_start.range.start_row + local fn_end = scope_start.range.end_row + local fn_lines = fn_end - fn_start + 1 + local is_whole_fn = selected_lines >= (fn_lines * 0.85) - local fn_start = scope_start.range.start_row - local fn_end = scope_start.range.end_row - local fn_lines = fn_end - fn_start + 1 - local is_whole_fn = selected_lines >= (fn_lines * 0.85) + if is_whole_fn then + return { + type = "whole_function", + scopes = { scope_start }, + expanded_start = fn_start, + expanded_end = fn_end, + } + else + return { + type = "partial_function", + scopes = { scope_start }, + expanded_start = sel_start, + expanded_end = sel_end, + } + end + end - if is_whole_fn then - return { - type = "whole_function", - scopes = { scope_start }, - expanded_start = fn_start, - expanded_end = fn_end, - } - else - return { - type = "partial_function", - scopes = { scope_start }, - expanded_start = sel_start, - expanded_end = sel_end, - } - end - end + -- Selection spans across multiple functions or one end is file-level + local affected = {} + local functions = M.get_all_functions(bufnr) - -- Selection spans across multiple functions or one end is file-level - local affected = {} - local functions = M.get_all_functions(bufnr) + if #functions > 0 then + for _, fn in ipairs(functions) do + local fn_start = fn.range.start_row + local fn_end = fn.range.end_row + if fn_end >= sel_start and fn_start <= sel_end then + table.insert(affected, fn) + end + end + end - if #functions > 0 then - for _, fn in ipairs(functions) do - local fn_start = fn.range.start_row - local fn_end = fn.range.end_row - if fn_end >= sel_start and fn_start <= sel_end then - table.insert(affected, fn) - end - end - end + if #affected > 0 then + local exp_start = sel_start + local exp_end = sel_end + for _, fn in ipairs(affected) do + exp_start = math.min(exp_start, fn.range.start_row) + exp_end = math.max(exp_end, fn.range.end_row) + end + return { + type = "multi_function", + scopes = affected, + expanded_start = exp_start, + expanded_end = exp_end, + } + end - if #affected > 0 then - local exp_start = sel_start - local exp_end = sel_end - for _, fn in ipairs(affected) do - exp_start = math.min(exp_start, fn.range.start_row) - exp_end = math.max(exp_end, fn.range.end_row) - end - return { - type = "multi_function", - scopes = affected, - expanded_start = exp_start, - expanded_end = exp_end, - } - end + -- Indentation-based fallback: walk outward to find the enclosing block + local base_indent = math.huge + for i = sel_start, math.min(sel_end, total_lines) do + local line = all_lines[i] + if line and not line:match("^%s*$") then + local indent = #(line:match("^(%s*)") or "") + base_indent = math.min(base_indent, indent) + end + end + if base_indent == math.huge then + base_indent = 0 + end - -- Indentation-based fallback: walk outward to find the enclosing block - local base_indent = math.huge - for i = sel_start, math.min(sel_end, total_lines) do - local line = all_lines[i] - if line and not line:match("^%s*$") then - local indent = #(line:match("^(%s*)") or "") - base_indent = math.min(base_indent, indent) - end - end - if base_indent == math.huge then - base_indent = 0 - end + local block_start = sel_start + for i = sel_start - 1, 1, -1 do + local line = all_lines[i] + if line and not line:match("^%s*$") then + local indent = #(line:match("^(%s*)") or "") + if indent < base_indent then + block_start = i + break + end + end + end - local block_start = sel_start - for i = sel_start - 1, 1, -1 do - local line = all_lines[i] - if line and not line:match("^%s*$") then - local indent = #(line:match("^(%s*)") or "") - if indent < base_indent then - block_start = i - break - end - end - end + local block_end = sel_end + for i = sel_end + 1, total_lines do + local line = all_lines[i] + if line and not line:match("^%s*$") then + local indent = #(line:match("^(%s*)") or "") + if indent < base_indent then + block_end = i + break + end + end + end - local block_end = sel_end - for i = sel_end + 1, total_lines do - local line = all_lines[i] - if line and not line:match("^%s*$") then - local indent = #(line:match("^(%s*)") or "") - if indent < base_indent then - block_end = i - break - end - end - end + local block_lines = {} + for i = block_start, math.min(block_end, total_lines) do + table.insert(block_lines, all_lines[i]) + end - local block_lines = {} - for i = block_start, math.min(block_end, total_lines) do - table.insert(block_lines, all_lines[i]) - end - - return { - type = "indent_block", - scopes = {{ - type = "block", - node_type = "indentation", - range = { start_row = block_start, end_row = block_end }, - text = table.concat(block_lines, "\n"), - name = nil, - }}, - expanded_start = block_start, - expanded_end = block_end, - } + return { + type = "indent_block", + scopes = { + { + type = "block", + node_type = "indentation", + range = { start_row = block_start, end_row = block_end }, + text = table.concat(block_lines, "\n"), + name = nil, + }, + }, + expanded_start = block_start, + expanded_end = block_end, + } end return M diff --git a/lua/codetyper/core/thinking_placeholder.lua b/lua/codetyper/core/thinking_placeholder.lua index a673432..047c58e 100644 --- a/lua/codetyper/core/thinking_placeholder.lua +++ b/lua/codetyper/core/thinking_placeholder.lua @@ -23,176 +23,176 @@ local inline_status = {} ---@param event table PromptEvent with range, scope_range, target_path ---@return boolean success function M.insert(event) - if not event or not event.range then - return false - end - local range = event.scope_range or event.range - local target_bufnr = vim.fn.bufnr(event.target_path) - if target_bufnr == -1 then - for _, buf in ipairs(vim.api.nvim_list_bufs()) do - if vim.api.nvim_buf_get_name(buf) == event.target_path then - target_bufnr = buf - break - end - end - end - if target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then - target_bufnr = vim.fn.bufadd(event.target_path) - if target_bufnr > 0 then - vim.fn.bufload(target_bufnr) - end - end - if target_bufnr <= 0 or not vim.api.nvim_buf_is_valid(target_bufnr) then - return false - end + if not event or not event.range then + return false + end + local range = event.scope_range or event.range + local target_bufnr = vim.fn.bufnr(event.target_path) + if target_bufnr == -1 then + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + if vim.api.nvim_buf_get_name(buf) == event.target_path then + target_bufnr = buf + break + end + end + end + if target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then + target_bufnr = vim.fn.bufadd(event.target_path) + if target_bufnr > 0 then + vim.fn.bufload(target_bufnr) + end + end + if target_bufnr <= 0 or not vim.api.nvim_buf_is_valid(target_bufnr) then + return false + end - local line_count = vim.api.nvim_buf_line_count(target_bufnr) - local end_line = range.end_line - -- Include next line if it's only "}" (or whitespace + "}") so we don't leave a stray closing brace - if end_line < line_count then - local next_line = vim.api.nvim_buf_get_lines(target_bufnr, end_line, end_line + 1, false) - if next_line and next_line[1] and next_line[1]:match("^%s*}$") then - end_line = end_line + 1 - end - end + local line_count = vim.api.nvim_buf_line_count(target_bufnr) + local end_line = range.end_line + -- Include next line if it's only "}" (or whitespace + "}") so we don't leave a stray closing brace + if end_line < line_count then + local next_line = vim.api.nvim_buf_get_lines(target_bufnr, end_line, end_line + 1, false) + if next_line and next_line[1] and next_line[1]:match("^%s*}$") then + end_line = end_line + 1 + end + end - local start_row_0 = range.start_line - 1 - local end_row_0 = end_line - -- Replace range with single placeholder line - vim.api.nvim_buf_set_lines(target_bufnr, start_row_0, end_row_0, false, { PLACEHOLDER_TEXT }) - -- Gray out: extmark over the whole line - vim.api.nvim_buf_set_extmark(target_bufnr, ns_highlight, start_row_0, 0, { - end_row = start_row_0 + 1, - hl_group = "Comment", - hl_eol = true, - }) - -- Store marks for this placeholder so patch can replace it - local start_mark = marks.mark_point(target_bufnr, start_row_0, 0) - local end_mark = marks.mark_point(target_bufnr, start_row_0, #PLACEHOLDER_TEXT) - placeholders[event.id] = { - start_mark = start_mark, - end_mark = end_mark, - bufnr = target_bufnr, - } - return true + local start_row_0 = range.start_line - 1 + local end_row_0 = end_line + -- Replace range with single placeholder line + vim.api.nvim_buf_set_lines(target_bufnr, start_row_0, end_row_0, false, { PLACEHOLDER_TEXT }) + -- Gray out: extmark over the whole line + vim.api.nvim_buf_set_extmark(target_bufnr, ns_highlight, start_row_0, 0, { + end_row = start_row_0 + 1, + hl_group = "Comment", + hl_eol = true, + }) + -- Store marks for this placeholder so patch can replace it + local start_mark = marks.mark_point(target_bufnr, start_row_0, 0) + local end_mark = marks.mark_point(target_bufnr, start_row_0, #PLACEHOLDER_TEXT) + placeholders[event.id] = { + start_mark = start_mark, + end_mark = end_mark, + bufnr = target_bufnr, + } + return true end --- Get placeholder marks for an event (so patch can replace that range with code). ---@param event_id string ---@return table|nil { start_mark, end_mark, bufnr } or nil function M.get(event_id) - return placeholders[event_id] + return placeholders[event_id] end --- Clear placeholder entry after applying (and optionally delete marks). ---@param event_id string function M.clear(event_id) - local p = placeholders[event_id] - if p then - marks.delete(p.start_mark) - marks.delete(p.end_mark) - placeholders[event_id] = nil - end + local p = placeholders[event_id] + if p then + marks.delete(p.start_mark) + marks.delete(p.end_mark) + placeholders[event_id] = nil + end end --- Remove placeholder from buffer (e.g. on failure/cancel) and clear. Replaces placeholder line with empty line. ---@param event_id string function M.remove_on_failure(event_id) - local p = placeholders[event_id] - if not p or not p.bufnr or not vim.api.nvim_buf_is_valid(p.bufnr) then - M.clear(event_id) - return - end - if marks.is_valid(p.start_mark) and marks.is_valid(p.end_mark) then - local sr, sc, er, ec = marks.range_to_vim(p.start_mark, p.end_mark) - if sr ~= nil then - vim.api.nvim_buf_set_text(p.bufnr, sr, sc, er, ec, { "" }) - end - end - M.clear(event_id) + local p = placeholders[event_id] + if not p or not p.bufnr or not vim.api.nvim_buf_is_valid(p.bufnr) then + M.clear(event_id) + return + end + if marks.is_valid(p.start_mark) and marks.is_valid(p.end_mark) then + local sr, sc, er, ec = marks.range_to_vim(p.start_mark, p.end_mark) + if sr ~= nil then + vim.api.nvim_buf_set_text(p.bufnr, sr, sc, er, ec, { "" }) + end + end + M.clear(event_id) end --- 99-style: show "⠋ Implementing..." as virtual text at the line above the selection (no buffer change). --- Use for inline requests where we must not insert placeholder (e.g. SEARCH/REPLACE). ---@param event table PromptEvent with id, range, target_path function M.start_inline(event) - if not event or not event.id or not event.range then - return - end - local range = event.range - local target_bufnr = vim.fn.bufnr(event.target_path) - if target_bufnr == -1 then - for _, buf in ipairs(vim.api.nvim_list_bufs()) do - if vim.api.nvim_buf_get_name(buf) == event.target_path then - target_bufnr = buf - break - end - end - end - if target_bufnr <= 0 or not vim.api.nvim_buf_is_valid(target_bufnr) then - return - end - local start_row_0 = math.max(0, range.start_line - 1) - local col = 0 - local extmark_id = vim.api.nvim_buf_set_extmark(target_bufnr, ns_inline, start_row_0, col, { - virt_lines = { { { " Implementing", "Comment" } } }, - virt_lines_above = true, - }) - local Throbber = require("codetyper.adapters.nvim.ui.throbber") - local throb = Throbber.new(function(icon) - if not inline_status[event.id] then - return - end - local ent = inline_status[event.id] - if not ent.bufnr or not vim.api.nvim_buf_is_valid(ent.bufnr) then - return - end - local text = ent.status_text or "Implementing" - local ok = pcall(vim.api.nvim_buf_set_extmark, ent.bufnr, ns_inline, start_row_0, col, { - id = ent.extmark_id, - virt_lines = { { { icon .. " " .. text, "Comment" } } }, - virt_lines_above = true, - }) - if not ok then - M.clear_inline(event.id) - end - end) - inline_status[event.id] = { - bufnr = target_bufnr, - nsid = ns_inline, - extmark_id = extmark_id, - throbber = throb, - start_row_0 = start_row_0, - col = col, - status_text = "Implementing", - } - throb:start() + if not event or not event.id or not event.range then + return + end + local range = event.range + local target_bufnr = vim.fn.bufnr(event.target_path) + if target_bufnr == -1 then + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + if vim.api.nvim_buf_get_name(buf) == event.target_path then + target_bufnr = buf + break + end + end + end + if target_bufnr <= 0 or not vim.api.nvim_buf_is_valid(target_bufnr) then + return + end + local start_row_0 = math.max(0, range.start_line - 1) + local col = 0 + local extmark_id = vim.api.nvim_buf_set_extmark(target_bufnr, ns_inline, start_row_0, col, { + virt_lines = { { { " Implementing", "Comment" } } }, + virt_lines_above = true, + }) + local Throbber = require("codetyper.adapters.nvim.ui.throbber") + local throb = Throbber.new(function(icon) + if not inline_status[event.id] then + return + end + local ent = inline_status[event.id] + if not ent.bufnr or not vim.api.nvim_buf_is_valid(ent.bufnr) then + return + end + local text = ent.status_text or "Implementing" + local ok = pcall(vim.api.nvim_buf_set_extmark, ent.bufnr, ns_inline, start_row_0, col, { + id = ent.extmark_id, + virt_lines = { { { icon .. " " .. text, "Comment" } } }, + virt_lines_above = true, + }) + if not ok then + M.clear_inline(event.id) + end + end) + inline_status[event.id] = { + bufnr = target_bufnr, + nsid = ns_inline, + extmark_id = extmark_id, + throbber = throb, + start_row_0 = start_row_0, + col = col, + status_text = "Implementing", + } + throb:start() end --- Update the inline status text for a running event. ---@param event_id string ---@param text string New status text (e.g. "Reading context...", "Sending to LLM...") function M.update_inline_status(event_id, text) - local ent = inline_status[event_id] - if ent then - ent.status_text = text - end + local ent = inline_status[event_id] + if ent then + ent.status_text = text + end end --- Clear 99-style inline virtual text (call when worker completes). ---@param event_id string function M.clear_inline(event_id) - local ent = inline_status[event_id] - if not ent then - return - end - if ent.throbber then - ent.throbber:stop() - end - if ent.bufnr and vim.api.nvim_buf_is_valid(ent.bufnr) and ent.extmark_id then - pcall(vim.api.nvim_buf_del_extmark, ent.bufnr, ns_inline, ent.extmark_id) - end - inline_status[event_id] = nil + local ent = inline_status[event_id] + if not ent then + return + end + if ent.throbber then + ent.throbber:stop() + end + if ent.bufnr and vim.api.nvim_buf_is_valid(ent.bufnr) and ent.extmark_id then + pcall(vim.api.nvim_buf_del_extmark, ent.bufnr, ns_inline, ent.extmark_id) + end + inline_status[event_id] = nil end return M diff --git a/lua/codetyper/core/transform.lua b/lua/codetyper/core/transform.lua index 2b796ce..f20b618 100644 --- a/lua/codetyper/core/transform.lua +++ b/lua/codetyper/core/transform.lua @@ -1,368 +1,394 @@ local M = {} local EXPLAIN_PATTERNS = { - "explain", "what does", "what is", "how does", "how is", - "why does", "why is", "tell me", "walk through", "understand", - "question", "what's this", "what this", "about this", "help me understand", + "explain", + "what does", + "what is", + "how does", + "how is", + "why does", + "why is", + "tell me", + "walk through", + "understand", + "question", + "what's this", + "what this", + "about this", + "help me understand", } ---@param input string ---@return boolean local function is_explain_intent(input) - local lower = input:lower() - for _, pat in ipairs(EXPLAIN_PATTERNS) do - if lower:find(pat, 1, true) then - return true - end - end - return false + local lower = input:lower() + for _, pat in ipairs(EXPLAIN_PATTERNS) do + if lower:find(pat, 1, true) then + return true + end + end + return false end --- Return editor dimensions (from UI, like 99 plugin) ---@return number width ---@return number height local function get_ui_dimensions() - local ui = vim.api.nvim_list_uis()[1] - if ui then - return ui.width, ui.height - end - return vim.o.columns, vim.o.lines + local ui = vim.api.nvim_list_uis()[1] + if ui then + return ui.width, ui.height + end + return vim.o.columns, vim.o.lines end --- Centered floating window config for prompt (2/3 width, 1/3 height) ---@return table { width, height, row, col, border } local function create_centered_window() - local width, height = get_ui_dimensions() - local win_width = math.floor(width * 2 / 3) - local win_height = math.floor(height / 3) - return { - width = win_width, - height = win_height, - row = math.floor((height - win_height) / 2), - col = math.floor((width - win_width) / 2), - border = "rounded", - } + local width, height = get_ui_dimensions() + local win_width = math.floor(width * 2 / 3) + local win_height = math.floor(height / 3) + return { + width = win_width, + height = win_height, + row = math.floor((height - win_height) / 2), + col = math.floor((width - win_width) / 2), + border = "rounded", + } end --- Get visual selection text and range ---@return table|nil { text: string, start_line: number, end_line: number } local function get_visual_selection() - local mode = vim.api.nvim_get_mode().mode - -- Check if in visual mode - local is_visual = mode == "v" or mode == "V" or mode == "\22" - if not is_visual then - return nil - end - -- Get selection range BEFORE any mode changes - local start_line = vim.fn.line("'<") - local end_line = vim.fn.line("'>") - -- Check if marks are valid (might be 0 if not in visual mode) - if start_line <= 0 or end_line <= 0 then - return nil - end - -- Third argument must be a Vim dictionary; empty Lua table can be treated as list - local opts = { type = mode } - -- Protect against invalid column numbers returned by getpos (can happen with virtual/long multibyte lines) - local ok, selection = pcall(function() - local s_pos = vim.fn.getpos("'<") - local e_pos = vim.fn.getpos("'>") - local bufnr = vim.api.nvim_get_current_buf() - -- clamp columns to the actual line length + 1 to avoid E964 - local function clamp_pos(pos) - local lnum = pos[2] - local col = pos[3] - local line = (vim.api.nvim_buf_get_lines(bufnr, lnum - 1, lnum, false) or {""})[1] or "" - local maxcol = #line + 1 - pos[3] = math.max(1, math.min(col, maxcol)) - return pos - end - s_pos = clamp_pos(s_pos) - e_pos = clamp_pos(e_pos) - return vim.fn.getregion(s_pos, e_pos, opts) - end) - if not ok then - -- Fallback: grab whole lines between start_line and end_line - local lines = vim.api.nvim_buf_get_lines(0, start_line - 1, end_line, false) - selection = lines + local mode = vim.api.nvim_get_mode().mode + -- Check if in visual mode + local is_visual = mode == "v" or mode == "V" or mode == "\22" + if not is_visual then + return nil + end + -- Get selection range BEFORE any mode changes + local start_line = vim.fn.line("'<") + local end_line = vim.fn.line("'>") + -- Check if marks are valid (might be 0 if not in visual mode) + if start_line <= 0 or end_line <= 0 then + return nil + end + -- Third argument must be a Vim dictionary; empty Lua table can be treated as list + local opts = { type = mode } + -- Protect against invalid column numbers returned by getpos (can happen with virtual/long multibyte lines) + local ok, selection = pcall(function() + local s_pos = vim.fn.getpos("'<") + local e_pos = vim.fn.getpos("'>") + local bufnr = vim.api.nvim_get_current_buf() + -- clamp columns to the actual line length + 1 to avoid E964 + local function clamp_pos(pos) + local lnum = pos[2] + local col = pos[3] + local line = (vim.api.nvim_buf_get_lines(bufnr, lnum - 1, lnum, false) or { "" })[1] or "" + local maxcol = #line + 1 + pos[3] = math.max(1, math.min(col, maxcol)) + return pos end - local text = type(selection) == "table" and table.concat(selection, "\n") or tostring(selection or "") - return { - text = text, - start_line = start_line, - end_line = end_line, - } + s_pos = clamp_pos(s_pos) + e_pos = clamp_pos(e_pos) + return vim.fn.getregion(s_pos, e_pos, opts) + end) + if not ok then + -- Fallback: grab whole lines between start_line and end_line + local lines = vim.api.nvim_buf_get_lines(0, start_line - 1, end_line, false) + selection = lines + end + local text = type(selection) == "table" and table.concat(selection, "\n") or tostring(selection or "") + return { + text = text, + start_line = start_line, + end_line = end_line, + } end --- Transform visual selection with custom prompt input --- Opens input window for prompt, processes selection on confirm. --- When nothing is selected (e.g. from Normal mode), only the prompt is requested. function M.cmd_transform_selection() - local logger = require("codetyper.support.logger") - logger.func_entry("commands", "cmd_transform_selection", {}) - -- Get visual selection (returns table with text, start_line, end_line or nil) - local selection_data = get_visual_selection() - local selection_text = selection_data and selection_data.text or "" - local has_selection = selection_text and #selection_text >= 4 + local logger = require("codetyper.support.logger") + logger.func_entry("commands", "cmd_transform_selection", {}) + -- Get visual selection (returns table with text, start_line, end_line or nil) + local selection_data = get_visual_selection() + local selection_text = selection_data and selection_data.text or "" + local has_selection = selection_text and #selection_text >= 4 - local bufnr = vim.api.nvim_get_current_buf() - local filepath = vim.fn.expand("%:p") - local line_count = vim.api.nvim_buf_line_count(bufnr) - line_count = math.max(1, line_count) + local bufnr = vim.api.nvim_get_current_buf() + local filepath = vim.fn.expand("%:p") + local line_count = vim.api.nvim_buf_line_count(bufnr) + line_count = math.max(1, line_count) - -- Range for injection: selection, cursor line when no selection - local start_line, end_line - local is_cursor_insert = false - if has_selection and selection_data then - start_line = selection_data.start_line - end_line = selection_data.end_line - logger.info("commands", string.format("Visual selection: start=%d end=%d selected_text_lines=%d", - start_line, end_line, #vim.split(selection_text, "\n", { plain = true }))) - else - -- No selection: insert at current cursor line (not replace whole file) - start_line = vim.fn.line(".") - end_line = start_line - is_cursor_insert = true - end - -- Clamp to valid 1-based range (avoid 0 or out-of-bounds) - start_line = math.max(1, math.min(start_line, line_count)) - end_line = math.max(1, math.min(end_line, line_count)) - if end_line < start_line then - end_line = start_line - end + -- Range for injection: selection, cursor line when no selection + local start_line, end_line + local is_cursor_insert = false + if has_selection and selection_data then + start_line = selection_data.start_line + end_line = selection_data.end_line + logger.info( + "commands", + string.format( + "Visual selection: start=%d end=%d selected_text_lines=%d", + start_line, + end_line, + #vim.split(selection_text, "\n", { plain = true }) + ) + ) + else + -- No selection: insert at current cursor line (not replace whole file) + start_line = vim.fn.line(".") + end_line = start_line + is_cursor_insert = true + end + -- Clamp to valid 1-based range (avoid 0 or out-of-bounds) + start_line = math.max(1, math.min(start_line, line_count)) + end_line = math.max(1, math.min(end_line, line_count)) + if end_line < start_line then + end_line = start_line + end - -- Capture injection range so we know exactly where to apply the generated code later - local injection_range = { start_line = start_line, end_line = end_line } - local range_line_count = end_line - start_line + 1 + -- Capture injection range so we know exactly where to apply the generated code later + local injection_range = { start_line = start_line, end_line = end_line } + local range_line_count = end_line - start_line + 1 - -- Open centered prompt window (pattern from 99: acwrite + BufWriteCmd to submit, BufLeave to keep focus) - local prompt_buf = vim.api.nvim_create_buf(false, true) - vim.bo[prompt_buf].buftype = "acwrite" - vim.bo[prompt_buf].bufhidden = "wipe" - vim.bo[prompt_buf].filetype = "markdown" - vim.bo[prompt_buf].swapfile = false - vim.api.nvim_buf_set_name(prompt_buf, "codetyper-prompt") + -- Open centered prompt window (pattern from 99: acwrite + BufWriteCmd to submit, BufLeave to keep focus) + local prompt_buf = vim.api.nvim_create_buf(false, true) + vim.bo[prompt_buf].buftype = "acwrite" + vim.bo[prompt_buf].bufhidden = "wipe" + vim.bo[prompt_buf].filetype = "markdown" + vim.bo[prompt_buf].swapfile = false + vim.api.nvim_buf_set_name(prompt_buf, "codetyper-prompt") - local win_opts = create_centered_window() - local prompt_win = vim.api.nvim_open_win(prompt_buf, true, { - relative = "editor", - row = win_opts.row, - col = win_opts.col, - width = win_opts.width, - height = win_opts.height, - style = "minimal", - border = win_opts.border, - title = has_selection and " Enter prompt for selection " or " Enter prompt ", - title_pos = "center", - }) - vim.wo[prompt_win].wrap = true - vim.api.nvim_set_current_win(prompt_win) + local win_opts = create_centered_window() + local prompt_win = vim.api.nvim_open_win(prompt_buf, true, { + relative = "editor", + row = win_opts.row, + col = win_opts.col, + width = win_opts.width, + height = win_opts.height, + style = "minimal", + border = win_opts.border, + title = has_selection and " Enter prompt for selection " or " Enter prompt ", + title_pos = "center", + }) + vim.wo[prompt_win].wrap = true + vim.api.nvim_set_current_win(prompt_win) - local function close_prompt() - if prompt_win and vim.api.nvim_win_is_valid(prompt_win) then - vim.api.nvim_win_close(prompt_win, true) - end - if prompt_buf and vim.api.nvim_buf_is_valid(prompt_buf) then - vim.api.nvim_buf_delete(prompt_buf, { force = true }) - end - prompt_win = nil - prompt_buf = nil - end + local function close_prompt() + if prompt_win and vim.api.nvim_win_is_valid(prompt_win) then + vim.api.nvim_win_close(prompt_win, true) + end + if prompt_buf and vim.api.nvim_buf_is_valid(prompt_buf) then + vim.api.nvim_buf_delete(prompt_buf, { force = true }) + end + prompt_win = nil + prompt_buf = nil + end - local submitted = false + local submitted = false - -- Resolve enclosing context for the selection (handles all cases: - -- partial inside function, whole function, spanning multiple functions, indentation fallback) - local scope_mod = require("codetyper.core.scope") - local sel_context = nil - local is_whole_file = false + -- Resolve enclosing context for the selection (handles all cases: + -- partial inside function, whole function, spanning multiple functions, indentation fallback) + local scope_mod = require("codetyper.core.scope") + local sel_context = nil + local is_whole_file = false - if has_selection and selection_data then - sel_context = scope_mod.resolve_selection_context(bufnr, start_line, end_line) - is_whole_file = sel_context.type == "file" + if has_selection and selection_data then + sel_context = scope_mod.resolve_selection_context(bufnr, start_line, end_line) + is_whole_file = sel_context.type == "file" - -- Expand injection range to cover full enclosing scopes when needed - if sel_context.type == "whole_function" or sel_context.type == "multi_function" then - injection_range.start_line = sel_context.expanded_start - injection_range.end_line = sel_context.expanded_end - start_line = sel_context.expanded_start - end_line = sel_context.expanded_end - -- Re-read the expanded selection text - local exp_lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false) - selection_text = table.concat(exp_lines, "\n") - end - end + -- Expand injection range to cover full enclosing scopes when needed + if sel_context.type == "whole_function" or sel_context.type == "multi_function" then + injection_range.start_line = sel_context.expanded_start + injection_range.end_line = sel_context.expanded_end + start_line = sel_context.expanded_start + end_line = sel_context.expanded_end + -- Re-read the expanded selection text + local exp_lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false) + selection_text = table.concat(exp_lines, "\n") + end + end - local function submit_prompt() - if not prompt_buf or not vim.api.nvim_buf_is_valid(prompt_buf) then - close_prompt() - return - end - submitted = true - local lines_input = vim.api.nvim_buf_get_lines(prompt_buf, 0, -1, false) - local input = table.concat(lines_input, "\n"):gsub("^%s+", ""):gsub("%s+$", "") - close_prompt() - if input == "" then - logger.info("commands", "User cancelled prompt input") - return - end + local function submit_prompt() + if not prompt_buf or not vim.api.nvim_buf_is_valid(prompt_buf) then + close_prompt() + return + end + submitted = true + local lines_input = vim.api.nvim_buf_get_lines(prompt_buf, 0, -1, false) + local input = table.concat(lines_input, "\n"):gsub("^%s+", ""):gsub("%s+$", "") + close_prompt() + if input == "" then + logger.info("commands", "User cancelled prompt input") + return + end - local is_explain = is_explain_intent(input) + local is_explain = is_explain_intent(input) - -- Explain intent requires a selection — notify and bail if none - if is_explain and not has_selection then - vim.notify("Nothing selected to explain — select code first", vim.log.levels.WARN) - return - end + -- Explain intent requires a selection — notify and bail if none + if is_explain and not has_selection then + vim.notify("Nothing selected to explain — select code first", vim.log.levels.WARN) + return + end - local content - local doc_injection_range = injection_range - local doc_intent_override = has_selection and { action = "replace" } or (is_cursor_insert and { action = "insert" } or nil) + local content + local doc_injection_range = injection_range + local doc_intent_override = has_selection and { action = "replace" } + or (is_cursor_insert and { action = "insert" } or nil) - if is_explain and has_selection and sel_context then - -- Build a prompt that asks the LLM to generate documentation comments only - local ft = vim.bo[bufnr].filetype or "text" - local context_block = "" - if sel_context.type == "partial_function" and #sel_context.scopes > 0 then - local scope = sel_context.scopes[1] - context_block = string.format( - "\n\nEnclosing %s \"%s\":\n```%s\n%s\n```", - scope.type, scope.name or "anonymous", ft, scope.text - ) - elseif sel_context.type == "multi_function" and #sel_context.scopes > 0 then - local parts = {} - for _, s in ipairs(sel_context.scopes) do - table.insert(parts, string.format("-- %s \"%s\":\n%s", s.type, s.name or "anonymous", s.text)) - end - context_block = "\n\nRelated scopes:\n```" .. ft .. "\n" .. table.concat(parts, "\n\n") .. "\n```" - elseif sel_context.type == "indent_block" and #sel_context.scopes > 0 then - context_block = string.format( - "\n\nEnclosing block:\n```%s\n%s\n```", - ft, sel_context.scopes[1].text - ) - end + if is_explain and has_selection and sel_context then + -- Build a prompt that asks the LLM to generate documentation comments only + local ft = vim.bo[bufnr].filetype or "text" + local context_block = "" + if sel_context.type == "partial_function" and #sel_context.scopes > 0 then + local scope = sel_context.scopes[1] + context_block = + string.format('\n\nEnclosing %s "%s":\n```%s\n%s\n```', scope.type, scope.name or "anonymous", ft, scope.text) + elseif sel_context.type == "multi_function" and #sel_context.scopes > 0 then + local parts = {} + for _, s in ipairs(sel_context.scopes) do + table.insert(parts, string.format('-- %s "%s":\n%s', s.type, s.name or "anonymous", s.text)) + end + context_block = "\n\nRelated scopes:\n```" .. ft .. "\n" .. table.concat(parts, "\n\n") .. "\n```" + elseif sel_context.type == "indent_block" and #sel_context.scopes > 0 then + context_block = string.format("\n\nEnclosing block:\n```%s\n%s\n```", ft, sel_context.scopes[1].text) + end - content = string.format( - "%s\n\nGenerate documentation comments for the following %s code. " - .. "Output ONLY the comment block using the correct comment syntax for %s. " - .. "Do NOT include the code itself.%s\n\nCode to document:\n```%s\n%s\n```", - input, ft, ft, context_block, ft, selection_text - ) + content = string.format( + "%s\n\nGenerate documentation comments for the following %s code. " + .. "Output ONLY the comment block using the correct comment syntax for %s. " + .. "Do NOT include the code itself.%s\n\nCode to document:\n```%s\n%s\n```", + input, + ft, + ft, + context_block, + ft, + selection_text + ) - -- Insert above the selection instead of replacing it - doc_injection_range = { start_line = start_line, end_line = start_line } - doc_intent_override = { action = "insert", type = "explain" } + -- Insert above the selection instead of replacing it + doc_injection_range = { start_line = start_line, end_line = start_line } + doc_intent_override = { action = "insert", type = "explain" } + elseif has_selection and sel_context then + if sel_context.type == "partial_function" and #sel_context.scopes > 0 then + local scope = sel_context.scopes[1] + content = string.format( + '%s\n\nEnclosing %s "%s" (lines %d-%d):\n```\n%s\n```\n\nSelected code to modify (lines %d-%d):\n%s', + input, + scope.type, + scope.name or "anonymous", + scope.range.start_row, + scope.range.end_row, + scope.text, + start_line, + end_line, + selection_text + ) + elseif sel_context.type == "multi_function" and #sel_context.scopes > 0 then + local scope_descs = {} + for _, s in ipairs(sel_context.scopes) do + table.insert( + scope_descs, + string.format('- %s "%s" (lines %d-%d)', s.type, s.name or "anonymous", s.range.start_row, s.range.end_row) + ) + end + content = string.format( + "%s\n\nAffected scopes:\n%s\n\nCode to replace (lines %d-%d):\n%s", + input, + table.concat(scope_descs, "\n"), + start_line, + end_line, + selection_text + ) + elseif sel_context.type == "indent_block" and #sel_context.scopes > 0 then + local block = sel_context.scopes[1] + content = string.format( + "%s\n\nEnclosing block (lines %d-%d):\n```\n%s\n```\n\nSelected code to modify (lines %d-%d):\n%s", + input, + block.range.start_row, + block.range.end_row, + block.text, + start_line, + end_line, + selection_text + ) + else + content = input .. "\n\nCode to replace (replace this code):\n" .. selection_text + end + elseif is_cursor_insert then + content = "Insert at line " .. start_line .. ":\n" .. input + else + content = input + end - elseif has_selection and sel_context then - if sel_context.type == "partial_function" and #sel_context.scopes > 0 then - local scope = sel_context.scopes[1] - content = string.format( - "%s\n\nEnclosing %s \"%s\" (lines %d-%d):\n```\n%s\n```\n\nSelected code to modify (lines %d-%d):\n%s", - input, - scope.type, - scope.name or "anonymous", - scope.range.start_row, scope.range.end_row, - scope.text, - start_line, end_line, - selection_text - ) - elseif sel_context.type == "multi_function" and #sel_context.scopes > 0 then - local scope_descs = {} - for _, s in ipairs(sel_context.scopes) do - table.insert(scope_descs, string.format("- %s \"%s\" (lines %d-%d)", - s.type, s.name or "anonymous", s.range.start_row, s.range.end_row)) - end - content = string.format( - "%s\n\nAffected scopes:\n%s\n\nCode to replace (lines %d-%d):\n%s", - input, - table.concat(scope_descs, "\n"), - start_line, end_line, - selection_text - ) - elseif sel_context.type == "indent_block" and #sel_context.scopes > 0 then - local block = sel_context.scopes[1] - content = string.format( - "%s\n\nEnclosing block (lines %d-%d):\n```\n%s\n```\n\nSelected code to modify (lines %d-%d):\n%s", - input, - block.range.start_row, block.range.end_row, - block.text, - start_line, end_line, - selection_text - ) - else - content = input .. "\n\nCode to replace (replace this code):\n" .. selection_text - end - elseif is_cursor_insert then - content = "Insert at line " .. start_line .. ":\n" .. input - else - content = input - end + local prompt = { + content = content, + start_line = doc_injection_range.start_line, + end_line = doc_injection_range.end_line, + start_col = 1, + end_col = 1, + user_prompt = input, + injection_range = doc_injection_range, + intent_override = doc_intent_override, + is_whole_file = is_whole_file, + } + local autocmds = require("codetyper.adapters.nvim.autocmds") + autocmds.process_single_prompt(bufnr, prompt, filepath, true) + end - local prompt = { - content = content, - start_line = doc_injection_range.start_line, - end_line = doc_injection_range.end_line, - start_col = 1, - end_col = 1, - user_prompt = input, - injection_range = doc_injection_range, - intent_override = doc_intent_override, - is_whole_file = is_whole_file, - } - local autocmds = require("codetyper.adapters.nvim.autocmds") - autocmds.process_single_prompt(bufnr, prompt, filepath, true) - end + local augroup = vim.api.nvim_create_augroup("CodetyperPrompt_" .. prompt_buf, { clear = true }) - local augroup = vim.api.nvim_create_augroup("CodetyperPrompt_" .. prompt_buf, { clear = true }) + -- Submit on :w (acwrite buffer triggers BufWriteCmd) + vim.api.nvim_create_autocmd("BufWriteCmd", { + group = augroup, + buffer = prompt_buf, + callback = function() + if prompt_win and vim.api.nvim_win_is_valid(prompt_win) then + submitted = true + submit_prompt() + end + end, + }) - -- Submit on :w (acwrite buffer triggers BufWriteCmd) - vim.api.nvim_create_autocmd("BufWriteCmd", { - group = augroup, - buffer = prompt_buf, - callback = function() - if prompt_win and vim.api.nvim_win_is_valid(prompt_win) then - submitted = true - submit_prompt() - end - end, - }) + -- Keep focus in prompt window (prevent leaving to other buffers) + vim.api.nvim_create_autocmd("BufLeave", { + group = augroup, + buffer = prompt_buf, + callback = function() + if prompt_win and vim.api.nvim_win_is_valid(prompt_win) then + vim.api.nvim_set_current_win(prompt_win) + end + end, + }) - -- Keep focus in prompt window (prevent leaving to other buffers) - vim.api.nvim_create_autocmd("BufLeave", { - group = augroup, - buffer = prompt_buf, - callback = function() - if prompt_win and vim.api.nvim_win_is_valid(prompt_win) then - vim.api.nvim_set_current_win(prompt_win) - end - end, - }) + -- Clean up when window is closed (e.g. :q or close button) + vim.api.nvim_create_autocmd("WinClosed", { + group = augroup, + pattern = tostring(prompt_win), + callback = function() + if not submitted then + logger.info("commands", "User cancelled prompt input") + end + close_prompt() + end, + }) - -- Clean up when window is closed (e.g. :q or close button) - vim.api.nvim_create_autocmd("WinClosed", { - group = augroup, - pattern = tostring(prompt_win), - callback = function() - if not submitted then - logger.info("commands", "User cancelled prompt input") - end - close_prompt() - end, - }) + local map_opts = { buffer = prompt_buf, noremap = true, silent = true } + -- Normal mode: Enter, :w, or Ctrl+Enter to submit + vim.keymap.set("n", "", submit_prompt, map_opts) + vim.keymap.set("n", "", submit_prompt, map_opts) + vim.keymap.set("n", "", submit_prompt, map_opts) + vim.keymap.set("n", "w", "w", vim.tbl_extend("force", map_opts, { desc = "Submit prompt" })) + -- Insert mode: Ctrl+Enter to submit + vim.keymap.set("i", "", submit_prompt, map_opts) + vim.keymap.set("i", "", submit_prompt, map_opts) + -- Close/cancel: Esc (in normal), q, or :q + vim.keymap.set("n", "", close_prompt, map_opts) + vim.keymap.set("n", "q", close_prompt, map_opts) - local map_opts = { buffer = prompt_buf, noremap = true, silent = true } - -- Normal mode: Enter, :w, or Ctrl+Enter to submit - vim.keymap.set("n", "", submit_prompt, map_opts) - vim.keymap.set("n", "", submit_prompt, map_opts) - vim.keymap.set("n", "", submit_prompt, map_opts) - vim.keymap.set("n", "w", "w", vim.tbl_extend("force", map_opts, { desc = "Submit prompt" })) - -- Insert mode: Ctrl+Enter to submit - vim.keymap.set("i", "", submit_prompt, map_opts) - vim.keymap.set("i", "", submit_prompt, map_opts) - -- Close/cancel: Esc (in normal), q, or :q - vim.keymap.set("n", "", close_prompt, map_opts) - vim.keymap.set("n", "q", close_prompt, map_opts) - - vim.cmd("startinsert") + vim.cmd("startinsert") end return M diff --git a/lua/codetyper/features/completion/inline.lua b/lua/codetyper/features/completion/inline.lua index d9d4a61..a9533b8 100644 --- a/lua/codetyper/features/completion/inline.lua +++ b/lua/codetyper/features/completion/inline.lua @@ -9,183 +9,183 @@ local utils = require("codetyper.support.utils") ---@param prefix string Prefix to filter files ---@return table[] List of completion items local function get_file_completions(prefix) - local cwd = vim.fn.getcwd() - local current_file = vim.fn.expand("%:p") - local current_dir = vim.fn.fnamemodify(current_file, ":h") - local files = {} + local cwd = vim.fn.getcwd() + local current_file = vim.fn.expand("%:p") + local current_dir = vim.fn.fnamemodify(current_file, ":h") + local files = {} - -- Use vim.fn.glob to find files matching the prefix - local pattern = prefix .. "*" + -- Use vim.fn.glob to find files matching the prefix + local pattern = prefix .. "*" - -- Determine base directory - use current file's directory if outside cwd - local base_dir = cwd - if current_dir ~= "" and not current_dir:find(cwd, 1, true) then - -- File is outside project, use its directory as base - base_dir = current_dir - end + -- Determine base directory - use current file's directory if outside cwd + local base_dir = cwd + if current_dir ~= "" and not current_dir:find(cwd, 1, true) then + -- File is outside project, use its directory as base + base_dir = current_dir + end - -- Search in base directory - local matches = vim.fn.glob(base_dir .. "/" .. pattern, false, true) + -- Search in base directory + local matches = vim.fn.glob(base_dir .. "/" .. pattern, false, true) - -- Search with ** for all subdirectories - local deep_matches = vim.fn.glob(base_dir .. "/**/" .. pattern, false, true) - for _, m in ipairs(deep_matches) do - table.insert(matches, m) - end + -- Search with ** for all subdirectories + local deep_matches = vim.fn.glob(base_dir .. "/**/" .. pattern, false, true) + for _, m in ipairs(deep_matches) do + table.insert(matches, m) + end - -- Also search in cwd if different from base_dir - if base_dir ~= cwd then - local cwd_matches = vim.fn.glob(cwd .. "/" .. pattern, false, true) - for _, m in ipairs(cwd_matches) do - table.insert(matches, m) - end - local cwd_deep = vim.fn.glob(cwd .. "/**/" .. pattern, false, true) - for _, m in ipairs(cwd_deep) do - table.insert(matches, m) - end - end + -- Also search in cwd if different from base_dir + if base_dir ~= cwd then + local cwd_matches = vim.fn.glob(cwd .. "/" .. pattern, false, true) + for _, m in ipairs(cwd_matches) do + table.insert(matches, m) + end + local cwd_deep = vim.fn.glob(cwd .. "/**/" .. pattern, false, true) + for _, m in ipairs(cwd_deep) do + table.insert(matches, m) + end + end - -- Also search specific directories if prefix doesn't have path - if not prefix:find("/") then - local search_dirs = { "src", "lib", "lua", "app", "components", "utils", "tests" } - for _, dir in ipairs(search_dirs) do - local dir_path = base_dir .. "/" .. dir - if vim.fn.isdirectory(dir_path) == 1 then - local dir_matches = vim.fn.glob(dir_path .. "/**/" .. pattern, false, true) - for _, m in ipairs(dir_matches) do - table.insert(matches, m) - end - end - end - end + -- Also search specific directories if prefix doesn't have path + if not prefix:find("/") then + local search_dirs = { "src", "lib", "lua", "app", "components", "utils", "tests" } + for _, dir in ipairs(search_dirs) do + local dir_path = base_dir .. "/" .. dir + if vim.fn.isdirectory(dir_path) == 1 then + local dir_matches = vim.fn.glob(dir_path .. "/**/" .. pattern, false, true) + for _, m in ipairs(dir_matches) do + table.insert(matches, m) + end + end + end + end - -- Convert to relative paths and deduplicate - local seen = {} - for _, match in ipairs(matches) do - -- Convert to relative path based on which base it came from - local rel_path - if match:find(base_dir, 1, true) == 1 then - rel_path = match:sub(#base_dir + 2) - elseif match:find(cwd, 1, true) == 1 then - rel_path = match:sub(#cwd + 2) - else - rel_path = vim.fn.fnamemodify(match, ":t") -- Just filename if can't make relative - end + -- Convert to relative paths and deduplicate + local seen = {} + for _, match in ipairs(matches) do + -- Convert to relative path based on which base it came from + local rel_path + if match:find(base_dir, 1, true) == 1 then + rel_path = match:sub(#base_dir + 2) + elseif match:find(cwd, 1, true) == 1 then + rel_path = match:sub(#cwd + 2) + else + rel_path = vim.fn.fnamemodify(match, ":t") -- Just filename if can't make relative + end - -- Skip directories, coder files, and hidden/generated files - if - vim.fn.isdirectory(match) == 0 - and not utils.is_coder_file(match) - and not rel_path:match("^%.") - and not rel_path:match("node_modules") - and not rel_path:match("%.git/") - and not rel_path:match("dist/") - and not rel_path:match("build/") - and not seen[rel_path] - then - seen[rel_path] = true - table.insert(files, { - word = rel_path, - abbr = rel_path, - kind = "File", - menu = "[ref]", - }) - end - end + -- Skip directories, coder files, and hidden/generated files + if + vim.fn.isdirectory(match) == 0 + and not utils.is_coder_file(match) + and not rel_path:match("^%.") + and not rel_path:match("node_modules") + and not rel_path:match("%.git/") + and not rel_path:match("dist/") + and not rel_path:match("build/") + and not seen[rel_path] + then + seen[rel_path] = true + table.insert(files, { + word = rel_path, + abbr = rel_path, + kind = "File", + menu = "[ref]", + }) + end + end - -- Sort by length (shorter paths first) - table.sort(files, function(a, b) - return #a.word < #b.word - end) + -- Sort by length (shorter paths first) + table.sort(files, function(a, b) + return #a.word < #b.word + end) - -- Limit results - local result = {} - for i = 1, math.min(#files, 15) do - result[i] = files[i] - end + -- Limit results + local result = {} + for i = 1, math.min(#files, 15) do + result[i] = files[i] + end - return result + return result end --- Show file completion popup function M.show_file_completion() - -- Check if we're in an open prompt tag - local is_inside = parser.is_cursor_in_open_tag() - if not is_inside then - return false - end + -- Check if we're in an open prompt tag + local is_inside = parser.is_cursor_in_open_tag() + if not is_inside then + return false + end - -- Get the prefix being typed - local prefix = parser.get_file_ref_prefix() - if prefix == nil then - return false - end + -- Get the prefix being typed + local prefix = parser.get_file_ref_prefix() + if prefix == nil then + return false + end - -- Get completions - local items = get_file_completions(prefix) + -- Get completions + local items = get_file_completions(prefix) - if #items == 0 then - -- Try with empty prefix to show all files - items = get_file_completions("") - end + if #items == 0 then + -- Try with empty prefix to show all files + items = get_file_completions("") + end - if #items > 0 then - -- Calculate start column (position right after @) - local cursor = vim.api.nvim_win_get_cursor(0) - local col = cursor[2] - #prefix + 1 -- 1-indexed for complete() + if #items > 0 then + -- Calculate start column (position right after @) + local cursor = vim.api.nvim_win_get_cursor(0) + local col = cursor[2] - #prefix + 1 -- 1-indexed for complete() - -- Show completion popup - vim.fn.complete(col, items) - return true - end + -- Show completion popup + vim.fn.complete(col, items) + return true + end - return false + return false end --- Setup completion for file references (works on ALL files) function M.setup() - local group = vim.api.nvim_create_augroup("CoderCompletion", { clear = true }) + local group = vim.api.nvim_create_augroup("CoderCompletion", { clear = true }) - -- Trigger completion on @ in insert mode (works on ALL files) - vim.api.nvim_create_autocmd("InsertCharPre", { - group = group, - pattern = "*", - callback = function() - -- Skip special buffers - if vim.bo.buftype ~= "" then - return - end + -- Trigger completion on @ in insert mode (works on ALL files) + vim.api.nvim_create_autocmd("InsertCharPre", { + group = group, + pattern = "*", + callback = function() + -- Skip special buffers + if vim.bo.buftype ~= "" then + return + end - if vim.v.char == "@" then - -- Schedule completion popup after the @ is inserted - vim.schedule(function() - -- Check we're in an open tag - local is_inside = parser.is_cursor_in_open_tag() - if not is_inside then - return - end + if vim.v.char == "@" then + -- Schedule completion popup after the @ is inserted + vim.schedule(function() + -- Check we're in an open tag + local is_inside = parser.is_cursor_in_open_tag() + if not is_inside then + return + end - -- Check we're not typing @/ (closing tag) - local cursor = vim.api.nvim_win_get_cursor(0) - local line = vim.api.nvim_get_current_line() - local next_char = line:sub(cursor[2] + 2, cursor[2] + 2) + -- Check we're not typing @/ (closing tag) + local cursor = vim.api.nvim_win_get_cursor(0) + local line = vim.api.nvim_get_current_line() + local next_char = line:sub(cursor[2] + 2, cursor[2] + 2) - if next_char == "/" then - return - end + if next_char == "/" then + return + end - -- Show file completion - M.show_file_completion() - end) - end - end, - desc = "Trigger file completion on @ inside prompt tags", - }) + -- Show file completion + M.show_file_completion() + end) + end + end, + desc = "Trigger file completion on @ inside prompt tags", + }) - -- Also allow manual trigger with style keybinding in insert mode - vim.keymap.set("i", "@", function() - M.show_file_completion() - end, { silent = true, desc = "Coder: Complete file reference" }) + -- Also allow manual trigger with style keybinding in insert mode + vim.keymap.set("i", "@", function() + M.show_file_completion() + end, { silent = true, desc = "Coder: Complete file reference" }) end return M diff --git a/lua/codetyper/features/completion/suggestion.lua b/lua/codetyper/features/completion/suggestion.lua index 98dda81..b427133 100644 --- a/lua/codetyper/features/completion/suggestion.lua +++ b/lua/codetyper/features/completion/suggestion.lua @@ -19,15 +19,15 @@ local M = {} ---@field using_copilot boolean Whether currently using copilot local state = { - current_suggestion = nil, - suggestions = {}, - current_index = 0, - extmark_id = nil, - bufnr = nil, - line = nil, - col = nil, - timer = nil, - using_copilot = false, + current_suggestion = nil, + suggestions = {}, + current_index = 0, + extmark_id = nil, + bufnr = nil, + line = nil, + col = nil, + timer = nil, + using_copilot = false, } --- Namespace for virtual text @@ -38,221 +38,221 @@ local hl_group = "CmpGhostText" --- Configuration local config = { - enabled = true, - auto_trigger = true, - debounce = 150, - use_copilot = true, -- Use copilot when available - keymap = { - accept = "", - next = "", - prev = "", - dismiss = "", - }, + enabled = true, + auto_trigger = true, + debounce = 150, + use_copilot = true, -- Use copilot when available + keymap = { + accept = "", + next = "", + prev = "", + dismiss = "", + }, } --- Check if copilot is available and enabled ---@return boolean, table|nil available, copilot_suggestion module local function get_copilot() - if not config.use_copilot then - return false, nil - end + if not config.use_copilot then + return false, nil + end - local ok, copilot_suggestion = pcall(require, "copilot.suggestion") - if not ok then - return false, nil - end + local ok, copilot_suggestion = pcall(require, "copilot.suggestion") + if not ok then + return false, nil + end - -- Check if copilot suggestion is enabled - local ok_client, copilot_client = pcall(require, "copilot.client") - if ok_client and copilot_client.is_disabled and copilot_client.is_disabled() then - return false, nil - end + -- Check if copilot suggestion is enabled + local ok_client, copilot_client = pcall(require, "copilot.client") + if ok_client and copilot_client.is_disabled and copilot_client.is_disabled() then + return false, nil + end - return true, copilot_suggestion + return true, copilot_suggestion end --- Check if suggestion is visible (copilot or codetyper) ---@return boolean function M.is_visible() - -- Check copilot first - local copilot_ok, copilot_suggestion = get_copilot() - if copilot_ok and copilot_suggestion.is_visible() then - state.using_copilot = true - return true - end + -- Check copilot first + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + state.using_copilot = true + return true + end - -- Check codetyper's own suggestion - state.using_copilot = false - return state.extmark_id ~= nil and state.current_suggestion ~= nil + -- Check codetyper's own suggestion + state.using_copilot = false + return state.extmark_id ~= nil and state.current_suggestion ~= nil end --- Clear the current suggestion function M.dismiss() - -- Dismiss copilot if active - local copilot_ok, copilot_suggestion = get_copilot() - if copilot_ok and copilot_suggestion.is_visible() then - copilot_suggestion.dismiss() - end + -- Dismiss copilot if active + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + copilot_suggestion.dismiss() + end - -- Clear codetyper's suggestion - if state.extmark_id and state.bufnr then - pcall(vim.api.nvim_buf_del_extmark, state.bufnr, ns, state.extmark_id) - end + -- Clear codetyper's suggestion + if state.extmark_id and state.bufnr then + pcall(vim.api.nvim_buf_del_extmark, state.bufnr, ns, state.extmark_id) + end - state.current_suggestion = nil - state.suggestions = {} - state.current_index = 0 - state.extmark_id = nil - state.bufnr = nil - state.line = nil - state.col = nil - state.using_copilot = false + state.current_suggestion = nil + state.suggestions = {} + state.current_index = 0 + state.extmark_id = nil + state.bufnr = nil + state.line = nil + state.col = nil + state.using_copilot = false end --- Display suggestion as ghost text ---@param suggestion string The suggestion to display local function display_suggestion(suggestion) - if not suggestion or suggestion == "" then - return - end + if not suggestion or suggestion == "" then + return + end - M.dismiss() + M.dismiss() - local bufnr = vim.api.nvim_get_current_buf() - local cursor = vim.api.nvim_win_get_cursor(0) - local line = cursor[1] - 1 - local col = cursor[2] + local bufnr = vim.api.nvim_get_current_buf() + local cursor = vim.api.nvim_win_get_cursor(0) + local line = cursor[1] - 1 + local col = cursor[2] - -- Split suggestion into lines - local lines = vim.split(suggestion, "\n", { plain = true }) + -- Split suggestion into lines + local lines = vim.split(suggestion, "\n", { plain = true }) - -- Build virtual text - local virt_text = {} - local virt_lines = {} + -- Build virtual text + local virt_text = {} + local virt_lines = {} - -- First line goes inline - if #lines > 0 then - virt_text = { { lines[1], hl_group } } - end + -- First line goes inline + if #lines > 0 then + virt_text = { { lines[1], hl_group } } + end - -- Remaining lines go below - for i = 2, #lines do - table.insert(virt_lines, { { lines[i], hl_group } }) - end + -- Remaining lines go below + for i = 2, #lines do + table.insert(virt_lines, { { lines[i], hl_group } }) + end - -- Create extmark with virtual text - local opts = { - virt_text = virt_text, - virt_text_pos = "overlay", - hl_mode = "combine", - } + -- Create extmark with virtual text + local opts = { + virt_text = virt_text, + virt_text_pos = "overlay", + hl_mode = "combine", + } - if #virt_lines > 0 then - opts.virt_lines = virt_lines - end + if #virt_lines > 0 then + opts.virt_lines = virt_lines + end - state.extmark_id = vim.api.nvim_buf_set_extmark(bufnr, ns, line, col, opts) - state.bufnr = bufnr - state.line = line - state.col = col - state.current_suggestion = suggestion + state.extmark_id = vim.api.nvim_buf_set_extmark(bufnr, ns, line, col, opts) + state.bufnr = bufnr + state.line = line + state.col = col + state.current_suggestion = suggestion end --- Accept the current suggestion ---@return boolean Whether a suggestion was accepted function M.accept() - -- Check copilot first - local copilot_ok, copilot_suggestion = get_copilot() - if copilot_ok and copilot_suggestion.is_visible() then - copilot_suggestion.accept() - state.using_copilot = false - return true - end + -- Check copilot first + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + copilot_suggestion.accept() + state.using_copilot = false + return true + end - -- Accept codetyper's suggestion - if not M.is_visible() then - return false - end + -- Accept codetyper's suggestion + if not M.is_visible() then + return false + end - local suggestion = state.current_suggestion - local bufnr = state.bufnr - local line = state.line - local col = state.col + local suggestion = state.current_suggestion + local bufnr = state.bufnr + local line = state.line + local col = state.col - M.dismiss() + M.dismiss() - if suggestion and bufnr and line ~= nil and col ~= nil then - -- Get current line content - local current_line = vim.api.nvim_buf_get_lines(bufnr, line, line + 1, false)[1] or "" + if suggestion and bufnr and line ~= nil and col ~= nil then + -- Get current line content + local current_line = vim.api.nvim_buf_get_lines(bufnr, line, line + 1, false)[1] or "" - -- Split suggestion into lines - local suggestion_lines = vim.split(suggestion, "\n", { plain = true }) + -- Split suggestion into lines + local suggestion_lines = vim.split(suggestion, "\n", { plain = true }) - if #suggestion_lines == 1 then - -- Single line - insert at cursor - local new_line = current_line:sub(1, col) .. suggestion .. current_line:sub(col + 1) - vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, { new_line }) - -- Move cursor to end of inserted text - vim.api.nvim_win_set_cursor(0, { line + 1, col + #suggestion }) - else - -- Multi-line - insert at cursor - local first_line = current_line:sub(1, col) .. suggestion_lines[1] - local last_line = suggestion_lines[#suggestion_lines] .. current_line:sub(col + 1) + if #suggestion_lines == 1 then + -- Single line - insert at cursor + local new_line = current_line:sub(1, col) .. suggestion .. current_line:sub(col + 1) + vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, { new_line }) + -- Move cursor to end of inserted text + vim.api.nvim_win_set_cursor(0, { line + 1, col + #suggestion }) + else + -- Multi-line - insert at cursor + local first_line = current_line:sub(1, col) .. suggestion_lines[1] + local last_line = suggestion_lines[#suggestion_lines] .. current_line:sub(col + 1) - local new_lines = { first_line } - for i = 2, #suggestion_lines - 1 do - table.insert(new_lines, suggestion_lines[i]) - end - table.insert(new_lines, last_line) + local new_lines = { first_line } + for i = 2, #suggestion_lines - 1 do + table.insert(new_lines, suggestion_lines[i]) + end + table.insert(new_lines, last_line) - vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, new_lines) - -- Move cursor to end of last line - vim.api.nvim_win_set_cursor(0, { line + #new_lines, #suggestion_lines[#suggestion_lines] }) - end + vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, new_lines) + -- Move cursor to end of last line + vim.api.nvim_win_set_cursor(0, { line + #new_lines, #suggestion_lines[#suggestion_lines] }) + end - return true - end + return true + end - return false + return false end --- Show next suggestion function M.next() - -- Check copilot first - local copilot_ok, copilot_suggestion = get_copilot() - if copilot_ok and copilot_suggestion.is_visible() then - copilot_suggestion.next() - return - end + -- Check copilot first + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + copilot_suggestion.next() + return + end - -- Codetyper's suggestions - if #state.suggestions <= 1 then - return - end + -- Codetyper's suggestions + if #state.suggestions <= 1 then + return + end - state.current_index = (state.current_index % #state.suggestions) + 1 - display_suggestion(state.suggestions[state.current_index]) + state.current_index = (state.current_index % #state.suggestions) + 1 + display_suggestion(state.suggestions[state.current_index]) end --- Show previous suggestion function M.prev() - -- Check copilot first - local copilot_ok, copilot_suggestion = get_copilot() - if copilot_ok and copilot_suggestion.is_visible() then - copilot_suggestion.prev() - return - end + -- Check copilot first + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + copilot_suggestion.prev() + return + end - -- Codetyper's suggestions - if #state.suggestions <= 1 then - return - end + -- Codetyper's suggestions + if #state.suggestions <= 1 then + return + end - state.current_index = state.current_index - 1 - if state.current_index < 1 then - state.current_index = #state.suggestions - end - display_suggestion(state.suggestions[state.current_index]) + state.current_index = state.current_index - 1 + if state.current_index < 1 then + state.current_index = #state.suggestions + end + display_suggestion(state.suggestions[state.current_index]) end --- Get suggestions from brain/indexer @@ -260,232 +260,227 @@ end ---@param context table Context info ---@return string[] suggestions local function get_suggestions(prefix, context) - local suggestions = {} + local suggestions = {} - -- Get completions from brain - local ok_brain, brain = pcall(require, "codetyper.brain") - if ok_brain and brain.is_initialized and brain.is_initialized() then - local result = brain.query({ - query = prefix, - max_results = 5, - types = { "pattern" }, - }) + -- Get completions from brain + local ok_brain, brain = pcall(require, "codetyper.brain") + if ok_brain and brain.is_initialized and brain.is_initialized() then + local result = brain.query({ + query = prefix, + max_results = 5, + types = { "pattern" }, + }) - if result and result.nodes then - for _, node in ipairs(result.nodes) do - if node.c and node.c.code then - table.insert(suggestions, node.c.code) - end - end - end - end + if result and result.nodes then + for _, node in ipairs(result.nodes) do + if node.c and node.c.code then + table.insert(suggestions, node.c.code) + end + end + end + end - -- Get completions from indexer - local ok_indexer, indexer = pcall(require, "codetyper.indexer") - if ok_indexer then - local index = indexer.load_index() - if index and index.symbols then - for symbol, _ in pairs(index.symbols) do - if symbol:lower():find(prefix:lower(), 1, true) and symbol ~= prefix then - -- Just complete the symbol name - local completion = symbol:sub(#prefix + 1) - if completion ~= "" then - table.insert(suggestions, completion) - end - end - end - end - end + -- Get completions from indexer + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if ok_indexer then + local index = indexer.load_index() + if index and index.symbols then + for symbol, _ in pairs(index.symbols) do + if symbol:lower():find(prefix:lower(), 1, true) and symbol ~= prefix then + -- Just complete the symbol name + local completion = symbol:sub(#prefix + 1) + if completion ~= "" then + table.insert(suggestions, completion) + end + end + end + end + end - -- Buffer-based completions - local bufnr = vim.api.nvim_get_current_buf() - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local seen = {} + -- Buffer-based completions + local bufnr = vim.api.nvim_get_current_buf() + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local seen = {} - for _, line in ipairs(lines) do - for word in line:gmatch("[%a_][%w_]*") do - if - #word > #prefix - and word:lower():find(prefix:lower(), 1, true) == 1 - and not seen[word] - and word ~= prefix - then - seen[word] = true - local completion = word:sub(#prefix + 1) - if completion ~= "" then - table.insert(suggestions, completion) - end - end - end - end + for _, line in ipairs(lines) do + for word in line:gmatch("[%a_][%w_]*") do + if #word > #prefix and word:lower():find(prefix:lower(), 1, true) == 1 and not seen[word] and word ~= prefix then + seen[word] = true + local completion = word:sub(#prefix + 1) + if completion ~= "" then + table.insert(suggestions, completion) + end + end + end + end - return suggestions + return suggestions end --- Trigger suggestion generation function M.trigger() - if not config.enabled then - return - end + if not config.enabled then + return + end - -- If copilot is available and has a suggestion, don't show codetyper's - local copilot_ok, copilot_suggestion = get_copilot() - if copilot_ok and copilot_suggestion.is_visible() then - -- Copilot is handling suggestions - state.using_copilot = true - return - end + -- If copilot is available and has a suggestion, don't show codetyper's + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + -- Copilot is handling suggestions + state.using_copilot = true + return + end - -- Cancel existing timer - if state.timer then - state.timer:stop() - state.timer = nil - end + -- Cancel existing timer + if state.timer then + state.timer:stop() + state.timer = nil + end - -- Get current context - local cursor = vim.api.nvim_win_get_cursor(0) - local line = vim.api.nvim_get_current_line() - local col = cursor[2] - local before_cursor = line:sub(1, col) + -- Get current context + local cursor = vim.api.nvim_win_get_cursor(0) + local line = vim.api.nvim_get_current_line() + local col = cursor[2] + local before_cursor = line:sub(1, col) - -- Extract prefix (word being typed) - local prefix = before_cursor:match("[%a_][%w_]*$") or "" + -- Extract prefix (word being typed) + local prefix = before_cursor:match("[%a_][%w_]*$") or "" - if #prefix < 2 then - M.dismiss() - return - end + if #prefix < 2 then + M.dismiss() + return + end - -- Debounce - wait a bit longer to let copilot try first - local debounce_time = copilot_ok and (config.debounce + 200) or config.debounce + -- Debounce - wait a bit longer to let copilot try first + local debounce_time = copilot_ok and (config.debounce + 200) or config.debounce - state.timer = vim.defer_fn(function() - -- Check again if copilot has shown something - if copilot_ok and copilot_suggestion.is_visible() then - state.using_copilot = true - state.timer = nil - return - end + state.timer = vim.defer_fn(function() + -- Check again if copilot has shown something + if copilot_ok and copilot_suggestion.is_visible() then + state.using_copilot = true + state.timer = nil + return + end - local suggestions = get_suggestions(prefix, { - line = line, - col = col, - bufnr = vim.api.nvim_get_current_buf(), - }) + local suggestions = get_suggestions(prefix, { + line = line, + col = col, + bufnr = vim.api.nvim_get_current_buf(), + }) - if #suggestions > 0 then - state.suggestions = suggestions - state.current_index = 1 - display_suggestion(suggestions[1]) - else - M.dismiss() - end + if #suggestions > 0 then + state.suggestions = suggestions + state.current_index = 1 + display_suggestion(suggestions[1]) + else + M.dismiss() + end - state.timer = nil - end, debounce_time) + state.timer = nil + end, debounce_time) end --- Setup keymaps local function setup_keymaps() - -- Accept with Tab (only when suggestion visible) - vim.keymap.set("i", config.keymap.accept, function() - if M.is_visible() then - M.accept() - return "" - end - -- Fallback to normal Tab behavior - return vim.api.nvim_replace_termcodes("", true, false, true) - end, { expr = true, silent = true, desc = "Accept codetyper suggestion" }) + -- Accept with Tab (only when suggestion visible) + vim.keymap.set("i", config.keymap.accept, function() + if M.is_visible() then + M.accept() + return "" + end + -- Fallback to normal Tab behavior + return vim.api.nvim_replace_termcodes("", true, false, true) + end, { expr = true, silent = true, desc = "Accept codetyper suggestion" }) - -- Next suggestion - vim.keymap.set("i", config.keymap.next, function() - M.next() - end, { silent = true, desc = "Next codetyper suggestion" }) + -- Next suggestion + vim.keymap.set("i", config.keymap.next, function() + M.next() + end, { silent = true, desc = "Next codetyper suggestion" }) - -- Previous suggestion - vim.keymap.set("i", config.keymap.prev, function() - M.prev() - end, { silent = true, desc = "Previous codetyper suggestion" }) + -- Previous suggestion + vim.keymap.set("i", config.keymap.prev, function() + M.prev() + end, { silent = true, desc = "Previous codetyper suggestion" }) - -- Dismiss - vim.keymap.set("i", config.keymap.dismiss, function() - M.dismiss() - end, { silent = true, desc = "Dismiss codetyper suggestion" }) + -- Dismiss + vim.keymap.set("i", config.keymap.dismiss, function() + M.dismiss() + end, { silent = true, desc = "Dismiss codetyper suggestion" }) end --- Setup autocmds for auto-trigger local function setup_autocmds() - local group = vim.api.nvim_create_augroup("CodetypeSuggestion", { clear = true }) + local group = vim.api.nvim_create_augroup("CodetypeSuggestion", { clear = true }) - -- Trigger on text change in insert mode - if config.auto_trigger then - vim.api.nvim_create_autocmd("TextChangedI", { - group = group, - callback = function() - M.trigger() - end, - }) - end + -- Trigger on text change in insert mode + if config.auto_trigger then + vim.api.nvim_create_autocmd("TextChangedI", { + group = group, + callback = function() + M.trigger() + end, + }) + end - -- Dismiss on leaving insert mode - vim.api.nvim_create_autocmd("InsertLeave", { - group = group, - callback = function() - M.dismiss() - end, - }) + -- Dismiss on leaving insert mode + vim.api.nvim_create_autocmd("InsertLeave", { + group = group, + callback = function() + M.dismiss() + end, + }) - -- Dismiss on cursor move (not from typing) - vim.api.nvim_create_autocmd("CursorMovedI", { - group = group, - callback = function() - -- Only dismiss if cursor moved significantly - if state.line ~= nil then - local cursor = vim.api.nvim_win_get_cursor(0) - if cursor[1] - 1 ~= state.line then - M.dismiss() - end - end - end, - }) + -- Dismiss on cursor move (not from typing) + vim.api.nvim_create_autocmd("CursorMovedI", { + group = group, + callback = function() + -- Only dismiss if cursor moved significantly + if state.line ~= nil then + local cursor = vim.api.nvim_win_get_cursor(0) + if cursor[1] - 1 ~= state.line then + M.dismiss() + end + end + end, + }) end --- Setup highlight group local function setup_highlights() - -- Use Comment highlight or define custom ghost text style - vim.api.nvim_set_hl(0, hl_group, { link = "Comment" }) + -- Use Comment highlight or define custom ghost text style + vim.api.nvim_set_hl(0, hl_group, { link = "Comment" }) end --- Setup the suggestion system ---@param opts? table Configuration options function M.setup(opts) - if opts then - config = vim.tbl_deep_extend("force", config, opts) - end + if opts then + config = vim.tbl_deep_extend("force", config, opts) + end - setup_highlights() - setup_keymaps() - setup_autocmds() + setup_highlights() + setup_keymaps() + setup_autocmds() end --- Enable suggestions function M.enable() - config.enabled = true + config.enabled = true end --- Disable suggestions function M.disable() - config.enabled = false - M.dismiss() + config.enabled = false + M.dismiss() end --- Toggle suggestions function M.toggle() - if config.enabled then - M.disable() - else - M.enable() - end + if config.enabled then + M.disable() + else + M.enable() + end end return M diff --git a/lua/codetyper/features/indexer/analyzer.lua b/lua/codetyper/features/indexer/analyzer.lua index dfc5de5..55b7801 100644 --- a/lua/codetyper/features/indexer/analyzer.lua +++ b/lua/codetyper/features/indexer/analyzer.lua @@ -11,8 +11,8 @@ local scanner = require("codetyper.features.indexer.scanner") --- Language-specific query patterns for Tree-sitter local TS_QUERIES = { - lua = { - functions = [[ + lua = { + functions = [[ (function_declaration name: (identifier) @name) @func (function_definition) @func (local_function name: (identifier) @name) @func @@ -20,67 +20,67 @@ local TS_QUERIES = { (variable_list name: (identifier) @name) (expression_list value: (function_definition) @func)) ]], - exports = [[ + exports = [[ (return_statement (expression_list (table_constructor))) @export ]], - }, - typescript = { - functions = [[ + }, + typescript = { + functions = [[ (function_declaration name: (identifier) @name) @func (method_definition name: (property_identifier) @name) @func (arrow_function) @func (lexical_declaration (variable_declarator name: (identifier) @name value: (arrow_function) @func)) ]], - exports = [[ + exports = [[ (export_statement) @export ]], - imports = [[ + imports = [[ (import_statement) @import ]], - }, - javascript = { - functions = [[ + }, + javascript = { + functions = [[ (function_declaration name: (identifier) @name) @func (method_definition name: (property_identifier) @name) @func (arrow_function) @func ]], - exports = [[ + exports = [[ (export_statement) @export ]], - imports = [[ + imports = [[ (import_statement) @import ]], - }, - python = { - functions = [[ + }, + python = { + functions = [[ (function_definition name: (identifier) @name) @func ]], - classes = [[ + classes = [[ (class_definition name: (identifier) @name) @class ]], - imports = [[ + imports = [[ (import_statement) @import (import_from_statement) @import ]], - }, - go = { - functions = [[ + }, + go = { + functions = [[ (function_declaration name: (identifier) @name) @func (method_declaration name: (field_identifier) @name) @func ]], - imports = [[ + imports = [[ (import_declaration) @import ]], - }, - rust = { - functions = [[ + }, + rust = { + functions = [[ (function_item name: (identifier) @name) @func ]], - imports = [[ + imports = [[ (use_declaration) @import ]], - }, + }, } -- Forward declaration for analyze_tree_generic (defined below) @@ -90,19 +90,19 @@ local analyze_tree_generic ---@param content string ---@return string local function hash_content(content) - local hash = 0 - for i = 1, math.min(#content, 10000) do - hash = (hash * 31 + string.byte(content, i)) % 2147483647 - end - return string.format("%08x", hash) + local hash = 0 + for i = 1, math.min(#content, 10000) do + hash = (hash * 31 + string.byte(content, i)) % 2147483647 + end + return string.format("%08x", hash) end --- Try to get Tree-sitter parser for a language ---@param lang string ---@return boolean local function has_ts_parser(lang) - local ok = pcall(vim.treesitter.language.inspect, lang) - return ok + local ok = pcall(vim.treesitter.language.inspect, lang) + return ok end --- Analyze file using Tree-sitter @@ -111,148 +111,154 @@ end ---@param content string ---@return table|nil local function analyze_with_treesitter(filepath, lang, content) - if not has_ts_parser(lang) then - return nil - end + if not has_ts_parser(lang) then + return nil + end - local result = { - functions = {}, - classes = {}, - exports = {}, - imports = {}, - } + local result = { + functions = {}, + classes = {}, + exports = {}, + imports = {}, + } - -- Create a temporary buffer for parsing - local bufnr = vim.api.nvim_create_buf(false, true) - vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, vim.split(content, "\n")) + -- Create a temporary buffer for parsing + local bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, vim.split(content, "\n")) - local ok, parser = pcall(vim.treesitter.get_parser, bufnr, lang) - if not ok or not parser then - vim.api.nvim_buf_delete(bufnr, { force = true }) - return nil - end + local ok, parser = pcall(vim.treesitter.get_parser, bufnr, lang) + if not ok or not parser then + vim.api.nvim_buf_delete(bufnr, { force = true }) + return nil + end - local tree = parser:parse()[1] - if not tree then - vim.api.nvim_buf_delete(bufnr, { force = true }) - return nil - end + local tree = parser:parse()[1] + if not tree then + vim.api.nvim_buf_delete(bufnr, { force = true }) + return nil + end - local root = tree:root() - local queries = TS_QUERIES[lang] + local root = tree:root() + local queries = TS_QUERIES[lang] - if not queries then - -- Fallback: walk tree manually for common patterns - result = analyze_tree_generic(root, bufnr) - else - -- Use language-specific queries - if queries.functions then - local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.functions) - if query_ok then - for id, node in query:iter_captures(root, bufnr, 0, -1) do - local capture_name = query.captures[id] - if capture_name == "func" or capture_name == "name" then - local start_row, _, end_row, _ = node:range() - local name = nil + if not queries then + -- Fallback: walk tree manually for common patterns + result = analyze_tree_generic(root, bufnr) + else + -- Use language-specific queries + if queries.functions then + local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.functions) + if query_ok then + for id, node in query:iter_captures(root, bufnr, 0, -1) do + local capture_name = query.captures[id] + if capture_name == "func" or capture_name == "name" then + local start_row, _, end_row, _ = node:range() + local name = nil - -- Try to get name from sibling capture or child - if capture_name == "func" then - local name_node = node:field("name")[1] - if name_node then - name = vim.treesitter.get_node_text(name_node, bufnr) - end - else - name = vim.treesitter.get_node_text(node, bufnr) - end + -- Try to get name from sibling capture or child + if capture_name == "func" then + local name_node = node:field("name")[1] + if name_node then + name = vim.treesitter.get_node_text(name_node, bufnr) + end + else + name = vim.treesitter.get_node_text(node, bufnr) + end - if name and not vim.tbl_contains(vim.tbl_map(function(f) - return f.name - end, result.functions), name) then - table.insert(result.functions, { - name = name, - line = start_row + 1, - end_line = end_row + 1, - params = {}, - }) - end - end - end - end - end + if + name + and not vim.tbl_contains( + vim.tbl_map(function(f) + return f.name + end, result.functions), + name + ) + then + table.insert(result.functions, { + name = name, + line = start_row + 1, + end_line = end_row + 1, + params = {}, + }) + end + end + end + end + end - if queries.classes then - local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.classes) - if query_ok then - for id, node in query:iter_captures(root, bufnr, 0, -1) do - local capture_name = query.captures[id] - if capture_name == "class" then - local start_row, _, end_row, _ = node:range() - local name_node = node:field("name")[1] - local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" + if queries.classes then + local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.classes) + if query_ok then + for id, node in query:iter_captures(root, bufnr, 0, -1) do + local capture_name = query.captures[id] + if capture_name == "class" then + local start_row, _, end_row, _ = node:range() + local name_node = node:field("name")[1] + local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" - table.insert(result.classes, { - name = name, - line = start_row + 1, - end_line = end_row + 1, - methods = {}, - }) - end - end - end - end + table.insert(result.classes, { + name = name, + line = start_row + 1, + end_line = end_row + 1, + methods = {}, + }) + end + end + end + end - if queries.exports then - local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.exports) - if query_ok then - for _, node in query:iter_captures(root, bufnr, 0, -1) do - local text = vim.treesitter.get_node_text(node, bufnr) - local start_row, _, _, _ = node:range() + if queries.exports then + local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.exports) + if query_ok then + for _, node in query:iter_captures(root, bufnr, 0, -1) do + local text = vim.treesitter.get_node_text(node, bufnr) + local start_row, _, _, _ = node:range() - -- Extract export names (simplified) - local names = {} - for name in text:gmatch("export%s+[%w_]+%s+([%w_]+)") do - table.insert(names, name) - end - for name in text:gmatch("export%s*{([^}]+)}") do - for n in name:gmatch("([%w_]+)") do - table.insert(names, n) - end - end + -- Extract export names (simplified) + local names = {} + for name in text:gmatch("export%s+[%w_]+%s+([%w_]+)") do + table.insert(names, name) + end + for name in text:gmatch("export%s*{([^}]+)}") do + for n in name:gmatch("([%w_]+)") do + table.insert(names, n) + end + end - for _, name in ipairs(names) do - table.insert(result.exports, { - name = name, - type = "unknown", - line = start_row + 1, - }) - end - end - end - end + for _, name in ipairs(names) do + table.insert(result.exports, { + name = name, + type = "unknown", + line = start_row + 1, + }) + end + end + end + end - if queries.imports then - local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.imports) - if query_ok then - for _, node in query:iter_captures(root, bufnr, 0, -1) do - local text = vim.treesitter.get_node_text(node, bufnr) - local start_row, _, _, _ = node:range() + if queries.imports then + local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.imports) + if query_ok then + for _, node in query:iter_captures(root, bufnr, 0, -1) do + local text = vim.treesitter.get_node_text(node, bufnr) + local start_row, _, _, _ = node:range() - -- Extract import source - local source = text:match('["\']([^"\']+)["\']') - if source then - table.insert(result.imports, { - source = source, - names = {}, - line = start_row + 1, - }) - end - end - end - end - end + -- Extract import source + local source = text:match("[\"']([^\"']+)[\"']") + if source then + table.insert(result.imports, { + source = source, + names = {}, + line = start_row + 1, + }) + end + end + end + end + end - vim.api.nvim_buf_delete(bufnr, { force = true }) - return result + vim.api.nvim_buf_delete(bufnr, { force = true }) + return result end --- Generic tree analysis for unsupported languages @@ -260,57 +266,57 @@ end ---@param bufnr number ---@return table analyze_tree_generic = function(root, bufnr) - local result = { - functions = {}, - classes = {}, - exports = {}, - imports = {}, - } + local result = { + functions = {}, + classes = {}, + exports = {}, + imports = {}, + } - local function visit(node) - local node_type = node:type() + local function visit(node) + local node_type = node:type() - -- Common function patterns - if - node_type:match("function") - or node_type:match("method") - or node_type == "arrow_function" - or node_type == "func_literal" - then - local start_row, _, end_row, _ = node:range() - local name_node = node:field("name")[1] - local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" + -- Common function patterns + if + node_type:match("function") + or node_type:match("method") + or node_type == "arrow_function" + or node_type == "func_literal" + then + local start_row, _, end_row, _ = node:range() + local name_node = node:field("name")[1] + local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" - table.insert(result.functions, { - name = name, - line = start_row + 1, - end_line = end_row + 1, - params = {}, - }) - end + table.insert(result.functions, { + name = name, + line = start_row + 1, + end_line = end_row + 1, + params = {}, + }) + end - -- Common class patterns - if node_type:match("class") or node_type == "struct_item" or node_type == "impl_item" then - local start_row, _, end_row, _ = node:range() - local name_node = node:field("name")[1] - local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" + -- Common class patterns + if node_type:match("class") or node_type == "struct_item" or node_type == "impl_item" then + local start_row, _, end_row, _ = node:range() + local name_node = node:field("name")[1] + local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" - table.insert(result.classes, { - name = name, - line = start_row + 1, - end_line = end_row + 1, - methods = {}, - }) - end + table.insert(result.classes, { + name = name, + line = start_row + 1, + end_line = end_row + 1, + methods = {}, + }) + end - -- Recurse into children - for child in node:iter_children() do - visit(child) - end - end + -- Recurse into children + for child in node:iter_children() do + visit(child) + end + end - visit(root) - return result + visit(root) + return result end --- Analyze file using pattern matching (fallback) @@ -318,268 +324,268 @@ end ---@param lang string ---@return table local function analyze_with_patterns(content, lang) - local result = { - functions = {}, - classes = {}, - exports = {}, - imports = {}, - } + local result = { + functions = {}, + classes = {}, + exports = {}, + imports = {}, + } - local lines = vim.split(content, "\n") + local lines = vim.split(content, "\n") - -- Language-specific patterns - local patterns = { - lua = { - func_start = "^%s*local?%s*function%s+([%w_%.]+)", - func_assign = "^%s*([%w_%.]+)%s*=%s*function", - module_return = "^return%s+M", - }, - javascript = { - func_start = "^%s*function%s+([%w_]+)", - func_arrow = "^%s*const%s+([%w_]+)%s*=%s*", - class_start = "^%s*class%s+([%w_]+)", - export_line = "^%s*export%s+", - import_line = "^%s*import%s+", - }, - typescript = { - func_start = "^%s*function%s+([%w_]+)", - func_arrow = "^%s*const%s+([%w_]+)%s*=%s*", - class_start = "^%s*class%s+([%w_]+)", - export_line = "^%s*export%s+", - import_line = "^%s*import%s+", - }, - python = { - func_start = "^%s*def%s+([%w_]+)", - class_start = "^%s*class%s+([%w_]+)", - import_line = "^%s*import%s+", - from_import = "^%s*from%s+", - }, - go = { - func_start = "^func%s+([%w_]+)", - method_start = "^func%s+%([^%)]+%)%s+([%w_]+)", - import_line = "^import%s+", - }, - rust = { - func_start = "^%s*pub?%s*fn%s+([%w_]+)", - struct_start = "^%s*pub?%s*struct%s+([%w_]+)", - impl_start = "^%s*impl%s+([%w_<>]+)", - use_line = "^%s*use%s+", - }, - } + -- Language-specific patterns + local patterns = { + lua = { + func_start = "^%s*local?%s*function%s+([%w_%.]+)", + func_assign = "^%s*([%w_%.]+)%s*=%s*function", + module_return = "^return%s+M", + }, + javascript = { + func_start = "^%s*function%s+([%w_]+)", + func_arrow = "^%s*const%s+([%w_]+)%s*=%s*", + class_start = "^%s*class%s+([%w_]+)", + export_line = "^%s*export%s+", + import_line = "^%s*import%s+", + }, + typescript = { + func_start = "^%s*function%s+([%w_]+)", + func_arrow = "^%s*const%s+([%w_]+)%s*=%s*", + class_start = "^%s*class%s+([%w_]+)", + export_line = "^%s*export%s+", + import_line = "^%s*import%s+", + }, + python = { + func_start = "^%s*def%s+([%w_]+)", + class_start = "^%s*class%s+([%w_]+)", + import_line = "^%s*import%s+", + from_import = "^%s*from%s+", + }, + go = { + func_start = "^func%s+([%w_]+)", + method_start = "^func%s+%([^%)]+%)%s+([%w_]+)", + import_line = "^import%s+", + }, + rust = { + func_start = "^%s*pub?%s*fn%s+([%w_]+)", + struct_start = "^%s*pub?%s*struct%s+([%w_]+)", + impl_start = "^%s*impl%s+([%w_<>]+)", + use_line = "^%s*use%s+", + }, + } - local lang_patterns = patterns[lang] or patterns.javascript + local lang_patterns = patterns[lang] or patterns.javascript - for i, line in ipairs(lines) do - -- Functions - if lang_patterns.func_start then - local name = line:match(lang_patterns.func_start) - if name then - table.insert(result.functions, { - name = name, - line = i, - end_line = i, - params = {}, - }) - end - end + for i, line in ipairs(lines) do + -- Functions + if lang_patterns.func_start then + local name = line:match(lang_patterns.func_start) + if name then + table.insert(result.functions, { + name = name, + line = i, + end_line = i, + params = {}, + }) + end + end - if lang_patterns.func_arrow then - local name = line:match(lang_patterns.func_arrow) - if name and line:match("=>") then - table.insert(result.functions, { - name = name, - line = i, - end_line = i, - params = {}, - }) - end - end + if lang_patterns.func_arrow then + local name = line:match(lang_patterns.func_arrow) + if name and line:match("=>") then + table.insert(result.functions, { + name = name, + line = i, + end_line = i, + params = {}, + }) + end + end - if lang_patterns.func_assign then - local name = line:match(lang_patterns.func_assign) - if name then - table.insert(result.functions, { - name = name, - line = i, - end_line = i, - params = {}, - }) - end - end + if lang_patterns.func_assign then + local name = line:match(lang_patterns.func_assign) + if name then + table.insert(result.functions, { + name = name, + line = i, + end_line = i, + params = {}, + }) + end + end - if lang_patterns.method_start then - local name = line:match(lang_patterns.method_start) - if name then - table.insert(result.functions, { - name = name, - line = i, - end_line = i, - params = {}, - }) - end - end + if lang_patterns.method_start then + local name = line:match(lang_patterns.method_start) + if name then + table.insert(result.functions, { + name = name, + line = i, + end_line = i, + params = {}, + }) + end + end - -- Classes - if lang_patterns.class_start then - local name = line:match(lang_patterns.class_start) - if name then - table.insert(result.classes, { - name = name, - line = i, - end_line = i, - methods = {}, - }) - end - end + -- Classes + if lang_patterns.class_start then + local name = line:match(lang_patterns.class_start) + if name then + table.insert(result.classes, { + name = name, + line = i, + end_line = i, + methods = {}, + }) + end + end - if lang_patterns.struct_start then - local name = line:match(lang_patterns.struct_start) - if name then - table.insert(result.classes, { - name = name, - line = i, - end_line = i, - methods = {}, - }) - end - end + if lang_patterns.struct_start then + local name = line:match(lang_patterns.struct_start) + if name then + table.insert(result.classes, { + name = name, + line = i, + end_line = i, + methods = {}, + }) + end + end - -- Exports - if lang_patterns.export_line and line:match(lang_patterns.export_line) then - local name = line:match("export%s+[%w_]+%s+([%w_]+)") - or line:match("export%s+default%s+([%w_]+)") - or line:match("export%s+{%s*([%w_]+)") - if name then - table.insert(result.exports, { - name = name, - type = "unknown", - line = i, - }) - end - end + -- Exports + if lang_patterns.export_line and line:match(lang_patterns.export_line) then + local name = line:match("export%s+[%w_]+%s+([%w_]+)") + or line:match("export%s+default%s+([%w_]+)") + or line:match("export%s+{%s*([%w_]+)") + if name then + table.insert(result.exports, { + name = name, + type = "unknown", + line = i, + }) + end + end - -- Imports - if lang_patterns.import_line and line:match(lang_patterns.import_line) then - local source = line:match('["\']([^"\']+)["\']') - if source then - table.insert(result.imports, { - source = source, - names = {}, - line = i, - }) - end - end + -- Imports + if lang_patterns.import_line and line:match(lang_patterns.import_line) then + local source = line:match("[\"']([^\"']+)[\"']") + if source then + table.insert(result.imports, { + source = source, + names = {}, + line = i, + }) + end + end - if lang_patterns.from_import and line:match(lang_patterns.from_import) then - local source = line:match("from%s+([%w_%.]+)") - if source then - table.insert(result.imports, { - source = source, - names = {}, - line = i, - }) - end - end + if lang_patterns.from_import and line:match(lang_patterns.from_import) then + local source = line:match("from%s+([%w_%.]+)") + if source then + table.insert(result.imports, { + source = source, + names = {}, + line = i, + }) + end + end - if lang_patterns.use_line and line:match(lang_patterns.use_line) then - local source = line:match("use%s+([%w_:]+)") - if source then - table.insert(result.imports, { - source = source, - names = {}, - line = i, - }) - end - end - end + if lang_patterns.use_line and line:match(lang_patterns.use_line) then + local source = line:match("use%s+([%w_:]+)") + if source then + table.insert(result.imports, { + source = source, + names = {}, + line = i, + }) + end + end + end - -- For Lua, infer exports from module table - if lang == "lua" then - for _, func in ipairs(result.functions) do - if func.name:match("^M%.") then - local name = func.name:gsub("^M%.", "") - table.insert(result.exports, { - name = name, - type = "function", - line = func.line, - }) - end - end - end + -- For Lua, infer exports from module table + if lang == "lua" then + for _, func in ipairs(result.functions) do + if func.name:match("^M%.") then + local name = func.name:gsub("^M%.", "") + table.insert(result.exports, { + name = name, + type = "function", + line = func.line, + }) + end + end + end - return result + return result end --- Analyze a single file ---@param filepath string Full path to file ---@return FileIndex|nil function M.analyze_file(filepath) - local content = utils.read_file(filepath) - if not content then - return nil - end + local content = utils.read_file(filepath) + if not content then + return nil + end - local lang = scanner.get_language(filepath) + local lang = scanner.get_language(filepath) - -- Map to Tree-sitter language names - local ts_lang_map = { - typescript = "typescript", - typescriptreact = "tsx", - javascript = "javascript", - javascriptreact = "javascript", - python = "python", - go = "go", - rust = "rust", - lua = "lua", - } + -- Map to Tree-sitter language names + local ts_lang_map = { + typescript = "typescript", + typescriptreact = "tsx", + javascript = "javascript", + javascriptreact = "javascript", + python = "python", + go = "go", + rust = "rust", + lua = "lua", + } - local ts_lang = ts_lang_map[lang] or lang + local ts_lang = ts_lang_map[lang] or lang - -- Try Tree-sitter first - local analysis = analyze_with_treesitter(filepath, ts_lang, content) + -- Try Tree-sitter first + local analysis = analyze_with_treesitter(filepath, ts_lang, content) - -- Fallback to pattern matching - if not analysis then - analysis = analyze_with_patterns(content, lang) - end + -- Fallback to pattern matching + if not analysis then + analysis = analyze_with_patterns(content, lang) + end - return { - path = filepath, - language = lang, - hash = hash_content(content), - exports = analysis.exports, - imports = analysis.imports, - functions = analysis.functions, - classes = analysis.classes, - last_indexed = os.time(), - } + return { + path = filepath, + language = lang, + hash = hash_content(content), + exports = analysis.exports, + imports = analysis.imports, + functions = analysis.functions, + classes = analysis.classes, + last_indexed = os.time(), + } end --- Extract exports from a buffer ---@param bufnr number ---@return Export[] function M.extract_exports(bufnr) - local filepath = vim.api.nvim_buf_get_name(bufnr) - local analysis = M.analyze_file(filepath) - return analysis and analysis.exports or {} + local filepath = vim.api.nvim_buf_get_name(bufnr) + local analysis = M.analyze_file(filepath) + return analysis and analysis.exports or {} end --- Extract functions from a buffer ---@param bufnr number ---@return FunctionInfo[] function M.extract_functions(bufnr) - local filepath = vim.api.nvim_buf_get_name(bufnr) - local analysis = M.analyze_file(filepath) - return analysis and analysis.functions or {} + local filepath = vim.api.nvim_buf_get_name(bufnr) + local analysis = M.analyze_file(filepath) + return analysis and analysis.functions or {} end --- Extract imports from a buffer ---@param bufnr number ---@return Import[] function M.extract_imports(bufnr) - local filepath = vim.api.nvim_buf_get_name(bufnr) - local analysis = M.analyze_file(filepath) - return analysis and analysis.imports or {} + local filepath = vim.api.nvim_buf_get_name(bufnr) + local analysis = M.analyze_file(filepath) + return analysis and analysis.imports or {} end return M diff --git a/lua/codetyper/features/indexer/init.lua b/lua/codetyper/features/indexer/init.lua index 8570241..d216120 100644 --- a/lua/codetyper/features/indexer/init.lua +++ b/lua/codetyper/features/indexer/init.lua @@ -20,17 +20,17 @@ local INDEX_DEBOUNCE_MS = 500 --- Default indexer configuration local default_config = { - enabled = true, - auto_index = true, - index_on_open = false, - max_file_size = 100000, - excluded_dirs = { "node_modules", "dist", "build", ".git", ".codetyper", "__pycache__", "vendor", "target" }, - index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" }, - memory = { - enabled = true, - max_memories = 1000, - prune_threshold = 0.1, - }, + enabled = true, + auto_index = true, + index_on_open = false, + max_file_size = 100000, + excluded_dirs = { "node_modules", "dist", "build", ".git", ".codetyper", "__pycache__", "vendor", "target" }, + index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" }, + memory = { + enabled = true, + max_memories = 1000, + prune_threshold = 0.1, + }, } --- Current configuration @@ -90,183 +90,183 @@ local index_cache = {} --- Get the index file path ---@return string|nil local function get_index_path() - local root = utils.get_project_root() - if not root then - return nil - end - return root .. "/.codetyper/" .. INDEX_FILE + local root = utils.get_project_root() + if not root then + return nil + end + return root .. "/.codetyper/" .. INDEX_FILE end --- Create empty index structure ---@return ProjectIndex local function create_empty_index() - local root = utils.get_project_root() - return { - version = INDEX_VERSION, - project_root = root or "", - project_name = root and vim.fn.fnamemodify(root, ":t") or "", - project_type = "unknown", - dependencies = {}, - dev_dependencies = {}, - files = {}, - symbols = {}, - last_indexed = os.time(), - stats = { - files = 0, - functions = 0, - classes = 0, - exports = 0, - }, - } + local root = utils.get_project_root() + return { + version = INDEX_VERSION, + project_root = root or "", + project_name = root and vim.fn.fnamemodify(root, ":t") or "", + project_type = "unknown", + dependencies = {}, + dev_dependencies = {}, + files = {}, + symbols = {}, + last_indexed = os.time(), + stats = { + files = 0, + functions = 0, + classes = 0, + exports = 0, + }, + } end --- Load index from disk ---@return ProjectIndex|nil function M.load_index() - local root = utils.get_project_root() - if not root then - return nil - end + local root = utils.get_project_root() + if not root then + return nil + end - -- Check cache first - if index_cache[root] then - return index_cache[root] - end + -- Check cache first + if index_cache[root] then + return index_cache[root] + end - local path = get_index_path() - if not path then - return nil - end + local path = get_index_path() + if not path then + return nil + end - local content = utils.read_file(path) - if not content then - return nil - end + local content = utils.read_file(path) + if not content then + return nil + end - local ok, index = pcall(vim.json.decode, content) - if not ok or not index then - return nil - end + local ok, index = pcall(vim.json.decode, content) + if not ok or not index then + return nil + end - -- Validate version - if index.version ~= INDEX_VERSION then - -- Index needs migration or rebuild - return nil - end + -- Validate version + if index.version ~= INDEX_VERSION then + -- Index needs migration or rebuild + return nil + end - -- Cache it - index_cache[root] = index - return index + -- Cache it + index_cache[root] = index + return index end --- Save index to disk ---@param index ProjectIndex ---@return boolean function M.save_index(index) - local root = utils.get_project_root() - if not root then - return false - end + local root = utils.get_project_root() + if not root then + return false + end - -- Ensure .codetyper directory exists - local coder_dir = root .. "/.codetyper" - utils.ensure_dir(coder_dir) + -- Ensure .codetyper directory exists + local coder_dir = root .. "/.codetyper" + utils.ensure_dir(coder_dir) - local path = get_index_path() - if not path then - return false - end + local path = get_index_path() + if not path then + return false + end - local ok, encoded = pcall(vim.json.encode, index) - if not ok then - return false - end + local ok, encoded = pcall(vim.json.encode, index) + if not ok then + return false + end - local success = utils.write_file(path, encoded) - if success then - -- Update cache - index_cache[root] = index - end - return success + local success = utils.write_file(path, encoded) + if success then + -- Update cache + index_cache[root] = index + end + return success end --- Index the entire project ---@param callback? fun(index: ProjectIndex) ---@return ProjectIndex|nil function M.index_project(callback) - local scanner = require("codetyper.features.indexer.scanner") - local analyzer = require("codetyper.features.indexer.analyzer") + local scanner = require("codetyper.features.indexer.scanner") + local analyzer = require("codetyper.features.indexer.analyzer") - local index = create_empty_index() - local root = utils.get_project_root() + local index = create_empty_index() + local root = utils.get_project_root() - if not root then - if callback then - callback(index) - end - return index - end + if not root then + if callback then + callback(index) + end + return index + end - -- Detect project type and parse dependencies - index.project_type = scanner.detect_project_type(root) - local deps = scanner.parse_dependencies(root, index.project_type) - index.dependencies = deps.dependencies or {} - index.dev_dependencies = deps.dev_dependencies or {} + -- Detect project type and parse dependencies + index.project_type = scanner.detect_project_type(root) + local deps = scanner.parse_dependencies(root, index.project_type) + index.dependencies = deps.dependencies or {} + index.dev_dependencies = deps.dev_dependencies or {} - -- Get all indexable files - local files = scanner.get_indexable_files(root, config) + -- Get all indexable files + local files = scanner.get_indexable_files(root, config) - -- Index each file - local total_functions = 0 - local total_classes = 0 - local total_exports = 0 + -- Index each file + local total_functions = 0 + local total_classes = 0 + local total_exports = 0 - for _, filepath in ipairs(files) do - local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "") - local file_index = analyzer.analyze_file(filepath) + for _, filepath in ipairs(files) do + local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "") + local file_index = analyzer.analyze_file(filepath) - if file_index then - file_index.path = relative_path - index.files[relative_path] = file_index + if file_index then + file_index.path = relative_path + index.files[relative_path] = file_index - -- Update symbol index - for _, exp in ipairs(file_index.exports or {}) do - if not index.symbols[exp.name] then - index.symbols[exp.name] = {} - end - table.insert(index.symbols[exp.name], relative_path) - total_exports = total_exports + 1 - end + -- Update symbol index + for _, exp in ipairs(file_index.exports or {}) do + if not index.symbols[exp.name] then + index.symbols[exp.name] = {} + end + table.insert(index.symbols[exp.name], relative_path) + total_exports = total_exports + 1 + end - total_functions = total_functions + #(file_index.functions or {}) - total_classes = total_classes + #(file_index.classes or {}) - end - end + total_functions = total_functions + #(file_index.functions or {}) + total_classes = total_classes + #(file_index.classes or {}) + end + end - -- Update stats - index.stats = { - files = #files, - functions = total_functions, - classes = total_classes, - exports = total_exports, - } - index.last_indexed = os.time() + -- Update stats + index.stats = { + files = #files, + functions = total_functions, + classes = total_classes, + exports = total_exports, + } + index.last_indexed = os.time() - -- Save to disk - M.save_index(index) + -- Save to disk + M.save_index(index) - -- Store memories - local memory = require("codetyper.features.indexer.memory") - memory.store_index_summary(index) + -- Store memories + local memory = require("codetyper.features.indexer.memory") + memory.store_index_summary(index) - -- Sync project summary to brain - M.sync_project_to_brain(index, files, root) + -- Sync project summary to brain + M.sync_project_to_brain(index, files, root) - if callback then - callback(index) - end + if callback then + callback(index) + end - return index + return index end --- Sync project index to brain @@ -274,331 +274,331 @@ end ---@param files string[] List of file paths ---@param root string Project root function M.sync_project_to_brain(index, files, root) - local ok_brain, brain = pcall(require, "codetyper.brain") - if not ok_brain or not brain.is_initialized or not brain.is_initialized() then - return - end + local ok_brain, brain = pcall(require, "codetyper.brain") + if not ok_brain or not brain.is_initialized or not brain.is_initialized() then + return + end - -- Store project-level pattern - brain.learn({ - type = "pattern", - file = root, - content = { - summary = "Project: " - .. index.project_name - .. " (" - .. index.project_type - .. ") - " - .. index.stats.files - .. " files", - detail = string.format( - "%d functions, %d classes, %d exports", - index.stats.functions, - index.stats.classes, - index.stats.exports - ), - }, - context = { - file = root, - project_type = index.project_type, - dependencies = index.dependencies, - }, - }) + -- Store project-level pattern + brain.learn({ + type = "pattern", + file = root, + content = { + summary = "Project: " + .. index.project_name + .. " (" + .. index.project_type + .. ") - " + .. index.stats.files + .. " files", + detail = string.format( + "%d functions, %d classes, %d exports", + index.stats.functions, + index.stats.classes, + index.stats.exports + ), + }, + context = { + file = root, + project_type = index.project_type, + dependencies = index.dependencies, + }, + }) - -- Store key file patterns (files with most functions/classes) - local key_files = {} - for path, file_index in pairs(index.files) do - local score = #(file_index.functions or {}) + (#(file_index.classes or {}) * 2) - if score >= 3 then - table.insert(key_files, { path = path, index = file_index, score = score }) - end - end + -- Store key file patterns (files with most functions/classes) + local key_files = {} + for path, file_index in pairs(index.files) do + local score = #(file_index.functions or {}) + (#(file_index.classes or {}) * 2) + if score >= 3 then + table.insert(key_files, { path = path, index = file_index, score = score }) + end + end - table.sort(key_files, function(a, b) - return a.score > b.score - end) + table.sort(key_files, function(a, b) + return a.score > b.score + end) - -- Store top 20 key files in brain - for i, kf in ipairs(key_files) do - if i > 20 then - break - end - M.sync_to_brain(root .. "/" .. kf.path, kf.index) - end + -- Store top 20 key files in brain + for i, kf in ipairs(key_files) do + if i > 20 then + break + end + M.sync_to_brain(root .. "/" .. kf.path, kf.index) + end end --- Index a single file (incremental update) ---@param filepath string ---@return FileIndex|nil function M.index_file(filepath) - local analyzer = require("codetyper.features.indexer.analyzer") - local memory = require("codetyper.features.indexer.memory") - local root = utils.get_project_root() + local analyzer = require("codetyper.features.indexer.analyzer") + local memory = require("codetyper.features.indexer.memory") + local root = utils.get_project_root() - if not root then - return nil - end + if not root then + return nil + end - -- Load existing index - local index = M.load_index() or create_empty_index() + -- Load existing index + local index = M.load_index() or create_empty_index() - -- Analyze file - local file_index = analyzer.analyze_file(filepath) - if not file_index then - return nil - end + -- Analyze file + local file_index = analyzer.analyze_file(filepath) + if not file_index then + return nil + end - local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "") - file_index.path = relative_path + local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "") + file_index.path = relative_path - -- Remove old symbol references for this file - for symbol, paths in pairs(index.symbols) do - for i = #paths, 1, -1 do - if paths[i] == relative_path then - table.remove(paths, i) - end - end - if #paths == 0 then - index.symbols[symbol] = nil - end - end + -- Remove old symbol references for this file + for symbol, paths in pairs(index.symbols) do + for i = #paths, 1, -1 do + if paths[i] == relative_path then + table.remove(paths, i) + end + end + if #paths == 0 then + index.symbols[symbol] = nil + end + end - -- Add new file index - index.files[relative_path] = file_index + -- Add new file index + index.files[relative_path] = file_index - -- Update symbol index - for _, exp in ipairs(file_index.exports or {}) do - if not index.symbols[exp.name] then - index.symbols[exp.name] = {} - end - table.insert(index.symbols[exp.name], relative_path) - end + -- Update symbol index + for _, exp in ipairs(file_index.exports or {}) do + if not index.symbols[exp.name] then + index.symbols[exp.name] = {} + end + table.insert(index.symbols[exp.name], relative_path) + end - -- Recalculate stats - local total_functions = 0 - local total_classes = 0 - local total_exports = 0 - local file_count = 0 + -- Recalculate stats + local total_functions = 0 + local total_classes = 0 + local total_exports = 0 + local file_count = 0 - for _, f in pairs(index.files) do - file_count = file_count + 1 - total_functions = total_functions + #(f.functions or {}) - total_classes = total_classes + #(f.classes or {}) - total_exports = total_exports + #(f.exports or {}) - end + for _, f in pairs(index.files) do + file_count = file_count + 1 + total_functions = total_functions + #(f.functions or {}) + total_classes = total_classes + #(f.classes or {}) + total_exports = total_exports + #(f.exports or {}) + end - index.stats = { - files = file_count, - functions = total_functions, - classes = total_classes, - exports = total_exports, - } - index.last_indexed = os.time() + index.stats = { + files = file_count, + functions = total_functions, + classes = total_classes, + exports = total_exports, + } + index.last_indexed = os.time() - -- Save to disk - M.save_index(index) + -- Save to disk + M.save_index(index) - -- Store file memory - memory.store_file_memory(relative_path, file_index) + -- Store file memory + memory.store_file_memory(relative_path, file_index) - -- Sync to brain if available - M.sync_to_brain(filepath, file_index) + -- Sync to brain if available + M.sync_to_brain(filepath, file_index) - return file_index + return file_index end --- Sync file analysis to brain system ---@param filepath string Full file path ---@param file_index FileIndex File analysis function M.sync_to_brain(filepath, file_index) - local ok_brain, brain = pcall(require, "codetyper.brain") - if not ok_brain or not brain.is_initialized or not brain.is_initialized() then - return - end + local ok_brain, brain = pcall(require, "codetyper.brain") + if not ok_brain or not brain.is_initialized or not brain.is_initialized() then + return + end - -- Only store if file has meaningful content - local funcs = file_index.functions or {} - local classes = file_index.classes or {} - if #funcs == 0 and #classes == 0 then - return - end + -- Only store if file has meaningful content + local funcs = file_index.functions or {} + local classes = file_index.classes or {} + if #funcs == 0 and #classes == 0 then + return + end - -- Build summary - local parts = {} - if #funcs > 0 then - local func_names = {} - for i, f in ipairs(funcs) do - if i <= 5 then - table.insert(func_names, f.name) - end - end - table.insert(parts, "functions: " .. table.concat(func_names, ", ")) - if #funcs > 5 then - table.insert(parts, "(+" .. (#funcs - 5) .. " more)") - end - end - if #classes > 0 then - local class_names = {} - for _, c in ipairs(classes) do - table.insert(class_names, c.name) - end - table.insert(parts, "classes: " .. table.concat(class_names, ", ")) - end + -- Build summary + local parts = {} + if #funcs > 0 then + local func_names = {} + for i, f in ipairs(funcs) do + if i <= 5 then + table.insert(func_names, f.name) + end + end + table.insert(parts, "functions: " .. table.concat(func_names, ", ")) + if #funcs > 5 then + table.insert(parts, "(+" .. (#funcs - 5) .. " more)") + end + end + if #classes > 0 then + local class_names = {} + for _, c in ipairs(classes) do + table.insert(class_names, c.name) + end + table.insert(parts, "classes: " .. table.concat(class_names, ", ")) + end - local filename = vim.fn.fnamemodify(filepath, ":t") - local summary = filename .. " - " .. table.concat(parts, "; ") + local filename = vim.fn.fnamemodify(filepath, ":t") + local summary = filename .. " - " .. table.concat(parts, "; ") - -- Learn this pattern in brain - brain.learn({ - type = "pattern", - file = filepath, - content = { - summary = summary, - detail = #funcs .. " functions, " .. #classes .. " classes", - }, - context = { - file = file_index.path or filepath, - language = file_index.language, - functions = funcs, - classes = classes, - exports = file_index.exports, - imports = file_index.imports, - }, - }) + -- Learn this pattern in brain + brain.learn({ + type = "pattern", + file = filepath, + content = { + summary = summary, + detail = #funcs .. " functions, " .. #classes .. " classes", + }, + context = { + file = file_index.path or filepath, + language = file_index.language, + functions = funcs, + classes = classes, + exports = file_index.exports, + imports = file_index.imports, + }, + }) end --- Schedule file indexing with debounce ---@param filepath string function M.schedule_index_file(filepath) - if not config.enabled or not config.auto_index then - return - end + if not config.enabled or not config.auto_index then + return + end - -- Check if file should be indexed - local scanner = require("codetyper.features.indexer.scanner") - if not scanner.should_index(filepath, config) then - return - end + -- Check if file should be indexed + local scanner = require("codetyper.features.indexer.scanner") + if not scanner.should_index(filepath, config) then + return + end - -- Cancel existing timer - if index_timer then - index_timer:stop() - end + -- Cancel existing timer + if index_timer then + index_timer:stop() + end - -- Schedule new index - index_timer = vim.defer_fn(function() - M.index_file(filepath) - index_timer = nil - end, INDEX_DEBOUNCE_MS) + -- Schedule new index + index_timer = vim.defer_fn(function() + M.index_file(filepath) + index_timer = nil + end, INDEX_DEBOUNCE_MS) end --- Get relevant context for a prompt ---@param opts {file: string, intent: table|nil, prompt: string, scope: string|nil} ---@return table Context information function M.get_context_for(opts) - local memory = require("codetyper.features.indexer.memory") - local index = M.load_index() + local memory = require("codetyper.features.indexer.memory") + local index = M.load_index() - local context = { - project_type = "unknown", - dependencies = {}, - relevant_files = {}, - relevant_symbols = {}, - patterns = {}, - } + local context = { + project_type = "unknown", + dependencies = {}, + relevant_files = {}, + relevant_symbols = {}, + patterns = {}, + } - if not index then - return context - end + if not index then + return context + end - context.project_type = index.project_type - context.dependencies = index.dependencies + context.project_type = index.project_type + context.dependencies = index.dependencies - -- Find relevant symbols from prompt - local words = {} - for word in opts.prompt:gmatch("%w+") do - if #word > 2 then - words[word:lower()] = true - end - end + -- Find relevant symbols from prompt + local words = {} + for word in opts.prompt:gmatch("%w+") do + if #word > 2 then + words[word:lower()] = true + end + end - -- Match symbols - for symbol, files in pairs(index.symbols) do - if words[symbol:lower()] then - context.relevant_symbols[symbol] = files - end - end + -- Match symbols + for symbol, files in pairs(index.symbols) do + if words[symbol:lower()] then + context.relevant_symbols[symbol] = files + end + end - -- Get file context if available - if opts.file then - local root = utils.get_project_root() - if root then - local relative_path = opts.file:gsub("^" .. vim.pesc(root) .. "/", "") - local file_index = index.files[relative_path] - if file_index then - context.current_file = file_index - end - end - end + -- Get file context if available + if opts.file then + local root = utils.get_project_root() + if root then + local relative_path = opts.file:gsub("^" .. vim.pesc(root) .. "/", "") + local file_index = index.files[relative_path] + if file_index then + context.current_file = file_index + end + end + end - -- Get relevant memories - context.patterns = memory.get_relevant(opts.prompt, 5) + -- Get relevant memories + context.patterns = memory.get_relevant(opts.prompt, 5) - return context + return context end --- Get index status ---@return table Status information function M.get_status() - local index = M.load_index() - if not index then - return { - indexed = false, - stats = nil, - last_indexed = nil, - } - end + local index = M.load_index() + if not index then + return { + indexed = false, + stats = nil, + last_indexed = nil, + } + end - return { - indexed = true, - stats = index.stats, - last_indexed = index.last_indexed, - project_type = index.project_type, - } + return { + indexed = true, + stats = index.stats, + last_indexed = index.last_indexed, + project_type = index.project_type, + } end --- Clear the project index function M.clear() - local root = utils.get_project_root() - if root then - index_cache[root] = nil - end + local root = utils.get_project_root() + if root then + index_cache[root] = nil + end - local path = get_index_path() - if path and utils.file_exists(path) then - os.remove(path) - end + local path = get_index_path() + if path and utils.file_exists(path) then + os.remove(path) + end end --- Setup the indexer with configuration ---@param opts? table Configuration options function M.setup(opts) - if opts then - config = vim.tbl_deep_extend("force", config, opts) - end + if opts then + config = vim.tbl_deep_extend("force", config, opts) + end - -- Index on startup if configured - if config.index_on_open then - vim.defer_fn(function() - M.index_project() - end, 1000) - end + -- Index on startup if configured + if config.index_on_open then + vim.defer_fn(function() + M.index_project() + end, 1000) + end end --- Get current configuration ---@return table function M.get_config() - return vim.deepcopy(config) + return vim.deepcopy(config) end return M diff --git a/lua/codetyper/features/indexer/memory.lua b/lua/codetyper/features/indexer/memory.lua index bc65f23..4ac90c2 100644 --- a/lua/codetyper/features/indexer/memory.lua +++ b/lua/codetyper/features/indexer/memory.lua @@ -20,9 +20,9 @@ local SYMBOLS_FILE = "symbols.json" --- In-memory cache local cache = { - patterns = nil, - conventions = nil, - symbols = nil, + patterns = nil, + conventions = nil, + symbols = nil, } ---@class Memory @@ -38,72 +38,72 @@ local cache = { --- Get the memories base directory ---@return string|nil local function get_memories_dir() - local root = utils.get_project_root() - if not root then - return nil - end - return root .. "/.codetyper/" .. MEMORIES_DIR + local root = utils.get_project_root() + if not root then + return nil + end + return root .. "/.codetyper/" .. MEMORIES_DIR end --- Get the sessions directory ---@return string|nil local function get_sessions_dir() - local root = utils.get_project_root() - if not root then - return nil - end - return root .. "/.codetyper/" .. SESSIONS_DIR + local root = utils.get_project_root() + if not root then + return nil + end + return root .. "/.codetyper/" .. SESSIONS_DIR end --- Ensure memories directory exists ---@return boolean local function ensure_memories_dir() - local dir = get_memories_dir() - if not dir then - return false - end - utils.ensure_dir(dir) - utils.ensure_dir(dir .. "/" .. FILES_DIR) - return true + local dir = get_memories_dir() + if not dir then + return false + end + utils.ensure_dir(dir) + utils.ensure_dir(dir .. "/" .. FILES_DIR) + return true end --- Ensure sessions directory exists ---@return boolean local function ensure_sessions_dir() - local dir = get_sessions_dir() - if not dir then - return false - end - return utils.ensure_dir(dir) + local dir = get_sessions_dir() + if not dir then + return false + end + return utils.ensure_dir(dir) end --- Generate a unique ID ---@return string local function generate_id() - return string.format("mem_%d_%s", os.time(), string.sub(tostring(math.random()), 3, 8)) + return string.format("mem_%d_%s", os.time(), string.sub(tostring(math.random()), 3, 8)) end --- Load a memory file ---@param filename string ---@return table local function load_memory_file(filename) - local dir = get_memories_dir() - if not dir then - return {} - end + local dir = get_memories_dir() + if not dir then + return {} + end - local path = dir .. "/" .. filename - local content = utils.read_file(path) - if not content then - return {} - end + local path = dir .. "/" .. filename + local content = utils.read_file(path) + if not content then + return {} + end - local ok, data = pcall(vim.json.decode, content) - if not ok or not data then - return {} - end + local ok, data = pcall(vim.json.decode, content) + if not ok or not data then + return {} + end - return data + return data end --- Save a memory file @@ -111,91 +111,91 @@ end ---@param data table ---@return boolean local function save_memory_file(filename, data) - if not ensure_memories_dir() then - return false - end + if not ensure_memories_dir() then + return false + end - local dir = get_memories_dir() - if not dir then - return false - end + local dir = get_memories_dir() + if not dir then + return false + end - local path = dir .. "/" .. filename - local ok, encoded = pcall(vim.json.encode, data) - if not ok then - return false - end + local path = dir .. "/" .. filename + local ok, encoded = pcall(vim.json.encode, data) + if not ok then + return false + end - return utils.write_file(path, encoded) + return utils.write_file(path, encoded) end --- Hash a file path for storage ---@param filepath string ---@return string local function hash_path(filepath) - local hash = 0 - for i = 1, #filepath do - hash = (hash * 31 + string.byte(filepath, i)) % 2147483647 - end - return string.format("%08x", hash) + local hash = 0 + for i = 1, #filepath do + hash = (hash * 31 + string.byte(filepath, i)) % 2147483647 + end + return string.format("%08x", hash) end --- Load patterns from cache or disk ---@return table function M.load_patterns() - if cache.patterns then - return cache.patterns - end - cache.patterns = load_memory_file(PATTERNS_FILE) - return cache.patterns + if cache.patterns then + return cache.patterns + end + cache.patterns = load_memory_file(PATTERNS_FILE) + return cache.patterns end --- Load conventions from cache or disk ---@return table function M.load_conventions() - if cache.conventions then - return cache.conventions - end - cache.conventions = load_memory_file(CONVENTIONS_FILE) - return cache.conventions + if cache.conventions then + return cache.conventions + end + cache.conventions = load_memory_file(CONVENTIONS_FILE) + return cache.conventions end --- Load symbols from cache or disk ---@return table function M.load_symbols() - if cache.symbols then - return cache.symbols - end - cache.symbols = load_memory_file(SYMBOLS_FILE) - return cache.symbols + if cache.symbols then + return cache.symbols + end + cache.symbols = load_memory_file(SYMBOLS_FILE) + return cache.symbols end --- Store a new memory ---@param memory Memory ---@return boolean function M.store_memory(memory) - memory.id = memory.id or generate_id() - memory.created_at = memory.created_at or os.time() - memory.updated_at = os.time() - memory.used_count = memory.used_count or 0 - memory.weight = memory.weight or 0.5 + memory.id = memory.id or generate_id() + memory.created_at = memory.created_at or os.time() + memory.updated_at = os.time() + memory.used_count = memory.used_count or 0 + memory.weight = memory.weight or 0.5 - local filename - if memory.type == "pattern" then - filename = PATTERNS_FILE - cache.patterns = nil - elseif memory.type == "convention" then - filename = CONVENTIONS_FILE - cache.conventions = nil - else - filename = PATTERNS_FILE - cache.patterns = nil - end + local filename + if memory.type == "pattern" then + filename = PATTERNS_FILE + cache.patterns = nil + elseif memory.type == "convention" then + filename = CONVENTIONS_FILE + cache.conventions = nil + else + filename = PATTERNS_FILE + cache.patterns = nil + end - local data = load_memory_file(filename) - data[memory.id] = memory + local data = load_memory_file(filename) + data[memory.id] = memory - return save_memory_file(filename, data) + return save_memory_file(filename, data) end --- Store file-specific memory @@ -203,145 +203,145 @@ end ---@param file_index table FileIndex data ---@return boolean function M.store_file_memory(relative_path, file_index) - if not ensure_memories_dir() then - return false - end + if not ensure_memories_dir() then + return false + end - local dir = get_memories_dir() - if not dir then - return false - end + local dir = get_memories_dir() + if not dir then + return false + end - local hash = hash_path(relative_path) - local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json" + local hash = hash_path(relative_path) + local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json" - local data = { - path = relative_path, - indexed_at = os.time(), - functions = file_index.functions or {}, - classes = file_index.classes or {}, - exports = file_index.exports or {}, - imports = file_index.imports or {}, - } + local data = { + path = relative_path, + indexed_at = os.time(), + functions = file_index.functions or {}, + classes = file_index.classes or {}, + exports = file_index.exports or {}, + imports = file_index.imports or {}, + } - local ok, encoded = pcall(vim.json.encode, data) - if not ok then - return false - end + local ok, encoded = pcall(vim.json.encode, data) + if not ok then + return false + end - return utils.write_file(path, encoded) + return utils.write_file(path, encoded) end --- Load file-specific memory ---@param relative_path string ---@return table|nil function M.load_file_memory(relative_path) - local dir = get_memories_dir() - if not dir then - return nil - end + local dir = get_memories_dir() + if not dir then + return nil + end - local hash = hash_path(relative_path) - local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json" + local hash = hash_path(relative_path) + local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json" - local content = utils.read_file(path) - if not content then - return nil - end + local content = utils.read_file(path) + if not content then + return nil + end - local ok, data = pcall(vim.json.decode, content) - if not ok then - return nil - end + local ok, data = pcall(vim.json.decode, content) + if not ok then + return nil + end - return data + return data end --- Store index summary as memories ---@param index ProjectIndex function M.store_index_summary(index) - -- Store project type convention - if index.project_type and index.project_type ~= "unknown" then - M.store_memory({ - type = "convention", - content = "Project uses " .. index.project_type .. " ecosystem", - context = { - project_root = index.project_root, - detected_at = os.time(), - }, - weight = 0.9, - }) - end + -- Store project type convention + if index.project_type and index.project_type ~= "unknown" then + M.store_memory({ + type = "convention", + content = "Project uses " .. index.project_type .. " ecosystem", + context = { + project_root = index.project_root, + detected_at = os.time(), + }, + weight = 0.9, + }) + end - -- Store dependency patterns - local dep_count = 0 - for _ in pairs(index.dependencies or {}) do - dep_count = dep_count + 1 - end + -- Store dependency patterns + local dep_count = 0 + for _ in pairs(index.dependencies or {}) do + dep_count = dep_count + 1 + end - if dep_count > 0 then - local deps_list = {} - for name, _ in pairs(index.dependencies) do - table.insert(deps_list, name) - end + if dep_count > 0 then + local deps_list = {} + for name, _ in pairs(index.dependencies) do + table.insert(deps_list, name) + end - M.store_memory({ - type = "pattern", - content = "Project dependencies: " .. table.concat(deps_list, ", "), - context = { - dependency_count = dep_count, - }, - weight = 0.7, - }) - end + M.store_memory({ + type = "pattern", + content = "Project dependencies: " .. table.concat(deps_list, ", "), + context = { + dependency_count = dep_count, + }, + weight = 0.7, + }) + end - -- Update symbol cache - cache.symbols = nil - save_memory_file(SYMBOLS_FILE, index.symbols or {}) + -- Update symbol cache + cache.symbols = nil + save_memory_file(SYMBOLS_FILE, index.symbols or {}) end --- Store session interaction ---@param interaction {prompt: string, response: string, file: string|nil, success: boolean} function M.store_session(interaction) - if not ensure_sessions_dir() then - return - end + if not ensure_sessions_dir() then + return + end - local dir = get_sessions_dir() - if not dir then - return - end + local dir = get_sessions_dir() + if not dir then + return + end - -- Use date-based session files - local date = os.date("%Y-%m-%d") - local path = dir .. "/" .. date .. ".json" + -- Use date-based session files + local date = os.date("%Y-%m-%d") + local path = dir .. "/" .. date .. ".json" - local sessions = {} - local content = utils.read_file(path) - if content then - local ok, data = pcall(vim.json.decode, content) - if ok and data then - sessions = data - end - end + local sessions = {} + local content = utils.read_file(path) + if content then + local ok, data = pcall(vim.json.decode, content) + if ok and data then + sessions = data + end + end - table.insert(sessions, { - timestamp = os.time(), - prompt = interaction.prompt, - response = string.sub(interaction.response or "", 1, 500), -- Truncate - file = interaction.file, - success = interaction.success, - }) + table.insert(sessions, { + timestamp = os.time(), + prompt = interaction.prompt, + response = string.sub(interaction.response or "", 1, 500), -- Truncate + file = interaction.file, + success = interaction.success, + }) - -- Limit session size - if #sessions > 100 then - sessions = { unpack(sessions, #sessions - 99) } - end + -- Limit session size + if #sessions > 100 then + sessions = { unpack(sessions, #sessions - 99) } + end - local ok, encoded = pcall(vim.json.encode, sessions) - if ok then - utils.write_file(path, encoded) - end + local ok, encoded = pcall(vim.json.encode, sessions) + if ok then + utils.write_file(path, encoded) + end end --- Get relevant memories for a query @@ -349,191 +349,191 @@ end ---@param limit number Maximum results ---@return Memory[] function M.get_relevant(query, limit) - limit = limit or 10 - local results = {} + limit = limit or 10 + local results = {} - -- Tokenize query - local query_words = {} - for word in query:lower():gmatch("%w+") do - if #word > 2 then - query_words[word] = true - end - end + -- Tokenize query + local query_words = {} + for word in query:lower():gmatch("%w+") do + if #word > 2 then + query_words[word] = true + end + end - -- Search patterns - local patterns = M.load_patterns() - for _, memory in pairs(patterns) do - local score = 0 - local content_lower = (memory.content or ""):lower() + -- Search patterns + local patterns = M.load_patterns() + for _, memory in pairs(patterns) do + local score = 0 + local content_lower = (memory.content or ""):lower() - for word in pairs(query_words) do - if content_lower:find(word, 1, true) then - score = score + 1 - end - end + for word in pairs(query_words) do + if content_lower:find(word, 1, true) then + score = score + 1 + end + end - if score > 0 then - memory.relevance_score = score * (memory.weight or 0.5) - table.insert(results, memory) - end - end + if score > 0 then + memory.relevance_score = score * (memory.weight or 0.5) + table.insert(results, memory) + end + end - -- Search conventions - local conventions = M.load_conventions() - for _, memory in pairs(conventions) do - local score = 0 - local content_lower = (memory.content or ""):lower() + -- Search conventions + local conventions = M.load_conventions() + for _, memory in pairs(conventions) do + local score = 0 + local content_lower = (memory.content or ""):lower() - for word in pairs(query_words) do - if content_lower:find(word, 1, true) then - score = score + 1 - end - end + for word in pairs(query_words) do + if content_lower:find(word, 1, true) then + score = score + 1 + end + end - if score > 0 then - memory.relevance_score = score * (memory.weight or 0.5) - table.insert(results, memory) - end - end + if score > 0 then + memory.relevance_score = score * (memory.weight or 0.5) + table.insert(results, memory) + end + end - -- Sort by relevance - table.sort(results, function(a, b) - return (a.relevance_score or 0) > (b.relevance_score or 0) - end) + -- Sort by relevance + table.sort(results, function(a, b) + return (a.relevance_score or 0) > (b.relevance_score or 0) + end) - -- Limit results - local limited = {} - for i = 1, math.min(limit, #results) do - limited[i] = results[i] - end + -- Limit results + local limited = {} + for i = 1, math.min(limit, #results) do + limited[i] = results[i] + end - return limited + return limited end --- Update memory usage count ---@param memory_id string function M.update_usage(memory_id) - local patterns = M.load_patterns() - if patterns[memory_id] then - patterns[memory_id].used_count = (patterns[memory_id].used_count or 0) + 1 - patterns[memory_id].updated_at = os.time() - save_memory_file(PATTERNS_FILE, patterns) - cache.patterns = nil - return - end + local patterns = M.load_patterns() + if patterns[memory_id] then + patterns[memory_id].used_count = (patterns[memory_id].used_count or 0) + 1 + patterns[memory_id].updated_at = os.time() + save_memory_file(PATTERNS_FILE, patterns) + cache.patterns = nil + return + end - local conventions = M.load_conventions() - if conventions[memory_id] then - conventions[memory_id].used_count = (conventions[memory_id].used_count or 0) + 1 - conventions[memory_id].updated_at = os.time() - save_memory_file(CONVENTIONS_FILE, conventions) - cache.conventions = nil - end + local conventions = M.load_conventions() + if conventions[memory_id] then + conventions[memory_id].used_count = (conventions[memory_id].used_count or 0) + 1 + conventions[memory_id].updated_at = os.time() + save_memory_file(CONVENTIONS_FILE, conventions) + cache.conventions = nil + end end --- Get all memories ---@return {patterns: table, conventions: table, symbols: table} function M.get_all() - return { - patterns = M.load_patterns(), - conventions = M.load_conventions(), - symbols = M.load_symbols(), - } + return { + patterns = M.load_patterns(), + conventions = M.load_conventions(), + symbols = M.load_symbols(), + } end --- Clear all memories ---@param pattern? string Optional pattern to match memory IDs function M.clear(pattern) - if not pattern then - -- Clear all - cache = { patterns = nil, conventions = nil, symbols = nil } - save_memory_file(PATTERNS_FILE, {}) - save_memory_file(CONVENTIONS_FILE, {}) - save_memory_file(SYMBOLS_FILE, {}) - return - end + if not pattern then + -- Clear all + cache = { patterns = nil, conventions = nil, symbols = nil } + save_memory_file(PATTERNS_FILE, {}) + save_memory_file(CONVENTIONS_FILE, {}) + save_memory_file(SYMBOLS_FILE, {}) + return + end - -- Clear matching pattern - local patterns = M.load_patterns() - for id in pairs(patterns) do - if id:match(pattern) then - patterns[id] = nil - end - end - save_memory_file(PATTERNS_FILE, patterns) - cache.patterns = nil + -- Clear matching pattern + local patterns = M.load_patterns() + for id in pairs(patterns) do + if id:match(pattern) then + patterns[id] = nil + end + end + save_memory_file(PATTERNS_FILE, patterns) + cache.patterns = nil - local conventions = M.load_conventions() - for id in pairs(conventions) do - if id:match(pattern) then - conventions[id] = nil - end - end - save_memory_file(CONVENTIONS_FILE, conventions) - cache.conventions = nil + local conventions = M.load_conventions() + for id in pairs(conventions) do + if id:match(pattern) then + conventions[id] = nil + end + end + save_memory_file(CONVENTIONS_FILE, conventions) + cache.conventions = nil end --- Prune low-weight memories ---@param threshold number Weight threshold (default: 0.1) function M.prune(threshold) - threshold = threshold or 0.1 + threshold = threshold or 0.1 - local patterns = M.load_patterns() - local pruned = 0 - for id, memory in pairs(patterns) do - if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then - patterns[id] = nil - pruned = pruned + 1 - end - end - if pruned > 0 then - save_memory_file(PATTERNS_FILE, patterns) - cache.patterns = nil - end + local patterns = M.load_patterns() + local pruned = 0 + for id, memory in pairs(patterns) do + if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then + patterns[id] = nil + pruned = pruned + 1 + end + end + if pruned > 0 then + save_memory_file(PATTERNS_FILE, patterns) + cache.patterns = nil + end - local conventions = M.load_conventions() - for id, memory in pairs(conventions) do - if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then - conventions[id] = nil - pruned = pruned + 1 - end - end - if pruned > 0 then - save_memory_file(CONVENTIONS_FILE, conventions) - cache.conventions = nil - end + local conventions = M.load_conventions() + for id, memory in pairs(conventions) do + if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then + conventions[id] = nil + pruned = pruned + 1 + end + end + if pruned > 0 then + save_memory_file(CONVENTIONS_FILE, conventions) + cache.conventions = nil + end - return pruned + return pruned end --- Get memory statistics ---@return table function M.get_stats() - local patterns = M.load_patterns() - local conventions = M.load_conventions() - local symbols = M.load_symbols() + local patterns = M.load_patterns() + local conventions = M.load_conventions() + local symbols = M.load_symbols() - local pattern_count = 0 - for _ in pairs(patterns) do - pattern_count = pattern_count + 1 - end + local pattern_count = 0 + for _ in pairs(patterns) do + pattern_count = pattern_count + 1 + end - local convention_count = 0 - for _ in pairs(conventions) do - convention_count = convention_count + 1 - end + local convention_count = 0 + for _ in pairs(conventions) do + convention_count = convention_count + 1 + end - local symbol_count = 0 - for _ in pairs(symbols) do - symbol_count = symbol_count + 1 - end + local symbol_count = 0 + for _ in pairs(symbols) do + symbol_count = symbol_count + 1 + end - return { - patterns = pattern_count, - conventions = convention_count, - symbols = symbol_count, - total = pattern_count + convention_count, - } + return { + patterns = pattern_count, + conventions = convention_count, + symbols = symbol_count, + total = pattern_count + convention_count, + } end return M diff --git a/lua/codetyper/features/indexer/scanner.lua b/lua/codetyper/features/indexer/scanner.lua index a789ecf..8fa297d 100644 --- a/lua/codetyper/features/indexer/scanner.lua +++ b/lua/codetyper/features/indexer/scanner.lua @@ -9,78 +9,78 @@ local utils = require("codetyper.support.utils") --- Project type markers local PROJECT_MARKERS = { - node = { "package.json" }, - rust = { "Cargo.toml" }, - go = { "go.mod" }, - python = { "pyproject.toml", "setup.py", "requirements.txt" }, - lua = { "init.lua", ".luarc.json" }, - ruby = { "Gemfile" }, - java = { "pom.xml", "build.gradle" }, - csharp = { "*.csproj", "*.sln" }, + node = { "package.json" }, + rust = { "Cargo.toml" }, + go = { "go.mod" }, + python = { "pyproject.toml", "setup.py", "requirements.txt" }, + lua = { "init.lua", ".luarc.json" }, + ruby = { "Gemfile" }, + java = { "pom.xml", "build.gradle" }, + csharp = { "*.csproj", "*.sln" }, } --- File extension to language mapping local EXTENSION_LANGUAGE = { - lua = "lua", - ts = "typescript", - tsx = "typescriptreact", - js = "javascript", - jsx = "javascriptreact", - py = "python", - go = "go", - rs = "rust", - rb = "ruby", - java = "java", - c = "c", - cpp = "cpp", - h = "c", - hpp = "cpp", - cs = "csharp", + lua = "lua", + ts = "typescript", + tsx = "typescriptreact", + js = "javascript", + jsx = "javascriptreact", + py = "python", + go = "go", + rs = "rust", + rb = "ruby", + java = "java", + c = "c", + cpp = "cpp", + h = "c", + hpp = "cpp", + cs = "csharp", } --- Default ignore patterns local DEFAULT_IGNORES = { - "^%.", -- Hidden files/folders - "^node_modules$", - "^__pycache__$", - "^%.git$", - "^%.codetyper$", - "^dist$", - "^build$", - "^target$", - "^vendor$", - "^%.next$", - "^%.nuxt$", - "^coverage$", - "%.min%.js$", - "%.min%.css$", - "%.map$", - "%.lock$", - "%-lock%.json$", + "^%.", -- Hidden files/folders + "^node_modules$", + "^__pycache__$", + "^%.git$", + "^%.codetyper$", + "^dist$", + "^build$", + "^target$", + "^vendor$", + "^%.next$", + "^%.nuxt$", + "^coverage$", + "%.min%.js$", + "%.min%.css$", + "%.map$", + "%.lock$", + "%-lock%.json$", } --- Detect project type from root markers ---@param root string Project root path ---@return string Project type function M.detect_project_type(root) - for project_type, markers in pairs(PROJECT_MARKERS) do - for _, marker in ipairs(markers) do - local path = root .. "/" .. marker - if marker:match("^%*") then - -- Glob pattern - local pattern = marker:gsub("^%*", "") - local entries = vim.fn.glob(root .. "/*" .. pattern, false, true) - if #entries > 0 then - return project_type - end - else - if utils.file_exists(path) then - return project_type - end - end - end - end - return "unknown" + for project_type, markers in pairs(PROJECT_MARKERS) do + for _, marker in ipairs(markers) do + local path = root .. "/" .. marker + if marker:match("^%*") then + -- Glob pattern + local pattern = marker:gsub("^%*", "") + local entries = vim.fn.glob(root .. "/*" .. pattern, false, true) + if #entries > 0 then + return project_type + end + else + if utils.file_exists(path) then + return project_type + end + end + end + end + return "unknown" end --- Parse project dependencies @@ -88,182 +88,182 @@ end ---@param project_type string Project type ---@return {dependencies: table, dev_dependencies: table} function M.parse_dependencies(root, project_type) - local deps = { - dependencies = {}, - dev_dependencies = {}, - } + local deps = { + dependencies = {}, + dev_dependencies = {}, + } - if project_type == "node" then - deps = M.parse_package_json(root) - elseif project_type == "rust" then - deps = M.parse_cargo_toml(root) - elseif project_type == "go" then - deps = M.parse_go_mod(root) - elseif project_type == "python" then - deps = M.parse_python_deps(root) - end + if project_type == "node" then + deps = M.parse_package_json(root) + elseif project_type == "rust" then + deps = M.parse_cargo_toml(root) + elseif project_type == "go" then + deps = M.parse_go_mod(root) + elseif project_type == "python" then + deps = M.parse_python_deps(root) + end - return deps + return deps end --- Parse package.json for Node.js projects ---@param root string Project root path ---@return {dependencies: table, dev_dependencies: table} function M.parse_package_json(root) - local path = root .. "/package.json" - local content = utils.read_file(path) - if not content then - return { dependencies = {}, dev_dependencies = {} } - end + local path = root .. "/package.json" + local content = utils.read_file(path) + if not content then + return { dependencies = {}, dev_dependencies = {} } + end - local ok, pkg = pcall(vim.json.decode, content) - if not ok or not pkg then - return { dependencies = {}, dev_dependencies = {} } - end + local ok, pkg = pcall(vim.json.decode, content) + if not ok or not pkg then + return { dependencies = {}, dev_dependencies = {} } + end - return { - dependencies = pkg.dependencies or {}, - dev_dependencies = pkg.devDependencies or {}, - } + return { + dependencies = pkg.dependencies or {}, + dev_dependencies = pkg.devDependencies or {}, + } end --- Parse Cargo.toml for Rust projects ---@param root string Project root path ---@return {dependencies: table, dev_dependencies: table} function M.parse_cargo_toml(root) - local path = root .. "/Cargo.toml" - local content = utils.read_file(path) - if not content then - return { dependencies = {}, dev_dependencies = {} } - end + local path = root .. "/Cargo.toml" + local content = utils.read_file(path) + if not content then + return { dependencies = {}, dev_dependencies = {} } + end - local deps = {} - local dev_deps = {} - local in_deps = false - local in_dev_deps = false + local deps = {} + local dev_deps = {} + local in_deps = false + local in_dev_deps = false - for line in content:gmatch("[^\n]+") do - if line:match("^%[dependencies%]") then - in_deps = true - in_dev_deps = false - elseif line:match("^%[dev%-dependencies%]") then - in_deps = false - in_dev_deps = true - elseif line:match("^%[") then - in_deps = false - in_dev_deps = false - elseif in_deps or in_dev_deps then - local name, version = line:match('^([%w_%-]+)%s*=%s*"([^"]+)"') - if not name then - name = line:match("^([%w_%-]+)%s*=") - version = "workspace" - end - if name then - if in_deps then - deps[name] = version or "unknown" - else - dev_deps[name] = version or "unknown" - end - end - end - end + for line in content:gmatch("[^\n]+") do + if line:match("^%[dependencies%]") then + in_deps = true + in_dev_deps = false + elseif line:match("^%[dev%-dependencies%]") then + in_deps = false + in_dev_deps = true + elseif line:match("^%[") then + in_deps = false + in_dev_deps = false + elseif in_deps or in_dev_deps then + local name, version = line:match('^([%w_%-]+)%s*=%s*"([^"]+)"') + if not name then + name = line:match("^([%w_%-]+)%s*=") + version = "workspace" + end + if name then + if in_deps then + deps[name] = version or "unknown" + else + dev_deps[name] = version or "unknown" + end + end + end + end - return { dependencies = deps, dev_dependencies = dev_deps } + return { dependencies = deps, dev_dependencies = dev_deps } end --- Parse go.mod for Go projects ---@param root string Project root path ---@return {dependencies: table, dev_dependencies: table} function M.parse_go_mod(root) - local path = root .. "/go.mod" - local content = utils.read_file(path) - if not content then - return { dependencies = {}, dev_dependencies = {} } - end + local path = root .. "/go.mod" + local content = utils.read_file(path) + if not content then + return { dependencies = {}, dev_dependencies = {} } + end - local deps = {} - local in_require = false + local deps = {} + local in_require = false - for line in content:gmatch("[^\n]+") do - if line:match("^require%s*%(") then - in_require = true - elseif line:match("^%)") then - in_require = false - elseif in_require then - local module, version = line:match("^%s*([%w%.%-%_/]+)%s+([%w%.%-]+)") - if module then - deps[module] = version - end - else - local module, version = line:match("^require%s+([%w%.%-%_/]+)%s+([%w%.%-]+)") - if module then - deps[module] = version - end - end - end + for line in content:gmatch("[^\n]+") do + if line:match("^require%s*%(") then + in_require = true + elseif line:match("^%)") then + in_require = false + elseif in_require then + local module, version = line:match("^%s*([%w%.%-%_/]+)%s+([%w%.%-]+)") + if module then + deps[module] = version + end + else + local module, version = line:match("^require%s+([%w%.%-%_/]+)%s+([%w%.%-]+)") + if module then + deps[module] = version + end + end + end - return { dependencies = deps, dev_dependencies = {} } + return { dependencies = deps, dev_dependencies = {} } end --- Parse Python dependencies (pyproject.toml or requirements.txt) ---@param root string Project root path ---@return {dependencies: table, dev_dependencies: table} function M.parse_python_deps(root) - local deps = {} - local dev_deps = {} + local deps = {} + local dev_deps = {} - -- Try pyproject.toml first - local pyproject = root .. "/pyproject.toml" - local content = utils.read_file(pyproject) + -- Try pyproject.toml first + local pyproject = root .. "/pyproject.toml" + local content = utils.read_file(pyproject) - if content then - -- Simple parsing for dependencies - local in_deps = false - local in_dev = false + if content then + -- Simple parsing for dependencies + local in_deps = false + local in_dev = false - for line in content:gmatch("[^\n]+") do - if line:match("^%[project%.dependencies%]") or line:match("^dependencies%s*=") then - in_deps = true - in_dev = false - elseif line:match("dev") and line:match("dependencies") then - in_deps = false - in_dev = true - elseif line:match("^%[") then - in_deps = false - in_dev = false - elseif in_deps or in_dev then - local name = line:match('"([%w_%-]+)') - if name then - if in_deps then - deps[name] = "latest" - else - dev_deps[name] = "latest" - end - end - end - end - end + for line in content:gmatch("[^\n]+") do + if line:match("^%[project%.dependencies%]") or line:match("^dependencies%s*=") then + in_deps = true + in_dev = false + elseif line:match("dev") and line:match("dependencies") then + in_deps = false + in_dev = true + elseif line:match("^%[") then + in_deps = false + in_dev = false + elseif in_deps or in_dev then + local name = line:match('"([%w_%-]+)') + if name then + if in_deps then + deps[name] = "latest" + else + dev_deps[name] = "latest" + end + end + end + end + end - -- Fallback to requirements.txt - local req_file = root .. "/requirements.txt" - content = utils.read_file(req_file) + -- Fallback to requirements.txt + local req_file = root .. "/requirements.txt" + content = utils.read_file(req_file) - if content then - for line in content:gmatch("[^\n]+") do - if not line:match("^#") and not line:match("^%s*$") then - local name, version = line:match("^([%w_%-]+)==([%d%.]+)") - if not name then - name = line:match("^([%w_%-]+)") - version = "latest" - end - if name then - deps[name] = version or "latest" - end - end - end - end + if content then + for line in content:gmatch("[^\n]+") do + if not line:match("^#") and not line:match("^%s*$") then + local name, version = line:match("^([%w_%-]+)==([%d%.]+)") + if not name then + name = line:match("^([%w_%-]+)") + version = "latest" + end + if name then + deps[name] = version or "latest" + end + end + end + end - return { dependencies = deps, dev_dependencies = dev_deps } + return { dependencies = deps, dev_dependencies = dev_deps } end --- Check if a file/directory should be ignored @@ -271,23 +271,23 @@ end ---@param config table Indexer configuration ---@return boolean function M.should_ignore(name, config) - -- Check default patterns - for _, pattern in ipairs(DEFAULT_IGNORES) do - if name:match(pattern) then - return true - end - end + -- Check default patterns + for _, pattern in ipairs(DEFAULT_IGNORES) do + if name:match(pattern) then + return true + end + end - -- Check config excluded dirs - if config and config.excluded_dirs then - for _, dir in ipairs(config.excluded_dirs) do - if name == dir then - return true - end - end - end + -- Check config excluded dirs + if config and config.excluded_dirs then + for _, dir in ipairs(config.excluded_dirs) do + if name == dir then + return true + end + end + end - return false + return false end --- Check if a file should be indexed @@ -295,42 +295,42 @@ end ---@param config table Indexer configuration ---@return boolean function M.should_index(filepath, config) - local name = vim.fn.fnamemodify(filepath, ":t") - local ext = vim.fn.fnamemodify(filepath, ":e") + local name = vim.fn.fnamemodify(filepath, ":t") + local ext = vim.fn.fnamemodify(filepath, ":e") - -- Check if it's a coder file - if utils.is_coder_file(filepath) then - return false - end + -- Check if it's a coder file + if utils.is_coder_file(filepath) then + return false + end - -- Check file size - if config and config.max_file_size then - local stat = vim.loop.fs_stat(filepath) - if stat and stat.size > config.max_file_size then - return false - end - end + -- Check file size + if config and config.max_file_size then + local stat = vim.loop.fs_stat(filepath) + if stat and stat.size > config.max_file_size then + return false + end + end - -- Check extension - if config and config.index_extensions then - local valid_ext = false - for _, allowed_ext in ipairs(config.index_extensions) do - if ext == allowed_ext then - valid_ext = true - break - end - end - if not valid_ext then - return false - end - end + -- Check extension + if config and config.index_extensions then + local valid_ext = false + for _, allowed_ext in ipairs(config.index_extensions) do + if ext == allowed_ext then + valid_ext = true + break + end + end + if not valid_ext then + return false + end + end - -- Check ignore patterns - if M.should_ignore(name, config) then - return false - end + -- Check ignore patterns + if M.should_ignore(name, config) then + return false + end - return true + return true end --- Get all indexable files in the project @@ -338,72 +338,72 @@ end ---@param config table Indexer configuration ---@return string[] List of file paths function M.get_indexable_files(root, config) - local files = {} + local files = {} - local function scan_dir(path) - local handle = vim.loop.fs_scandir(path) - if not handle then - return - end + local function scan_dir(path) + local handle = vim.loop.fs_scandir(path) + if not handle then + return + end - while true do - local name, type = vim.loop.fs_scandir_next(handle) - if not name then - break - end + while true do + local name, type = vim.loop.fs_scandir_next(handle) + if not name then + break + end - local full_path = path .. "/" .. name + local full_path = path .. "/" .. name - if M.should_ignore(name, config) then - goto continue - end + if M.should_ignore(name, config) then + goto continue + end - if type == "directory" then - scan_dir(full_path) - elseif type == "file" then - if M.should_index(full_path, config) then - table.insert(files, full_path) - end - end + if type == "directory" then + scan_dir(full_path) + elseif type == "file" then + if M.should_index(full_path, config) then + table.insert(files, full_path) + end + end - ::continue:: - end - end + ::continue:: + end + end - scan_dir(root) - return files + scan_dir(root) + return files end --- Get language from file extension ---@param filepath string File path ---@return string Language name function M.get_language(filepath) - local ext = vim.fn.fnamemodify(filepath, ":e") - return EXTENSION_LANGUAGE[ext] or ext + local ext = vim.fn.fnamemodify(filepath, ":e") + return EXTENSION_LANGUAGE[ext] or ext end --- Read .gitignore patterns ---@param root string Project root ---@return string[] Patterns function M.read_gitignore(root) - local patterns = {} - local path = root .. "/.gitignore" - local content = utils.read_file(path) + local patterns = {} + local path = root .. "/.gitignore" + local content = utils.read_file(path) - if not content then - return patterns - end + if not content then + return patterns + end - for line in content:gmatch("[^\n]+") do - -- Skip comments and empty lines - if not line:match("^#") and not line:match("^%s*$") then - -- Convert gitignore pattern to Lua pattern (simplified) - local pattern = line:gsub("^/", "^"):gsub("%*%*", ".*"):gsub("%*", "[^/]*"):gsub("%?", ".") - table.insert(patterns, pattern) - end - end + for line in content:gmatch("[^\n]+") do + -- Skip comments and empty lines + if not line:match("^#") and not line:match("^%s*$") then + -- Convert gitignore pattern to Lua pattern (simplified) + local pattern = line:gsub("^/", "^"):gsub("%*%*", ".*"):gsub("%*", "[^/]*"):gsub("%?", ".") + table.insert(patterns, pattern) + end + end - return patterns + return patterns end return M diff --git a/lua/codetyper/init.lua b/lua/codetyper/init.lua index 1ac1766..5bf557b 100644 --- a/lua/codetyper/init.lua +++ b/lua/codetyper/init.lua @@ -20,66 +20,66 @@ M._initialized = false --- Setup the plugin with user configuration ---@param opts? CoderConfig User configuration options function M.setup(opts) - if M._initialized then - return - end + if M._initialized then + return + end - local config = require("codetyper.config.defaults") - M.config = config.setup(opts) + local config = require("codetyper.config.defaults") + M.config = config.setup(opts) - -- Initialize modules - local commands = require("codetyper.adapters.nvim.commands") - local gitignore = require("codetyper.support.gitignore") - local autocmds = require("codetyper.adapters.nvim.autocmds") - local tree = require("codetyper.support.tree") - local completion = require("codetyper.features.completion.inline") + -- Initialize modules + local commands = require("codetyper.adapters.nvim.commands") + local gitignore = require("codetyper.support.gitignore") + local autocmds = require("codetyper.adapters.nvim.autocmds") + local tree = require("codetyper.support.tree") + local completion = require("codetyper.features.completion.inline") - -- Register commands - commands.setup() + -- Register commands + commands.setup() - -- Setup autocommands - autocmds.setup() + -- Setup autocommands + autocmds.setup() - -- Setup file reference completion - completion.setup() + -- Setup file reference completion + completion.setup() - -- Ensure .gitignore has coder files excluded - gitignore.ensure_ignored() + -- Ensure .gitignore has coder files excluded + gitignore.ensure_ignored() - -- Initialize tree logging (creates .codetyper folder and initial tree.log) - tree.setup() + -- Initialize tree logging (creates .codetyper folder and initial tree.log) + tree.setup() - -- Initialize project indexer if enabled - if M.config.indexer and M.config.indexer.enabled then - local indexer = require("codetyper.features.indexer") - indexer.setup(M.config.indexer) - end + -- Initialize project indexer if enabled + if M.config.indexer and M.config.indexer.enabled then + local indexer = require("codetyper.features.indexer") + indexer.setup(M.config.indexer) + end - -- Initialize brain learning system if enabled - if M.config.brain and M.config.brain.enabled then - local brain = require("codetyper.core.memory") - brain.setup(M.config.brain) - end + -- Initialize brain learning system if enabled + if M.config.brain and M.config.brain.enabled then + local brain = require("codetyper.core.memory") + brain.setup(M.config.brain) + end - -- Setup inline ghost text suggestions (Copilot-style) - if M.config.suggestion and M.config.suggestion.enabled then - local suggestion = require("codetyper.features.completion.suggestion") - suggestion.setup(M.config.suggestion) - end + -- Setup inline ghost text suggestions (Copilot-style) + if M.config.suggestion and M.config.suggestion.enabled then + local suggestion = require("codetyper.features.completion.suggestion") + suggestion.setup(M.config.suggestion) + end - M._initialized = true + M._initialized = true end --- Get current configuration ---@return CoderConfig function M.get_config() - return M.config + return M.config end --- Check if plugin is initialized ---@return boolean function M.is_initialized() - return M._initialized + return M._initialized end return M diff --git a/lua/codetyper/inject.lua b/lua/codetyper/inject.lua index 5c8f654..25f12c7 100644 --- a/lua/codetyper/inject.lua +++ b/lua/codetyper/inject.lua @@ -53,10 +53,10 @@ function M.inject_code(target_path, code, prompt_type) -- For generic, auto-add instead of prompting M.inject_add(target_buf, code) end - + -- Mark buffer as modified and save vim.bo[target_buf].modified = true - + -- Auto-save the target file vim.schedule(function() if vim.api.nvim_buf_is_valid(target_buf) then @@ -76,34 +76,34 @@ end ---@param opts table|nil { strategy = "replace"|"insert"|"append", range = { start_line, end_line } (1-based) } ---@return table { imports_added: number, body_lines: number, imports_merged: boolean } function M.inject(bufnr, code, opts) - opts = opts or {} - local strategy = opts.strategy or "replace" - local range = opts.range - local lines = vim.split(code, "\n", { plain = true }) - local body_lines = #lines + opts = opts or {} + local strategy = opts.strategy or "replace" + local range = opts.range + local lines = vim.split(code, "\n", { plain = true }) + local body_lines = #lines - if not vim.api.nvim_buf_is_valid(bufnr) then - return { imports_added = 0, body_lines = 0, imports_merged = false } - end + if not vim.api.nvim_buf_is_valid(bufnr) then + return { imports_added = 0, body_lines = 0, imports_merged = false } + end - local line_count = vim.api.nvim_buf_line_count(bufnr) + local line_count = vim.api.nvim_buf_line_count(bufnr) - if strategy == "replace" and range and range.start_line and range.end_line then - local start_0 = math.max(0, range.start_line - 1) - local end_0 = math.min(line_count, range.end_line) - if end_0 < start_0 then - end_0 = start_0 - end - vim.api.nvim_buf_set_lines(bufnr, start_0, end_0, false, lines) - elseif strategy == "insert" and range and range.start_line then - local at_0 = math.max(0, math.min(range.start_line - 1, line_count)) - vim.api.nvim_buf_set_lines(bufnr, at_0, at_0, false, lines) - else - -- append - vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, lines) - end + if strategy == "replace" and range and range.start_line and range.end_line then + local start_0 = math.max(0, range.start_line - 1) + local end_0 = math.min(line_count, range.end_line) + if end_0 < start_0 then + end_0 = start_0 + end + vim.api.nvim_buf_set_lines(bufnr, start_0, end_0, false, lines) + elseif strategy == "insert" and range and range.start_line then + local at_0 = math.max(0, math.min(range.start_line - 1, line_count)) + vim.api.nvim_buf_set_lines(bufnr, at_0, at_0, false, lines) + else + -- append + vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, lines) + end - return { imports_added = 0, body_lines = body_lines, imports_merged = false } + return { imports_added = 0, body_lines = body_lines, imports_merged = false } end --- Inject code for refactor (replace entire file) diff --git a/lua/codetyper/params/agents/bash.lua b/lua/codetyper/params/agents/bash.lua index 5a81e86..bc19454 100644 --- a/lua/codetyper/params/agents/bash.lua +++ b/lua/codetyper/params/agents/bash.lua @@ -1,35 +1,37 @@ +local M = {} + M.params = { - { - name = "command", - description = "The shell command to execute", - type = "string", - }, - { - name = "cwd", - description = "Working directory for the command (optional)", - type = "string", - optional = true, - }, - { - name = "timeout", - description = "Timeout in milliseconds (default: 120000)", - type = "integer", - optional = true, - }, + { + name = "command", + description = "The shell command to execute", + type = "string", + }, + { + name = "cwd", + description = "Working directory for the command (optional)", + type = "string", + optional = true, + }, + { + name = "timeout", + description = "Timeout in milliseconds (default: 120000)", + type = "integer", + optional = true, + }, } M.returns = { - { - name = "stdout", - description = "Command output", - type = "string", - }, - { - name = "error", - description = "Error message if command failed", - type = "string", - optional = true, - }, + { + name = "stdout", + description = "Command output", + type = "string", + }, + { + name = "error", + description = "Error message if command failed", + type = "string", + optional = true, + }, } return M diff --git a/lua/codetyper/params/agents/confidence.lua b/lua/codetyper/params/agents/confidence.lua index 2cb192b..ce105e7 100644 --- a/lua/codetyper/params/agents/confidence.lua +++ b/lua/codetyper/params/agents/confidence.lua @@ -3,38 +3,38 @@ local M = {} --- Heuristic weights (must sum to 1.0) M.weights = { - length = 0.15, -- Response length relative to prompt - uncertainty = 0.30, -- Uncertainty phrases - syntax = 0.25, -- Syntax completeness - repetition = 0.15, -- Duplicate lines - truncation = 0.15, -- Incomplete ending + length = 0.15, -- Response length relative to prompt + uncertainty = 0.30, -- Uncertainty phrases + syntax = 0.25, -- Syntax completeness + repetition = 0.15, -- Duplicate lines + truncation = 0.15, -- Incomplete ending } --- Uncertainty phrases that indicate low confidence M.uncertainty_phrases = { - -- English - "i'm not sure", - "i am not sure", - "maybe", - "perhaps", - "might work", - "could work", - "not certain", - "uncertain", - "i think", - "possibly", - "TODO", - "FIXME", - "XXX", - "placeholder", - "implement this", - "fill in", - "your code here", - "...", -- Ellipsis as placeholder - "# TODO", - "// TODO", - "-- TODO", - "/* TODO", + -- English + "i'm not sure", + "i am not sure", + "maybe", + "perhaps", + "might work", + "could work", + "not certain", + "uncertain", + "i think", + "possibly", + "TODO", + "FIXME", + "XXX", + "placeholder", + "implement this", + "fill in", + "your code here", + "...", -- Ellipsis as placeholder + "# TODO", + "// TODO", + "-- TODO", + "/* TODO", } return M diff --git a/lua/codetyper/params/agents/conflict.lua b/lua/codetyper/params/agents/conflict.lua index 5c77656..f205c22 100644 --- a/lua/codetyper/params/agents/conflict.lua +++ b/lua/codetyper/params/agents/conflict.lua @@ -3,31 +3,31 @@ local M = {} --- Configuration defaults M.config = { - -- Run linter check after accepting AI suggestions - lint_after_accept = true, - -- Auto-fix lint errors without prompting - auto_fix_lint_errors = true, - -- Auto-show menu after injecting conflict - auto_show_menu = true, - -- Auto-show menu for next conflict after resolving one - auto_show_next_menu = true, + -- Run linter check after accepting AI suggestions + lint_after_accept = true, + -- Auto-fix lint errors without prompting + auto_fix_lint_errors = true, + -- Auto-show menu after injecting conflict + auto_show_menu = true, + -- Auto-show menu for next conflict after resolving one + auto_show_next_menu = true, } --- Highlight groups M.hl_groups = { - current = "CoderConflictCurrent", - current_label = "CoderConflictCurrentLabel", - incoming = "CoderConflictIncoming", - incoming_label = "CoderConflictIncomingLabel", - separator = "CoderConflictSeparator", - hint = "CoderConflictHint", + current = "CoderConflictCurrent", + current_label = "CoderConflictCurrentLabel", + incoming = "CoderConflictIncoming", + incoming_label = "CoderConflictIncomingLabel", + separator = "CoderConflictSeparator", + hint = "CoderConflictHint", } --- Conflict markers M.markers = { - current_start = "<<<<<<< CURRENT", - separator = "=======", - incoming_end = ">>>>>>> INCOMING", + current_start = "<<<<<<< CURRENT", + separator = "=======", + incoming_end = ">>>>>>> INCOMING", } return M diff --git a/lua/codetyper/params/agents/context.lua b/lua/codetyper/params/agents/context.lua index cfa4acb..84bd697 100644 --- a/lua/codetyper/params/agents/context.lua +++ b/lua/codetyper/params/agents/context.lua @@ -3,46 +3,46 @@ local M = {} --- Common ignore patterns M.ignore_patterns = { - "^%.", -- Hidden files/dirs - "node_modules", - "%.git$", - "__pycache__", - "%.pyc$", - "target", -- Rust - "build", - "dist", - "%.o$", - "%.a$", - "%.so$", - "%.min%.", - "%.map$", + "^%.", -- Hidden files/dirs + "node_modules", + "%.git$", + "__pycache__", + "%.pyc$", + "target", -- Rust + "build", + "dist", + "%.o$", + "%.a$", + "%.so$", + "%.min%.", + "%.map$", } --- Key files that are important for understanding the project M.important_files = { - ["package.json"] = "Node.js project config", - ["Cargo.toml"] = "Rust project config", - ["go.mod"] = "Go module config", - ["pyproject.toml"] = "Python project config", - ["setup.py"] = "Python setup config", - ["Makefile"] = "Build configuration", - ["CMakeLists.txt"] = "CMake config", - [".gitignore"] = "Git ignore patterns", - ["README.md"] = "Project documentation", - ["init.lua"] = "Neovim plugin entry", - ["plugin.lua"] = "Neovim plugin config", + ["package.json"] = "Node.js project config", + ["Cargo.toml"] = "Rust project config", + ["go.mod"] = "Go module config", + ["pyproject.toml"] = "Python project config", + ["setup.py"] = "Python setup config", + ["Makefile"] = "Build configuration", + ["CMakeLists.txt"] = "CMake config", + [".gitignore"] = "Git ignore patterns", + ["README.md"] = "Project documentation", + ["init.lua"] = "Neovim plugin entry", + ["plugin.lua"] = "Neovim plugin config", } --- Project type detection indicators M.indicators = { - ["package.json"] = { type = "node", language = "javascript/typescript" }, - ["Cargo.toml"] = { type = "rust", language = "rust" }, - ["go.mod"] = { type = "go", language = "go" }, - ["pyproject.toml"] = { type = "python", language = "python" }, - ["setup.py"] = { type = "python", language = "python" }, - ["Gemfile"] = { type = "ruby", language = "ruby" }, - ["pom.xml"] = { type = "maven", language = "java" }, - ["build.gradle"] = { type = "gradle", language = "java/kotlin" }, + ["package.json"] = { type = "node", language = "javascript/typescript" }, + ["Cargo.toml"] = { type = "rust", language = "rust" }, + ["go.mod"] = { type = "go", language = "go" }, + ["pyproject.toml"] = { type = "python", language = "python" }, + ["setup.py"] = { type = "python", language = "python" }, + ["Gemfile"] = { type = "ruby", language = "ruby" }, + ["pom.xml"] = { type = "maven", language = "java" }, + ["build.gradle"] = { type = "gradle", language = "java/kotlin" }, } return M diff --git a/lua/codetyper/params/agents/edit.lua b/lua/codetyper/params/agents/edit.lua index 3f353cd..81b6232 100644 --- a/lua/codetyper/params/agents/edit.lua +++ b/lua/codetyper/params/agents/edit.lua @@ -1,33 +1,35 @@ +local M = {} + M.params = { - { - name = "path", - description = "Path to the file to edit", - type = "string", - }, - { - name = "old_string", - description = "Text to find and replace (empty string to create new file or append)", - type = "string", - }, - { - name = "new_string", - description = "Text to replace with", - type = "string", - }, + { + name = "path", + description = "Path to the file to edit", + type = "string", + }, + { + name = "old_string", + description = "Text to find and replace (empty string to create new file or append)", + type = "string", + }, + { + name = "new_string", + description = "Text to replace with", + type = "string", + }, } M.returns = { - { - name = "success", - description = "Whether the edit was applied", - type = "boolean", - }, - { - name = "error", - description = "Error message if edit failed", - type = "string", - optional = true, - }, + { + name = "success", + description = "Whether the edit was applied", + type = "boolean", + }, + { + name = "error", + description = "Error message if edit failed", + type = "string", + optional = true, + }, } return M diff --git a/lua/codetyper/params/agents/grep.lua b/lua/codetyper/params/agents/grep.lua index b977a46..5165be3 100644 --- a/lua/codetyper/params/agents/grep.lua +++ b/lua/codetyper/params/agents/grep.lua @@ -1,3 +1,5 @@ +local M = {} + M.description = [[Searches for a pattern in files using ripgrep. Returns file paths and matching lines. Use this to find code by content. diff --git a/lua/codetyper/params/agents/intent.lua b/lua/codetyper/params/agents/intent.lua index cea91ad..dd53e89 100644 --- a/lua/codetyper/params/agents/intent.lua +++ b/lua/codetyper/params/agents/intent.lua @@ -3,167 +3,167 @@ local M = {} --- Intent patterns with associated metadata M.intent_patterns = { - -- Complete: fill in missing implementation - complete = { - patterns = { - "complete", - "finish", - "implement", - "fill in", - "fill out", - "stub", - "todo", - "fixme", - }, - scope_hint = "function", - action = "replace", - priority = 1, - }, + -- Complete: fill in missing implementation + complete = { + patterns = { + "complete", + "finish", + "implement", + "fill in", + "fill out", + "stub", + "todo", + "fixme", + }, + scope_hint = "function", + action = "replace", + priority = 1, + }, - -- Refactor: rewrite existing code - refactor = { - patterns = { - "refactor", - "rewrite", - "restructure", - "reorganize", - "clean up", - "cleanup", - "simplify", - "improve", - }, - scope_hint = "function", - action = "replace", - priority = 2, - }, + -- Refactor: rewrite existing code + refactor = { + patterns = { + "refactor", + "rewrite", + "restructure", + "reorganize", + "clean up", + "cleanup", + "simplify", + "improve", + }, + scope_hint = "function", + action = "replace", + priority = 2, + }, - -- Fix: repair bugs or issues - fix = { - patterns = { - "fix", - "repair", - "correct", - "debug", - "solve", - "resolve", - "patch", - "bug", - "error", - "issue", - "update", - "modify", - "change", - "adjust", - "tweak", - }, - scope_hint = "function", - action = "replace", - priority = 1, - }, + -- Fix: repair bugs or issues + fix = { + patterns = { + "fix", + "repair", + "correct", + "debug", + "solve", + "resolve", + "patch", + "bug", + "error", + "issue", + "update", + "modify", + "change", + "adjust", + "tweak", + }, + scope_hint = "function", + action = "replace", + priority = 1, + }, - -- Add: insert new code - add = { - patterns = { - "add", - "create", - "insert", - "include", - "append", - "new", - "generate", - "write", - }, - scope_hint = nil, -- Could be anywhere - action = "insert", - priority = 3, - }, + -- Add: insert new code + add = { + patterns = { + "add", + "create", + "insert", + "include", + "append", + "new", + "generate", + "write", + }, + scope_hint = nil, -- Could be anywhere + action = "insert", + priority = 3, + }, - -- Document: add documentation - document = { - patterns = { - "document", - "comment", - "jsdoc", - "docstring", - "describe", - "annotate", - "type hint", - "typehint", - }, - scope_hint = "function", - action = "replace", -- Replace with documented version - priority = 2, - }, + -- Document: add documentation + document = { + patterns = { + "document", + "comment", + "jsdoc", + "docstring", + "describe", + "annotate", + "type hint", + "typehint", + }, + scope_hint = "function", + action = "replace", -- Replace with documented version + priority = 2, + }, - -- Test: generate tests - test = { - patterns = { - "test", - "spec", - "unit test", - "integration test", - "coverage", - }, - scope_hint = "file", - action = "append", - priority = 3, - }, + -- Test: generate tests + test = { + patterns = { + "test", + "spec", + "unit test", + "integration test", + "coverage", + }, + scope_hint = "file", + action = "append", + priority = 3, + }, - -- Optimize: improve performance - optimize = { - patterns = { - "optimize", - "performance", - "faster", - "efficient", - "speed up", - "reduce", - "minimize", - }, - scope_hint = "function", - action = "replace", - priority = 2, - }, + -- Optimize: improve performance + optimize = { + patterns = { + "optimize", + "performance", + "faster", + "efficient", + "speed up", + "reduce", + "minimize", + }, + scope_hint = "function", + action = "replace", + priority = 2, + }, - -- Explain: generate documentation for selected code - explain = { - patterns = { - "explain", - "what does", - "what is", - "how does", - "how is", - "why does", - "why is", - "tell me", - "walk through", - "understand", - "question", - "what's this", - "what this", - "about this", - "help me understand", - }, - scope_hint = "function", - action = "insert", - priority = 4, - }, + -- Explain: generate documentation for selected code + explain = { + patterns = { + "explain", + "what does", + "what is", + "how does", + "how is", + "why does", + "why is", + "tell me", + "walk through", + "understand", + "question", + "what's this", + "what this", + "about this", + "help me understand", + }, + scope_hint = "function", + action = "insert", + priority = 4, + }, } --- Scope hint patterns M.scope_patterns = { - ["this function"] = "function", - ["this method"] = "function", - ["the function"] = "function", - ["the method"] = "function", - ["this class"] = "class", - ["the class"] = "class", - ["this file"] = "file", - ["the file"] = "file", - ["this block"] = "block", - ["the block"] = "block", - ["this"] = nil, -- Use Tree-sitter to determine - ["here"] = nil, + ["this function"] = "function", + ["this method"] = "function", + ["the function"] = "function", + ["the method"] = "function", + ["this class"] = "class", + ["the class"] = "class", + ["this file"] = "file", + ["the file"] = "file", + ["this block"] = "block", + ["the block"] = "block", + ["this"] = nil, -- Use Tree-sitter to determine + ["here"] = nil, } return M diff --git a/lua/codetyper/params/agents/languages.lua b/lua/codetyper/params/agents/languages.lua index 7aee62c..5e8707f 100644 --- a/lua/codetyper/params/agents/languages.lua +++ b/lua/codetyper/params/agents/languages.lua @@ -3,57 +3,57 @@ local M = {} --- Language-specific import patterns M.import_patterns = { - -- JavaScript/TypeScript - javascript = { - { pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true }, - { pattern = "^%s*import%s+['\"]", multi_line = false }, - { pattern = "^%s*import%s*{", multi_line = true }, - { pattern = "^%s*import%s*%*", multi_line = true }, - { pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true }, - { pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false }, - { pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false }, - { pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false }, - }, - -- Python - python = { - { pattern = "^%s*import%s+%w", multi_line = false }, - { pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true }, - }, - -- Lua - lua = { - { pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false }, - { pattern = "^%s*require%s*%(?['\"]", multi_line = false }, - }, - -- Go - go = { - { pattern = "^%s*import%s+%(?", multi_line = true }, - }, - -- Rust - rust = { - { pattern = "^%s*use%s+", multi_line = true }, - { pattern = "^%s*extern%s+crate%s+", multi_line = false }, - }, - -- C/C++ - c = { - { pattern = "^%s*#include%s*[<\"]", multi_line = false }, - }, - -- Java/Kotlin - java = { - { pattern = "^%s*import%s+", multi_line = false }, - }, - -- Ruby - ruby = { - { pattern = "^%s*require%s+['\"]", multi_line = false }, - { pattern = "^%s*require_relative%s+['\"]", multi_line = false }, - }, - -- PHP - php = { - { pattern = "^%s*use%s+", multi_line = false }, - { pattern = "^%s*require%s+['\"]", multi_line = false }, - { pattern = "^%s*require_once%s+['\"]", multi_line = false }, - { pattern = "^%s*include%s+['\"]", multi_line = false }, - { pattern = "^%s*include_once%s+['\"]", multi_line = false }, - }, + -- JavaScript/TypeScript + javascript = { + { pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true }, + { pattern = "^%s*import%s+['\"]", multi_line = false }, + { pattern = "^%s*import%s*{", multi_line = true }, + { pattern = "^%s*import%s*%*", multi_line = true }, + { pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true }, + { pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false }, + { pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false }, + { pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false }, + }, + -- Python + python = { + { pattern = "^%s*import%s+%w", multi_line = false }, + { pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true }, + }, + -- Lua + lua = { + { pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false }, + { pattern = "^%s*require%s*%(?['\"]", multi_line = false }, + }, + -- Go + go = { + { pattern = "^%s*import%s+%(?", multi_line = true }, + }, + -- Rust + rust = { + { pattern = "^%s*use%s+", multi_line = true }, + { pattern = "^%s*extern%s+crate%s+", multi_line = false }, + }, + -- C/C++ + c = { + { pattern = '^%s*#include%s*[<"]', multi_line = false }, + }, + -- Java/Kotlin + java = { + { pattern = "^%s*import%s+", multi_line = false }, + }, + -- Ruby + ruby = { + { pattern = "^%s*require%s+['\"]", multi_line = false }, + { pattern = "^%s*require_relative%s+['\"]", multi_line = false }, + }, + -- PHP + php = { + { pattern = "^%s*use%s+", multi_line = false }, + { pattern = "^%s*require%s+['\"]", multi_line = false }, + { pattern = "^%s*require_once%s+['\"]", multi_line = false }, + { pattern = "^%s*include%s+['\"]", multi_line = false }, + { pattern = "^%s*include_once%s+['\"]", multi_line = false }, + }, } -- Alias common extensions to language configs @@ -72,16 +72,16 @@ M.import_patterns.rb = M.import_patterns.ruby --- Language-specific comment patterns M.comment_patterns = { - lua = { "^%-%-" }, - python = { "^#" }, - javascript = { "^//", "^/%*", "^%*" }, - typescript = { "^//", "^/%*", "^%*" }, - go = { "^//", "^/%*", "^%*" }, - rust = { "^//", "^/%*", "^%*" }, - c = { "^//", "^/%*", "^%*", "^#" }, - java = { "^//", "^/%*", "^%*" }, - ruby = { "^#" }, - php = { "^//", "^/%*", "^%*", "^#" }, + lua = { "^%-%-" }, + python = { "^#" }, + javascript = { "^//", "^/%*", "^%*" }, + typescript = { "^//", "^/%*", "^%*" }, + go = { "^//", "^/%*", "^%*" }, + rust = { "^//", "^/%*", "^%*" }, + c = { "^//", "^/%*", "^%*", "^#" }, + java = { "^//", "^/%*", "^%*" }, + ruby = { "^#" }, + php = { "^//", "^/%*", "^%*", "^#" }, } return M diff --git a/lua/codetyper/params/agents/linter.lua b/lua/codetyper/params/agents/linter.lua index 88644b1..1aed1ad 100644 --- a/lua/codetyper/params/agents/linter.lua +++ b/lua/codetyper/params/agents/linter.lua @@ -2,14 +2,14 @@ local M = {} M.config = { - -- Auto-save file after code injection - auto_save = true, - -- Delay in ms to wait for LSP diagnostics to update - diagnostic_delay_ms = 500, - -- Severity levels to check (1=Error, 2=Warning, 3=Info, 4=Hint) - min_severity = vim.diagnostic.severity.WARN, - -- Auto-offer to fix lint errors - auto_offer_fix = true, + -- Auto-save file after code injection + auto_save = true, + -- Delay in ms to wait for LSP diagnostics to update + diagnostic_delay_ms = 500, + -- Severity levels to check (1=Error, 2=Warning, 3=Info, 4=Hint) + min_severity = vim.diagnostic.severity.WARN, + -- Auto-offer to fix lint errors + auto_offer_fix = true, } return M diff --git a/lua/codetyper/params/agents/logs.lua b/lua/codetyper/params/agents/logs.lua index 6346f73..4747aa7 100644 --- a/lua/codetyper/params/agents/logs.lua +++ b/lua/codetyper/params/agents/logs.lua @@ -2,35 +2,35 @@ local M = {} M.icons = { - start = "->", - success = "OK", - error = "ERR", - approval = "??", - approved = "YES", - rejected = "NO", + start = "->", + success = "OK", + error = "ERR", + approval = "??", + approved = "YES", + rejected = "NO", } M.level_icons = { - info = "i", - debug = ".", - request = ">", - response = "<", - tool = "T", - error = "!", - warning = "?", - success = "i", - queue = "Q", - patch = "P", + info = "i", + debug = ".", + request = ">", + response = "<", + tool = "T", + error = "!", + warning = "?", + success = "i", + queue = "Q", + patch = "P", } M.thinking_types = { "thinking", "reason", "action", "task", "result" } M.thinking_prefixes = { - thinking = "⏺", - reason = "⏺", - action = "⏺", - task = "✶", - result = "", + thinking = "⏺", + reason = "⏺", + action = "⏺", + task = "✶", + result = "", } -return M \ No newline at end of file +return M diff --git a/lua/codetyper/params/agents/parser.lua b/lua/codetyper/params/agents/parser.lua index d014634..c6216c7 100644 --- a/lua/codetyper/params/agents/parser.lua +++ b/lua/codetyper/params/agents/parser.lua @@ -2,14 +2,14 @@ local M = {} M.patterns = { - fenced_json = "```json%s*(%b{})%s*```", - inline_json = '(%{"tool"%s*:%s*"[^"]+"%s*,%s*"parameters"%s*:%s*%b{}%})', + fenced_json = "```json%s*(%b{})%s*```", + inline_json = '(%{"tool"%s*:%s*"[^"]+"%s*,%s*"parameters"%s*:%s*%b{}%})', } M.defaults = { - stop_reason = "end_turn", - tool_stop_reason = "tool_use", - replacement_text = "[Tool call]", + stop_reason = "end_turn", + tool_stop_reason = "tool_use", + replacement_text = "[Tool call]", } -return M \ No newline at end of file +return M diff --git a/lua/codetyper/params/agents/patch.lua b/lua/codetyper/params/agents/patch.lua index c328e0e..4ad1622 100644 --- a/lua/codetyper/params/agents/patch.lua +++ b/lua/codetyper/params/agents/patch.lua @@ -2,11 +2,11 @@ local M = {} M.config = { - snapshot_range = 5, -- Lines above/below prompt to snapshot - clean_interval_ms = 60000, -- Check for stale patches every minute - max_age_ms = 3600000, -- 1 hour TTL - staleness_check = true, - use_search_replace_parser = true, -- Enable new parsing logic + snapshot_range = 5, -- Lines above/below prompt to snapshot + clean_interval_ms = 60000, -- Check for stale patches every minute + max_age_ms = 3600000, -- 1 hour TTL + staleness_check = true, + use_search_replace_parser = true, -- Enable new parsing logic } return M diff --git a/lua/codetyper/params/agents/permissions.lua b/lua/codetyper/params/agents/permissions.lua index bf49d34..373b9fb 100644 --- a/lua/codetyper/params/agents/permissions.lua +++ b/lua/codetyper/params/agents/permissions.lua @@ -3,45 +3,45 @@ local M = {} --- Dangerous command patterns that should never be auto-allowed M.dangerous_patterns = { - "^rm%s+%-rf", - "^rm%s+%-r%s+/", - "^rm%s+/", - "^sudo%s+rm", - "^chmod%s+777", - "^chmod%s+%-R", - "^chown%s+%-R", - "^dd%s+", - "^mkfs", - "^fdisk", - "^format", - ":.*>%s*/dev/", - "^curl.*|.*sh", - "^wget.*|.*sh", - "^eval%s+", - "`;.*`", - "%$%(.*%)", - "fork%s*bomb", + "^rm%s+%-rf", + "^rm%s+%-r%s+/", + "^rm%s+/", + "^sudo%s+rm", + "^chmod%s+777", + "^chmod%s+%-R", + "^chown%s+%-R", + "^dd%s+", + "^mkfs", + "^fdisk", + "^format", + ":.*>%s*/dev/", + "^curl.*|.*sh", + "^wget.*|.*sh", + "^eval%s+", + "`;.*`", + "%$%(.*%)", + "fork%s*bomb", } --- Safe command patterns that can be auto-allowed M.safe_patterns = { - "^ls%s", - "^ls$", - "^cat%s", - "^head%s", - "^tail%s", - "^grep%s", - "^find%s", - "^pwd$", - "^echo%s", - "^wc%s", - "^git%s+status", - "^git%s+diff", - "^git%s+log", - "^git%s+show", - "^git%s+branch", - "^git%s+checkout", - "^git%s+add", -- Generally safe if reviewing changes + "^ls%s", + "^ls$", + "^cat%s", + "^head%s", + "^tail%s", + "^grep%s", + "^find%s", + "^pwd$", + "^echo%s", + "^wc%s", + "^git%s+status", + "^git%s+diff", + "^git%s+log", + "^git%s+show", + "^git%s+branch", + "^git%s+checkout", + "^git%s+add", -- Generally safe if reviewing changes } return M diff --git a/lua/codetyper/params/agents/scheduler.lua b/lua/codetyper/params/agents/scheduler.lua index 1c9d135..8d8ffb1 100644 --- a/lua/codetyper/params/agents/scheduler.lua +++ b/lua/codetyper/params/agents/scheduler.lua @@ -4,13 +4,13 @@ local M = {} M.config = { - enabled = true, - ollama_scout = true, - escalation_threshold = 0.7, - max_concurrent = 5, -- Allow multiple in-flight requests (like 99); user can type while thinking - completion_delay_ms = 100, - apply_delay_ms = 5000, -- Wait before applying code - remote_provider = "copilot", -- Default fallback provider + enabled = true, + ollama_scout = true, + escalation_threshold = 0.7, + max_concurrent = 5, -- Allow multiple in-flight requests (like 99); user can type while thinking + completion_delay_ms = 100, + apply_delay_ms = 5000, -- Wait before applying code + remote_provider = "copilot", -- Default fallback provider } return M diff --git a/lua/codetyper/params/agents/scope.lua b/lua/codetyper/params/agents/scope.lua index aa15eae..160d4bc 100644 --- a/lua/codetyper/params/agents/scope.lua +++ b/lua/codetyper/params/agents/scope.lua @@ -3,70 +3,70 @@ local M = {} --- Node types that represent function-like scopes per language M.function_nodes = { - -- Lua - ["function_declaration"] = "function", - ["function_definition"] = "function", - ["local_function"] = "function", - ["function"] = "function", + -- Lua + ["function_declaration"] = "function", + ["function_definition"] = "function", + ["local_function"] = "function", + ["function"] = "function", - -- JavaScript/TypeScript - ["function_declaration"] = "function", - ["function_expression"] = "function", - ["arrow_function"] = "function", - ["method_definition"] = "method", - ["function"] = "function", + -- JavaScript/TypeScript + ["function_declaration"] = "function", + ["function_expression"] = "function", + ["arrow_function"] = "function", + ["method_definition"] = "method", + ["function"] = "function", - -- Python - ["function_definition"] = "function", - ["lambda"] = "function", + -- Python + ["function_definition"] = "function", + ["lambda"] = "function", - -- Go - ["function_declaration"] = "function", - ["method_declaration"] = "method", - ["func_literal"] = "function", + -- Go + ["function_declaration"] = "function", + ["method_declaration"] = "method", + ["func_literal"] = "function", - -- Rust - ["function_item"] = "function", - ["closure_expression"] = "function", + -- Rust + ["function_item"] = "function", + ["closure_expression"] = "function", - -- C/C++ - ["function_definition"] = "function", - ["lambda_expression"] = "function", + -- C/C++ + ["function_definition"] = "function", + ["lambda_expression"] = "function", - -- Java - ["method_declaration"] = "method", - ["constructor_declaration"] = "method", - ["lambda_expression"] = "function", + -- Java + ["method_declaration"] = "method", + ["constructor_declaration"] = "method", + ["lambda_expression"] = "function", - -- Ruby - ["method"] = "method", - ["singleton_method"] = "method", - ["lambda"] = "function", - ["block"] = "function", + -- Ruby + ["method"] = "method", + ["singleton_method"] = "method", + ["lambda"] = "function", + ["block"] = "function", - -- PHP - ["function_definition"] = "function", - ["method_declaration"] = "method", - ["arrow_function"] = "function", + -- PHP + ["function_definition"] = "function", + ["method_declaration"] = "method", + ["arrow_function"] = "function", } --- Node types that represent class-like scopes M.class_nodes = { - ["class_declaration"] = "class", - ["class_definition"] = "class", - ["struct_declaration"] = "class", - ["impl_item"] = "class", -- Rust config - ["interface_declaration"] = "class", - ["trait_item"] = "class", + ["class_declaration"] = "class", + ["class_definition"] = "class", + ["struct_declaration"] = "class", + ["impl_item"] = "class", -- Rust config + ["interface_declaration"] = "class", + ["trait_item"] = "class", } --- Node types that represent block scopes M.block_nodes = { - ["block"] = "block", - ["do_statement"] = "block", -- Lua - ["if_statement"] = "block", - ["for_statement"] = "block", - ["while_statement"] = "block", + ["block"] = "block", + ["do_statement"] = "block", -- Lua + ["if_statement"] = "block", + ["for_statement"] = "block", + ["while_statement"] = "block", } return M diff --git a/lua/codetyper/params/agents/search_replace.lua b/lua/codetyper/params/agents/search_replace.lua index 54ede10..e40e0a2 100644 --- a/lua/codetyper/params/agents/search_replace.lua +++ b/lua/codetyper/params/agents/search_replace.lua @@ -2,10 +2,10 @@ local M = {} M.patterns = { - dash_style = "%-%-%-%-%-%-%-?%s*SEARCH%s*\n(.-)\n=======%s*\n(.-)\n%+%+%+%+%+%+%+?%s*REPLACE", - claude_style = "<<<<<<<[%s]*SEARCH%s*\n(.-)\n=======%s*\n(.-)\n>>>>>>>[%s]*REPLACE", - simple_style = "%[SEARCH%]%s*\n(.-)\n%[REPLACE%]%s*\n(.-)\n%[END%]", - diff_block = "```diff\n(.-)\n```", + dash_style = "%-%-%-%-%-%-%-?%s*SEARCH%s*\n(.-)\n=======%s*\n(.-)\n%+%+%+%+%+%+%+?%s*REPLACE", + claude_style = "<<<<<<<[%s]*SEARCH%s*\n(.-)\n=======%s*\n(.-)\n>>>>>>>[%s]*REPLACE", + simple_style = "%[SEARCH%]%s*\n(.-)\n%[REPLACE%]%s*\n(.-)\n%[END%]", + diff_block = "```diff\n(.-)\n```", } -return M \ No newline at end of file +return M diff --git a/lua/codetyper/params/agents/tools.lua b/lua/codetyper/params/agents/tools.lua index 94379dd..ba55984 100644 --- a/lua/codetyper/params/agents/tools.lua +++ b/lua/codetyper/params/agents/tools.lua @@ -3,145 +3,145 @@ local M = {} --- Tool definitions in a provider-agnostic format M.definitions = { - read_file = { - name = "read_file", - description = "Read the contents of a file at the specified path", - parameters = { - type = "object", - properties = { - path = { - type = "string", - description = "The path to the file to read", - }, - start_line = { - type = "number", - description = "Optional start line number (1-indexed)", - }, - end_line = { - type = "number", - description = "Optional end line number (1-indexed)", - }, - }, - required = { "path" }, - }, - }, + read_file = { + name = "read_file", + description = "Read the contents of a file at the specified path", + parameters = { + type = "object", + properties = { + path = { + type = "string", + description = "The path to the file to read", + }, + start_line = { + type = "number", + description = "Optional start line number (1-indexed)", + }, + end_line = { + type = "number", + description = "Optional end line number (1-indexed)", + }, + }, + required = { "path" }, + }, + }, - edit_file = { - name = "edit_file", - description = "Edit a file by replacing specific content. Provide the exact content to find and the replacement.", - parameters = { - type = "object", - properties = { - path = { - type = "string", - description = "The path to the file to edit", - }, - find = { - type = "string", - description = "The exact content to replace", - }, - replace = { - type = "string", - description = "The new content", - }, - }, - required = { "path", "find", "replace" }, - }, - }, + edit_file = { + name = "edit_file", + description = "Edit a file by replacing specific content. Provide the exact content to find and the replacement.", + parameters = { + type = "object", + properties = { + path = { + type = "string", + description = "The path to the file to edit", + }, + find = { + type = "string", + description = "The exact content to replace", + }, + replace = { + type = "string", + description = "The new content", + }, + }, + required = { "path", "find", "replace" }, + }, + }, - write_file = { - name = "write_file", - description = "Write content to a file, creating it if it doesn't exist or overwriting if it does", - parameters = { - type = "object", - properties = { - path = { - type = "string", - description = "The path to the file to write", - }, - content = { - type = "string", - description = "The content to write", - }, - }, - required = { "path", "content" }, - }, - }, + write_file = { + name = "write_file", + description = "Write content to a file, creating it if it doesn't exist or overwriting if it does", + parameters = { + type = "object", + properties = { + path = { + type = "string", + description = "The path to the file to write", + }, + content = { + type = "string", + description = "The content to write", + }, + }, + required = { "path", "content" }, + }, + }, - bash = { - name = "bash", - description = "Execute a bash command and return the output. Use for git, npm, build tools, etc.", - parameters = { - type = "object", - properties = { - command = { - type = "string", - description = "The bash command to execute", - }, - }, - required = { "command" }, - }, - }, + bash = { + name = "bash", + description = "Execute a bash command and return the output. Use for git, npm, build tools, etc.", + parameters = { + type = "object", + properties = { + command = { + type = "string", + description = "The bash command to execute", + }, + }, + required = { "command" }, + }, + }, - delete_file = { - name = "delete_file", - description = "Delete a file", - parameters = { - type = "object", - properties = { - path = { - type = "string", - description = "The path to the file to delete", - }, - reason = { - type = "string", - description = "Reason for deletion", - }, - }, - required = { "path", "reason" }, - }, - }, + delete_file = { + name = "delete_file", + description = "Delete a file", + parameters = { + type = "object", + properties = { + path = { + type = "string", + description = "The path to the file to delete", + }, + reason = { + type = "string", + description = "Reason for deletion", + }, + }, + required = { "path", "reason" }, + }, + }, - list_directory = { - name = "list_directory", - description = "List files and directories in a path", - parameters = { - type = "object", - properties = { - path = { - type = "string", - description = "The path to list", - }, - recursive = { - type = "boolean", - description = "Whether to list recursively", - }, - }, - required = { "path" }, - }, - }, + list_directory = { + name = "list_directory", + description = "List files and directories in a path", + parameters = { + type = "object", + properties = { + path = { + type = "string", + description = "The path to list", + }, + recursive = { + type = "boolean", + description = "Whether to list recursively", + }, + }, + required = { "path" }, + }, + }, - search_files = { - name = "search_files", - description = "Search for files by name/glob pattern or content", - parameters = { - type = "object", - properties = { - pattern = { - type = "string", - description = "Glob pattern to search for filenames", - }, - content = { - type = "string", - description = "Content string to search for within files", - }, - path = { - type = "string", - description = "The root path to start search", - }, - }, - }, - }, + search_files = { + name = "search_files", + description = "Search for files by name/glob pattern or content", + parameters = { + type = "object", + properties = { + pattern = { + type = "string", + description = "Glob pattern to search for filenames", + }, + content = { + type = "string", + description = "Content string to search for within files", + }, + path = { + type = "string", + description = "The root path to start search", + }, + }, + }, + }, } return M diff --git a/lua/codetyper/params/agents/view.lua b/lua/codetyper/params/agents/view.lua index 64d2d4e..73e08a1 100644 --- a/lua/codetyper/params/agents/view.lua +++ b/lua/codetyper/params/agents/view.lua @@ -1,37 +1,37 @@ local M = {} M.params = { - { - name = "path", - description = "Path to the file (relative to project root or absolute)", - type = "string", - }, - { - name = "start_line", - description = "Line number to start reading (1-indexed)", - type = "integer", - optional = true, - }, - { - name = "end_line", - description = "Line number to end reading (1-indexed, inclusive)", - type = "integer", - optional = true, - }, + { + name = "path", + description = "Path to the file (relative to project root or absolute)", + type = "string", + }, + { + name = "start_line", + description = "Line number to start reading (1-indexed)", + type = "integer", + optional = true, + }, + { + name = "end_line", + description = "Line number to end reading (1-indexed, inclusive)", + type = "integer", + optional = true, + }, } M.returns = { - { - name = "content", - description = "File contents as JSON with content, total_line_count, is_truncated", - type = "string", - }, - { - name = "error", - description = "Error message if file could not be read", - type = "string", - optional = true, - }, + { + name = "content", + description = "File contents as JSON with content, total_line_count, is_truncated", + type = "string", + }, + { + name = "error", + description = "Error message if file could not be read", + type = "string", + optional = true, + }, } -return M \ No newline at end of file +return M diff --git a/lua/codetyper/params/agents/worker.lua b/lua/codetyper/params/agents/worker.lua index 261676f..3a730e8 100644 --- a/lua/codetyper/params/agents/worker.lua +++ b/lua/codetyper/params/agents/worker.lua @@ -3,24 +3,24 @@ local M = {} --- Patterns that indicate LLM needs more context (must be near start of response) M.context_needed_patterns = { - "I need to see", - "Could you provide", - "Please provide", - "Can you show", - "don't have enough context", - "need more information", - "cannot see the definition", - "missing the implementation", - "I would need to check", - "please share", - "Please upload", - "could not find", + "I need to see", + "Could you provide", + "Please provide", + "Can you show", + "don't have enough context", + "need more information", + "cannot see the definition", + "missing the implementation", + "I would need to check", + "please share", + "Please upload", + "could not find", } --- Default timeouts by provider type M.default_timeouts = { - ollama = 120000, -- 120s (local models can be slower) - copilot = 60000, -- 60s + ollama = 120000, -- 120s (local models can be slower) + copilot = 60000, -- 60s } return M diff --git a/lua/codetyper/params/agents/write.lua b/lua/codetyper/params/agents/write.lua index 60a71b4..c07092a 100644 --- a/lua/codetyper/params/agents/write.lua +++ b/lua/codetyper/params/agents/write.lua @@ -1,30 +1,30 @@ local M = {} M.params = { - { - name = "path", - description = "Path to the file to write", - type = "string", - }, - { - name = "content", - description = "Content to write to the file", - type = "string", - }, + { + name = "path", + description = "Path to the file to write", + type = "string", + }, + { + name = "content", + description = "Content to write to the file", + type = "string", + }, } M.returns = { - { - name = "success", - description = "Whether the file was written successfully", - type = "boolean", - }, - { - name = "error", - description = "Error message if write failed", - type = "string", - optional = true, - }, + { + name = "success", + description = "Whether the file was written successfully", + type = "boolean", + }, + { + name = "error", + description = "Error message if write failed", + type = "string", + optional = true, + }, } -return M \ No newline at end of file +return M diff --git a/lua/codetyper/parser.lua b/lua/codetyper/parser.lua index 911f4ba..42dce05 100644 --- a/lua/codetyper/parser.lua +++ b/lua/codetyper/parser.lua @@ -7,13 +7,13 @@ local logger = require("codetyper.support.logger") -- Get current codetyper configuration at call time local function get_config() - local ok, codetyper = pcall(require, "codetyper") - if ok and codetyper.get_config then - return codetyper.get_config() or {} - end - -- Fall back to defaults if codetyper isn't available - local defaults = require("codetyper.config.defaults") - return defaults.get_defaults() + local ok, codetyper = pcall(require, "codetyper") + if ok and codetyper.get_config then + return codetyper.get_config() or {} + end + -- Fall back to defaults if codetyper isn't available + local defaults = require("codetyper.config.defaults") + return defaults.get_defaults() end --- Find all prompts in buffer content @@ -22,231 +22,222 @@ end ---@param close_tag string Closing tag ---@return CoderPrompt[] List of found prompts function M.find_prompts(content, open_tag, close_tag) - logger.func_entry("parser", "find_prompts", { - content_length = #content, - open_tag = open_tag, - close_tag = close_tag, - }) + logger.func_entry("parser", "find_prompts", { + content_length = #content, + open_tag = open_tag, + close_tag = close_tag, + }) - local prompts = {} - local escaped_open = utils.escape_pattern(open_tag) - local escaped_close = utils.escape_pattern(close_tag) + local prompts = {} + local escaped_open = utils.escape_pattern(open_tag) + local escaped_close = utils.escape_pattern(close_tag) - local lines = vim.split(content, "\n", { plain = true }) - local in_prompt = false - local current_prompt = nil - local prompt_content = {} + local lines = vim.split(content, "\n", { plain = true }) + local in_prompt = false + local current_prompt = nil + local prompt_content = {} - logger.debug("parser", "find_prompts: parsing " .. #lines .. " lines") + logger.debug("parser", "find_prompts: parsing " .. #lines .. " lines") - for line_num, line in ipairs(lines) do - if not in_prompt then - -- Look for opening tag - local start_col = line:find(escaped_open) - if start_col then - logger.debug("parser", "find_prompts: found opening tag at line " .. line_num .. ", col " .. start_col) - in_prompt = true - current_prompt = { - start_line = line_num, - start_col = start_col, - content = "", - } - -- Get content after opening tag on same line - local after_tag = line:sub(start_col + #open_tag) - local end_col = after_tag:find(escaped_close) - if end_col then - -- Single line prompt - current_prompt.content = after_tag:sub(1, end_col - 1) - current_prompt.end_line = line_num - current_prompt.end_col = start_col + #open_tag + end_col + #close_tag - 2 - table.insert(prompts, current_prompt) - logger.debug("parser", "find_prompts: single-line prompt completed at line " .. line_num) - in_prompt = false - current_prompt = nil - else - table.insert(prompt_content, after_tag) - end - end - else - -- Look for closing tag - local end_col = line:find(escaped_close) - if end_col then - -- Found closing tag - local before_tag = line:sub(1, end_col - 1) - table.insert(prompt_content, before_tag) - current_prompt.content = table.concat(prompt_content, "\n") - current_prompt.end_line = line_num - current_prompt.end_col = end_col + #close_tag - 1 - table.insert(prompts, current_prompt) - logger.debug( - "parser", - "find_prompts: multi-line prompt completed at line " - .. line_num - .. ", total lines: " - .. #prompt_content - ) - in_prompt = false - current_prompt = nil - prompt_content = {} - else - table.insert(prompt_content, line) - end - end - end + for line_num, line in ipairs(lines) do + if not in_prompt then + -- Look for opening tag + local start_col = line:find(escaped_open) + if start_col then + logger.debug("parser", "find_prompts: found opening tag at line " .. line_num .. ", col " .. start_col) + in_prompt = true + current_prompt = { + start_line = line_num, + start_col = start_col, + content = "", + } + -- Get content after opening tag on same line + local after_tag = line:sub(start_col + #open_tag) + local end_col = after_tag:find(escaped_close) + if end_col then + -- Single line prompt + current_prompt.content = after_tag:sub(1, end_col - 1) + current_prompt.end_line = line_num + current_prompt.end_col = start_col + #open_tag + end_col + #close_tag - 2 + table.insert(prompts, current_prompt) + logger.debug("parser", "find_prompts: single-line prompt completed at line " .. line_num) + in_prompt = false + current_prompt = nil + else + table.insert(prompt_content, after_tag) + end + end + else + -- Look for closing tag + local end_col = line:find(escaped_close) + if end_col then + -- Found closing tag + local before_tag = line:sub(1, end_col - 1) + table.insert(prompt_content, before_tag) + current_prompt.content = table.concat(prompt_content, "\n") + current_prompt.end_line = line_num + current_prompt.end_col = end_col + #close_tag - 1 + table.insert(prompts, current_prompt) + logger.debug( + "parser", + "find_prompts: multi-line prompt completed at line " .. line_num .. ", total lines: " .. #prompt_content + ) + in_prompt = false + current_prompt = nil + prompt_content = {} + else + table.insert(prompt_content, line) + end + end + end - logger.debug("parser", "find_prompts: found " .. #prompts .. " prompts total") - logger.func_exit("parser", "find_prompts", "found " .. #prompts .. " prompts") + logger.debug("parser", "find_prompts: found " .. #prompts .. " prompts total") + logger.func_exit("parser", "find_prompts", "found " .. #prompts .. " prompts") - return prompts + return prompts end --- Find prompts in a buffer ---@param bufnr number Buffer number ---@return CoderPrompt[] List of found prompts function M.find_prompts_in_buffer(bufnr) - logger.func_entry("parser", "find_prompts_in_buffer", { bufnr = bufnr }) + logger.func_entry("parser", "find_prompts_in_buffer", { bufnr = bufnr }) - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local content = table.concat(lines, "\n") + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local content = table.concat(lines, "\n") - logger.debug( - "parser", - "find_prompts_in_buffer: bufnr=" .. bufnr .. ", lines=" .. #lines .. ", content_length=" .. #content - ) + logger.debug( + "parser", + "find_prompts_in_buffer: bufnr=" .. bufnr .. ", lines=" .. #lines .. ", content_length=" .. #content + ) - local cfg = get_config() - local result = M.find_prompts(content, cfg.patterns.open_tag, cfg.patterns.close_tag) + local cfg = get_config() + local result = M.find_prompts(content, cfg.patterns.open_tag, cfg.patterns.close_tag) - logger.func_exit("parser", "find_prompts_in_buffer", "found " .. #result .. " prompts") - return result + logger.func_exit("parser", "find_prompts_in_buffer", "found " .. #result .. " prompts") + return result end --- Get prompt at cursor position ---@param bufnr? number Buffer number (default: current) ---@return CoderPrompt|nil Prompt at cursor or nil function M.get_prompt_at_cursor(bufnr) - bufnr = bufnr or vim.api.nvim_get_current_buf() - local cursor = vim.api.nvim_win_get_cursor(0) - local line = cursor[1] - local col = cursor[2] + 1 -- Convert to 1-indexed + bufnr = bufnr or vim.api.nvim_get_current_buf() + local cursor = vim.api.nvim_win_get_cursor(0) + local line = cursor[1] + local col = cursor[2] + 1 -- Convert to 1-indexed - logger.func_entry("parser", "get_prompt_at_cursor", { - bufnr = bufnr, - line = line, - col = col, - }) + logger.func_entry("parser", "get_prompt_at_cursor", { + bufnr = bufnr, + line = line, + col = col, + }) - local prompts = M.find_prompts_in_buffer(bufnr) + local prompts = M.find_prompts_in_buffer(bufnr) - logger.debug("parser", "get_prompt_at_cursor: checking " .. #prompts .. " prompts") + logger.debug("parser", "get_prompt_at_cursor: checking " .. #prompts .. " prompts") - for i, prompt in ipairs(prompts) do - logger.debug( - "parser", - "get_prompt_at_cursor: checking prompt " - .. i - .. " (lines " - .. prompt.start_line - .. "-" - .. prompt.end_line - .. ")" - ) - if line >= prompt.start_line and line <= prompt.end_line then - logger.debug("parser", "get_prompt_at_cursor: cursor line " .. line .. " is within prompt line range") - if line == prompt.start_line and col < prompt.start_col then - logger.debug( - "parser", - "get_prompt_at_cursor: cursor col " .. col .. " is before prompt start_col " .. prompt.start_col - ) - goto continue - end - if line == prompt.end_line and col > prompt.end_col then - logger.debug( - "parser", - "get_prompt_at_cursor: cursor col " .. col .. " is after prompt end_col " .. prompt.end_col - ) - goto continue - end - logger.debug("parser", "get_prompt_at_cursor: found prompt at cursor") - logger.func_exit("parser", "get_prompt_at_cursor", "prompt found") - return prompt - end - ::continue:: - end + for i, prompt in ipairs(prompts) do + logger.debug( + "parser", + "get_prompt_at_cursor: checking prompt " .. i .. " (lines " .. prompt.start_line .. "-" .. prompt.end_line .. ")" + ) + if line >= prompt.start_line and line <= prompt.end_line then + logger.debug("parser", "get_prompt_at_cursor: cursor line " .. line .. " is within prompt line range") + if line == prompt.start_line and col < prompt.start_col then + logger.debug( + "parser", + "get_prompt_at_cursor: cursor col " .. col .. " is before prompt start_col " .. prompt.start_col + ) + goto continue + end + if line == prompt.end_line and col > prompt.end_col then + logger.debug( + "parser", + "get_prompt_at_cursor: cursor col " .. col .. " is after prompt end_col " .. prompt.end_col + ) + goto continue + end + logger.debug("parser", "get_prompt_at_cursor: found prompt at cursor") + logger.func_exit("parser", "get_prompt_at_cursor", "prompt found") + return prompt + end + ::continue:: + end - logger.debug("parser", "get_prompt_at_cursor: no prompt found at cursor") - logger.func_exit("parser", "get_prompt_at_cursor", nil) - return nil + logger.debug("parser", "get_prompt_at_cursor: no prompt found at cursor") + logger.func_exit("parser", "get_prompt_at_cursor", nil) + return nil end --- Get the last closed prompt in buffer ---@param bufnr? number Buffer number (default: current) ---@return CoderPrompt|nil Last prompt or nil function M.get_last_prompt(bufnr) - bufnr = bufnr or vim.api.nvim_get_current_buf() + bufnr = bufnr or vim.api.nvim_get_current_buf() - logger.func_entry("parser", "get_last_prompt", { bufnr = bufnr }) + logger.func_entry("parser", "get_last_prompt", { bufnr = bufnr }) - local prompts = M.find_prompts_in_buffer(bufnr) + local prompts = M.find_prompts_in_buffer(bufnr) - if #prompts > 0 then - local last = prompts[#prompts] - logger.debug("parser", "get_last_prompt: returning prompt at line " .. last.start_line) - logger.func_exit("parser", "get_last_prompt", "prompt at line " .. last.start_line) - return last - end + if #prompts > 0 then + local last = prompts[#prompts] + logger.debug("parser", "get_last_prompt: returning prompt at line " .. last.start_line) + logger.func_exit("parser", "get_last_prompt", "prompt at line " .. last.start_line) + return last + end - logger.debug("parser", "get_last_prompt: no prompts found") - logger.func_exit("parser", "get_last_prompt", nil) - return nil + logger.debug("parser", "get_last_prompt: no prompts found") + logger.func_exit("parser", "get_last_prompt", nil) + return nil end --- Extract the prompt type from content ---@param content string Prompt content ---@return "refactor" | "add" | "document" | "explain" | "generic" Prompt type function M.detect_prompt_type(content) - logger.func_entry("parser", "detect_prompt_type", { content_preview = content:sub(1, 50) }) + logger.func_entry("parser", "detect_prompt_type", { content_preview = content:sub(1, 50) }) - local lower = content:lower() + local lower = content:lower() - if lower:match("refactor") then - logger.debug("parser", "detect_prompt_type: detected 'refactor'") - logger.func_exit("parser", "detect_prompt_type", "refactor") - return "refactor" - elseif lower:match("add") or lower:match("create") or lower:match("implement") then - logger.debug("parser", "detect_prompt_type: detected 'add'") - logger.func_exit("parser", "detect_prompt_type", "add") - return "add" - elseif lower:match("document") or lower:match("comment") or lower:match("jsdoc") then - logger.debug("parser", "detect_prompt_type: detected 'document'") - logger.func_exit("parser", "detect_prompt_type", "document") - return "document" - elseif lower:match("explain") or lower:match("what") or lower:match("how") then - logger.debug("parser", "detect_prompt_type: detected 'explain'") - logger.func_exit("parser", "detect_prompt_type", "explain") - return "explain" - end + if lower:match("refactor") then + logger.debug("parser", "detect_prompt_type: detected 'refactor'") + logger.func_exit("parser", "detect_prompt_type", "refactor") + return "refactor" + elseif lower:match("add") or lower:match("create") or lower:match("implement") then + logger.debug("parser", "detect_prompt_type: detected 'add'") + logger.func_exit("parser", "detect_prompt_type", "add") + return "add" + elseif lower:match("document") or lower:match("comment") or lower:match("jsdoc") then + logger.debug("parser", "detect_prompt_type: detected 'document'") + logger.func_exit("parser", "detect_prompt_type", "document") + return "document" + elseif lower:match("explain") or lower:match("what") or lower:match("how") then + logger.debug("parser", "detect_prompt_type: detected 'explain'") + logger.func_exit("parser", "detect_prompt_type", "explain") + return "explain" + end - logger.debug("parser", "detect_prompt_type: detected 'generic'") - logger.func_exit("parser", "detect_prompt_type", "generic") - return "generic" + logger.debug("parser", "detect_prompt_type: detected 'generic'") + logger.func_exit("parser", "detect_prompt_type", "generic") + return "generic" end --- Clean prompt content (trim whitespace, normalize newlines) ---@param content string Raw prompt content ---@return string Cleaned content function M.clean_prompt(content) - logger.func_entry("parser", "clean_prompt", { content_length = #content }) + logger.func_entry("parser", "clean_prompt", { content_length = #content }) - -- Trim leading/trailing whitespace - content = content:match("^%s*(.-)%s*$") - -- Normalize multiple newlines - content = content:gsub("\n\n\n+", "\n\n") + -- Trim leading/trailing whitespace + content = content:match("^%s*(.-)%s*$") + -- Normalize multiple newlines + content = content:gsub("\n\n\n+", "\n\n") - logger.debug("parser", "clean_prompt: cleaned from " .. #content .. " chars") - logger.func_exit("parser", "clean_prompt", "length=" .. #content) + logger.debug("parser", "clean_prompt: cleaned from " .. #content .. " chars") + logger.func_exit("parser", "clean_prompt", "length=" .. #content) - return content + return content end --- Check if line contains a closing tag @@ -254,48 +245,43 @@ end ---@param close_tag string Closing tag ---@return boolean function M.has_closing_tag(line, close_tag) - logger.func_entry("parser", "has_closing_tag", { line_preview = line:sub(1, 30), close_tag = close_tag }) + logger.func_entry("parser", "has_closing_tag", { line_preview = line:sub(1, 30), close_tag = close_tag }) - local result = line:find(utils.escape_pattern(close_tag)) ~= nil + local result = line:find(utils.escape_pattern(close_tag)) ~= nil - logger.debug("parser", "has_closing_tag: result=" .. tostring(result)) - logger.func_exit("parser", "has_closing_tag", result) + logger.debug("parser", "has_closing_tag: result=" .. tostring(result)) + logger.func_exit("parser", "has_closing_tag", result) - return result + return result end --- Check if buffer has any unclosed prompts ---@param bufnr? number Buffer number (default: current) ---@return boolean function M.has_unclosed_prompts(bufnr) - bufnr = bufnr or vim.api.nvim_get_current_buf() + bufnr = bufnr or vim.api.nvim_get_current_buf() - logger.func_entry("parser", "has_unclosed_prompts", { bufnr = bufnr }) + logger.func_entry("parser", "has_unclosed_prompts", { bufnr = bufnr }) - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local content = table.concat(lines, "\n") + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local content = table.concat(lines, "\n") - local cfg = get_config() - local escaped_open = utils.escape_pattern(cfg.patterns.open_tag) - local escaped_close = utils.escape_pattern(cfg.patterns.close_tag) + local cfg = get_config() + local escaped_open = utils.escape_pattern(cfg.patterns.open_tag) + local escaped_close = utils.escape_pattern(cfg.patterns.close_tag) - local _, open_count = content:gsub(escaped_open, "") - local _, close_count = content:gsub(escaped_close, "") + local _, open_count = content:gsub(escaped_open, "") + local _, close_count = content:gsub(escaped_close, "") - local has_unclosed = open_count > close_count + local has_unclosed = open_count > close_count - logger.debug( - "parser", - "has_unclosed_prompts: open=" - .. open_count - .. ", close=" - .. close_count - .. ", unclosed=" - .. tostring(has_unclosed) - ) - logger.func_exit("parser", "has_unclosed_prompts", has_unclosed) + logger.debug( + "parser", + "has_unclosed_prompts: open=" .. open_count .. ", close=" .. close_count .. ", unclosed=" .. tostring(has_unclosed) + ) + logger.func_exit("parser", "has_unclosed_prompts", has_unclosed) - return has_unclosed + return has_unclosed end --- Extract file references from prompt content @@ -303,39 +289,39 @@ end ---@param content string Prompt content ---@return string[] List of file references function M.extract_file_references(content) - logger.func_entry("parser", "extract_file_references", { content_length = #content }) + logger.func_entry("parser", "extract_file_references", { content_length = #content }) - local files = {} - -- Pattern: @ followed by word char, dot, underscore, or dash as FIRST char - -- Then optionally more path characters including / - -- This ensures @/ is NOT matched (/ cannot be first char) - for file in content:gmatch("@([%w%._%-][%w%._%-/]*)") do - if file ~= "" then - table.insert(files, file) - logger.debug("parser", "extract_file_references: found file reference: " .. file) - end - end + local files = {} + -- Pattern: @ followed by word char, dot, underscore, or dash as FIRST char + -- Then optionally more path characters including / + -- This ensures @/ is NOT matched (/ cannot be first char) + for file in content:gmatch("@([%w%._%-][%w%._%-/]*)") do + if file ~= "" then + table.insert(files, file) + logger.debug("parser", "extract_file_references: found file reference: " .. file) + end + end - logger.debug("parser", "extract_file_references: found " .. #files .. " file references") - logger.func_exit("parser", "extract_file_references", files) + logger.debug("parser", "extract_file_references: found " .. #files .. " file references") + logger.func_exit("parser", "extract_file_references", files) - return files + return files end --- Remove file references from prompt content (for clean prompt text) ---@param content string Prompt content ---@return string Cleaned content without file references function M.strip_file_references(content) - logger.func_entry("parser", "strip_file_references", { content_length = #content }) + logger.func_entry("parser", "strip_file_references", { content_length = #content }) - -- Remove @filename patterns but preserve @/ closing tag - -- Pattern requires first char after @ to be word char, dot, underscore, or dash (NOT /) - local result = content:gsub("@([%w%._%-][%w%._%-/]*)", "") + -- Remove @filename patterns but preserve @/ closing tag + -- Pattern requires first char after @ to be word char, dot, underscore, or dash (NOT /) + local result = content:gsub("@([%w%._%-][%w%._%-/]*)", "") - logger.debug("parser", "strip_file_references: stripped " .. (#content - #result) .. " chars") - logger.func_exit("parser", "strip_file_references", "length=" .. #result) + logger.debug("parser", "strip_file_references: stripped " .. (#content - #result) .. " chars") + logger.func_exit("parser", "strip_file_references", "length=" .. #result) - return result + return result end --- Check if cursor is inside an unclosed prompt tag @@ -343,88 +329,88 @@ end ---@return boolean is_inside Whether cursor is inside an open tag ---@return number|nil start_line Line where the open tag starts function M.is_cursor_in_open_tag(bufnr) - bufnr = bufnr or vim.api.nvim_get_current_buf() + bufnr = bufnr or vim.api.nvim_get_current_buf() - logger.func_entry("parser", "is_cursor_in_open_tag", { bufnr = bufnr }) + logger.func_entry("parser", "is_cursor_in_open_tag", { bufnr = bufnr }) - local cursor = vim.api.nvim_win_get_cursor(0) - local cursor_line = cursor[1] + local cursor = vim.api.nvim_win_get_cursor(0) + local cursor_line = cursor[1] - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, cursor_line, false) - local cfg = get_config() - local escaped_open = utils.escape_pattern(cfg.patterns.open_tag) - local escaped_close = utils.escape_pattern(cfg.patterns.close_tag) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, cursor_line, false) + local cfg = get_config() + local escaped_open = utils.escape_pattern(cfg.patterns.open_tag) + local escaped_close = utils.escape_pattern(cfg.patterns.close_tag) - local open_count = 0 - local close_count = 0 - local last_open_line = nil + local open_count = 0 + local close_count = 0 + local last_open_line = nil - for line_num, line in ipairs(lines) do - -- Count opens on this line - for _ in line:gmatch(escaped_open) do - open_count = open_count + 1 - last_open_line = line_num - logger.debug("parser", "is_cursor_in_open_tag: found open tag at line " .. line_num) - end - -- Count closes on this line - for _ in line:gmatch(escaped_close) do - close_count = close_count + 1 - logger.debug("parser", "is_cursor_in_open_tag: found close tag at line " .. line_num) - end - end + for line_num, line in ipairs(lines) do + -- Count opens on this line + for _ in line:gmatch(escaped_open) do + open_count = open_count + 1 + last_open_line = line_num + logger.debug("parser", "is_cursor_in_open_tag: found open tag at line " .. line_num) + end + -- Count closes on this line + for _ in line:gmatch(escaped_close) do + close_count = close_count + 1 + logger.debug("parser", "is_cursor_in_open_tag: found close tag at line " .. line_num) + end + end - local is_inside = open_count > close_count + local is_inside = open_count > close_count - logger.debug( - "parser", - "is_cursor_in_open_tag: open=" - .. open_count - .. ", close=" - .. close_count - .. ", is_inside=" - .. tostring(is_inside) - .. ", last_open_line=" - .. tostring(last_open_line) - ) - logger.func_exit("parser", "is_cursor_in_open_tag", { is_inside = is_inside, last_open_line = last_open_line }) + logger.debug( + "parser", + "is_cursor_in_open_tag: open=" + .. open_count + .. ", close=" + .. close_count + .. ", is_inside=" + .. tostring(is_inside) + .. ", last_open_line=" + .. tostring(last_open_line) + ) + logger.func_exit("parser", "is_cursor_in_open_tag", { is_inside = is_inside, last_open_line = last_open_line }) - return is_inside, is_inside and last_open_line or nil + return is_inside, is_inside and last_open_line or nil end --- Get the word being typed after @ symbol ---@param bufnr? number Buffer number ---@return string|nil prefix The text after @ being typed, or nil if not typing a file ref function M.get_file_ref_prefix(bufnr) - bufnr = bufnr or vim.api.nvim_get_current_buf() + bufnr = bufnr or vim.api.nvim_get_current_buf() - logger.func_entry("parser", "get_file_ref_prefix", { bufnr = bufnr }) + logger.func_entry("parser", "get_file_ref_prefix", { bufnr = bufnr }) - local cursor = vim.api.nvim_win_get_cursor(0) - local line = vim.api.nvim_buf_get_lines(bufnr, cursor[1] - 1, cursor[1], false)[1] - if not line then - logger.debug("parser", "get_file_ref_prefix: no line at cursor") - logger.func_exit("parser", "get_file_ref_prefix", nil) - return nil - end + local cursor = vim.api.nvim_win_get_cursor(0) + local line = vim.api.nvim_buf_get_lines(bufnr, cursor[1] - 1, cursor[1], false)[1] + if not line then + logger.debug("parser", "get_file_ref_prefix: no line at cursor") + logger.func_exit("parser", "get_file_ref_prefix", nil) + return nil + end - local col = cursor[2] - local before_cursor = line:sub(1, col) + local col = cursor[2] + local before_cursor = line:sub(1, col) - -- Check if we're typing after @ but not @/ - -- Match @ followed by optional path characters at end of string - local prefix = before_cursor:match("@([%w%._%-/]*)$") + -- Check if we're typing after @ but not @/ + -- Match @ followed by optional path characters at end of string + local prefix = before_cursor:match("@([%w%._%-/]*)$") - -- Make sure it's not the closing tag pattern - if prefix and before_cursor:sub(-2) == "@/" then - logger.debug("parser", "get_file_ref_prefix: closing tag detected, returning nil") - logger.func_exit("parser", "get_file_ref_prefix", nil) - return nil - end + -- Make sure it's not the closing tag pattern + if prefix and before_cursor:sub(-2) == "@/" then + logger.debug("parser", "get_file_ref_prefix: closing tag detected, returning nil") + logger.func_exit("parser", "get_file_ref_prefix", nil) + return nil + end - logger.debug("parser", "get_file_ref_prefix: prefix=" .. tostring(prefix)) - logger.func_exit("parser", "get_file_ref_prefix", prefix) + logger.debug("parser", "get_file_ref_prefix: prefix=" .. tostring(prefix)) + logger.func_exit("parser", "get_file_ref_prefix", prefix) - return prefix + return prefix end logger.info("parser", "Parser module loaded") diff --git a/lua/codetyper/prompts/agents/bash.lua b/lua/codetyper/prompts/agents/bash.lua index db10d83..b7d2248 100644 --- a/lua/codetyper/prompts/agents/bash.lua +++ b/lua/codetyper/prompts/agents/bash.lua @@ -1,3 +1,5 @@ +local M = {} + M.description = [[Executes a bash command in a shell. IMPORTANT RULES: diff --git a/lua/codetyper/prompts/agents/diff.lua b/lua/codetyper/prompts/agents/diff.lua index 5c9bb5e..08ef74c 100644 --- a/lua/codetyper/prompts/agents/diff.lua +++ b/lua/codetyper/prompts/agents/diff.lua @@ -3,64 +3,63 @@ local M = {} --- Bash approval dialog strings M.bash_approval = { - title = " BASH COMMAND APPROVAL", - divider = " " .. string.rep("─", 56), - command_label = " Command:", - warning_prefix = " ⚠️ WARNING: ", - options = { - " [y] Allow once - Execute this command", - " [s] Allow this session - Auto-allow until restart", - " [a] Add to allow list - Always allow this command", - " [n] Reject - Cancel execution", - }, - cancel_hint = " Press key to choose | [q] or [Esc] to cancel", + title = " BASH COMMAND APPROVAL", + divider = " " .. string.rep("─", 56), + command_label = " Command:", + warning_prefix = " ⚠️ WARNING: ", + options = { + " [y] Allow once - Execute this command", + " [s] Allow this session - Auto-allow until restart", + " [a] Add to allow list - Always allow this command", + " [n] Reject - Cancel execution", + }, + cancel_hint = " Press key to choose | [q] or [Esc] to cancel", } --- Diff view help message M.diff_help = { - { "Diff: ", "Normal" }, - { "{path}", "Directory" }, - { " | ", "Normal" }, - { "y/", "Keyword" }, - { " approve ", "Normal" }, - { "n/q/", "Keyword" }, - { " reject ", "Normal" }, - { "", "Keyword" }, - { " switch panes", "Normal" }, + { "Diff: ", "Normal" }, + { "{path}", "Directory" }, + { " | ", "Normal" }, + { "y/", "Keyword" }, + { " approve ", "Normal" }, + { "n/q/", "Keyword" }, + { " reject ", "Normal" }, + { "", "Keyword" }, + { " switch panes", "Normal" }, } - --- Review UI interface strings M.review = { - diff_header = { - top = "╭─ %s %s %s ─────────────────────────────────────", - path = "│ %s", - op = "│ Operation: %s", - status = "│ Status: %s", - bottom = "╰────────────────────────────────────────────────────", - }, - list_menu = { - top = "╭─ Changes (%s) ──────────╮", - items = { - "│ │", - "│ j/k: navigate │", - "│ Enter: view diff │", - "│ a: approve r: reject │", - "│ A: approve all │", - "│ q: close │", - }, - bottom = "╰──────────────────────────────╯", - }, - status = { - applied = "Applied", - approved = "Approved", - pending = "Pending", - }, - messages = { - no_changes = " No changes to review", - no_changes_short = "No changes to review", - applied_count = "Applied %d change(s)", - }, + diff_header = { + top = "╭─ %s %s %s ─────────────────────────────────────", + path = "│ %s", + op = "│ Operation: %s", + status = "│ Status: %s", + bottom = "╰────────────────────────────────────────────────────", + }, + list_menu = { + top = "╭─ Changes (%s) ──────────╮", + items = { + "│ │", + "│ j/k: navigate │", + "│ Enter: view diff │", + "│ a: approve r: reject │", + "│ A: approve all │", + "│ q: close │", + }, + bottom = "╰──────────────────────────────╯", + }, + status = { + applied = "Applied", + approved = "Approved", + pending = "Pending", + }, + messages = { + no_changes = " No changes to review", + no_changes_short = "No changes to review", + applied_count = "Applied %d change(s)", + }, } return M diff --git a/lua/codetyper/prompts/agents/edit.lua b/lua/codetyper/prompts/agents/edit.lua index 2603a83..d230586 100644 --- a/lua/codetyper/prompts/agents/edit.lua +++ b/lua/codetyper/prompts/agents/edit.lua @@ -1,3 +1,5 @@ +local M = {} + M.description = [[Makes a targeted edit to a file by replacing text. The old_string should match the content you want to replace. The tool uses multiple diff --git a/lua/codetyper/prompts/agents/grep.lua b/lua/codetyper/prompts/agents/grep.lua index 9713c41..eaac880 100644 --- a/lua/codetyper/prompts/agents/grep.lua +++ b/lua/codetyper/prompts/agents/grep.lua @@ -1,41 +1,43 @@ +local M = {} + M.params = { - { - name = "pattern", - description = "Regular expression pattern to search for", - type = "string", - }, - { - name = "path", - description = "Directory or file to search in (default: project root)", - type = "string", - optional = true, - }, - { - name = "include", - description = "File glob pattern to include (e.g., '*.lua')", - type = "string", - optional = true, - }, - { - name = "max_results", - description = "Maximum number of results (default: 50)", - type = "integer", - optional = true, - }, + { + name = "pattern", + description = "Regular expression pattern to search for", + type = "string", + }, + { + name = "path", + description = "Directory or file to search in (default: project root)", + type = "string", + optional = true, + }, + { + name = "include", + description = "File glob pattern to include (e.g., '*.lua')", + type = "string", + optional = true, + }, + { + name = "max_results", + description = "Maximum number of results (default: 50)", + type = "integer", + optional = true, + }, } M.returns = { - { - name = "matches", - description = "JSON array of matches with file, line_number, and content", - type = "string", - }, - { - name = "error", - description = "Error message if search failed", - type = "string", - optional = true, - }, + { + name = "matches", + description = "JSON array of matches with file, line_number, and content", + type = "string", + }, + { + name = "error", + description = "Error message if search failed", + type = "string", + optional = true, + }, } return M diff --git a/lua/codetyper/prompts/agents/init.lua b/lua/codetyper/prompts/agents/init.lua index 2b98333..48d0f77 100644 --- a/lua/codetyper/prompts/agents/init.lua +++ b/lua/codetyper/prompts/agents/init.lua @@ -23,7 +23,7 @@ end --- System prompt for agent mode M.system = - [[You are an expert AI coding assistant integrated into Neovim. You MUST use the provided tools to accomplish tasks. + [[You are an expert AI coding assistant integrated into Neovim. You MUST use the provided tools to accomplish tasks. ## CRITICAL: YOU MUST USE TOOLS diff --git a/lua/codetyper/prompts/agents/intent.lua b/lua/codetyper/prompts/agents/intent.lua index 94a2f47..ac57222 100644 --- a/lua/codetyper/prompts/agents/intent.lua +++ b/lua/codetyper/prompts/agents/intent.lua @@ -2,48 +2,48 @@ local M = {} M.modifiers = { - complete = [[ + complete = [[ You are completing an incomplete function. Return the complete function with all missing parts filled in. Keep the existing signature unless changes are required. Output only the code, no explanations.]], - refactor = [[ + refactor = [[ You are refactoring existing code. Improve the code structure while maintaining the same behavior. Keep the function signature unchanged. Output only the refactored code, no explanations.]], - fix = [[ + fix = [[ You are fixing a bug in the code. Identify and correct the issue while minimizing changes. Preserve the original intent of the code. Output only the fixed code, no explanations.]], - add = [[ + add = [[ You are adding new code. Follow the existing code style and conventions. Output only the new code to be inserted, no explanations.]], - document = [[ + document = [[ You are adding documentation to the code. Add appropriate comments/docstrings for the function. Include parameter types, return types, and description. Output the complete function with documentation.]], - test = [[ + test = [[ You are generating tests for the code. Create comprehensive unit tests covering edge cases. Follow the testing conventions of the project. Output only the test code, no explanations.]], - optimize = [[ + optimize = [[ You are optimizing code for performance. Improve efficiency while maintaining correctness. Document any significant algorithmic changes. Output only the optimized code, no explanations.]], - explain = [[ + explain = [[ You are documenting code by adding documentation comments above it. Generate ONLY the documentation comment block (using the correct comment syntax for the file's language). Include: a brief description of what the code does, parameter types and descriptions, return type and description, and any important notes about edge cases or side effects. diff --git a/lua/codetyper/prompts/agents/loop.lua b/lua/codetyper/prompts/agents/loop.lua index 39bb09e..204f557 100644 --- a/lua/codetyper/prompts/agents/loop.lua +++ b/lua/codetyper/prompts/agents/loop.lua @@ -20,32 +20,6 @@ When you need to perform a task: Always explain your reasoning before using tools. When you're done, provide a clear summary of what was accomplished.]] -M.dispatch_prompt = [[You are a research assistant. Your task is to find information and report back. -You have access to: view (read files), grep (search content), glob (find files). -Be thorough and report your findings clearly.]] - -### File Operations -- **read_file**: Read any file. Parameters: path (string) -- **write_file**: Create or overwrite files. Parameters: path (string), content (string) -- **edit_file**: Modify existing files. Parameters: path (string), find (string), replace (string) -- **list_directory**: List files and directories. Parameters: path (string, optional), recursive (boolean, optional) -- **search_files**: Find files. Parameters: pattern (string), content (string), path (string) -- **delete_file**: Delete a file. Parameters: path (string), reason (string) - -### Shell Commands -- **bash**: Run shell commands. Parameters: command (string) - -## WORKFLOW - -1. **Analyze**: Understand the user's request. -2. **Explore**: Use `list_directory`, `search_files`, or `read_file` to find relevant files. -3. **Plan**: Think about what needs to be changed. -4. **Execute**: Use `edit_file`, `write_file`, or `bash` to apply changes. -5. **Verify**: You can check files after editing. - -Always verify context before making changes. -]] - M.dispatch_prompt = [[ You are a research assistant. Your job is to explore the codebase and answer the user's question or find specific information. You have access to: view (read files), grep (search content), glob (find files). diff --git a/lua/codetyper/prompts/agents/modal.lua b/lua/codetyper/prompts/agents/modal.lua index 9d205c5..15b4c4e 100644 --- a/lua/codetyper/prompts/agents/modal.lua +++ b/lua/codetyper/prompts/agents/modal.lua @@ -3,12 +3,12 @@ local M = {} --- Modal UI strings M.ui = { - files_header = { "", "-- No files detected in LLM response --" }, - llm_response_header = "-- LLM Response: --", - suggested_commands_header = "-- Suggested commands: --", - commands_hint = "-- Press to run a command, or r to run all --", - input_header = "-- Enter additional context below (Ctrl-Enter to submit, Esc to cancel) --", - project_inspect_header = { "", "-- Project inspection results --" }, + files_header = { "", "-- No files detected in LLM response --" }, + llm_response_header = "-- LLM Response: --", + suggested_commands_header = "-- Suggested commands: --", + commands_hint = "-- Press to run a command, or r to run all --", + input_header = "-- Enter additional context below (Ctrl-Enter to submit, Esc to cancel) --", + project_inspect_header = { "", "-- Project inspection results --" }, } return M diff --git a/lua/codetyper/prompts/agents/personas.lua b/lua/codetyper/prompts/agents/personas.lua index 4c8255f..1ff1fcd 100644 --- a/lua/codetyper/prompts/agents/personas.lua +++ b/lua/codetyper/prompts/agents/personas.lua @@ -2,10 +2,10 @@ local M = {} M.builtin = { - coder = { - name = "coder", - description = "Full-featured coding agent with file modification capabilities", - system_prompt = [[You are an expert software engineer. You have access to tools to read, write, and modify files. + coder = { + name = "coder", + description = "Full-featured coding agent with file modification capabilities", + system_prompt = [[You are an expert software engineer. You have access to tools to read, write, and modify files. ## Your Capabilities - Read files to understand the codebase @@ -26,12 +26,12 @@ M.builtin = { - Make precise edits using exact string matching - Explain your reasoning before making changes - If unsure, ask for clarification]], - tools = { "view", "edit", "write", "grep", "glob", "bash" }, - }, - planner = { - name = "planner", - description = "Planning agent - read-only, helps design implementations", - system_prompt = [[You are a software architect. Analyze codebases and create implementation plans. + tools = { "view", "edit", "write", "grep", "glob", "bash" }, + }, + planner = { + name = "planner", + description = "Planning agent - read-only, helps design implementations", + system_prompt = [[You are a software architect. Analyze codebases and create implementation plans. You can read files and search the codebase, but cannot modify files. Your role is to: @@ -41,18 +41,18 @@ Your role is to: 4. Suggest which files to modify and how Be thorough in your analysis before making recommendations.]], - tools = { "view", "grep", "glob" }, - }, - explorer = { - name = "explorer", - description = "Exploration agent - quickly find information in codebase", - system_prompt = [[You are a codebase exploration assistant. Find information quickly and report back. + tools = { "view", "grep", "glob" }, + }, + explorer = { + name = "explorer", + description = "Exploration agent - quickly find information in codebase", + system_prompt = [[You are a codebase exploration assistant. Find information quickly and report back. Your goal is to efficiently search and summarize findings. Use glob to find files, grep to search content, and view to read specific files. Be concise and focused in your responses.]], - tools = { "view", "grep", "glob" }, - }, + tools = { "view", "grep", "glob" }, + }, } return M diff --git a/lua/codetyper/prompts/agents/tools.lua b/lua/codetyper/prompts/agents/tools.lua index 4a7d5eb..df14bc6 100644 --- a/lua/codetyper/prompts/agents/tools.lua +++ b/lua/codetyper/prompts/agents/tools.lua @@ -2,14 +2,14 @@ local M = {} M.instructions = { - intro = "You have access to the following tools. To use a tool, respond with a JSON block.", - header = "To call a tool, output a JSON block like this:", - example = [[ + intro = "You have access to the following tools. To use a tool, respond with a JSON block.", + header = "To call a tool, output a JSON block like this:", + example = [[ ```json {"tool": "tool_name", "parameters": {"param1": "value1"}} ``` ]], - footer = [[ + footer = [[ After receiving tool results, continue your response or call another tool. When you're done, just respond normally without any tool calls. ]], diff --git a/lua/codetyper/prompts/agents/view.lua b/lua/codetyper/prompts/agents/view.lua index 7f38e4c..cafb2e8 100644 --- a/lua/codetyper/prompts/agents/view.lua +++ b/lua/codetyper/prompts/agents/view.lua @@ -8,4 +8,4 @@ Usage notes: - If content is truncated, use line ranges to read in chunks - Returns JSON with content, total_line_count, and is_truncated]] -return M \ No newline at end of file +return M diff --git a/lua/codetyper/prompts/agents/write.lua b/lua/codetyper/prompts/agents/write.lua index d15b87b..f2ecc1b 100644 --- a/lua/codetyper/prompts/agents/write.lua +++ b/lua/codetyper/prompts/agents/write.lua @@ -1,3 +1,5 @@ +local M = {} + M.description = [[Creates or overwrites a file with new content. IMPORTANT: diff --git a/lua/codetyper/support/gitignore.lua b/lua/codetyper/support/gitignore.lua index 856f51d..8df78af 100644 --- a/lua/codetyper/support/gitignore.lua +++ b/lua/codetyper/support/gitignore.lua @@ -6,8 +6,8 @@ local utils = require("codetyper.support.utils") --- Patterns to add to .gitignore local IGNORE_PATTERNS = { - "*.codetyper/*", - ".codetyper/", + "*.codetyper/*", + ".codetyper/", } --- Comment to identify codetyper entries @@ -18,102 +18,102 @@ local CODER_COMMENT = "# Codetyper.nvim - AI coding partner files" ---@param pattern string Pattern to check ---@return boolean local function pattern_exists(content, pattern) - local escaped = utils.escape_pattern(pattern) - return content:match("\n" .. escaped .. "\n") ~= nil - or content:match("^" .. escaped .. "\n") ~= nil - or content:match("\n" .. escaped .. "$") ~= nil - or content == pattern + local escaped = utils.escape_pattern(pattern) + return content:match("\n" .. escaped .. "\n") ~= nil + or content:match("^" .. escaped .. "\n") ~= nil + or content:match("\n" .. escaped .. "$") ~= nil + or content == pattern end --- Check if all patterns exist in gitignore content ---@param content string Gitignore content ---@return boolean, string[] All exist status and list of missing patterns local function all_patterns_exist(content) - local missing = {} - for _, pattern in ipairs(IGNORE_PATTERNS) do - if not pattern_exists(content, pattern) then - table.insert(missing, pattern) - end - end - return #missing == 0, missing + local missing = {} + for _, pattern in ipairs(IGNORE_PATTERNS) do + if not pattern_exists(content, pattern) then + table.insert(missing, pattern) + end + end + return #missing == 0, missing end --- Get the path to .gitignore in project root ---@return string|nil Path to .gitignore or nil function M.get_gitignore_path() - local root = utils.get_project_root() - if not root then - return nil - end - return root .. "/.gitignore" + local root = utils.get_project_root() + if not root then + return nil + end + return root .. "/.gitignore" end --- Check if coder files are already ignored ---@return boolean function M.is_ignored() - local gitignore_path = M.get_gitignore_path() - if not gitignore_path then - return false - end + local gitignore_path = M.get_gitignore_path() + if not gitignore_path then + return false + end - local content = utils.read_file(gitignore_path) - if not content then - return false - end + local content = utils.read_file(gitignore_path) + if not content then + return false + end - local all_exist, _ = all_patterns_exist(content) - return all_exist + local all_exist, _ = all_patterns_exist(content) + return all_exist end --- Add coder patterns to .gitignore ---@return boolean Success status function M.add_to_gitignore() - local gitignore_path = M.get_gitignore_path() - if not gitignore_path then - utils.notify("Could not determine project root", vim.log.levels.WARN) - return false - end + local gitignore_path = M.get_gitignore_path() + if not gitignore_path then + utils.notify("Could not determine project root", vim.log.levels.WARN) + return false + end - local content = utils.read_file(gitignore_path) - local patterns_to_add = {} + local content = utils.read_file(gitignore_path) + local patterns_to_add = {} - if content then - -- File exists, check which patterns are missing - local _, missing = all_patterns_exist(content) - if #missing == 0 then - return true -- All already ignored - end - patterns_to_add = missing - else - -- Create new .gitignore with all patterns - content = "" - patterns_to_add = IGNORE_PATTERNS - end + if content then + -- File exists, check which patterns are missing + local _, missing = all_patterns_exist(content) + if #missing == 0 then + return true -- All already ignored + end + patterns_to_add = missing + else + -- Create new .gitignore with all patterns + content = "" + patterns_to_add = IGNORE_PATTERNS + end - -- Build the patterns string - local patterns_str = table.concat(patterns_to_add, "\n") + -- Build the patterns string + local patterns_str = table.concat(patterns_to_add, "\n") - if content == "" then - -- New file - content = CODER_COMMENT .. "\n" .. patterns_str .. "\n" - else - -- Append to existing - local newline = content:sub(-1) == "\n" and "" or "\n" - -- Check if comment already exists - if not content:match(utils.escape_pattern(CODER_COMMENT)) then - content = content .. newline .. "\n" .. CODER_COMMENT .. "\n" .. patterns_str .. "\n" - else - content = content .. newline .. patterns_str .. "\n" - end - end + if content == "" then + -- New file + content = CODER_COMMENT .. "\n" .. patterns_str .. "\n" + else + -- Append to existing + local newline = content:sub(-1) == "\n" and "" or "\n" + -- Check if comment already exists + if not content:match(utils.escape_pattern(CODER_COMMENT)) then + content = content .. newline .. "\n" .. CODER_COMMENT .. "\n" .. patterns_str .. "\n" + else + content = content .. newline .. patterns_str .. "\n" + end + end - if utils.write_file(gitignore_path, content) then - utils.notify("Added coder patterns to .gitignore") - return true - else - utils.notify("Failed to update .gitignore", vim.log.levels.ERROR) - return false - end + if utils.write_file(gitignore_path, content) then + utils.notify("Added coder patterns to .gitignore") + return true + else + utils.notify("Failed to update .gitignore", vim.log.levels.ERROR) + return false + end end --- Ensure coder files are in .gitignore (called on setup) @@ -122,116 +122,116 @@ end ---@param auto_gitignore? boolean Override auto_gitignore setting (default: true) ---@return boolean Success status function M.ensure_ignored(auto_gitignore) - -- Only add to gitignore if this is a git project - if not utils.is_git_project() then - return false -- Not a git project, skip - end + -- Only add to gitignore if this is a git project + if not utils.is_git_project() then + return false -- Not a git project, skip + end - if not auto_gitignore then - return true - end + if not auto_gitignore then + return true + end - if M.is_ignored() then - return true - end + if M.is_ignored() then + return true + end - -- Default to true if not specified - if auto_gitignore == nil then - -- Try to get from config if available - local ok, codetyper = pcall(require, "codetyper") - if ok and codetyper.is_initialized and codetyper.is_initialized() then - local config = codetyper.get_config() - auto_gitignore = config and config.auto_gitignore - else - auto_gitignore = true -- Default to true - end - end + -- Default to true if not specified + if auto_gitignore == nil then + -- Try to get from config if available + local ok, codetyper = pcall(require, "codetyper") + if ok and codetyper.is_initialized and codetyper.is_initialized() then + local config = codetyper.get_config() + auto_gitignore = config and config.auto_gitignore + else + auto_gitignore = true -- Default to true + end + end - -- Silently add to gitignore (no notifications unless there's an error) - return M.add_to_gitignore_silent() + -- Silently add to gitignore (no notifications unless there's an error) + return M.add_to_gitignore_silent() end -- /@ @/ --- Add coder patterns to .gitignore silently (no notifications) ---@return boolean Success status function M.add_to_gitignore_silent() - local gitignore_path = M.get_gitignore_path() - if not gitignore_path then - return false - end + local gitignore_path = M.get_gitignore_path() + if not gitignore_path then + return false + end - local content = utils.read_file(gitignore_path) - local patterns_to_add = {} + local content = utils.read_file(gitignore_path) + local patterns_to_add = {} - if content then - local _, missing = all_patterns_exist(content) - if #missing == 0 then - return true - end - patterns_to_add = missing - else - content = "" - patterns_to_add = IGNORE_PATTERNS - end + if content then + local _, missing = all_patterns_exist(content) + if #missing == 0 then + return true + end + patterns_to_add = missing + else + content = "" + patterns_to_add = IGNORE_PATTERNS + end - local patterns_str = table.concat(patterns_to_add, "\n") + local patterns_str = table.concat(patterns_to_add, "\n") - if content == "" then - content = CODER_COMMENT .. "\n" .. patterns_str .. "\n" - else - local newline = content:sub(-1) == "\n" and "" or "\n" - if not content:match(utils.escape_pattern(CODER_COMMENT)) then - content = content .. newline .. "\n" .. CODER_COMMENT .. "\n" .. patterns_str .. "\n" - else - content = content .. newline .. patterns_str .. "\n" - end - end + if content == "" then + content = CODER_COMMENT .. "\n" .. patterns_str .. "\n" + else + local newline = content:sub(-1) == "\n" and "" or "\n" + if not content:match(utils.escape_pattern(CODER_COMMENT)) then + content = content .. newline .. "\n" .. CODER_COMMENT .. "\n" .. patterns_str .. "\n" + else + content = content .. newline .. patterns_str .. "\n" + end + end - return utils.write_file(gitignore_path, content) + return utils.write_file(gitignore_path, content) end --- Remove coder patterns from .gitignore ---@return boolean Success status function M.remove_from_gitignore() - local gitignore_path = M.get_gitignore_path() - if not gitignore_path then - return false - end + local gitignore_path = M.get_gitignore_path() + if not gitignore_path then + return false + end - local content = utils.read_file(gitignore_path) - if not content then - return false - end + local content = utils.read_file(gitignore_path) + if not content then + return false + end - -- Remove the comment and all patterns - content = content:gsub(CODER_COMMENT .. "\n", "") - for _, pattern in ipairs(IGNORE_PATTERNS) do - content = content:gsub(utils.escape_pattern(pattern) .. "\n?", "") - end + -- Remove the comment and all patterns + content = content:gsub(CODER_COMMENT .. "\n", "") + for _, pattern in ipairs(IGNORE_PATTERNS) do + content = content:gsub(utils.escape_pattern(pattern) .. "\n?", "") + end - -- Clean up extra newlines - content = content:gsub("\n\n\n+", "\n\n") + -- Clean up extra newlines + content = content:gsub("\n\n\n+", "\n\n") - return utils.write_file(gitignore_path, content) + return utils.write_file(gitignore_path, content) end --- Get list of patterns being ignored ---@return string[] List of patterns function M.get_ignore_patterns() - return vim.deepcopy(IGNORE_PATTERNS) + return vim.deepcopy(IGNORE_PATTERNS) end --- Force update gitignore (manual trigger) ---@return boolean Success status function M.force_update() - local gitignore_path = M.get_gitignore_path() - if not gitignore_path then - utils.notify("Could not determine project root for .gitignore", vim.log.levels.WARN) - return false - end + local gitignore_path = M.get_gitignore_path() + if not gitignore_path then + utils.notify("Could not determine project root for .gitignore", vim.log.levels.WARN) + return false + end - utils.notify("Updating .gitignore at: " .. gitignore_path) - return M.add_to_gitignore() + utils.notify("Updating .gitignore at: " .. gitignore_path) + return M.add_to_gitignore() end return M diff --git a/lua/codetyper/support/langmap.lua b/lua/codetyper/support/langmap.lua index 351622b..8313fd3 100644 --- a/lua/codetyper/support/langmap.lua +++ b/lua/codetyper/support/langmap.lua @@ -1,75 +1,75 @@ local lang_map = { - -- JavaScript/TypeScript - ts = "TypeScript", - tsx = "TypeScript React (TSX)", - js = "JavaScript", - jsx = "JavaScript React (JSX)", - mjs = "JavaScript (ESM)", - cjs = "JavaScript (CommonJS)", - -- Python - py = "Python", - pyw = "Python", - pyx = "Cython", - -- Systems languages - c = "C", - h = "C Header", - cpp = "C++", - hpp = "C++ Header", - cc = "C++", - cxx = "C++", - rs = "Rust", - go = "Go", - -- JVM languages - java = "Java", - kt = "Kotlin", - kts = "Kotlin Script", - scala = "Scala", - clj = "Clojure", - -- Web - html = "HTML", - css = "CSS", - scss = "SCSS", - sass = "Sass", - less = "Less", - vue = "Vue", - svelte = "Svelte", - -- Scripting - lua = "Lua", - rb = "Ruby", - php = "PHP", - pl = "Perl", - sh = "Shell (Bash)", - bash = "Bash", - zsh = "Zsh", - fish = "Fish", - ps1 = "PowerShell", - -- .NET - cs = "C#", - fs = "F#", - vb = "Visual Basic", - -- Data/Config - json = "JSON", - yaml = "YAML", - yml = "YAML", - toml = "TOML", - xml = "XML", - sql = "SQL", - graphql = "GraphQL", - -- Other - swift = "Swift", - dart = "Dart", - ex = "Elixir", - exs = "Elixir Script", - erl = "Erlang", - hs = "Haskell", - ml = "OCaml", - r = "R", - jl = "Julia", - nim = "Nim", - zig = "Zig", - v = "V", - md = "Markdown", - mdx = "MDX", + -- JavaScript/TypeScript + ts = "TypeScript", + tsx = "TypeScript React (TSX)", + js = "JavaScript", + jsx = "JavaScript React (JSX)", + mjs = "JavaScript (ESM)", + cjs = "JavaScript (CommonJS)", + -- Python + py = "Python", + pyw = "Python", + pyx = "Cython", + -- Systems languages + c = "C", + h = "C Header", + cpp = "C++", + hpp = "C++ Header", + cc = "C++", + cxx = "C++", + rs = "Rust", + go = "Go", + -- JVM languages + java = "Java", + kt = "Kotlin", + kts = "Kotlin Script", + scala = "Scala", + clj = "Clojure", + -- Web + html = "HTML", + css = "CSS", + scss = "SCSS", + sass = "Sass", + less = "Less", + vue = "Vue", + svelte = "Svelte", + -- Scripting + lua = "Lua", + rb = "Ruby", + php = "PHP", + pl = "Perl", + sh = "Shell (Bash)", + bash = "Bash", + zsh = "Zsh", + fish = "Fish", + ps1 = "PowerShell", + -- .NET + cs = "C#", + fs = "F#", + vb = "Visual Basic", + -- Data/Config + json = "JSON", + yaml = "YAML", + yml = "YAML", + toml = "TOML", + xml = "XML", + sql = "SQL", + graphql = "GraphQL", + -- Other + swift = "Swift", + dart = "Dart", + ex = "Elixir", + exs = "Elixir Script", + erl = "Erlang", + hs = "Haskell", + ml = "OCaml", + r = "R", + jl = "Julia", + nim = "Nim", + zig = "Zig", + v = "V", + md = "Markdown", + mdx = "MDX", } return lang_map diff --git a/lua/codetyper/support/logger.lua b/lua/codetyper/support/logger.lua index 3071012..ccb2f11 100644 --- a/lua/codetyper/support/logger.lua +++ b/lua/codetyper/support/logger.lua @@ -6,48 +6,48 @@ local M = {} local logger = nil local function get_logger() - if logger then - return logger - end - - -- Try to get codetyper module for config - local ok, codetyper = pcall(require, "codetyper") - local config = {} - if ok and codetyper.get_config then - config = codetyper.get_config() or {} - end - - -- Use ~/.config/nvim/logs/ directory - local log_dir = vim.fn.expand("~/.config/nvim/logs") - vim.fn.mkdir(log_dir, "p") - - logger = { - debug_enabled = config.debug_logging or false, - log_file = config.log_file or log_dir .. "/codetyper.log", - } - - return logger + if logger then + return logger + end + + -- Try to get codetyper module for config + local ok, codetyper = pcall(require, "codetyper") + local config = {} + if ok and codetyper.get_config then + config = codetyper.get_config() or {} + end + + -- Use ~/.config/nvim/logs/ directory + local log_dir = vim.fn.expand("~/.config/nvim/logs") + vim.fn.mkdir(log_dir, "p") + + logger = { + debug_enabled = config.debug_logging or false, + log_file = config.log_file or log_dir .. "/codetyper.log", + } + + return logger end --- Get current timestamp ---@return string timestamp ISO 8601 format local function get_timestamp() - return os.date("%Y-%m-%d %H:%M:%S") + return os.date("%Y-%m-%d %H:%M:%S") end --- Get calling function info ---@return string caller_info local function get_caller_info() - local info = debug.getinfo(3, "Sn") - if not info then - return "unknown" - end - - local name = info.name or "anonymous" - local source = info.source and info.source:gsub("^@", "") or "unknown" - local line = info.linedefined or 0 - - return string.format("%s:%d [%s]", source, line, name) + local info = debug.getinfo(3, "Sn") + if not info then + return "unknown" + end + + local name = info.name or "anonymous" + local source = info.source and info.source:gsub("^@", "") or "unknown" + local line = info.linedefined or 0 + + return string.format("%s:%d [%s]", source, line, name) end --- Format log message @@ -56,63 +56,63 @@ end ---@param message string Log message ---@return string formatted local function format_log(level, module, message) - local timestamp = get_timestamp() - local caller = get_caller_info() - return string.format("[%s] [%s] [%s] %s | %s", timestamp, level, module, caller, message) + local timestamp = get_timestamp() + local caller = get_caller_info() + return string.format("[%s] [%s] [%s] %s | %s", timestamp, level, module, caller, message) end --- Write log to file ---@param message string Log message local function write_to_file(message) - local log = get_logger() - local f = io.open(log.log_file, "a") - if f then - f:write(message .. "\n") - f:close() - end + local log = get_logger() + local f = io.open(log.log_file, "a") + if f then + f:write(message .. "\n") + f:close() + end end --- Log debug message ---@param module string Module name ---@param message string Log message function M.debug(module, message) - local log = get_logger() - if not log.debug_enabled then - return - end - - local formatted = format_log("DEBUG", module, message) - write_to_file(formatted) - - -- Also use vim.notify for visibility - vim.notify("[codetyper] " .. message, vim.log.levels.DEBUG) + local log = get_logger() + if not log.debug_enabled then + return + end + + local formatted = format_log("DEBUG", module, message) + write_to_file(formatted) + + -- Also use vim.notify for visibility + vim.notify("[codetyper] " .. message, vim.log.levels.DEBUG) end --- Log info message ---@param module string Module name ---@param message string Log message function M.info(module, message) - local formatted = format_log("INFO", module, message) - write_to_file(formatted) - vim.notify("[codetyper] " .. message, vim.log.levels.INFO) + local formatted = format_log("INFO", module, message) + write_to_file(formatted) + vim.notify("[codetyper] " .. message, vim.log.levels.INFO) end --- Log warning message ---@param module string Module name ---@param message string Log message function M.warn(module, message) - local formatted = format_log("WARN", module, message) - write_to_file(formatted) - vim.notify("[codetyper] " .. message, vim.log.levels.WARN) + local formatted = format_log("WARN", module, message) + write_to_file(formatted) + vim.notify("[codetyper] " .. message, vim.log.levels.WARN) end --- Log error message ---@param module string Module name ---@param message string Log message function M.error(module, message) - local formatted = format_log("ERROR", module, message) - write_to_file(formatted) - vim.notify("[codetyper] " .. message, vim.log.levels.ERROR) + local formatted = format_log("ERROR", module, message) + write_to_file(formatted) + vim.notify("[codetyper] " .. message, vim.log.levels.ERROR) end --- Log function entry with parameters @@ -120,26 +120,26 @@ end ---@param func_name string Function name ---@param params table|nil Parameters (will be inspected) function M.func_entry(module, func_name, params) - local log = get_logger() - if not log.debug_enabled then - return - end - - local param_str = "" - if params then - local parts = {} - for k, v in pairs(params) do - local val_str = tostring(v) - if #val_str > 50 then - val_str = val_str:sub(1, 47) .. "..." - end - table.insert(parts, k .. "=" .. val_str) - end - param_str = table.concat(parts, ", ") - end - - local message = string.format("ENTER %s(%s)", func_name, param_str) - M.debug(module, message) + local log = get_logger() + if not log.debug_enabled then + return + end + + local param_str = "" + if params then + local parts = {} + for k, v in pairs(params) do + local val_str = tostring(v) + if #val_str > 50 then + val_str = val_str:sub(1, 47) .. "..." + end + table.insert(parts, k .. "=" .. val_str) + end + param_str = table.concat(parts, ", ") + end + + local message = string.format("ENTER %s(%s)", func_name, param_str) + M.debug(module, message) end --- Log function exit with return value @@ -147,75 +147,75 @@ end ---@param func_name string Function name ---@param result any Return value (will be inspected) function M.func_exit(module, func_name, result) - local log = get_logger() - if not log.debug_enabled then - return - end - - local result_str = tostring(result) - if type(result) == "table" then - result_str = vim.inspect(result) - end - if #result_str > 100 then - result_str = result_str:sub(1, 97) .. "..." - end - - local message = string.format("EXIT %s -> %s", func_name, result_str) - M.debug(module, message) + local log = get_logger() + if not log.debug_enabled then + return + end + + local result_str = tostring(result) + if type(result) == "table" then + result_str = vim.inspect(result) + end + if #result_str > 100 then + result_str = result_str:sub(1, 97) .. "..." + end + + local message = string.format("EXIT %s -> %s", func_name, result_str) + M.debug(module, message) end --- Enable or disable debug logging ---@param enabled boolean function M.set_debug(enabled) - local log = get_logger() - log.debug_enabled = enabled - M.info("logger", "Debug logging " .. (enabled and "enabled" or "disabled")) + local log = get_logger() + log.debug_enabled = enabled + M.info("logger", "Debug logging " .. (enabled and "enabled" or "disabled")) end --- Get log file path ---@return string log_file path function M.get_log_file() - local log = get_logger() - return log.log_file + local log = get_logger() + return log.log_file end --- Clear log file function M.clear() - local log = get_logger() - local f = io.open(log.log_file, "w") - if f then - f:write("") - f:close() - end - M.info("logger", "Log file cleared") + local log = get_logger() + local f = io.open(log.log_file, "w") + if f then + f:write("") + f:close() + end + M.info("logger", "Log file cleared") end --- Show logs in a buffer function M.show() - local log = get_logger() - local lines = {} - - local f = io.open(log.log_file, "r") - if f then - for line in f:lines() do - table.insert(lines, line) - end - f:close() - end - - -- Create a new buffer - local bufnr = vim.api.nvim_create_buf(false, true) - vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) - vim.bo[bufnr].filetype = "log" - vim.bo[bufnr].modifiable = false - vim.bo[bufnr].readonly = true - - -- Open in a split - vim.cmd("vsplit") - local win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(win, bufnr) - - return bufnr + local log = get_logger() + local lines = {} + + local f = io.open(log.log_file, "r") + if f then + for line in f:lines() do + table.insert(lines, line) + end + f:close() + end + + -- Create a new buffer + local bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) + vim.bo[bufnr].filetype = "log" + vim.bo[bufnr].modifiable = false + vim.bo[bufnr].readonly = true + + -- Open in a split + vim.cmd("vsplit") + local win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(win, bufnr) + + return bufnr end return M diff --git a/lua/codetyper/support/tree.lua b/lua/codetyper/support/tree.lua index 5fab0b0..e7ad413 100644 --- a/lua/codetyper/support/tree.lua +++ b/lua/codetyper/support/tree.lua @@ -208,7 +208,7 @@ function M.generate_tree() -- Patterns to ignore local ignore_patterns = { - "^%.", -- Hidden files/folders + "^%.", -- Hidden files/folders "^node_modules$", "^__pycache__$", "^%.git$", @@ -217,7 +217,7 @@ function M.generate_tree() "^build$", "^target$", "^vendor$", - "%.codetyper%.", -- Coder files + "%.codetyper%.", -- Coder files } local lines = { diff --git a/lua/codetyper/support/utils.lua b/lua/codetyper/support/utils.lua index 5409626..725e5ad 100644 --- a/lua/codetyper/support/utils.lua +++ b/lua/codetyper/support/utils.lua @@ -6,8 +6,8 @@ local M = {} ---@param prefix? string Prefix for the ID (default: "id") ---@return string Unique ID function M.generate_id(prefix) - prefix = prefix or "id" - return prefix .. "_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF)) + prefix = prefix or "id" + return prefix .. "_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF)) end --- Get the project root directory @@ -132,7 +132,7 @@ end ---@param level? number Vim log level (default: INFO) function M.notify(msg, level) level = level or vim.log.levels.INFO - + -- Also log to file local logger = require("codetyper.support.logger") local level_name = "INFO" @@ -143,20 +143,20 @@ function M.notify(msg, level) elseif level == vim.log.levels.ERROR then level_name = "ERROR" end - + -- Write to log file local log_dir = vim.fn.expand("~/.config/nvim/logs") vim.fn.mkdir(log_dir, "p") local log_file = log_dir .. "/codetyper.log" local timestamp = os.date("%Y-%m-%d %H:%M:%S") local log_entry = string.format("[%s] [%s] [utils.notify] %s\n", timestamp, level_name, msg) - + local f = io.open(log_file, "a") if f then f:write(log_entry) f:close() end - + vim.notify("[Codetyper] " .. msg, level) end @@ -231,36 +231,36 @@ end ---@param response string ---@return boolean balanced function M.check_brackets(response) - local pairs = { - ["{"] = "}", - ["["] = "]", - ["("] = ")", - } + local pairs = { + ["{"] = "}", + ["["] = "]", + ["("] = ")", + } - local stack = {} + local stack = {} - for char in response:gmatch(".") do - if pairs[char] then - table.insert(stack, pairs[char]) - elseif char == "}" or char == "]" or char == ")" then - if #stack == 0 or stack[#stack] ~= char then - return false - end - table.remove(stack) - end - end + for char in response:gmatch(".") do + if pairs[char] then + table.insert(stack, pairs[char]) + elseif char == "}" or char == "]" or char == ")" then + if #stack == 0 or stack[#stack] ~= char then + return false + end + table.remove(stack) + end + end - return #stack == 0 + return #stack == 0 end --- Simple hash function for content ---@param content string ---@return string function M.hash_content(content) - local hash = vim.fn.sha256(content) - -- If sha256 returns hex string, format %x might be wrong if it expects number? - -- vim.fn.sha256 returns a hex string already. - return hash + local hash = vim.fn.sha256(content) + -- If sha256 returns hex string, format %x might be wrong if it expects number? + -- vim.fn.sha256 returns a hex string already. + return hash end --- Check if a line is empty or a comment @@ -269,22 +269,24 @@ end ---@param filetype string ---@return boolean function M.is_empty_or_comment(line, filetype) - local trimmed = line:match("^%s*(.-)%s*$") - if trimmed == "" then - return true - end + local trimmed = line:match("^%s*(.-)%s*$") + if trimmed == "" then + return true + end - local ok, languages = pcall(require, "codetyper.params.agent.languages") - if not ok then return false end + local ok, languages = pcall(require, "codetyper.params.agent.languages") + if not ok then + return false + end - local patterns = languages.comment_patterns[filetype] or languages.comment_patterns.javascript - for _, pattern in ipairs(patterns) do - if trimmed:match(pattern) then - return true - end - end + local patterns = languages.comment_patterns[filetype] or languages.comment_patterns.javascript + for _, pattern in ipairs(patterns) do + if trimmed:match(pattern) then + return true + end + end - return false + return false end --- Classify an import as "builtin", "local", or "third_party" @@ -292,37 +294,42 @@ end ---@param filetype string The filetype ---@return string category "builtin"|"local"|"third_party" function M.classify_import(imp, filetype) - local is_local = false - local is_builtin = false + local is_local = false + local is_builtin = false - if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then - -- Local: starts with . or .. - is_local = imp:match("from%s+['\"]%.") or imp:match("require%(['\"]%.") - -- Node builtin modules - is_builtin = imp:match("from%s+['\"]node:") or imp:match("from%s+['\"]fs['\"]") - or imp:match("from%s+['\"]path['\"]") or imp:match("from%s+['\"]http['\"]") - elseif filetype == "python" or filetype == "py" then - -- Local: relative imports - is_local = imp:match("^from%s+%.") or imp:match("^import%s+%.") - -- Python stdlib (simplified check) - is_builtin = imp:match("^import%s+os") or imp:match("^import%s+sys") - or imp:match("^from%s+os%s+") or imp:match("^from%s+sys%s+") - or imp:match("^import%s+re") or imp:match("^import%s+json") - elseif filetype == "lua" then - -- Local: relative requires - is_local = imp:match("require%(['\"]%.") or imp:match("require%s+['\"]%.") - elseif filetype == "go" then - -- Local: project imports (contain /) - is_local = imp:match("['\"][^'\"]+/[^'\"]+['\"]") and not imp:match("github%.com") - end + if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then + -- Local: starts with . or .. + is_local = imp:match("from%s+['\"]%.") or imp:match("require%(['\"]%.") + -- Node builtin modules + is_builtin = imp:match("from%s+['\"]node:") + or imp:match("from%s+['\"]fs['\"]") + or imp:match("from%s+['\"]path['\"]") + or imp:match("from%s+['\"]http['\"]") + elseif filetype == "python" or filetype == "py" then + -- Local: relative imports + is_local = imp:match("^from%s+%.") or imp:match("^import%s+%.") + -- Python stdlib (simplified check) + is_builtin = imp:match("^import%s+os") + or imp:match("^import%s+sys") + or imp:match("^from%s+os%s+") + or imp:match("^from%s+sys%s+") + or imp:match("^import%s+re") + or imp:match("^import%s+json") + elseif filetype == "lua" then + -- Local: relative requires + is_local = imp:match("require%(['\"]%.") or imp:match("require%s+['\"]%.") + elseif filetype == "go" then + -- Local: project imports (contain /) + is_local = imp:match("['\"][^'\"]+/[^'\"]+['\"]") and not imp:match("github%.com") + end - if is_builtin then - return "builtin" - elseif is_local then - return "local" - else - return "third_party" - end + if is_builtin then + return "builtin" + elseif is_local then + return "local" + else + return "third_party" + end end --- Check if a line ends a multi-line import @@ -330,44 +337,44 @@ end ---@param filetype string ---@return boolean function M.ends_multiline_import(line, filetype) - -- Check for closing patterns - if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then - -- ES6 imports end with 'from "..." ;' or just ';' or a line with just '}' - if line:match("from%s+['\"][^'\"]+['\"]%s*;?%s*$") then - return true - end - if line:match("}%s*from%s+['\"]") then - return true - end - if line:match("^%s*}%s*;?%s*$") then - return true - end - if line:match(";%s*$") then - return true - end - elseif filetype == "python" or filetype == "py" then - -- Python single-line import: doesn't end with \, (, or , - -- Examples: "from typing import List, Dict" or "import os" - if not line:match("\\%s*$") and not line:match("%(%s*$") and not line:match(",%s*$") then - return true - end - -- Python multiline imports end with closing paren - if line:match("%)%s*$") then - return true - end - elseif filetype == "go" then - -- Go multi-line imports end with ')' - if line:match("%)%s*$") then - return true - end - elseif filetype == "rust" or filetype == "rs" then - -- Rust use statements end with ';' - if line:match(";%s*$") then - return true - end - end + -- Check for closing patterns + if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then + -- ES6 imports end with 'from "..." ;' or just ';' or a line with just '}' + if line:match("from%s+['\"][^'\"]+['\"]%s*;?%s*$") then + return true + end + if line:match("}%s*from%s+['\"]") then + return true + end + if line:match("^%s*}%s*;?%s*$") then + return true + end + if line:match(";%s*$") then + return true + end + elseif filetype == "python" or filetype == "py" then + -- Python single-line import: doesn't end with \, (, or , + -- Examples: "from typing import List, Dict" or "import os" + if not line:match("\\%s*$") and not line:match("%(%s*$") and not line:match(",%s*$") then + return true + end + -- Python multiline imports end with closing paren + if line:match("%)%s*$") then + return true + end + elseif filetype == "go" then + -- Go multi-line imports end with ')' + if line:match("%)%s*$") then + return true + end + elseif filetype == "rust" or filetype == "rs" then + -- Rust use statements end with ';' + if line:match(";%s*$") then + return true + end + end - return false + return false end return M diff --git a/plugin/codetyper.lua b/plugin/codetyper.lua index 1869e91..3e5b1c6 100644 --- a/plugin/codetyper.lua +++ b/plugin/codetyper.lua @@ -8,14 +8,14 @@ local cmd = vim.cmd -- Prevent loading twice if g.loaded_codetyper then - return + return end g.loaded_codetyper = true -- Minimum Neovim version check if fn.has("nvim-0.8.0") == 0 then - api.nvim_err_writeln("Codetyper.nvim requires Neovim 0.8.0 or higher") - return + api.nvim_err_writeln("Codetyper.nvim requires Neovim 0.8.0 or higher") + return end --- Initialize codetyper plugin fully @@ -23,81 +23,79 @@ end --- Also registers autocmds for /@ @/ prompt detection ---@return boolean success local function init_coder_files() - local ok, err = pcall(function() - -- Full plugin initialization (includes config, commands, autocmds, tree, gitignore) - local codetyper = require("codetyper") - if not codetyper.is_initialized() then - codetyper.setup() - end - end) + local ok, err = pcall(function() + -- Full plugin initialization (includes config, commands, autocmds, tree, gitignore) + local codetyper = require("codetyper") + if not codetyper.is_initialized() then + codetyper.setup() + end + end) - if not ok then - vim.notify("[Codetyper] Failed to initialize: " .. tostring(err), vim.log.levels.ERROR) - return false - end - return true + if not ok then + vim.notify("[Codetyper] Failed to initialize: " .. tostring(err), vim.log.levels.ERROR) + return false + end + return true end -- Initialize .codetyper folder and tree.log on project open api.nvim_create_autocmd("VimEnter", { - callback = function() - -- Delay slightly to ensure cwd is set - vim.defer_fn(function() - init_coder_files() - end, 100) - end, - desc = "Initialize Codetyper .codetyper folder on startup", + callback = function() + -- Delay slightly to ensure cwd is set + vim.defer_fn(function() + init_coder_files() + end, 100) + end, + desc = "Initialize Codetyper .codetyper folder on startup", }) -- Also initialize on directory change api.nvim_create_autocmd("DirChanged", { - callback = function() - vim.defer_fn(function() - init_coder_files() - end, 100) - end, - desc = "Initialize Codetyper .codetyper folder on directory change", + callback = function() + vim.defer_fn(function() + init_coder_files() + end, 100) + end, + desc = "Initialize Codetyper .codetyper folder on directory change", }) -- Auto-initialize when opening a coder file (for nvim-tree, telescope, etc.) api.nvim_create_autocmd({ "BufRead", "BufNewFile", "BufEnter" }, { - pattern = "*.codetyper/*", - callback = function() - -- Initialize plugin if not already done - local codetyper = require("codetyper") - if not codetyper.is_initialized() then - codetyper.setup() - end - end, - desc = "Auto-initialize Codetyper when opening coder files", + pattern = "*.codetyper/*", + callback = function() + -- Initialize plugin if not already done + local codetyper = require("codetyper") + if not codetyper.is_initialized() then + codetyper.setup() + end + end, + desc = "Auto-initialize Codetyper when opening coder files", }) -- Lazy-load the plugin on first command usage api.nvim_create_user_command("Coder", function(opts) - require("codetyper").setup() - cmd("Coder " .. (opts.args or "")) + require("codetyper").setup() + cmd("Coder " .. (opts.args or "")) end, { - nargs = "?", - complete = function() - return { - "tree", - "tree-view", - "reset", - "gitignore", - } - end, - desc = "Codetyper.nvim commands", + nargs = "?", + complete = function() + return { + "tree", + "tree-view", + "reset", + "gitignore", + } + end, + desc = "Codetyper.nvim commands", }) -- Lazy-load aliases api.nvim_create_user_command("CoderTree", function() - require("codetyper").setup() - cmd("CoderTree") + require("codetyper").setup() + cmd("CoderTree") end, { desc = "Refresh tree.log" }) api.nvim_create_user_command("CoderTreeView", function() - require("codetyper").setup() - cmd("CoderTreeView") + require("codetyper").setup() + cmd("CoderTreeView") end, { desc = "View tree.log" }) - -