From 84c8bcf92c2174b131e64ba554c33ea0e2904b0c Mon Sep 17 00:00:00 2001 From: Carlos Gutierrez Date: Wed, 14 Jan 2026 21:43:56 -0500 Subject: [PATCH] Adding autocomplete and copilot suggestions --- lua/codetyper/agent/diff.lua | 118 +++- lua/codetyper/agent/executor.lua | 196 ++++++ lua/codetyper/agent/init.lua | 30 +- lua/codetyper/agent/permissions.lua | 229 +++++++ lua/codetyper/agent/queue.lua | 2 +- lua/codetyper/agent/scheduler.lua | 8 +- lua/codetyper/agent/tools.lua | 61 ++ lua/codetyper/agent/ui.lua | 71 +- lua/codetyper/agent/worker.lua | 77 ++- lua/codetyper/ask.lua | 232 ++++++- lua/codetyper/ask/explorer.lua | 676 ++++++++++++++++++++ lua/codetyper/ask/intent.lua | 302 +++++++++ lua/codetyper/autocmds.lua | 209 +++++- lua/codetyper/brain/delta/commit.lua | 291 +++++++++ lua/codetyper/brain/delta/diff.lua | 261 ++++++++ lua/codetyper/brain/delta/init.lua | 278 ++++++++ lua/codetyper/brain/graph/edge.lua | 367 +++++++++++ lua/codetyper/brain/graph/init.lua | 213 ++++++ lua/codetyper/brain/graph/node.lua | 403 ++++++++++++ lua/codetyper/brain/graph/query.lua | 394 ++++++++++++ lua/codetyper/brain/hash.lua | 112 ++++ lua/codetyper/brain/init.lua | 276 ++++++++ lua/codetyper/brain/learners/convention.lua | 233 +++++++ lua/codetyper/brain/learners/correction.lua | 213 ++++++ lua/codetyper/brain/learners/init.lua | 232 +++++++ lua/codetyper/brain/learners/pattern.lua | 172 +++++ lua/codetyper/brain/output/formatter.lua | 279 ++++++++ lua/codetyper/brain/output/init.lua | 166 +++++ lua/codetyper/brain/storage.lua | 338 ++++++++++ lua/codetyper/brain/types.lua | 175 +++++ lua/codetyper/cmp_source/init.lua | 301 +++++++++ lua/codetyper/commands.lua | 168 ++++- lua/codetyper/config.lua | 57 +- lua/codetyper/health.lua | 10 +- lua/codetyper/indexer/analyzer.lua | 582 +++++++++++++++++ lua/codetyper/indexer/init.lua | 604 +++++++++++++++++ lua/codetyper/indexer/memory.lua | 539 ++++++++++++++++ lua/codetyper/indexer/scanner.lua | 409 ++++++++++++ lua/codetyper/init.lua | 20 +- lua/codetyper/llm/claude.lua | 364 ----------- lua/codetyper/llm/init.lua | 71 +- lua/codetyper/prompts/agent.lua | 114 ++-- lua/codetyper/prompts/init.lua | 1 + lua/codetyper/prompts/system.lua | 11 +- lua/codetyper/suggestion/init.lua | 491 ++++++++++++++ lua/codetyper/types.lua | 22 +- lua/codetyper/window.lua | 8 +- tests/spec/ask_intent_spec.lua | 229 +++++++ tests/spec/brain_delta_spec.lua | 252 ++++++++ tests/spec/brain_hash_spec.lua | 128 ++++ tests/spec/brain_node_spec.lua | 234 +++++++ tests/spec/brain_storage_spec.lua | 173 +++++ tests/spec/indexer_spec.lua | 345 ++++++++++ tests/spec/memory_spec.lua | 341 ++++++++++ tests/spec/scanner_spec.lua | 285 +++++++++ 55 files changed, 11823 insertions(+), 550 deletions(-) create mode 100644 lua/codetyper/agent/permissions.lua create mode 100644 lua/codetyper/ask/explorer.lua create mode 100644 lua/codetyper/ask/intent.lua create mode 100644 lua/codetyper/brain/delta/commit.lua create mode 100644 lua/codetyper/brain/delta/diff.lua create mode 100644 lua/codetyper/brain/delta/init.lua create mode 100644 lua/codetyper/brain/graph/edge.lua create mode 100644 lua/codetyper/brain/graph/init.lua create mode 100644 lua/codetyper/brain/graph/node.lua create mode 100644 lua/codetyper/brain/graph/query.lua create mode 100644 lua/codetyper/brain/hash.lua create mode 100644 lua/codetyper/brain/init.lua create mode 100644 lua/codetyper/brain/learners/convention.lua create mode 100644 lua/codetyper/brain/learners/correction.lua create mode 100644 lua/codetyper/brain/learners/init.lua create mode 100644 lua/codetyper/brain/learners/pattern.lua create mode 100644 lua/codetyper/brain/output/formatter.lua create mode 100644 lua/codetyper/brain/output/init.lua create mode 100644 lua/codetyper/brain/storage.lua create mode 100644 lua/codetyper/brain/types.lua create mode 100644 lua/codetyper/cmp_source/init.lua create mode 100644 lua/codetyper/indexer/analyzer.lua create mode 100644 lua/codetyper/indexer/init.lua create mode 100644 lua/codetyper/indexer/memory.lua create mode 100644 lua/codetyper/indexer/scanner.lua delete mode 100644 lua/codetyper/llm/claude.lua create mode 100644 lua/codetyper/suggestion/init.lua create mode 100644 tests/spec/ask_intent_spec.lua create mode 100644 tests/spec/brain_delta_spec.lua create mode 100644 tests/spec/brain_hash_spec.lua create mode 100644 tests/spec/brain_node_spec.lua create mode 100644 tests/spec/brain_storage_spec.lua create mode 100644 tests/spec/indexer_spec.lua create mode 100644 tests/spec/memory_spec.lua create mode 100644 tests/spec/scanner_spec.lua diff --git a/lua/codetyper/agent/diff.lua b/lua/codetyper/agent/diff.lua index 5717df3..2f8bf50 100644 --- a/lua/codetyper/agent/diff.lua +++ b/lua/codetyper/agent/diff.lua @@ -9,7 +9,20 @@ local M = {} ---@param callback fun(approved: boolean) Called with user decision function M.show_diff(diff_data, callback) local original_lines = vim.split(diff_data.original, "\n", { plain = true }) - local modified_lines = vim.split(diff_data.modified, "\n", { plain = true }) + local modified_lines + + -- For delete operations, show a clear message + if diff_data.operation == "delete" then + modified_lines = { + "", + " FILE WILL BE DELETED", + "", + " Reason: " .. (diff_data.reason or "No reason provided"), + "", + } + else + modified_lines = vim.split(diff_data.modified, "\n", { plain = true }) + end -- Calculate window dimensions local width = math.floor(vim.o.columns * 0.8) @@ -59,7 +72,7 @@ function M.show_diff(diff_data, callback) col = col + half_width + 1, style = "minimal", border = "rounded", - title = " MODIFIED [" .. diff_data.operation .. "] ", + title = diff_data.operation == "delete" and " ⚠️ DELETE " or (" MODIFIED [" .. diff_data.operation .. "] "), title_pos = "center", }) @@ -157,26 +170,52 @@ function M.show_diff(diff_data, callback) }, false, {}) end ---- Show approval dialog for bash commands +---@alias BashApprovalResult {approved: boolean, permission_level: string|nil} + +--- Show approval dialog for bash commands with permission levels ---@param command string The bash command to approve ----@param callback fun(approved: boolean) Called with user decision +---@param callback fun(result: BashApprovalResult) Called with user decision function M.show_bash_approval(command, callback) - -- Create a simple floating window for bash approval + local permissions = require("codetyper.agent.permissions") + + -- Check if command is auto-approved + local perm_result = permissions.check_bash_permission(command) + if perm_result.auto and perm_result.allowed then + vim.schedule(function() + callback({ approved = true, permission_level = "auto" }) + end) + return + end + + -- Create approval dialog with options local lines = { "", " BASH COMMAND APPROVAL", - " " .. string.rep("-", 50), + " " .. string.rep("─", 56), "", " Command:", " $ " .. command, "", - " " .. string.rep("-", 50), - " Press [y] or [Enter] to execute", - " Press [n], [q], or [Esc] to cancel", - "", } - local width = math.max(60, #command + 10) + -- Add warning for dangerous commands + if not perm_result.allowed and perm_result.reason ~= "Requires approval" then + table.insert(lines, " ⚠️ WARNING: " .. perm_result.reason) + table.insert(lines, "") + end + + table.insert(lines, " " .. string.rep("─", 56)) + table.insert(lines, "") + table.insert(lines, " [y] Allow once - Execute this command") + table.insert(lines, " [s] Allow this session - Auto-allow until restart") + table.insert(lines, " [a] Add to allow list - Always allow this command") + table.insert(lines, " [n] Reject - Cancel execution") + table.insert(lines, "") + table.insert(lines, " " .. string.rep("─", 56)) + table.insert(lines, " Press key to choose | [q] or [Esc] to cancel") + table.insert(lines, "") + + local width = math.max(65, #command + 15) local height = #lines local buf = vim.api.nvim_create_buf(false, true) @@ -196,45 +235,84 @@ function M.show_bash_approval(command, callback) title_pos = "center", }) - -- Apply some highlighting + -- Apply highlighting vim.api.nvim_buf_add_highlight(buf, -1, "Title", 1, 0, -1) vim.api.nvim_buf_add_highlight(buf, -1, "String", 5, 0, -1) + -- Highlight options + for i, line in ipairs(lines) do + if line:match("^%s+%[y%]") then + vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticOk", i - 1, 0, -1) + elseif line:match("^%s+%[s%]") then + vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticInfo", i - 1, 0, -1) + elseif line:match("^%s+%[a%]") then + vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticHint", i - 1, 0, -1) + elseif line:match("^%s+%[n%]") then + vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticError", i - 1, 0, -1) + elseif line:match("⚠️") then + vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticWarn", i - 1, 0, -1) + end + end + local callback_called = false - local function close_and_respond(approved) + local function close_and_respond(approved, permission_level) if callback_called then return end callback_called = true + -- Grant permission if approved with session or list level + if approved and permission_level then + permissions.grant_permission(command, permission_level) + end + pcall(vim.api.nvim_win_close, win, true) vim.schedule(function() - callback(approved) + callback({ approved = approved, permission_level = permission_level }) end) end local keymap_opts = { buffer = buf, noremap = true, silent = true, nowait = true } - -- Approve + -- Allow once vim.keymap.set("n", "y", function() - close_and_respond(true) + close_and_respond(true, "allow") end, keymap_opts) vim.keymap.set("n", "", function() - close_and_respond(true) + close_and_respond(true, "allow") + end, keymap_opts) + + -- Allow this session + vim.keymap.set("n", "s", function() + close_and_respond(true, "allow_session") + end, keymap_opts) + + -- Add to allow list + vim.keymap.set("n", "a", function() + close_and_respond(true, "allow_list") end, keymap_opts) -- Reject vim.keymap.set("n", "n", function() - close_and_respond(false) + close_and_respond(false, nil) end, keymap_opts) vim.keymap.set("n", "q", function() - close_and_respond(false) + close_and_respond(false, nil) end, keymap_opts) vim.keymap.set("n", "", function() - close_and_respond(false) + close_and_respond(false, nil) end, keymap_opts) end +--- Show approval dialog for bash commands (simple version for backward compatibility) +---@param command string The bash command to approve +---@param callback fun(approved: boolean) Called with user decision +function M.show_bash_approval_simple(command, callback) + M.show_bash_approval(command, function(result) + callback(result.approved) + end) +end + return M diff --git a/lua/codetyper/agent/executor.lua b/lua/codetyper/agent/executor.lua index 28e5826..08f278e 100644 --- a/lua/codetyper/agent/executor.lua +++ b/lua/codetyper/agent/executor.lua @@ -27,6 +27,9 @@ function M.execute(tool_name, parameters, callback) edit_file = M.handle_edit_file, write_file = M.handle_write_file, bash = M.handle_bash, + delete_file = M.handle_delete_file, + list_directory = M.handle_list_directory, + search_files = M.handle_search_files, } local handler = handlers[tool_name] @@ -156,6 +159,165 @@ function M.handle_bash(params, callback) }) end +--- Handle delete_file tool +---@param params table { path: string, reason: string } +---@param callback fun(result: ExecutionResult) +function M.handle_delete_file(params, callback) + local path = M.resolve_path(params.path) + local reason = params.reason or "No reason provided" + + -- Check if file exists + if not utils.file_exists(path) then + callback({ + success = false, + result = "File not found: " .. path, + requires_approval = false, + }) + return + end + + -- Read content for showing in diff (so user knows what they're deleting) + local content = utils.read_file(path) or "[Could not read file]" + + callback({ + success = true, + result = "Delete: " .. path .. " (" .. reason .. ")", + requires_approval = true, + diff_data = { + path = path, + original = content, + modified = "", -- Empty = deletion + operation = "delete", + reason = reason, + }, + }) +end + +--- Handle list_directory tool +---@param params table { path?: string, recursive?: boolean } +---@param callback fun(result: ExecutionResult) +function M.handle_list_directory(params, callback) + local path = params.path and M.resolve_path(params.path) or (utils.get_project_root() or vim.fn.getcwd()) + local recursive = params.recursive or false + + -- Use vim.fn.readdir or glob for directory listing + local entries = {} + local function list_dir(dir, depth) + if depth > 3 then + return + end + + local ok, files = pcall(vim.fn.readdir, dir) + if not ok or not files then + return + end + + for _, name in ipairs(files) do + if name ~= "." and name ~= ".." and not name:match("^%.git$") and not name:match("^node_modules$") then + local full_path = dir .. "/" .. name + local stat = vim.loop.fs_stat(full_path) + if stat then + local prefix = string.rep(" ", depth) + local type_indicator = stat.type == "directory" and "/" or "" + table.insert(entries, prefix .. name .. type_indicator) + + if recursive and stat.type == "directory" then + list_dir(full_path, depth + 1) + end + end + end + end + end + + list_dir(path, 0) + + local result = "Directory: " .. path .. "\n\n" .. table.concat(entries, "\n") + + callback({ + success = true, + result = result, + requires_approval = false, + }) +end + +--- Handle search_files tool +---@param params table { pattern?: string, content?: string, path?: string } +---@param callback fun(result: ExecutionResult) +function M.handle_search_files(params, callback) + local search_path = params.path and M.resolve_path(params.path) or (utils.get_project_root() or vim.fn.getcwd()) + local pattern = params.pattern + local content_search = params.content + + local results = {} + + if pattern then + -- Search by file name pattern using glob + local glob_pattern = search_path .. "/**/" .. pattern + local files = vim.fn.glob(glob_pattern, false, true) + + for _, file in ipairs(files) do + -- Skip common ignore patterns + if not file:match("node_modules") and not file:match("%.git/") then + table.insert(results, file:gsub(search_path .. "/", "")) + end + end + end + + if content_search then + -- Search by content using grep + local grep_results = {} + local grep_cmd = string.format("grep -rl '%s' '%s' 2>/dev/null | head -20", content_search:gsub("'", "\\'"), search_path) + + local handle = io.popen(grep_cmd) + if handle then + for line in handle:lines() do + if not line:match("node_modules") and not line:match("%.git/") then + table.insert(grep_results, line:gsub(search_path .. "/", "")) + end + end + handle:close() + end + + -- Merge with pattern results or use as primary results + if #results == 0 then + results = grep_results + else + -- Intersection of pattern and content results + local pattern_set = {} + for _, f in ipairs(results) do + pattern_set[f] = true + end + results = {} + for _, f in ipairs(grep_results) do + if pattern_set[f] then + table.insert(results, f) + end + end + end + end + + local result_text = "Search results" + if pattern then + result_text = result_text .. " (pattern: " .. pattern .. ")" + end + if content_search then + result_text = result_text .. " (content: " .. content_search .. ")" + end + result_text = result_text .. ":\n\n" + + if #results == 0 then + result_text = result_text .. "No files found." + else + result_text = result_text .. table.concat(results, "\n") + end + + callback({ + success = true, + result = result_text, + requires_approval = false, + }) +end + --- Actually apply an approved change ---@param diff_data DiffData The diff data to apply ---@param callback fun(result: ExecutionResult) @@ -164,6 +326,24 @@ function M.apply_change(diff_data, callback) -- Extract command from modified (remove "$ " prefix) local command = diff_data.modified:gsub("^%$ ", "") M.execute_bash_command(command, 30000, callback) + elseif diff_data.operation == "delete" then + -- Delete file + local ok, err = os.remove(diff_data.path) + if ok then + -- Close buffer if it's open + M.close_buffer_if_open(diff_data.path) + callback({ + success = true, + result = "Deleted: " .. diff_data.path, + requires_approval = false, + }) + else + callback({ + success = false, + result = "Failed to delete: " .. diff_data.path .. " (" .. (err or "unknown error") .. ")", + requires_approval = false, + }) + end else -- Write file local success = utils.write_file(diff_data.path, diff_data.modified) @@ -275,6 +455,22 @@ function M.reload_buffer_if_open(filepath) end end +--- Close a buffer if it's currently open (for deleted files) +---@param filepath string Path to the file +function M.close_buffer_if_open(filepath) + local full_path = vim.fn.fnamemodify(filepath, ":p") + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + if vim.api.nvim_buf_is_loaded(buf) then + local buf_name = vim.api.nvim_buf_get_name(buf) + if buf_name == full_path then + -- Force close the buffer + pcall(vim.api.nvim_buf_delete, buf, { force = true }) + break + end + end + end +end + --- Resolve a path (expand ~ and make absolute if needed) ---@param path string Path to resolve ---@return string Resolved path diff --git a/lua/codetyper/agent/init.lua b/lua/codetyper/agent/init.lua index c17add6..3751dd1 100644 --- a/lua/codetyper/agent/init.lua +++ b/lua/codetyper/agent/init.lua @@ -123,12 +123,14 @@ function M.agent_loop(context, callbacks) local config = codetyper.get_config() local parsed - if config.llm.provider == "claude" then + -- Copilot uses Claude-like response format + if config.llm.provider == "copilot" then parsed = parser.parse_claude_response(response) - -- For Claude, preserve the original content array for proper tool_use handling table.insert(state.conversation, { role = "assistant", - content = response.content, -- Keep original content blocks for Claude API + content = parsed.text or "", + tool_calls = parsed.tool_calls, + _raw_content = response.content, }) else -- For Ollama, response is the text directly @@ -200,9 +202,22 @@ function M.process_tool_calls(tool_calls, index, context, callbacks) show_fn = diff.show_diff end - show_fn(result.diff_data, function(approved) + show_fn(result.diff_data, function(approval_result) + -- Handle both old (boolean) and new (table) approval result formats + local approved = type(approval_result) == "table" and approval_result.approved or approval_result + local permission_level = type(approval_result) == "table" and approval_result.permission_level or nil + if approved then - logs.tool(tool_call.name, "approved", "User approved") + local log_msg = "User approved" + if permission_level == "allow_session" then + log_msg = "Allowed for session" + elseif permission_level == "allow_list" then + log_msg = "Added to allow list" + elseif permission_level == "auto" then + log_msg = "Auto-approved" + end + logs.tool(tool_call.name, "approved", log_msg) + -- Apply the change executor.apply_change(result.diff_data, function(apply_result) -- Store result for sending back to LLM @@ -261,8 +276,9 @@ function M.continue_with_results(context, callbacks) local codetyper = require("codetyper") local config = codetyper.get_config() - if config.llm.provider == "claude" then - -- Claude format: tool_result blocks + -- Copilot uses Claude-like format for tool results + if config.llm.provider == "copilot" then + -- Claude-style tool_result blocks local content = {} for _, result in ipairs(state.pending_tool_results) do table.insert(content, { diff --git a/lua/codetyper/agent/permissions.lua b/lua/codetyper/agent/permissions.lua new file mode 100644 index 0000000..421d0b2 --- /dev/null +++ b/lua/codetyper/agent/permissions.lua @@ -0,0 +1,229 @@ +---@mod codetyper.agent.permissions Permission manager for agent actions +--- +--- Manages permissions for bash commands and file operations with +--- allow, allow-session, allow-list, and reject options. + +local M = {} + +---@class PermissionState +---@field session_allowed table Commands allowed for this session +---@field allow_list table Patterns always allowed +---@field deny_list table Patterns always denied + +local state = { + session_allowed = {}, + allow_list = {}, + deny_list = {}, +} + +--- Dangerous command patterns that should never be auto-allowed +local DANGEROUS_PATTERNS = { + "^rm%s+%-rf", + "^rm%s+%-r%s+/", + "^rm%s+/", + "^sudo%s+rm", + "^chmod%s+777", + "^chmod%s+%-R", + "^chown%s+%-R", + "^dd%s+", + "^mkfs", + "^fdisk", + "^format", + ":.*>%s*/dev/", + "^curl.*|.*sh", + "^wget.*|.*sh", + "^eval%s+", + "`;.*`", + "%$%(.*%)", + "fork%s*bomb", +} + +--- Safe command patterns that can be auto-allowed +local SAFE_PATTERNS = { + "^ls%s", + "^ls$", + "^cat%s", + "^head%s", + "^tail%s", + "^grep%s", + "^find%s", + "^pwd$", + "^echo%s", + "^wc%s", + "^which%s", + "^type%s", + "^file%s", + "^stat%s", + "^git%s+status", + "^git%s+log", + "^git%s+diff", + "^git%s+branch", + "^git%s+show", + "^npm%s+list", + "^npm%s+ls", + "^npm%s+outdated", + "^yarn%s+list", + "^cargo%s+check", + "^cargo%s+test", + "^go%s+test", + "^go%s+build", + "^make%s+test", + "^make%s+check", +} + +---@alias PermissionLevel "allow"|"allow_session"|"allow_list"|"reject" + +---@class PermissionResult +---@field allowed boolean Whether action is allowed +---@field reason string Reason for the decision +---@field auto boolean Whether this was an automatic decision + +--- Check if a command matches a pattern +---@param command string The command to check +---@param pattern string The pattern to match +---@return boolean +local function matches_pattern(command, pattern) + return command:match(pattern) ~= nil +end + +--- Check if command is dangerous +---@param command string The command to check +---@return boolean, string|nil dangerous, reason +local function is_dangerous(command) + for _, pattern in ipairs(DANGEROUS_PATTERNS) do + if matches_pattern(command, pattern) then + return true, "Matches dangerous pattern: " .. pattern + end + end + return false, nil +end + +--- Check if command is safe +---@param command string The command to check +---@return boolean +local function is_safe(command) + for _, pattern in ipairs(SAFE_PATTERNS) do + if matches_pattern(command, pattern) then + return true + end + end + return false +end + +--- Normalize command for comparison (trim, lowercase first word) +---@param command string +---@return string +local function normalize_command(command) + return vim.trim(command) +end + +--- Check permission for a bash command +---@param command string The command to check +---@return PermissionResult +function M.check_bash_permission(command) + local normalized = normalize_command(command) + + -- Check deny list first + for pattern, _ in pairs(state.deny_list) do + if matches_pattern(normalized, pattern) then + return { + allowed = false, + reason = "Command in deny list", + auto = true, + } + end + end + + -- Check if command is dangerous + local dangerous, reason = is_dangerous(normalized) + if dangerous then + return { + allowed = false, + reason = reason, + auto = false, -- Require explicit approval for dangerous commands + } + end + + -- Check session allowed + if state.session_allowed[normalized] then + return { + allowed = true, + reason = "Allowed for this session", + auto = true, + } + end + + -- Check allow list patterns + for pattern, _ in pairs(state.allow_list) do + if matches_pattern(normalized, pattern) then + return { + allowed = true, + reason = "Matches allow list pattern", + auto = true, + } + end + end + + -- Check if command is inherently safe + if is_safe(normalized) then + return { + allowed = true, + reason = "Safe read-only command", + auto = true, + } + end + + -- Otherwise, require explicit permission + return { + allowed = false, + reason = "Requires approval", + auto = false, + } +end + +--- Grant permission for a command +---@param command string The command +---@param level PermissionLevel The permission level +function M.grant_permission(command, level) + local normalized = normalize_command(command) + + if level == "allow_session" then + state.session_allowed[normalized] = true + elseif level == "allow_list" then + -- Add as pattern (escape special chars for exact match) + local pattern = "^" .. vim.pesc(normalized) .. "$" + state.allow_list[pattern] = true + end +end + +--- Add a pattern to the allow list +---@param pattern string Lua pattern to allow +function M.add_to_allow_list(pattern) + state.allow_list[pattern] = true +end + +--- Add a pattern to the deny list +---@param pattern string Lua pattern to deny +function M.add_to_deny_list(pattern) + state.deny_list[pattern] = true +end + +--- Clear session permissions +function M.clear_session() + state.session_allowed = {} +end + +--- Reset all permissions +function M.reset() + state.session_allowed = {} + state.allow_list = {} + state.deny_list = {} +end + +--- Get current permission state (for debugging) +---@return PermissionState +function M.get_state() + return vim.deepcopy(state) +end + +return M diff --git a/lua/codetyper/agent/queue.lua b/lua/codetyper/agent/queue.lua index a17e004..9c763c1 100644 --- a/lua/codetyper/agent/queue.lua +++ b/lua/codetyper/agent/queue.lua @@ -23,7 +23,7 @@ local M = {} ---@field priority number Priority (1=high, 2=normal, 3=low) ---@field status string "pending"|"processing"|"completed"|"escalated"|"cancelled"|"needs_context"|"failed" ---@field attempt_count number Number of processing attempts ----@field worker_type string|nil LLM provider used ("ollama"|"claude"|etc) +---@field worker_type string|nil LLM provider used ("ollama"|"openai"|"gemini"|"copilot") ---@field created_at number System time when created ---@field intent Intent|nil Detected intent from prompt ---@field scope ScopeInfo|nil Resolved scope (function/class/file) diff --git a/lua/codetyper/agent/scheduler.lua b/lua/codetyper/agent/scheduler.lua index 73483b7..c98f4c3 100644 --- a/lua/codetyper/agent/scheduler.lua +++ b/lua/codetyper/agent/scheduler.lua @@ -28,7 +28,7 @@ local state = { max_concurrent = 2, completion_delay_ms = 100, apply_delay_ms = 5000, -- Wait before applying code - remote_provider = "claude", -- Default fallback provider + remote_provider = "copilot", -- Default fallback provider }, } @@ -90,9 +90,7 @@ local function get_remote_provider() -- If current provider is ollama, use configured remote if config.llm.provider == "ollama" then -- Check which remote provider is configured - if config.llm.claude and config.llm.claude.api_key then - return "claude" - elseif config.llm.openai and config.llm.openai.api_key then + if config.llm.openai and config.llm.openai.api_key then return "openai" elseif config.llm.gemini and config.llm.gemini.api_key then return "gemini" @@ -120,7 +118,7 @@ local function get_primary_provider() return config.llm.provider end end - return "claude" + return "ollama" end --- Retry event with additional context diff --git a/lua/codetyper/agent/tools.lua b/lua/codetyper/agent/tools.lua index 21d9bbf..7ad9a26 100644 --- a/lua/codetyper/agent/tools.lua +++ b/lua/codetyper/agent/tools.lua @@ -81,6 +81,67 @@ M.definitions = { required = { "command" }, }, }, + + delete_file = { + name = "delete_file", + description = "Delete a file from the filesystem. Use with caution - requires explicit user approval.", + parameters = { + type = "object", + properties = { + path = { + type = "string", + description = "Path to the file to delete", + }, + reason = { + type = "string", + description = "Reason for deleting this file (shown to user for approval)", + }, + }, + required = { "path", "reason" }, + }, + }, + + list_directory = { + name = "list_directory", + description = "List files and directories in a path. Use to explore project structure.", + parameters = { + type = "object", + properties = { + path = { + type = "string", + description = "Path to the directory to list (defaults to current directory)", + }, + recursive = { + type = "boolean", + description = "Whether to list recursively (default: false, max depth: 3)", + }, + }, + required = {}, + }, + }, + + search_files = { + name = "search_files", + description = "Search for files by name pattern or content. Use to find relevant files in the project.", + parameters = { + type = "object", + properties = { + pattern = { + type = "string", + description = "Glob pattern for file names (e.g., '*.lua', 'test_*.py')", + }, + content = { + type = "string", + description = "Search for files containing this text", + }, + path = { + type = "string", + description = "Directory to search in (defaults to project root)", + }, + }, + required = {}, + }, + }, } --- Convert tool definitions to Claude API format diff --git a/lua/codetyper/agent/ui.lua b/lua/codetyper/agent/ui.lua index 1ec0b3e..a738955 100644 --- a/lua/codetyper/agent/ui.lua +++ b/lua/codetyper/agent/ui.lua @@ -35,14 +35,62 @@ local state = { local ns_chat = vim.api.nvim_create_namespace("codetyper_agent_chat") local ns_logs = vim.api.nvim_create_namespace("codetyper_agent_logs") ---- Fixed widths -local CHAT_WIDTH = 300 -local LOGS_WIDTH = 50 +--- Fixed heights local INPUT_HEIGHT = 5 +local LOGS_WIDTH = 50 + +--- Calculate dynamic width (1/4 of screen, minimum 30) +---@return number +local function get_panel_width() + return math.max(math.floor(vim.o.columns * 0.25), 30) +end --- Autocmd group local agent_augroup = nil +--- Autocmd group for width maintenance +local width_augroup = nil + +--- Store target width +local target_width = nil + +--- Setup autocmd to always maintain 1/4 window width +local function setup_width_autocmd() + -- Clear previous autocmd group if exists + if width_augroup then + pcall(vim.api.nvim_del_augroup_by_id, width_augroup) + end + + width_augroup = vim.api.nvim_create_augroup("CodetypeAgentWidth", { clear = true }) + + -- Always maintain 1/4 width on any window event + vim.api.nvim_create_autocmd({ "WinResized", "WinNew", "WinClosed", "VimResized" }, { + group = width_augroup, + callback = function() + if not state.is_open or not state.chat_win then + return + end + if not vim.api.nvim_win_is_valid(state.chat_win) then + return + end + + vim.schedule(function() + if state.chat_win and vim.api.nvim_win_is_valid(state.chat_win) then + -- Always calculate 1/4 of current screen width + local new_target = math.max(math.floor(vim.o.columns * 0.25), 30) + target_width = new_target + + local current_width = vim.api.nvim_win_get_width(state.chat_win) + if current_width ~= target_width then + pcall(vim.api.nvim_win_set_width, state.chat_win, target_width) + end + end + end) + end, + desc = "Maintain Agent panel at 1/4 window width", + }) +end + --- Add a log entry to the logs buffer ---@param entry table Log entry local function add_log_entry(entry) @@ -479,7 +527,7 @@ function M.open() vim.cmd("topleft vsplit") state.chat_win = vim.api.nvim_get_current_win() vim.api.nvim_win_set_buf(state.chat_win, state.chat_buf) - vim.api.nvim_win_set_width(state.chat_win, CHAT_WIDTH) + vim.api.nvim_win_set_width(state.chat_win, get_panel_width()) -- Window options for chat vim.wo[state.chat_win].number = false @@ -592,6 +640,10 @@ function M.open() end, }) + -- Setup autocmd to maintain 1/4 width + target_width = get_panel_width() + setup_width_autocmd() + state.is_open = true -- Focus input and log startup @@ -603,7 +655,16 @@ function M.open() if ok then local config = codetyper.get_config() local provider = config.llm.provider - local model = provider == "claude" and config.llm.claude.model or config.llm.ollama.model + local model = "unknown" + if provider == "ollama" then + model = config.llm.ollama.model + elseif provider == "openai" then + model = config.llm.openai.model + elseif provider == "gemini" then + model = config.llm.gemini.model + elseif provider == "copilot" then + model = config.llm.copilot.model + end logs.info(string.format("%s (%s)", provider, model)) end end diff --git a/lua/codetyper/agent/worker.lua b/lua/codetyper/agent/worker.lua index c6b8a10..2af866c 100644 --- a/lua/codetyper/agent/worker.lua +++ b/lua/codetyper/agent/worker.lua @@ -178,8 +178,7 @@ local active_workers = {} --- Default timeouts by provider type local default_timeouts = { ollama = 30000, -- 30s for local - claude = 60000, -- 60s for remote - openai = 60000, + openai = 60000, -- 60s for remote gemini = 60000, copilot = 60000, } @@ -225,6 +224,54 @@ local function format_attached_files(attached_files) return table.concat(parts, "") end +--- Format indexed project context for inclusion in prompt +---@param indexed_context table|nil +---@return string +local function format_indexed_context(indexed_context) + if not indexed_context then + return "" + end + + local parts = {} + + -- Project type + if indexed_context.project_type and indexed_context.project_type ~= "unknown" then + table.insert(parts, "Project type: " .. indexed_context.project_type) + end + + -- Relevant symbols + if indexed_context.relevant_symbols then + local symbol_list = {} + for symbol, files in pairs(indexed_context.relevant_symbols) do + if #files > 0 then + table.insert(symbol_list, symbol .. " (in " .. files[1] .. ")") + end + end + if #symbol_list > 0 then + table.insert(parts, "Relevant symbols: " .. table.concat(symbol_list, ", ")) + end + end + + -- Learned patterns + if indexed_context.patterns and #indexed_context.patterns > 0 then + local pattern_list = {} + for i, p in ipairs(indexed_context.patterns) do + if i <= 3 then + table.insert(pattern_list, p.content or "") + end + end + if #pattern_list > 0 then + table.insert(parts, "Project conventions: " .. table.concat(pattern_list, "; ")) + end + end + + if #parts == 0 then + return "" + end + + return "\n\n--- Project Context ---\n" .. table.concat(parts, "\n") +end + --- Build prompt for code generation ---@param event table PromptEvent ---@return string prompt @@ -245,9 +292,26 @@ local function build_prompt(event) local filetype = vim.fn.fnamemodify(event.target_path or "", ":e") + -- Get indexed project context + local indexed_context = nil + local indexed_content = "" + pcall(function() + local indexer = require("codetyper.indexer") + indexed_context = indexer.get_context_for({ + file = event.target_path, + intent = event.intent, + prompt = event.prompt_content, + scope = event.scope_text, + }) + indexed_content = format_indexed_context(indexed_context) + end) + -- Format attached files local attached_content = format_attached_files(event.attached_files) + -- Combine attached files and indexed context + local extra_context = attached_content .. indexed_content + -- Build context with scope information local context = { target_path = event.target_path, @@ -258,6 +322,7 @@ local function build_prompt(event) scope_range = event.scope_range, intent = event.intent, attached_files = event.attached_files, + indexed_context = indexed_context, } -- Build the actual prompt based on intent and scope @@ -296,7 +361,7 @@ Return ONLY the complete %s with implementation. No explanations, no duplicates. scope_type, filetype, event.scope_text, - attached_content, + extra_context, event.prompt_content, scope_type ) @@ -317,7 +382,7 @@ Return the complete transformed %s. Output only code, no explanations.]], filetype, filetype, event.scope_text, - attached_content, + extra_context, event.prompt_content, scope_type ) @@ -337,7 +402,7 @@ Output only the code to insert, no explanations.]], scope_name, filetype, event.scope_text, - attached_content, + extra_context, event.prompt_content ) end @@ -357,7 +422,7 @@ Output only code, no explanations.]], filetype, filetype, target_content:sub(1, 4000), -- Limit context size - attached_content, + extra_context, event.prompt_content ) end diff --git a/lua/codetyper/ask.lua b/lua/codetyper/ask.lua index 4b2b880..0282f0b 100644 --- a/lua/codetyper/ask.lua +++ b/lua/codetyper/ask.lua @@ -728,14 +728,17 @@ local function build_file_context() end --- Build context for the question +---@param intent? table Detected intent from intent module ---@return table Context object -local function build_context() +local function build_context(intent) local context = { project_root = utils.get_project_root(), current_file = nil, current_content = nil, language = nil, referenced_files = state.referenced_files, + brain_context = nil, + indexer_context = nil, } -- Try to get current file context from the non-ask window @@ -754,49 +757,140 @@ local function build_context() end end + -- Add brain context if intent needs it + if intent and intent.needs_brain_context then + local ok_brain, brain = pcall(require, "codetyper.brain") + if ok_brain and brain.is_initialized() then + context.brain_context = brain.get_context_for_llm({ + file = context.current_file, + max_tokens = 1000, + }) + end + end + + -- Add indexer context if intent needs project-wide context + if intent and intent.needs_project_context then + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if ok_indexer then + context.indexer_context = indexer.get_context_for({ + file = context.current_file, + prompt = "", -- Will be filled later + intent = intent, + }) + end + end + return context end ---- Submit the question to LLM -function M.submit() - local question = get_input_text() - - if not question or question:match("^%s*$") then - utils.notify("Please enter a question", vim.log.levels.WARN) - M.focus_input() +--- Append exploration log to output buffer +---@param msg string +---@param level string +local function append_exploration_log(msg, level) + if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then return end - -- Build context BEFORE clearing input (to preserve file references) - local context = build_context() - local file_context, file_count = build_file_context() + vim.schedule(function() + if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then + return + end - -- Build display message (without full file contents) - local display_question = question - if file_count > 0 then - display_question = question .. "\n📎 " .. file_count .. " file(s) attached" + vim.bo[state.output_buf].modifiable = true + + local lines = vim.api.nvim_buf_get_lines(state.output_buf, 0, -1, false) + + -- Format based on level + local formatted = msg + if level == "progress" then + formatted = msg + elseif level == "debug" then + formatted = msg + elseif level == "file" then + formatted = msg + end + + table.insert(lines, formatted) + + vim.api.nvim_buf_set_lines(state.output_buf, 0, -1, false, lines) + vim.bo[state.output_buf].modifiable = false + + -- Scroll to bottom + if state.output_win and vim.api.nvim_win_is_valid(state.output_win) then + local line_count = vim.api.nvim_buf_line_count(state.output_buf) + pcall(vim.api.nvim_win_set_cursor, state.output_win, { line_count, 0 }) + end + end) +end + +--- Continue submission after exploration +---@param question string +---@param intent table +---@param context table +---@param file_context string +---@param file_count number +---@param exploration_result table|nil +local function continue_submit(question, intent, context, file_context, file_count, exploration_result) + -- Get prompt type based on intent + local ok_intent, intent_module = pcall(require, "codetyper.ask.intent") + local prompt_type = "ask" + if ok_intent then + prompt_type = intent_module.get_prompt_type(intent) end - -- Add user message to output - append_to_output(display_question, true) - - -- Clear input and references AFTER building context - M.clear_input() - - -- Build system prompt for ask mode using prompts module + -- Build system prompt using prompts module local prompts = require("codetyper.prompts") - local system_prompt = prompts.system.ask + local system_prompt = prompts.system[prompt_type] or prompts.system.ask if context.current_file then system_prompt = system_prompt .. "\n\nCurrent open file: " .. context.current_file system_prompt = system_prompt .. "\nLanguage: " .. (context.language or "unknown") end + -- Add exploration context if available + if exploration_result then + local ok_explorer, explorer = pcall(require, "codetyper.ask.explorer") + if ok_explorer then + local explore_context = explorer.build_context(exploration_result) + system_prompt = system_prompt .. "\n\n=== PROJECT EXPLORATION RESULTS ===\n" + system_prompt = system_prompt .. explore_context + system_prompt = system_prompt .. "\n=== END EXPLORATION ===\n" + end + end + + -- Add brain context (learned patterns, conventions) + if context.brain_context and context.brain_context ~= "" then + system_prompt = system_prompt .. "\n\n=== LEARNED PROJECT KNOWLEDGE ===\n" + system_prompt = system_prompt .. context.brain_context + system_prompt = system_prompt .. "\n=== END LEARNED KNOWLEDGE ===\n" + end + + -- Add indexer context (project structure, symbols) + if context.indexer_context then + local idx_ctx = context.indexer_context + if idx_ctx.project_type and idx_ctx.project_type ~= "unknown" then + system_prompt = system_prompt .. "\n\nProject type: " .. idx_ctx.project_type + end + if idx_ctx.relevant_symbols and next(idx_ctx.relevant_symbols) then + system_prompt = system_prompt .. "\n\nRelevant symbols in project:" + for symbol, files in pairs(idx_ctx.relevant_symbols) do + system_prompt = system_prompt .. "\n - " .. symbol .. " (in: " .. table.concat(files, ", ") .. ")" + end + end + if idx_ctx.patterns and #idx_ctx.patterns > 0 then + system_prompt = system_prompt .. "\n\nProject patterns/memories:" + for _, pattern in ipairs(idx_ctx.patterns) do + system_prompt = system_prompt .. "\n - " .. (pattern.summary or pattern.content or "") + end + end + end + -- Add to history table.insert(state.history, { role = "user", content = question }) -- Show loading indicator - append_to_output("⏳ Thinking...", false) + append_to_output("", false) + append_to_output("⏳ Generating response...", false) -- Get LLM client and generate response local ok, llm = pcall(require, "codetyper.llm") @@ -829,10 +923,23 @@ function M.submit() .. "\n```" end + -- Add exploration summary to prompt if available + if exploration_result then + full_prompt = full_prompt + .. "\n\nPROJECT EXPLORATION COMPLETE: " + .. exploration_result.total_files + .. " files analyzed. " + .. "Project type: " + .. exploration_result.project.language + .. " (" + .. (exploration_result.project.framework or exploration_result.project.type) + .. ")" + end + local request_context = { file_content = file_context ~= "" and file_context or context.current_content, language = context.language, - prompt_type = "explain", + prompt_type = prompt_type, file_path = context.current_file, } @@ -844,9 +951,9 @@ function M.submit() -- Remove last few lines (the thinking message) local to_remove = 0 for i = #lines, 1, -1 do - if lines[i]:match("Thinking") or lines[i]:match("^[│└┌─]") then + if lines[i]:match("Generating") or lines[i]:match("^[│└┌─]") or lines[i] == "" then to_remove = to_remove + 1 - if lines[i]:match("┌") then + if lines[i]:match("┌") or to_remove >= 5 then break end else @@ -879,6 +986,77 @@ function M.submit() end) end +--- Submit the question to LLM +function M.submit() + local question = get_input_text() + + if not question or question:match("^%s*$") then + utils.notify("Please enter a question", vim.log.levels.WARN) + M.focus_input() + return + end + + -- Detect intent from prompt + local ok_intent, intent_module = pcall(require, "codetyper.ask.intent") + local intent = nil + if ok_intent then + intent = intent_module.detect(question) + else + -- Fallback intent + intent = { + type = "ask", + confidence = 0.5, + needs_project_context = false, + needs_brain_context = true, + needs_exploration = false, + } + end + + -- Build context BEFORE clearing input (to preserve file references) + local context = build_context(intent) + local file_context, file_count = build_file_context() + + -- Build display message (without full file contents) + local display_question = question + if file_count > 0 then + display_question = question .. "\n📎 " .. file_count .. " file(s) attached" + end + -- Show detected intent if not standard ask + if intent.type ~= "ask" then + display_question = display_question .. "\n🎯 " .. intent.type:upper() .. " mode" + end + -- Show exploration indicator + if intent.needs_exploration then + display_question = display_question .. "\n🔍 Project exploration required" + end + + -- Add user message to output + append_to_output(display_question, true) + + -- Clear input and references AFTER building context + M.clear_input() + + -- Check if exploration is needed + if intent.needs_exploration then + local ok_explorer, explorer = pcall(require, "codetyper.ask.explorer") + if ok_explorer then + local root = utils.get_project_root() + if root then + -- Start exploration with logging + append_to_output("", false) + explorer.explore(root, append_exploration_log, function(exploration_result) + -- After exploration completes, continue with LLM request + continue_submit(question, intent, context, file_context, file_count, exploration_result) + end) + return + end + end + end + + -- No exploration needed, continue directly + continue_submit(question, intent, context, file_context, file_count, nil) +end + --- Clear chat history function M.clear_history() state.history = {} diff --git a/lua/codetyper/ask/explorer.lua b/lua/codetyper/ask/explorer.lua new file mode 100644 index 0000000..cc03f4a --- /dev/null +++ b/lua/codetyper/ask/explorer.lua @@ -0,0 +1,676 @@ +---@mod codetyper.ask.explorer Project exploration for Ask mode +---@brief [[ +--- Performs comprehensive project exploration when explaining a project. +--- Shows progress, indexes files, and builds brain context. +---@brief ]] + +local M = {} + +local utils = require("codetyper.utils") + +---@class ExplorationState +---@field is_exploring boolean +---@field files_scanned number +---@field total_files number +---@field current_file string|nil +---@field findings table +---@field on_log fun(msg: string, level: string)|nil + +local state = { + is_exploring = false, + files_scanned = 0, + total_files = 0, + current_file = nil, + findings = {}, + on_log = nil, +} + +--- File extensions to analyze +local ANALYZABLE_EXTENSIONS = { + lua = true, + ts = true, + tsx = true, + js = true, + jsx = true, + py = true, + go = true, + rs = true, + rb = true, + java = true, + c = true, + cpp = true, + h = true, + hpp = true, + json = true, + yaml = true, + yml = true, + toml = true, + md = true, + xml = true, +} + +--- Directories to skip +local SKIP_DIRS = { + -- Version control + [".git"] = true, + [".svn"] = true, + [".hg"] = true, + + -- IDE/Editor + [".idea"] = true, + [".vscode"] = true, + [".cursor"] = true, + [".cursorignore"] = true, + [".claude"] = true, + [".zed"] = true, + + -- Project tooling + [".coder"] = true, + [".github"] = true, + [".gitlab"] = true, + [".husky"] = true, + + -- Build outputs + dist = true, + build = true, + out = true, + target = true, + bin = true, + obj = true, + [".build"] = true, + [".output"] = true, + + -- Dependencies + node_modules = true, + vendor = true, + [".vendor"] = true, + packages = true, + bower_components = true, + jspm_packages = true, + + -- Cache/temp + [".cache"] = true, + [".tmp"] = true, + [".temp"] = true, + __pycache__ = true, + [".pytest_cache"] = true, + [".mypy_cache"] = true, + [".ruff_cache"] = true, + [".tox"] = true, + [".nox"] = true, + [".eggs"] = true, + ["*.egg-info"] = true, + + -- Framework specific + [".next"] = true, + [".nuxt"] = true, + [".svelte-kit"] = true, + [".vercel"] = true, + [".netlify"] = true, + [".serverless"] = true, + [".turbo"] = true, + + -- Testing/coverage + coverage = true, + [".nyc_output"] = true, + htmlcov = true, + + -- Logs + logs = true, + log = true, + + -- OS files + [".DS_Store"] = true, + Thumbs_db = true, +} + +--- Files to skip (patterns) +local SKIP_FILES = { + -- Lock files + "package%-lock%.json", + "yarn%.lock", + "pnpm%-lock%.yaml", + "Gemfile%.lock", + "Cargo%.lock", + "poetry%.lock", + "Pipfile%.lock", + "composer%.lock", + "go%.sum", + "flake%.lock", + "%.lock$", + "%-lock%.json$", + "%-lock%.yaml$", + + -- Generated files + "%.min%.js$", + "%.min%.css$", + "%.bundle%.js$", + "%.chunk%.js$", + "%.map$", + "%.d%.ts$", + + -- Binary/media (shouldn't match anyway but be safe) + "%.png$", + "%.jpg$", + "%.jpeg$", + "%.gif$", + "%.ico$", + "%.svg$", + "%.woff", + "%.ttf$", + "%.eot$", + "%.pdf$", + "%.zip$", + "%.tar", + "%.gz$", + + -- Config that's not useful + "%.env", + "%.env%.", +} + +--- Log a message during exploration +---@param msg string +---@param level? string "info"|"debug"|"file"|"progress" +local function log(msg, level) + level = level or "info" + if state.on_log then + state.on_log(msg, level) + end +end + +--- Check if file should be skipped +---@param filename string +---@return boolean +local function should_skip_file(filename) + for _, pattern in ipairs(SKIP_FILES) do + if filename:match(pattern) then + return true + end + end + return false +end + +--- Check if directory should be skipped +---@param dirname string +---@return boolean +local function should_skip_dir(dirname) + -- Direct match + if SKIP_DIRS[dirname] then + return true + end + -- Pattern match for .cursor* etc + if dirname:match("^%.cursor") then + return true + end + return false +end + +--- Get all files in project +---@param root string Project root +---@return string[] files +local function get_project_files(root) + local files = {} + + local function scan_dir(dir) + local handle = vim.loop.fs_scandir(dir) + if not handle then + return + end + + while true do + local name, type = vim.loop.fs_scandir_next(handle) + if not name then + break + end + + local full_path = dir .. "/" .. name + + if type == "directory" then + if not should_skip_dir(name) then + scan_dir(full_path) + end + elseif type == "file" then + if not should_skip_file(name) then + local ext = name:match("%.([^%.]+)$") + if ext and ANALYZABLE_EXTENSIONS[ext:lower()] then + table.insert(files, full_path) + end + end + end + end + end + + scan_dir(root) + return files +end + +--- Analyze a single file +---@param filepath string +---@return table|nil analysis +local function analyze_file(filepath) + local content = utils.read_file(filepath) + if not content or content == "" then + return nil + end + + local ext = filepath:match("%.([^%.]+)$") or "" + local lines = vim.split(content, "\n") + + local analysis = { + path = filepath, + extension = ext, + lines = #lines, + size = #content, + imports = {}, + exports = {}, + functions = {}, + classes = {}, + summary = "", + } + + -- Extract key patterns based on file type + for i, line in ipairs(lines) do + -- Imports/requires + local import = line:match('import%s+.*%s+from%s+["\']([^"\']+)["\']') + or line:match('require%(["\']([^"\']+)["\']%)') + or line:match("from%s+([%w_.]+)%s+import") + if import then + table.insert(analysis.imports, { source = import, line = i }) + end + + -- Function definitions + local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(") + or line:match("^%s*local%s+function%s+([%w_]+)%s*%(") + or line:match("^%s*def%s+([%w_]+)%s*%(") + or line:match("^%s*func%s+([%w_]+)%s*%(") + or line:match("^%s*async%s+function%s+([%w_]+)%s*%(") + or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(") + if func then + table.insert(analysis.functions, { name = func, line = i }) + end + + -- Class definitions + local class = line:match("^%s*class%s+([%w_]+)") + or line:match("^%s*public%s+class%s+([%w_]+)") + or line:match("^%s*interface%s+([%w_]+)") + if class then + table.insert(analysis.classes, { name = class, line = i }) + end + + -- Exports + local exp = line:match("^%s*export%s+.*%s+([%w_]+)") + or line:match("^%s*module%.exports%s*=") + or line:match("^return%s+M") + if exp then + table.insert(analysis.exports, { name = exp, line = i }) + end + end + + -- Create summary + local parts = {} + if #analysis.functions > 0 then + table.insert(parts, #analysis.functions .. " functions") + end + if #analysis.classes > 0 then + table.insert(parts, #analysis.classes .. " classes") + end + if #analysis.imports > 0 then + table.insert(parts, #analysis.imports .. " imports") + end + analysis.summary = table.concat(parts, ", ") + + return analysis +end + +--- Detect project type from files +---@param root string +---@return string type, table info +local function detect_project_type(root) + local info = { + name = vim.fn.fnamemodify(root, ":t"), + type = "unknown", + framework = nil, + language = nil, + } + + -- Check for common project files + if utils.file_exists(root .. "/package.json") then + info.type = "node" + info.language = "JavaScript/TypeScript" + local content = utils.read_file(root .. "/package.json") + if content then + local ok, pkg = pcall(vim.json.decode, content) + if ok then + info.name = pkg.name or info.name + if pkg.dependencies then + if pkg.dependencies.react then + info.framework = "React" + elseif pkg.dependencies.vue then + info.framework = "Vue" + elseif pkg.dependencies.next then + info.framework = "Next.js" + elseif pkg.dependencies.express then + info.framework = "Express" + end + end + end + end + elseif utils.file_exists(root .. "/pom.xml") then + info.type = "maven" + info.language = "Java" + local content = utils.read_file(root .. "/pom.xml") + if content and content:match("spring%-boot") then + info.framework = "Spring Boot" + end + elseif utils.file_exists(root .. "/Cargo.toml") then + info.type = "rust" + info.language = "Rust" + elseif utils.file_exists(root .. "/go.mod") then + info.type = "go" + info.language = "Go" + elseif utils.file_exists(root .. "/requirements.txt") or utils.file_exists(root .. "/pyproject.toml") then + info.type = "python" + info.language = "Python" + elseif utils.file_exists(root .. "/init.lua") or utils.file_exists(root .. "/plugin/") then + info.type = "neovim-plugin" + info.language = "Lua" + end + + return info.type, info +end + +--- Build project structure summary +---@param files string[] +---@param root string +---@return table structure +local function build_structure(files, root) + local structure = { + directories = {}, + by_extension = {}, + total_files = #files, + } + + for _, file in ipairs(files) do + local relative = file:gsub("^" .. vim.pesc(root) .. "/", "") + local dir = vim.fn.fnamemodify(relative, ":h") + local ext = file:match("%.([^%.]+)$") or "unknown" + + structure.directories[dir] = (structure.directories[dir] or 0) + 1 + structure.by_extension[ext] = (structure.by_extension[ext] or 0) + 1 + end + + return structure +end + +--- Explore project and build context +---@param root string Project root +---@param on_log fun(msg: string, level: string) Log callback +---@param on_complete fun(result: table) Completion callback +function M.explore(root, on_log, on_complete) + if state.is_exploring then + on_log("⚠️ Already exploring...", "warning") + return + end + + state.is_exploring = true + state.on_log = on_log + state.findings = {} + + -- Start exploration + log("⏺ Exploring project structure...", "info") + log("", "info") + + -- Detect project type + log(" Detect(Project type)", "progress") + local project_type, project_info = detect_project_type(root) + log(" ⎿ " .. project_info.language .. " (" .. (project_info.framework or project_type) .. ")", "debug") + + state.findings.project = project_info + + -- Get all files + log("", "info") + log(" Scan(Project files)", "progress") + local files = get_project_files(root) + state.total_files = #files + log(" ⎿ Found " .. #files .. " analyzable files", "debug") + + -- Build structure + local structure = build_structure(files, root) + state.findings.structure = structure + + -- Show directory breakdown + log("", "info") + log(" Structure(Directories)", "progress") + local sorted_dirs = {} + for dir, count in pairs(structure.directories) do + table.insert(sorted_dirs, { dir = dir, count = count }) + end + table.sort(sorted_dirs, function(a, b) + return a.count > b.count + end) + for i, entry in ipairs(sorted_dirs) do + if i <= 5 then + log(" ⎿ " .. entry.dir .. " (" .. entry.count .. " files)", "debug") + end + end + if #sorted_dirs > 5 then + log(" ⎿ +" .. (#sorted_dirs - 5) .. " more directories", "debug") + end + + -- Analyze files asynchronously + log("", "info") + log(" Analyze(Source files)", "progress") + + state.files_scanned = 0 + local analyses = {} + local key_files = {} + + -- Process files in batches to avoid blocking + local batch_size = 10 + local current_batch = 0 + + local function process_batch() + local start_idx = current_batch * batch_size + 1 + local end_idx = math.min(start_idx + batch_size - 1, #files) + + for i = start_idx, end_idx do + local file = files[i] + local relative = file:gsub("^" .. vim.pesc(root) .. "/", "") + + state.files_scanned = state.files_scanned + 1 + state.current_file = relative + + local analysis = analyze_file(file) + if analysis then + analysis.relative_path = relative + table.insert(analyses, analysis) + + -- Track key files (many functions/classes) + if #analysis.functions >= 3 or #analysis.classes >= 1 then + table.insert(key_files, { + path = relative, + functions = #analysis.functions, + classes = #analysis.classes, + summary = analysis.summary, + }) + end + end + + -- Log some files + if i <= 3 or (i % 20 == 0) then + log(" ⎿ " .. relative .. ": " .. (analysis and analysis.summary or "(empty)"), "file") + end + end + + -- Progress update + local progress = math.floor((state.files_scanned / state.total_files) * 100) + if progress % 25 == 0 and progress > 0 then + log(" ⎿ " .. progress .. "% complete (" .. state.files_scanned .. "/" .. state.total_files .. ")", "debug") + end + + current_batch = current_batch + 1 + + if end_idx < #files then + -- Schedule next batch + vim.defer_fn(process_batch, 10) + else + -- Complete + finish_exploration(root, analyses, key_files, on_complete) + end + end + + -- Start processing + vim.defer_fn(process_batch, 10) +end + +--- Finish exploration and store results +---@param root string +---@param analyses table +---@param key_files table +---@param on_complete fun(result: table) +function finish_exploration(root, analyses, key_files, on_complete) + log(" ⎿ +" .. (#analyses - 3) .. " more files analyzed", "debug") + + -- Show key files + if #key_files > 0 then + log("", "info") + log(" KeyFiles(Important components)", "progress") + table.sort(key_files, function(a, b) + return (a.functions + a.classes * 2) > (b.functions + b.classes * 2) + end) + for i, kf in ipairs(key_files) do + if i <= 5 then + log(" ⎿ " .. kf.path .. ": " .. kf.summary, "file") + end + end + if #key_files > 5 then + log(" ⎿ +" .. (#key_files - 5) .. " more key files", "debug") + end + end + + state.findings.analyses = analyses + state.findings.key_files = key_files + + -- Store in brain if available + local ok_brain, brain = pcall(require, "codetyper.brain") + if ok_brain and brain.is_initialized() then + log("", "info") + log(" Store(Brain context)", "progress") + + -- Store project pattern + brain.learn({ + type = "pattern", + file = root, + content = { + summary = "Project: " .. state.findings.project.name, + detail = state.findings.project.language + .. " " + .. (state.findings.project.framework or state.findings.project.type), + code = nil, + }, + context = { + file = root, + language = state.findings.project.language, + }, + }) + + -- Store key file patterns + for i, kf in ipairs(key_files) do + if i <= 10 then + brain.learn({ + type = "pattern", + file = root .. "/" .. kf.path, + content = { + summary = kf.path .. " - " .. kf.summary, + detail = kf.summary, + }, + context = { + file = kf.path, + }, + }) + end + end + + log(" ⎿ Stored " .. math.min(#key_files, 10) + 1 .. " patterns in brain", "debug") + end + + -- Store in indexer if available + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if ok_indexer then + log(" Index(Project index)", "progress") + indexer.index_project(function(index) + log(" ⎿ Indexed " .. (index.stats.files or 0) .. " files", "debug") + end) + end + + log("", "info") + log("✓ Exploration complete!", "info") + log("", "info") + + -- Build result + local result = { + project = state.findings.project, + structure = state.findings.structure, + key_files = key_files, + total_files = state.total_files, + analyses = analyses, + } + + state.is_exploring = false + state.on_log = nil + + on_complete(result) +end + +--- Check if exploration is in progress +---@return boolean +function M.is_exploring() + return state.is_exploring +end + +--- Get exploration progress +---@return number scanned, number total +function M.get_progress() + return state.files_scanned, state.total_files +end + +--- Build context string from exploration result +---@param result table Exploration result +---@return string context +function M.build_context(result) + local parts = {} + + -- Project info + table.insert(parts, "## Project: " .. result.project.name) + table.insert(parts, "- Type: " .. result.project.type) + table.insert(parts, "- Language: " .. (result.project.language or "Unknown")) + if result.project.framework then + table.insert(parts, "- Framework: " .. result.project.framework) + end + table.insert(parts, "- Files: " .. result.total_files) + table.insert(parts, "") + + -- Structure + table.insert(parts, "## Structure") + if result.structure and result.structure.by_extension then + for ext, count in pairs(result.structure.by_extension) do + table.insert(parts, "- ." .. ext .. ": " .. count .. " files") + end + end + table.insert(parts, "") + + -- Key components + if result.key_files and #result.key_files > 0 then + table.insert(parts, "## Key Components") + for i, kf in ipairs(result.key_files) do + if i <= 10 then + table.insert(parts, "- " .. kf.path .. ": " .. kf.summary) + end + end + end + + return table.concat(parts, "\n") +end + +return M diff --git a/lua/codetyper/ask/intent.lua b/lua/codetyper/ask/intent.lua new file mode 100644 index 0000000..f56340c --- /dev/null +++ b/lua/codetyper/ask/intent.lua @@ -0,0 +1,302 @@ +---@mod codetyper.ask.intent Intent detection for Ask mode +---@brief [[ +--- Analyzes user prompts to detect intent (ask/explain vs code generation). +--- Routes to appropriate prompt type and context sources. +---@brief ]] + +local M = {} + +---@alias IntentType "ask"|"explain"|"generate"|"refactor"|"document"|"test" + +---@class Intent +---@field type IntentType Detected intent type +---@field confidence number 0-1 confidence score +---@field needs_project_context boolean Whether project-wide context is needed +---@field needs_brain_context boolean Whether brain/learned context is helpful +---@field needs_exploration boolean Whether full project exploration is needed +---@field keywords string[] Keywords that influenced detection + +--- Patterns for detecting ask/explain intent (questions about code) +local ASK_PATTERNS = { + -- Question words + { pattern = "^what%s", weight = 0.9 }, + { pattern = "^why%s", weight = 0.95 }, + { pattern = "^how%s+does", weight = 0.9 }, + { pattern = "^how%s+do%s+i", weight = 0.7 }, -- Could be asking for code + { pattern = "^where%s", weight = 0.85 }, + { pattern = "^when%s", weight = 0.85 }, + { pattern = "^which%s", weight = 0.8 }, + { pattern = "^who%s", weight = 0.85 }, + { pattern = "^can%s+you%s+explain", weight = 0.95 }, + { pattern = "^could%s+you%s+explain", weight = 0.95 }, + { pattern = "^please%s+explain", weight = 0.95 }, + + -- Explanation requests + { pattern = "explain%s", weight = 0.9 }, + { pattern = "describe%s", weight = 0.85 }, + { pattern = "tell%s+me%s+about", weight = 0.85 }, + { pattern = "walk%s+me%s+through", weight = 0.9 }, + { pattern = "help%s+me%s+understand", weight = 0.95 }, + { pattern = "what%s+is%s+the%s+purpose", weight = 0.95 }, + { pattern = "what%s+does%s+this", weight = 0.9 }, + { pattern = "what%s+does%s+it", weight = 0.9 }, + { pattern = "how%s+does%s+this%s+work", weight = 0.95 }, + { pattern = "how%s+does%s+it%s+work", weight = 0.95 }, + + -- Understanding queries + { pattern = "understand", weight = 0.7 }, + { pattern = "meaning%s+of", weight = 0.85 }, + { pattern = "difference%s+between", weight = 0.9 }, + { pattern = "compared%s+to", weight = 0.8 }, + { pattern = "vs%s", weight = 0.7 }, + { pattern = "versus", weight = 0.7 }, + { pattern = "pros%s+and%s+cons", weight = 0.9 }, + { pattern = "advantages", weight = 0.8 }, + { pattern = "disadvantages", weight = 0.8 }, + { pattern = "trade%-?offs?", weight = 0.85 }, + + -- Analysis requests + { pattern = "analyze", weight = 0.85 }, + { pattern = "review", weight = 0.7 }, -- Could also be refactor + { pattern = "overview", weight = 0.9 }, + { pattern = "summary", weight = 0.9 }, + { pattern = "summarize", weight = 0.9 }, + + -- Question marks (weaker signal) + { pattern = "%?$", weight = 0.3 }, + { pattern = "%?%s*$", weight = 0.3 }, +} + +--- Patterns for detecting code generation intent +local GENERATE_PATTERNS = { + -- Direct commands + { pattern = "^create%s", weight = 0.9 }, + { pattern = "^make%s", weight = 0.85 }, + { pattern = "^build%s", weight = 0.85 }, + { pattern = "^write%s", weight = 0.9 }, + { pattern = "^add%s", weight = 0.85 }, + { pattern = "^implement%s", weight = 0.95 }, + { pattern = "^generate%s", weight = 0.95 }, + { pattern = "^code%s", weight = 0.8 }, + + -- Modification commands + { pattern = "^fix%s", weight = 0.9 }, + { pattern = "^change%s", weight = 0.8 }, + { pattern = "^update%s", weight = 0.75 }, + { pattern = "^modify%s", weight = 0.8 }, + { pattern = "^replace%s", weight = 0.85 }, + { pattern = "^remove%s", weight = 0.85 }, + { pattern = "^delete%s", weight = 0.85 }, + + -- Feature requests + { pattern = "i%s+need%s+a", weight = 0.8 }, + { pattern = "i%s+want%s+a", weight = 0.8 }, + { pattern = "give%s+me", weight = 0.7 }, + { pattern = "show%s+me%s+how%s+to%s+code", weight = 0.9 }, + { pattern = "how%s+do%s+i%s+implement", weight = 0.85 }, + { pattern = "can%s+you%s+write", weight = 0.9 }, + { pattern = "can%s+you%s+create", weight = 0.9 }, + { pattern = "can%s+you%s+add", weight = 0.85 }, + { pattern = "can%s+you%s+make", weight = 0.85 }, + + -- Code-specific terms + { pattern = "function%s+that", weight = 0.85 }, + { pattern = "class%s+that", weight = 0.85 }, + { pattern = "method%s+that", weight = 0.85 }, + { pattern = "component%s+that", weight = 0.85 }, + { pattern = "module%s+that", weight = 0.85 }, + { pattern = "api%s+for", weight = 0.8 }, + { pattern = "endpoint%s+for", weight = 0.8 }, +} + +--- Patterns for detecting refactor intent +local REFACTOR_PATTERNS = { + { pattern = "^refactor%s", weight = 0.95 }, + { pattern = "refactor%s+this", weight = 0.95 }, + { pattern = "clean%s+up", weight = 0.85 }, + { pattern = "improve%s+this%s+code", weight = 0.85 }, + { pattern = "make%s+this%s+cleaner", weight = 0.85 }, + { pattern = "simplify", weight = 0.8 }, + { pattern = "optimize", weight = 0.75 }, -- Could be explain + { pattern = "reorganize", weight = 0.9 }, + { pattern = "restructure", weight = 0.9 }, + { pattern = "extract%s+to", weight = 0.9 }, + { pattern = "split%s+into", weight = 0.85 }, + { pattern = "dry%s+this", weight = 0.9 }, -- Don't repeat yourself + { pattern = "reduce%s+duplication", weight = 0.9 }, +} + +--- Patterns for detecting documentation intent +local DOCUMENT_PATTERNS = { + { pattern = "^document%s", weight = 0.95 }, + { pattern = "add%s+documentation", weight = 0.95 }, + { pattern = "add%s+docs", weight = 0.95 }, + { pattern = "add%s+comments", weight = 0.9 }, + { pattern = "add%s+docstring", weight = 0.95 }, + { pattern = "add%s+jsdoc", weight = 0.95 }, + { pattern = "write%s+documentation", weight = 0.95 }, + { pattern = "document%s+this", weight = 0.95 }, +} + +--- Patterns for detecting test generation intent +local TEST_PATTERNS = { + { pattern = "^test%s", weight = 0.9 }, + { pattern = "write%s+tests?%s+for", weight = 0.95 }, + { pattern = "add%s+tests?%s+for", weight = 0.95 }, + { pattern = "create%s+tests?%s+for", weight = 0.95 }, + { pattern = "generate%s+tests?", weight = 0.95 }, + { pattern = "unit%s+tests?", weight = 0.9 }, + { pattern = "test%s+cases?%s+for", weight = 0.95 }, + { pattern = "spec%s+for", weight = 0.85 }, +} + +--- Patterns indicating project-wide context is needed +local PROJECT_CONTEXT_PATTERNS = { + { pattern = "project", weight = 0.9 }, + { pattern = "codebase", weight = 0.95 }, + { pattern = "entire", weight = 0.7 }, + { pattern = "whole", weight = 0.7 }, + { pattern = "all%s+files", weight = 0.9 }, + { pattern = "architecture", weight = 0.95 }, + { pattern = "structure", weight = 0.85 }, + { pattern = "how%s+is%s+.*%s+organized", weight = 0.95 }, + { pattern = "where%s+is%s+.*%s+defined", weight = 0.9 }, + { pattern = "dependencies", weight = 0.85 }, + { pattern = "imports?%s+from", weight = 0.7 }, + { pattern = "modules?", weight = 0.6 }, + { pattern = "packages?", weight = 0.6 }, +} + +--- Patterns indicating project exploration is needed (full indexing) +local EXPLORE_PATTERNS = { + { pattern = "explain%s+.*%s*project", weight = 1.0 }, + { pattern = "explain%s+.*%s*codebase", weight = 1.0 }, + { pattern = "explain%s+me%s+the%s+project", weight = 1.0 }, + { pattern = "tell%s+me%s+about%s+.*%s*project", weight = 0.95 }, + { pattern = "what%s+is%s+this%s+project", weight = 0.95 }, + { pattern = "overview%s+of%s+.*%s*project", weight = 0.95 }, + { pattern = "understand%s+.*%s*project", weight = 0.9 }, + { pattern = "analyze%s+.*%s*project", weight = 0.9 }, + { pattern = "explore%s+.*%s*project", weight = 1.0 }, + { pattern = "explore%s+.*%s*codebase", weight = 1.0 }, + { pattern = "index%s+.*%s*project", weight = 1.0 }, + { pattern = "scan%s+.*%s*project", weight = 0.95 }, +} + +--- Match patterns against text +---@param text string Lowercased text to match +---@param patterns table Pattern list with weights +---@return number Score, string[] Matched keywords +local function match_patterns(text, patterns) + local score = 0 + local matched = {} + + for _, p in ipairs(patterns) do + if text:match(p.pattern) then + score = score + p.weight + table.insert(matched, p.pattern) + end + end + + return score, matched +end + +--- Detect intent from user prompt +---@param prompt string User's question/request +---@return Intent Detected intent +function M.detect(prompt) + local text = prompt:lower() + + -- Calculate raw scores for each intent type (sum of matched weights) + local ask_score, ask_kw = match_patterns(text, ASK_PATTERNS) + local gen_score, gen_kw = match_patterns(text, GENERATE_PATTERNS) + local ref_score, ref_kw = match_patterns(text, REFACTOR_PATTERNS) + local doc_score, doc_kw = match_patterns(text, DOCUMENT_PATTERNS) + local test_score, test_kw = match_patterns(text, TEST_PATTERNS) + local proj_score, _ = match_patterns(text, PROJECT_CONTEXT_PATTERNS) + local explore_score, _ = match_patterns(text, EXPLORE_PATTERNS) + + -- Find the winner by raw score (highest accumulated weight) + local scores = { + { type = "ask", score = ask_score, keywords = ask_kw }, + { type = "generate", score = gen_score, keywords = gen_kw }, + { type = "refactor", score = ref_score, keywords = ref_kw }, + { type = "document", score = doc_score, keywords = doc_kw }, + { type = "test", score = test_score, keywords = test_kw }, + } + + table.sort(scores, function(a, b) + return a.score > b.score + end) + + local winner = scores[1] + + -- If top score is very low, default to ask (safer for Q&A) + if winner.score < 0.3 then + winner = { type = "ask", score = 0.5, keywords = {} } + end + + -- If ask and generate are close AND there's a question mark, prefer ask + if winner.type == "generate" and ask_score > 0 then + if text:match("%?%s*$") and ask_score >= gen_score * 0.5 then + winner = { type = "ask", score = ask_score, keywords = ask_kw } + end + end + + -- Determine if "explain" vs "ask" (explain needs more context) + local intent_type = winner.type + if intent_type == "ask" then + -- "explain" if asking about how something works, otherwise "ask" + if text:match("explain") or text:match("how%s+does") or text:match("walk%s+me%s+through") then + intent_type = "explain" + end + end + + -- Normalize confidence to 0-1 range (cap at reasonable max) + local confidence = math.min(winner.score / 2, 1.0) + + -- Check if exploration is needed (full project indexing) + local needs_exploration = explore_score >= 0.9 + + ---@type Intent + local intent = { + type = intent_type, + confidence = confidence, + needs_project_context = proj_score > 0.5 or needs_exploration, + needs_brain_context = intent_type == "ask" or intent_type == "explain", + needs_exploration = needs_exploration, + keywords = winner.keywords, + } + + return intent +end + +--- Get prompt type for system prompt selection +---@param intent Intent Detected intent +---@return string Prompt type for prompts.system +function M.get_prompt_type(intent) + local mapping = { + ask = "ask", + explain = "ask", -- Uses same prompt as ask + generate = "code_generation", + refactor = "refactor", + document = "document", + test = "test", + } + return mapping[intent.type] or "ask" +end + +--- Check if intent requires code output +---@param intent Intent +---@return boolean +function M.produces_code(intent) + local code_intents = { + generate = true, + refactor = true, + document = true, -- Documentation is code (comments) + test = true, + } + return code_intents[intent.type] or false +end + +return M diff --git a/lua/codetyper/autocmds.lua b/lua/codetyper/autocmds.lua index cf5d06b..26123f3 100644 --- a/lua/codetyper/autocmds.lua +++ b/lua/codetyper/autocmds.lua @@ -155,10 +155,28 @@ function M.setup() if filepath:match("%.coder%.") or filepath:match("tree%.log$") then return end + -- Skip non-project files + if filepath:match("node_modules") or filepath:match("%.git/") or filepath:match("%.coder/") then + return + end -- Schedule tree update with debounce schedule_tree_update() + + -- Trigger incremental indexing if enabled + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if ok_indexer then + indexer.schedule_index_file(filepath) + end + + -- Update brain with file patterns + local ok_brain, brain = pcall(require, "codetyper.brain") + if ok_brain and brain.is_initialized and brain.is_initialized() then + vim.defer_fn(function() + M.update_brain_from_file(filepath) + end, 500) -- Debounce brain updates + end end, - desc = "Update tree.log on file creation/save", + desc = "Update tree.log, index, and brain on file creation/save", }) -- Update tree.log when files are deleted (via netrw or file explorer) @@ -186,6 +204,19 @@ function M.setup() desc = "Update tree.log on directory change", }) + -- Shutdown brain on Vim exit + vim.api.nvim_create_autocmd("VimLeavePre", { + group = group, + pattern = "*", + callback = function() + local ok, brain = pcall(require, "codetyper.brain") + if ok and brain.is_initialized and brain.is_initialized() then + brain.shutdown() + end + end, + desc = "Shutdown brain and flush pending changes", + }) + -- Auto-index: Create/open coder companion file when opening source files vim.api.nvim_create_autocmd("BufEnter", { group = group, @@ -211,7 +242,7 @@ local function get_config_safe() open_tag = "/@", close_tag = "@/", file_pattern = "*.coder.*", - } + }, } end return config @@ -400,13 +431,11 @@ function M.check_for_closed_prompt() attached_files = attached_files, }) - local scope_info = scope and scope.type ~= "file" - and string.format(" [%s: %s]", scope.type, scope.name or "anonymous") + local scope_info = scope + and scope.type ~= "file" + and string.format(" [%s: %s]", scope.type, scope.name or "anonymous") or "" - utils.notify( - string.format("Prompt queued: %s%s", intent.type, scope_info), - vim.log.levels.INFO - ) + utils.notify(string.format("Prompt queued: %s%s", intent.type, scope_info), vim.log.levels.INFO) end) else -- Legacy: direct processing @@ -555,13 +584,11 @@ function M.check_all_prompts() attached_files = attached_files, }) - local scope_info = scope and scope.type ~= "file" - and string.format(" [%s: %s]", scope.type, scope.name or "anonymous") + local scope_info = scope + and scope.type ~= "file" + and string.format(" [%s: %s]", scope.type, scope.name or "anonymous") or "" - utils.notify( - string.format("Prompt queued: %s%s", intent.type, scope_info), - vim.log.levels.INFO - ) + utils.notify(string.format("Prompt queued: %s%s", intent.type, scope_info), vim.log.levels.INFO) end) ::continue:: @@ -822,15 +849,138 @@ function M.clear() vim.api.nvim_del_augroup_by_name(AUGROUP) end +--- Debounce timers for brain updates per file +---@type table +local brain_update_timers = {} + +--- Update brain with patterns from a file +---@param filepath string +function M.update_brain_from_file(filepath) + local ok_brain, brain = pcall(require, "codetyper.brain") + if not ok_brain or not brain.is_initialized() then + return + end + + -- Read file content + local content = utils.read_file(filepath) + if not content or content == "" then + return + end + + local ext = vim.fn.fnamemodify(filepath, ":e") + local lines = vim.split(content, "\n") + + -- Extract key patterns from the file + local functions = {} + local classes = {} + local imports = {} + + for i, line in ipairs(lines) do + -- Functions + local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(") + or line:match("^%s*local%s+function%s+([%w_]+)%s*%(") + or line:match("^%s*def%s+([%w_]+)%s*%(") + or line:match("^%s*func%s+([%w_]+)%s*%(") + or line:match("^%s*async%s+function%s+([%w_]+)%s*%(") + or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(") + or line:match("^%s*private%s+.*%s+([%w_]+)%s*%(") + if func then + table.insert(functions, { name = func, line = i }) + end + + -- Classes + local class = line:match("^%s*class%s+([%w_]+)") + or line:match("^%s*public%s+class%s+([%w_]+)") + or line:match("^%s*interface%s+([%w_]+)") + or line:match("^%s*struct%s+([%w_]+)") + if class then + table.insert(classes, { name = class, line = i }) + end + + -- Imports + local imp = line:match("import%s+.*%s+from%s+[\"']([^\"']+)[\"']") + or line:match("require%([\"']([^\"']+)[\"']%)") + or line:match("from%s+([%w_.]+)%s+import") + if imp then + table.insert(imports, imp) + end + end + + -- Only store if file has meaningful content + if #functions == 0 and #classes == 0 then + return + end + + -- Build summary + local parts = {} + if #functions > 0 then + local func_names = {} + for i, f in ipairs(functions) do + if i <= 5 then + table.insert(func_names, f.name) + end + end + table.insert(parts, "functions: " .. table.concat(func_names, ", ")) + end + if #classes > 0 then + local class_names = {} + for _, c in ipairs(classes) do + table.insert(class_names, c.name) + end + table.insert(parts, "classes: " .. table.concat(class_names, ", ")) + end + + local summary = vim.fn.fnamemodify(filepath, ":t") .. " - " .. table.concat(parts, "; ") + + -- Learn this pattern + brain.learn({ + type = "pattern", + file = filepath, + content = { + summary = summary, + detail = #functions .. " functions, " .. #classes .. " classes", + code = nil, + }, + context = { + file = filepath, + language = ext, + functions = functions, + classes = classes, + }, + }) +end + --- Track buffers that have been auto-indexed ---@type table local auto_indexed_buffers = {} --- Supported file extensions for auto-indexing local supported_extensions = { - "ts", "tsx", "js", "jsx", "py", "lua", "go", "rs", "rb", - "java", "c", "cpp", "cs", "json", "yaml", "yml", "md", - "html", "css", "scss", "vue", "svelte", "php", "sh", "zsh", + "ts", + "tsx", + "js", + "jsx", + "py", + "lua", + "go", + "rs", + "rb", + "java", + "c", + "cpp", + "cs", + "json", + "yaml", + "yml", + "md", + "html", + "css", + "scss", + "vue", + "svelte", + "php", + "sh", + "zsh", } --- Check if extension is supported @@ -968,14 +1118,23 @@ function M.open_coder_companion(open_split) %s @/%s ]], - comment_prefix, filename, close_comment, - comment_prefix, close_comment, - comment_prefix, close_comment, - comment_prefix, close_comment, - comment_prefix, close_comment, - comment_prefix, close_comment, - comment_prefix, close_comment, - comment_prefix, close_comment + comment_prefix, + filename, + close_comment, + comment_prefix, + close_comment, + comment_prefix, + close_comment, + comment_prefix, + close_comment, + comment_prefix, + close_comment, + comment_prefix, + close_comment, + comment_prefix, + close_comment, + comment_prefix, + close_comment ) utils.write_file(coder_path, template) end diff --git a/lua/codetyper/brain/delta/commit.lua b/lua/codetyper/brain/delta/commit.lua new file mode 100644 index 0000000..9afd60d --- /dev/null +++ b/lua/codetyper/brain/delta/commit.lua @@ -0,0 +1,291 @@ +--- Brain Delta Commit Operations +--- Git-like commit creation and management + +local storage = require("codetyper.brain.storage") +local hash_mod = require("codetyper.brain.hash") +local diff_mod = require("codetyper.brain.delta.diff") +local types = require("codetyper.brain.types") + +local M = {} + +--- Create a new delta commit +---@param changes table[] Changes to commit +---@param message string Commit message +---@param trigger? string Trigger source +---@return Delta|nil Created delta +function M.create(changes, message, trigger) + if not changes or #changes == 0 then + return nil + end + + local now = os.time() + local head = storage.get_head() + + -- Create delta object + local delta = { + h = hash_mod.delta_hash(changes, head, now), + p = head, + ts = now, + ch = {}, + m = { + msg = message or "Unnamed commit", + trig = trigger or "manual", + }, + } + + -- Process changes + for _, change in ipairs(changes) do + table.insert(delta.ch, { + op = change.op, + path = change.path, + bh = change.bh, + ah = change.ah, + diff = change.diff, + }) + end + + -- Save delta + storage.save_delta(delta) + + -- Update HEAD + storage.set_head(delta.h) + + -- Update meta + local meta = storage.get_meta() + storage.update_meta({ dc = meta.dc + 1 }) + + return delta +end + +--- Get a delta by hash +---@param delta_hash string Delta hash +---@return Delta|nil +function M.get(delta_hash) + return storage.get_delta(delta_hash) +end + +--- Get the current HEAD delta +---@return Delta|nil +function M.get_head() + local head_hash = storage.get_head() + if not head_hash then + return nil + end + return M.get(head_hash) +end + +--- Get delta history (ancestry chain) +---@param limit? number Max entries +---@param from_hash? string Starting hash (default: HEAD) +---@return Delta[] +function M.get_history(limit, from_hash) + limit = limit or 50 + local history = {} + local current_hash = from_hash or storage.get_head() + + while current_hash and #history < limit do + local delta = M.get(current_hash) + if not delta then + break + end + + table.insert(history, delta) + current_hash = delta.p + end + + return history +end + +--- Check if a delta exists +---@param delta_hash string Delta hash +---@return boolean +function M.exists(delta_hash) + return M.get(delta_hash) ~= nil +end + +--- Get the path from one delta to another +---@param from_hash string Start delta hash +---@param to_hash string End delta hash +---@return Delta[]|nil Path of deltas, or nil if no path +function M.get_path(from_hash, to_hash) + -- Build ancestry from both sides + local from_ancestry = {} + local current = from_hash + while current do + from_ancestry[current] = true + local delta = M.get(current) + if not delta then + break + end + current = delta.p + end + + -- Walk from to_hash back to find common ancestor + local path = {} + current = to_hash + while current do + local delta = M.get(current) + if not delta then + break + end + + table.insert(path, 1, delta) + + if from_ancestry[current] then + -- Found common ancestor + return path + end + + current = delta.p + end + + return nil +end + +--- Get all changes between two deltas +---@param from_hash string|nil Start delta hash (nil = beginning) +---@param to_hash string End delta hash +---@return table[] Combined changes +function M.get_changes_between(from_hash, to_hash) + local path = {} + local current = to_hash + + while current and current ~= from_hash do + local delta = M.get(current) + if not delta then + break + end + table.insert(path, 1, delta) + current = delta.p + end + + -- Collect all changes + local changes = {} + for _, delta in ipairs(path) do + for _, change in ipairs(delta.ch) do + table.insert(changes, change) + end + end + + return changes +end + +--- Compute reverse changes for rollback +---@param delta Delta Delta to reverse +---@return table[] Reverse changes +function M.compute_reverse(delta) + local reversed = {} + + for i = #delta.ch, 1, -1 do + local change = delta.ch[i] + local rev = { + path = change.path, + } + + if change.op == types.DELTA_OPS.ADD then + rev.op = types.DELTA_OPS.DELETE + rev.bh = change.ah + elseif change.op == types.DELTA_OPS.DELETE then + rev.op = types.DELTA_OPS.ADD + rev.ah = change.bh + elseif change.op == types.DELTA_OPS.MODIFY then + rev.op = types.DELTA_OPS.MODIFY + rev.bh = change.ah + rev.ah = change.bh + if change.diff then + rev.diff = diff_mod.reverse(change.diff) + end + end + + table.insert(reversed, rev) + end + + return reversed +end + +--- Squash multiple deltas into one +---@param delta_hashes string[] Delta hashes to squash +---@param message string Squash commit message +---@return Delta|nil Squashed delta +function M.squash(delta_hashes, message) + if #delta_hashes == 0 then + return nil + end + + -- Collect all changes in order + local all_changes = {} + for _, delta_hash in ipairs(delta_hashes) do + local delta = M.get(delta_hash) + if delta then + for _, change in ipairs(delta.ch) do + table.insert(all_changes, change) + end + end + end + + -- Compact the changes + local compacted = diff_mod.compact(all_changes) + + return M.create(compacted, message, "squash") +end + +--- Get summary of a delta +---@param delta Delta Delta to summarize +---@return table Summary +function M.summarize(delta) + local adds = 0 + local mods = 0 + local dels = 0 + local paths = {} + + for _, change in ipairs(delta.ch) do + if change.op == types.DELTA_OPS.ADD then + adds = adds + 1 + elseif change.op == types.DELTA_OPS.MODIFY then + mods = mods + 1 + elseif change.op == types.DELTA_OPS.DELETE then + dels = dels + 1 + end + + -- Extract category from path + local parts = vim.split(change.path, ".", { plain = true }) + if parts[1] then + paths[parts[1]] = true + end + end + + return { + hash = delta.h, + parent = delta.p, + timestamp = delta.ts, + message = delta.m.msg, + trigger = delta.m.trig, + stats = { + adds = adds, + modifies = mods, + deletes = dels, + total = adds + mods + dels, + }, + categories = vim.tbl_keys(paths), + } +end + +--- Format delta for display +---@param delta Delta Delta to format +---@return string[] Lines +function M.format(delta) + local summary = M.summarize(delta) + local lines = { + string.format("commit %s", delta.h), + string.format("Date: %s", os.date("%Y-%m-%d %H:%M:%S", delta.ts)), + string.format("Parent: %s", delta.p or "(none)"), + "", + " " .. (delta.m.msg or "No message"), + "", + string.format(" %d additions, %d modifications, %d deletions", summary.stats.adds, summary.stats.modifies, summary.stats.deletes), + } + + return lines +end + +return M diff --git a/lua/codetyper/brain/delta/diff.lua b/lua/codetyper/brain/delta/diff.lua new file mode 100644 index 0000000..b1d63d6 --- /dev/null +++ b/lua/codetyper/brain/delta/diff.lua @@ -0,0 +1,261 @@ +--- Brain Delta Diff Computation +--- Field-level diff algorithms for delta versioning + +local hash = require("codetyper.brain.hash") + +local M = {} + +--- Compute diff between two values +---@param before any Before value +---@param after any After value +---@param path? string Current path +---@return table[] Diff entries +function M.compute(before, after, path) + path = path or "" + local diffs = {} + + local before_type = type(before) + local after_type = type(after) + + -- Handle nil cases + if before == nil and after == nil then + return diffs + end + + if before == nil then + table.insert(diffs, { + path = path, + op = "add", + value = after, + }) + return diffs + end + + if after == nil then + table.insert(diffs, { + path = path, + op = "delete", + value = before, + }) + return diffs + end + + -- Type change + if before_type ~= after_type then + table.insert(diffs, { + path = path, + op = "replace", + from = before, + to = after, + }) + return diffs + end + + -- Tables (recursive) + if before_type == "table" then + -- Get all keys + local keys = {} + for k in pairs(before) do + keys[k] = true + end + for k in pairs(after) do + keys[k] = true + end + + for k in pairs(keys) do + local sub_path = path == "" and tostring(k) or (path .. "." .. tostring(k)) + local sub_diffs = M.compute(before[k], after[k], sub_path) + for _, d in ipairs(sub_diffs) do + table.insert(diffs, d) + end + end + + return diffs + end + + -- Primitive comparison + if before ~= after then + table.insert(diffs, { + path = path, + op = "replace", + from = before, + to = after, + }) + end + + return diffs +end + +--- Apply a diff to a value +---@param base any Base value +---@param diffs table[] Diff entries +---@return any Result value +function M.apply(base, diffs) + local result = vim.deepcopy(base) or {} + + for _, diff in ipairs(diffs) do + M.apply_single(result, diff) + end + + return result +end + +--- Apply a single diff entry +---@param target table Target table +---@param diff table Diff entry +function M.apply_single(target, diff) + local path = diff.path + local parts = vim.split(path, ".", { plain = true }) + + if #parts == 0 or parts[1] == "" then + -- Root-level change + if diff.op == "add" or diff.op == "replace" then + for k, v in pairs(diff.value or diff.to or {}) do + target[k] = v + end + end + return + end + + -- Navigate to parent + local current = target + for i = 1, #parts - 1 do + local key = parts[i] + -- Try numeric key + local num_key = tonumber(key) + key = num_key or key + + if current[key] == nil then + current[key] = {} + end + current = current[key] + end + + -- Apply to final key + local final_key = parts[#parts] + local num_key = tonumber(final_key) + final_key = num_key or final_key + + if diff.op == "add" then + current[final_key] = diff.value + elseif diff.op == "delete" then + current[final_key] = nil + elseif diff.op == "replace" then + current[final_key] = diff.to + end +end + +--- Reverse a diff (for rollback) +---@param diffs table[] Diff entries +---@return table[] Reversed diffs +function M.reverse(diffs) + local reversed = {} + + for i = #diffs, 1, -1 do + local diff = diffs[i] + local rev = { + path = diff.path, + } + + if diff.op == "add" then + rev.op = "delete" + rev.value = diff.value + elseif diff.op == "delete" then + rev.op = "add" + rev.value = diff.value + elseif diff.op == "replace" then + rev.op = "replace" + rev.from = diff.to + rev.to = diff.from + end + + table.insert(reversed, rev) + end + + return reversed +end + +--- Compact diffs (combine related changes) +---@param diffs table[] Diff entries +---@return table[] Compacted diffs +function M.compact(diffs) + local by_path = {} + + for _, diff in ipairs(diffs) do + local existing = by_path[diff.path] + if existing then + -- Combine: keep first "from", use last "to" + if diff.op == "replace" then + existing.to = diff.to + elseif diff.op == "delete" then + existing.op = "delete" + existing.to = nil + end + else + by_path[diff.path] = vim.deepcopy(diff) + end + end + + -- Convert back to array, filter out no-ops + local result = {} + for _, diff in pairs(by_path) do + -- Skip if add then delete (net no change) + if not (diff.op == "delete" and diff.from == nil) then + table.insert(result, diff) + end + end + + return result +end + +--- Create a minimal diff summary for storage +---@param diffs table[] Diff entries +---@return table Summary +function M.summarize(diffs) + local adds = 0 + local deletes = 0 + local replaces = 0 + local paths = {} + + for _, diff in ipairs(diffs) do + if diff.op == "add" then + adds = adds + 1 + elseif diff.op == "delete" then + deletes = deletes + 1 + elseif diff.op == "replace" then + replaces = replaces + 1 + end + + -- Extract top-level path + local parts = vim.split(diff.path, ".", { plain = true }) + if parts[1] then + paths[parts[1]] = true + end + end + + return { + adds = adds, + deletes = deletes, + replaces = replaces, + paths = vim.tbl_keys(paths), + total = adds + deletes + replaces, + } +end + +--- Check if two states are equal (no diff) +---@param state1 any First state +---@param state2 any Second state +---@return boolean +function M.equals(state1, state2) + local diffs = M.compute(state1, state2) + return #diffs == 0 +end + +--- Get hash of diff for deduplication +---@param diffs table[] Diff entries +---@return string Hash +function M.hash(diffs) + return hash.compute_table(diffs) +end + +return M diff --git a/lua/codetyper/brain/delta/init.lua b/lua/codetyper/brain/delta/init.lua new file mode 100644 index 0000000..7efa8e5 --- /dev/null +++ b/lua/codetyper/brain/delta/init.lua @@ -0,0 +1,278 @@ +--- Brain Delta Coordinator +--- Git-like versioning system for brain state + +local storage = require("codetyper.brain.storage") +local commit_mod = require("codetyper.brain.delta.commit") +local diff_mod = require("codetyper.brain.delta.diff") +local types = require("codetyper.brain.types") + +local M = {} + +-- Re-export submodules +M.commit = commit_mod +M.diff = diff_mod + +--- Create a commit from pending graph changes +---@param message string Commit message +---@param trigger? string Trigger source +---@return string|nil Delta hash +function M.commit(message, trigger) + local graph = require("codetyper.brain.graph") + local changes = graph.get_pending_changes() + + if #changes == 0 then + return nil + end + + local delta = commit_mod.create(changes, message, trigger or "auto") + if delta then + return delta.h + end + + return nil +end + +--- Rollback to a specific delta +---@param target_hash string Target delta hash +---@return boolean Success +function M.rollback(target_hash) + local current_hash = storage.get_head() + if not current_hash then + return false + end + + if current_hash == target_hash then + return true -- Already at target + end + + -- Get path from target to current + local deltas_to_reverse = {} + local current = current_hash + + while current and current ~= target_hash do + local delta = commit_mod.get(current) + if not delta then + return false -- Broken chain + end + table.insert(deltas_to_reverse, delta) + current = delta.p + end + + if current ~= target_hash then + return false -- Target not in ancestry + end + + -- Apply reverse changes + for _, delta in ipairs(deltas_to_reverse) do + local reverse_changes = commit_mod.compute_reverse(delta) + M.apply_changes(reverse_changes) + end + + -- Update HEAD + storage.set_head(target_hash) + + -- Create a rollback commit + commit_mod.create({ + { + op = types.DELTA_OPS.MODIFY, + path = "meta.head", + bh = current_hash, + ah = target_hash, + }, + }, "Rollback to " .. target_hash:sub(1, 8), "rollback") + + return true +end + +--- Apply changes to current state +---@param changes table[] Changes to apply +function M.apply_changes(changes) + local node_mod = require("codetyper.brain.graph.node") + + for _, change in ipairs(changes) do + local parts = vim.split(change.path, ".", { plain = true }) + + if parts[1] == "nodes" and #parts >= 3 then + local node_type = parts[2] + local node_id = parts[3] + + if change.op == types.DELTA_OPS.ADD then + -- Node was added, need to delete for reverse + node_mod.delete(node_id) + elseif change.op == types.DELTA_OPS.DELETE then + -- Node was deleted, would need original data to restore + -- This is a limitation - we'd need content storage + elseif change.op == types.DELTA_OPS.MODIFY then + -- Apply diff if available + if change.diff then + local node = node_mod.get(node_id) + if node then + local updated = diff_mod.apply(node, change.diff) + -- Direct update without tracking + local nodes = storage.get_nodes(node_type) + nodes[node_id] = updated + storage.save_nodes(node_type, nodes) + end + end + end + elseif parts[1] == "graph" then + -- Handle graph/edge changes + local edge_mod = require("codetyper.brain.graph.edge") + if parts[2] == "edges" and #parts >= 3 then + local edge_id = parts[3] + if change.op == types.DELTA_OPS.ADD then + -- Edge was added, delete for reverse + -- Parse edge_id to get source/target + local graph = storage.get_graph() + if graph.edges and graph.edges[edge_id] then + local edge = graph.edges[edge_id] + edge_mod.delete(edge.s, edge.t, edge.ty) + end + end + end + end + end +end + +--- Get delta history +---@param limit? number Max entries +---@return Delta[] +function M.get_history(limit) + return commit_mod.get_history(limit) +end + +--- Get formatted log +---@param limit? number Max entries +---@return string[] Log lines +function M.log(limit) + local history = M.get_history(limit or 20) + local lines = {} + + for _, delta in ipairs(history) do + local formatted = commit_mod.format(delta) + for _, line in ipairs(formatted) do + table.insert(lines, line) + end + table.insert(lines, "") + end + + return lines +end + +--- Get current HEAD hash +---@return string|nil +function M.head() + return storage.get_head() +end + +--- Check if there are uncommitted changes +---@return boolean +function M.has_pending() + local graph = require("codetyper.brain.graph") + local node_pending = require("codetyper.brain.graph.node").pending + local edge_pending = require("codetyper.brain.graph.edge").pending + return #node_pending > 0 or #edge_pending > 0 +end + +--- Get status (like git status) +---@return table Status info +function M.status() + local node_pending = require("codetyper.brain.graph.node").pending + local edge_pending = require("codetyper.brain.graph.edge").pending + + local adds = 0 + local mods = 0 + local dels = 0 + + for _, change in ipairs(node_pending) do + if change.op == types.DELTA_OPS.ADD then + adds = adds + 1 + elseif change.op == types.DELTA_OPS.MODIFY then + mods = mods + 1 + elseif change.op == types.DELTA_OPS.DELETE then + dels = dels + 1 + end + end + + for _, change in ipairs(edge_pending) do + if change.op == types.DELTA_OPS.ADD then + adds = adds + 1 + elseif change.op == types.DELTA_OPS.DELETE then + dels = dels + 1 + end + end + + return { + head = storage.get_head(), + pending = { + adds = adds, + modifies = mods, + deletes = dels, + total = adds + mods + dels, + }, + clean = (adds + mods + dels) == 0, + } +end + +--- Prune old deltas +---@param keep number Number of recent deltas to keep +---@return number Number of pruned deltas +function M.prune_history(keep) + keep = keep or 100 + local history = M.get_history(1000) -- Get all + + if #history <= keep then + return 0 + end + + local pruned = 0 + local brain_dir = storage.get_brain_dir() + + for i = keep + 1, #history do + local delta = history[i] + local filepath = brain_dir .. "/deltas/objects/" .. delta.h .. ".json" + if os.remove(filepath) then + pruned = pruned + 1 + end + end + + -- Update meta + local meta = storage.get_meta() + storage.update_meta({ dc = math.max(0, meta.dc - pruned) }) + + return pruned +end + +--- Reset to initial state (dangerous!) +---@return boolean Success +function M.reset() + -- Clear all nodes + for _, node_type in pairs(types.NODE_TYPES) do + storage.save_nodes(node_type .. "s", {}) + end + + -- Clear graph + storage.save_graph({ adj = {}, radj = {}, edges = {} }) + + -- Clear indices + storage.save_index("by_file", {}) + storage.save_index("by_time", {}) + storage.save_index("by_symbol", {}) + + -- Reset meta + storage.update_meta({ + head = nil, + nc = 0, + ec = 0, + dc = 0, + }) + + -- Clear pending + require("codetyper.brain.graph.node").pending = {} + require("codetyper.brain.graph.edge").pending = {} + + storage.flush_all() + return true +end + +return M diff --git a/lua/codetyper/brain/graph/edge.lua b/lua/codetyper/brain/graph/edge.lua new file mode 100644 index 0000000..0b81c92 --- /dev/null +++ b/lua/codetyper/brain/graph/edge.lua @@ -0,0 +1,367 @@ +--- Brain Graph Edge Operations +--- CRUD operations for node connections + +local storage = require("codetyper.brain.storage") +local hash = require("codetyper.brain.hash") +local types = require("codetyper.brain.types") + +local M = {} + +--- Pending changes for delta tracking +---@type table[] +M.pending = {} + +--- Create a new edge between nodes +---@param source_id string Source node ID +---@param target_id string Target node ID +---@param edge_type EdgeType Edge type +---@param props? EdgeProps Edge properties +---@return Edge|nil Created edge +function M.create(source_id, target_id, edge_type, props) + props = props or {} + + local edge = { + id = hash.edge_id(source_id, target_id), + s = source_id, + t = target_id, + ty = edge_type, + p = { + w = props.w or 0.5, + dir = props.dir or "bi", + r = props.r, + }, + ts = os.time(), + } + + -- Update adjacency lists + local graph = storage.get_graph() + + -- Forward adjacency + graph.adj[source_id] = graph.adj[source_id] or {} + graph.adj[source_id][edge_type] = graph.adj[source_id][edge_type] or {} + + -- Check for duplicate + if vim.tbl_contains(graph.adj[source_id][edge_type], target_id) then + -- Edge exists, strengthen it instead + return M.strengthen(source_id, target_id, edge_type) + end + + table.insert(graph.adj[source_id][edge_type], target_id) + + -- Reverse adjacency + graph.radj[target_id] = graph.radj[target_id] or {} + graph.radj[target_id][edge_type] = graph.radj[target_id][edge_type] or {} + table.insert(graph.radj[target_id][edge_type], source_id) + + -- Store edge properties separately (for weight/metadata) + graph.edges = graph.edges or {} + graph.edges[edge.id] = edge + + storage.save_graph(graph) + + -- Update meta + local meta = storage.get_meta() + storage.update_meta({ ec = meta.ec + 1 }) + + -- Track pending change + table.insert(M.pending, { + op = types.DELTA_OPS.ADD, + path = "graph.edges." .. edge.id, + ah = hash.compute_table(edge), + }) + + return edge +end + +--- Get edge by source and target +---@param source_id string Source node ID +---@param target_id string Target node ID +---@param edge_type? EdgeType Optional edge type filter +---@return Edge|nil +function M.get(source_id, target_id, edge_type) + local graph = storage.get_graph() + local edge_id = hash.edge_id(source_id, target_id) + + if not graph.edges or not graph.edges[edge_id] then + return nil + end + + local edge = graph.edges[edge_id] + + if edge_type and edge.ty ~= edge_type then + return nil + end + + return edge +end + +--- Get all edges for a node +---@param node_id string Node ID +---@param edge_types? EdgeType[] Edge types to include +---@param direction? "out"|"in"|"both" Direction (default: "out") +---@return Edge[] +function M.get_edges(node_id, edge_types, direction) + direction = direction or "out" + local graph = storage.get_graph() + local results = {} + + edge_types = edge_types or vim.tbl_values(types.EDGE_TYPES) + + -- Outgoing edges + if direction == "out" or direction == "both" then + local adj = graph.adj[node_id] + if adj then + for _, edge_type in ipairs(edge_types) do + local targets = adj[edge_type] or {} + for _, target_id in ipairs(targets) do + local edge_id = hash.edge_id(node_id, target_id) + if graph.edges and graph.edges[edge_id] then + table.insert(results, graph.edges[edge_id]) + end + end + end + end + end + + -- Incoming edges + if direction == "in" or direction == "both" then + local radj = graph.radj[node_id] + if radj then + for _, edge_type in ipairs(edge_types) do + local sources = radj[edge_type] or {} + for _, source_id in ipairs(sources) do + local edge_id = hash.edge_id(source_id, node_id) + if graph.edges and graph.edges[edge_id] then + table.insert(results, graph.edges[edge_id]) + end + end + end + end + end + + return results +end + +--- Get neighbor node IDs +---@param node_id string Node ID +---@param edge_types? EdgeType[] Edge types to follow +---@param direction? "out"|"in"|"both" Direction +---@return string[] Neighbor node IDs +function M.get_neighbors(node_id, edge_types, direction) + direction = direction or "out" + local graph = storage.get_graph() + local neighbors = {} + + edge_types = edge_types or vim.tbl_values(types.EDGE_TYPES) + + -- Outgoing + if direction == "out" or direction == "both" then + local adj = graph.adj[node_id] + if adj then + for _, edge_type in ipairs(edge_types) do + for _, target in ipairs(adj[edge_type] or {}) do + if not vim.tbl_contains(neighbors, target) then + table.insert(neighbors, target) + end + end + end + end + end + + -- Incoming + if direction == "in" or direction == "both" then + local radj = graph.radj[node_id] + if radj then + for _, edge_type in ipairs(edge_types) do + for _, source in ipairs(radj[edge_type] or {}) do + if not vim.tbl_contains(neighbors, source) then + table.insert(neighbors, source) + end + end + end + end + end + + return neighbors +end + +--- Delete an edge +---@param source_id string Source node ID +---@param target_id string Target node ID +---@param edge_type? EdgeType Edge type (deletes all if nil) +---@return boolean Success +function M.delete(source_id, target_id, edge_type) + local graph = storage.get_graph() + local edge_id = hash.edge_id(source_id, target_id) + + if not graph.edges or not graph.edges[edge_id] then + return false + end + + local edge = graph.edges[edge_id] + + if edge_type and edge.ty ~= edge_type then + return false + end + + local before_hash = hash.compute_table(edge) + + -- Remove from adjacency + if graph.adj[source_id] and graph.adj[source_id][edge.ty] then + graph.adj[source_id][edge.ty] = vim.tbl_filter(function(id) + return id ~= target_id + end, graph.adj[source_id][edge.ty]) + end + + -- Remove from reverse adjacency + if graph.radj[target_id] and graph.radj[target_id][edge.ty] then + graph.radj[target_id][edge.ty] = vim.tbl_filter(function(id) + return id ~= source_id + end, graph.radj[target_id][edge.ty]) + end + + -- Remove edge data + graph.edges[edge_id] = nil + + storage.save_graph(graph) + + -- Update meta + local meta = storage.get_meta() + storage.update_meta({ ec = math.max(0, meta.ec - 1) }) + + -- Track pending change + table.insert(M.pending, { + op = types.DELTA_OPS.DELETE, + path = "graph.edges." .. edge_id, + bh = before_hash, + }) + + return true +end + +--- Delete all edges for a node +---@param node_id string Node ID +---@return number Number of deleted edges +function M.delete_all(node_id) + local edges = M.get_edges(node_id, nil, "both") + local count = 0 + + for _, edge in ipairs(edges) do + if M.delete(edge.s, edge.t, edge.ty) then + count = count + 1 + end + end + + return count +end + +--- Strengthen an existing edge +---@param source_id string Source node ID +---@param target_id string Target node ID +---@param edge_type EdgeType Edge type +---@return Edge|nil Updated edge +function M.strengthen(source_id, target_id, edge_type) + local graph = storage.get_graph() + local edge_id = hash.edge_id(source_id, target_id) + + if not graph.edges or not graph.edges[edge_id] then + return nil + end + + local edge = graph.edges[edge_id] + + if edge.ty ~= edge_type then + return nil + end + + -- Increase weight (diminishing returns) + edge.p.w = math.min(1.0, edge.p.w + (1 - edge.p.w) * 0.1) + edge.ts = os.time() + + graph.edges[edge_id] = edge + storage.save_graph(graph) + + return edge +end + +--- Find path between two nodes +---@param from_id string Start node ID +---@param to_id string End node ID +---@param max_depth? number Maximum depth (default: 5) +---@return table|nil Path info {nodes: string[], edges: Edge[], found: boolean} +function M.find_path(from_id, to_id, max_depth) + max_depth = max_depth or 5 + + -- BFS + local queue = { { id = from_id, path = {}, edges = {} } } + local visited = { [from_id] = true } + + while #queue > 0 do + local current = table.remove(queue, 1) + + if current.id == to_id then + table.insert(current.path, to_id) + return { + nodes = current.path, + edges = current.edges, + found = true, + } + end + + if #current.path >= max_depth then + goto continue + end + + -- Get all neighbors + local edges = M.get_edges(current.id, nil, "both") + + for _, edge in ipairs(edges) do + local neighbor = edge.s == current.id and edge.t or edge.s + + if not visited[neighbor] then + visited[neighbor] = true + + local new_path = vim.list_extend({}, current.path) + table.insert(new_path, current.id) + + local new_edges = vim.list_extend({}, current.edges) + table.insert(new_edges, edge) + + table.insert(queue, { + id = neighbor, + path = new_path, + edges = new_edges, + }) + end + end + + ::continue:: + end + + return { nodes = {}, edges = {}, found = false } +end + +--- Get pending changes and clear +---@return table[] Pending changes +function M.get_and_clear_pending() + local changes = M.pending + M.pending = {} + return changes +end + +--- Check if two nodes are connected +---@param node_id_1 string First node ID +---@param node_id_2 string Second node ID +---@param edge_type? EdgeType Edge type filter +---@return boolean +function M.are_connected(node_id_1, node_id_2, edge_type) + local edge = M.get(node_id_1, node_id_2, edge_type) + if edge then + return true + end + -- Check reverse + edge = M.get(node_id_2, node_id_1, edge_type) + return edge ~= nil +end + +return M diff --git a/lua/codetyper/brain/graph/init.lua b/lua/codetyper/brain/graph/init.lua new file mode 100644 index 0000000..004652c --- /dev/null +++ b/lua/codetyper/brain/graph/init.lua @@ -0,0 +1,213 @@ +--- Brain Graph Coordinator +--- High-level graph operations + +local node = require("codetyper.brain.graph.node") +local edge = require("codetyper.brain.graph.edge") +local query = require("codetyper.brain.graph.query") +local storage = require("codetyper.brain.storage") +local types = require("codetyper.brain.types") + +local M = {} + +-- Re-export submodules +M.node = node +M.edge = edge +M.query = query + +--- Add a learning with automatic edge creation +---@param node_type NodeType Node type +---@param content NodeContent Content +---@param context? NodeContext Context +---@param related_ids? string[] Related node IDs +---@return Node Created node +function M.add_learning(node_type, content, context, related_ids) + -- Create the node + local new_node = node.create(node_type, content, context) + + -- Create edges to related nodes + if related_ids then + for _, related_id in ipairs(related_ids) do + local related_node = node.get(related_id) + if related_node then + -- Determine edge type based on relationship + local edge_type = types.EDGE_TYPES.SEMANTIC + + -- If same file, use file edge + if context and context.f and related_node.ctx and related_node.ctx.f == context.f then + edge_type = types.EDGE_TYPES.FILE + end + + edge.create(new_node.id, related_id, edge_type, { + w = 0.5, + r = "Related learning", + }) + end + end + end + + -- Find and link to similar existing nodes + local similar = query.semantic_search(content.s, 5) + for _, sim_node in ipairs(similar) do + if sim_node.id ~= new_node.id then + -- Create semantic edge if similarity is high enough + local sim_score = query.compute_relevance(sim_node, { query = content.s }) + if sim_score > 0.5 then + edge.create(new_node.id, sim_node.id, types.EDGE_TYPES.SEMANTIC, { + w = sim_score, + r = "Semantic similarity", + }) + end + end + end + + return new_node +end + +--- Remove a learning and its edges +---@param node_id string Node ID to remove +---@return boolean Success +function M.remove_learning(node_id) + -- Delete all edges first + edge.delete_all(node_id) + + -- Delete the node + return node.delete(node_id) +end + +--- Prune low-value nodes +---@param opts? table Prune options +---@return number Number of pruned nodes +function M.prune(opts) + opts = opts or {} + local threshold = opts.threshold or 0.1 + local unused_days = opts.unused_days or 90 + local now = os.time() + local cutoff = now - (unused_days * 86400) + + local pruned = 0 + + -- Find nodes to prune + for _, node_type in pairs(types.NODE_TYPES) do + local nodes_to_prune = node.find({ + types = { node_type }, + min_weight = 0, -- Get all + }) + + for _, n in ipairs(nodes_to_prune) do + local should_prune = false + + -- Prune if weight below threshold and not used recently + if n.sc.w < threshold and (n.ts.lu or n.ts.up) < cutoff then + should_prune = true + end + + -- Prune if never used and old + if n.sc.u == 0 and n.ts.cr < cutoff then + should_prune = true + end + + if should_prune then + if M.remove_learning(n.id) then + pruned = pruned + 1 + end + end + end + end + + return pruned +end + +--- Get all pending changes from nodes and edges +---@return table[] Combined pending changes +function M.get_pending_changes() + local changes = {} + + -- Get node changes + local node_changes = node.get_and_clear_pending() + for _, change in ipairs(node_changes) do + table.insert(changes, change) + end + + -- Get edge changes + local edge_changes = edge.get_and_clear_pending() + for _, change in ipairs(edge_changes) do + table.insert(changes, change) + end + + return changes +end + +--- Get graph statistics +---@return table Stats +function M.stats() + local meta = storage.get_meta() + + -- Count nodes by type + local by_type = {} + for _, node_type in pairs(types.NODE_TYPES) do + local nodes = storage.get_nodes(node_type .. "s") + by_type[node_type] = vim.tbl_count(nodes) + end + + -- Count edges by type + local graph = storage.get_graph() + local edges_by_type = {} + if graph.edges then + for _, e in pairs(graph.edges) do + edges_by_type[e.ty] = (edges_by_type[e.ty] or 0) + 1 + end + end + + return { + node_count = meta.nc, + edge_count = meta.ec, + delta_count = meta.dc, + nodes_by_type = by_type, + edges_by_type = edges_by_type, + } +end + +--- Create temporal edge between nodes created in sequence +---@param node_ids string[] Node IDs in temporal order +function M.link_temporal(node_ids) + for i = 1, #node_ids - 1 do + edge.create(node_ids[i], node_ids[i + 1], types.EDGE_TYPES.TEMPORAL, { + w = 0.7, + dir = "fwd", + r = "Temporal sequence", + }) + end +end + +--- Create causal edge (this caused that) +---@param cause_id string Cause node ID +---@param effect_id string Effect node ID +---@param reason? string Reason description +function M.link_causal(cause_id, effect_id, reason) + edge.create(cause_id, effect_id, types.EDGE_TYPES.CAUSAL, { + w = 0.8, + dir = "fwd", + r = reason or "Caused by", + }) +end + +--- Mark a node as superseded by another +---@param old_id string Old node ID +---@param new_id string New node ID +function M.supersede(old_id, new_id) + edge.create(old_id, new_id, types.EDGE_TYPES.SUPERSEDES, { + w = 1.0, + dir = "fwd", + r = "Superseded by newer learning", + }) + + -- Reduce weight of old node + local old_node = node.get(old_id) + if old_node then + node.update(old_id, { + sc = { w = old_node.sc.w * 0.5 }, + }) + end +end + +return M diff --git a/lua/codetyper/brain/graph/node.lua b/lua/codetyper/brain/graph/node.lua new file mode 100644 index 0000000..eb30ec6 --- /dev/null +++ b/lua/codetyper/brain/graph/node.lua @@ -0,0 +1,403 @@ +--- Brain Graph Node Operations +--- CRUD operations for learning nodes + +local storage = require("codetyper.brain.storage") +local hash = require("codetyper.brain.hash") +local types = require("codetyper.brain.types") + +local M = {} + +--- Pending changes for delta tracking +---@type table[] +M.pending = {} + +--- Node type to file mapping +local TYPE_MAP = { + [types.NODE_TYPES.PATTERN] = "patterns", + [types.NODE_TYPES.CORRECTION] = "corrections", + [types.NODE_TYPES.DECISION] = "decisions", + [types.NODE_TYPES.CONVENTION] = "conventions", + [types.NODE_TYPES.FEEDBACK] = "feedback", + [types.NODE_TYPES.SESSION] = "sessions", + -- Full names for convenience + patterns = "patterns", + corrections = "corrections", + decisions = "decisions", + conventions = "conventions", + feedback = "feedback", + sessions = "sessions", +} + +--- Get storage key for node type +---@param node_type string Node type +---@return string Storage key +local function get_storage_key(node_type) + return TYPE_MAP[node_type] or "patterns" +end + +--- Create a new node +---@param node_type NodeType Node type +---@param content NodeContent Content +---@param context? NodeContext Context +---@param opts? table Additional options +---@return Node Created node +function M.create(node_type, content, context, opts) + opts = opts or {} + local now = os.time() + + local node = { + id = hash.node_id(node_type, content.s), + t = node_type, + h = hash.compute(content.s .. (content.d or "")), + c = { + s = content.s or "", + d = content.d or content.s or "", + code = content.code, + lang = content.lang, + }, + ctx = context or {}, + sc = { + w = opts.weight or 0.5, + u = 0, + sr = 1.0, + }, + ts = { + cr = now, + up = now, + lu = now, + }, + m = { + src = opts.source or types.SOURCES.AUTO, + v = 1, + }, + } + + -- Store node + local storage_key = get_storage_key(node_type) + local nodes = storage.get_nodes(storage_key) + nodes[node.id] = node + storage.save_nodes(storage_key, nodes) + + -- Update meta + local meta = storage.get_meta() + storage.update_meta({ nc = meta.nc + 1 }) + + -- Update indices + M.update_indices(node, "add") + + -- Track pending change + table.insert(M.pending, { + op = types.DELTA_OPS.ADD, + path = "nodes." .. storage_key .. "." .. node.id, + ah = node.h, + }) + + return node +end + +--- Get a node by ID +---@param node_id string Node ID +---@return Node|nil +function M.get(node_id) + -- Parse node type from ID (n___) + local parts = vim.split(node_id, "_") + if #parts < 3 then + return nil + end + + local node_type = parts[2] + local storage_key = get_storage_key(node_type) + local nodes = storage.get_nodes(storage_key) + + return nodes[node_id] +end + +--- Update a node +---@param node_id string Node ID +---@param updates table Partial updates +---@return Node|nil Updated node +function M.update(node_id, updates) + local node = M.get(node_id) + if not node then + return nil + end + + local before_hash = node.h + + -- Apply updates + if updates.c then + node.c = vim.tbl_deep_extend("force", node.c, updates.c) + end + if updates.ctx then + node.ctx = vim.tbl_deep_extend("force", node.ctx, updates.ctx) + end + if updates.sc then + node.sc = vim.tbl_deep_extend("force", node.sc, updates.sc) + end + + -- Update timestamps and hash + node.ts.up = os.time() + node.h = hash.compute((node.c.s or "") .. (node.c.d or "")) + node.m.v = (node.m.v or 0) + 1 + + -- Save + local storage_key = get_storage_key(node.t) + local nodes = storage.get_nodes(storage_key) + nodes[node_id] = node + storage.save_nodes(storage_key, nodes) + + -- Update indices if context changed + if updates.ctx then + M.update_indices(node, "update") + end + + -- Track pending change + table.insert(M.pending, { + op = types.DELTA_OPS.MODIFY, + path = "nodes." .. storage_key .. "." .. node_id, + bh = before_hash, + ah = node.h, + }) + + return node +end + +--- Delete a node +---@param node_id string Node ID +---@return boolean Success +function M.delete(node_id) + local node = M.get(node_id) + if not node then + return false + end + + local storage_key = get_storage_key(node.t) + local nodes = storage.get_nodes(storage_key) + + if not nodes[node_id] then + return false + end + + local before_hash = node.h + nodes[node_id] = nil + storage.save_nodes(storage_key, nodes) + + -- Update meta + local meta = storage.get_meta() + storage.update_meta({ nc = math.max(0, meta.nc - 1) }) + + -- Update indices + M.update_indices(node, "delete") + + -- Track pending change + table.insert(M.pending, { + op = types.DELTA_OPS.DELETE, + path = "nodes." .. storage_key .. "." .. node_id, + bh = before_hash, + }) + + return true +end + +--- Find nodes by criteria +---@param criteria table Search criteria +---@return Node[] +function M.find(criteria) + local results = {} + + local node_types = criteria.types or vim.tbl_values(types.NODE_TYPES) + + for _, node_type in ipairs(node_types) do + local storage_key = get_storage_key(node_type) + local nodes = storage.get_nodes(storage_key) + + for _, node in pairs(nodes) do + local matches = true + + -- Filter by file + if criteria.file and node.ctx.f ~= criteria.file then + matches = false + end + + -- Filter by min weight + if criteria.min_weight and node.sc.w < criteria.min_weight then + matches = false + end + + -- Filter by since timestamp + if criteria.since and node.ts.cr < criteria.since then + matches = false + end + + -- Filter by content match + if criteria.query then + local query_lower = criteria.query:lower() + local summary_lower = (node.c.s or ""):lower() + local detail_lower = (node.c.d or ""):lower() + if not summary_lower:find(query_lower, 1, true) and not detail_lower:find(query_lower, 1, true) then + matches = false + end + end + + if matches then + table.insert(results, node) + end + end + end + + -- Sort by relevance (weight * recency) + table.sort(results, function(a, b) + local score_a = a.sc.w * (1 / (1 + (os.time() - a.ts.lu) / 86400)) + local score_b = b.sc.w * (1 / (1 + (os.time() - b.ts.lu) / 86400)) + return score_a > score_b + end) + + -- Apply limit + if criteria.limit and #results > criteria.limit then + local limited = {} + for i = 1, criteria.limit do + limited[i] = results[i] + end + return limited + end + + return results +end + +--- Record usage of a node +---@param node_id string Node ID +---@param success? boolean Was the usage successful +function M.record_usage(node_id, success) + local node = M.get(node_id) + if not node then + return + end + + -- Update usage stats + node.sc.u = node.sc.u + 1 + node.ts.lu = os.time() + + -- Update success rate + if success ~= nil then + local total = node.sc.u + local successes = node.sc.sr * (total - 1) + (success and 1 or 0) + node.sc.sr = successes / total + end + + -- Increase weight slightly for frequently used nodes + if node.sc.u > 5 then + node.sc.w = math.min(1.0, node.sc.w + 0.01) + end + + -- Save (direct save, no pending change tracking for usage) + local storage_key = get_storage_key(node.t) + local nodes = storage.get_nodes(storage_key) + nodes[node_id] = node + storage.save_nodes(storage_key, nodes) +end + +--- Update indices for a node +---@param node Node The node +---@param op "add"|"update"|"delete" Operation type +function M.update_indices(node, op) + -- File index + if node.ctx.f then + local by_file = storage.get_index("by_file") + + if op == "delete" then + if by_file[node.ctx.f] then + by_file[node.ctx.f] = vim.tbl_filter(function(id) + return id ~= node.id + end, by_file[node.ctx.f]) + end + else + by_file[node.ctx.f] = by_file[node.ctx.f] or {} + if not vim.tbl_contains(by_file[node.ctx.f], node.id) then + table.insert(by_file[node.ctx.f], node.id) + end + end + + storage.save_index("by_file", by_file) + end + + -- Symbol index + if node.ctx.sym then + local by_symbol = storage.get_index("by_symbol") + + for _, sym in ipairs(node.ctx.sym) do + if op == "delete" then + if by_symbol[sym] then + by_symbol[sym] = vim.tbl_filter(function(id) + return id ~= node.id + end, by_symbol[sym]) + end + else + by_symbol[sym] = by_symbol[sym] or {} + if not vim.tbl_contains(by_symbol[sym], node.id) then + table.insert(by_symbol[sym], node.id) + end + end + end + + storage.save_index("by_symbol", by_symbol) + end + + -- Time index (daily buckets) + local day = os.date("%Y-%m-%d", node.ts.cr) + local by_time = storage.get_index("by_time") + + if op == "delete" then + if by_time[day] then + by_time[day] = vim.tbl_filter(function(id) + return id ~= node.id + end, by_time[day]) + end + elseif op == "add" then + by_time[day] = by_time[day] or {} + if not vim.tbl_contains(by_time[day], node.id) then + table.insert(by_time[day], node.id) + end + end + + storage.save_index("by_time", by_time) +end + +--- Get pending changes and clear +---@return table[] Pending changes +function M.get_and_clear_pending() + local changes = M.pending + M.pending = {} + return changes +end + +--- Merge two similar nodes +---@param node_id_1 string First node ID +---@param node_id_2 string Second node ID (will be deleted) +---@return Node|nil Merged node +function M.merge(node_id_1, node_id_2) + local node1 = M.get(node_id_1) + local node2 = M.get(node_id_2) + + if not node1 or not node2 then + return nil + end + + -- Merge content (keep longer detail) + local merged_detail = #node1.c.d > #node2.c.d and node1.c.d or node2.c.d + + -- Merge scores (combine weights and usage) + local merged_weight = (node1.sc.w + node2.sc.w) / 2 + local merged_usage = node1.sc.u + node2.sc.u + + M.update(node_id_1, { + c = { d = merged_detail }, + sc = { w = merged_weight, u = merged_usage }, + }) + + -- Delete the second node + M.delete(node_id_2) + + return M.get(node_id_1) +end + +return M diff --git a/lua/codetyper/brain/graph/query.lua b/lua/codetyper/brain/graph/query.lua new file mode 100644 index 0000000..a2e9f9a --- /dev/null +++ b/lua/codetyper/brain/graph/query.lua @@ -0,0 +1,394 @@ +--- Brain Graph Query Engine +--- Multi-dimensional traversal and relevance scoring + +local storage = require("codetyper.brain.storage") +local types = require("codetyper.brain.types") + +local M = {} + +--- Lazy load dependencies to avoid circular requires +local function get_node_module() + return require("codetyper.brain.graph.node") +end + +local function get_edge_module() + return require("codetyper.brain.graph.edge") +end + +--- Compute text similarity (simple keyword matching) +---@param text1 string First text +---@param text2 string Second text +---@return number Similarity score (0-1) +local function text_similarity(text1, text2) + if not text1 or not text2 then + return 0 + end + + text1 = text1:lower() + text2 = text2:lower() + + -- Extract words + local words1 = {} + for word in text1:gmatch("%w+") do + words1[word] = true + end + + local words2 = {} + for word in text2:gmatch("%w+") do + words2[word] = true + end + + -- Count matches + local matches = 0 + local total = 0 + + for word in pairs(words1) do + total = total + 1 + if words2[word] then + matches = matches + 1 + end + end + + for word in pairs(words2) do + if not words1[word] then + total = total + 1 + end + end + + if total == 0 then + return 0 + end + + return matches / total +end + +--- Compute relevance score for a node +---@param node Node Node to score +---@param opts QueryOpts Query options +---@return number Relevance score (0-1) +function M.compute_relevance(node, opts) + local score = 0 + local weights = { + content_match = 0.30, + recency = 0.20, + usage = 0.15, + weight = 0.15, + connection_density = 0.10, + success_rate = 0.10, + } + + -- Content similarity + if opts.query then + local summary = node.c.s or "" + local detail = node.c.d or "" + local similarity = math.max(text_similarity(opts.query, summary), text_similarity(opts.query, detail) * 0.8) + score = score + (similarity * weights.content_match) + else + score = score + weights.content_match * 0.5 -- Base score if no query + end + + -- Recency decay (exponential with 30-day half-life) + local age_days = (os.time() - (node.ts.lu or node.ts.up)) / 86400 + local recency = math.exp(-age_days / 30) + score = score + (recency * weights.recency) + + -- Usage frequency (normalized) + local usage = math.min(node.sc.u / 10, 1.0) + score = score + (usage * weights.usage) + + -- Node weight + score = score + (node.sc.w * weights.weight) + + -- Connection density + local edge_mod = get_edge_module() + local connections = #edge_mod.get_edges(node.id, nil, "both") + local density = math.min(connections / 5, 1.0) + score = score + (density * weights.connection_density) + + -- Success rate + score = score + (node.sc.sr * weights.success_rate) + + return score +end + +--- Traverse graph from seed nodes +---@param seed_ids string[] Starting node IDs +---@param depth number Traversal depth +---@param edge_types? EdgeType[] Edge types to follow +---@return table Discovered nodes indexed by ID +local function traverse(seed_ids, depth, edge_types) + local node_mod = get_node_module() + local edge_mod = get_edge_module() + local discovered = {} + local frontier = seed_ids + + for _ = 1, depth do + local next_frontier = {} + + for _, node_id in ipairs(frontier) do + -- Skip if already discovered + if discovered[node_id] then + goto continue + end + + -- Get and store node + local node = node_mod.get(node_id) + if node then + discovered[node_id] = node + + -- Get neighbors + local neighbors = edge_mod.get_neighbors(node_id, edge_types, "both") + for _, neighbor_id in ipairs(neighbors) do + if not discovered[neighbor_id] then + table.insert(next_frontier, neighbor_id) + end + end + end + + ::continue:: + end + + frontier = next_frontier + if #frontier == 0 then + break + end + end + + return discovered +end + +--- Execute a query across all dimensions +---@param opts QueryOpts Query options +---@return QueryResult +function M.execute(opts) + opts = opts or {} + local node_mod = get_node_module() + local results = { + semantic = {}, + file = {}, + temporal = {}, + } + + -- 1. Semantic traversal (content similarity) + if opts.query then + local seed_nodes = node_mod.find({ + query = opts.query, + types = opts.types, + limit = 10, + }) + + local seed_ids = vim.tbl_map(function(n) + return n.id + end, seed_nodes) + local depth = opts.depth or 2 + + local discovered = traverse(seed_ids, depth, { types.EDGE_TYPES.SEMANTIC }) + for id, node in pairs(discovered) do + results.semantic[id] = node + end + end + + -- 2. File-based traversal + if opts.file then + local by_file = storage.get_index("by_file") + local file_node_ids = by_file[opts.file] or {} + + for _, node_id in ipairs(file_node_ids) do + local node = node_mod.get(node_id) + if node then + results.file[node.id] = node + end + end + + -- Also get nodes from related files via edges + local discovered = traverse(file_node_ids, 1, { types.EDGE_TYPES.FILE }) + for id, node in pairs(discovered) do + results.file[id] = node + end + end + + -- 3. Temporal traversal (recent context) + if opts.since then + local by_time = storage.get_index("by_time") + local now = os.time() + + for day, node_ids in pairs(by_time) do + -- Parse day to timestamp + local year, month, day_num = day:match("(%d+)-(%d+)-(%d+)") + if year then + local day_ts = os.time({ year = tonumber(year), month = tonumber(month), day = tonumber(day_num) }) + if day_ts >= opts.since then + for _, node_id in ipairs(node_ids) do + local node = node_mod.get(node_id) + if node then + results.temporal[node.id] = node + end + end + end + end + end + + -- Follow temporal edges + local temporal_ids = vim.tbl_keys(results.temporal) + local discovered = traverse(temporal_ids, 1, { types.EDGE_TYPES.TEMPORAL }) + for id, node in pairs(discovered) do + results.temporal[id] = node + end + end + + -- 4. Combine and deduplicate + local all_nodes = {} + for _, category in pairs(results) do + for id, node in pairs(category) do + if not all_nodes[id] then + all_nodes[id] = node + end + end + end + + -- 5. Score and rank + local scored = {} + for id, node in pairs(all_nodes) do + local relevance = M.compute_relevance(node, opts) + table.insert(scored, { node = node, relevance = relevance }) + end + + table.sort(scored, function(a, b) + return a.relevance > b.relevance + end) + + -- 6. Apply limit + local limit = opts.limit or 50 + local result_nodes = {} + local truncated = #scored > limit + + for i = 1, math.min(limit, #scored) do + table.insert(result_nodes, scored[i].node) + end + + -- 7. Get edges between result nodes + local edge_mod = get_edge_module() + local result_edges = {} + local node_ids = {} + for _, node in ipairs(result_nodes) do + node_ids[node.id] = true + end + + for _, node in ipairs(result_nodes) do + local edges = edge_mod.get_edges(node.id, nil, "out") + for _, edge in ipairs(edges) do + if node_ids[edge.t] then + table.insert(result_edges, edge) + end + end + end + + return { + nodes = result_nodes, + edges = result_edges, + stats = { + semantic_count = vim.tbl_count(results.semantic), + file_count = vim.tbl_count(results.file), + temporal_count = vim.tbl_count(results.temporal), + total_scored = #scored, + }, + truncated = truncated, + } +end + +--- Find nodes by file +---@param filepath string File path +---@param limit? number Max results +---@return Node[] +function M.by_file(filepath, limit) + local result = M.execute({ + file = filepath, + limit = limit or 20, + }) + return result.nodes +end + +--- Find nodes by time range +---@param since number Start timestamp +---@param until_ts? number End timestamp +---@param limit? number Max results +---@return Node[] +function M.by_time_range(since, until_ts, limit) + local node_mod = get_node_module() + local by_time = storage.get_index("by_time") + local results = {} + + until_ts = until_ts or os.time() + + for day, node_ids in pairs(by_time) do + local year, month, day_num = day:match("(%d+)-(%d+)-(%d+)") + if year then + local day_ts = os.time({ year = tonumber(year), month = tonumber(month), day = tonumber(day_num) }) + if day_ts >= since and day_ts <= until_ts then + for _, node_id in ipairs(node_ids) do + local node = node_mod.get(node_id) + if node then + table.insert(results, node) + end + end + end + end + end + + -- Sort by creation time + table.sort(results, function(a, b) + return a.ts.cr > b.ts.cr + end) + + if limit and #results > limit then + local limited = {} + for i = 1, limit do + limited[i] = results[i] + end + return limited + end + + return results +end + +--- Find semantically similar nodes +---@param query string Query text +---@param limit? number Max results +---@return Node[] +function M.semantic_search(query, limit) + local result = M.execute({ + query = query, + limit = limit or 10, + depth = 2, + }) + return result.nodes +end + +--- Get context chain (path) for explanation +---@param node_ids string[] Node IDs to chain +---@return string[] Chain descriptions +function M.get_context_chain(node_ids) + local node_mod = get_node_module() + local edge_mod = get_edge_module() + local chain = {} + + for i, node_id in ipairs(node_ids) do + local node = node_mod.get(node_id) + if node then + local entry = string.format("[%s] %s (w:%.2f)", node.t:upper(), node.c.s, node.sc.w) + table.insert(chain, entry) + + -- Add edge to next node if exists + if node_ids[i + 1] then + local edge = edge_mod.get(node_id, node_ids[i + 1]) + if edge then + table.insert(chain, string.format(" -> %s (w:%.2f)", edge.ty, edge.p.w)) + end + end + end + end + + return chain +end + +return M diff --git a/lua/codetyper/brain/hash.lua b/lua/codetyper/brain/hash.lua new file mode 100644 index 0000000..6d3dd1a --- /dev/null +++ b/lua/codetyper/brain/hash.lua @@ -0,0 +1,112 @@ +--- Brain Hashing Utilities +--- Content-addressable storage with 8-character hashes + +local M = {} + +--- Simple DJB2 hash algorithm (fast, good distribution) +---@param str string String to hash +---@return number Hash value +local function djb2(str) + local hash = 5381 + for i = 1, #str do + hash = ((hash * 33) + string.byte(str, i)) % 0x100000000 + end + return hash +end + +--- Convert number to hex string +---@param num number Number to convert +---@param len number Desired length +---@return string Hex string +local function to_hex(num, len) + local hex = string.format("%x", num) + if #hex < len then + hex = string.rep("0", len - #hex) .. hex + end + return hex:sub(-len) +end + +--- Compute 8-character hash from string +---@param content string Content to hash +---@return string 8-character hex hash +function M.compute(content) + if not content or content == "" then + return "00000000" + end + local hash = djb2(content) + return to_hex(hash, 8) +end + +--- Compute hash from table (JSON-serialized) +---@param tbl table Table to hash +---@return string 8-character hex hash +function M.compute_table(tbl) + local ok, json = pcall(vim.json.encode, tbl) + if not ok then + return "00000000" + end + return M.compute(json) +end + +--- Generate unique node ID +---@param node_type string Node type prefix +---@param content? string Optional content for hash +---@return string Node ID (n__) +function M.node_id(node_type, content) + local ts = os.time() + local hash_input = (content or "") .. tostring(ts) .. tostring(math.random(100000)) + local hash = M.compute(hash_input):sub(1, 6) + return string.format("n_%s_%d_%s", node_type, ts, hash) +end + +--- Generate unique edge ID +---@param source_id string Source node ID +---@param target_id string Target node ID +---@return string Edge ID (e__) +function M.edge_id(source_id, target_id) + local src_hash = M.compute(source_id):sub(1, 4) + local tgt_hash = M.compute(target_id):sub(1, 4) + return string.format("e_%s_%s", src_hash, tgt_hash) +end + +--- Generate delta hash +---@param changes table[] Delta changes +---@param parent string|nil Parent delta hash +---@param timestamp number Delta timestamp +---@return string 8-character delta hash +function M.delta_hash(changes, parent, timestamp) + local content = (parent or "root") .. tostring(timestamp) + for _, change in ipairs(changes or {}) do + content = content .. (change.op or "") .. (change.path or "") + end + return M.compute(content) +end + +--- Hash file path for storage +---@param filepath string File path +---@return string 8-character hash +function M.path_hash(filepath) + return M.compute(filepath) +end + +--- Check if two hashes match +---@param hash1 string First hash +---@param hash2 string Second hash +---@return boolean True if matching +function M.matches(hash1, hash2) + return hash1 == hash2 +end + +--- Generate random hash (for testing/temporary IDs) +---@return string 8-character random hash +function M.random() + local chars = "0123456789abcdef" + local result = "" + for _ = 1, 8 do + local idx = math.random(1, #chars) + result = result .. chars:sub(idx, idx) + end + return result +end + +return M diff --git a/lua/codetyper/brain/init.lua b/lua/codetyper/brain/init.lua new file mode 100644 index 0000000..863ea30 --- /dev/null +++ b/lua/codetyper/brain/init.lua @@ -0,0 +1,276 @@ +--- Brain Learning System +--- Graph-based knowledge storage with delta versioning + +local storage = require("codetyper.brain.storage") +local types = require("codetyper.brain.types") + +local M = {} + +---@type BrainConfig|nil +local config = nil + +---@type boolean +local initialized = false + +--- Pending changes counter for auto-commit +local pending_changes = 0 + +--- Default configuration +local DEFAULT_CONFIG = { + enabled = true, + auto_learn = true, + auto_commit = true, + commit_threshold = 10, + max_nodes = 5000, + max_deltas = 500, + prune = { + enabled = true, + threshold = 0.1, + unused_days = 90, + }, + output = { + max_tokens = 4000, + format = "compact", + }, +} + +--- Initialize brain system +---@param opts? BrainConfig Configuration options +function M.setup(opts) + config = vim.tbl_deep_extend("force", DEFAULT_CONFIG, opts or {}) + + if not config.enabled then + return + end + + -- Ensure storage directories + storage.ensure_dirs() + + -- Initialize meta if not exists + storage.get_meta() + + initialized = true +end + +--- Check if brain is initialized +---@return boolean +function M.is_initialized() + return initialized and config and config.enabled +end + +--- Get current configuration +---@return BrainConfig|nil +function M.get_config() + return config +end + +--- Learn from an event +---@param event LearnEvent Learning event +---@return string|nil Node ID if created +function M.learn(event) + if not M.is_initialized() or not config.auto_learn then + return nil + end + + local learners = require("codetyper.brain.learners") + local node_id = learners.process(event) + + if node_id then + pending_changes = pending_changes + 1 + + -- Auto-commit if threshold reached + if config.auto_commit and pending_changes >= config.commit_threshold then + M.commit("Auto-commit: " .. pending_changes .. " changes") + pending_changes = 0 + end + end + + return node_id +end + +--- Query relevant knowledge for context +---@param opts QueryOpts Query options +---@return QueryResult +function M.query(opts) + if not M.is_initialized() then + return { nodes = {}, edges = {}, stats = {}, truncated = false } + end + + local query_engine = require("codetyper.brain.graph.query") + return query_engine.execute(opts) +end + +--- Get LLM-optimized context string +---@param opts? QueryOpts Query options +---@return string Formatted context +function M.get_context_for_llm(opts) + if not M.is_initialized() then + return "" + end + + opts = opts or {} + opts.max_tokens = opts.max_tokens or config.output.max_tokens + + local result = M.query(opts) + local formatter = require("codetyper.brain.output.formatter") + + if config.output.format == "json" then + return formatter.to_json(result, opts) + else + return formatter.to_compact(result, opts) + end +end + +--- Create a delta commit +---@param message string Commit message +---@return string|nil Delta hash +function M.commit(message) + if not M.is_initialized() then + return nil + end + + local delta_mgr = require("codetyper.brain.delta") + return delta_mgr.commit(message) +end + +--- Rollback to a previous delta +---@param delta_hash string Target delta hash +---@return boolean Success +function M.rollback(delta_hash) + if not M.is_initialized() then + return false + end + + local delta_mgr = require("codetyper.brain.delta") + return delta_mgr.rollback(delta_hash) +end + +--- Get delta history +---@param limit? number Max entries +---@return Delta[] +function M.get_history(limit) + if not M.is_initialized() then + return {} + end + + local delta_mgr = require("codetyper.brain.delta") + return delta_mgr.get_history(limit or 50) +end + +--- Prune low-value nodes +---@param opts? table Prune options +---@return number Number of pruned nodes +function M.prune(opts) + if not M.is_initialized() or not config.prune.enabled then + return 0 + end + + opts = vim.tbl_extend("force", { + threshold = config.prune.threshold, + unused_days = config.prune.unused_days, + }, opts or {}) + + local graph = require("codetyper.brain.graph") + return graph.prune(opts) +end + +--- Export brain state +---@return table|nil Exported data +function M.export() + if not M.is_initialized() then + return nil + end + + return { + schema = types.SCHEMA_VERSION, + meta = storage.get_meta(), + graph = storage.get_graph(), + nodes = { + patterns = storage.get_nodes("patterns"), + corrections = storage.get_nodes("corrections"), + decisions = storage.get_nodes("decisions"), + conventions = storage.get_nodes("conventions"), + feedback = storage.get_nodes("feedback"), + sessions = storage.get_nodes("sessions"), + }, + indices = { + by_file = storage.get_index("by_file"), + by_time = storage.get_index("by_time"), + by_symbol = storage.get_index("by_symbol"), + }, + } +end + +--- Import brain state +---@param data table Exported data +---@return boolean Success +function M.import(data) + if not data or data.schema ~= types.SCHEMA_VERSION then + return false + end + + storage.ensure_dirs() + + -- Import nodes + if data.nodes then + for node_type, nodes in pairs(data.nodes) do + storage.save_nodes(node_type, nodes) + end + end + + -- Import graph + if data.graph then + storage.save_graph(data.graph) + end + + -- Import indices + if data.indices then + for index_type, index_data in pairs(data.indices) do + storage.save_index(index_type, index_data) + end + end + + -- Import meta last + if data.meta then + for k, v in pairs(data.meta) do + storage.update_meta({ [k] = v }) + end + end + + storage.flush_all() + return true +end + +--- Get stats about the brain +---@return table Stats +function M.stats() + if not M.is_initialized() then + return {} + end + + local meta = storage.get_meta() + return { + initialized = true, + node_count = meta.nc, + edge_count = meta.ec, + delta_count = meta.dc, + head = meta.head, + pending_changes = pending_changes, + } +end + +--- Flush all pending writes to disk +function M.flush() + storage.flush_all() +end + +--- Shutdown brain (call before exit) +function M.shutdown() + if pending_changes > 0 then + M.commit("Session end: " .. pending_changes .. " changes") + end + storage.flush_all() + initialized = false +end + +return M diff --git a/lua/codetyper/brain/learners/convention.lua b/lua/codetyper/brain/learners/convention.lua new file mode 100644 index 0000000..9e8f097 --- /dev/null +++ b/lua/codetyper/brain/learners/convention.lua @@ -0,0 +1,233 @@ +--- Brain Convention Learner +--- Learns project conventions and coding standards + +local types = require("codetyper.brain.types") + +local M = {} + +--- Detect if event contains convention info +---@param event LearnEvent Learning event +---@return boolean +function M.detect(event) + local valid_types = { + "convention_detected", + "naming_pattern", + "style_pattern", + "project_structure", + "config_change", + } + + for _, t in ipairs(valid_types) do + if event.type == t then + return true + end + end + + return false +end + +--- Extract convention data from event +---@param event LearnEvent Learning event +---@return table|nil Extracted data +function M.extract(event) + local data = event.data or {} + + if event.type == "convention_detected" then + return { + summary = "Convention: " .. (data.name or "unnamed"), + detail = data.description or data.name, + rule = data.rule, + examples = data.examples, + category = data.category or "general", + file = event.file, + } + end + + if event.type == "naming_pattern" then + return { + summary = "Naming: " .. (data.pattern_name or data.pattern), + detail = "Naming convention: " .. (data.description or data.pattern), + rule = data.pattern, + examples = data.examples, + category = "naming", + scope = data.scope, -- function, variable, class, file + } + end + + if event.type == "style_pattern" then + return { + summary = "Style: " .. (data.name or "unnamed"), + detail = data.description or "Code style pattern", + rule = data.rule, + examples = data.examples, + category = "style", + lang = data.language, + } + end + + if event.type == "project_structure" then + return { + summary = "Structure: " .. (data.pattern or "project layout"), + detail = data.description or "Project structure convention", + rule = data.rule, + category = "structure", + paths = data.paths, + } + end + + if event.type == "config_change" then + return { + summary = "Config: " .. (data.setting or "setting change"), + detail = "Configuration: " .. (data.description or data.setting), + before = data.before, + after = data.after, + category = "config", + file = event.file, + } + end + + return nil +end + +--- Check if convention should be learned +---@param data table Extracted data +---@return boolean +function M.should_learn(data) + if not data.summary then + return false + end + + -- Skip very vague conventions + if not data.detail or #data.detail < 5 then + return false + end + + return true +end + +--- Create node from convention data +---@param data table Extracted data +---@return table Node creation params +function M.create_node_params(data) + local detail = data.detail or "" + + -- Add examples if available + if data.examples and #data.examples > 0 then + detail = detail .. "\n\nExamples:" + for _, ex in ipairs(data.examples) do + detail = detail .. "\n- " .. tostring(ex) + end + end + + -- Add rule if available + if data.rule then + detail = detail .. "\n\nRule: " .. tostring(data.rule) + end + + return { + node_type = types.NODE_TYPES.CONVENTION, + content = { + s = data.summary:sub(1, 200), + d = detail, + lang = data.lang, + }, + context = { + f = data.file, + sym = data.scope and { data.scope } or nil, + }, + opts = { + weight = 0.6, + source = types.SOURCES.AUTO, + }, + } +end + +--- Find related conventions +---@param data table Extracted data +---@param query_fn function Query function +---@return string[] Related node IDs +function M.find_related(data, query_fn) + local related = {} + + -- Find conventions in same category + if data.category then + local similar = query_fn({ + query = data.category, + types = { types.NODE_TYPES.CONVENTION }, + limit = 5, + }) + for _, node in ipairs(similar) do + table.insert(related, node.id) + end + end + + -- Find patterns that follow this convention + if data.rule then + local patterns = query_fn({ + query = data.rule, + types = { types.NODE_TYPES.PATTERN }, + limit = 3, + }) + for _, node in ipairs(patterns) do + if not vim.tbl_contains(related, node.id) then + table.insert(related, node.id) + end + end + end + + return related +end + +--- Detect naming convention from symbol names +---@param symbols string[] Symbol names to analyze +---@return table|nil Detected convention +function M.detect_naming(symbols) + if not symbols or #symbols < 3 then + return nil + end + + local patterns = { + snake_case = 0, + camelCase = 0, + PascalCase = 0, + SCREAMING_SNAKE = 0, + kebab_case = 0, + } + + for _, sym in ipairs(symbols) do + if sym:match("^[a-z][a-z0-9_]*$") then + patterns.snake_case = patterns.snake_case + 1 + elseif sym:match("^[a-z][a-zA-Z0-9]*$") then + patterns.camelCase = patterns.camelCase + 1 + elseif sym:match("^[A-Z][a-zA-Z0-9]*$") then + patterns.PascalCase = patterns.PascalCase + 1 + elseif sym:match("^[A-Z][A-Z0-9_]*$") then + patterns.SCREAMING_SNAKE = patterns.SCREAMING_SNAKE + 1 + elseif sym:match("^[a-z][a-z0-9%-]*$") then + patterns.kebab_case = patterns.kebab_case + 1 + end + end + + -- Find dominant pattern + local max_count = 0 + local dominant = nil + + for pattern, count in pairs(patterns) do + if count > max_count then + max_count = count + dominant = pattern + end + end + + if dominant and max_count >= #symbols * 0.6 then + return { + pattern = dominant, + confidence = max_count / #symbols, + sample_size = #symbols, + } + end + + return nil +end + +return M diff --git a/lua/codetyper/brain/learners/correction.lua b/lua/codetyper/brain/learners/correction.lua new file mode 100644 index 0000000..9a94f2a --- /dev/null +++ b/lua/codetyper/brain/learners/correction.lua @@ -0,0 +1,213 @@ +--- Brain Correction Learner +--- Learns from user corrections and edits + +local types = require("codetyper.brain.types") + +local M = {} + +--- Detect if event is a correction +---@param event LearnEvent Learning event +---@return boolean +function M.detect(event) + local valid_types = { + "user_correction", + "code_rejected", + "code_modified", + "suggestion_rejected", + } + + for _, t in ipairs(valid_types) do + if event.type == t then + return true + end + end + + return false +end + +--- Extract correction data from event +---@param event LearnEvent Learning event +---@return table|nil Extracted data +function M.extract(event) + local data = event.data or {} + + if event.type == "user_correction" then + return { + summary = "Correction: " .. (data.error_type or "user edit"), + detail = data.description or "User corrected the generated code", + before = data.before, + after = data.after, + error_type = data.error_type, + file = event.file, + function_name = data.function_name, + lines = data.lines, + } + end + + if event.type == "code_rejected" then + return { + summary = "Rejected: " .. (data.reason or "not accepted"), + detail = data.description or "User rejected generated code", + rejected_code = data.code, + reason = data.reason, + file = event.file, + intent = data.intent, + } + end + + if event.type == "code_modified" then + local changes = M.analyze_changes(data.before, data.after) + return { + summary = "Modified: " .. changes.summary, + detail = changes.detail, + before = data.before, + after = data.after, + change_type = changes.type, + file = event.file, + lines = data.lines, + } + end + + return nil +end + +--- Analyze changes between before/after code +---@param before string Before code +---@param after string After code +---@return table Change analysis +function M.analyze_changes(before, after) + before = before or "" + after = after or "" + + local before_lines = vim.split(before, "\n") + local after_lines = vim.split(after, "\n") + + local added = 0 + local removed = 0 + local modified = 0 + + -- Simple line-based diff + local max_lines = math.max(#before_lines, #after_lines) + for i = 1, max_lines do + local b = before_lines[i] + local a = after_lines[i] + + if b == nil and a ~= nil then + added = added + 1 + elseif b ~= nil and a == nil then + removed = removed + 1 + elseif b ~= a then + modified = modified + 1 + end + end + + local change_type = "mixed" + if added > 0 and removed == 0 and modified == 0 then + change_type = "addition" + elseif removed > 0 and added == 0 and modified == 0 then + change_type = "deletion" + elseif modified > 0 and added == 0 and removed == 0 then + change_type = "modification" + end + + return { + type = change_type, + summary = string.format("+%d -%d ~%d lines", added, removed, modified), + detail = string.format("Added %d, removed %d, modified %d lines", added, removed, modified), + stats = { + added = added, + removed = removed, + modified = modified, + }, + } +end + +--- Check if correction should be learned +---@param data table Extracted data +---@return boolean +function M.should_learn(data) + -- Always learn corrections - they're valuable + if not data.summary then + return false + end + + -- Skip trivial changes + if data.before and data.after then + -- Skip if only whitespace changed + local before_trimmed = data.before:gsub("%s+", "") + local after_trimmed = data.after:gsub("%s+", "") + if before_trimmed == after_trimmed then + return false + end + end + + return true +end + +--- Create node from correction data +---@param data table Extracted data +---@return table Node creation params +function M.create_node_params(data) + local detail = data.detail or "" + + -- Include before/after in detail for learning + if data.before and data.after then + detail = detail .. "\n\nBefore:\n" .. data.before:sub(1, 500) + detail = detail .. "\n\nAfter:\n" .. data.after:sub(1, 500) + end + + return { + node_type = types.NODE_TYPES.CORRECTION, + content = { + s = data.summary:sub(1, 200), + d = detail, + code = data.after or data.rejected_code, + lang = data.lang, + }, + context = { + f = data.file, + fn = data.function_name, + ln = data.lines, + }, + opts = { + weight = 0.7, -- Corrections are valuable + source = types.SOURCES.USER, + }, + } +end + +--- Find related nodes for corrections +---@param data table Extracted data +---@param query_fn function Query function +---@return string[] Related node IDs +function M.find_related(data, query_fn) + local related = {} + + -- Find patterns that might be corrected + if data.before then + local similar = query_fn({ + query = data.before:sub(1, 100), + types = { types.NODE_TYPES.PATTERN }, + limit = 3, + }) + for _, node in ipairs(similar) do + table.insert(related, node.id) + end + end + + -- Find other corrections in same file + if data.file then + local file_corrections = query_fn({ + file = data.file, + types = { types.NODE_TYPES.CORRECTION }, + limit = 3, + }) + for _, node in ipairs(file_corrections) do + table.insert(related, node.id) + end + end + + return related +end + +return M diff --git a/lua/codetyper/brain/learners/init.lua b/lua/codetyper/brain/learners/init.lua new file mode 100644 index 0000000..936313e --- /dev/null +++ b/lua/codetyper/brain/learners/init.lua @@ -0,0 +1,232 @@ +--- Brain Learners Coordinator +--- Routes learning events to appropriate learners + +local types = require("codetyper.brain.types") + +local M = {} + +-- Lazy load learners +local function get_pattern_learner() + return require("codetyper.brain.learners.pattern") +end + +local function get_correction_learner() + return require("codetyper.brain.learners.correction") +end + +local function get_convention_learner() + return require("codetyper.brain.learners.convention") +end + +--- All available learners +local LEARNERS = { + { name = "pattern", loader = get_pattern_learner }, + { name = "correction", loader = get_correction_learner }, + { name = "convention", loader = get_convention_learner }, +} + +--- Process a learning event +---@param event LearnEvent Learning event +---@return string|nil Created node ID +function M.process(event) + if not event or not event.type then + return nil + end + + -- Add timestamp if missing + event.timestamp = event.timestamp or os.time() + + -- Find matching learner + for _, learner_info in ipairs(LEARNERS) do + local learner = learner_info.loader() + + if learner.detect(event) then + return M.learn_with(learner, event) + end + end + + -- Handle generic feedback events + if event.type == "user_feedback" then + return M.process_feedback(event) + end + + -- Handle session events + if event.type == "session_start" or event.type == "session_end" then + return M.process_session(event) + end + + return nil +end + +--- Learn using a specific learner +---@param learner table Learner module +---@param event LearnEvent Learning event +---@return string|nil Created node ID +function M.learn_with(learner, event) + -- Extract data + local extracted = learner.extract(event) + if not extracted then + return nil + end + + -- Handle multiple extractions (e.g., from file indexing) + if vim.islist(extracted) then + local node_ids = {} + for _, data in ipairs(extracted) do + local node_id = M.create_learning(learner, data, event) + if node_id then + table.insert(node_ids, node_id) + end + end + return node_ids[1] -- Return first for now + end + + return M.create_learning(learner, extracted, event) +end + +--- Create a learning from extracted data +---@param learner table Learner module +---@param data table Extracted data +---@param event LearnEvent Original event +---@return string|nil Created node ID +function M.create_learning(learner, data, event) + -- Check if should learn + if not learner.should_learn(data) then + return nil + end + + -- Get node params + local params = learner.create_node_params(data) + + -- Get graph module + local graph = require("codetyper.brain.graph") + + -- Find related nodes + local related_ids = {} + if learner.find_related then + related_ids = learner.find_related(data, function(opts) + return graph.query.execute(opts).nodes + end) + end + + -- Create the learning + local node = graph.add_learning(params.node_type, params.content, params.context, related_ids) + + -- Update weight if specified + if params.opts and params.opts.weight then + graph.node.update(node.id, { sc = { w = params.opts.weight } }) + end + + return node.id +end + +--- Process feedback event +---@param event LearnEvent Feedback event +---@return string|nil Created node ID +function M.process_feedback(event) + local data = event.data or {} + local graph = require("codetyper.brain.graph") + + local content = { + s = "Feedback: " .. (data.feedback or "unknown"), + d = data.description or ("User " .. (data.feedback or "gave feedback")), + } + + local context = { + f = event.file, + } + + -- If feedback references a node, update it + if data.node_id then + local node = graph.node.get(data.node_id) + if node then + local weight_delta = data.feedback == "accepted" and 0.1 or -0.1 + local new_weight = math.max(0, math.min(1, node.sc.w + weight_delta)) + + graph.node.update(data.node_id, { + sc = { w = new_weight }, + }) + + -- Record usage + graph.node.record_usage(data.node_id, data.feedback == "accepted") + + -- Create feedback node linked to original + local fb_node = graph.add_learning(types.NODE_TYPES.FEEDBACK, content, context, { data.node_id }) + + return fb_node.id + end + end + + -- Create standalone feedback node + local node = graph.add_learning(types.NODE_TYPES.FEEDBACK, content, context) + return node.id +end + +--- Process session event +---@param event LearnEvent Session event +---@return string|nil Created node ID +function M.process_session(event) + local data = event.data or {} + local graph = require("codetyper.brain.graph") + + local content = { + s = event.type == "session_start" and "Session started" or "Session ended", + d = data.description or event.type, + } + + if event.type == "session_end" and data.stats then + content.d = content.d .. "\n\nStats:" + content.d = content.d .. "\n- Completions: " .. (data.stats.completions or 0) + content.d = content.d .. "\n- Corrections: " .. (data.stats.corrections or 0) + content.d = content.d .. "\n- Files: " .. (data.stats.files or 0) + end + + local node = graph.add_learning(types.NODE_TYPES.SESSION, content, {}) + + -- Link to recent session nodes + if event.type == "session_end" then + local recent = graph.query.by_time_range(os.time() - 3600, os.time(), 20) -- Last hour + local session_nodes = {} + + for _, n in ipairs(recent) do + if n.id ~= node.id then + table.insert(session_nodes, n.id) + end + end + + -- Create temporal links + if #session_nodes > 0 then + graph.link_temporal(session_nodes) + end + end + + return node.id +end + +--- Batch process multiple events +---@param events LearnEvent[] Events to process +---@return string[] Created node IDs +function M.batch_process(events) + local node_ids = {} + + for _, event in ipairs(events) do + local node_id = M.process(event) + if node_id then + table.insert(node_ids, node_id) + end + end + + return node_ids +end + +--- Get learner names +---@return string[] +function M.get_learner_names() + local names = {} + for _, learner in ipairs(LEARNERS) do + table.insert(names, learner.name) + end + return names +end + +return M diff --git a/lua/codetyper/brain/learners/pattern.lua b/lua/codetyper/brain/learners/pattern.lua new file mode 100644 index 0000000..c07276e --- /dev/null +++ b/lua/codetyper/brain/learners/pattern.lua @@ -0,0 +1,172 @@ +--- Brain Pattern Learner +--- Detects and learns code patterns + +local types = require("codetyper.brain.types") + +local M = {} + +--- Detect if event contains a learnable pattern +---@param event LearnEvent Learning event +---@return boolean +function M.detect(event) + local valid_types = { + "code_completion", + "file_indexed", + "code_analyzed", + "pattern_detected", + } + + for _, t in ipairs(valid_types) do + if event.type == t then + return true + end + end + + return false +end + +--- Extract pattern data from event +---@param event LearnEvent Learning event +---@return table|nil Extracted data +function M.extract(event) + local data = event.data or {} + + -- Extract from code completion + if event.type == "code_completion" then + return { + summary = "Code pattern: " .. (data.intent or "unknown"), + detail = data.code or data.content or "", + code = data.code, + lang = data.language, + file = event.file, + function_name = data.function_name, + symbols = data.symbols, + } + end + + -- Extract from file indexing + if event.type == "file_indexed" then + local patterns = {} + + -- Extract function patterns + if data.functions then + for _, func in ipairs(data.functions) do + table.insert(patterns, { + summary = "Function: " .. func.name, + detail = func.signature or func.name, + code = func.body, + lang = data.language, + file = event.file, + function_name = func.name, + lines = func.lines, + }) + end + end + + -- Extract class patterns + if data.classes then + for _, class in ipairs(data.classes) do + table.insert(patterns, { + summary = "Class: " .. class.name, + detail = class.description or class.name, + lang = data.language, + file = event.file, + symbols = { class.name }, + }) + end + end + + return #patterns > 0 and patterns or nil + end + + -- Extract from explicit pattern detection + if event.type == "pattern_detected" then + return { + summary = data.name or "Unnamed pattern", + detail = data.description or data.name or "", + code = data.example, + lang = data.language, + file = event.file, + symbols = data.symbols, + } + end + + return nil +end + +--- Check if pattern should be learned +---@param data table Extracted data +---@return boolean +function M.should_learn(data) + -- Skip if no meaningful content + if not data.summary or data.summary == "" then + return false + end + + -- Skip very short patterns + if data.detail and #data.detail < 10 then + return false + end + + -- Skip auto-generated patterns + if data.summary:match("^%s*$") then + return false + end + + return true +end + +--- Create node from pattern data +---@param data table Extracted data +---@return table Node creation params +function M.create_node_params(data) + return { + node_type = types.NODE_TYPES.PATTERN, + content = { + s = data.summary:sub(1, 200), -- Limit summary + d = data.detail, + code = data.code, + lang = data.lang, + }, + context = { + f = data.file, + fn = data.function_name, + ln = data.lines, + sym = data.symbols, + }, + opts = { + weight = 0.5, + source = types.SOURCES.AUTO, + }, + } +end + +--- Find potentially related nodes +---@param data table Extracted data +---@param query_fn function Query function +---@return string[] Related node IDs +function M.find_related(data, query_fn) + local related = {} + + -- Find nodes in same file + if data.file then + local file_nodes = query_fn({ file = data.file, limit = 5 }) + for _, node in ipairs(file_nodes) do + table.insert(related, node.id) + end + end + + -- Find semantically similar + if data.summary then + local similar = query_fn({ query = data.summary, limit = 3 }) + for _, node in ipairs(similar) do + if not vim.tbl_contains(related, node.id) then + table.insert(related, node.id) + end + end + end + + return related +end + +return M diff --git a/lua/codetyper/brain/output/formatter.lua b/lua/codetyper/brain/output/formatter.lua new file mode 100644 index 0000000..6a627ea --- /dev/null +++ b/lua/codetyper/brain/output/formatter.lua @@ -0,0 +1,279 @@ +--- Brain Output Formatter +--- LLM-optimized output formatting + +local types = require("codetyper.brain.types") + +local M = {} + +--- Estimate token count (rough approximation) +---@param text string Text to estimate +---@return number Estimated tokens +function M.estimate_tokens(text) + if not text then + return 0 + end + -- Rough estimate: 1 token ~= 4 characters + return math.ceil(#text / 4) +end + +--- Format nodes to compact text format +---@param result QueryResult Query result +---@param opts? table Options +---@return string Formatted output +function M.to_compact(result, opts) + opts = opts or {} + local max_tokens = opts.max_tokens or 4000 + local lines = {} + local current_tokens = 0 + + -- Header + table.insert(lines, "---BRAIN_CONTEXT---") + if opts.query then + table.insert(lines, "Q: " .. opts.query) + end + table.insert(lines, "") + + -- Add nodes by relevance (already sorted) + table.insert(lines, "Learnings:") + + for i, node in ipairs(result.nodes) do + -- Format: [idx] TYPE | w:0.85 u:5 | Summary + local line = string.format( + "[%d] %s | w:%.2f u:%d | %s", + i, + (node.t or "?"):upper(), + node.sc.w or 0, + node.sc.u or 0, + (node.c.s or ""):sub(1, 100) + ) + + local line_tokens = M.estimate_tokens(line) + if current_tokens + line_tokens > max_tokens - 100 then + table.insert(lines, "... (truncated)") + break + end + + table.insert(lines, line) + current_tokens = current_tokens + line_tokens + + -- Add context if file-related + if node.ctx and node.ctx.f then + local ctx_line = " @ " .. node.ctx.f + if node.ctx.fn then + ctx_line = ctx_line .. ":" .. node.ctx.fn + end + if node.ctx.ln then + ctx_line = ctx_line .. " L" .. node.ctx.ln[1] + end + table.insert(lines, ctx_line) + current_tokens = current_tokens + M.estimate_tokens(ctx_line) + end + end + + -- Add connections if space allows + if #result.edges > 0 and current_tokens < max_tokens - 200 then + table.insert(lines, "") + table.insert(lines, "Connections:") + + for _, edge in ipairs(result.edges) do + if current_tokens >= max_tokens - 50 then + break + end + + local conn_line = string.format( + " %s --%s(%.2f)--> %s", + edge.s:sub(-8), + edge.ty, + edge.p.w or 0.5, + edge.t:sub(-8) + ) + table.insert(lines, conn_line) + current_tokens = current_tokens + M.estimate_tokens(conn_line) + end + end + + table.insert(lines, "---END_CONTEXT---") + + return table.concat(lines, "\n") +end + +--- Format nodes to JSON format +---@param result QueryResult Query result +---@param opts? table Options +---@return string JSON output +function M.to_json(result, opts) + opts = opts or {} + local max_tokens = opts.max_tokens or 4000 + + local output = { + _s = "brain-v1", -- Schema + q = opts.query, + l = {}, -- Learnings + c = {}, -- Connections + } + + local current_tokens = 50 -- Base overhead + + -- Add nodes + for _, node in ipairs(result.nodes) do + local entry = { + t = node.t, + s = (node.c.s or ""):sub(1, 150), + w = node.sc.w, + u = node.sc.u, + } + + if node.ctx and node.ctx.f then + entry.f = node.ctx.f + end + + local entry_tokens = M.estimate_tokens(vim.json.encode(entry)) + if current_tokens + entry_tokens > max_tokens - 100 then + break + end + + table.insert(output.l, entry) + current_tokens = current_tokens + entry_tokens + end + + -- Add edges if space + if current_tokens < max_tokens - 200 then + for _, edge in ipairs(result.edges) do + if current_tokens >= max_tokens - 50 then + break + end + + local e = { + s = edge.s:sub(-8), + t = edge.t:sub(-8), + r = edge.ty, + w = edge.p.w, + } + + table.insert(output.c, e) + current_tokens = current_tokens + 30 + end + end + + return vim.json.encode(output) +end + +--- Format as natural language +---@param result QueryResult Query result +---@param opts? table Options +---@return string Natural language output +function M.to_natural(result, opts) + opts = opts or {} + local max_tokens = opts.max_tokens or 4000 + local lines = {} + local current_tokens = 0 + + if #result.nodes == 0 then + return "No relevant learnings found." + end + + table.insert(lines, "Based on previous learnings:") + table.insert(lines, "") + + -- Group by type + local by_type = {} + for _, node in ipairs(result.nodes) do + by_type[node.t] = by_type[node.t] or {} + table.insert(by_type[node.t], node) + end + + local type_names = { + [types.NODE_TYPES.PATTERN] = "Code Patterns", + [types.NODE_TYPES.CORRECTION] = "Previous Corrections", + [types.NODE_TYPES.CONVENTION] = "Project Conventions", + [types.NODE_TYPES.DECISION] = "Architectural Decisions", + [types.NODE_TYPES.FEEDBACK] = "User Preferences", + [types.NODE_TYPES.SESSION] = "Session Context", + } + + for node_type, nodes in pairs(by_type) do + local type_name = type_names[node_type] or node_type + + table.insert(lines, "**" .. type_name .. "**") + + for _, node in ipairs(nodes) do + if current_tokens >= max_tokens - 100 then + table.insert(lines, "...") + goto done + end + + local bullet = string.format("- %s (confidence: %.0f%%)", node.c.s or "?", (node.sc.w or 0) * 100) + + table.insert(lines, bullet) + current_tokens = current_tokens + M.estimate_tokens(bullet) + + -- Add detail if high weight + if node.sc.w > 0.7 and node.c.d and #node.c.d > #(node.c.s or "") then + local detail = " " .. node.c.d:sub(1, 150) + if #node.c.d > 150 then + detail = detail .. "..." + end + table.insert(lines, detail) + current_tokens = current_tokens + M.estimate_tokens(detail) + end + end + + table.insert(lines, "") + end + + ::done:: + + return table.concat(lines, "\n") +end + +--- Format context chain for explanation +---@param chain table[] Chain of nodes and edges +---@return string Chain explanation +function M.format_chain(chain) + local lines = {} + + for i, item in ipairs(chain) do + if item.node then + local prefix = i == 1 and "" or " -> " + table.insert(lines, string.format("%s[%s] %s (w:%.2f)", prefix, item.node.t:upper(), item.node.c.s:sub(1, 50), item.node.sc.w)) + end + if item.edge then + table.insert(lines, string.format(" via %s (w:%.2f)", item.edge.ty, item.edge.p.w)) + end + end + + return table.concat(lines, "\n") +end + +--- Compress output to fit token budget +---@param text string Text to compress +---@param max_tokens number Token budget +---@return string Compressed text +function M.compress(text, max_tokens) + local current = M.estimate_tokens(text) + + if current <= max_tokens then + return text + end + + -- Simple truncation with ellipsis + local ratio = max_tokens / current + local target_chars = math.floor(#text * ratio * 0.9) -- 10% buffer + + return text:sub(1, target_chars) .. "\n...(truncated)" +end + +--- Get minimal context for quick lookups +---@param nodes Node[] Nodes to format +---@return string Minimal context +function M.minimal(nodes) + local items = {} + + for _, node in ipairs(nodes) do + table.insert(items, string.format("%s:%s", node.t, (node.c.s or ""):sub(1, 40))) + end + + return table.concat(items, " | ") +end + +return M diff --git a/lua/codetyper/brain/output/init.lua b/lua/codetyper/brain/output/init.lua new file mode 100644 index 0000000..1009561 --- /dev/null +++ b/lua/codetyper/brain/output/init.lua @@ -0,0 +1,166 @@ +--- Brain Output Coordinator +--- Manages LLM context generation + +local formatter = require("codetyper.brain.output.formatter") + +local M = {} + +-- Re-export formatter +M.formatter = formatter + +--- Default token budget +local DEFAULT_MAX_TOKENS = 4000 + +--- Generate context for LLM prompt +---@param opts? table Options +---@return string Context string +function M.generate(opts) + opts = opts or {} + + local brain = require("codetyper.brain") + if not brain.is_initialized() then + return "" + end + + -- Build query opts + local query_opts = { + query = opts.query, + file = opts.file, + types = opts.types, + since = opts.since, + limit = opts.limit or 30, + depth = opts.depth or 2, + max_tokens = opts.max_tokens or DEFAULT_MAX_TOKENS, + } + + -- Execute query + local result = brain.query(query_opts) + + if #result.nodes == 0 then + return "" + end + + -- Format based on style + local format = opts.format or "compact" + + if format == "json" then + return formatter.to_json(result, query_opts) + elseif format == "natural" then + return formatter.to_natural(result, query_opts) + else + return formatter.to_compact(result, query_opts) + end +end + +--- Generate context for a specific file +---@param filepath string File path +---@param opts? table Options +---@return string Context string +function M.for_file(filepath, opts) + opts = opts or {} + opts.file = filepath + return M.generate(opts) +end + +--- Generate context for current buffer +---@param opts? table Options +---@return string Context string +function M.for_current_buffer(opts) + local filepath = vim.fn.expand("%:p") + if filepath == "" then + return "" + end + return M.for_file(filepath, opts) +end + +--- Generate context for a query/prompt +---@param query string Query text +---@param opts? table Options +---@return string Context string +function M.for_query(query, opts) + opts = opts or {} + opts.query = query + return M.generate(opts) +end + +--- Get context for LLM system prompt +---@param opts? table Options +---@return string System context +function M.system_context(opts) + opts = opts or {} + opts.limit = opts.limit or 20 + opts.format = opts.format or "compact" + + local context = M.generate(opts) + + if context == "" then + return "" + end + + return [[ +The following context contains learned patterns and conventions from this project: + +]] .. context .. [[ + + +Use this context to inform your responses, following established patterns and conventions. +]] +end + +--- Get relevant context for code completion +---@param prefix string Code before cursor +---@param suffix string Code after cursor +---@param filepath string Current file +---@return string Context +function M.for_completion(prefix, suffix, filepath) + -- Extract relevant terms from code + local terms = {} + + -- Get function/class names + for word in prefix:gmatch("[A-Z][a-zA-Z0-9]+") do + table.insert(terms, word) + end + for word in prefix:gmatch("function%s+([a-zA-Z_][a-zA-Z0-9_]*)") do + table.insert(terms, word) + end + + local query = table.concat(terms, " ") + + return M.generate({ + query = query, + file = filepath, + limit = 15, + max_tokens = 2000, + format = "compact", + }) +end + +--- Check if context is available +---@return boolean +function M.has_context() + local brain = require("codetyper.brain") + if not brain.is_initialized() then + return false + end + + local stats = brain.stats() + return stats.node_count > 0 +end + +--- Get context stats +---@return table Stats +function M.stats() + local brain = require("codetyper.brain") + if not brain.is_initialized() then + return { available = false } + end + + local stats = brain.stats() + return { + available = true, + node_count = stats.node_count, + edge_count = stats.edge_count, + } +end + +return M diff --git a/lua/codetyper/brain/storage.lua b/lua/codetyper/brain/storage.lua new file mode 100644 index 0000000..6433f7b --- /dev/null +++ b/lua/codetyper/brain/storage.lua @@ -0,0 +1,338 @@ +--- Brain Storage Layer +--- Cache + disk persistence with lazy loading + +local utils = require("codetyper.utils") +local types = require("codetyper.brain.types") + +local M = {} + +--- In-memory cache keyed by project root +---@type table +local cache = {} + +--- Dirty flags for pending writes +---@type table> +local dirty = {} + +--- Debounce timers +---@type table +local timers = {} + +local DEBOUNCE_MS = 500 + +--- Get brain directory path for current project +---@param root? string Project root (defaults to current) +---@return string Brain directory path +function M.get_brain_dir(root) + root = root or utils.get_project_root() + return root .. "/.coder/brain" +end + +--- Ensure brain directory structure exists +---@param root? string Project root +---@return boolean Success +function M.ensure_dirs(root) + local brain_dir = M.get_brain_dir(root) + local dirs = { + brain_dir, + brain_dir .. "/nodes", + brain_dir .. "/indices", + brain_dir .. "/deltas", + brain_dir .. "/deltas/objects", + } + for _, dir in ipairs(dirs) do + if not utils.ensure_dir(dir) then + return false + end + end + return true +end + +--- Get file path for a storage key +---@param key string Storage key (e.g., "meta", "nodes.patterns", "deltas.objects.abc123") +---@param root? string Project root +---@return string File path +function M.get_path(key, root) + local brain_dir = M.get_brain_dir(root) + local parts = vim.split(key, ".", { plain = true }) + + if #parts == 1 then + return brain_dir .. "/" .. key .. ".json" + elseif #parts == 2 then + return brain_dir .. "/" .. parts[1] .. "/" .. parts[2] .. ".json" + else + return brain_dir .. "/" .. table.concat(parts, "/") .. ".json" + end +end + +--- Get cache for project +---@param root? string Project root +---@return table Project cache +local function get_cache(root) + root = root or utils.get_project_root() + if not cache[root] then + cache[root] = {} + dirty[root] = {} + end + return cache[root] +end + +--- Read JSON from disk +---@param filepath string File path +---@return table|nil Data or nil on error +local function read_json(filepath) + local content = utils.read_file(filepath) + if not content or content == "" then + return nil + end + local ok, data = pcall(vim.json.decode, content) + if not ok then + return nil + end + return data +end + +--- Write JSON to disk +---@param filepath string File path +---@param data table Data to write +---@return boolean Success +local function write_json(filepath, data) + local ok, json = pcall(vim.json.encode, data) + if not ok then + return false + end + return utils.write_file(filepath, json) +end + +--- Load data from disk into cache +---@param key string Storage key +---@param root? string Project root +---@return table|nil Data or nil +function M.load(key, root) + root = root or utils.get_project_root() + local project_cache = get_cache(root) + + -- Return cached if available + if project_cache[key] ~= nil then + return project_cache[key] + end + + -- Load from disk + local filepath = M.get_path(key, root) + local data = read_json(filepath) + + -- Cache the result (even nil to avoid repeated reads) + project_cache[key] = data or {} + + return project_cache[key] +end + +--- Save data to cache and schedule disk write +---@param key string Storage key +---@param data table Data to save +---@param root? string Project root +---@param immediate? boolean Skip debounce +function M.save(key, data, root, immediate) + root = root or utils.get_project_root() + local project_cache = get_cache(root) + + -- Update cache + project_cache[key] = data + dirty[root][key] = true + + if immediate then + M.flush(key, root) + return + end + + -- Debounced write + local timer_key = root .. ":" .. key + if timers[timer_key] then + timers[timer_key]:stop() + end + + timers[timer_key] = vim.defer_fn(function() + M.flush(key, root) + timers[timer_key] = nil + end, DEBOUNCE_MS) +end + +--- Flush a key to disk immediately +---@param key string Storage key +---@param root? string Project root +---@return boolean Success +function M.flush(key, root) + root = root or utils.get_project_root() + local project_cache = get_cache(root) + + if not dirty[root][key] then + return true + end + + M.ensure_dirs(root) + local filepath = M.get_path(key, root) + local data = project_cache[key] + + if data == nil then + -- Delete file if data is nil + os.remove(filepath) + dirty[root][key] = nil + return true + end + + local success = write_json(filepath, data) + if success then + dirty[root][key] = nil + end + return success +end + +--- Flush all dirty keys to disk +---@param root? string Project root +function M.flush_all(root) + root = root or utils.get_project_root() + if not dirty[root] then + return + end + + for key, is_dirty in pairs(dirty[root]) do + if is_dirty then + M.flush(key, root) + end + end +end + +--- Get meta.json data +---@param root? string Project root +---@return GraphMeta +function M.get_meta(root) + local meta = M.load("meta", root) + if not meta or not meta.v then + meta = { + v = types.SCHEMA_VERSION, + head = nil, + nc = 0, + ec = 0, + dc = 0, + } + M.save("meta", meta, root) + end + return meta +end + +--- Update meta.json +---@param updates table Partial updates +---@param root? string Project root +function M.update_meta(updates, root) + local meta = M.get_meta(root) + for k, v in pairs(updates) do + meta[k] = v + end + M.save("meta", meta, root) +end + +--- Get nodes by type +---@param node_type string Node type (e.g., "patterns", "corrections") +---@param root? string Project root +---@return table Nodes indexed by ID +function M.get_nodes(node_type, root) + return M.load("nodes." .. node_type, root) or {} +end + +--- Save nodes by type +---@param node_type string Node type +---@param nodes table Nodes indexed by ID +---@param root? string Project root +function M.save_nodes(node_type, nodes, root) + M.save("nodes." .. node_type, nodes, root) +end + +--- Get graph adjacency +---@param root? string Project root +---@return Graph Graph data +function M.get_graph(root) + local graph = M.load("graph", root) + if not graph or not graph.adj then + graph = { + adj = {}, + radj = {}, + } + M.save("graph", graph, root) + end + return graph +end + +--- Save graph +---@param graph Graph Graph data +---@param root? string Project root +function M.save_graph(graph, root) + M.save("graph", graph, root) +end + +--- Get index by type +---@param index_type string Index type (e.g., "by_file", "by_time") +---@param root? string Project root +---@return table Index data +function M.get_index(index_type, root) + return M.load("indices." .. index_type, root) or {} +end + +--- Save index +---@param index_type string Index type +---@param data table Index data +---@param root? string Project root +function M.save_index(index_type, data, root) + M.save("indices." .. index_type, data, root) +end + +--- Get delta by hash +---@param hash string Delta hash +---@param root? string Project root +---@return Delta|nil Delta data +function M.get_delta(hash, root) + return M.load("deltas.objects." .. hash, root) +end + +--- Save delta +---@param delta Delta Delta data +---@param root? string Project root +function M.save_delta(delta, root) + M.save("deltas.objects." .. delta.h, delta, root, true) -- Immediate write for deltas +end + +--- Get HEAD delta hash +---@param root? string Project root +---@return string|nil HEAD hash +function M.get_head(root) + local meta = M.get_meta(root) + return meta.head +end + +--- Set HEAD delta hash +---@param hash string|nil Delta hash +---@param root? string Project root +function M.set_head(hash, root) + M.update_meta({ head = hash }, root) +end + +--- Clear all caches (for testing) +function M.clear_cache() + cache = {} + dirty = {} + for _, timer in pairs(timers) do + if timer then + timer:stop() + end + end + timers = {} +end + +--- Check if brain exists for project +---@param root? string Project root +---@return boolean +function M.exists(root) + local brain_dir = M.get_brain_dir(root) + return vim.fn.isdirectory(brain_dir) == 1 +end + +return M diff --git a/lua/codetyper/brain/types.lua b/lua/codetyper/brain/types.lua new file mode 100644 index 0000000..28f2162 --- /dev/null +++ b/lua/codetyper/brain/types.lua @@ -0,0 +1,175 @@ +---@meta +--- Brain Learning System Type Definitions +--- Optimized for LLM consumption with compact field names + +local M = {} + +---@alias NodeType "pat"|"cor"|"dec"|"con"|"fbk"|"ses" +-- pat = pattern, cor = correction, dec = decision +-- con = convention, fbk = feedback, ses = session + +---@alias EdgeType "sem"|"file"|"temp"|"caus"|"sup" +-- sem = semantic, file = file-based, temp = temporal +-- caus = causal, sup = supersedes + +---@alias DeltaOp "add"|"mod"|"del" + +---@class NodeContent +---@field s string Summary (max 200 chars) +---@field d string Detail (full description) +---@field code? string Optional code snippet +---@field lang? string Language identifier + +---@class NodeContext +---@field f? string File path (relative) +---@field fn? string Function name +---@field ln? number[] Line range [start, end] +---@field sym? string[] Symbol references + +---@class NodeScores +---@field w number Weight (0-1) +---@field u number Usage count +---@field sr number Success rate (0-1) + +---@class NodeTimestamps +---@field cr number Created (unix timestamp) +---@field up number Updated (unix timestamp) +---@field lu? number Last used (unix timestamp) + +---@class NodeMeta +---@field src "auto"|"user"|"llm" Source of learning +---@field v number Version number +---@field dr? string[] Delta references + +---@class Node +---@field id string Unique identifier (n__) +---@field t NodeType Node type +---@field h string Content hash (8 chars) +---@field c NodeContent Content +---@field ctx NodeContext Context +---@field sc NodeScores Scores +---@field ts NodeTimestamps Timestamps +---@field m? NodeMeta Metadata + +---@class EdgeProps +---@field w number Weight (0-1) +---@field dir "bi"|"fwd"|"bwd" Direction +---@field r? string Reason/description + +---@class Edge +---@field id string Unique identifier (e__) +---@field s string Source node ID +---@field t string Target node ID +---@field ty EdgeType Edge type +---@field p EdgeProps Properties +---@field ts number Created timestamp + +---@class DeltaChange +---@field op DeltaOp Operation type +---@field path string JSON path (e.g., "nodes.pat.n_123") +---@field bh? string Before hash +---@field ah? string After hash +---@field diff? table Field-level diff + +---@class DeltaMeta +---@field msg string Commit message +---@field trig string Trigger source +---@field sid? string Session ID + +---@class Delta +---@field h string Hash (8 chars) +---@field p? string Parent hash +---@field ts number Timestamp +---@field ch DeltaChange[] Changes +---@field m DeltaMeta Metadata + +---@class GraphMeta +---@field v number Schema version +---@field head? string Current HEAD delta hash +---@field nc number Node count +---@field ec number Edge count +---@field dc number Delta count + +---@class AdjacencyEntry +---@field sem? string[] Semantic edges +---@field file? string[] File edges +---@field temp? string[] Temporal edges +---@field caus? string[] Causal edges +---@field sup? string[] Supersedes edges + +---@class Graph +---@field meta GraphMeta Metadata +---@field adj table Adjacency list +---@field radj table Reverse adjacency + +---@class QueryOpts +---@field query? string Text query +---@field file? string File path filter +---@field types? NodeType[] Node types to include +---@field since? number Timestamp filter +---@field limit? number Max results +---@field depth? number Traversal depth +---@field max_tokens? number Token budget + +---@class QueryResult +---@field nodes Node[] Matched nodes +---@field edges Edge[] Related edges +---@field stats table Query statistics +---@field truncated boolean Whether results were truncated + +---@class LLMContext +---@field schema string Schema version +---@field query string Original query +---@field learnings table[] Compact learning entries +---@field connections table[] Connection summaries +---@field tokens number Estimated token count + +---@class LearnEvent +---@field type string Event type +---@field data table Event data +---@field file? string Related file +---@field timestamp number Event timestamp + +---@class BrainConfig +---@field enabled boolean Enable brain system +---@field auto_learn boolean Auto-learn from events +---@field auto_commit boolean Auto-commit after threshold +---@field commit_threshold number Changes before auto-commit +---@field max_nodes number Max nodes before pruning +---@field max_deltas number Max delta history +---@field prune table Pruning config +---@field output table Output config + +-- Type constants for runtime use +M.NODE_TYPES = { + PATTERN = "pat", + CORRECTION = "cor", + DECISION = "dec", + CONVENTION = "con", + FEEDBACK = "fbk", + SESSION = "ses", +} + +M.EDGE_TYPES = { + SEMANTIC = "sem", + FILE = "file", + TEMPORAL = "temp", + CAUSAL = "caus", + SUPERSEDES = "sup", +} + +M.DELTA_OPS = { + ADD = "add", + MODIFY = "mod", + DELETE = "del", +} + +M.SOURCES = { + AUTO = "auto", + USER = "user", + LLM = "llm", +} + +M.SCHEMA_VERSION = 1 + +return M diff --git a/lua/codetyper/cmp_source/init.lua b/lua/codetyper/cmp_source/init.lua new file mode 100644 index 0000000..5e4c89d --- /dev/null +++ b/lua/codetyper/cmp_source/init.lua @@ -0,0 +1,301 @@ +---@mod codetyper.cmp_source Completion source for nvim-cmp +---@brief [[ +--- Provides intelligent code completions using the brain, indexer, and LLM. +--- Integrates with nvim-cmp as a custom source. +---@brief ]] + +local M = {} + +local source = {} + +--- Check if cmp is available +---@return boolean +local function has_cmp() + return pcall(require, "cmp") +end + +--- Get completion items from brain context +---@param prefix string Current word prefix +---@return table[] items +local function get_brain_completions(prefix) + local items = {} + + local ok_brain, brain = pcall(require, "codetyper.brain") + if not ok_brain then + return items + end + + -- Check if brain is initialized safely + local is_init = false + if brain.is_initialized then + local ok, result = pcall(brain.is_initialized) + is_init = ok and result + end + + if not is_init then + return items + end + + -- Query brain for relevant patterns + local ok_query, result = pcall(brain.query, { + query = prefix, + max_results = 10, + types = { "pattern" }, + }) + + if ok_query and result and result.nodes then + for _, node in ipairs(result.nodes) do + if node.c and node.c.s then + -- Extract function/class names from summary + local summary = node.c.s + for name in summary:gmatch("functions:%s*([^;]+)") do + for func in name:gmatch("([%w_]+)") do + if func:lower():find(prefix:lower(), 1, true) then + table.insert(items, { + label = func, + kind = 3, -- Function + detail = "[brain]", + documentation = summary, + }) + end + end + end + for name in summary:gmatch("classes:%s*([^;]+)") do + for class in name:gmatch("([%w_]+)") do + if class:lower():find(prefix:lower(), 1, true) then + table.insert(items, { + label = class, + kind = 7, -- Class + detail = "[brain]", + documentation = summary, + }) + end + end + end + end + end + end + + return items +end + +--- Get completion items from indexer symbols +---@param prefix string Current word prefix +---@return table[] items +local function get_indexer_completions(prefix) + local items = {} + + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if not ok_indexer then + return items + end + + local ok_load, index = pcall(indexer.load_index) + if not ok_load or not index then + return items + end + + -- Search symbols + if index.symbols then + for symbol, files in pairs(index.symbols) do + if symbol:lower():find(prefix:lower(), 1, true) then + local files_str = type(files) == "table" and table.concat(files, ", ") or tostring(files) + table.insert(items, { + label = symbol, + kind = 6, -- Variable (generic) + detail = "[index] " .. files_str:sub(1, 30), + documentation = "Symbol found in: " .. files_str, + }) + end + end + end + + -- Search functions in files + if index.files then + for filepath, file_index in pairs(index.files) do + if file_index and file_index.functions then + for _, func in ipairs(file_index.functions) do + if func.name and func.name:lower():find(prefix:lower(), 1, true) then + table.insert(items, { + label = func.name, + kind = 3, -- Function + detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"), + documentation = func.docstring or ("Function at line " .. (func.line or "?")), + }) + end + end + end + if file_index and file_index.classes then + for _, class in ipairs(file_index.classes) do + if class.name and class.name:lower():find(prefix:lower(), 1, true) then + table.insert(items, { + label = class.name, + kind = 7, -- Class + detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"), + documentation = class.docstring or ("Class at line " .. (class.line or "?")), + }) + end + end + end + end + end + + return items +end + +--- Get completion items from current buffer (fallback) +---@param prefix string Current word prefix +---@param bufnr number Buffer number +---@return table[] items +local function get_buffer_completions(prefix, bufnr) + local items = {} + local seen = {} + + -- Get all lines in buffer + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local prefix_lower = prefix:lower() + + for _, line in ipairs(lines) do + -- Extract words that could be identifiers + for word in line:gmatch("[%a_][%w_]*") do + if #word >= 3 and word:lower():find(prefix_lower, 1, true) and not seen[word] and word ~= prefix then + seen[word] = true + table.insert(items, { + label = word, + kind = 1, -- Text + detail = "[buffer]", + }) + end + end + end + + return items +end + +--- Create new cmp source instance +function source.new() + return setmetatable({}, { __index = source }) +end + +--- Get source name +function source:get_keyword_pattern() + return [[\k\+]] +end + +--- Check if source is available +function source:is_available() + return true +end + +--- Get debug name +function source:get_debug_name() + return "codetyper" +end + +--- Get trigger characters +function source:get_trigger_characters() + return { ".", ":", "_" } +end + +--- Complete +---@param params table +---@param callback fun(response: table|nil) +function source:complete(params, callback) + local prefix = params.context.cursor_before_line:match("[%w_]+$") or "" + + if #prefix < 2 then + callback({ items = {}, isIncomplete = true }) + return + end + + -- Collect completions from brain, indexer, and buffer + local items = {} + local seen = {} + + -- Get brain completions (highest priority) + local ok1, brain_items = pcall(get_brain_completions, prefix) + if ok1 and brain_items then + for _, item in ipairs(brain_items) do + if not seen[item.label] then + seen[item.label] = true + item.sortText = "1" .. item.label + table.insert(items, item) + end + end + end + + -- Get indexer completions + local ok2, indexer_items = pcall(get_indexer_completions, prefix) + if ok2 and indexer_items then + for _, item in ipairs(indexer_items) do + if not seen[item.label] then + seen[item.label] = true + item.sortText = "2" .. item.label + table.insert(items, item) + end + end + end + + -- Get buffer completions as fallback (lower priority) + local bufnr = params.context.bufnr + if bufnr then + local ok3, buffer_items = pcall(get_buffer_completions, prefix, bufnr) + if ok3 and buffer_items then + for _, item in ipairs(buffer_items) do + if not seen[item.label] then + seen[item.label] = true + item.sortText = "3" .. item.label + table.insert(items, item) + end + end + end + end + + callback({ + items = items, + isIncomplete = #items >= 50, + }) +end + +--- Setup the completion source +function M.setup() + if not has_cmp() then + return false + end + + local cmp = require("cmp") + local new_source = source.new() + + -- Register the source + cmp.register_source("codetyper", new_source) + + return true +end + +--- Check if source is registered +---@return boolean +function M.is_registered() + local ok, cmp = pcall(require, "cmp") + if not ok then + return false + end + + -- Try to get registered sources + local config = cmp.get_config() + if config and config.sources then + for _, src in ipairs(config.sources) do + if src.name == "codetyper" then + return true + end + end + end + + return false +end + +--- Get source for manual registration +function M.get_source() + return source +end + +return M diff --git a/lua/codetyper/commands.lua b/lua/codetyper/commands.lua index 1471489..7be1cb7 100644 --- a/lua/codetyper/commands.lua +++ b/lua/codetyper/commands.lua @@ -164,13 +164,19 @@ local function cmd_status() "Provider: " .. config.llm.provider, } - if config.llm.provider == "claude" then - local has_key = (config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY) ~= nil - table.insert(status, "Claude API Key: " .. (has_key and "configured" or "NOT SET")) - table.insert(status, "Claude Model: " .. config.llm.claude.model) - else + if config.llm.provider == "ollama" then table.insert(status, "Ollama Host: " .. config.llm.ollama.host) table.insert(status, "Ollama Model: " .. config.llm.ollama.model) + elseif config.llm.provider == "openai" then + local has_key = (config.llm.openai.api_key or vim.env.OPENAI_API_KEY) ~= nil + table.insert(status, "OpenAI API Key: " .. (has_key and "configured" or "NOT SET")) + table.insert(status, "OpenAI Model: " .. config.llm.openai.model) + elseif config.llm.provider == "gemini" then + local has_key = (config.llm.gemini.api_key or vim.env.GEMINI_API_KEY) ~= nil + table.insert(status, "Gemini API Key: " .. (has_key and "configured" or "NOT SET")) + table.insert(status, "Gemini Model: " .. config.llm.gemini.model) + elseif config.llm.provider == "copilot" then + table.insert(status, "Copilot Model: " .. config.llm.copilot.model) end table.insert(status, "") @@ -618,6 +624,131 @@ local function cmd_transform_visual() cmd_transform_range(start_line, end_line) end +--- Index the entire project +local function cmd_index_project() + local indexer = require("codetyper.indexer") + + utils.notify("Indexing project...", vim.log.levels.INFO) + + indexer.index_project(function(index) + if index then + local msg = string.format( + "Indexed: %d files, %d functions, %d classes, %d exports", + index.stats.files, + index.stats.functions, + index.stats.classes, + index.stats.exports + ) + utils.notify(msg, vim.log.levels.INFO) + else + utils.notify("Failed to index project", vim.log.levels.ERROR) + end + end) +end + +--- Show index status +local function cmd_index_status() + local indexer = require("codetyper.indexer") + local memory = require("codetyper.indexer.memory") + + local status = indexer.get_status() + local mem_stats = memory.get_stats() + + local lines = { + "Project Index Status", + "====================", + "", + } + + if status.indexed then + table.insert(lines, "Status: Indexed") + table.insert(lines, "Project Type: " .. (status.project_type or "unknown")) + table.insert(lines, "Last Indexed: " .. os.date("%Y-%m-%d %H:%M:%S", status.last_indexed)) + table.insert(lines, "") + table.insert(lines, "Stats:") + table.insert(lines, " Files: " .. (status.stats.files or 0)) + table.insert(lines, " Functions: " .. (status.stats.functions or 0)) + table.insert(lines, " Classes: " .. (status.stats.classes or 0)) + table.insert(lines, " Exports: " .. (status.stats.exports or 0)) + else + table.insert(lines, "Status: Not indexed") + table.insert(lines, "Run :CoderIndexProject to index") + end + + table.insert(lines, "") + table.insert(lines, "Memories:") + table.insert(lines, " Patterns: " .. mem_stats.patterns) + table.insert(lines, " Conventions: " .. mem_stats.conventions) + table.insert(lines, " Symbols: " .. mem_stats.symbols) + + utils.notify(table.concat(lines, "\n")) +end + +--- Show learned memories +local function cmd_memories() + local memory = require("codetyper.indexer.memory") + + local all = memory.get_all() + local lines = { + "Learned Memories", + "================", + "", + "Patterns:", + } + + local pattern_count = 0 + for _, mem in pairs(all.patterns) do + pattern_count = pattern_count + 1 + if pattern_count <= 10 then + table.insert(lines, " - " .. (mem.content or ""):sub(1, 60)) + end + end + if pattern_count > 10 then + table.insert(lines, " ... and " .. (pattern_count - 10) .. " more") + elseif pattern_count == 0 then + table.insert(lines, " (none)") + end + + table.insert(lines, "") + table.insert(lines, "Conventions:") + + local conv_count = 0 + for _, mem in pairs(all.conventions) do + conv_count = conv_count + 1 + if conv_count <= 10 then + table.insert(lines, " - " .. (mem.content or ""):sub(1, 60)) + end + end + if conv_count > 10 then + table.insert(lines, " ... and " .. (conv_count - 10) .. " more") + elseif conv_count == 0 then + table.insert(lines, " (none)") + end + + utils.notify(table.concat(lines, "\n")) +end + +--- Clear memories +---@param pattern string|nil Optional pattern to match +local function cmd_forget(pattern) + local memory = require("codetyper.indexer.memory") + + if not pattern or pattern == "" then + -- Confirm before clearing all + vim.ui.select({ "Yes", "No" }, { + prompt = "Clear all memories?", + }, function(choice) + if choice == "Yes" then + memory.clear() + utils.notify("All memories cleared", vim.log.levels.INFO) + end + end) + else + memory.clear(pattern) + utils.notify("Cleared memories matching: " .. pattern, vim.log.levels.INFO) + end +end + --- Transform a single prompt at cursor position local function cmd_transform_at_cursor() local parser = require("codetyper.parser") @@ -741,6 +872,12 @@ local function coder_cmd(args) ["logs-toggle"] = cmd_logs_toggle, ["queue-status"] = cmd_queue_status, ["queue-process"] = cmd_queue_process, + ["index-project"] = cmd_index_project, + ["index-status"] = cmd_index_status, + memories = cmd_memories, + forget = function(args) + cmd_forget(args.fargs[2]) + end, ["auto-toggle"] = function() local preferences = require("codetyper.preferences") preferences.toggle_auto_process() @@ -787,6 +924,7 @@ function M.setup() "agent", "agent-close", "agent-toggle", "agent-stop", "type-toggle", "logs-toggle", "queue-status", "queue-process", + "index-project", "index-status", "memories", "forget", "auto-toggle", "auto-set", } end, @@ -875,6 +1013,26 @@ function M.setup() autocmds.open_coder_companion() end, { desc = "Open coder companion for current file" }) + -- Project indexer commands + vim.api.nvim_create_user_command("CoderIndexProject", function() + cmd_index_project() + end, { desc = "Index the entire project" }) + + vim.api.nvim_create_user_command("CoderIndexStatus", function() + cmd_index_status() + end, { desc = "Show project index status" }) + + vim.api.nvim_create_user_command("CoderMemories", function() + cmd_memories() + end, { desc = "Show learned memories" }) + + vim.api.nvim_create_user_command("CoderForget", function(opts) + cmd_forget(opts.args ~= "" and opts.args or nil) + end, { + desc = "Clear memories (optionally matching pattern)", + nargs = "?", + }) + -- Queue commands vim.api.nvim_create_user_command("CoderQueueStatus", function() cmd_queue_status() diff --git a/lua/codetyper/config.lua b/lua/codetyper/config.lua index b443cbc..ec8ef8c 100644 --- a/lua/codetyper/config.lua +++ b/lua/codetyper/config.lua @@ -5,11 +5,7 @@ local M = {} ---@type CoderConfig local defaults = { llm = { - provider = "ollama", -- Options: "claude", "ollama", "openai", "gemini", "copilot" - claude = { - api_key = nil, -- Will use ANTHROPIC_API_KEY env var if nil - model = "claude-sonnet-4-20250514", - }, + provider = "ollama", -- Options: "ollama", "openai", "gemini", "copilot" ollama = { host = "http://localhost:11434", model = "deepseek-coder:6.7b", @@ -48,6 +44,48 @@ local defaults = { completion_delay_ms = 100, -- Wait after completion popup closes apply_delay_ms = 5000, -- Wait before removing tags and applying code (ms) }, + indexer = { + enabled = true, -- Enable project indexing + auto_index = true, -- Index files on save + index_on_open = false, -- Index project when opening + max_file_size = 100000, -- Skip files larger than 100KB + excluded_dirs = { "node_modules", "dist", "build", ".git", ".coder", "__pycache__", "vendor", "target" }, + index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" }, + memory = { + enabled = true, -- Enable memory persistence + max_memories = 1000, -- Maximum stored memories + prune_threshold = 0.1, -- Remove low-weight memories + }, + }, + brain = { + enabled = true, -- Enable brain learning system + auto_learn = true, -- Auto-learn from events + auto_commit = true, -- Auto-commit after threshold + commit_threshold = 10, -- Changes before auto-commit + max_nodes = 5000, -- Maximum nodes before pruning + max_deltas = 500, -- Maximum delta history + prune = { + enabled = true, -- Enable auto-pruning + threshold = 0.1, -- Remove nodes below this weight + unused_days = 90, -- Remove unused nodes after N days + }, + output = { + max_tokens = 4000, -- Token budget for LLM context + format = "compact", -- "compact"|"json"|"natural" + }, + }, + suggestion = { + enabled = true, -- Enable ghost text suggestions (Copilot-style) + auto_trigger = true, -- Auto-trigger on typing + debounce = 150, -- Debounce in milliseconds + use_copilot = true, -- Use copilot.lua suggestions when available, fallback to codetyper + keymap = { + accept = "", -- Accept suggestion + next = "", -- Next suggestion (Alt+]) + prev = "", -- Previous suggestion (Alt+[) + dismiss = "", -- Dismiss suggestion (Ctrl+]) + }, + }, } --- Deep merge two tables @@ -88,7 +126,7 @@ function M.validate(config) return false, "Missing LLM configuration" end - local valid_providers = { "claude", "ollama", "openai", "gemini", "copilot" } + local valid_providers = { "ollama", "openai", "gemini", "copilot" } local is_valid_provider = false for _, p in ipairs(valid_providers) do if config.llm.provider == p then @@ -102,12 +140,7 @@ function M.validate(config) end -- Validate provider-specific configuration - if config.llm.provider == "claude" then - local api_key = config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY - if not api_key or api_key == "" then - return false, "Claude API key not configured. Set llm.claude.api_key or ANTHROPIC_API_KEY env var" - end - elseif config.llm.provider == "openai" then + if config.llm.provider == "openai" then local api_key = config.llm.openai.api_key or vim.env.OPENAI_API_KEY if not api_key or api_key == "" then return false, "OpenAI API key not configured. Set llm.openai.api_key or OPENAI_API_KEY env var" diff --git a/lua/codetyper/health.lua b/lua/codetyper/health.lua index 213cb9f..45dafd3 100644 --- a/lua/codetyper/health.lua +++ b/lua/codetyper/health.lua @@ -36,15 +36,7 @@ function M.check() health.info("LLM Provider: " .. config.llm.provider) - if config.llm.provider == "claude" then - local api_key = config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY - if api_key and api_key ~= "" then - health.ok("Claude API key configured") - else - health.warn("Claude API key not set. Set ANTHROPIC_API_KEY or llm.claude.api_key") - end - health.info("Claude model: " .. config.llm.claude.model) - elseif config.llm.provider == "ollama" then + if config.llm.provider == "ollama" then health.info("Ollama host: " .. config.llm.ollama.host) health.info("Ollama model: " .. config.llm.ollama.model) diff --git a/lua/codetyper/indexer/analyzer.lua b/lua/codetyper/indexer/analyzer.lua new file mode 100644 index 0000000..78aad8a --- /dev/null +++ b/lua/codetyper/indexer/analyzer.lua @@ -0,0 +1,582 @@ +---@mod codetyper.indexer.analyzer Code analyzer using Tree-sitter +---@brief [[ +--- Analyzes source files to extract functions, classes, exports, and imports. +--- Uses Tree-sitter when available, falls back to pattern matching. +---@brief ]] + +local M = {} + +local utils = require("codetyper.utils") +local scanner = require("codetyper.indexer.scanner") + +--- Language-specific query patterns for Tree-sitter +local TS_QUERIES = { + lua = { + functions = [[ + (function_declaration name: (identifier) @name) @func + (function_definition) @func + (local_function name: (identifier) @name) @func + (assignment_statement + (variable_list name: (identifier) @name) + (expression_list value: (function_definition) @func)) + ]], + exports = [[ + (return_statement (expression_list (table_constructor))) @export + ]], + }, + typescript = { + functions = [[ + (function_declaration name: (identifier) @name) @func + (method_definition name: (property_identifier) @name) @func + (arrow_function) @func + (lexical_declaration + (variable_declarator name: (identifier) @name value: (arrow_function) @func)) + ]], + exports = [[ + (export_statement) @export + ]], + imports = [[ + (import_statement) @import + ]], + }, + javascript = { + functions = [[ + (function_declaration name: (identifier) @name) @func + (method_definition name: (property_identifier) @name) @func + (arrow_function) @func + ]], + exports = [[ + (export_statement) @export + ]], + imports = [[ + (import_statement) @import + ]], + }, + python = { + functions = [[ + (function_definition name: (identifier) @name) @func + ]], + classes = [[ + (class_definition name: (identifier) @name) @class + ]], + imports = [[ + (import_statement) @import + (import_from_statement) @import + ]], + }, + go = { + functions = [[ + (function_declaration name: (identifier) @name) @func + (method_declaration name: (field_identifier) @name) @func + ]], + imports = [[ + (import_declaration) @import + ]], + }, + rust = { + functions = [[ + (function_item name: (identifier) @name) @func + ]], + imports = [[ + (use_declaration) @import + ]], + }, +} + +--- Hash file content for change detection +---@param content string +---@return string +local function hash_content(content) + local hash = 0 + for i = 1, math.min(#content, 10000) do + hash = (hash * 31 + string.byte(content, i)) % 2147483647 + end + return string.format("%08x", hash) +end + +--- Try to get Tree-sitter parser for a language +---@param lang string +---@return boolean +local function has_ts_parser(lang) + local ok = pcall(vim.treesitter.language.inspect, lang) + return ok +end + +--- Analyze file using Tree-sitter +---@param filepath string +---@param lang string +---@param content string +---@return table|nil +local function analyze_with_treesitter(filepath, lang, content) + if not has_ts_parser(lang) then + return nil + end + + local result = { + functions = {}, + classes = {}, + exports = {}, + imports = {}, + } + + -- Create a temporary buffer for parsing + local bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, vim.split(content, "\n")) + + local ok, parser = pcall(vim.treesitter.get_parser, bufnr, lang) + if not ok or not parser then + vim.api.nvim_buf_delete(bufnr, { force = true }) + return nil + end + + local tree = parser:parse()[1] + if not tree then + vim.api.nvim_buf_delete(bufnr, { force = true }) + return nil + end + + local root = tree:root() + local queries = TS_QUERIES[lang] + + if not queries then + -- Fallback: walk tree manually for common patterns + result = analyze_tree_generic(root, bufnr) + else + -- Use language-specific queries + if queries.functions then + local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.functions) + if query_ok then + for id, node in query:iter_captures(root, bufnr, 0, -1) do + local capture_name = query.captures[id] + if capture_name == "func" or capture_name == "name" then + local start_row, _, end_row, _ = node:range() + local name = nil + + -- Try to get name from sibling capture or child + if capture_name == "func" then + local name_node = node:field("name")[1] + if name_node then + name = vim.treesitter.get_node_text(name_node, bufnr) + end + else + name = vim.treesitter.get_node_text(node, bufnr) + end + + if name and not vim.tbl_contains(vim.tbl_map(function(f) + return f.name + end, result.functions), name) then + table.insert(result.functions, { + name = name, + line = start_row + 1, + end_line = end_row + 1, + params = {}, + }) + end + end + end + end + end + + if queries.classes then + local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.classes) + if query_ok then + for id, node in query:iter_captures(root, bufnr, 0, -1) do + local capture_name = query.captures[id] + if capture_name == "class" then + local start_row, _, end_row, _ = node:range() + local name_node = node:field("name")[1] + local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" + + table.insert(result.classes, { + name = name, + line = start_row + 1, + end_line = end_row + 1, + methods = {}, + }) + end + end + end + end + + if queries.exports then + local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.exports) + if query_ok then + for _, node in query:iter_captures(root, bufnr, 0, -1) do + local text = vim.treesitter.get_node_text(node, bufnr) + local start_row, _, _, _ = node:range() + + -- Extract export names (simplified) + local names = {} + for name in text:gmatch("export%s+[%w_]+%s+([%w_]+)") do + table.insert(names, name) + end + for name in text:gmatch("export%s*{([^}]+)}") do + for n in name:gmatch("([%w_]+)") do + table.insert(names, n) + end + end + + for _, name in ipairs(names) do + table.insert(result.exports, { + name = name, + type = "unknown", + line = start_row + 1, + }) + end + end + end + end + + if queries.imports then + local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.imports) + if query_ok then + for _, node in query:iter_captures(root, bufnr, 0, -1) do + local text = vim.treesitter.get_node_text(node, bufnr) + local start_row, _, _, _ = node:range() + + -- Extract import source + local source = text:match('["\']([^"\']+)["\']') + if source then + table.insert(result.imports, { + source = source, + names = {}, + line = start_row + 1, + }) + end + end + end + end + end + + vim.api.nvim_buf_delete(bufnr, { force = true }) + return result +end + +--- Generic tree analysis for unsupported languages +---@param root TSNode +---@param bufnr number +---@return table +local function analyze_tree_generic(root, bufnr) + local result = { + functions = {}, + classes = {}, + exports = {}, + imports = {}, + } + + local function visit(node) + local node_type = node:type() + + -- Common function patterns + if + node_type:match("function") + or node_type:match("method") + or node_type == "arrow_function" + or node_type == "func_literal" + then + local start_row, _, end_row, _ = node:range() + local name_node = node:field("name")[1] + local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" + + table.insert(result.functions, { + name = name, + line = start_row + 1, + end_line = end_row + 1, + params = {}, + }) + end + + -- Common class patterns + if node_type:match("class") or node_type == "struct_item" or node_type == "impl_item" then + local start_row, _, end_row, _ = node:range() + local name_node = node:field("name")[1] + local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous" + + table.insert(result.classes, { + name = name, + line = start_row + 1, + end_line = end_row + 1, + methods = {}, + }) + end + + -- Recurse into children + for child in node:iter_children() do + visit(child) + end + end + + visit(root) + return result +end + +--- Analyze file using pattern matching (fallback) +---@param content string +---@param lang string +---@return table +local function analyze_with_patterns(content, lang) + local result = { + functions = {}, + classes = {}, + exports = {}, + imports = {}, + } + + local lines = vim.split(content, "\n") + + -- Language-specific patterns + local patterns = { + lua = { + func_start = "^%s*local?%s*function%s+([%w_%.]+)", + func_assign = "^%s*([%w_%.]+)%s*=%s*function", + module_return = "^return%s+M", + }, + javascript = { + func_start = "^%s*function%s+([%w_]+)", + func_arrow = "^%s*const%s+([%w_]+)%s*=%s*", + class_start = "^%s*class%s+([%w_]+)", + export_line = "^%s*export%s+", + import_line = "^%s*import%s+", + }, + typescript = { + func_start = "^%s*function%s+([%w_]+)", + func_arrow = "^%s*const%s+([%w_]+)%s*=%s*", + class_start = "^%s*class%s+([%w_]+)", + export_line = "^%s*export%s+", + import_line = "^%s*import%s+", + }, + python = { + func_start = "^%s*def%s+([%w_]+)", + class_start = "^%s*class%s+([%w_]+)", + import_line = "^%s*import%s+", + from_import = "^%s*from%s+", + }, + go = { + func_start = "^func%s+([%w_]+)", + method_start = "^func%s+%([^%)]+%)%s+([%w_]+)", + import_line = "^import%s+", + }, + rust = { + func_start = "^%s*pub?%s*fn%s+([%w_]+)", + struct_start = "^%s*pub?%s*struct%s+([%w_]+)", + impl_start = "^%s*impl%s+([%w_<>]+)", + use_line = "^%s*use%s+", + }, + } + + local lang_patterns = patterns[lang] or patterns.javascript + + for i, line in ipairs(lines) do + -- Functions + if lang_patterns.func_start then + local name = line:match(lang_patterns.func_start) + if name then + table.insert(result.functions, { + name = name, + line = i, + end_line = i, + params = {}, + }) + end + end + + if lang_patterns.func_arrow then + local name = line:match(lang_patterns.func_arrow) + if name and line:match("=>") then + table.insert(result.functions, { + name = name, + line = i, + end_line = i, + params = {}, + }) + end + end + + if lang_patterns.func_assign then + local name = line:match(lang_patterns.func_assign) + if name then + table.insert(result.functions, { + name = name, + line = i, + end_line = i, + params = {}, + }) + end + end + + if lang_patterns.method_start then + local name = line:match(lang_patterns.method_start) + if name then + table.insert(result.functions, { + name = name, + line = i, + end_line = i, + params = {}, + }) + end + end + + -- Classes + if lang_patterns.class_start then + local name = line:match(lang_patterns.class_start) + if name then + table.insert(result.classes, { + name = name, + line = i, + end_line = i, + methods = {}, + }) + end + end + + if lang_patterns.struct_start then + local name = line:match(lang_patterns.struct_start) + if name then + table.insert(result.classes, { + name = name, + line = i, + end_line = i, + methods = {}, + }) + end + end + + -- Exports + if lang_patterns.export_line and line:match(lang_patterns.export_line) then + local name = line:match("export%s+[%w_]+%s+([%w_]+)") + or line:match("export%s+default%s+([%w_]+)") + or line:match("export%s+{%s*([%w_]+)") + if name then + table.insert(result.exports, { + name = name, + type = "unknown", + line = i, + }) + end + end + + -- Imports + if lang_patterns.import_line and line:match(lang_patterns.import_line) then + local source = line:match('["\']([^"\']+)["\']') + if source then + table.insert(result.imports, { + source = source, + names = {}, + line = i, + }) + end + end + + if lang_patterns.from_import and line:match(lang_patterns.from_import) then + local source = line:match("from%s+([%w_%.]+)") + if source then + table.insert(result.imports, { + source = source, + names = {}, + line = i, + }) + end + end + + if lang_patterns.use_line and line:match(lang_patterns.use_line) then + local source = line:match("use%s+([%w_:]+)") + if source then + table.insert(result.imports, { + source = source, + names = {}, + line = i, + }) + end + end + end + + -- For Lua, infer exports from module table + if lang == "lua" then + for _, func in ipairs(result.functions) do + if func.name:match("^M%.") then + local name = func.name:gsub("^M%.", "") + table.insert(result.exports, { + name = name, + type = "function", + line = func.line, + }) + end + end + end + + return result +end + +--- Analyze a single file +---@param filepath string Full path to file +---@return FileIndex|nil +function M.analyze_file(filepath) + local content = utils.read_file(filepath) + if not content then + return nil + end + + local lang = scanner.get_language(filepath) + + -- Map to Tree-sitter language names + local ts_lang_map = { + typescript = "typescript", + typescriptreact = "tsx", + javascript = "javascript", + javascriptreact = "javascript", + python = "python", + go = "go", + rust = "rust", + lua = "lua", + } + + local ts_lang = ts_lang_map[lang] or lang + + -- Try Tree-sitter first + local analysis = analyze_with_treesitter(filepath, ts_lang, content) + + -- Fallback to pattern matching + if not analysis then + analysis = analyze_with_patterns(content, lang) + end + + return { + path = filepath, + language = lang, + hash = hash_content(content), + exports = analysis.exports, + imports = analysis.imports, + functions = analysis.functions, + classes = analysis.classes, + last_indexed = os.time(), + } +end + +--- Extract exports from a buffer +---@param bufnr number +---@return Export[] +function M.extract_exports(bufnr) + local filepath = vim.api.nvim_buf_get_name(bufnr) + local analysis = M.analyze_file(filepath) + return analysis and analysis.exports or {} +end + +--- Extract functions from a buffer +---@param bufnr number +---@return FunctionInfo[] +function M.extract_functions(bufnr) + local filepath = vim.api.nvim_buf_get_name(bufnr) + local analysis = M.analyze_file(filepath) + return analysis and analysis.functions or {} +end + +--- Extract imports from a buffer +---@param bufnr number +---@return Import[] +function M.extract_imports(bufnr) + local filepath = vim.api.nvim_buf_get_name(bufnr) + local analysis = M.analyze_file(filepath) + return analysis and analysis.imports or {} +end + +return M diff --git a/lua/codetyper/indexer/init.lua b/lua/codetyper/indexer/init.lua new file mode 100644 index 0000000..69a065e --- /dev/null +++ b/lua/codetyper/indexer/init.lua @@ -0,0 +1,604 @@ +---@mod codetyper.indexer Project indexer for Codetyper.nvim +---@brief [[ +--- Indexes project structure, dependencies, and code symbols. +--- Stores knowledge in .coder/ directory for enriching LLM context. +---@brief ]] + +local M = {} + +local utils = require("codetyper.utils") + +--- Index schema version for migrations +local INDEX_VERSION = 1 + +--- Index file name +local INDEX_FILE = "index.json" + +--- Debounce timer for file indexing +local index_timer = nil +local INDEX_DEBOUNCE_MS = 500 + +--- Default indexer configuration +local default_config = { + enabled = true, + auto_index = true, + index_on_open = false, + max_file_size = 100000, + excluded_dirs = { "node_modules", "dist", "build", ".git", ".coder", "__pycache__", "vendor", "target" }, + index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" }, + memory = { + enabled = true, + max_memories = 1000, + prune_threshold = 0.1, + }, +} + +--- Current configuration +---@type table +local config = vim.deepcopy(default_config) + +--- Cached project index +---@type table +local index_cache = {} + +---@class ProjectIndex +---@field version number Index schema version +---@field project_root string Absolute path to project +---@field project_name string Project name +---@field project_type string "node"|"rust"|"go"|"python"|"lua"|"unknown" +---@field dependencies table name -> version +---@field dev_dependencies table name -> version +---@field files table path -> FileIndex +---@field symbols table symbol -> [file paths] +---@field last_indexed number Timestamp +---@field stats {files: number, functions: number, classes: number, exports: number} + +---@class FileIndex +---@field path string Relative path from project root +---@field language string Detected language +---@field hash string Content hash for change detection +---@field exports Export[] Exported symbols +---@field imports Import[] Dependencies +---@field functions FunctionInfo[] +---@field classes ClassInfo[] +---@field last_indexed number Timestamp + +---@class Export +---@field name string Symbol name +---@field type string "function"|"class"|"constant"|"type"|"variable" +---@field line number Line number + +---@class Import +---@field source string Import source/module +---@field names string[] Imported names +---@field line number Line number + +---@class FunctionInfo +---@field name string Function name +---@field params string[] Parameter names +---@field line number Start line +---@field end_line number End line +---@field docstring string|nil Documentation + +---@class ClassInfo +---@field name string Class name +---@field methods string[] Method names +---@field line number Start line +---@field end_line number End line +---@field docstring string|nil Documentation + +--- Get the index file path +---@return string|nil +local function get_index_path() + local root = utils.get_project_root() + if not root then + return nil + end + return root .. "/.coder/" .. INDEX_FILE +end + +--- Create empty index structure +---@return ProjectIndex +local function create_empty_index() + local root = utils.get_project_root() + return { + version = INDEX_VERSION, + project_root = root or "", + project_name = root and vim.fn.fnamemodify(root, ":t") or "", + project_type = "unknown", + dependencies = {}, + dev_dependencies = {}, + files = {}, + symbols = {}, + last_indexed = os.time(), + stats = { + files = 0, + functions = 0, + classes = 0, + exports = 0, + }, + } +end + +--- Load index from disk +---@return ProjectIndex|nil +function M.load_index() + local root = utils.get_project_root() + if not root then + return nil + end + + -- Check cache first + if index_cache[root] then + return index_cache[root] + end + + local path = get_index_path() + if not path then + return nil + end + + local content = utils.read_file(path) + if not content then + return nil + end + + local ok, index = pcall(vim.json.decode, content) + if not ok or not index then + return nil + end + + -- Validate version + if index.version ~= INDEX_VERSION then + -- Index needs migration or rebuild + return nil + end + + -- Cache it + index_cache[root] = index + return index +end + +--- Save index to disk +---@param index ProjectIndex +---@return boolean +function M.save_index(index) + local root = utils.get_project_root() + if not root then + return false + end + + -- Ensure .coder directory exists + local coder_dir = root .. "/.coder" + utils.ensure_dir(coder_dir) + + local path = get_index_path() + if not path then + return false + end + + local ok, encoded = pcall(vim.json.encode, index) + if not ok then + return false + end + + local success = utils.write_file(path, encoded) + if success then + -- Update cache + index_cache[root] = index + end + return success +end + +--- Index the entire project +---@param callback? fun(index: ProjectIndex) +---@return ProjectIndex|nil +function M.index_project(callback) + local scanner = require("codetyper.indexer.scanner") + local analyzer = require("codetyper.indexer.analyzer") + + local index = create_empty_index() + local root = utils.get_project_root() + + if not root then + if callback then + callback(index) + end + return index + end + + -- Detect project type and parse dependencies + index.project_type = scanner.detect_project_type(root) + local deps = scanner.parse_dependencies(root, index.project_type) + index.dependencies = deps.dependencies or {} + index.dev_dependencies = deps.dev_dependencies or {} + + -- Get all indexable files + local files = scanner.get_indexable_files(root, config) + + -- Index each file + local total_functions = 0 + local total_classes = 0 + local total_exports = 0 + + for _, filepath in ipairs(files) do + local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "") + local file_index = analyzer.analyze_file(filepath) + + if file_index then + file_index.path = relative_path + index.files[relative_path] = file_index + + -- Update symbol index + for _, exp in ipairs(file_index.exports or {}) do + if not index.symbols[exp.name] then + index.symbols[exp.name] = {} + end + table.insert(index.symbols[exp.name], relative_path) + total_exports = total_exports + 1 + end + + total_functions = total_functions + #(file_index.functions or {}) + total_classes = total_classes + #(file_index.classes or {}) + end + end + + -- Update stats + index.stats = { + files = #files, + functions = total_functions, + classes = total_classes, + exports = total_exports, + } + index.last_indexed = os.time() + + -- Save to disk + M.save_index(index) + + -- Store memories + local memory = require("codetyper.indexer.memory") + memory.store_index_summary(index) + + -- Sync project summary to brain + M.sync_project_to_brain(index, files, root) + + if callback then + callback(index) + end + + return index +end + +--- Sync project index to brain +---@param index ProjectIndex +---@param files string[] List of file paths +---@param root string Project root +function M.sync_project_to_brain(index, files, root) + local ok_brain, brain = pcall(require, "codetyper.brain") + if not ok_brain or not brain.is_initialized or not brain.is_initialized() then + return + end + + -- Store project-level pattern + brain.learn({ + type = "pattern", + file = root, + content = { + summary = "Project: " + .. index.project_name + .. " (" + .. index.project_type + .. ") - " + .. index.stats.files + .. " files", + detail = string.format( + "%d functions, %d classes, %d exports", + index.stats.functions, + index.stats.classes, + index.stats.exports + ), + }, + context = { + file = root, + project_type = index.project_type, + dependencies = index.dependencies, + }, + }) + + -- Store key file patterns (files with most functions/classes) + local key_files = {} + for path, file_index in pairs(index.files) do + local score = #(file_index.functions or {}) + (#(file_index.classes or {}) * 2) + if score >= 3 then + table.insert(key_files, { path = path, index = file_index, score = score }) + end + end + + table.sort(key_files, function(a, b) + return a.score > b.score + end) + + -- Store top 20 key files in brain + for i, kf in ipairs(key_files) do + if i > 20 then + break + end + M.sync_to_brain(root .. "/" .. kf.path, kf.index) + end +end + +--- Index a single file (incremental update) +---@param filepath string +---@return FileIndex|nil +function M.index_file(filepath) + local analyzer = require("codetyper.indexer.analyzer") + local memory = require("codetyper.indexer.memory") + local root = utils.get_project_root() + + if not root then + return nil + end + + -- Load existing index + local index = M.load_index() or create_empty_index() + + -- Analyze file + local file_index = analyzer.analyze_file(filepath) + if not file_index then + return nil + end + + local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "") + file_index.path = relative_path + + -- Remove old symbol references for this file + for symbol, paths in pairs(index.symbols) do + for i = #paths, 1, -1 do + if paths[i] == relative_path then + table.remove(paths, i) + end + end + if #paths == 0 then + index.symbols[symbol] = nil + end + end + + -- Add new file index + index.files[relative_path] = file_index + + -- Update symbol index + for _, exp in ipairs(file_index.exports or {}) do + if not index.symbols[exp.name] then + index.symbols[exp.name] = {} + end + table.insert(index.symbols[exp.name], relative_path) + end + + -- Recalculate stats + local total_functions = 0 + local total_classes = 0 + local total_exports = 0 + local file_count = 0 + + for _, f in pairs(index.files) do + file_count = file_count + 1 + total_functions = total_functions + #(f.functions or {}) + total_classes = total_classes + #(f.classes or {}) + total_exports = total_exports + #(f.exports or {}) + end + + index.stats = { + files = file_count, + functions = total_functions, + classes = total_classes, + exports = total_exports, + } + index.last_indexed = os.time() + + -- Save to disk + M.save_index(index) + + -- Store file memory + memory.store_file_memory(relative_path, file_index) + + -- Sync to brain if available + M.sync_to_brain(filepath, file_index) + + return file_index +end + +--- Sync file analysis to brain system +---@param filepath string Full file path +---@param file_index FileIndex File analysis +function M.sync_to_brain(filepath, file_index) + local ok_brain, brain = pcall(require, "codetyper.brain") + if not ok_brain or not brain.is_initialized or not brain.is_initialized() then + return + end + + -- Only store if file has meaningful content + local funcs = file_index.functions or {} + local classes = file_index.classes or {} + if #funcs == 0 and #classes == 0 then + return + end + + -- Build summary + local parts = {} + if #funcs > 0 then + local func_names = {} + for i, f in ipairs(funcs) do + if i <= 5 then + table.insert(func_names, f.name) + end + end + table.insert(parts, "functions: " .. table.concat(func_names, ", ")) + if #funcs > 5 then + table.insert(parts, "(+" .. (#funcs - 5) .. " more)") + end + end + if #classes > 0 then + local class_names = {} + for _, c in ipairs(classes) do + table.insert(class_names, c.name) + end + table.insert(parts, "classes: " .. table.concat(class_names, ", ")) + end + + local filename = vim.fn.fnamemodify(filepath, ":t") + local summary = filename .. " - " .. table.concat(parts, "; ") + + -- Learn this pattern in brain + brain.learn({ + type = "pattern", + file = filepath, + content = { + summary = summary, + detail = #funcs .. " functions, " .. #classes .. " classes", + }, + context = { + file = file_index.path or filepath, + language = file_index.language, + functions = funcs, + classes = classes, + exports = file_index.exports, + imports = file_index.imports, + }, + }) +end + +--- Schedule file indexing with debounce +---@param filepath string +function M.schedule_index_file(filepath) + if not config.enabled or not config.auto_index then + return + end + + -- Check if file should be indexed + local scanner = require("codetyper.indexer.scanner") + if not scanner.should_index(filepath, config) then + return + end + + -- Cancel existing timer + if index_timer then + index_timer:stop() + end + + -- Schedule new index + index_timer = vim.defer_fn(function() + M.index_file(filepath) + index_timer = nil + end, INDEX_DEBOUNCE_MS) +end + +--- Get relevant context for a prompt +---@param opts {file: string, intent: table|nil, prompt: string, scope: string|nil} +---@return table Context information +function M.get_context_for(opts) + local memory = require("codetyper.indexer.memory") + local index = M.load_index() + + local context = { + project_type = "unknown", + dependencies = {}, + relevant_files = {}, + relevant_symbols = {}, + patterns = {}, + } + + if not index then + return context + end + + context.project_type = index.project_type + context.dependencies = index.dependencies + + -- Find relevant symbols from prompt + local words = {} + for word in opts.prompt:gmatch("%w+") do + if #word > 2 then + words[word:lower()] = true + end + end + + -- Match symbols + for symbol, files in pairs(index.symbols) do + if words[symbol:lower()] then + context.relevant_symbols[symbol] = files + end + end + + -- Get file context if available + if opts.file then + local root = utils.get_project_root() + if root then + local relative_path = opts.file:gsub("^" .. vim.pesc(root) .. "/", "") + local file_index = index.files[relative_path] + if file_index then + context.current_file = file_index + end + end + end + + -- Get relevant memories + context.patterns = memory.get_relevant(opts.prompt, 5) + + return context +end + +--- Get index status +---@return table Status information +function M.get_status() + local index = M.load_index() + if not index then + return { + indexed = false, + stats = nil, + last_indexed = nil, + } + end + + return { + indexed = true, + stats = index.stats, + last_indexed = index.last_indexed, + project_type = index.project_type, + } +end + +--- Clear the project index +function M.clear() + local root = utils.get_project_root() + if root then + index_cache[root] = nil + end + + local path = get_index_path() + if path and utils.file_exists(path) then + os.remove(path) + end +end + +--- Setup the indexer with configuration +---@param opts? table Configuration options +function M.setup(opts) + if opts then + config = vim.tbl_deep_extend("force", config, opts) + end + + -- Index on startup if configured + if config.index_on_open then + vim.defer_fn(function() + M.index_project() + end, 1000) + end +end + +--- Get current configuration +---@return table +function M.get_config() + return vim.deepcopy(config) +end + +return M diff --git a/lua/codetyper/indexer/memory.lua b/lua/codetyper/indexer/memory.lua new file mode 100644 index 0000000..9b45cbc --- /dev/null +++ b/lua/codetyper/indexer/memory.lua @@ -0,0 +1,539 @@ +---@mod codetyper.indexer.memory Memory persistence manager +---@brief [[ +--- Stores and retrieves learned patterns and memories in .coder/memories/. +--- Supports session history for learning from interactions. +---@brief ]] + +local M = {} + +local utils = require("codetyper.utils") + +--- Memory directories +local MEMORIES_DIR = "memories" +local SESSIONS_DIR = "sessions" +local FILES_DIR = "files" + +--- Memory files +local PATTERNS_FILE = "patterns.json" +local CONVENTIONS_FILE = "conventions.json" +local SYMBOLS_FILE = "symbols.json" + +--- In-memory cache +local cache = { + patterns = nil, + conventions = nil, + symbols = nil, +} + +---@class Memory +---@field id string Unique identifier +---@field type "pattern"|"convention"|"session"|"interaction" +---@field content string The learned information +---@field context table Where/when learned +---@field weight number Importance score (0.0-1.0) +---@field created_at number Timestamp +---@field updated_at number Last update timestamp +---@field used_count number Times referenced + +--- Get the memories base directory +---@return string|nil +local function get_memories_dir() + local root = utils.get_project_root() + if not root then + return nil + end + return root .. "/.coder/" .. MEMORIES_DIR +end + +--- Get the sessions directory +---@return string|nil +local function get_sessions_dir() + local root = utils.get_project_root() + if not root then + return nil + end + return root .. "/.coder/" .. SESSIONS_DIR +end + +--- Ensure memories directory exists +---@return boolean +local function ensure_memories_dir() + local dir = get_memories_dir() + if not dir then + return false + end + utils.ensure_dir(dir) + utils.ensure_dir(dir .. "/" .. FILES_DIR) + return true +end + +--- Ensure sessions directory exists +---@return boolean +local function ensure_sessions_dir() + local dir = get_sessions_dir() + if not dir then + return false + end + return utils.ensure_dir(dir) +end + +--- Generate a unique ID +---@return string +local function generate_id() + return string.format("mem_%d_%s", os.time(), string.sub(tostring(math.random()), 3, 8)) +end + +--- Load a memory file +---@param filename string +---@return table +local function load_memory_file(filename) + local dir = get_memories_dir() + if not dir then + return {} + end + + local path = dir .. "/" .. filename + local content = utils.read_file(path) + if not content then + return {} + end + + local ok, data = pcall(vim.json.decode, content) + if not ok or not data then + return {} + end + + return data +end + +--- Save a memory file +---@param filename string +---@param data table +---@return boolean +local function save_memory_file(filename, data) + if not ensure_memories_dir() then + return false + end + + local dir = get_memories_dir() + if not dir then + return false + end + + local path = dir .. "/" .. filename + local ok, encoded = pcall(vim.json.encode, data) + if not ok then + return false + end + + return utils.write_file(path, encoded) +end + +--- Hash a file path for storage +---@param filepath string +---@return string +local function hash_path(filepath) + local hash = 0 + for i = 1, #filepath do + hash = (hash * 31 + string.byte(filepath, i)) % 2147483647 + end + return string.format("%08x", hash) +end + +--- Load patterns from cache or disk +---@return table +function M.load_patterns() + if cache.patterns then + return cache.patterns + end + cache.patterns = load_memory_file(PATTERNS_FILE) + return cache.patterns +end + +--- Load conventions from cache or disk +---@return table +function M.load_conventions() + if cache.conventions then + return cache.conventions + end + cache.conventions = load_memory_file(CONVENTIONS_FILE) + return cache.conventions +end + +--- Load symbols from cache or disk +---@return table +function M.load_symbols() + if cache.symbols then + return cache.symbols + end + cache.symbols = load_memory_file(SYMBOLS_FILE) + return cache.symbols +end + +--- Store a new memory +---@param memory Memory +---@return boolean +function M.store_memory(memory) + memory.id = memory.id or generate_id() + memory.created_at = memory.created_at or os.time() + memory.updated_at = os.time() + memory.used_count = memory.used_count or 0 + memory.weight = memory.weight or 0.5 + + local filename + if memory.type == "pattern" then + filename = PATTERNS_FILE + cache.patterns = nil + elseif memory.type == "convention" then + filename = CONVENTIONS_FILE + cache.conventions = nil + else + filename = PATTERNS_FILE + cache.patterns = nil + end + + local data = load_memory_file(filename) + data[memory.id] = memory + + return save_memory_file(filename, data) +end + +--- Store file-specific memory +---@param relative_path string Relative file path +---@param file_index table FileIndex data +---@return boolean +function M.store_file_memory(relative_path, file_index) + if not ensure_memories_dir() then + return false + end + + local dir = get_memories_dir() + if not dir then + return false + end + + local hash = hash_path(relative_path) + local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json" + + local data = { + path = relative_path, + indexed_at = os.time(), + functions = file_index.functions or {}, + classes = file_index.classes or {}, + exports = file_index.exports or {}, + imports = file_index.imports or {}, + } + + local ok, encoded = pcall(vim.json.encode, data) + if not ok then + return false + end + + return utils.write_file(path, encoded) +end + +--- Load file-specific memory +---@param relative_path string +---@return table|nil +function M.load_file_memory(relative_path) + local dir = get_memories_dir() + if not dir then + return nil + end + + local hash = hash_path(relative_path) + local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json" + + local content = utils.read_file(path) + if not content then + return nil + end + + local ok, data = pcall(vim.json.decode, content) + if not ok then + return nil + end + + return data +end + +--- Store index summary as memories +---@param index ProjectIndex +function M.store_index_summary(index) + -- Store project type convention + if index.project_type and index.project_type ~= "unknown" then + M.store_memory({ + type = "convention", + content = "Project uses " .. index.project_type .. " ecosystem", + context = { + project_root = index.project_root, + detected_at = os.time(), + }, + weight = 0.9, + }) + end + + -- Store dependency patterns + local dep_count = 0 + for _ in pairs(index.dependencies or {}) do + dep_count = dep_count + 1 + end + + if dep_count > 0 then + local deps_list = {} + for name, _ in pairs(index.dependencies) do + table.insert(deps_list, name) + end + + M.store_memory({ + type = "pattern", + content = "Project dependencies: " .. table.concat(deps_list, ", "), + context = { + dependency_count = dep_count, + }, + weight = 0.7, + }) + end + + -- Update symbol cache + cache.symbols = nil + save_memory_file(SYMBOLS_FILE, index.symbols or {}) +end + +--- Store session interaction +---@param interaction {prompt: string, response: string, file: string|nil, success: boolean} +function M.store_session(interaction) + if not ensure_sessions_dir() then + return + end + + local dir = get_sessions_dir() + if not dir then + return + end + + -- Use date-based session files + local date = os.date("%Y-%m-%d") + local path = dir .. "/" .. date .. ".json" + + local sessions = {} + local content = utils.read_file(path) + if content then + local ok, data = pcall(vim.json.decode, content) + if ok and data then + sessions = data + end + end + + table.insert(sessions, { + timestamp = os.time(), + prompt = interaction.prompt, + response = string.sub(interaction.response or "", 1, 500), -- Truncate + file = interaction.file, + success = interaction.success, + }) + + -- Limit session size + if #sessions > 100 then + sessions = { unpack(sessions, #sessions - 99) } + end + + local ok, encoded = pcall(vim.json.encode, sessions) + if ok then + utils.write_file(path, encoded) + end +end + +--- Get relevant memories for a query +---@param query string Search query +---@param limit number Maximum results +---@return Memory[] +function M.get_relevant(query, limit) + limit = limit or 10 + local results = {} + + -- Tokenize query + local query_words = {} + for word in query:lower():gmatch("%w+") do + if #word > 2 then + query_words[word] = true + end + end + + -- Search patterns + local patterns = M.load_patterns() + for _, memory in pairs(patterns) do + local score = 0 + local content_lower = (memory.content or ""):lower() + + for word in pairs(query_words) do + if content_lower:find(word, 1, true) then + score = score + 1 + end + end + + if score > 0 then + memory.relevance_score = score * (memory.weight or 0.5) + table.insert(results, memory) + end + end + + -- Search conventions + local conventions = M.load_conventions() + for _, memory in pairs(conventions) do + local score = 0 + local content_lower = (memory.content or ""):lower() + + for word in pairs(query_words) do + if content_lower:find(word, 1, true) then + score = score + 1 + end + end + + if score > 0 then + memory.relevance_score = score * (memory.weight or 0.5) + table.insert(results, memory) + end + end + + -- Sort by relevance + table.sort(results, function(a, b) + return (a.relevance_score or 0) > (b.relevance_score or 0) + end) + + -- Limit results + local limited = {} + for i = 1, math.min(limit, #results) do + limited[i] = results[i] + end + + return limited +end + +--- Update memory usage count +---@param memory_id string +function M.update_usage(memory_id) + local patterns = M.load_patterns() + if patterns[memory_id] then + patterns[memory_id].used_count = (patterns[memory_id].used_count or 0) + 1 + patterns[memory_id].updated_at = os.time() + save_memory_file(PATTERNS_FILE, patterns) + cache.patterns = nil + return + end + + local conventions = M.load_conventions() + if conventions[memory_id] then + conventions[memory_id].used_count = (conventions[memory_id].used_count or 0) + 1 + conventions[memory_id].updated_at = os.time() + save_memory_file(CONVENTIONS_FILE, conventions) + cache.conventions = nil + end +end + +--- Get all memories +---@return {patterns: table, conventions: table, symbols: table} +function M.get_all() + return { + patterns = M.load_patterns(), + conventions = M.load_conventions(), + symbols = M.load_symbols(), + } +end + +--- Clear all memories +---@param pattern? string Optional pattern to match memory IDs +function M.clear(pattern) + if not pattern then + -- Clear all + cache = { patterns = nil, conventions = nil, symbols = nil } + save_memory_file(PATTERNS_FILE, {}) + save_memory_file(CONVENTIONS_FILE, {}) + save_memory_file(SYMBOLS_FILE, {}) + return + end + + -- Clear matching pattern + local patterns = M.load_patterns() + for id in pairs(patterns) do + if id:match(pattern) then + patterns[id] = nil + end + end + save_memory_file(PATTERNS_FILE, patterns) + cache.patterns = nil + + local conventions = M.load_conventions() + for id in pairs(conventions) do + if id:match(pattern) then + conventions[id] = nil + end + end + save_memory_file(CONVENTIONS_FILE, conventions) + cache.conventions = nil +end + +--- Prune low-weight memories +---@param threshold number Weight threshold (default: 0.1) +function M.prune(threshold) + threshold = threshold or 0.1 + + local patterns = M.load_patterns() + local pruned = 0 + for id, memory in pairs(patterns) do + if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then + patterns[id] = nil + pruned = pruned + 1 + end + end + if pruned > 0 then + save_memory_file(PATTERNS_FILE, patterns) + cache.patterns = nil + end + + local conventions = M.load_conventions() + for id, memory in pairs(conventions) do + if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then + conventions[id] = nil + pruned = pruned + 1 + end + end + if pruned > 0 then + save_memory_file(CONVENTIONS_FILE, conventions) + cache.conventions = nil + end + + return pruned +end + +--- Get memory statistics +---@return table +function M.get_stats() + local patterns = M.load_patterns() + local conventions = M.load_conventions() + local symbols = M.load_symbols() + + local pattern_count = 0 + for _ in pairs(patterns) do + pattern_count = pattern_count + 1 + end + + local convention_count = 0 + for _ in pairs(conventions) do + convention_count = convention_count + 1 + end + + local symbol_count = 0 + for _ in pairs(symbols) do + symbol_count = symbol_count + 1 + end + + return { + patterns = pattern_count, + conventions = convention_count, + symbols = symbol_count, + total = pattern_count + convention_count, + } +end + +return M diff --git a/lua/codetyper/indexer/scanner.lua b/lua/codetyper/indexer/scanner.lua new file mode 100644 index 0000000..1846ebf --- /dev/null +++ b/lua/codetyper/indexer/scanner.lua @@ -0,0 +1,409 @@ +---@mod codetyper.indexer.scanner File scanner for project indexing +---@brief [[ +--- Discovers indexable files, detects project type, and parses dependencies. +---@brief ]] + +local M = {} + +local utils = require("codetyper.utils") + +--- Project type markers +local PROJECT_MARKERS = { + node = { "package.json" }, + rust = { "Cargo.toml" }, + go = { "go.mod" }, + python = { "pyproject.toml", "setup.py", "requirements.txt" }, + lua = { "init.lua", ".luarc.json" }, + ruby = { "Gemfile" }, + java = { "pom.xml", "build.gradle" }, + csharp = { "*.csproj", "*.sln" }, +} + +--- File extension to language mapping +local EXTENSION_LANGUAGE = { + lua = "lua", + ts = "typescript", + tsx = "typescriptreact", + js = "javascript", + jsx = "javascriptreact", + py = "python", + go = "go", + rs = "rust", + rb = "ruby", + java = "java", + c = "c", + cpp = "cpp", + h = "c", + hpp = "cpp", + cs = "csharp", +} + +--- Default ignore patterns +local DEFAULT_IGNORES = { + "^%.", -- Hidden files/folders + "^node_modules$", + "^__pycache__$", + "^%.git$", + "^%.coder$", + "^dist$", + "^build$", + "^target$", + "^vendor$", + "^%.next$", + "^%.nuxt$", + "^coverage$", + "%.min%.js$", + "%.min%.css$", + "%.map$", + "%.lock$", + "%-lock%.json$", +} + +--- Detect project type from root markers +---@param root string Project root path +---@return string Project type +function M.detect_project_type(root) + for project_type, markers in pairs(PROJECT_MARKERS) do + for _, marker in ipairs(markers) do + local path = root .. "/" .. marker + if marker:match("^%*") then + -- Glob pattern + local pattern = marker:gsub("^%*", "") + local entries = vim.fn.glob(root .. "/*" .. pattern, false, true) + if #entries > 0 then + return project_type + end + else + if utils.file_exists(path) then + return project_type + end + end + end + end + return "unknown" +end + +--- Parse project dependencies +---@param root string Project root path +---@param project_type string Project type +---@return {dependencies: table, dev_dependencies: table} +function M.parse_dependencies(root, project_type) + local deps = { + dependencies = {}, + dev_dependencies = {}, + } + + if project_type == "node" then + deps = M.parse_package_json(root) + elseif project_type == "rust" then + deps = M.parse_cargo_toml(root) + elseif project_type == "go" then + deps = M.parse_go_mod(root) + elseif project_type == "python" then + deps = M.parse_python_deps(root) + end + + return deps +end + +--- Parse package.json for Node.js projects +---@param root string Project root path +---@return {dependencies: table, dev_dependencies: table} +function M.parse_package_json(root) + local path = root .. "/package.json" + local content = utils.read_file(path) + if not content then + return { dependencies = {}, dev_dependencies = {} } + end + + local ok, pkg = pcall(vim.json.decode, content) + if not ok or not pkg then + return { dependencies = {}, dev_dependencies = {} } + end + + return { + dependencies = pkg.dependencies or {}, + dev_dependencies = pkg.devDependencies or {}, + } +end + +--- Parse Cargo.toml for Rust projects +---@param root string Project root path +---@return {dependencies: table, dev_dependencies: table} +function M.parse_cargo_toml(root) + local path = root .. "/Cargo.toml" + local content = utils.read_file(path) + if not content then + return { dependencies = {}, dev_dependencies = {} } + end + + local deps = {} + local dev_deps = {} + local in_deps = false + local in_dev_deps = false + + for line in content:gmatch("[^\n]+") do + if line:match("^%[dependencies%]") then + in_deps = true + in_dev_deps = false + elseif line:match("^%[dev%-dependencies%]") then + in_deps = false + in_dev_deps = true + elseif line:match("^%[") then + in_deps = false + in_dev_deps = false + elseif in_deps or in_dev_deps then + local name, version = line:match('^([%w_%-]+)%s*=%s*"([^"]+)"') + if not name then + name = line:match("^([%w_%-]+)%s*=") + version = "workspace" + end + if name then + if in_deps then + deps[name] = version or "unknown" + else + dev_deps[name] = version or "unknown" + end + end + end + end + + return { dependencies = deps, dev_dependencies = dev_deps } +end + +--- Parse go.mod for Go projects +---@param root string Project root path +---@return {dependencies: table, dev_dependencies: table} +function M.parse_go_mod(root) + local path = root .. "/go.mod" + local content = utils.read_file(path) + if not content then + return { dependencies = {}, dev_dependencies = {} } + end + + local deps = {} + local in_require = false + + for line in content:gmatch("[^\n]+") do + if line:match("^require%s*%(") then + in_require = true + elseif line:match("^%)") then + in_require = false + elseif in_require then + local module, version = line:match("^%s*([%w%.%-%_/]+)%s+([%w%.%-]+)") + if module then + deps[module] = version + end + else + local module, version = line:match("^require%s+([%w%.%-%_/]+)%s+([%w%.%-]+)") + if module then + deps[module] = version + end + end + end + + return { dependencies = deps, dev_dependencies = {} } +end + +--- Parse Python dependencies (pyproject.toml or requirements.txt) +---@param root string Project root path +---@return {dependencies: table, dev_dependencies: table} +function M.parse_python_deps(root) + local deps = {} + local dev_deps = {} + + -- Try pyproject.toml first + local pyproject = root .. "/pyproject.toml" + local content = utils.read_file(pyproject) + + if content then + -- Simple parsing for dependencies + local in_deps = false + local in_dev = false + + for line in content:gmatch("[^\n]+") do + if line:match("^%[project%.dependencies%]") or line:match("^dependencies%s*=") then + in_deps = true + in_dev = false + elseif line:match("dev") and line:match("dependencies") then + in_deps = false + in_dev = true + elseif line:match("^%[") then + in_deps = false + in_dev = false + elseif in_deps or in_dev then + local name = line:match('"([%w_%-]+)') + if name then + if in_deps then + deps[name] = "latest" + else + dev_deps[name] = "latest" + end + end + end + end + end + + -- Fallback to requirements.txt + local req_file = root .. "/requirements.txt" + content = utils.read_file(req_file) + + if content then + for line in content:gmatch("[^\n]+") do + if not line:match("^#") and not line:match("^%s*$") then + local name, version = line:match("^([%w_%-]+)==([%d%.]+)") + if not name then + name = line:match("^([%w_%-]+)") + version = "latest" + end + if name then + deps[name] = version or "latest" + end + end + end + end + + return { dependencies = deps, dev_dependencies = dev_deps } +end + +--- Check if a file/directory should be ignored +---@param name string File or directory name +---@param config table Indexer configuration +---@return boolean +function M.should_ignore(name, config) + -- Check default patterns + for _, pattern in ipairs(DEFAULT_IGNORES) do + if name:match(pattern) then + return true + end + end + + -- Check config excluded dirs + if config and config.excluded_dirs then + for _, dir in ipairs(config.excluded_dirs) do + if name == dir then + return true + end + end + end + + return false +end + +--- Check if a file should be indexed +---@param filepath string Full file path +---@param config table Indexer configuration +---@return boolean +function M.should_index(filepath, config) + local name = vim.fn.fnamemodify(filepath, ":t") + local ext = vim.fn.fnamemodify(filepath, ":e") + + -- Check if it's a coder file + if utils.is_coder_file(filepath) then + return false + end + + -- Check file size + if config and config.max_file_size then + local stat = vim.loop.fs_stat(filepath) + if stat and stat.size > config.max_file_size then + return false + end + end + + -- Check extension + if config and config.index_extensions then + local valid_ext = false + for _, allowed_ext in ipairs(config.index_extensions) do + if ext == allowed_ext then + valid_ext = true + break + end + end + if not valid_ext then + return false + end + end + + -- Check ignore patterns + if M.should_ignore(name, config) then + return false + end + + return true +end + +--- Get all indexable files in the project +---@param root string Project root path +---@param config table Indexer configuration +---@return string[] List of file paths +function M.get_indexable_files(root, config) + local files = {} + + local function scan_dir(path) + local handle = vim.loop.fs_scandir(path) + if not handle then + return + end + + while true do + local name, type = vim.loop.fs_scandir_next(handle) + if not name then + break + end + + local full_path = path .. "/" .. name + + if M.should_ignore(name, config) then + goto continue + end + + if type == "directory" then + scan_dir(full_path) + elseif type == "file" then + if M.should_index(full_path, config) then + table.insert(files, full_path) + end + end + + ::continue:: + end + end + + scan_dir(root) + return files +end + +--- Get language from file extension +---@param filepath string File path +---@return string Language name +function M.get_language(filepath) + local ext = vim.fn.fnamemodify(filepath, ":e") + return EXTENSION_LANGUAGE[ext] or ext +end + +--- Read .gitignore patterns +---@param root string Project root +---@return string[] Patterns +function M.read_gitignore(root) + local patterns = {} + local path = root .. "/.gitignore" + local content = utils.read_file(path) + + if not content then + return patterns + end + + for line in content:gmatch("[^\n]+") do + -- Skip comments and empty lines + if not line:match("^#") and not line:match("^%s*$") then + -- Convert gitignore pattern to Lua pattern (simplified) + local pattern = line:gsub("^/", "^"):gsub("%*%*", ".*"):gsub("%*", "[^/]*"):gsub("%?", ".") + table.insert(patterns, pattern) + end + end + + return patterns +end + +return M diff --git a/lua/codetyper/init.lua b/lua/codetyper/init.lua index bb1f3d7..4335d07 100644 --- a/lua/codetyper/init.lua +++ b/lua/codetyper/init.lua @@ -1,7 +1,7 @@ ---@mod codetyper Codetyper.nvim - AI-powered coding partner ---@brief [[ --- Codetyper.nvim is a Neovim plugin that acts as your coding partner. ---- It uses LLM APIs (Claude, OpenAI, Gemini, Copilot, Ollama) to help you +--- It uses LLM APIs (OpenAI, Gemini, Copilot, Ollama) to help you --- write code faster using special `.coder.*` files and inline prompt tags. --- Features an event-driven scheduler with confidence scoring and --- completion-aware injection timing. @@ -51,6 +51,24 @@ function M.setup(opts) -- Initialize tree logging (creates .coder folder and initial tree.log) tree.setup() + -- Initialize project indexer if enabled + if M.config.indexer and M.config.indexer.enabled then + local indexer = require("codetyper.indexer") + indexer.setup(M.config.indexer) + end + + -- Initialize brain learning system if enabled + if M.config.brain and M.config.brain.enabled then + local brain = require("codetyper.brain") + brain.setup(M.config.brain) + end + + -- Setup inline ghost text suggestions (Copilot-style) + if M.config.suggestion and M.config.suggestion.enabled then + local suggestion = require("codetyper.suggestion") + suggestion.setup(M.config.suggestion) + end + -- Start the event-driven scheduler if enabled if M.config.scheduler and M.config.scheduler.enabled then local scheduler = require("codetyper.agent.scheduler") diff --git a/lua/codetyper/llm/claude.lua b/lua/codetyper/llm/claude.lua deleted file mode 100644 index c66d5e2..0000000 --- a/lua/codetyper/llm/claude.lua +++ /dev/null @@ -1,364 +0,0 @@ ----@mod codetyper.llm.claude Claude API client for Codetyper.nvim - -local M = {} - -local utils = require("codetyper.utils") -local llm = require("codetyper.llm") - ---- Claude API endpoint -local API_URL = "https://api.anthropic.com/v1/messages" - ---- Get API key from config or environment ----@return string|nil API key -local function get_api_key() - local codetyper = require("codetyper") - local config = codetyper.get_config() - - return config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY -end - ---- Get model from config ----@return string Model name -local function get_model() - local codetyper = require("codetyper") - local config = codetyper.get_config() - - return config.llm.claude.model -end - ---- Build request body for Claude API ----@param prompt string User prompt ----@param context table Context information ----@return table Request body -local function build_request_body(prompt, context) - local system_prompt = llm.build_system_prompt(context) - - return { - model = get_model(), - max_tokens = 4096, - system = system_prompt, - messages = { - { - role = "user", - content = prompt, - }, - }, - } -end - ---- Make HTTP request to Claude API ----@param body table Request body ----@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function -local function make_request(body, callback) - local api_key = get_api_key() - if not api_key then - callback(nil, "Claude API key not configured", nil) - return - end - - local json_body = vim.json.encode(body) - - -- Use curl for HTTP request (plenary.curl alternative) - local cmd = { - "curl", - "-s", - "-X", - "POST", - API_URL, - "-H", - "Content-Type: application/json", - "-H", - "x-api-key: " .. api_key, - "-H", - "anthropic-version: 2023-06-01", - "-d", - json_body, - } - - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if not data or #data == 0 or (data[1] == "" and #data == 1) then - return - end - - local response_text = table.concat(data, "\n") - local ok, response = pcall(vim.json.decode, response_text) - - if not ok then - vim.schedule(function() - callback(nil, "Failed to parse Claude response", nil) - end) - return - end - - if response.error then - vim.schedule(function() - callback(nil, response.error.message or "Claude API error", nil) - end) - return - end - - -- Extract usage info - local usage = response.usage or {} - - if response.content and response.content[1] and response.content[1].text then - local code = llm.extract_code(response.content[1].text) - vim.schedule(function() - callback(code, nil, usage) - end) - else - vim.schedule(function() - callback(nil, "No content in Claude response", nil) - end) - end - end, - on_stderr = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - callback(nil, "Claude API request failed: " .. table.concat(data, "\n"), nil) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - callback(nil, "Claude API request failed with code: " .. code, nil) - end) - end - end, - }) -end - ---- Generate code using Claude API ----@param prompt string The user's prompt ----@param context table Context information ----@param callback fun(response: string|nil, error: string|nil) Callback function -function M.generate(prompt, context, callback) - local logs = require("codetyper.agent.logs") - local model = get_model() - - -- Log the request - logs.request("claude", model) - logs.thinking("Building request body...") - - local body = build_request_body(prompt, context) - - -- Estimate prompt tokens - local prompt_estimate = logs.estimate_tokens(vim.json.encode(body)) - logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) - logs.thinking("Sending to Claude API...") - - utils.notify("Sending request to Claude...", vim.log.levels.INFO) - - make_request(body, function(response, err, usage) - if err then - logs.error(err) - utils.notify(err, vim.log.levels.ERROR) - callback(nil, err) - else - -- Log token usage - if usage then - logs.response(usage.input_tokens or 0, usage.output_tokens or 0, "end_turn") - end - logs.thinking("Response received, extracting code...") - logs.info("Code generated successfully") - utils.notify("Code generated successfully", vim.log.levels.INFO) - callback(response, nil) - end - end) -end - ---- Check if Claude is properly configured ----@return boolean, string? Valid status and optional error message -function M.validate() - local api_key = get_api_key() - if not api_key or api_key == "" then - return false, "Claude API key not configured" - end - return true -end - ---- Generate with tool use support for agentic mode ----@param messages table[] Conversation history ----@param context table Context information ----@param tool_definitions table Tool definitions ----@param callback fun(response: table|nil, error: string|nil) Callback with raw response -function M.generate_with_tools(messages, context, tool_definitions, callback) - local logs = require("codetyper.agent.logs") - local model = get_model() - - -- Log the request - logs.request("claude", model) - logs.thinking("Preparing agent request...") - - local api_key = get_api_key() - if not api_key then - logs.error("Claude API key not configured") - callback(nil, "Claude API key not configured") - return - end - - local tools_module = require("codetyper.agent.tools") - local agent_prompts = require("codetyper.prompts.agent") - - -- Build system prompt with agent instructions - local system_prompt = llm.build_system_prompt(context) - system_prompt = system_prompt .. "\n\n" .. agent_prompts.system - system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions - - -- Build request body with tools - local body = { - model = get_model(), - max_tokens = 4096, - system = system_prompt, - messages = M.format_messages_for_claude(messages), - tools = tools_module.to_claude_format(), - } - - local json_body = vim.json.encode(body) - - -- Estimate prompt tokens - local prompt_estimate = logs.estimate_tokens(json_body) - logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) - logs.thinking("Sending to Claude API...") - - local cmd = { - "curl", - "-s", - "-X", - "POST", - API_URL, - "-H", - "Content-Type: application/json", - "-H", - "x-api-key: " .. api_key, - "-H", - "anthropic-version: 2023-06-01", - "-d", - json_body, - } - - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if not data or #data == 0 or (data[1] == "" and #data == 1) then - return - end - - local response_text = table.concat(data, "\n") - local ok, response = pcall(vim.json.decode, response_text) - - if not ok then - vim.schedule(function() - logs.error("Failed to parse Claude response") - callback(nil, "Failed to parse Claude response") - end) - return - end - - if response.error then - vim.schedule(function() - logs.error(response.error.message or "Claude API error") - callback(nil, response.error.message or "Claude API error") - end) - return - end - - -- Log token usage from response - if response.usage then - logs.response(response.usage.input_tokens or 0, response.usage.output_tokens or 0, response.stop_reason) - end - - -- Log what's in the response - if response.content then - for _, block in ipairs(response.content) do - if block.type == "text" then - logs.thinking("Response contains text") - elseif block.type == "tool_use" then - logs.thinking("Tool call: " .. block.name) - end - end - end - - -- Return raw response for parser to handle - vim.schedule(function() - callback(response, nil) - end) - end, - on_stderr = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - logs.error("Claude API request failed: " .. table.concat(data, "\n")) - callback(nil, "Claude API request failed: " .. table.concat(data, "\n")) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - logs.error("Claude API request failed with code: " .. code) - callback(nil, "Claude API request failed with code: " .. code) - end) - end - end, - }) -end - ---- Format messages for Claude API ----@param messages table[] Internal message format ----@return table[] Claude API message format -function M.format_messages_for_claude(messages) - local formatted = {} - - for _, msg in ipairs(messages) do - if msg.role == "user" then - if type(msg.content) == "table" then - -- Tool results - table.insert(formatted, { - role = "user", - content = msg.content, - }) - else - table.insert(formatted, { - role = "user", - content = msg.content, - }) - end - elseif msg.role == "assistant" then - -- Build content array for assistant messages - local content = {} - - -- Add text if present - if msg.content and msg.content ~= "" then - table.insert(content, { - type = "text", - text = msg.content, - }) - end - - -- Add tool uses if present - if msg.tool_calls then - for _, tool_call in ipairs(msg.tool_calls) do - table.insert(content, { - type = "tool_use", - id = tool_call.id, - name = tool_call.name, - input = tool_call.parameters, - }) - end - end - - if #content > 0 then - table.insert(formatted, { - role = "assistant", - content = content, - }) - end - end - end - - return formatted -end - -return M diff --git a/lua/codetyper/llm/init.lua b/lua/codetyper/llm/init.lua index 3058d28..7ffeef8 100644 --- a/lua/codetyper/llm/init.lua +++ b/lua/codetyper/llm/init.lua @@ -10,9 +10,7 @@ function M.get_client() local codetyper = require("codetyper") local config = codetyper.get_config() - if config.llm.provider == "claude" then - return require("codetyper.llm.claude") - elseif config.llm.provider == "ollama" then + if config.llm.provider == "ollama" then return require("codetyper.llm.ollama") elseif config.llm.provider == "openai" then return require("codetyper.llm.openai") @@ -50,7 +48,49 @@ function M.build_system_prompt(context) system = system:gsub("{{language}}", context.language or "unknown") system = system:gsub("{{filepath}}", context.file_path or "unknown") - -- Add file content with analysis hints + -- For agent mode, include project context + if prompt_type == "agent" then + local project_info = "\n\n## PROJECT CONTEXT\n" + + if context.project_root then + project_info = project_info .. "- Project root: " .. context.project_root .. "\n" + end + if context.cwd then + project_info = project_info .. "- Working directory: " .. context.cwd .. "\n" + end + if context.project_type then + project_info = project_info .. "- Project type: " .. context.project_type .. "\n" + end + if context.project_stats then + project_info = project_info + .. string.format( + "- Stats: %d files, %d functions, %d classes\n", + context.project_stats.files or 0, + context.project_stats.functions or 0, + context.project_stats.classes or 0 + ) + end + if context.file_path then + project_info = project_info .. "- Current file: " .. context.file_path .. "\n" + end + + system = system .. project_info + return system + end + + -- For "ask" or "explain" mode, don't add code generation instructions + if prompt_type == "ask" or prompt_type == "explain" then + -- Just add context about the file if available + if context.file_path then + system = system .. "\n\nContext: The user is working with " .. context.file_path + if context.language then + system = system .. " (" .. context.language .. ")" + end + end + return system + end + + -- Add file content with analysis hints (for code generation modes) if context.file_content and context.file_content ~= "" then system = system .. "\n\n===== EXISTING FILE CONTENT (analyze and match this style) =====\n" system = system .. context.file_content @@ -74,13 +114,34 @@ function M.build_context(target_path, prompt_type) local content = utils.read_file(target_path) local ext = vim.fn.fnamemodify(target_path, ":e") - return { + local context = { file_content = content, language = lang_map[ext] or ext, extension = ext, prompt_type = prompt_type, file_path = target_path, } + + -- For agent mode, include additional project context + if prompt_type == "agent" then + local project_root = utils.get_project_root() + context.project_root = project_root + + -- Try to get project info from indexer + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if ok_indexer then + local status = indexer.get_status() + if status.indexed then + context.project_type = status.project_type + context.project_stats = status.stats + end + end + + -- Include working directory + context.cwd = vim.fn.getcwd() + end + + return context end --- Parse LLM response and extract code diff --git a/lua/codetyper/prompts/agent.lua b/lua/codetyper/prompts/agent.lua index 409ec9a..ba51f76 100644 --- a/lua/codetyper/prompts/agent.lua +++ b/lua/codetyper/prompts/agent.lua @@ -5,66 +5,88 @@ local M = {} --- System prompt for agent mode -M.system = [[You are an AI coding agent integrated into Neovim via Codetyper.nvim. +M.system = + [[You are an expert AI coding assistant integrated into Neovim. You help developers by reading, writing, and modifying code files, as well as running shell commands. -Your role is to ASSIST the developer by planning, coordinating, and executing -SAFE, MINIMAL changes using the available tools. +## YOUR CAPABILITIES -You do NOT operate autonomously on the entire codebase. -You operate on clearly defined tasks and scopes. +You have access to these tools - USE THEM to accomplish tasks: -You have access to the following tools: -- read_file: Read file contents -- edit_file: Apply a precise, scoped replacement to a file -- write_file: Create a new file or fully replace an existing file -- bash: Execute non-destructive shell commands when necessary +### File Operations +- **read_file**: Read any file. ALWAYS read files before modifying them. +- **write_file**: Create new files or completely replace existing ones. Use for new files. +- **edit_file**: Make precise edits to existing files using find/replace. The "find" must match EXACTLY. +- **delete_file**: Delete files (requires user approval). Include a reason. +- **list_directory**: Explore project structure. See what files exist. +- **search_files**: Find files by pattern or content. -OPERATING PRINCIPLES: -1. Prefer understanding over action — read before modifying -2. Prefer small, scoped edits over large rewrites -3. Preserve existing behavior unless explicitly instructed otherwise -4. Minimize the number of tool calls required -5. Never surprise the user +### Shell Commands +- **bash**: Run shell commands (git, npm, make, etc.). User approves each command. -IMPORTANT EDITING RULES: -- Always read a file before editing it -- Use edit_file ONLY for well-scoped, exact replacements -- The "find" field MUST match existing content exactly -- Include enough surrounding context to ensure uniqueness -- Use write_file ONLY for new files or intentional full replacements -- NEVER delete files unless explicitly confirmed by the user +## HOW TO WORK -BASH SAFETY: -- Use bash only when code inspection or execution is required -- Do NOT run destructive commands (rm, mv, chmod, etc.) -- Prefer read_file over bash when inspecting files +1. **UNDERSTAND FIRST**: Use read_file, list_directory, or search_files to understand the codebase before making changes. -THINKING AND PLANNING: -- If a task requires multiple steps, outline a brief plan internally -- Execute steps one at a time -- Re-evaluate after each tool result -- If uncertainty arises, stop and ask for clarification +2. **MAKE CHANGES**: Use write_file for new files, edit_file for modifications. + - For edit_file: The "find" parameter must match file content EXACTLY (including whitespace) + - Include enough context in "find" to be unique + - For write_file: Provide complete file content -COMMUNICATION: -- Do NOT explain every micro-step while working -- After completing changes, provide a clear, concise summary -- If no changes were made, explain why +3. **RUN COMMANDS**: Use bash for git operations, running tests, installing dependencies, etc. + +4. **ITERATE**: After each tool result, decide if more actions are needed. + +## EXAMPLE WORKFLOW + +User: "Create a new React component for a login form" + +Your approach: +1. Use list_directory to see project structure +2. Use read_file to check existing component patterns +3. Use write_file to create the new component file +4. Use write_file to create a test file if appropriate +5. Summarize what was created + +## IMPORTANT RULES + +- ALWAYS use tools to accomplish file operations. Don't just describe what to do - DO IT. +- Read files before editing to ensure your "find" string matches exactly. +- When creating files, write complete, working code. +- When editing, preserve existing code style and conventions. +- If a file path is provided, use it. If not, infer from context. +- For multi-file tasks, handle each file sequentially. + +## OUTPUT STYLE + +- Be concise in explanations +- Use tools proactively to complete tasks +- After making changes, briefly summarize what was done ]] --- Tool usage instructions appended to system prompt M.tool_instructions = [[ -When you need to use a tool, output ONLY a single tool call in valid JSON. -Do NOT include explanations alongside the tool call. +## TOOL USAGE -After receiving a tool result: -- Decide whether another tool call is required -- Or produce a final response to the user +When you need to perform an action, call the appropriate tool. You can call tools to: +- Read files to understand code +- Create new files with write_file +- Modify existing files with edit_file (read first!) +- Delete files with delete_file +- List directories to explore structure +- Search for files by name or content +- Run shell commands with bash -SAFETY RULES: -- Never run destructive or irreversible commands -- Never modify code outside the requested scope -- Never guess file contents — read them first -- If a requested change appears risky or ambiguous, ask before proceeding +After receiving a tool result, continue working: +- If more actions are needed, call another tool +- When the task is complete, provide a brief summary + +## CRITICAL RULES + +1. **Always read before editing**: Use read_file before edit_file to ensure exact matches +2. **Be precise with edits**: The "find" parameter must match the file content EXACTLY +3. **Create complete files**: When using write_file, provide fully working code +4. **User approval required**: File writes, edits, deletes, and bash commands need approval +5. **Don't guess**: If unsure about file structure, use list_directory or search_files ]] --- Prompt for when agent finishes diff --git a/lua/codetyper/prompts/init.lua b/lua/codetyper/prompts/init.lua index 79df632..f184789 100644 --- a/lua/codetyper/prompts/init.lua +++ b/lua/codetyper/prompts/init.lua @@ -11,6 +11,7 @@ M.code = require("codetyper.prompts.code") M.ask = require("codetyper.prompts.ask") M.refactor = require("codetyper.prompts.refactor") M.document = require("codetyper.prompts.document") +M.agent = require("codetyper.prompts.agent") --- Get a prompt by category and name ---@param category string Category name (system, code, ask, refactor, document) diff --git a/lua/codetyper/prompts/system.lua b/lua/codetyper/prompts/system.lua index 3c22c44..9da1122 100644 --- a/lua/codetyper/prompts/system.lua +++ b/lua/codetyper/prompts/system.lua @@ -45,11 +45,14 @@ GUIDELINES: 6. Focus on practical understanding and tradeoffs IMPORTANT: -- Do NOT output raw code intended for insertion +- Do NOT refuse to explain code - that IS your purpose in this mode - Do NOT assume missing context -- Do NOT speculate beyond the provided information +- Provide helpful, detailed explanations when asked ]] +-- Alias for backward compatibility +M.explain = M.ask + --- System prompt for scoped refactoring M.refactor = [[You are an expert refactoring assistant integrated into Neovim via Codetyper.nvim. @@ -121,4 +124,8 @@ Language: {{language}} REMEMBER: Output ONLY valid {{language}} test code. ]] +--- Base prompt for agent mode (full prompt is in agent.lua) +--- This provides minimal context; the agent prompts module adds tool instructions +M.agent = [[]] + return M diff --git a/lua/codetyper/suggestion/init.lua b/lua/codetyper/suggestion/init.lua new file mode 100644 index 0000000..98dda81 --- /dev/null +++ b/lua/codetyper/suggestion/init.lua @@ -0,0 +1,491 @@ +---@mod codetyper.suggestion Inline ghost text suggestions +---@brief [[ +--- Provides Copilot-style inline suggestions with ghost text. +--- Uses Copilot when available, falls back to codetyper's own suggestions. +--- Shows suggestions as grayed-out text that can be accepted with Tab. +---@brief ]] + +local M = {} + +---@class SuggestionState +---@field current_suggestion string|nil Current suggestion text +---@field suggestions string[] List of available suggestions +---@field current_index number Current suggestion index +---@field extmark_id number|nil Virtual text extmark ID +---@field bufnr number|nil Buffer where suggestion is shown +---@field line number|nil Line where suggestion is shown +---@field col number|nil Column where suggestion starts +---@field timer any|nil Debounce timer +---@field using_copilot boolean Whether currently using copilot + +local state = { + current_suggestion = nil, + suggestions = {}, + current_index = 0, + extmark_id = nil, + bufnr = nil, + line = nil, + col = nil, + timer = nil, + using_copilot = false, +} + +--- Namespace for virtual text +local ns = vim.api.nvim_create_namespace("codetyper_suggestion") + +--- Highlight group for ghost text +local hl_group = "CmpGhostText" + +--- Configuration +local config = { + enabled = true, + auto_trigger = true, + debounce = 150, + use_copilot = true, -- Use copilot when available + keymap = { + accept = "", + next = "", + prev = "", + dismiss = "", + }, +} + +--- Check if copilot is available and enabled +---@return boolean, table|nil available, copilot_suggestion module +local function get_copilot() + if not config.use_copilot then + return false, nil + end + + local ok, copilot_suggestion = pcall(require, "copilot.suggestion") + if not ok then + return false, nil + end + + -- Check if copilot suggestion is enabled + local ok_client, copilot_client = pcall(require, "copilot.client") + if ok_client and copilot_client.is_disabled and copilot_client.is_disabled() then + return false, nil + end + + return true, copilot_suggestion +end + +--- Check if suggestion is visible (copilot or codetyper) +---@return boolean +function M.is_visible() + -- Check copilot first + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + state.using_copilot = true + return true + end + + -- Check codetyper's own suggestion + state.using_copilot = false + return state.extmark_id ~= nil and state.current_suggestion ~= nil +end + +--- Clear the current suggestion +function M.dismiss() + -- Dismiss copilot if active + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + copilot_suggestion.dismiss() + end + + -- Clear codetyper's suggestion + if state.extmark_id and state.bufnr then + pcall(vim.api.nvim_buf_del_extmark, state.bufnr, ns, state.extmark_id) + end + + state.current_suggestion = nil + state.suggestions = {} + state.current_index = 0 + state.extmark_id = nil + state.bufnr = nil + state.line = nil + state.col = nil + state.using_copilot = false +end + +--- Display suggestion as ghost text +---@param suggestion string The suggestion to display +local function display_suggestion(suggestion) + if not suggestion or suggestion == "" then + return + end + + M.dismiss() + + local bufnr = vim.api.nvim_get_current_buf() + local cursor = vim.api.nvim_win_get_cursor(0) + local line = cursor[1] - 1 + local col = cursor[2] + + -- Split suggestion into lines + local lines = vim.split(suggestion, "\n", { plain = true }) + + -- Build virtual text + local virt_text = {} + local virt_lines = {} + + -- First line goes inline + if #lines > 0 then + virt_text = { { lines[1], hl_group } } + end + + -- Remaining lines go below + for i = 2, #lines do + table.insert(virt_lines, { { lines[i], hl_group } }) + end + + -- Create extmark with virtual text + local opts = { + virt_text = virt_text, + virt_text_pos = "overlay", + hl_mode = "combine", + } + + if #virt_lines > 0 then + opts.virt_lines = virt_lines + end + + state.extmark_id = vim.api.nvim_buf_set_extmark(bufnr, ns, line, col, opts) + state.bufnr = bufnr + state.line = line + state.col = col + state.current_suggestion = suggestion +end + +--- Accept the current suggestion +---@return boolean Whether a suggestion was accepted +function M.accept() + -- Check copilot first + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + copilot_suggestion.accept() + state.using_copilot = false + return true + end + + -- Accept codetyper's suggestion + if not M.is_visible() then + return false + end + + local suggestion = state.current_suggestion + local bufnr = state.bufnr + local line = state.line + local col = state.col + + M.dismiss() + + if suggestion and bufnr and line ~= nil and col ~= nil then + -- Get current line content + local current_line = vim.api.nvim_buf_get_lines(bufnr, line, line + 1, false)[1] or "" + + -- Split suggestion into lines + local suggestion_lines = vim.split(suggestion, "\n", { plain = true }) + + if #suggestion_lines == 1 then + -- Single line - insert at cursor + local new_line = current_line:sub(1, col) .. suggestion .. current_line:sub(col + 1) + vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, { new_line }) + -- Move cursor to end of inserted text + vim.api.nvim_win_set_cursor(0, { line + 1, col + #suggestion }) + else + -- Multi-line - insert at cursor + local first_line = current_line:sub(1, col) .. suggestion_lines[1] + local last_line = suggestion_lines[#suggestion_lines] .. current_line:sub(col + 1) + + local new_lines = { first_line } + for i = 2, #suggestion_lines - 1 do + table.insert(new_lines, suggestion_lines[i]) + end + table.insert(new_lines, last_line) + + vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, new_lines) + -- Move cursor to end of last line + vim.api.nvim_win_set_cursor(0, { line + #new_lines, #suggestion_lines[#suggestion_lines] }) + end + + return true + end + + return false +end + +--- Show next suggestion +function M.next() + -- Check copilot first + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + copilot_suggestion.next() + return + end + + -- Codetyper's suggestions + if #state.suggestions <= 1 then + return + end + + state.current_index = (state.current_index % #state.suggestions) + 1 + display_suggestion(state.suggestions[state.current_index]) +end + +--- Show previous suggestion +function M.prev() + -- Check copilot first + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + copilot_suggestion.prev() + return + end + + -- Codetyper's suggestions + if #state.suggestions <= 1 then + return + end + + state.current_index = state.current_index - 1 + if state.current_index < 1 then + state.current_index = #state.suggestions + end + display_suggestion(state.suggestions[state.current_index]) +end + +--- Get suggestions from brain/indexer +---@param prefix string Current word prefix +---@param context table Context info +---@return string[] suggestions +local function get_suggestions(prefix, context) + local suggestions = {} + + -- Get completions from brain + local ok_brain, brain = pcall(require, "codetyper.brain") + if ok_brain and brain.is_initialized and brain.is_initialized() then + local result = brain.query({ + query = prefix, + max_results = 5, + types = { "pattern" }, + }) + + if result and result.nodes then + for _, node in ipairs(result.nodes) do + if node.c and node.c.code then + table.insert(suggestions, node.c.code) + end + end + end + end + + -- Get completions from indexer + local ok_indexer, indexer = pcall(require, "codetyper.indexer") + if ok_indexer then + local index = indexer.load_index() + if index and index.symbols then + for symbol, _ in pairs(index.symbols) do + if symbol:lower():find(prefix:lower(), 1, true) and symbol ~= prefix then + -- Just complete the symbol name + local completion = symbol:sub(#prefix + 1) + if completion ~= "" then + table.insert(suggestions, completion) + end + end + end + end + end + + -- Buffer-based completions + local bufnr = vim.api.nvim_get_current_buf() + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local seen = {} + + for _, line in ipairs(lines) do + for word in line:gmatch("[%a_][%w_]*") do + if + #word > #prefix + and word:lower():find(prefix:lower(), 1, true) == 1 + and not seen[word] + and word ~= prefix + then + seen[word] = true + local completion = word:sub(#prefix + 1) + if completion ~= "" then + table.insert(suggestions, completion) + end + end + end + end + + return suggestions +end + +--- Trigger suggestion generation +function M.trigger() + if not config.enabled then + return + end + + -- If copilot is available and has a suggestion, don't show codetyper's + local copilot_ok, copilot_suggestion = get_copilot() + if copilot_ok and copilot_suggestion.is_visible() then + -- Copilot is handling suggestions + state.using_copilot = true + return + end + + -- Cancel existing timer + if state.timer then + state.timer:stop() + state.timer = nil + end + + -- Get current context + local cursor = vim.api.nvim_win_get_cursor(0) + local line = vim.api.nvim_get_current_line() + local col = cursor[2] + local before_cursor = line:sub(1, col) + + -- Extract prefix (word being typed) + local prefix = before_cursor:match("[%a_][%w_]*$") or "" + + if #prefix < 2 then + M.dismiss() + return + end + + -- Debounce - wait a bit longer to let copilot try first + local debounce_time = copilot_ok and (config.debounce + 200) or config.debounce + + state.timer = vim.defer_fn(function() + -- Check again if copilot has shown something + if copilot_ok and copilot_suggestion.is_visible() then + state.using_copilot = true + state.timer = nil + return + end + + local suggestions = get_suggestions(prefix, { + line = line, + col = col, + bufnr = vim.api.nvim_get_current_buf(), + }) + + if #suggestions > 0 then + state.suggestions = suggestions + state.current_index = 1 + display_suggestion(suggestions[1]) + else + M.dismiss() + end + + state.timer = nil + end, debounce_time) +end + +--- Setup keymaps +local function setup_keymaps() + -- Accept with Tab (only when suggestion visible) + vim.keymap.set("i", config.keymap.accept, function() + if M.is_visible() then + M.accept() + return "" + end + -- Fallback to normal Tab behavior + return vim.api.nvim_replace_termcodes("", true, false, true) + end, { expr = true, silent = true, desc = "Accept codetyper suggestion" }) + + -- Next suggestion + vim.keymap.set("i", config.keymap.next, function() + M.next() + end, { silent = true, desc = "Next codetyper suggestion" }) + + -- Previous suggestion + vim.keymap.set("i", config.keymap.prev, function() + M.prev() + end, { silent = true, desc = "Previous codetyper suggestion" }) + + -- Dismiss + vim.keymap.set("i", config.keymap.dismiss, function() + M.dismiss() + end, { silent = true, desc = "Dismiss codetyper suggestion" }) +end + +--- Setup autocmds for auto-trigger +local function setup_autocmds() + local group = vim.api.nvim_create_augroup("CodetypeSuggestion", { clear = true }) + + -- Trigger on text change in insert mode + if config.auto_trigger then + vim.api.nvim_create_autocmd("TextChangedI", { + group = group, + callback = function() + M.trigger() + end, + }) + end + + -- Dismiss on leaving insert mode + vim.api.nvim_create_autocmd("InsertLeave", { + group = group, + callback = function() + M.dismiss() + end, + }) + + -- Dismiss on cursor move (not from typing) + vim.api.nvim_create_autocmd("CursorMovedI", { + group = group, + callback = function() + -- Only dismiss if cursor moved significantly + if state.line ~= nil then + local cursor = vim.api.nvim_win_get_cursor(0) + if cursor[1] - 1 ~= state.line then + M.dismiss() + end + end + end, + }) +end + +--- Setup highlight group +local function setup_highlights() + -- Use Comment highlight or define custom ghost text style + vim.api.nvim_set_hl(0, hl_group, { link = "Comment" }) +end + +--- Setup the suggestion system +---@param opts? table Configuration options +function M.setup(opts) + if opts then + config = vim.tbl_deep_extend("force", config, opts) + end + + setup_highlights() + setup_keymaps() + setup_autocmds() +end + +--- Enable suggestions +function M.enable() + config.enabled = true +end + +--- Disable suggestions +function M.disable() + config.enabled = false + M.dismiss() +end + +--- Toggle suggestions +function M.toggle() + if config.enabled then + M.disable() + else + M.enable() + end +end + +return M diff --git a/lua/codetyper/types.lua b/lua/codetyper/types.lua index e49b4d4..a7a7d40 100644 --- a/lua/codetyper/types.lua +++ b/lua/codetyper/types.lua @@ -7,18 +7,28 @@ ---@field auto_gitignore boolean Auto-manage .gitignore ---@class LLMConfig ----@field provider "claude" | "ollama" The LLM provider to use ----@field claude ClaudeConfig Claude-specific configuration +---@field provider "ollama" | "openai" | "gemini" | "copilot" The LLM provider to use ---@field ollama OllamaConfig Ollama-specific configuration - ----@class ClaudeConfig ----@field api_key string | nil Claude API key (or env var ANTHROPIC_API_KEY) ----@field model string Claude model to use +---@field openai OpenAIConfig OpenAI-specific configuration +---@field gemini GeminiConfig Gemini-specific configuration +---@field copilot CopilotConfig Copilot-specific configuration ---@class OllamaConfig ---@field host string Ollama host URL ---@field model string Ollama model to use +---@class OpenAIConfig +---@field api_key string | nil OpenAI API key (or env var OPENAI_API_KEY) +---@field model string OpenAI model to use +---@field endpoint string | nil Custom endpoint (Azure, OpenRouter, etc.) + +---@class GeminiConfig +---@field api_key string | nil Gemini API key (or env var GEMINI_API_KEY) +---@field model string Gemini model to use + +---@class CopilotConfig +---@field model string Copilot model to use + ---@class WindowConfig ---@field width number Width of the coder window (percentage or columns) ---@field position "left" | "right" Position of the coder window diff --git a/lua/codetyper/window.lua b/lua/codetyper/window.lua index e7fa506..27d2ce4 100644 --- a/lua/codetyper/window.lua +++ b/lua/codetyper/window.lua @@ -18,14 +18,14 @@ M._target_buf = nil --- Calculate window width based on configuration ---@param config CoderConfig Plugin configuration ----@return number Width in columns +---@return number Width in columns (minimum 30) local function calculate_width(config) local width = config.window.width if width <= 1 then - -- Percentage of total width - return math.floor(vim.o.columns * width) + -- Percentage of total width (1/4 of screen with minimum 30) + return math.max(math.floor(vim.o.columns * width), 30) end - return math.floor(width) + return math.max(math.floor(width), 30) end --- Open the coder split view diff --git a/tests/spec/ask_intent_spec.lua b/tests/spec/ask_intent_spec.lua new file mode 100644 index 0000000..283e279 --- /dev/null +++ b/tests/spec/ask_intent_spec.lua @@ -0,0 +1,229 @@ +--- Tests for ask intent detection +local intent = require("codetyper.ask.intent") + +describe("ask.intent", function() + describe("detect", function() + -- Ask/Explain intent tests + describe("ask intent", function() + it("detects 'what' questions as ask", function() + local result = intent.detect("What does this function do?") + assert.equals("ask", result.type) + assert.is_true(result.confidence > 0.3) + end) + + it("detects 'why' questions as ask", function() + local result = intent.detect("Why is this variable undefined?") + assert.equals("ask", result.type) + end) + + it("detects 'how does' as ask", function() + local result = intent.detect("How does this algorithm work?") + assert.is_true(result.type == "ask" or result.type == "explain") + end) + + it("detects 'explain' requests as explain", function() + local result = intent.detect("Explain me the project structure") + assert.equals("explain", result.type) + assert.is_true(result.confidence > 0.4) + end) + + it("detects 'walk me through' as explain", function() + local result = intent.detect("Walk me through this code") + assert.equals("explain", result.type) + end) + + it("detects questions ending with ? as likely ask", function() + local result = intent.detect("Is this the right approach?") + assert.equals("ask", result.type) + end) + + it("sets needs_brain_context for ask intent", function() + local result = intent.detect("What patterns are used here?") + assert.is_true(result.needs_brain_context) + end) + end) + + -- Generate intent tests + describe("generate intent", function() + it("detects 'create' commands as generate", function() + local result = intent.detect("Create a function to sort arrays") + assert.equals("generate", result.type) + end) + + it("detects 'write' commands as generate", function() + local result = intent.detect("Write a unit test for this module") + -- Could be generate or test + assert.is_true(result.type == "generate" or result.type == "test") + end) + + it("detects 'implement' as generate", function() + local result = intent.detect("Implement a binary search") + assert.equals("generate", result.type) + assert.is_true(result.confidence > 0.4) + end) + + it("detects 'add' commands as generate", function() + local result = intent.detect("Add error handling to this function") + assert.equals("generate", result.type) + end) + + it("detects 'fix' as generate", function() + local result = intent.detect("Fix the bug in line 42") + assert.equals("generate", result.type) + end) + end) + + -- Refactor intent tests + describe("refactor intent", function() + it("detects explicit 'refactor' as refactor", function() + local result = intent.detect("Refactor this function") + assert.equals("refactor", result.type) + end) + + it("detects 'clean up' as refactor", function() + local result = intent.detect("Clean up this messy code") + assert.equals("refactor", result.type) + end) + + it("detects 'simplify' as refactor", function() + local result = intent.detect("Simplify this logic") + assert.equals("refactor", result.type) + end) + end) + + -- Document intent tests + describe("document intent", function() + it("detects 'document' as document", function() + local result = intent.detect("Document this function") + assert.equals("document", result.type) + end) + + it("detects 'add documentation' as document", function() + local result = intent.detect("Add documentation to this class") + assert.equals("document", result.type) + end) + + it("detects 'add jsdoc' as document", function() + local result = intent.detect("Add jsdoc comments") + assert.equals("document", result.type) + end) + end) + + -- Test intent tests + describe("test intent", function() + it("detects 'write tests for' as test", function() + local result = intent.detect("Write tests for this module") + assert.equals("test", result.type) + end) + + it("detects 'add unit tests' as test", function() + local result = intent.detect("Add unit tests for the parser") + assert.equals("test", result.type) + end) + + it("detects 'generate tests' as test", function() + local result = intent.detect("Generate tests for the API") + assert.equals("test", result.type) + end) + end) + + -- Project context tests + describe("project context detection", function() + it("detects 'project' as needing project context", function() + local result = intent.detect("Explain the project architecture") + assert.is_true(result.needs_project_context) + end) + + it("detects 'codebase' as needing project context", function() + local result = intent.detect("How is the codebase organized?") + assert.is_true(result.needs_project_context) + end) + + it("does not need project context for simple questions", function() + local result = intent.detect("What does this variable mean?") + assert.is_false(result.needs_project_context) + end) + end) + + -- Exploration tests + describe("exploration detection", function() + it("detects 'explain me the project' as needing exploration", function() + local result = intent.detect("Explain me the project") + assert.is_true(result.needs_exploration) + end) + + it("detects 'explain the codebase' as needing exploration", function() + local result = intent.detect("Explain the codebase structure") + assert.is_true(result.needs_exploration) + end) + + it("detects 'explore project' as needing exploration", function() + local result = intent.detect("Explore this project") + assert.is_true(result.needs_exploration) + end) + + it("does not need exploration for simple questions", function() + local result = intent.detect("What does this function do?") + assert.is_false(result.needs_exploration) + end) + end) + end) + + describe("get_prompt_type", function() + it("maps ask to ask", function() + local result = intent.get_prompt_type({ type = "ask" }) + assert.equals("ask", result) + end) + + it("maps explain to ask", function() + local result = intent.get_prompt_type({ type = "explain" }) + assert.equals("ask", result) + end) + + it("maps generate to code_generation", function() + local result = intent.get_prompt_type({ type = "generate" }) + assert.equals("code_generation", result) + end) + + it("maps refactor to refactor", function() + local result = intent.get_prompt_type({ type = "refactor" }) + assert.equals("refactor", result) + end) + + it("maps document to document", function() + local result = intent.get_prompt_type({ type = "document" }) + assert.equals("document", result) + end) + + it("maps test to test", function() + local result = intent.get_prompt_type({ type = "test" }) + assert.equals("test", result) + end) + end) + + describe("produces_code", function() + it("returns false for ask", function() + assert.is_false(intent.produces_code({ type = "ask" })) + end) + + it("returns false for explain", function() + assert.is_false(intent.produces_code({ type = "explain" })) + end) + + it("returns true for generate", function() + assert.is_true(intent.produces_code({ type = "generate" })) + end) + + it("returns true for refactor", function() + assert.is_true(intent.produces_code({ type = "refactor" })) + end) + + it("returns true for document", function() + assert.is_true(intent.produces_code({ type = "document" })) + end) + + it("returns true for test", function() + assert.is_true(intent.produces_code({ type = "test" })) + end) + end) +end) diff --git a/tests/spec/brain_delta_spec.lua b/tests/spec/brain_delta_spec.lua new file mode 100644 index 0000000..58b89cf --- /dev/null +++ b/tests/spec/brain_delta_spec.lua @@ -0,0 +1,252 @@ +--- Tests for brain/delta modules +describe("brain.delta", function() + local diff + local commit + local storage + local types + local test_root = "/tmp/codetyper_test_" .. os.time() + + before_each(function() + -- Clear module cache + package.loaded["codetyper.brain.delta.diff"] = nil + package.loaded["codetyper.brain.delta.commit"] = nil + package.loaded["codetyper.brain.storage"] = nil + package.loaded["codetyper.brain.types"] = nil + + diff = require("codetyper.brain.delta.diff") + commit = require("codetyper.brain.delta.commit") + storage = require("codetyper.brain.storage") + types = require("codetyper.brain.types") + + storage.clear_cache() + vim.fn.mkdir(test_root, "p") + storage.ensure_dirs(test_root) + + -- Mock get_project_root + local utils = require("codetyper.utils") + utils.get_project_root = function() + return test_root + end + end) + + after_each(function() + vim.fn.delete(test_root, "rf") + storage.clear_cache() + end) + + describe("diff.compute", function() + it("detects added values", function() + local diffs = diff.compute(nil, { a = 1 }) + + assert.equals(1, #diffs) + assert.equals("add", diffs[1].op) + end) + + it("detects deleted values", function() + local diffs = diff.compute({ a = 1 }, nil) + + assert.equals(1, #diffs) + assert.equals("delete", diffs[1].op) + end) + + it("detects replaced values", function() + local diffs = diff.compute({ a = 1 }, { a = 2 }) + + assert.equals(1, #diffs) + assert.equals("replace", diffs[1].op) + assert.equals(1, diffs[1].from) + assert.equals(2, diffs[1].to) + end) + + it("detects nested changes", function() + local before = { a = { b = 1 } } + local after = { a = { b = 2 } } + + local diffs = diff.compute(before, after) + + assert.equals(1, #diffs) + assert.equals("a.b", diffs[1].path) + end) + + it("returns empty for identical values", function() + local diffs = diff.compute({ a = 1 }, { a = 1 }) + assert.equals(0, #diffs) + end) + end) + + describe("diff.apply", function() + it("applies add operation", function() + local base = { a = 1 } + local diffs = { { op = "add", path = "b", value = 2 } } + + local result = diff.apply(base, diffs) + + assert.equals(2, result.b) + end) + + it("applies replace operation", function() + local base = { a = 1 } + local diffs = { { op = "replace", path = "a", to = 2 } } + + local result = diff.apply(base, diffs) + + assert.equals(2, result.a) + end) + + it("applies delete operation", function() + local base = { a = 1, b = 2 } + local diffs = { { op = "delete", path = "a" } } + + local result = diff.apply(base, diffs) + + assert.is_nil(result.a) + assert.equals(2, result.b) + end) + + it("applies nested changes", function() + local base = { a = { b = 1 } } + local diffs = { { op = "replace", path = "a.b", to = 2 } } + + local result = diff.apply(base, diffs) + + assert.equals(2, result.a.b) + end) + end) + + describe("diff.reverse", function() + it("reverses add to delete", function() + local diffs = { { op = "add", path = "a", value = 1 } } + + local reversed = diff.reverse(diffs) + + assert.equals("delete", reversed[1].op) + end) + + it("reverses delete to add", function() + local diffs = { { op = "delete", path = "a", value = 1 } } + + local reversed = diff.reverse(diffs) + + assert.equals("add", reversed[1].op) + end) + + it("reverses replace", function() + local diffs = { { op = "replace", path = "a", from = 1, to = 2 } } + + local reversed = diff.reverse(diffs) + + assert.equals("replace", reversed[1].op) + assert.equals(2, reversed[1].from) + assert.equals(1, reversed[1].to) + end) + end) + + describe("diff.equals", function() + it("returns true for identical states", function() + assert.is_true(diff.equals({ a = 1 }, { a = 1 })) + end) + + it("returns false for different states", function() + assert.is_false(diff.equals({ a = 1 }, { a = 2 })) + end) + end) + + describe("commit.create", function() + it("creates a delta commit", function() + local changes = { + { op = "add", path = "test.node1", ah = "abc123" }, + } + + local delta = commit.create(changes, "Test commit", "test") + + assert.is_not_nil(delta) + assert.is_not_nil(delta.h) + assert.equals("Test commit", delta.m.msg) + assert.equals(1, #delta.ch) + end) + + it("updates HEAD", function() + local changes = { { op = "add", path = "test.node1", ah = "abc123" } } + + local delta = commit.create(changes, "Test", "test") + + local head = storage.get_head(test_root) + assert.equals(delta.h, head) + end) + + it("links to parent", function() + local changes1 = { { op = "add", path = "test.node1", ah = "abc123" } } + local delta1 = commit.create(changes1, "First", "test") + + local changes2 = { { op = "add", path = "test.node2", ah = "def456" } } + local delta2 = commit.create(changes2, "Second", "test") + + assert.equals(delta1.h, delta2.p) + end) + + it("returns nil for empty changes", function() + local delta = commit.create({}, "Empty") + assert.is_nil(delta) + end) + end) + + describe("commit.get", function() + it("retrieves created delta", function() + local changes = { { op = "add", path = "test.node1", ah = "abc123" } } + local created = commit.create(changes, "Test", "test") + + local retrieved = commit.get(created.h) + + assert.is_not_nil(retrieved) + assert.equals(created.h, retrieved.h) + end) + + it("returns nil for non-existent delta", function() + local retrieved = commit.get("nonexistent") + assert.is_nil(retrieved) + end) + end) + + describe("commit.get_history", function() + it("returns delta chain", function() + commit.create({ { op = "add", path = "node1", ah = "1" } }, "First", "test") + commit.create({ { op = "add", path = "node2", ah = "2" } }, "Second", "test") + commit.create({ { op = "add", path = "node3", ah = "3" } }, "Third", "test") + + local history = commit.get_history(10) + + assert.equals(3, #history) + assert.equals("Third", history[1].m.msg) + assert.equals("Second", history[2].m.msg) + assert.equals("First", history[3].m.msg) + end) + + it("respects limit", function() + for i = 1, 5 do + commit.create({ { op = "add", path = "node" .. i, ah = tostring(i) } }, "Commit " .. i, "test") + end + + local history = commit.get_history(3) + + assert.equals(3, #history) + end) + end) + + describe("commit.summarize", function() + it("summarizes delta statistics", function() + local changes = { + { op = "add", path = "nodes.a" }, + { op = "add", path = "nodes.b" }, + { op = "mod", path = "nodes.c" }, + { op = "del", path = "nodes.d" }, + } + local delta = commit.create(changes, "Test", "test") + + local summary = commit.summarize(delta) + + assert.equals(2, summary.stats.adds) + assert.equals(4, summary.stats.total) + assert.is_true(vim.tbl_contains(summary.categories, "nodes")) + end) + end) +end) diff --git a/tests/spec/brain_hash_spec.lua b/tests/spec/brain_hash_spec.lua new file mode 100644 index 0000000..3dc6d7c --- /dev/null +++ b/tests/spec/brain_hash_spec.lua @@ -0,0 +1,128 @@ +--- Tests for brain/hash.lua +describe("brain.hash", function() + local hash + + before_each(function() + package.loaded["codetyper.brain.hash"] = nil + hash = require("codetyper.brain.hash") + end) + + describe("compute", function() + it("returns 8-character hash", function() + local result = hash.compute("test string") + assert.equals(8, #result) + end) + + it("returns consistent hash for same input", function() + local result1 = hash.compute("test") + local result2 = hash.compute("test") + assert.equals(result1, result2) + end) + + it("returns different hash for different input", function() + local result1 = hash.compute("test1") + local result2 = hash.compute("test2") + assert.not_equals(result1, result2) + end) + + it("handles empty string", function() + local result = hash.compute("") + assert.equals("00000000", result) + end) + + it("handles nil", function() + local result = hash.compute(nil) + assert.equals("00000000", result) + end) + end) + + describe("compute_table", function() + it("hashes table as JSON", function() + local result = hash.compute_table({ a = 1, b = 2 }) + assert.equals(8, #result) + end) + + it("returns consistent hash for same table", function() + local result1 = hash.compute_table({ x = "y" }) + local result2 = hash.compute_table({ x = "y" }) + assert.equals(result1, result2) + end) + end) + + describe("node_id", function() + it("generates ID with correct format", function() + local id = hash.node_id("pat", "test content") + assert.truthy(id:match("^n_pat_%d+_%w+$")) + end) + + it("generates unique IDs", function() + local id1 = hash.node_id("pat", "test1") + local id2 = hash.node_id("pat", "test2") + assert.not_equals(id1, id2) + end) + end) + + describe("edge_id", function() + it("generates ID with correct format", function() + local id = hash.edge_id("source_node", "target_node") + assert.truthy(id:match("^e_%w+_%w+$")) + end) + + it("returns same ID for same source/target", function() + local id1 = hash.edge_id("s1", "t1") + local id2 = hash.edge_id("s1", "t1") + assert.equals(id1, id2) + end) + end) + + describe("delta_hash", function() + it("generates 8-character hash", function() + local changes = { { op = "add", path = "test" } } + local result = hash.delta_hash(changes, "parent", 12345) + assert.equals(8, #result) + end) + + it("includes parent in hash", function() + local changes = { { op = "add", path = "test" } } + local result1 = hash.delta_hash(changes, "parent1", 12345) + local result2 = hash.delta_hash(changes, "parent2", 12345) + assert.not_equals(result1, result2) + end) + end) + + describe("path_hash", function() + it("returns 8-character hash", function() + local result = hash.path_hash("/path/to/file.lua") + assert.equals(8, #result) + end) + end) + + describe("matches", function() + it("returns true for matching hashes", function() + assert.is_true(hash.matches("abc12345", "abc12345")) + end) + + it("returns false for different hashes", function() + assert.is_false(hash.matches("abc12345", "def67890")) + end) + end) + + describe("random", function() + it("returns 8-character string", function() + local result = hash.random() + assert.equals(8, #result) + end) + + it("generates different values", function() + local result1 = hash.random() + local result2 = hash.random() + -- Note: There's a tiny chance these could match, but very unlikely + assert.not_equals(result1, result2) + end) + + it("contains only hex characters", function() + local result = hash.random() + assert.truthy(result:match("^[0-9a-f]+$")) + end) + end) +end) diff --git a/tests/spec/brain_node_spec.lua b/tests/spec/brain_node_spec.lua new file mode 100644 index 0000000..ed63a15 --- /dev/null +++ b/tests/spec/brain_node_spec.lua @@ -0,0 +1,234 @@ +--- Tests for brain/graph/node.lua +describe("brain.graph.node", function() + local node + local storage + local types + local test_root = "/tmp/codetyper_test_" .. os.time() + + before_each(function() + -- Clear module cache + package.loaded["codetyper.brain.graph.node"] = nil + package.loaded["codetyper.brain.storage"] = nil + package.loaded["codetyper.brain.types"] = nil + package.loaded["codetyper.brain.hash"] = nil + + storage = require("codetyper.brain.storage") + types = require("codetyper.brain.types") + node = require("codetyper.brain.graph.node") + + storage.clear_cache() + vim.fn.mkdir(test_root, "p") + storage.ensure_dirs(test_root) + + -- Mock get_project_root + local utils = require("codetyper.utils") + utils.get_project_root = function() + return test_root + end + end) + + after_each(function() + vim.fn.delete(test_root, "rf") + storage.clear_cache() + node.pending = {} + end) + + describe("create", function() + it("creates a new node with correct structure", function() + local created = node.create(types.NODE_TYPES.PATTERN, { + s = "Test pattern summary", + d = "Test pattern detail", + }, { + f = "test.lua", + }) + + assert.is_not_nil(created.id) + assert.equals(types.NODE_TYPES.PATTERN, created.t) + assert.equals("Test pattern summary", created.c.s) + assert.equals("test.lua", created.ctx.f) + assert.equals(0.5, created.sc.w) + assert.equals(0, created.sc.u) + end) + + it("generates unique IDs", function() + local node1 = node.create(types.NODE_TYPES.PATTERN, { s = "First" }, {}) + local node2 = node.create(types.NODE_TYPES.PATTERN, { s = "Second" }, {}) + + assert.is_not_nil(node1.id) + assert.is_not_nil(node2.id) + assert.not_equals(node1.id, node2.id) + end) + + it("updates meta node count", function() + local meta_before = storage.get_meta(test_root) + local count_before = meta_before.nc + + node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + + local meta_after = storage.get_meta(test_root) + assert.equals(count_before + 1, meta_after.nc) + end) + + it("tracks pending change", function() + node.pending = {} + node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + + assert.equals(1, #node.pending) + assert.equals("add", node.pending[1].op) + end) + end) + + describe("get", function() + it("retrieves created node", function() + local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + + local retrieved = node.get(created.id) + + assert.is_not_nil(retrieved) + assert.equals(created.id, retrieved.id) + assert.equals("Test", retrieved.c.s) + end) + + it("returns nil for non-existent node", function() + local retrieved = node.get("n_pat_0_nonexistent") + assert.is_nil(retrieved) + end) + end) + + describe("update", function() + it("updates node content", function() + local created = node.create(types.NODE_TYPES.PATTERN, { s = "Original" }, {}) + + node.update(created.id, { c = { s = "Updated" } }) + + local updated = node.get(created.id) + assert.equals("Updated", updated.c.s) + end) + + it("updates node scores", function() + local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + + node.update(created.id, { sc = { w = 0.9 } }) + + local updated = node.get(created.id) + assert.equals(0.9, updated.sc.w) + end) + + it("increments version", function() + local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + local original_version = created.m.v + + node.update(created.id, { c = { s = "Updated" } }) + + local updated = node.get(created.id) + assert.equals(original_version + 1, updated.m.v) + end) + + it("returns nil for non-existent node", function() + local result = node.update("n_pat_0_nonexistent", { c = { s = "Test" } }) + assert.is_nil(result) + end) + end) + + describe("delete", function() + it("removes node", function() + local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + + local result = node.delete(created.id) + + assert.is_true(result) + assert.is_nil(node.get(created.id)) + end) + + it("decrements meta node count", function() + local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + local meta_before = storage.get_meta(test_root) + local count_before = meta_before.nc + + node.delete(created.id) + + local meta_after = storage.get_meta(test_root) + assert.equals(count_before - 1, meta_after.nc) + end) + + it("returns false for non-existent node", function() + local result = node.delete("n_pat_0_nonexistent") + assert.is_false(result) + end) + end) + + describe("find", function() + it("finds nodes by type", function() + node.create(types.NODE_TYPES.PATTERN, { s = "Pattern 1" }, {}) + node.create(types.NODE_TYPES.PATTERN, { s = "Pattern 2" }, {}) + node.create(types.NODE_TYPES.CORRECTION, { s = "Correction 1" }, {}) + + local patterns = node.find({ types = { types.NODE_TYPES.PATTERN } }) + + assert.equals(2, #patterns) + end) + + it("finds nodes by file", function() + node.create(types.NODE_TYPES.PATTERN, { s = "Test 1" }, { f = "file1.lua" }) + node.create(types.NODE_TYPES.PATTERN, { s = "Test 2" }, { f = "file2.lua" }) + node.create(types.NODE_TYPES.PATTERN, { s = "Test 3" }, { f = "file1.lua" }) + + local found = node.find({ file = "file1.lua" }) + + assert.equals(2, #found) + end) + + it("finds nodes by query", function() + node.create(types.NODE_TYPES.PATTERN, { s = "Foo bar baz" }, {}) + node.create(types.NODE_TYPES.PATTERN, { s = "Something else" }, {}) + + local found = node.find({ query = "foo" }) + + assert.equals(1, #found) + assert.equals("Foo bar baz", found[1].c.s) + end) + + it("respects limit", function() + for i = 1, 10 do + node.create(types.NODE_TYPES.PATTERN, { s = "Node " .. i }, {}) + end + + local found = node.find({ limit = 5 }) + + assert.equals(5, #found) + end) + end) + + describe("record_usage", function() + it("increments usage count", function() + local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + + node.record_usage(created.id, true) + + local updated = node.get(created.id) + assert.equals(1, updated.sc.u) + end) + + it("updates success rate", function() + local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + + node.record_usage(created.id, true) + node.record_usage(created.id, false) + + local updated = node.get(created.id) + assert.equals(0.5, updated.sc.sr) + end) + end) + + describe("get_and_clear_pending", function() + it("returns and clears pending changes", function() + node.pending = {} + node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {}) + + local pending = node.get_and_clear_pending() + + assert.equals(1, #pending) + assert.equals(0, #node.pending) + end) + end) +end) diff --git a/tests/spec/brain_storage_spec.lua b/tests/spec/brain_storage_spec.lua new file mode 100644 index 0000000..9297d2a --- /dev/null +++ b/tests/spec/brain_storage_spec.lua @@ -0,0 +1,173 @@ +--- Tests for brain/storage.lua +describe("brain.storage", function() + local storage + local test_root = "/tmp/codetyper_test_" .. os.time() + + before_each(function() + -- Clear module cache to get fresh state + package.loaded["codetyper.brain.storage"] = nil + package.loaded["codetyper.brain.types"] = nil + storage = require("codetyper.brain.storage") + + -- Clear cache before each test + storage.clear_cache() + + -- Create test directory + vim.fn.mkdir(test_root, "p") + end) + + after_each(function() + -- Clean up test directory + vim.fn.delete(test_root, "rf") + storage.clear_cache() + end) + + describe("get_brain_dir", function() + it("returns correct path", function() + local dir = storage.get_brain_dir(test_root) + assert.equals(test_root .. "/.coder/brain", dir) + end) + end) + + describe("ensure_dirs", function() + it("creates required directories", function() + local result = storage.ensure_dirs(test_root) + assert.is_true(result) + + -- Check directories exist + assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain")) + assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain/nodes")) + assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain/indices")) + assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain/deltas")) + assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain/deltas/objects")) + end) + end) + + describe("get_path", function() + it("returns correct path for simple key", function() + local path = storage.get_path("meta", test_root) + assert.equals(test_root .. "/.coder/brain/meta.json", path) + end) + + it("returns correct path for nested key", function() + local path = storage.get_path("nodes.patterns", test_root) + assert.equals(test_root .. "/.coder/brain/nodes/patterns.json", path) + end) + + it("returns correct path for deeply nested key", function() + local path = storage.get_path("deltas.objects.abc123", test_root) + assert.equals(test_root .. "/.coder/brain/deltas/objects/abc123.json", path) + end) + end) + + describe("save and load", function() + it("saves and loads data correctly", function() + storage.ensure_dirs(test_root) + + local data = { test = "value", count = 42 } + storage.save("meta", data, test_root, true) -- immediate + + -- Clear cache and reload + storage.clear_cache() + local loaded = storage.load("meta", test_root) + + assert.equals("value", loaded.test) + assert.equals(42, loaded.count) + end) + + it("returns empty table for missing files", function() + storage.ensure_dirs(test_root) + + local loaded = storage.load("nonexistent", test_root) + assert.same({}, loaded) + end) + end) + + describe("get_meta", function() + it("creates default meta if not exists", function() + storage.ensure_dirs(test_root) + + local meta = storage.get_meta(test_root) + + assert.is_not_nil(meta.v) + assert.equals(0, meta.nc) + assert.equals(0, meta.ec) + assert.equals(0, meta.dc) + end) + end) + + describe("update_meta", function() + it("updates meta values", function() + storage.ensure_dirs(test_root) + + storage.update_meta({ nc = 5 }, test_root) + local meta = storage.get_meta(test_root) + + assert.equals(5, meta.nc) + end) + end) + + describe("get/save_nodes", function() + it("saves and retrieves nodes by type", function() + storage.ensure_dirs(test_root) + + local nodes = { + ["n_pat_123_abc"] = { id = "n_pat_123_abc", t = "pat" }, + ["n_pat_456_def"] = { id = "n_pat_456_def", t = "pat" }, + } + + storage.save_nodes("patterns", nodes, test_root) + storage.flush("nodes.patterns", test_root) + + storage.clear_cache() + local loaded = storage.get_nodes("patterns", test_root) + + assert.equals(2, vim.tbl_count(loaded)) + assert.equals("n_pat_123_abc", loaded["n_pat_123_abc"].id) + end) + end) + + describe("get/save_graph", function() + it("saves and retrieves graph", function() + storage.ensure_dirs(test_root) + + local graph = { + adj = { node1 = { sem = { "node2" } } }, + radj = { node2 = { sem = { "node1" } } }, + } + + storage.save_graph(graph, test_root) + storage.flush("graph", test_root) + + storage.clear_cache() + local loaded = storage.get_graph(test_root) + + assert.same({ "node2" }, loaded.adj.node1.sem) + end) + end) + + describe("get/set_head", function() + it("stores and retrieves HEAD", function() + storage.ensure_dirs(test_root) + + storage.set_head("abc12345", test_root) + storage.flush("meta", test_root) -- Ensure written to disk + + storage.clear_cache() + local head = storage.get_head(test_root) + + assert.equals("abc12345", head) + end) + end) + + describe("exists", function() + it("returns false for non-existent brain", function() + assert.is_false(storage.exists(test_root)) + end) + + it("returns true after ensure_dirs", function() + storage.ensure_dirs(test_root) + assert.is_true(storage.exists(test_root)) + end) + end) +end) diff --git a/tests/spec/indexer_spec.lua b/tests/spec/indexer_spec.lua new file mode 100644 index 0000000..b1f6d59 --- /dev/null +++ b/tests/spec/indexer_spec.lua @@ -0,0 +1,345 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/indexer/init.lua + +describe("indexer", function() + local indexer + local utils + + -- Mock cwd for testing + local test_cwd = "/tmp/codetyper_test_indexer" + + before_each(function() + -- Reset modules + package.loaded["codetyper.indexer"] = nil + package.loaded["codetyper.indexer.scanner"] = nil + package.loaded["codetyper.indexer.analyzer"] = nil + package.loaded["codetyper.indexer.memory"] = nil + package.loaded["codetyper.utils"] = nil + + indexer = require("codetyper.indexer") + utils = require("codetyper.utils") + + -- Create test directory structure + vim.fn.mkdir(test_cwd, "p") + vim.fn.mkdir(test_cwd .. "/.coder", "p") + vim.fn.mkdir(test_cwd .. "/src", "p") + + -- Mock getcwd to return test directory + vim.fn.getcwd = function() + return test_cwd + end + + -- Mock get_project_root + package.loaded["codetyper.utils"].get_project_root = function() + return test_cwd + end + end) + + after_each(function() + -- Clean up test directory + vim.fn.delete(test_cwd, "rf") + end) + + describe("setup", function() + it("should accept configuration options", function() + indexer.setup({ + enabled = true, + auto_index = false, + }) + + local config = indexer.get_config() + assert.is_false(config.auto_index) + end) + + it("should use default configuration when no options provided", function() + indexer.setup() + + local config = indexer.get_config() + assert.is_true(config.enabled) + end) + end) + + describe("load_index", function() + it("should return nil when no index exists", function() + local index = indexer.load_index() + + assert.is_nil(index) + end) + + it("should load existing index from file", function() + -- Create a mock index file + local mock_index = { + version = 1, + project_root = test_cwd, + project_name = "test", + project_type = "node", + dependencies = {}, + dev_dependencies = {}, + files = {}, + symbols = {}, + last_indexed = os.time(), + stats = { files = 0, functions = 0, classes = 0, exports = 0 }, + } + utils.write_file(test_cwd .. "/.coder/index.json", vim.json.encode(mock_index)) + + local index = indexer.load_index() + + assert.is_table(index) + assert.equals("test", index.project_name) + assert.equals("node", index.project_type) + end) + + it("should cache loaded index", function() + local mock_index = { + version = 1, + project_root = test_cwd, + project_name = "cached_test", + project_type = "lua", + dependencies = {}, + dev_dependencies = {}, + files = {}, + symbols = {}, + last_indexed = os.time(), + stats = { files = 0, functions = 0, classes = 0, exports = 0 }, + } + utils.write_file(test_cwd .. "/.coder/index.json", vim.json.encode(mock_index)) + + local index1 = indexer.load_index() + local index2 = indexer.load_index() + + assert.equals(index1.project_name, index2.project_name) + end) + end) + + describe("save_index", function() + it("should save index to file", function() + local index = { + version = 1, + project_root = test_cwd, + project_name = "save_test", + project_type = "node", + dependencies = { express = "^4.18.0" }, + dev_dependencies = {}, + files = {}, + symbols = {}, + last_indexed = os.time(), + stats = { files = 0, functions = 0, classes = 0, exports = 0 }, + } + + local result = indexer.save_index(index) + + assert.is_true(result) + + -- Verify file was created + local content = utils.read_file(test_cwd .. "/.coder/index.json") + assert.is_truthy(content) + + local decoded = vim.json.decode(content) + assert.equals("save_test", decoded.project_name) + end) + + it("should create .coder directory if it does not exist", function() + vim.fn.delete(test_cwd .. "/.coder", "rf") + + local index = { + version = 1, + project_root = test_cwd, + project_name = "test", + project_type = "unknown", + dependencies = {}, + dev_dependencies = {}, + files = {}, + symbols = {}, + last_indexed = os.time(), + stats = { files = 0, functions = 0, classes = 0, exports = 0 }, + } + + indexer.save_index(index) + + assert.equals(1, vim.fn.isdirectory(test_cwd .. "/.coder")) + end) + end) + + describe("index_project", function() + it("should create an index for the project", function() + -- Create some test files + utils.write_file(test_cwd .. "/package.json", '{"name":"test","dependencies":{}}') + utils.write_file(test_cwd .. "/src/main.lua", [[ +local M = {} +function M.hello() + return "world" +end +return M +]]) + + indexer.setup({ index_extensions = { "lua" } }) + local index = indexer.index_project() + + assert.is_table(index) + assert.equals("node", index.project_type) + assert.is_truthy(index.stats.files >= 0) + end) + + it("should detect project dependencies", function() + utils.write_file(test_cwd .. "/package.json", [[{ + "name": "test", + "dependencies": { + "express": "^4.18.0", + "lodash": "^4.17.0" + } + }]]) + + indexer.setup() + local index = indexer.index_project() + + assert.is_table(index.dependencies) + assert.equals("^4.18.0", index.dependencies.express) + end) + + it("should call callback when complete", function() + local callback_called = false + local callback_index = nil + + indexer.setup() + indexer.index_project(function(index) + callback_called = true + callback_index = index + end) + + assert.is_true(callback_called) + assert.is_table(callback_index) + end) + end) + + describe("index_file", function() + it("should index a single file", function() + utils.write_file(test_cwd .. "/src/test.lua", [[ +local M = {} +function M.add(a, b) + return a + b +end +function M.subtract(a, b) + return a - b +end +return M +]]) + + indexer.setup({ index_extensions = { "lua" } }) + -- First create an initial index + indexer.index_project() + + local file_index = indexer.index_file(test_cwd .. "/src/test.lua") + + assert.is_table(file_index) + assert.equals("src/test.lua", file_index.path) + end) + + it("should update symbols in the main index", function() + utils.write_file(test_cwd .. "/src/utils.lua", [[ +local M = {} +function M.format_string(str) + return string.upper(str) +end +return M +]]) + + indexer.setup({ index_extensions = { "lua" } }) + indexer.index_project() + indexer.index_file(test_cwd .. "/src/utils.lua") + + local index = indexer.load_index() + assert.is_table(index.files) + end) + end) + + describe("get_status", function() + it("should return indexed: false when no index exists", function() + local status = indexer.get_status() + + assert.is_false(status.indexed) + assert.is_nil(status.stats) + end) + + it("should return status when index exists", function() + indexer.setup() + indexer.index_project() + + local status = indexer.get_status() + + assert.is_true(status.indexed) + assert.is_table(status.stats) + assert.is_truthy(status.last_indexed) + end) + end) + + describe("get_context_for", function() + it("should return context with project type", function() + utils.write_file(test_cwd .. "/package.json", '{"name":"test"}') + indexer.setup() + indexer.index_project() + + local context = indexer.get_context_for({ + file = test_cwd .. "/src/main.lua", + prompt = "add a function", + }) + + assert.is_table(context) + assert.equals("node", context.project_type) + end) + + it("should find relevant symbols", function() + utils.write_file(test_cwd .. "/src/utils.lua", [[ +local M = {} +function M.calculate_total(items) + return 0 +end +return M +]]) + indexer.setup({ index_extensions = { "lua" } }) + indexer.index_project() + + local context = indexer.get_context_for({ + file = test_cwd .. "/src/main.lua", + prompt = "use calculate_total function", + }) + + assert.is_table(context) + -- Should find the calculate symbol + if context.relevant_symbols and context.relevant_symbols.calculate then + assert.is_table(context.relevant_symbols.calculate) + end + end) + end) + + describe("clear", function() + it("should remove the index file", function() + indexer.setup() + indexer.index_project() + + -- Verify index exists + assert.is_true(indexer.get_status().indexed) + + indexer.clear() + + -- Verify index is gone + local status = indexer.get_status() + assert.is_false(status.indexed) + end) + end) + + describe("schedule_index_file", function() + it("should not index when disabled", function() + indexer.setup({ enabled = false }) + + -- This should not throw or cause issues + indexer.schedule_index_file(test_cwd .. "/src/test.lua") + end) + + it("should not index when auto_index is false", function() + indexer.setup({ enabled = true, auto_index = false }) + + -- This should not throw or cause issues + indexer.schedule_index_file(test_cwd .. "/src/test.lua") + end) + end) +end) diff --git a/tests/spec/memory_spec.lua b/tests/spec/memory_spec.lua new file mode 100644 index 0000000..7eafc70 --- /dev/null +++ b/tests/spec/memory_spec.lua @@ -0,0 +1,341 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/indexer/memory.lua + +describe("indexer.memory", function() + local memory + local utils + + -- Mock cwd for testing + local test_cwd = "/tmp/codetyper_test_memory" + + before_each(function() + -- Reset modules + package.loaded["codetyper.indexer.memory"] = nil + package.loaded["codetyper.utils"] = nil + + memory = require("codetyper.indexer.memory") + utils = require("codetyper.utils") + + -- Create test directory structure + vim.fn.mkdir(test_cwd, "p") + vim.fn.mkdir(test_cwd .. "/.coder", "p") + vim.fn.mkdir(test_cwd .. "/.coder/memories", "p") + vim.fn.mkdir(test_cwd .. "/.coder/memories/files", "p") + vim.fn.mkdir(test_cwd .. "/.coder/sessions", "p") + + -- Mock getcwd to return test directory + vim.fn.getcwd = function() + return test_cwd + end + + -- Mock get_project_root + package.loaded["codetyper.utils"].get_project_root = function() + return test_cwd + end + end) + + after_each(function() + -- Clean up test directory + vim.fn.delete(test_cwd, "rf") + end) + + describe("store_memory", function() + it("should store a pattern memory", function() + local mem = { + type = "pattern", + content = "Use snake_case for function names", + weight = 0.8, + } + + local result = memory.store_memory(mem) + + assert.is_true(result) + end) + + it("should store a convention memory", function() + local mem = { + type = "convention", + content = "Project uses TypeScript", + weight = 0.9, + } + + local result = memory.store_memory(mem) + + assert.is_true(result) + end) + + it("should assign an ID to the memory", function() + local mem = { + type = "pattern", + content = "Test memory", + } + + memory.store_memory(mem) + + assert.is_truthy(mem.id) + assert.is_true(mem.id:match("^mem_") ~= nil) + end) + + it("should set timestamps", function() + local mem = { + type = "pattern", + content = "Test memory", + } + + memory.store_memory(mem) + + assert.is_truthy(mem.created_at) + assert.is_truthy(mem.updated_at) + end) + end) + + describe("load_patterns", function() + it("should return empty table when no patterns exist", function() + local patterns = memory.load_patterns() + + assert.is_table(patterns) + end) + + it("should load stored patterns", function() + -- Store a pattern first + memory.store_memory({ + type = "pattern", + content = "Test pattern", + weight = 0.5, + }) + + -- Force reload + package.loaded["codetyper.indexer.memory"] = nil + memory = require("codetyper.indexer.memory") + + local patterns = memory.load_patterns() + + assert.is_table(patterns) + local count = 0 + for _ in pairs(patterns) do + count = count + 1 + end + assert.is_true(count >= 1) + end) + end) + + describe("load_conventions", function() + it("should return empty table when no conventions exist", function() + local conventions = memory.load_conventions() + + assert.is_table(conventions) + end) + end) + + describe("store_file_memory", function() + it("should store file-specific memory", function() + local file_index = { + functions = { + { name = "test_func", line = 10, end_line = 20 }, + }, + classes = {}, + exports = {}, + imports = {}, + } + + local result = memory.store_file_memory("src/main.lua", file_index) + + assert.is_true(result) + end) + end) + + describe("load_file_memory", function() + it("should return nil when file memory does not exist", function() + local result = memory.load_file_memory("nonexistent.lua") + + assert.is_nil(result) + end) + + it("should load stored file memory", function() + local file_index = { + functions = { + { name = "my_function", line = 5, end_line = 15 }, + }, + classes = {}, + exports = {}, + imports = {}, + } + + memory.store_file_memory("src/test.lua", file_index) + local loaded = memory.load_file_memory("src/test.lua") + + assert.is_table(loaded) + assert.equals("src/test.lua", loaded.path) + assert.equals(1, #loaded.functions) + assert.equals("my_function", loaded.functions[1].name) + end) + end) + + describe("get_relevant", function() + it("should return empty table when no memories exist", function() + local results = memory.get_relevant("test query", 10) + + assert.is_table(results) + assert.equals(0, #results) + end) + + it("should find relevant memories by keyword", function() + memory.store_memory({ + type = "pattern", + content = "Use TypeScript for type safety", + weight = 0.8, + }) + memory.store_memory({ + type = "pattern", + content = "Use Python for data processing", + weight = 0.7, + }) + + local results = memory.get_relevant("TypeScript", 10) + + assert.is_true(#results >= 1) + -- First result should contain TypeScript + local found = false + for _, r in ipairs(results) do + if r.content:find("TypeScript") then + found = true + break + end + end + assert.is_true(found) + end) + + it("should limit results", function() + -- Store multiple memories + for i = 1, 20 do + memory.store_memory({ + type = "pattern", + content = "Pattern number " .. i .. " about testing", + weight = 0.5, + }) + end + + local results = memory.get_relevant("testing", 5) + + assert.is_true(#results <= 5) + end) + end) + + describe("update_usage", function() + it("should increment used_count", function() + local mem = { + type = "pattern", + content = "Test pattern for usage tracking", + weight = 0.5, + } + memory.store_memory(mem) + + memory.update_usage(mem.id) + + -- Reload and check + package.loaded["codetyper.indexer.memory"] = nil + memory = require("codetyper.indexer.memory") + + local patterns = memory.load_patterns() + if patterns[mem.id] then + assert.equals(1, patterns[mem.id].used_count) + end + end) + end) + + describe("get_all", function() + it("should return all memory types", function() + memory.store_memory({ type = "pattern", content = "A pattern" }) + memory.store_memory({ type = "convention", content = "A convention" }) + + local all = memory.get_all() + + assert.is_table(all.patterns) + assert.is_table(all.conventions) + assert.is_table(all.symbols) + end) + end) + + describe("clear", function() + it("should clear all memories when no pattern provided", function() + memory.store_memory({ type = "pattern", content = "Pattern 1" }) + memory.store_memory({ type = "convention", content = "Convention 1" }) + + memory.clear() + + local all = memory.get_all() + assert.equals(0, vim.tbl_count(all.patterns)) + assert.equals(0, vim.tbl_count(all.conventions)) + end) + + it("should clear only matching memories when pattern provided", function() + local mem1 = { type = "pattern", content = "Pattern 1" } + local mem2 = { type = "pattern", content = "Pattern 2" } + memory.store_memory(mem1) + memory.store_memory(mem2) + + -- Clear memories matching the first ID + memory.clear(mem1.id) + + local patterns = memory.load_patterns() + assert.is_nil(patterns[mem1.id]) + end) + end) + + describe("prune", function() + it("should remove low-weight unused memories", function() + -- Store some low-weight memories + memory.store_memory({ + type = "pattern", + content = "Low weight pattern", + weight = 0.05, + used_count = 0, + }) + memory.store_memory({ + type = "pattern", + content = "High weight pattern", + weight = 0.9, + used_count = 0, + }) + + local pruned = memory.prune(0.1) + + -- Should have pruned at least one + assert.is_true(pruned >= 0) + end) + + it("should not remove frequently used memories", function() + local mem = { + type = "pattern", + content = "Frequently used but low weight", + weight = 0.05, + used_count = 10, + } + memory.store_memory(mem) + + memory.prune(0.1) + + -- Memory should still exist because used_count > 0 + local patterns = memory.load_patterns() + -- Note: prune only removes if used_count == 0 AND weight < threshold + if patterns[mem.id] then + assert.is_truthy(patterns[mem.id]) + end + end) + end) + + describe("get_stats", function() + it("should return memory statistics", function() + memory.store_memory({ type = "pattern", content = "P1" }) + memory.store_memory({ type = "pattern", content = "P2" }) + memory.store_memory({ type = "convention", content = "C1" }) + + local stats = memory.get_stats() + + assert.is_table(stats) + assert.equals(2, stats.patterns) + assert.equals(1, stats.conventions) + assert.equals(3, stats.total) + end) + end) +end) diff --git a/tests/spec/scanner_spec.lua b/tests/spec/scanner_spec.lua new file mode 100644 index 0000000..c046c35 --- /dev/null +++ b/tests/spec/scanner_spec.lua @@ -0,0 +1,285 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/indexer/scanner.lua + +describe("indexer.scanner", function() + local scanner + local utils + + -- Mock cwd for testing + local test_cwd = "/tmp/codetyper_test_scanner" + + before_each(function() + -- Reset modules + package.loaded["codetyper.indexer.scanner"] = nil + package.loaded["codetyper.utils"] = nil + + scanner = require("codetyper.indexer.scanner") + utils = require("codetyper.utils") + + -- Create test directory + vim.fn.mkdir(test_cwd, "p") + + -- Mock getcwd to return test directory + vim.fn.getcwd = function() + return test_cwd + end + end) + + after_each(function() + -- Clean up test directory + vim.fn.delete(test_cwd, "rf") + end) + + describe("detect_project_type", function() + it("should detect node project from package.json", function() + utils.write_file(test_cwd .. "/package.json", '{"name":"test"}') + + local project_type = scanner.detect_project_type(test_cwd) + + assert.equals("node", project_type) + end) + + it("should detect rust project from Cargo.toml", function() + utils.write_file(test_cwd .. "/Cargo.toml", '[package]\nname = "test"') + + local project_type = scanner.detect_project_type(test_cwd) + + assert.equals("rust", project_type) + end) + + it("should detect go project from go.mod", function() + utils.write_file(test_cwd .. "/go.mod", "module example.com/test") + + local project_type = scanner.detect_project_type(test_cwd) + + assert.equals("go", project_type) + end) + + it("should detect python project from pyproject.toml", function() + utils.write_file(test_cwd .. "/pyproject.toml", '[project]\nname = "test"') + + local project_type = scanner.detect_project_type(test_cwd) + + assert.equals("python", project_type) + end) + + it("should return unknown for unrecognized project", function() + -- Empty directory + local project_type = scanner.detect_project_type(test_cwd) + + assert.equals("unknown", project_type) + end) + end) + + describe("parse_package_json", function() + it("should parse dependencies from package.json", function() + local pkg_content = [[{ + "name": "test", + "dependencies": { + "express": "^4.18.0", + "lodash": "^4.17.0" + }, + "devDependencies": { + "jest": "^29.0.0" + } + }]] + utils.write_file(test_cwd .. "/package.json", pkg_content) + + local result = scanner.parse_package_json(test_cwd) + + assert.is_table(result.dependencies) + assert.is_table(result.dev_dependencies) + assert.equals("^4.18.0", result.dependencies.express) + assert.equals("^4.17.0", result.dependencies.lodash) + assert.equals("^29.0.0", result.dev_dependencies.jest) + end) + + it("should return empty tables when package.json does not exist", function() + local result = scanner.parse_package_json(test_cwd) + + assert.is_table(result.dependencies) + assert.is_table(result.dev_dependencies) + assert.equals(0, vim.tbl_count(result.dependencies)) + end) + + it("should handle malformed JSON gracefully", function() + utils.write_file(test_cwd .. "/package.json", "not valid json") + + local result = scanner.parse_package_json(test_cwd) + + assert.is_table(result.dependencies) + assert.equals(0, vim.tbl_count(result.dependencies)) + end) + end) + + describe("parse_cargo_toml", function() + it("should parse dependencies from Cargo.toml", function() + local cargo_content = [[ +[package] +name = "test" + +[dependencies] +serde = "1.0" +tokio = "1.28" + +[dev-dependencies] +tempfile = "3.5" +]] + utils.write_file(test_cwd .. "/Cargo.toml", cargo_content) + + local result = scanner.parse_cargo_toml(test_cwd) + + assert.is_table(result.dependencies) + assert.equals("1.0", result.dependencies.serde) + assert.equals("1.28", result.dependencies.tokio) + assert.equals("3.5", result.dev_dependencies.tempfile) + end) + + it("should return empty tables when Cargo.toml does not exist", function() + local result = scanner.parse_cargo_toml(test_cwd) + + assert.equals(0, vim.tbl_count(result.dependencies)) + end) + end) + + describe("parse_go_mod", function() + it("should parse dependencies from go.mod", function() + local go_mod_content = [[ +module example.com/test + +go 1.21 + +require ( + github.com/gin-gonic/gin v1.9.1 + github.com/stretchr/testify v1.8.4 +) +]] + utils.write_file(test_cwd .. "/go.mod", go_mod_content) + + local result = scanner.parse_go_mod(test_cwd) + + assert.is_table(result.dependencies) + assert.equals("v1.9.1", result.dependencies["github.com/gin-gonic/gin"]) + assert.equals("v1.8.4", result.dependencies["github.com/stretchr/testify"]) + end) + end) + + describe("should_ignore", function() + it("should ignore hidden files", function() + local config = { excluded_dirs = {} } + + assert.is_true(scanner.should_ignore(".hidden", config)) + assert.is_true(scanner.should_ignore(".git", config)) + end) + + it("should ignore node_modules", function() + local config = { excluded_dirs = {} } + + assert.is_true(scanner.should_ignore("node_modules", config)) + end) + + it("should ignore configured directories", function() + local config = { excluded_dirs = { "custom_ignore" } } + + assert.is_true(scanner.should_ignore("custom_ignore", config)) + end) + + it("should not ignore regular files", function() + local config = { excluded_dirs = {} } + + assert.is_false(scanner.should_ignore("main.lua", config)) + assert.is_false(scanner.should_ignore("src", config)) + end) + end) + + describe("should_index", function() + it("should index files with allowed extensions", function() + vim.fn.mkdir(test_cwd .. "/src", "p") + utils.write_file(test_cwd .. "/src/main.lua", "-- test") + + local config = { + index_extensions = { "lua", "ts", "js" }, + max_file_size = 100000, + excluded_dirs = {}, + } + + assert.is_true(scanner.should_index(test_cwd .. "/src/main.lua", config)) + end) + + it("should not index coder files", function() + utils.write_file(test_cwd .. "/main.coder.lua", "-- test") + + local config = { + index_extensions = { "lua" }, + max_file_size = 100000, + excluded_dirs = {}, + } + + assert.is_false(scanner.should_index(test_cwd .. "/main.coder.lua", config)) + end) + + it("should not index files with disallowed extensions", function() + utils.write_file(test_cwd .. "/image.png", "binary") + + local config = { + index_extensions = { "lua", "ts", "js" }, + max_file_size = 100000, + excluded_dirs = {}, + } + + assert.is_false(scanner.should_index(test_cwd .. "/image.png", config)) + end) + end) + + describe("get_indexable_files", function() + it("should return list of indexable files", function() + vim.fn.mkdir(test_cwd .. "/src", "p") + utils.write_file(test_cwd .. "/src/main.lua", "-- main") + utils.write_file(test_cwd .. "/src/utils.lua", "-- utils") + utils.write_file(test_cwd .. "/README.md", "# Readme") + + local config = { + index_extensions = { "lua" }, + max_file_size = 100000, + excluded_dirs = { "node_modules" }, + } + + local files = scanner.get_indexable_files(test_cwd, config) + + assert.equals(2, #files) + end) + + it("should skip ignored directories", function() + vim.fn.mkdir(test_cwd .. "/src", "p") + vim.fn.mkdir(test_cwd .. "/node_modules", "p") + utils.write_file(test_cwd .. "/src/main.lua", "-- main") + utils.write_file(test_cwd .. "/node_modules/package.lua", "-- ignore") + + local config = { + index_extensions = { "lua" }, + max_file_size = 100000, + excluded_dirs = { "node_modules" }, + } + + local files = scanner.get_indexable_files(test_cwd, config) + + -- Should only include src/main.lua + assert.equals(1, #files) + end) + end) + + describe("get_language", function() + it("should return correct language for extensions", function() + assert.equals("lua", scanner.get_language("test.lua")) + assert.equals("typescript", scanner.get_language("test.ts")) + assert.equals("javascript", scanner.get_language("test.js")) + assert.equals("python", scanner.get_language("test.py")) + assert.equals("go", scanner.get_language("test.go")) + assert.equals("rust", scanner.get_language("test.rs")) + end) + + it("should return extension as fallback", function() + assert.equals("unknown", scanner.get_language("test.unknown")) + end) + end) +end)