Refactor: Restructure project into core, features, adapters, and config modules

This commit is contained in:
2026-01-16 11:52:46 -05:00
parent 4fb52596e3
commit 10c1de8843
83 changed files with 337 additions and 337 deletions

View File

@@ -0,0 +1,268 @@
---@mod codetyper.agent.context_builder Context builder for agent prompts
---
--- Builds rich context including project structure, memories, and conventions
--- to help the LLM understand the codebase.
local M = {}
local utils = require("codetyper.support.utils")
local params = require("codetyper.params.agent.context")
--- Get project structure as a tree string
---@param max_depth? number Maximum depth to traverse (default: 3)
---@param max_files? number Maximum files to show (default: 50)
---@return string Project tree
function M.get_project_structure(max_depth, max_files)
max_depth = max_depth or 3
max_files = max_files or 50
local root = utils.get_project_root() or vim.fn.getcwd()
local lines = { "PROJECT STRUCTURE:", root, "" }
local file_count = 0
-- Common ignore patterns
local ignore_patterns = params.ignore_patterns
local function should_ignore(name)
for _, pattern in ipairs(ignore_patterns) do
if name:match(pattern) then
return true
end
end
return false
end
local function traverse(path, depth, prefix)
if depth > max_depth or file_count >= max_files then
return
end
local entries = {}
local handle = vim.loop.fs_scandir(path)
if not handle then
return
end
while true do
local name, type = vim.loop.fs_scandir_next(handle)
if not name then
break
end
if not should_ignore(name) then
table.insert(entries, { name = name, type = type })
end
end
-- Sort: directories first, then alphabetically
table.sort(entries, function(a, b)
if a.type == "directory" and b.type ~= "directory" then
return true
elseif a.type ~= "directory" and b.type == "directory" then
return false
else
return a.name < b.name
end
end)
for i, entry in ipairs(entries) do
if file_count >= max_files then
table.insert(lines, prefix .. "... (truncated)")
return
end
local is_last = (i == #entries)
local branch = is_last and "└── " or "├── "
local new_prefix = prefix .. (is_last and " " or "")
local icon = entry.type == "directory" and "/" or ""
table.insert(lines, prefix .. branch .. entry.name .. icon)
file_count = file_count + 1
if entry.type == "directory" then
traverse(path .. "/" .. entry.name, depth + 1, new_prefix)
end
end
end
traverse(root, 1, "")
if file_count >= max_files then
table.insert(lines, "")
table.insert(lines, "(Structure truncated at " .. max_files .. " entries)")
end
return table.concat(lines, "\n")
end
--- Get key files that are important for understanding the project
---@return table<string, string> Map of filename to description
function M.get_key_files()
local root = utils.get_project_root() or vim.fn.getcwd()
local key_files = {}
local important_files = {
["package.json"] = "Node.js project config",
["Cargo.toml"] = "Rust project config",
["go.mod"] = "Go module config",
["pyproject.toml"] = "Python project config",
["setup.py"] = "Python setup config",
["Makefile"] = "Build configuration",
["CMakeLists.txt"] = "CMake config",
[".gitignore"] = "Git ignore patterns",
["README.md"] = "Project documentation",
["init.lua"] = "Neovim plugin entry",
["plugin.lua"] = "Neovim plugin config",
}
for filename, desc in paparams.important_filesnd
return key_files
end
--- Detect project type and language
---@return table { type: string, language: string, framework?: string }
function M.detect_project_type()
local root = utils.get_project_root() or vim.fn.getcwd()
local indicators = {
["package.json"] = { type = "node", language = "javascript/typescript" },
["Cargo.toml"] = { type = "rust", language = "rust" },
["go.mod"] = { type = "go", language = "go" },
["pyproject.toml"] = { type = "python", language = "python" },
["setup.py"] = { type = "python", language = "python" },
["Gemfile"] = { type = "ruby", language = "ruby" },
["pom.xml"] = { type = "maven", language = "java" },
["build.gradle"] = { type = "gradle", language = "java/kotlin" },
}
-- Check for Neovim plugin specifically
if vim.fn.isdirectoparams.indicators return info
end
end
return { type = "unknown", language = "unknown" }
end
--- Get memories/patterns from the brain system
---@return string Formatted memories context
function M.get_memories_context()
local ok_memory, memory = pcall(require, "codetyper.indexer.memory")
if not ok_memory then
return ""
end
local all = memory.get_all()
if not all then
return ""
end
local lines = {}
-- Add patterns
if all.patterns and next(all.patterns) then
table.insert(lines, "LEARNED PATTERNS:")
local count = 0
for _, mem in pairs(all.patterns) do
if count >= 5 then
break
end
if mem.content then
table.insert(lines, " - " .. mem.content:sub(1, 100))
count = count + 1
end
end
table.insert(lines, "")
end
-- Add conventions
if all.conventions and next(all.conventions) then
table.insert(lines, "CODING CONVENTIONS:")
local count = 0
for _, mem in pairs(all.conventions) do
if count >= 5 then
break
end
if mem.content then
table.insert(lines, " - " .. mem.content:sub(1, 100))
count = count + 1
end
end
table.insert(lines, "")
end
return table.concat(lines, "\n")
end
--- Build the full context for agent prompts
---@return string Full context string
function M.build_full_context()
local sections = {}
-- Project info
local project_type = M.detect_project_type()
table.insert(sections, string.format(
"PROJECT INFO:\n Type: %s\n Language: %s%s\n",
project_type.type,
project_type.language,
project_type.framework and ("\n Framework: " .. project_type.framework) or ""
))
-- Project structure
local structure = M.get_project_structure(3, 40)
table.insert(sections, structure)
-- Key files
local key_files = M.get_key_files()
if next(key_files) then
local key_lines = { "", "KEY FILES:" }
for name, info in pairs(key_files) do
table.insert(key_lines, string.format(" %s - %s", name, info.description))
end
table.insert(sections, table.concat(key_lines, "\n"))
end
-- Memories
local memories = M.get_memories_context()
if memories ~= "" then
table.insert(sections, "\n" .. memories)
end
return table.concat(sections, "\n")
end
--- Get a compact context summary for token efficiency
---@return string Compact context
function M.build_compact_context()
local root = utils.get_project_root() or vim.fn.getcwd()
local project_type = M.detect_project_type()
local lines = {
"CONTEXT:",
" Root: " .. root,
" Type: " .. project_type.type .. " (" .. project_type.language .. ")",
}
-- Add main directories
local main_dirs = {}
local handle = vim.loop.fs_scandir(root)
if handle then
while true do
local name, type = vim.loop.fs_scandir_next(handle)
if not name then
break
end
if type == "directory" and not name:match("^%.") and not name:match("node_modules") then
table.insert(main_dirs, name .. "/")
end
end
end
if #main_dirs > 0 then
table.sort(main_dirs)
table.insert(lines, " Main dirs: " .. table.concat(main_dirs, ", "))
end
return table.concat(lines, "\n")
end
return M

View File

@@ -0,0 +1,754 @@
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
---@brief [[
--- Full agentic system that handles multi-file changes via tool calling.
--- Multi-file agent system with tool orchestration.
---@brief ]]
local M = {}
---@class AgenticMessage
---@field role "system"|"user"|"assistant"|"tool"
---@field content string|table
---@field tool_calls? table[] For assistant messages with tool calls
---@field tool_call_id? string For tool result messages
---@field name? string Tool name for tool results
---@class AgenticToolCall
---@field id string Unique tool call ID
---@field type "function"
---@field function {name: string, arguments: string|table}
---@class AgenticOpts
---@field task string The task to accomplish
---@field files? string[] Initial files to include as context
---@field agent? string Agent name to use (default: "coder")
---@field model? string Model override
---@field max_iterations? number Max tool call rounds (default: 20)
---@field on_message? fun(msg: AgenticMessage) Called for each message
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
---@field on_file_change? fun(path: string, action: string) Called when file is modified
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
---@field on_status? fun(status: string) Status updates
local utils = require("codetyper.support.utils")
--- Load agent definition
---@param name string Agent name
---@return table|nil agent definition
local function load_agent(name)
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
local agent_file = agents_dir .. "/" .. name .. ".md"
-- Check if custom agent exists
if vim.fn.filereadable(agent_file) == 1 then
local content = table.concat(vim.fn.readfile(agent_file), "\n")
-- Parse frontmatter and content
local frontmatter = {}
local body = content
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
if fm_match then
-- Parse YAML-like frontmatter
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
local key, value = line:match("^(%w+):%s*(.+)$")
if key and value then
frontmatter[key] = value
end
end
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
end
return {
name = name,
description = frontmatter.description or "Custom agent: " .. name,
system_prompt = body,
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
model = frontmatter.model,
}
end
-- Built-in agents
local builtin_agents = require("codetyper.prompts.agents.personas").builtin
return builtin_agents[name]
end
--- Load rules from .coder/rules/
---@return string Combined rules content
local function load_rules()
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
local rules = {}
if vim.fn.isdirectory(rules_dir) == 1 then
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
for _, file in ipairs(files) do
local content = table.concat(vim.fn.readfile(file), "\n")
local filename = vim.fn.fnamemodify(file, ":t:r")
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
end
end
if #rules > 0 then
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
end
return ""
end
--- Build messages array for API request
---@param history AgenticMessage[]
---@param provider string "openai"|"claude"
---@return table[] Formatted messages
local function build_messages(history, provider)
local messages = {}
for _, msg in ipairs(history) do
if msg.role == "system" then
if provider == "claude" then
-- Claude uses system parameter, not message
-- Skip system messages in array
else
table.insert(messages, {
role = "system",
content = msg.content,
})
end
elseif msg.role == "user" then
table.insert(messages, {
role = "user",
content = msg.content,
})
elseif msg.role == "assistant" then
local message = {
role = "assistant",
content = msg.content,
}
if msg.tool_calls then
message.tool_calls = msg.tool_calls
if provider == "claude" then
-- Claude format: content is array of blocks
message.content = {}
if msg.content and msg.content ~= "" then
table.insert(message.content, {
type = "text",
text = msg.content,
})
end
for _, tc in ipairs(msg.tool_calls) do
table.insert(message.content, {
type = "tool_use",
id = tc.id,
name = tc["function"].name,
input = type(tc["function"].arguments) == "string"
and vim.json.decode(tc["function"].arguments)
or tc["function"].arguments,
})
end
end
end
table.insert(messages, message)
elseif msg.role == "tool" then
if provider == "claude" then
table.insert(messages, {
role = "user",
content = {
{
type = "tool_result",
tool_use_id = msg.tool_call_id,
content = msg.content,
},
},
})
else
table.insert(messages, {
role = "tool",
tool_call_id = msg.tool_call_id,
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
})
end
end
end
return messages
end
--- Build tools array for API request
---@param tool_names string[] Tool names to include
---@param provider string "openai"|"claude"
---@return table[] Formatted tools
local function build_tools(tool_names, provider)
local tools_mod = require("codetyper.core.tools")
local tools = {}
for _, name in ipairs(tool_names) do
local tool = tools_mod.get(name)
if tool then
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if not param.optional then
table.insert(required, param.name)
end
end
local description = type(tool.description) == "function" and tool.description() or tool.description
if provider == "claude" then
table.insert(tools, {
name = tool.name,
description = description,
input_schema = {
type = "object",
properties = properties,
required = required,
},
})
else
table.insert(tools, {
type = "function",
["function"] = {
name = tool.name,
description = description,
parameters = {
type = "object",
properties = properties,
required = required,
},
},
})
end
end
end
return tools
end
--- Execute a tool call
---@param tool_call AgenticToolCall
---@param opts AgenticOpts
---@return string result
---@return string|nil error
local function execute_tool(tool_call, opts)
local tools_mod = require("codetyper.core.tools")
local name = tool_call["function"].name
local args = tool_call["function"].arguments
-- Parse arguments if string
if type(args) == "string" then
local ok, parsed = pcall(vim.json.decode, args)
if ok then
args = parsed
else
return "", "Failed to parse tool arguments: " .. args
end
end
-- Notify tool start
if opts.on_tool_start then
opts.on_tool_start(name, args)
end
if opts.on_status then
opts.on_status("Executing: " .. name)
end
-- Execute the tool
local tool = tools_mod.get(name)
if not tool then
local err = "Unknown tool: " .. name
if opts.on_tool_end then
opts.on_tool_end(name, nil, err)
end
return "", err
end
local result, err = tool.func(args, {
on_log = function(msg)
if opts.on_status then
opts.on_status(msg)
end
end,
})
-- Notify tool end
if opts.on_tool_end then
opts.on_tool_end(name, result, err)
end
-- Track file changes
if opts.on_file_change and (name == "write" or name == "edit") and not err then
opts.on_file_change(args.path, name == "write" and "created" or "modified")
end
if err then
return "", err
end
return type(result) == "string" and result or vim.json.encode(result), nil
end
--- Parse tool calls from LLM response (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return AgenticToolCall[]
local function parse_tool_calls(response, provider)
local tool_calls = {}
-- Unified format: content array with tool_use blocks
local content = response.content or {}
for _, block in ipairs(content) do
if block.type == "tool_use" then
-- OpenAI expects arguments as JSON string, not table
local args = block.input
if type(args) == "table" then
args = vim.json.encode(args)
end
table.insert(tool_calls, {
id = block.id or utils.generate_id("call"),
type = "function",
["function"] = {
name = block.name,
arguments = args,
},
})
end
end
return tool_calls
end
--- Extract text content from response (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return string
local function extract_content(response, provider)
local parts = {}
for _, block in ipairs(response.content or {}) do
if block.type == "text" then
table.insert(parts, block.text)
end
end
return table.concat(parts, "\n")
end
--- Check if response indicates completion (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return boolean
local function is_complete(response, provider)
return response.stop_reason == "end_turn"
end
--- Make API request to LLM with native tool calling support
---@param messages table[] Formatted messages
---@param tools table[] Formatted tools
---@param system_prompt string System prompt
---@param provider string "openai"|"claude"|"copilot"
---@param model string Model name
---@param callback fun(response: table|nil, error: string|nil)
local function call_llm(messages, tools, system_prompt, provider, model, callback)
local context = {
language = "lua",
file_content = "",
prompt_type = "agent",
project_root = vim.fn.getcwd(),
cwd = vim.fn.getcwd(),
}
-- Use native tool calling APIs
if provider == "copilot" then
local client = require("codetyper.core.llm.copilot")
-- Copilot's generate_with_tools expects messages in a specific format
-- Convert to the format it expects
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
-- Convert to our internal format
local result = {
content = {},
stop_reason = "end_turn",
}
if response and response.content then
for _, block in ipairs(response.content) do
if block.type == "text" then
table.insert(result.content, { type = "text", text = block.text })
elseif block.type == "tool_use" then
table.insert(result.content, {
type = "tool_use",
id = block.id or utils.generate_id("call"),
name = block.name,
input = block.input,
})
result.stop_reason = "tool_use"
end
end
end
callback(result, nil)
end)
elseif provider == "openai" then
local client = require("codetyper.core.llm.openai")
-- OpenAI's generate_with_tools
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
local result = {
content = {},
stop_reason = "end_turn",
}
if response and response.content then
for _, block in ipairs(response.content) do
if block.type == "text" then
table.insert(result.content, { type = "text", text = block.text })
elseif block.type == "tool_use" then
table.insert(result.content, {
type = "tool_use",
id = block.id or utils.generate_id("call"),
name = block.name,
input = block.input,
})
result.stop_reason = "tool_use"
end
end
end
callback(result, nil)
end)
elseif provider == "ollama" then
local client = require("codetyper.core.llm.ollama")
-- Ollama's generate_with_tools (text-based tool calling)
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
callback(response, nil)
end)
else
-- Fallback for other providers (ollama, etc.) - use text-based parsing
local client = require("codetyper.core.llm." .. provider)
-- Build prompt from messages
local prompts = require("codetyper.prompts.agent")
local prompt_parts = {}
for _, msg in ipairs(messages) do
if msg.role == "user" then
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
table.insert(prompt_parts, prompts.text_user_prefix .. content)
elseif msg.role == "assistant" then
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
table.insert(prompt_parts, prompts.text_assistant_prefix .. content)
end
end
-- Add tool descriptions to prompt for text-based providers
local tool_desc = require("codetyper.prompts.agent").tool_instructions_text
for _, tool in ipairs(tools) do
local name = tool.name or (tool["function"] and tool["function"].name)
local desc = tool.description or (tool["function"] and tool["function"].description)
if name then
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
end
end
context.file_content = system_prompt .. tool_desc
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
if err then
callback(nil, err)
return
end
-- Parse response for tool calls (text-based fallback)
local result = {
content = {},
stop_reason = "end_turn",
}
-- Extract text content
local text_content = response
-- Try to extract JSON tool calls from response
local json_match = response:match("```json%s*(%b{})%s*```")
if json_match then
local ok, parsed = pcall(vim.json.decode, json_match)
if ok and parsed.tool then
table.insert(result.content, {
type = "tool_use",
id = utils.generate_id("call"),
name = parsed.tool,
input = parsed.arguments or {},
})
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
result.stop_reason = "tool_use"
end
end
if text_content and text_content ~= "" then
table.insert(result.content, 1, { type = "text", text = text_content })
end
callback(result, nil)
end)
end
end
--- Run the agentic loop
---@param opts AgenticOpts
function M.run(opts)
-- Load agent
local agent = load_agent(opts.agent or "coder")
if not agent then
if opts.on_complete then
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
end
return
end
-- Load rules
local rules = load_rules()
-- Build system prompt
local system_prompt = agent.system_prompt .. rules
-- Initialize message history
---@type AgenticMessage[]
local history = {
{ role = "system", content = system_prompt },
}
-- Add initial file context if provided
if opts.files and #opts.files > 0 then
local file_context = require("codetyper.prompts.agent").format_file_context(opts.files)
table.insert(history, { role = "user", content = file_context })
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
end
-- Add the task
table.insert(history, { role = "user", content = opts.task })
-- Determine provider
local config = require("codetyper").get_config()
local provider = config.llm.provider or "copilot"
-- Note: Ollama has its own handler in call_llm, don't change it
-- Get tools for this agent
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
-- Ensure tools are loaded
local tools_mod = require("codetyper.core.tools")
tools_mod.setup()
-- Build tools for API
local tools = build_tools(tool_names, provider)
-- Iteration tracking
local iteration = 0
local max_iterations = opts.max_iterations or 20
--- Process one iteration
local function process_iteration()
iteration = iteration + 1
if iteration > max_iterations then
if opts.on_complete then
opts.on_complete(nil, "Max iterations reached")
end
return
end
if opts.on_status then
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
end
-- Build messages for API
local messages = build_messages(history, provider)
-- Call LLM
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
if err then
if opts.on_complete then
opts.on_complete(nil, err)
end
return
end
-- Extract content and tool calls
local content = extract_content(response, provider)
local tool_calls = parse_tool_calls(response, provider)
-- Add assistant message to history
local assistant_msg = {
role = "assistant",
content = content,
tool_calls = #tool_calls > 0 and tool_calls or nil,
}
table.insert(history, assistant_msg)
if opts.on_message then
opts.on_message(assistant_msg)
end
-- Process tool calls if any
if #tool_calls > 0 then
for _, tc in ipairs(tool_calls) do
local result, tool_err = execute_tool(tc, opts)
-- Add tool result to history
local tool_msg = {
role = "tool",
tool_call_id = tc.id,
name = tc["function"].name,
content = tool_err or result,
}
table.insert(history, tool_msg)
if opts.on_message then
opts.on_message(tool_msg)
end
end
-- Continue the loop
vim.schedule(process_iteration)
else
-- No tool calls - check if complete
if is_complete(response, provider) or content ~= "" then
if opts.on_complete then
opts.on_complete(content, nil)
end
else
-- Continue if not explicitly complete
vim.schedule(process_iteration)
end
end
end)
end
-- Start the loop
process_iteration()
end
--- Create default agent files in .coder/agents/
function M.init_agents_dir()
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
vim.fn.mkdir(agents_dir, "p")
-- Create example agent
local example_agent = require("codetyper.prompts.agents.templates").agent
local example_path = agents_dir .. "/example.md"
if vim.fn.filereadable(example_path) ~= 1 then
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
end
return agents_dir
end
--- Create default rules in .coder/rules/
function M.init_rules_dir()
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
vim.fn.mkdir(rules_dir, "p")
-- Create example rule
local example_rule = require("codetyper.prompts.agents.templates").rule
local example_path = rules_dir .. "/code-style.md"
if vim.fn.filereadable(example_path) ~= 1 then
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
end
return rules_dir
end
--- Initialize both agents and rules directories
function M.init()
M.init_agents_dir()
M.init_rules_dir()
end
--- List available agents
---@return table[] List of {name, description, builtin}
function M.list_agents()
local agents = {}
-- Built-in agents
local personas = require("codetyper.prompts.agents.personas").builtin
local builtins = vim.tbl_keys(personas)
table.sort(builtins)
for _, name in ipairs(builtins) do
local agent = load_agent(name)
if agent then
table.insert(agents, {
name = agent.name,
description = agent.description,
builtin = true,
})
end
end
-- Custom agents from .coder/agents/
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
if vim.fn.isdirectory(agents_dir) == 1 then
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
for _, file in ipairs(files) do
local name = vim.fn.fnamemodify(file, ":t:r")
if not vim.tbl_contains(builtins, name) then
local agent = load_agent(name)
if agent then
table.insert(agents, {
name = agent.name,
description = agent.description,
builtin = false,
})
end
end
end
end
return agents
end
return M

View File

@@ -0,0 +1,455 @@
---@mod codetyper.agent Agent orchestration for Codetyper.nvim
---
--- Manages the agentic conversation loop with tool execution.
local M = {}
local tools = require("codetyper.core.tools")
local executor = require("codetyper.core.scheduler.executor")
local parser = require("codetyper.core.llm.parser")
local diff = require("codetyper.core.diff.diff")
local diff_review = require("codetyper.adapters.nvim.ui.diff_review")
local resume = require("codetyper.core.scheduler.resume")
local utils = require("codetyper.support.utils")
local logs = require("codetyper.adapters.nvim.ui.logs")
---@class AgentState
---@field conversation table[] Message history for multi-turn
---@field pending_tool_results table[] Results waiting to be sent back
---@field is_running boolean Whether agent loop is active
---@field max_iterations number Maximum tool call iterations
local state = {
conversation = {},
pending_tool_results = {},
is_running = false,
max_iterations = 25, -- Increased for complex tasks (env setup, tests, fixes)
current_iteration = 0,
original_prompt = "", -- Store for resume functionality
current_context = nil, -- Store context for resume
current_callbacks = nil, -- Store callbacks for continue
}
---@class AgentCallbacks
---@field on_text fun(text: string) Called when text content is received
---@field on_tool_start fun(name: string) Called when a tool starts
---@field on_tool_result fun(name: string, result: string) Called when a tool completes
---@field on_complete fun() Called when agent finishes
---@field on_error fun(err: string) Called on error
--- Reset agent state for new conversation
function M.reset()
state.conversation = {}
state.pending_tool_results = {}
state.is_running = false
state.current_iteration = 0
-- Clear collected diffs
diff_review.clear()
end
--- Check if agent is currently running
---@return boolean
function M.is_running()
return state.is_running
end
--- Stop the agent
function M.stop()
state.is_running = false
utils.notify("Agent stopped")
end
--- Main agent entry point
---@param prompt string User's request
---@param context table File context
---@param callbacks AgentCallbacks Callback functions
function M.run(prompt, context, callbacks)
if state.is_running then
callbacks.on_error("Agent is already running")
return
end
logs.info("Starting agent run")
logs.debug("Prompt length: " .. #prompt .. " chars")
state.is_running = true
state.current_iteration = 0
state.original_prompt = prompt
state.current_context = context
state.current_callbacks = callbacks
-- Add user message to conversation
table.insert(state.conversation, {
role = "user",
content = prompt,
})
-- Start the agent loop
M.agent_loop(context, callbacks)
end
--- The core agent loop
---@param context table File context
---@param callbacks AgentCallbacks
function M.agent_loop(context, callbacks)
if not state.is_running then
callbacks.on_complete()
return
end
state.current_iteration = state.current_iteration + 1
logs.info(string.format("Agent loop iteration %d/%d", state.current_iteration, state.max_iterations))
if state.current_iteration > state.max_iterations then
logs.info("Max iterations reached, asking user to continue or stop")
-- Ask user if they want to continue
M.prompt_continue(context, callbacks)
return
end
local llm = require("codetyper.core.llm")
local client = llm.get_client()
-- Check if client supports tools
if not client.generate_with_tools then
logs.error("Provider does not support agent mode")
callbacks.on_error("Current LLM provider does not support agent mode")
state.is_running = false
return
end
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
-- Generate with tools enabled
-- Ensure tools are loaded and get definitions
tools.setup()
local tool_defs = tools.to_openai_format()
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
if err then
state.is_running = false
callbacks.on_error(err)
return
end
-- Parse response based on provider
local codetyper = require("codetyper")
local config = codetyper.get_config()
local parsed
-- Copilot uses Claude-like response format
if config.llm.provider == "copilot" then
parsed = parser.parse_claude_response(response)
table.insert(state.conversation, {
role = "assistant",
content = parsed.text or "",
tool_calls = parsed.tool_calls,
_raw_content = response.content,
})
else
-- For Ollama, response is the text directly
if type(response) == "string" then
parsed = parser.parse_ollama_response(response)
else
parsed = parser.parse_ollama_response(response.response or "")
end
-- Add assistant response to conversation
table.insert(state.conversation, {
role = "assistant",
content = parsed.text,
tool_calls = parsed.tool_calls,
})
end
-- Display any text content
if parsed.text and parsed.text ~= "" then
local clean_text = parser.clean_text(parsed.text)
if clean_text ~= "" then
callbacks.on_text(clean_text)
end
end
-- Check for tool calls
if #parsed.tool_calls > 0 then
logs.info(string.format("Processing %d tool call(s)", #parsed.tool_calls))
-- Process tool calls sequentially
M.process_tool_calls(parsed.tool_calls, 1, context, callbacks)
else
-- No more tool calls, agent is done
logs.info("No tool calls, finishing agent loop")
state.is_running = false
callbacks.on_complete()
end
end)
end
--- Process tool calls one at a time
---@param tool_calls table[] List of tool calls
---@param index number Current index
---@param context table File context
---@param callbacks AgentCallbacks
function M.process_tool_calls(tool_calls, index, context, callbacks)
if not state.is_running then
callbacks.on_complete()
return
end
if index > #tool_calls then
-- All tools processed, continue agent loop with results
M.continue_with_results(context, callbacks)
return
end
local tool_call = tool_calls[index]
callbacks.on_tool_start(tool_call.name)
executor.execute(tool_call.name, tool_call.parameters, function(result)
if result.requires_approval then
logs.tool(tool_call.name, "approval", "Waiting for user approval")
-- Show diff preview and wait for user decision
local show_fn
if result.diff_data.operation == "bash" then
show_fn = function(_, cb)
diff.show_bash_approval(result.diff_data.modified:gsub("^%$ ", ""), cb)
end
else
show_fn = diff.show_diff
end
show_fn(result.diff_data, function(approval_result)
-- Handle both old (boolean) and new (table) approval result formats
local approved = type(approval_result) == "table" and approval_result.approved or approval_result
local permission_level = type(approval_result) == "table" and approval_result.permission_level or nil
if approved then
local log_msg = "User approved"
if permission_level == "allow_session" then
log_msg = "Allowed for session"
elseif permission_level == "allow_list" then
log_msg = "Added to allow list"
elseif permission_level == "auto" then
log_msg = "Auto-approved"
end
logs.tool(tool_call.name, "approved", log_msg)
-- Apply the change and collect for review
executor.apply_change(result.diff_data, function(apply_result)
-- Collect the diff for end-of-session review
if result.diff_data.operation ~= "bash" then
diff_review.add({
path = result.diff_data.path,
operation = result.diff_data.operation,
original = result.diff_data.original,
modified = result.diff_data.modified,
approved = true,
applied = true,
})
end
-- Store result for sending back to LLM
table.insert(state.pending_tool_results, {
tool_use_id = tool_call.id,
name = tool_call.name,
result = apply_result.result,
})
callbacks.on_tool_result(tool_call.name, apply_result.result)
-- Process next tool call
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
end)
else
logs.tool(tool_call.name, "rejected", "User rejected")
-- User rejected
table.insert(state.pending_tool_results, {
tool_use_id = tool_call.id,
name = tool_call.name,
result = "User rejected this change",
})
callbacks.on_tool_result(tool_call.name, "Rejected by user")
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
end
end)
else
-- No approval needed (read_file), store result immediately
table.insert(state.pending_tool_results, {
tool_use_id = tool_call.id,
name = tool_call.name,
result = result.result,
})
-- For read_file, just show a brief confirmation
local display_result = result.result
if tool_call.name == "read_file" and result.success then
display_result = "[Read " .. #result.result .. " bytes]"
end
callbacks.on_tool_result(tool_call.name, display_result)
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
end
end)
end
--- Continue the loop after tool execution
---@param context table File context
---@param callbacks AgentCallbacks
function M.continue_with_results(context, callbacks)
if #state.pending_tool_results == 0 then
state.is_running = false
callbacks.on_complete()
return
end
-- Build tool results message
local codetyper = require("codetyper")
local config = codetyper.get_config()
-- Copilot uses OpenAI format for tool results (role: "tool")
if config.llm.provider == "copilot" then
-- OpenAI-style tool messages - each result is a separate message
for _, result in ipairs(state.pending_tool_results) do
table.insert(state.conversation, {
role = "tool",
tool_call_id = result.tool_use_id,
content = result.result,
})
end
else
-- Ollama format: plain text describing results
local result_text = "Tool results:\n"
for _, result in ipairs(state.pending_tool_results) do
result_text = result_text .. "\n[" .. result.name .. "]: " .. result.result .. "\n"
end
table.insert(state.conversation, {
role = "user",
content = result_text,
})
end
state.pending_tool_results = {}
-- Continue the loop
M.agent_loop(context, callbacks)
end
--- Get conversation history
---@return table[]
function M.get_conversation()
return state.conversation
end
--- Set max iterations
---@param max number Maximum iterations
function M.set_max_iterations(max)
state.max_iterations = max
end
--- Get the count of collected changes
---@return number
function M.get_changes_count()
return diff_review.count()
end
--- Show the diff review UI for all collected changes
function M.show_diff_review()
diff_review.open()
end
--- Check if diff review is open
---@return boolean
function M.is_review_open()
return diff_review.is_open()
end
--- Prompt user to continue or stop at max iterations
---@param context table File context
---@param callbacks AgentCallbacks
function M.prompt_continue(context, callbacks)
vim.schedule(function()
vim.ui.select({ "Continue (25 more iterations)", "Stop and save for later" }, {
prompt = string.format("Agent reached %d iterations. Continue?", state.max_iterations),
}, function(choice)
if choice and choice:match("^Continue") then
-- Reset iteration counter and continue
state.current_iteration = 0
logs.info("User chose to continue, resetting iteration counter")
M.agent_loop(context, callbacks)
else
-- Save state for later resume
logs.info("User chose to stop, saving state for resume")
resume.save(
state.conversation,
state.pending_tool_results,
state.current_iteration,
state.original_prompt
)
state.is_running = false
callbacks.on_text("Agent paused. Use /continue to resume later.")
callbacks.on_complete()
end
end)
end)
end
--- Continue a previously stopped agent session
---@param callbacks AgentCallbacks
---@return boolean Success
function M.continue_session(callbacks)
if state.is_running then
utils.notify("Agent is already running", vim.log.levels.WARN)
return false
end
local saved = resume.load()
if not saved then
utils.notify("No saved agent session to continue", vim.log.levels.WARN)
return false
end
logs.info("Resuming agent session")
logs.info(string.format("Loaded %d messages, iteration %d", #saved.conversation, saved.iteration))
-- Restore state
state.conversation = saved.conversation
state.pending_tool_results = saved.pending_tool_results or {}
state.current_iteration = 0 -- Reset for fresh iterations
state.original_prompt = saved.original_prompt
state.is_running = true
state.current_callbacks = callbacks
-- Build context from current state
local llm = require("codetyper.core.llm")
local context = {}
local current_file = vim.fn.expand("%:p")
if current_file ~= "" and vim.fn.filereadable(current_file) == 1 then
context = llm.build_context(current_file, "agent")
end
state.current_context = context
-- Clear saved state
resume.clear()
-- Add continuation message
table.insert(state.conversation, {
role = "user",
content = "Continue where you left off. Complete the remaining tasks.",
})
-- Continue the loop
callbacks.on_text("Resuming agent session...")
M.agent_loop(context, callbacks)
return true
end
--- Check if there's a saved session to continue
---@return boolean
function M.has_saved_session()
return resume.has_saved_state()
end
--- Get info about saved session
---@return table|nil
function M.get_saved_session_info()
return resume.get_info()
end
return M

View File

@@ -0,0 +1,425 @@
---@mod codetyper.agent.linter Linter validation for generated code
---@brief [[
--- Validates generated code by checking LSP diagnostics after injection.
--- Automatically saves the file and waits for LSP to update before checking.
---@brief ]]
local M = {}
local config_params = require("codetyper.params.agent.linter")
local prompts = require("codetyper.prompts.agent.linter")
--- Configuration
local config = config_params.config
--- Diagnostic results for tracking
---@type table<number, table>
local validation_results = {}
--- Configure linter behavior
---@param opts table Configuration options
function M.configure(opts)
for k, v in pairs(opts) do
if config[k] ~= nil then
config[k] = v
end
end
end
--- Get current configuration
---@return table
function M.get_config()
return vim.deepcopy(config)
end
--- Save buffer if modified
---@param bufnr number Buffer number
---@return boolean success
local function save_buffer(bufnr)
if not vim.api.nvim_buf_is_valid(bufnr) then
return false
end
-- Skip if buffer is not modified
if not vim.bo[bufnr].modified then
return true
end
-- Skip if buffer has no name (unsaved file)
local bufname = vim.api.nvim_buf_get_name(bufnr)
if bufname == "" then
return false
end
-- Save the buffer
local ok, err = pcall(function()
vim.api.nvim_buf_call(bufnr, function()
vim.cmd("silent! write")
end)
end)
if not ok then
pcall(function()
local logs = require("codetyper.adapters.nvim.ui.logs")
logs.add({
type = "warning",
message = "Failed to save buffer: " .. tostring(err),
})
end)
return false
end
return true
end
--- Get LSP diagnostics for a buffer
---@param bufnr number Buffer number
---@param start_line? number Start line (1-indexed)
---@param end_line? number End line (1-indexed)
---@return table[] diagnostics List of diagnostics
function M.get_diagnostics(bufnr, start_line, end_line)
if not vim.api.nvim_buf_is_valid(bufnr) then
return {}
end
local all_diagnostics = vim.diagnostic.get(bufnr)
local filtered = {}
for _, diag in ipairs(all_diagnostics) do
-- Filter by severity
if diag.severity <= config.min_severity then
-- Filter by line range if specified
if start_line and end_line then
local diag_line = diag.lnum + 1 -- Convert to 1-indexed
if diag_line >= start_line and diag_line <= end_line then
table.insert(filtered, diag)
end
else
table.insert(filtered, diag)
end
end
end
return filtered
end
--- Format a diagnostic for display
---@param diag table Diagnostic object
---@return string
local function format_diagnostic(diag)
local severity_names = {
[vim.diagnostic.severity.ERROR] = "ERROR",
[vim.diagnostic.severity.WARN] = "WARN",
[vim.diagnostic.severity.INFO] = "INFO",
[vim.diagnostic.severity.HINT] = "HINT",
}
local severity = severity_names[diag.severity] or "UNKNOWN"
local line = diag.lnum + 1
local source = diag.source or "lsp"
return string.format("[%s] Line %d (%s): %s", severity, line, source, diag.message)
end
--- Check if there are errors in generated code region
---@param bufnr number Buffer number
---@param start_line number Start line (1-indexed)
---@param end_line number End line (1-indexed)
---@return table result {has_errors, has_warnings, diagnostics, summary}
function M.check_region(bufnr, start_line, end_line)
local diagnostics = M.get_diagnostics(bufnr, start_line, end_line)
local errors = 0
local warnings = 0
for _, diag in ipairs(diagnostics) do
if diag.severity == vim.diagnostic.severity.ERROR then
errors = errors + 1
elseif diag.severity == vim.diagnostic.severity.WARN then
warnings = warnings + 1
end
end
return {
has_errors = errors > 0,
has_warnings = warnings > 0,
error_count = errors,
warning_count = warnings,
diagnostics = diagnostics,
summary = string.format("%d error(s), %d warning(s)", errors, warnings),
}
end
--- Validate code after injection and report issues
---@param bufnr number Buffer number
---@param start_line? number Start line of injected code (1-indexed)
---@param end_line? number End line of injected code (1-indexed)
---@param callback? function Callback with (result) when validation completes
function M.validate_after_injection(bufnr, start_line, end_line, callback)
-- Save the file first
if config.auto_save then
save_buffer(bufnr)
end
-- Wait for LSP to process changes
vim.defer_fn(function()
if not vim.api.nvim_buf_is_valid(bufnr) then
if callback then callback(nil) end
return
end
local result
if start_line and end_line then
result = M.check_region(bufnr, start_line, end_line)
else
-- Check entire buffer
local line_count = vim.api.nvim_buf_line_count(bufnr)
result = M.check_region(bufnr, 1, line_count)
end
-- Store result for this buffer
validation_results[bufnr] = {
timestamp = os.time(),
result = result,
start_line = start_line,
end_line = end_line,
}
-- Log results
pcall(function()
local logs = require("codetyper.adapters.nvim.ui.logs")
if result.has_errors then
logs.add({
type = "error",
message = string.format("Linter found issues: %s", result.summary),
})
-- Log individual errors
for _, diag in ipairs(result.diagnostics) do
if diag.severity == vim.diagnostic.severity.ERROR then
logs.add({
type = "error",
message = format_diagnostic(diag),
})
end
end
elseif result.has_warnings then
logs.add({
type = "warning",
message = string.format("Linter warnings: %s", result.summary),
})
else
logs.add({
type = "success",
message = "Linter check passed - no errors or warnings",
})
end
end)
-- Notify user
if result.has_errors then
vim.notify(
string.format("Generated code has lint errors: %s", result.summary),
vim.log.levels.ERROR
)
-- Offer to fix if configured
if config.auto_offer_fix and #result.diagnostics > 0 then
M.offer_fix(bufnr, result)
end
elseif result.has_warnings then
vim.notify(
string.format("Generated code has warnings: %s", result.summary),
vim.log.levels.WARN
)
end
if callback then
callback(result)
end
end, config.diagnostic_delay_ms)
end
--- Offer to fix lint errors using AI
---@param bufnr number Buffer number
---@param result table Validation result
function M.offer_fix(bufnr, result)
if not result.has_errors and not result.has_warnings then
return
end
-- Build error summary for prompt
local error_messages = {}
for _, diag in ipairs(result.diagnostics) do
table.insert(error_messages, format_diagnostic(diag))
end
vim.ui.select(
{ "Yes - Auto-fix with AI", "No - I'll fix manually", "Show errors in quickfix" },
{
prompt = string.format("Found %d issue(s). Would you like AI to fix them?", #result.diagnostics),
},
function(choice)
if not choice then return end
if choice:match("^Yes") then
M.request_ai_fix(bufnr, result)
elseif choice:match("quickfix") then
M.show_in_quickfix(bufnr, result)
end
end
)
end
--- Show lint errors in quickfix list
---@param bufnr number Buffer number
---@param result table Validation result
function M.show_in_quickfix(bufnr, result)
local qf_items = {}
local bufname = vim.api.nvim_buf_get_name(bufnr)
for _, diag in ipairs(result.diagnostics) do
table.insert(qf_items, {
bufnr = bufnr,
filename = bufname,
lnum = diag.lnum + 1,
col = diag.col + 1,
text = diag.message,
type = diag.severity == vim.diagnostic.severity.ERROR and "E" or "W",
})
end
vim.fn.setqflist(qf_items, "r")
vim.cmd("copen")
end
--- Request AI to fix lint errors
---@param bufnr number Buffer number
---@param result table Validation result
function M.request_ai_fix(bufnr, result)
if not vim.api.nvim_buf_is_valid(bufnr) then
return
end
local filepath = vim.api.nvim_buf_get_name(bufnr)
-- Build fix prompt
local error_list = {}
for _, diag in ipairs(result.diagnostics) do
table.insert(error_list, format_diagnostic(diag))
end
-- Get the affected code region
local start_line = result.diagnostics[1] and (result.diagnostics[1].lnum + 1) or 1
local end_line = start_line
for _, diag in ipairs(result.diagnostics) do
local line = diag.lnum + 1
if line < start_line then start_line = line end
if line > end_line then end_line = line end
end
-- Expand range by a few lines for context
start_line = math.max(1, start_line - 5)
end_line = math.min(vim.api.nvim_buf_line_count(bufnr), end_line + 5)
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
local code_context = table.concat(lines, "\n")
-- Create fix prompt using inline tag
local fix_prompt = string.format(
prompts.fix_request,
table.concat(error_list, "\n"),
start_line,
end_line,
code_context
)
-- Queue the fix through the scheduler
local scheduler = require("codetyper.core.scheduler.scheduler")
local queue = require("codetyper.core.events.queue")
local patch_mod = require("codetyper.core.diff.patch")
-- Ensure scheduler is running
if not scheduler.status().running then
scheduler.start()
end
-- Take snapshot
local snapshot = patch_mod.snapshot_buffer(bufnr, {
start_line = start_line,
end_line = end_line,
})
-- Enqueue fix request
queue.enqueue({
id = queue.generate_id(),
bufnr = bufnr,
range = { start_line = start_line, end_line = end_line },
timestamp = os.clock(),
changedtick = snapshot.changedtick,
content_hash = snapshot.content_hash,
prompt_content = fix_prompt,
target_path = filepath,
priority = 1, -- High priority for fixes
status = "pending",
attempt_count = 0,
intent = {
type = "fix",
action = "replace",
confidence = 0.9,
},
scope_range = { start_line = start_line, end_line = end_line },
source = "linter_fix",
})
pcall(function()
local logs = require("codetyper.adapters.nvim.ui.logs")
logs.add({
type = "info",
message = "Queued AI fix request for lint errors",
})
end)
vim.notify("Queued AI fix request for lint errors", vim.log.levels.INFO)
end
--- Get last validation result for a buffer
---@param bufnr number Buffer number
---@return table|nil result
function M.get_last_result(bufnr)
return validation_results[bufnr]
end
--- Clear validation results for a buffer
---@param bufnr number Buffer number
function M.clear_result(bufnr)
validation_results[bufnr] = nil
end
--- Check if buffer has any lint errors currently
---@param bufnr number Buffer number
---@return boolean has_errors
function M.has_errors(bufnr)
local diagnostics = vim.diagnostic.get(bufnr, {
severity = vim.diagnostic.severity.ERROR,
})
return #diagnostics > 0
end
--- Check if buffer has any lint warnings currently
---@param bufnr number Buffer number
---@return boolean has_warnings
function M.has_warnings(bufnr)
local diagnostics = vim.diagnostic.get(bufnr, {
severity = { min = vim.diagnostic.severity.WARN },
})
return #diagnostics > 0
end
--- Validate all buffers with recent changes
function M.validate_all_changed()
for bufnr, data in pairs(validation_results) do
if vim.api.nvim_buf_is_valid(bufnr) then
M.validate_after_injection(bufnr, data.start_line, data.end_line)
end
end
end
return M

View File

@@ -0,0 +1,182 @@
---@mod codetyper.agent.permissions Permission manager for agent actions
---
--- Manages permissions for bash commands and file operations with
--- allow, allow-session, allow-list, and reject options.
local M = {}
---@class PermissionState
---@field session_allowed table<string, boolean> Commands allowed for this session
---@field allow_list table<string, boolean> Patterns always allowed
---@field deny_list table<string, boolean> Patterns always denied
local params = require("codetyper.params.agent.permissions")
local state = {
session_allowed = {},
allow_list = {},
deny_list = {},
}
--- Dangerous command patterns that should never be auto-allowed
local DANGEROUS_PATTERNS = params.dangerous_patterns
--- Safe command patterns that can be auto-allowed
local SAFE_PATTERNS = params.safe_patterns
---@alias PermissionLevel "allow"|"allow_session"|"allow_list"|"reject"
---@class PermissionResult
---@field allowed boolean Whether action is allowed
---@field reason string Reason for the decision
---@field auto boolean Whether this was an automatic decision
--- Check if a command matches a pattern
---@param command string The command to check
---@param pattern string The pattern to match
---@return boolean
local function matches_pattern(command, pattern)
return command:match(pattern) ~= nil
end
--- Check if command is dangerous
---@param command string The command to check
---@return boolean, string|nil dangerous, reason
local function is_dangerous(command)
for _, pattern in ipairs(DANGEROUS_PATTERNS) do
if matches_pattern(command, pattern) then
return true, "Matches dangerous pattern: " .. pattern
end
end
return false, nil
end
--- Check if command is safe
---@param command string The command to check
---@return boolean
local function is_safe(command)
for _, pattern in ipairs(SAFE_PATTERNS) do
if matches_pattern(command, pattern) then
return true
end
end
return false
end
--- Normalize command for comparison (trim, lowercase first word)
---@param command string
---@return string
local function normalize_command(command)
return vim.trim(command)
end
--- Check permission for a bash command
---@param command string The command to check
---@return PermissionResult
function M.check_bash_permission(command)
local normalized = normalize_command(command)
-- Check deny list first
for pattern, _ in pairs(state.deny_list) do
if matches_pattern(normalized, pattern) then
return {
allowed = false,
reason = "Command in deny list",
auto = true,
}
end
end
-- Check if command is dangerous
local dangerous, reason = is_dangerous(normalized)
if dangerous then
return {
allowed = false,
reason = reason,
auto = false, -- Require explicit approval for dangerous commands
}
end
-- Check session allowed
if state.session_allowed[normalized] then
return {
allowed = true,
reason = "Allowed for this session",
auto = true,
}
end
-- Check allow list patterns
for pattern, _ in pairs(state.allow_list) do
if matches_pattern(normalized, pattern) then
return {
allowed = true,
reason = "Matches allow list pattern",
auto = true,
}
end
end
-- Check if command is inherently safe
if is_safe(normalized) then
return {
allowed = true,
reason = "Safe read-only command",
auto = true,
}
end
-- Otherwise, require explicit permission
return {
allowed = false,
reason = "Requires approval",
auto = false,
}
end
--- Grant permission for a command
---@param command string The command
---@param level PermissionLevel The permission level
function M.grant_permission(command, level)
local normalized = normalize_command(command)
if level == "allow_session" then
state.session_allowed[normalized] = true
elseif level == "allow_list" then
-- Add as pattern (escape special chars for exact match)
local pattern = "^" .. vim.pesc(normalized) .. "$"
state.allow_list[pattern] = true
end
end
--- Add a pattern to the allow list
---@param pattern string Lua pattern to allow
function M.add_to_allow_list(pattern)
state.allow_list[pattern] = true
end
--- Add a pattern to the deny list
---@param pattern string Lua pattern to deny
function M.add_to_deny_list(pattern)
state.deny_list[pattern] = true
end
--- Clear session permissions
function M.clear_session()
state.session_allowed = {}
end
--- Reset all permissions
function M.reset()
state.session_allowed = {}
state.allow_list = {}
state.deny_list = {}
end
--- Get current permission state (for debugging)
---@return PermissionState
function M.get_state()
return vim.deepcopy(state)
end
return M

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,676 @@
---@mod codetyper.ask.explorer Project exploration for Ask mode
---@brief [[
--- Performs comprehensive project exploration when explaining a project.
--- Shows progress, indexes files, and builds brain context.
---@brief ]]
local M = {}
local utils = require("codetyper.support.utils")
---@class ExplorationState
---@field is_exploring boolean
---@field files_scanned number
---@field total_files number
---@field current_file string|nil
---@field findings table
---@field on_log fun(msg: string, level: string)|nil
local state = {
is_exploring = false,
files_scanned = 0,
total_files = 0,
current_file = nil,
findings = {},
on_log = nil,
}
--- File extensions to analyze
local ANALYZABLE_EXTENSIONS = {
lua = true,
ts = true,
tsx = true,
js = true,
jsx = true,
py = true,
go = true,
rs = true,
rb = true,
java = true,
c = true,
cpp = true,
h = true,
hpp = true,
json = true,
yaml = true,
yml = true,
toml = true,
md = true,
xml = true,
}
--- Directories to skip
local SKIP_DIRS = {
-- Version control
[".git"] = true,
[".svn"] = true,
[".hg"] = true,
-- IDE/Editor
[".idea"] = true,
[".vscode"] = true,
[".cursor"] = true,
[".cursorignore"] = true,
[".claude"] = true,
[".zed"] = true,
-- Project tooling
[".coder"] = true,
[".github"] = true,
[".gitlab"] = true,
[".husky"] = true,
-- Build outputs
dist = true,
build = true,
out = true,
target = true,
bin = true,
obj = true,
[".build"] = true,
[".output"] = true,
-- Dependencies
node_modules = true,
vendor = true,
[".vendor"] = true,
packages = true,
bower_components = true,
jspm_packages = true,
-- Cache/temp
[".cache"] = true,
[".tmp"] = true,
[".temp"] = true,
__pycache__ = true,
[".pytest_cache"] = true,
[".mypy_cache"] = true,
[".ruff_cache"] = true,
[".tox"] = true,
[".nox"] = true,
[".eggs"] = true,
["*.egg-info"] = true,
-- Framework specific
[".next"] = true,
[".nuxt"] = true,
[".svelte-kit"] = true,
[".vercel"] = true,
[".netlify"] = true,
[".serverless"] = true,
[".turbo"] = true,
-- Testing/coverage
coverage = true,
[".nyc_output"] = true,
htmlcov = true,
-- Logs
logs = true,
log = true,
-- OS files
[".DS_Store"] = true,
Thumbs_db = true,
}
--- Files to skip (patterns)
local SKIP_FILES = {
-- Lock files
"package%-lock%.json",
"yarn%.lock",
"pnpm%-lock%.yaml",
"Gemfile%.lock",
"Cargo%.lock",
"poetry%.lock",
"Pipfile%.lock",
"composer%.lock",
"go%.sum",
"flake%.lock",
"%.lock$",
"%-lock%.json$",
"%-lock%.yaml$",
-- Generated files
"%.min%.js$",
"%.min%.css$",
"%.bundle%.js$",
"%.chunk%.js$",
"%.map$",
"%.d%.ts$",
-- Binary/media (shouldn't match anyway but be safe)
"%.png$",
"%.jpg$",
"%.jpeg$",
"%.gif$",
"%.ico$",
"%.svg$",
"%.woff",
"%.ttf$",
"%.eot$",
"%.pdf$",
"%.zip$",
"%.tar",
"%.gz$",
-- Config that's not useful
"%.env",
"%.env%.",
}
--- Log a message during exploration
---@param msg string
---@param level? string "info"|"debug"|"file"|"progress"
local function log(msg, level)
level = level or "info"
if state.on_log then
state.on_log(msg, level)
end
end
--- Check if file should be skipped
---@param filename string
---@return boolean
local function should_skip_file(filename)
for _, pattern in ipairs(SKIP_FILES) do
if filename:match(pattern) then
return true
end
end
return false
end
--- Check if directory should be skipped
---@param dirname string
---@return boolean
local function should_skip_dir(dirname)
-- Direct match
if SKIP_DIRS[dirname] then
return true
end
-- Pattern match for .cursor* etc
if dirname:match("^%.cursor") then
return true
end
return false
end
--- Get all files in project
---@param root string Project root
---@return string[] files
local function get_project_files(root)
local files = {}
local function scan_dir(dir)
local handle = vim.loop.fs_scandir(dir)
if not handle then
return
end
while true do
local name, type = vim.loop.fs_scandir_next(handle)
if not name then
break
end
local full_path = dir .. "/" .. name
if type == "directory" then
if not should_skip_dir(name) then
scan_dir(full_path)
end
elseif type == "file" then
if not should_skip_file(name) then
local ext = name:match("%.([^%.]+)$")
if ext and ANALYZABLE_EXTENSIONS[ext:lower()] then
table.insert(files, full_path)
end
end
end
end
end
scan_dir(root)
return files
end
--- Analyze a single file
---@param filepath string
---@return table|nil analysis
local function analyze_file(filepath)
local content = utils.read_file(filepath)
if not content or content == "" then
return nil
end
local ext = filepath:match("%.([^%.]+)$") or ""
local lines = vim.split(content, "\n")
local analysis = {
path = filepath,
extension = ext,
lines = #lines,
size = #content,
imports = {},
exports = {},
functions = {},
classes = {},
summary = "",
}
-- Extract key patterns based on file type
for i, line in ipairs(lines) do
-- Imports/requires
local import = line:match('import%s+.*%s+from%s+["\']([^"\']+)["\']')
or line:match('require%(["\']([^"\']+)["\']%)')
or line:match("from%s+([%w_.]+)%s+import")
if import then
table.insert(analysis.imports, { source = import, line = i })
end
-- Function definitions
local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(")
or line:match("^%s*local%s+function%s+([%w_]+)%s*%(")
or line:match("^%s*def%s+([%w_]+)%s*%(")
or line:match("^%s*func%s+([%w_]+)%s*%(")
or line:match("^%s*async%s+function%s+([%w_]+)%s*%(")
or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(")
if func then
table.insert(analysis.functions, { name = func, line = i })
end
-- Class definitions
local class = line:match("^%s*class%s+([%w_]+)")
or line:match("^%s*public%s+class%s+([%w_]+)")
or line:match("^%s*interface%s+([%w_]+)")
if class then
table.insert(analysis.classes, { name = class, line = i })
end
-- Exports
local exp = line:match("^%s*export%s+.*%s+([%w_]+)")
or line:match("^%s*module%.exports%s*=")
or line:match("^return%s+M")
if exp then
table.insert(analysis.exports, { name = exp, line = i })
end
end
-- Create summary
local parts = {}
if #analysis.functions > 0 then
table.insert(parts, #analysis.functions .. " functions")
end
if #analysis.classes > 0 then
table.insert(parts, #analysis.classes .. " classes")
end
if #analysis.imports > 0 then
table.insert(parts, #analysis.imports .. " imports")
end
analysis.summary = table.concat(parts, ", ")
return analysis
end
--- Detect project type from files
---@param root string
---@return string type, table info
local function detect_project_type(root)
local info = {
name = vim.fn.fnamemodify(root, ":t"),
type = "unknown",
framework = nil,
language = nil,
}
-- Check for common project files
if utils.file_exists(root .. "/package.json") then
info.type = "node"
info.language = "JavaScript/TypeScript"
local content = utils.read_file(root .. "/package.json")
if content then
local ok, pkg = pcall(vim.json.decode, content)
if ok then
info.name = pkg.name or info.name
if pkg.dependencies then
if pkg.dependencies.react then
info.framework = "React"
elseif pkg.dependencies.vue then
info.framework = "Vue"
elseif pkg.dependencies.next then
info.framework = "Next.js"
elseif pkg.dependencies.express then
info.framework = "Express"
end
end
end
end
elseif utils.file_exists(root .. "/pom.xml") then
info.type = "maven"
info.language = "Java"
local content = utils.read_file(root .. "/pom.xml")
if content and content:match("spring%-boot") then
info.framework = "Spring Boot"
end
elseif utils.file_exists(root .. "/Cargo.toml") then
info.type = "rust"
info.language = "Rust"
elseif utils.file_exists(root .. "/go.mod") then
info.type = "go"
info.language = "Go"
elseif utils.file_exists(root .. "/requirements.txt") or utils.file_exists(root .. "/pyproject.toml") then
info.type = "python"
info.language = "Python"
elseif utils.file_exists(root .. "/init.lua") or utils.file_exists(root .. "/plugin/") then
info.type = "neovim-plugin"
info.language = "Lua"
end
return info.type, info
end
--- Build project structure summary
---@param files string[]
---@param root string
---@return table structure
local function build_structure(files, root)
local structure = {
directories = {},
by_extension = {},
total_files = #files,
}
for _, file in ipairs(files) do
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
local dir = vim.fn.fnamemodify(relative, ":h")
local ext = file:match("%.([^%.]+)$") or "unknown"
structure.directories[dir] = (structure.directories[dir] or 0) + 1
structure.by_extension[ext] = (structure.by_extension[ext] or 0) + 1
end
return structure
end
--- Explore project and build context
---@param root string Project root
---@param on_log fun(msg: string, level: string) Log callback
---@param on_complete fun(result: table) Completion callback
function M.explore(root, on_log, on_complete)
if state.is_exploring then
on_log("⚠️ Already exploring...", "warning")
return
end
state.is_exploring = true
state.on_log = on_log
state.findings = {}
-- Start exploration
log("⏺ Exploring project structure...", "info")
log("", "info")
-- Detect project type
log(" Detect(Project type)", "progress")
local project_type, project_info = detect_project_type(root)
log("" .. project_info.language .. " (" .. (project_info.framework or project_type) .. ")", "debug")
state.findings.project = project_info
-- Get all files
log("", "info")
log(" Scan(Project files)", "progress")
local files = get_project_files(root)
state.total_files = #files
log(" ⎿ Found " .. #files .. " analyzable files", "debug")
-- Build structure
local structure = build_structure(files, root)
state.findings.structure = structure
-- Show directory breakdown
log("", "info")
log(" Structure(Directories)", "progress")
local sorted_dirs = {}
for dir, count in pairs(structure.directories) do
table.insert(sorted_dirs, { dir = dir, count = count })
end
table.sort(sorted_dirs, function(a, b)
return a.count > b.count
end)
for i, entry in ipairs(sorted_dirs) do
if i <= 5 then
log("" .. entry.dir .. " (" .. entry.count .. " files)", "debug")
end
end
if #sorted_dirs > 5 then
log(" ⎿ +" .. (#sorted_dirs - 5) .. " more directories", "debug")
end
-- Analyze files asynchronously
log("", "info")
log(" Analyze(Source files)", "progress")
state.files_scanned = 0
local analyses = {}
local key_files = {}
-- Process files in batches to avoid blocking
local batch_size = 10
local current_batch = 0
local function process_batch()
local start_idx = current_batch * batch_size + 1
local end_idx = math.min(start_idx + batch_size - 1, #files)
for i = start_idx, end_idx do
local file = files[i]
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
state.files_scanned = state.files_scanned + 1
state.current_file = relative
local analysis = analyze_file(file)
if analysis then
analysis.relative_path = relative
table.insert(analyses, analysis)
-- Track key files (many functions/classes)
if #analysis.functions >= 3 or #analysis.classes >= 1 then
table.insert(key_files, {
path = relative,
functions = #analysis.functions,
classes = #analysis.classes,
summary = analysis.summary,
})
end
end
-- Log some files
if i <= 3 or (i % 20 == 0) then
log("" .. relative .. ": " .. (analysis and analysis.summary or "(empty)"), "file")
end
end
-- Progress update
local progress = math.floor((state.files_scanned / state.total_files) * 100)
if progress % 25 == 0 and progress > 0 then
log("" .. progress .. "% complete (" .. state.files_scanned .. "/" .. state.total_files .. ")", "debug")
end
current_batch = current_batch + 1
if end_idx < #files then
-- Schedule next batch
vim.defer_fn(process_batch, 10)
else
-- Complete
finish_exploration(root, analyses, key_files, on_complete)
end
end
-- Start processing
vim.defer_fn(process_batch, 10)
end
--- Finish exploration and store results
---@param root string
---@param analyses table
---@param key_files table
---@param on_complete fun(result: table)
function finish_exploration(root, analyses, key_files, on_complete)
log(" ⎿ +" .. (#analyses - 3) .. " more files analyzed", "debug")
-- Show key files
if #key_files > 0 then
log("", "info")
log(" KeyFiles(Important components)", "progress")
table.sort(key_files, function(a, b)
return (a.functions + a.classes * 2) > (b.functions + b.classes * 2)
end)
for i, kf in ipairs(key_files) do
if i <= 5 then
log("" .. kf.path .. ": " .. kf.summary, "file")
end
end
if #key_files > 5 then
log(" ⎿ +" .. (#key_files - 5) .. " more key files", "debug")
end
end
state.findings.analyses = analyses
state.findings.key_files = key_files
-- Store in brain if available
local ok_brain, brain = pcall(require, "codetyper.brain")
if ok_brain and brain.is_initialized() then
log("", "info")
log(" Store(Brain context)", "progress")
-- Store project pattern
brain.learn({
type = "pattern",
file = root,
content = {
summary = "Project: " .. state.findings.project.name,
detail = state.findings.project.language
.. " "
.. (state.findings.project.framework or state.findings.project.type),
code = nil,
},
context = {
file = root,
language = state.findings.project.language,
},
})
-- Store key file patterns
for i, kf in ipairs(key_files) do
if i <= 10 then
brain.learn({
type = "pattern",
file = root .. "/" .. kf.path,
content = {
summary = kf.path .. " - " .. kf.summary,
detail = kf.summary,
},
context = {
file = kf.path,
},
})
end
end
log(" ⎿ Stored " .. math.min(#key_files, 10) + 1 .. " patterns in brain", "debug")
end
-- Store in indexer if available
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
if ok_indexer then
log(" Index(Project index)", "progress")
indexer.index_project(function(index)
log(" ⎿ Indexed " .. (index.stats.files or 0) .. " files", "debug")
end)
end
log("", "info")
log("✓ Exploration complete!", "info")
log("", "info")
-- Build result
local result = {
project = state.findings.project,
structure = state.findings.structure,
key_files = key_files,
total_files = state.total_files,
analyses = analyses,
}
state.is_exploring = false
state.on_log = nil
on_complete(result)
end
--- Check if exploration is in progress
---@return boolean
function M.is_exploring()
return state.is_exploring
end
--- Get exploration progress
---@return number scanned, number total
function M.get_progress()
return state.files_scanned, state.total_files
end
--- Build context string from exploration result
---@param result table Exploration result
---@return string context
function M.build_context(result)
local parts = {}
-- Project info
table.insert(parts, "## Project: " .. result.project.name)
table.insert(parts, "- Type: " .. result.project.type)
table.insert(parts, "- Language: " .. (result.project.language or "Unknown"))
if result.project.framework then
table.insert(parts, "- Framework: " .. result.project.framework)
end
table.insert(parts, "- Files: " .. result.total_files)
table.insert(parts, "")
-- Structure
table.insert(parts, "## Structure")
if result.structure and result.structure.by_extension then
for ext, count in pairs(result.structure.by_extension) do
table.insert(parts, "- ." .. ext .. ": " .. count .. " files")
end
end
table.insert(parts, "")
-- Key components
if result.key_files and #result.key_files > 0 then
table.insert(parts, "## Key Components")
for i, kf in ipairs(result.key_files) do
if i <= 10 then
table.insert(parts, "- " .. kf.path .. ": " .. kf.summary)
end
end
end
return table.concat(parts, "\n")
end
return M

View File

@@ -0,0 +1,302 @@
---@mod codetyper.ask.intent Intent detection for Ask mode
---@brief [[
--- Analyzes user prompts to detect intent (ask/explain vs code generation).
--- Routes to appropriate prompt type and context sources.
---@brief ]]
local M = {}
---@alias IntentType "ask"|"explain"|"generate"|"refactor"|"document"|"test"
---@class Intent
---@field type IntentType Detected intent type
---@field confidence number 0-1 confidence score
---@field needs_project_context boolean Whether project-wide context is needed
---@field needs_brain_context boolean Whether brain/learned context is helpful
---@field needs_exploration boolean Whether full project exploration is needed
---@field keywords string[] Keywords that influenced detection
--- Patterns for detecting ask/explain intent (questions about code)
local ASK_PATTERNS = {
-- Question words
{ pattern = "^what%s", weight = 0.9 },
{ pattern = "^why%s", weight = 0.95 },
{ pattern = "^how%s+does", weight = 0.9 },
{ pattern = "^how%s+do%s+i", weight = 0.7 }, -- Could be asking for code
{ pattern = "^where%s", weight = 0.85 },
{ pattern = "^when%s", weight = 0.85 },
{ pattern = "^which%s", weight = 0.8 },
{ pattern = "^who%s", weight = 0.85 },
{ pattern = "^can%s+you%s+explain", weight = 0.95 },
{ pattern = "^could%s+you%s+explain", weight = 0.95 },
{ pattern = "^please%s+explain", weight = 0.95 },
-- Explanation requests
{ pattern = "explain%s", weight = 0.9 },
{ pattern = "describe%s", weight = 0.85 },
{ pattern = "tell%s+me%s+about", weight = 0.85 },
{ pattern = "walk%s+me%s+through", weight = 0.9 },
{ pattern = "help%s+me%s+understand", weight = 0.95 },
{ pattern = "what%s+is%s+the%s+purpose", weight = 0.95 },
{ pattern = "what%s+does%s+this", weight = 0.9 },
{ pattern = "what%s+does%s+it", weight = 0.9 },
{ pattern = "how%s+does%s+this%s+work", weight = 0.95 },
{ pattern = "how%s+does%s+it%s+work", weight = 0.95 },
-- Understanding queries
{ pattern = "understand", weight = 0.7 },
{ pattern = "meaning%s+of", weight = 0.85 },
{ pattern = "difference%s+between", weight = 0.9 },
{ pattern = "compared%s+to", weight = 0.8 },
{ pattern = "vs%s", weight = 0.7 },
{ pattern = "versus", weight = 0.7 },
{ pattern = "pros%s+and%s+cons", weight = 0.9 },
{ pattern = "advantages", weight = 0.8 },
{ pattern = "disadvantages", weight = 0.8 },
{ pattern = "trade%-?offs?", weight = 0.85 },
-- Analysis requests
{ pattern = "analyze", weight = 0.85 },
{ pattern = "review", weight = 0.7 }, -- Could also be refactor
{ pattern = "overview", weight = 0.9 },
{ pattern = "summary", weight = 0.9 },
{ pattern = "summarize", weight = 0.9 },
-- Question marks (weaker signal)
{ pattern = "%?$", weight = 0.3 },
{ pattern = "%?%s*$", weight = 0.3 },
}
--- Patterns for detecting code generation intent
local GENERATE_PATTERNS = {
-- Direct commands
{ pattern = "^create%s", weight = 0.9 },
{ pattern = "^make%s", weight = 0.85 },
{ pattern = "^build%s", weight = 0.85 },
{ pattern = "^write%s", weight = 0.9 },
{ pattern = "^add%s", weight = 0.85 },
{ pattern = "^implement%s", weight = 0.95 },
{ pattern = "^generate%s", weight = 0.95 },
{ pattern = "^code%s", weight = 0.8 },
-- Modification commands
{ pattern = "^fix%s", weight = 0.9 },
{ pattern = "^change%s", weight = 0.8 },
{ pattern = "^update%s", weight = 0.75 },
{ pattern = "^modify%s", weight = 0.8 },
{ pattern = "^replace%s", weight = 0.85 },
{ pattern = "^remove%s", weight = 0.85 },
{ pattern = "^delete%s", weight = 0.85 },
-- Feature requests
{ pattern = "i%s+need%s+a", weight = 0.8 },
{ pattern = "i%s+want%s+a", weight = 0.8 },
{ pattern = "give%s+me", weight = 0.7 },
{ pattern = "show%s+me%s+how%s+to%s+code", weight = 0.9 },
{ pattern = "how%s+do%s+i%s+implement", weight = 0.85 },
{ pattern = "can%s+you%s+write", weight = 0.9 },
{ pattern = "can%s+you%s+create", weight = 0.9 },
{ pattern = "can%s+you%s+add", weight = 0.85 },
{ pattern = "can%s+you%s+make", weight = 0.85 },
-- Code-specific terms
{ pattern = "function%s+that", weight = 0.85 },
{ pattern = "class%s+that", weight = 0.85 },
{ pattern = "method%s+that", weight = 0.85 },
{ pattern = "component%s+that", weight = 0.85 },
{ pattern = "module%s+that", weight = 0.85 },
{ pattern = "api%s+for", weight = 0.8 },
{ pattern = "endpoint%s+for", weight = 0.8 },
}
--- Patterns for detecting refactor intent
local REFACTOR_PATTERNS = {
{ pattern = "^refactor%s", weight = 0.95 },
{ pattern = "refactor%s+this", weight = 0.95 },
{ pattern = "clean%s+up", weight = 0.85 },
{ pattern = "improve%s+this%s+code", weight = 0.85 },
{ pattern = "make%s+this%s+cleaner", weight = 0.85 },
{ pattern = "simplify", weight = 0.8 },
{ pattern = "optimize", weight = 0.75 }, -- Could be explain
{ pattern = "reorganize", weight = 0.9 },
{ pattern = "restructure", weight = 0.9 },
{ pattern = "extract%s+to", weight = 0.9 },
{ pattern = "split%s+into", weight = 0.85 },
{ pattern = "dry%s+this", weight = 0.9 }, -- Don't repeat yourself
{ pattern = "reduce%s+duplication", weight = 0.9 },
}
--- Patterns for detecting documentation intent
local DOCUMENT_PATTERNS = {
{ pattern = "^document%s", weight = 0.95 },
{ pattern = "add%s+documentation", weight = 0.95 },
{ pattern = "add%s+docs", weight = 0.95 },
{ pattern = "add%s+comments", weight = 0.9 },
{ pattern = "add%s+docstring", weight = 0.95 },
{ pattern = "add%s+jsdoc", weight = 0.95 },
{ pattern = "write%s+documentation", weight = 0.95 },
{ pattern = "document%s+this", weight = 0.95 },
}
--- Patterns for detecting test generation intent
local TEST_PATTERNS = {
{ pattern = "^test%s", weight = 0.9 },
{ pattern = "write%s+tests?%s+for", weight = 0.95 },
{ pattern = "add%s+tests?%s+for", weight = 0.95 },
{ pattern = "create%s+tests?%s+for", weight = 0.95 },
{ pattern = "generate%s+tests?", weight = 0.95 },
{ pattern = "unit%s+tests?", weight = 0.9 },
{ pattern = "test%s+cases?%s+for", weight = 0.95 },
{ pattern = "spec%s+for", weight = 0.85 },
}
--- Patterns indicating project-wide context is needed
local PROJECT_CONTEXT_PATTERNS = {
{ pattern = "project", weight = 0.9 },
{ pattern = "codebase", weight = 0.95 },
{ pattern = "entire", weight = 0.7 },
{ pattern = "whole", weight = 0.7 },
{ pattern = "all%s+files", weight = 0.9 },
{ pattern = "architecture", weight = 0.95 },
{ pattern = "structure", weight = 0.85 },
{ pattern = "how%s+is%s+.*%s+organized", weight = 0.95 },
{ pattern = "where%s+is%s+.*%s+defined", weight = 0.9 },
{ pattern = "dependencies", weight = 0.85 },
{ pattern = "imports?%s+from", weight = 0.7 },
{ pattern = "modules?", weight = 0.6 },
{ pattern = "packages?", weight = 0.6 },
}
--- Patterns indicating project exploration is needed (full indexing)
local EXPLORE_PATTERNS = {
{ pattern = "explain%s+.*%s*project", weight = 1.0 },
{ pattern = "explain%s+.*%s*codebase", weight = 1.0 },
{ pattern = "explain%s+me%s+the%s+project", weight = 1.0 },
{ pattern = "tell%s+me%s+about%s+.*%s*project", weight = 0.95 },
{ pattern = "what%s+is%s+this%s+project", weight = 0.95 },
{ pattern = "overview%s+of%s+.*%s*project", weight = 0.95 },
{ pattern = "understand%s+.*%s*project", weight = 0.9 },
{ pattern = "analyze%s+.*%s*project", weight = 0.9 },
{ pattern = "explore%s+.*%s*project", weight = 1.0 },
{ pattern = "explore%s+.*%s*codebase", weight = 1.0 },
{ pattern = "index%s+.*%s*project", weight = 1.0 },
{ pattern = "scan%s+.*%s*project", weight = 0.95 },
}
--- Match patterns against text
---@param text string Lowercased text to match
---@param patterns table Pattern list with weights
---@return number Score, string[] Matched keywords
local function match_patterns(text, patterns)
local score = 0
local matched = {}
for _, p in ipairs(patterns) do
if text:match(p.pattern) then
score = score + p.weight
table.insert(matched, p.pattern)
end
end
return score, matched
end
--- Detect intent from user prompt
---@param prompt string User's question/request
---@return Intent Detected intent
function M.detect(prompt)
local text = prompt:lower()
-- Calculate raw scores for each intent type (sum of matched weights)
local ask_score, ask_kw = match_patterns(text, ASK_PATTERNS)
local gen_score, gen_kw = match_patterns(text, GENERATE_PATTERNS)
local ref_score, ref_kw = match_patterns(text, REFACTOR_PATTERNS)
local doc_score, doc_kw = match_patterns(text, DOCUMENT_PATTERNS)
local test_score, test_kw = match_patterns(text, TEST_PATTERNS)
local proj_score, _ = match_patterns(text, PROJECT_CONTEXT_PATTERNS)
local explore_score, _ = match_patterns(text, EXPLORE_PATTERNS)
-- Find the winner by raw score (highest accumulated weight)
local scores = {
{ type = "ask", score = ask_score, keywords = ask_kw },
{ type = "generate", score = gen_score, keywords = gen_kw },
{ type = "refactor", score = ref_score, keywords = ref_kw },
{ type = "document", score = doc_score, keywords = doc_kw },
{ type = "test", score = test_score, keywords = test_kw },
}
table.sort(scores, function(a, b)
return a.score > b.score
end)
local winner = scores[1]
-- If top score is very low, default to ask (safer for Q&A)
if winner.score < 0.3 then
winner = { type = "ask", score = 0.5, keywords = {} }
end
-- If ask and generate are close AND there's a question mark, prefer ask
if winner.type == "generate" and ask_score > 0 then
if text:match("%?%s*$") and ask_score >= gen_score * 0.5 then
winner = { type = "ask", score = ask_score, keywords = ask_kw }
end
end
-- Determine if "explain" vs "ask" (explain needs more context)
local intent_type = winner.type
if intent_type == "ask" then
-- "explain" if asking about how something works, otherwise "ask"
if text:match("explain") or text:match("how%s+does") or text:match("walk%s+me%s+through") then
intent_type = "explain"
end
end
-- Normalize confidence to 0-1 range (cap at reasonable max)
local confidence = math.min(winner.score / 2, 1.0)
-- Check if exploration is needed (full project indexing)
local needs_exploration = explore_score >= 0.9
---@type Intent
local intent = {
type = intent_type,
confidence = confidence,
needs_project_context = proj_score > 0.5 or needs_exploration,
needs_brain_context = intent_type == "ask" or intent_type == "explain",
needs_exploration = needs_exploration,
keywords = winner.keywords,
}
return intent
end
--- Get prompt type for system prompt selection
---@param intent Intent Detected intent
---@return string Prompt type for prompts.system
function M.get_prompt_type(intent)
local mapping = {
ask = "ask",
explain = "ask", -- Uses same prompt as ask
generate = "code_generation",
refactor = "refactor",
document = "document",
test = "test",
}
return mapping[intent.type] or "ask"
end
--- Check if intent requires code output
---@param intent Intent
---@return boolean
function M.produces_code(intent)
local code_intents = {
generate = true,
refactor = true,
document = true, -- Documentation is code (comments)
test = true,
}
return code_intents[intent.type] or false
end
return M

View File

@@ -0,0 +1,456 @@
---@mod codetyper.agent.inject Smart code injection with import handling
---@brief [[
--- Intelligent code injection that properly handles imports, merging them
--- into existing import sections instead of blindly appending.
---@brief ]]
local M = {}
---@class ImportConfig
---@field pattern string Lua pattern to match import statements
---@field multi_line boolean Whether imports can span multiple lines
---@field sort_key function|nil Function to extract sort key from import
---@field group_by function|nil Function to group imports
---@class ParsedCode
---@field imports string[] Import statements
---@field body string[] Non-import code lines
---@field import_lines table<number, boolean> Map of line numbers that are imports
local utils = require("codetyper.support.utils")
local languages = require("codetyper.params.agent.languages")
local import_patterns = languages.import_patterns
--- Check if a line is an import statement for the given language
---@param line string
---@param patterns table[] Import patterns for the language
---@return boolean is_import
---@return boolean is_multi_line
local function is_import_line(line, patterns)
for _, p in ipairs(patterns) do
if line:match(p.pattern) then
return true, p.multi_line or false
end
end
return false, false
end
--- Check if a line ends a multi-line import
---@param line string
---@param filetype string
---@return boolean
local function ends_multiline_import(line, filetype)
return utils.ends_multiline_import(line, filetype)
end
--- Parse code into imports and body
---@param code string|string[] Code to parse
---@param filetype string File type/extension
---@return ParsedCode
function M.parse_code(code, filetype)
local lines
if type(code) == "string" then
lines = vim.split(code, "\n", { plain = true })
else
lines = code
end
local patterns = import_patterns[filetype] or import_patterns.javascript
local result = {
imports = {},
body = {},
import_lines = {},
}
local in_multiline_import = false
local current_import_lines = {}
for i, line in ipairs(lines) do
if in_multiline_import then
-- Continue collecting multi-line import
table.insert(current_import_lines, line)
if ends_multiline_import(line, filetype) then
-- Complete the multi-line import
table.insert(result.imports, table.concat(current_import_lines, "\n"))
for j = i - #current_import_lines + 1, i do
result.import_lines[j] = true
end
current_import_lines = {}
in_multiline_import = false
end
else
local is_import, is_multi = is_import_line(line, patterns)
if is_import then
result.import_lines[i] = true
if is_multi and not ends_multiline_import(line, filetype) then
-- Start of multi-line import
in_multiline_import = true
current_import_lines = { line }
else
-- Single-line import
table.insert(result.imports, line)
end
else
-- Non-import line
table.insert(result.body, line)
end
end
end
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
if #current_import_lines > 0 then
table.insert(result.imports, table.concat(current_import_lines, "\n"))
end
return result
end
--- Find the import section range in a buffer
---@param bufnr number Buffer number
---@param filetype string
---@return number|nil start_line First import line (1-indexed)
---@return number|nil end_line Last import line (1-indexed)
function M.find_import_section(bufnr, filetype)
if not vim.api.nvim_buf_is_valid(bufnr) then
return nil, nil
end
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local patterns = import_patterns[filetype] or import_patterns.javascript
local first_import = nil
local last_import = nil
local in_multiline = false
local consecutive_non_import = 0
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
for i, line in ipairs(lines) do
if in_multiline then
last_import = i
consecutive_non_import = 0
if ends_multiline_import(line, filetype) then
in_multiline = false
end
else
local is_import, is_multi = is_import_line(line, patterns)
if is_import then
if not first_import then
first_import = i
end
last_import = i
consecutive_non_import = 0
if is_multi and not ends_multiline_import(line, filetype) then
in_multiline = true
end
elseif utils.is_empty_or_comment(line, filetype) then
-- Allow gaps in import section
if first_import then
consecutive_non_import = consecutive_non_import + 1
if consecutive_non_import > max_gap then
-- Too many non-import lines, import section has ended
break
end
end
else
-- Non-import, non-empty line
if first_import then
-- Import section has ended
break
end
end
end
end
return first_import, last_import
end
--- Get existing imports from a buffer
---@param bufnr number Buffer number
---@param filetype string
---@return string[] Existing import statements
function M.get_existing_imports(bufnr, filetype)
local start_line, end_line = M.find_import_section(bufnr, filetype)
if not start_line then
return {}
end
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
local parsed = M.parse_code(lines, filetype)
return parsed.imports
end
--- Normalize an import for comparison (remove whitespace variations)
---@param import_str string
---@return string
local function normalize_import(import_str)
-- Remove trailing semicolon for comparison
local normalized = import_str:gsub(";%s*$", "")
-- Remove all whitespace around braces, commas, colons
normalized = normalized:gsub("%s*{%s*", "{")
normalized = normalized:gsub("%s*}%s*", "}")
normalized = normalized:gsub("%s*,%s*", ",")
normalized = normalized:gsub("%s*:%s*", ":")
-- Collapse multiple whitespace to single space
normalized = normalized:gsub("%s+", " ")
-- Trim leading/trailing whitespace
normalized = normalized:match("^%s*(.-)%s*$")
return normalized
end
--- Check if two imports are duplicates
---@param import1 string
---@param import2 string
---@return boolean
local function are_duplicate_imports(import1, import2)
return normalize_import(import1) == normalize_import(import2)
end
--- Merge new imports with existing ones, avoiding duplicates
---@param existing string[] Existing imports
---@param new_imports string[] New imports to merge
---@return string[] Merged imports
function M.merge_imports(existing, new_imports)
local merged = {}
local seen = {}
-- Add existing imports
for _, imp in ipairs(existing) do
local normalized = normalize_import(imp)
if not seen[normalized] then
seen[normalized] = true
table.insert(merged, imp)
end
end
-- Add new imports that aren't duplicates
for _, imp in ipairs(new_imports) do
local normalized = normalize_import(imp)
if not seen[normalized] then
seen[normalized] = true
table.insert(merged, imp)
end
end
return merged
end
--- Sort imports by their source/module
---@param imports string[]
---@param filetype string
---@return string[]
function M.sort_imports(imports, filetype)
-- Group imports: stdlib/builtin first, then third-party, then local
local builtin = {}
local third_party = {}
local local_imports = {}
for _, imp in ipairs(imports) do
local category = utils.classify_import(imp, filetype)
if category == "builtin" then
table.insert(builtin, imp)
elseif category == "local" then
table.insert(local_imports, imp)
else
table.insert(third_party, imp)
end
end
-- Sort each group alphabetically
table.sort(builtin)
table.sort(third_party)
table.sort(local_imports)
-- Combine with proper spacing
local result = {}
for _, imp in ipairs(builtin) do
table.insert(result, imp)
end
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
table.insert(result, "") -- Blank line between groups
end
for _, imp in ipairs(third_party) do
table.insert(result, imp)
end
if #third_party > 0 and #local_imports > 0 then
table.insert(result, "") -- Blank line between groups
end
for _, imp in ipairs(local_imports) do
table.insert(result, imp)
end
return result
end
---@class InjectResult
---@field success boolean
---@field imports_added number Number of new imports added
---@field imports_merged boolean Whether imports were merged into existing section
---@field body_lines number Number of body lines injected
--- Smart inject code into a buffer, properly handling imports
---@param bufnr number Target buffer
---@param code string|string[] Code to inject
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
---@return InjectResult
function M.inject(bufnr, code, opts)
opts = opts or {}
if not vim.api.nvim_buf_is_valid(bufnr) then
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
end
-- Get filetype
local filetype = opts.filetype
if not filetype then
local bufname = vim.api.nvim_buf_get_name(bufnr)
filetype = vim.fn.fnamemodify(bufname, ":e")
end
-- Parse the code to separate imports from body
local parsed = M.parse_code(code, filetype)
local result = {
success = true,
imports_added = 0,
imports_merged = false,
body_lines = #parsed.body,
}
-- Handle imports first if there are any
if #parsed.imports > 0 then
local import_start, import_end = M.find_import_section(bufnr, filetype)
if import_start then
-- Merge with existing import section
local existing_imports = M.get_existing_imports(bufnr, filetype)
local merged = M.merge_imports(existing_imports, parsed.imports)
-- Count how many new imports were actually added
result.imports_added = #merged - #existing_imports
result.imports_merged = true
-- Optionally sort imports
if opts.sort_imports ~= false then
merged = M.sort_imports(merged, filetype)
end
-- Convert back to lines (handling multi-line imports)
local import_lines = {}
for _, imp in ipairs(merged) do
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
table.insert(import_lines, line)
end
end
-- Replace the import section
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
-- Adjust line numbers for body injection
local lines_diff = #import_lines - (import_end - import_start + 1)
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
opts.range.start_line = opts.range.start_line + lines_diff
if opts.range.end_line then
opts.range.end_line = opts.range.end_line + lines_diff
end
end
else
-- No existing import section, add imports at the top
-- Find the first non-comment, non-empty line
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local insert_at = 0
for i, line in ipairs(lines) do
local trimmed = line:match("^%s*(.-)%s*$")
-- Skip shebang, docstrings, and initial comments
if trimmed ~= "" and not trimmed:match("^#!")
and not trimmed:match("^['\"]") and not utils.is_empty_or_comment(line, filetype) then
insert_at = i - 1
break
end
insert_at = i
end
-- Add imports with a trailing blank line
local import_lines = {}
for _, imp in ipairs(parsed.imports) do
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
table.insert(import_lines, line)
end
end
table.insert(import_lines, "") -- Blank line after imports
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
result.imports_added = #parsed.imports
result.imports_merged = false
-- Adjust body injection range
if opts.range and opts.range.start_line then
opts.range.start_line = opts.range.start_line + #import_lines
if opts.range.end_line then
opts.range.end_line = opts.range.end_line + #import_lines
end
end
end
end
-- Handle body (non-import) code
if #parsed.body > 0 then
-- Filter out empty leading/trailing lines from body
local body_lines = parsed.body
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
table.remove(body_lines, 1)
end
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
table.remove(body_lines)
end
if #body_lines > 0 then
local line_count = vim.api.nvim_buf_line_count(bufnr)
local strategy = opts.strategy or "append"
if strategy == "replace" and opts.range then
local start_line = math.max(1, opts.range.start_line)
local end_line = math.min(line_count, opts.range.end_line)
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
elseif strategy == "insert" and opts.range then
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
else
-- Default: append
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
if last_line:match("%S") then
-- Add blank line for spacing
table.insert(body_lines, 1, "")
end
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
end
result.body_lines = #body_lines
end
end
return result
end
--- Check if code contains imports
---@param code string|string[]
---@param filetype string
---@return boolean
function M.has_imports(code, filetype)
local parsed = M.parse_code(code, filetype)
return #parsed.imports > 0
end
return M

View File

@@ -0,0 +1,192 @@
---@mod codetyper.completion Insert mode completion for file references
---
--- Provides completion for @filename inside /@ @/ tags.
local M = {}
local parser = require("codetyper.parser")
local utils = require("codetyper.support.utils")
--- Get list of files for completion
---@param prefix string Prefix to filter files
---@return table[] List of completion items
local function get_file_completions(prefix)
local cwd = vim.fn.getcwd()
local current_file = vim.fn.expand("%:p")
local current_dir = vim.fn.fnamemodify(current_file, ":h")
local files = {}
-- Use vim.fn.glob to find files matching the prefix
local pattern = prefix .. "*"
-- Determine base directory - use current file's directory if outside cwd
local base_dir = cwd
if current_dir ~= "" and not current_dir:find(cwd, 1, true) then
-- File is outside project, use its directory as base
base_dir = current_dir
end
-- Search in base directory
local matches = vim.fn.glob(base_dir .. "/" .. pattern, false, true)
-- Search with ** for all subdirectories
local deep_matches = vim.fn.glob(base_dir .. "/**/" .. pattern, false, true)
for _, m in ipairs(deep_matches) do
table.insert(matches, m)
end
-- Also search in cwd if different from base_dir
if base_dir ~= cwd then
local cwd_matches = vim.fn.glob(cwd .. "/" .. pattern, false, true)
for _, m in ipairs(cwd_matches) do
table.insert(matches, m)
end
local cwd_deep = vim.fn.glob(cwd .. "/**/" .. pattern, false, true)
for _, m in ipairs(cwd_deep) do
table.insert(matches, m)
end
end
-- Also search specific directories if prefix doesn't have path
if not prefix:find("/") then
local search_dirs = { "src", "lib", "lua", "app", "components", "utils", "tests" }
for _, dir in ipairs(search_dirs) do
local dir_path = base_dir .. "/" .. dir
if vim.fn.isdirectory(dir_path) == 1 then
local dir_matches = vim.fn.glob(dir_path .. "/**/" .. pattern, false, true)
for _, m in ipairs(dir_matches) do
table.insert(matches, m)
end
end
end
end
-- Convert to relative paths and deduplicate
local seen = {}
for _, match in ipairs(matches) do
-- Convert to relative path based on which base it came from
local rel_path
if match:find(base_dir, 1, true) == 1 then
rel_path = match:sub(#base_dir + 2)
elseif match:find(cwd, 1, true) == 1 then
rel_path = match:sub(#cwd + 2)
else
rel_path = vim.fn.fnamemodify(match, ":t") -- Just filename if can't make relative
end
-- Skip directories, coder files, and hidden/generated files
if vim.fn.isdirectory(match) == 0
and not utils.is_coder_file(match)
and not rel_path:match("^%.")
and not rel_path:match("node_modules")
and not rel_path:match("%.git/")
and not rel_path:match("dist/")
and not rel_path:match("build/")
and not seen[rel_path]
then
seen[rel_path] = true
table.insert(files, {
word = rel_path,
abbr = rel_path,
kind = "File",
menu = "[ref]",
})
end
end
-- Sort by length (shorter paths first)
table.sort(files, function(a, b)
return #a.word < #b.word
end)
-- Limit results
local result = {}
for i = 1, math.min(#files, 15) do
result[i] = files[i]
end
return result
end
--- Show file completion popup
function M.show_file_completion()
-- Check if we're in an open prompt tag
local is_inside = parser.is_cursor_in_open_tag()
if not is_inside then
return false
end
-- Get the prefix being typed
local prefix = parser.get_file_ref_prefix()
if prefix == nil then
return false
end
-- Get completions
local items = get_file_completions(prefix)
if #items == 0 then
-- Try with empty prefix to show all files
items = get_file_completions("")
end
if #items > 0 then
-- Calculate start column (position right after @)
local cursor = vim.api.nvim_win_get_cursor(0)
local col = cursor[2] - #prefix + 1 -- 1-indexed for complete()
-- Show completion popup
vim.fn.complete(col, items)
return true
end
return false
end
--- Setup completion for file references (works on ALL files)
function M.setup()
local group = vim.api.nvim_create_augroup("CoderCompletion", { clear = true })
-- Trigger completion on @ in insert mode (works on ALL files)
vim.api.nvim_create_autocmd("InsertCharPre", {
group = group,
pattern = "*",
callback = function()
-- Skip special buffers
if vim.bo.buftype ~= "" then
return
end
if vim.v.char == "@" then
-- Schedule completion popup after the @ is inserted
vim.schedule(function()
-- Check we're in an open tag
local is_inside = parser.is_cursor_in_open_tag()
if not is_inside then
return
end
-- Check we're not typing @/ (closing tag)
local cursor = vim.api.nvim_win_get_cursor(0)
local line = vim.api.nvim_get_current_line()
local next_char = line:sub(cursor[2] + 2, cursor[2] + 2)
if next_char == "/" then
return
end
-- Show file completion
M.show_file_completion()
end)
end
end,
desc = "Trigger file completion on @ inside prompt tags",
})
-- Also allow manual trigger with <C-x><C-f> style keybinding in insert mode
vim.keymap.set("i", "<C-x>@", function()
M.show_file_completion()
end, { silent = true, desc = "Coder: Complete file reference" })
end
return M

View File

@@ -0,0 +1,491 @@
---@mod codetyper.suggestion Inline ghost text suggestions
---@brief [[
--- Provides Copilot-style inline suggestions with ghost text.
--- Uses Copilot when available, falls back to codetyper's own suggestions.
--- Shows suggestions as grayed-out text that can be accepted with Tab.
---@brief ]]
local M = {}
---@class SuggestionState
---@field current_suggestion string|nil Current suggestion text
---@field suggestions string[] List of available suggestions
---@field current_index number Current suggestion index
---@field extmark_id number|nil Virtual text extmark ID
---@field bufnr number|nil Buffer where suggestion is shown
---@field line number|nil Line where suggestion is shown
---@field col number|nil Column where suggestion starts
---@field timer any|nil Debounce timer
---@field using_copilot boolean Whether currently using copilot
local state = {
current_suggestion = nil,
suggestions = {},
current_index = 0,
extmark_id = nil,
bufnr = nil,
line = nil,
col = nil,
timer = nil,
using_copilot = false,
}
--- Namespace for virtual text
local ns = vim.api.nvim_create_namespace("codetyper_suggestion")
--- Highlight group for ghost text
local hl_group = "CmpGhostText"
--- Configuration
local config = {
enabled = true,
auto_trigger = true,
debounce = 150,
use_copilot = true, -- Use copilot when available
keymap = {
accept = "<Tab>",
next = "<M-]>",
prev = "<M-[>",
dismiss = "<C-]>",
},
}
--- Check if copilot is available and enabled
---@return boolean, table|nil available, copilot_suggestion module
local function get_copilot()
if not config.use_copilot then
return false, nil
end
local ok, copilot_suggestion = pcall(require, "copilot.suggestion")
if not ok then
return false, nil
end
-- Check if copilot suggestion is enabled
local ok_client, copilot_client = pcall(require, "copilot.client")
if ok_client and copilot_client.is_disabled and copilot_client.is_disabled() then
return false, nil
end
return true, copilot_suggestion
end
--- Check if suggestion is visible (copilot or codetyper)
---@return boolean
function M.is_visible()
-- Check copilot first
local copilot_ok, copilot_suggestion = get_copilot()
if copilot_ok and copilot_suggestion.is_visible() then
state.using_copilot = true
return true
end
-- Check codetyper's own suggestion
state.using_copilot = false
return state.extmark_id ~= nil and state.current_suggestion ~= nil
end
--- Clear the current suggestion
function M.dismiss()
-- Dismiss copilot if active
local copilot_ok, copilot_suggestion = get_copilot()
if copilot_ok and copilot_suggestion.is_visible() then
copilot_suggestion.dismiss()
end
-- Clear codetyper's suggestion
if state.extmark_id and state.bufnr then
pcall(vim.api.nvim_buf_del_extmark, state.bufnr, ns, state.extmark_id)
end
state.current_suggestion = nil
state.suggestions = {}
state.current_index = 0
state.extmark_id = nil
state.bufnr = nil
state.line = nil
state.col = nil
state.using_copilot = false
end
--- Display suggestion as ghost text
---@param suggestion string The suggestion to display
local function display_suggestion(suggestion)
if not suggestion or suggestion == "" then
return
end
M.dismiss()
local bufnr = vim.api.nvim_get_current_buf()
local cursor = vim.api.nvim_win_get_cursor(0)
local line = cursor[1] - 1
local col = cursor[2]
-- Split suggestion into lines
local lines = vim.split(suggestion, "\n", { plain = true })
-- Build virtual text
local virt_text = {}
local virt_lines = {}
-- First line goes inline
if #lines > 0 then
virt_text = { { lines[1], hl_group } }
end
-- Remaining lines go below
for i = 2, #lines do
table.insert(virt_lines, { { lines[i], hl_group } })
end
-- Create extmark with virtual text
local opts = {
virt_text = virt_text,
virt_text_pos = "overlay",
hl_mode = "combine",
}
if #virt_lines > 0 then
opts.virt_lines = virt_lines
end
state.extmark_id = vim.api.nvim_buf_set_extmark(bufnr, ns, line, col, opts)
state.bufnr = bufnr
state.line = line
state.col = col
state.current_suggestion = suggestion
end
--- Accept the current suggestion
---@return boolean Whether a suggestion was accepted
function M.accept()
-- Check copilot first
local copilot_ok, copilot_suggestion = get_copilot()
if copilot_ok and copilot_suggestion.is_visible() then
copilot_suggestion.accept()
state.using_copilot = false
return true
end
-- Accept codetyper's suggestion
if not M.is_visible() then
return false
end
local suggestion = state.current_suggestion
local bufnr = state.bufnr
local line = state.line
local col = state.col
M.dismiss()
if suggestion and bufnr and line ~= nil and col ~= nil then
-- Get current line content
local current_line = vim.api.nvim_buf_get_lines(bufnr, line, line + 1, false)[1] or ""
-- Split suggestion into lines
local suggestion_lines = vim.split(suggestion, "\n", { plain = true })
if #suggestion_lines == 1 then
-- Single line - insert at cursor
local new_line = current_line:sub(1, col) .. suggestion .. current_line:sub(col + 1)
vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, { new_line })
-- Move cursor to end of inserted text
vim.api.nvim_win_set_cursor(0, { line + 1, col + #suggestion })
else
-- Multi-line - insert at cursor
local first_line = current_line:sub(1, col) .. suggestion_lines[1]
local last_line = suggestion_lines[#suggestion_lines] .. current_line:sub(col + 1)
local new_lines = { first_line }
for i = 2, #suggestion_lines - 1 do
table.insert(new_lines, suggestion_lines[i])
end
table.insert(new_lines, last_line)
vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, new_lines)
-- Move cursor to end of last line
vim.api.nvim_win_set_cursor(0, { line + #new_lines, #suggestion_lines[#suggestion_lines] })
end
return true
end
return false
end
--- Show next suggestion
function M.next()
-- Check copilot first
local copilot_ok, copilot_suggestion = get_copilot()
if copilot_ok and copilot_suggestion.is_visible() then
copilot_suggestion.next()
return
end
-- Codetyper's suggestions
if #state.suggestions <= 1 then
return
end
state.current_index = (state.current_index % #state.suggestions) + 1
display_suggestion(state.suggestions[state.current_index])
end
--- Show previous suggestion
function M.prev()
-- Check copilot first
local copilot_ok, copilot_suggestion = get_copilot()
if copilot_ok and copilot_suggestion.is_visible() then
copilot_suggestion.prev()
return
end
-- Codetyper's suggestions
if #state.suggestions <= 1 then
return
end
state.current_index = state.current_index - 1
if state.current_index < 1 then
state.current_index = #state.suggestions
end
display_suggestion(state.suggestions[state.current_index])
end
--- Get suggestions from brain/indexer
---@param prefix string Current word prefix
---@param context table Context info
---@return string[] suggestions
local function get_suggestions(prefix, context)
local suggestions = {}
-- Get completions from brain
local ok_brain, brain = pcall(require, "codetyper.brain")
if ok_brain and brain.is_initialized and brain.is_initialized() then
local result = brain.query({
query = prefix,
max_results = 5,
types = { "pattern" },
})
if result and result.nodes then
for _, node in ipairs(result.nodes) do
if node.c and node.c.code then
table.insert(suggestions, node.c.code)
end
end
end
end
-- Get completions from indexer
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
if ok_indexer then
local index = indexer.load_index()
if index and index.symbols then
for symbol, _ in pairs(index.symbols) do
if symbol:lower():find(prefix:lower(), 1, true) and symbol ~= prefix then
-- Just complete the symbol name
local completion = symbol:sub(#prefix + 1)
if completion ~= "" then
table.insert(suggestions, completion)
end
end
end
end
end
-- Buffer-based completions
local bufnr = vim.api.nvim_get_current_buf()
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local seen = {}
for _, line in ipairs(lines) do
for word in line:gmatch("[%a_][%w_]*") do
if
#word > #prefix
and word:lower():find(prefix:lower(), 1, true) == 1
and not seen[word]
and word ~= prefix
then
seen[word] = true
local completion = word:sub(#prefix + 1)
if completion ~= "" then
table.insert(suggestions, completion)
end
end
end
end
return suggestions
end
--- Trigger suggestion generation
function M.trigger()
if not config.enabled then
return
end
-- If copilot is available and has a suggestion, don't show codetyper's
local copilot_ok, copilot_suggestion = get_copilot()
if copilot_ok and copilot_suggestion.is_visible() then
-- Copilot is handling suggestions
state.using_copilot = true
return
end
-- Cancel existing timer
if state.timer then
state.timer:stop()
state.timer = nil
end
-- Get current context
local cursor = vim.api.nvim_win_get_cursor(0)
local line = vim.api.nvim_get_current_line()
local col = cursor[2]
local before_cursor = line:sub(1, col)
-- Extract prefix (word being typed)
local prefix = before_cursor:match("[%a_][%w_]*$") or ""
if #prefix < 2 then
M.dismiss()
return
end
-- Debounce - wait a bit longer to let copilot try first
local debounce_time = copilot_ok and (config.debounce + 200) or config.debounce
state.timer = vim.defer_fn(function()
-- Check again if copilot has shown something
if copilot_ok and copilot_suggestion.is_visible() then
state.using_copilot = true
state.timer = nil
return
end
local suggestions = get_suggestions(prefix, {
line = line,
col = col,
bufnr = vim.api.nvim_get_current_buf(),
})
if #suggestions > 0 then
state.suggestions = suggestions
state.current_index = 1
display_suggestion(suggestions[1])
else
M.dismiss()
end
state.timer = nil
end, debounce_time)
end
--- Setup keymaps
local function setup_keymaps()
-- Accept with Tab (only when suggestion visible)
vim.keymap.set("i", config.keymap.accept, function()
if M.is_visible() then
M.accept()
return ""
end
-- Fallback to normal Tab behavior
return vim.api.nvim_replace_termcodes("<Tab>", true, false, true)
end, { expr = true, silent = true, desc = "Accept codetyper suggestion" })
-- Next suggestion
vim.keymap.set("i", config.keymap.next, function()
M.next()
end, { silent = true, desc = "Next codetyper suggestion" })
-- Previous suggestion
vim.keymap.set("i", config.keymap.prev, function()
M.prev()
end, { silent = true, desc = "Previous codetyper suggestion" })
-- Dismiss
vim.keymap.set("i", config.keymap.dismiss, function()
M.dismiss()
end, { silent = true, desc = "Dismiss codetyper suggestion" })
end
--- Setup autocmds for auto-trigger
local function setup_autocmds()
local group = vim.api.nvim_create_augroup("CodetypeSuggestion", { clear = true })
-- Trigger on text change in insert mode
if config.auto_trigger then
vim.api.nvim_create_autocmd("TextChangedI", {
group = group,
callback = function()
M.trigger()
end,
})
end
-- Dismiss on leaving insert mode
vim.api.nvim_create_autocmd("InsertLeave", {
group = group,
callback = function()
M.dismiss()
end,
})
-- Dismiss on cursor move (not from typing)
vim.api.nvim_create_autocmd("CursorMovedI", {
group = group,
callback = function()
-- Only dismiss if cursor moved significantly
if state.line ~= nil then
local cursor = vim.api.nvim_win_get_cursor(0)
if cursor[1] - 1 ~= state.line then
M.dismiss()
end
end
end,
})
end
--- Setup highlight group
local function setup_highlights()
-- Use Comment highlight or define custom ghost text style
vim.api.nvim_set_hl(0, hl_group, { link = "Comment" })
end
--- Setup the suggestion system
---@param opts? table Configuration options
function M.setup(opts)
if opts then
config = vim.tbl_deep_extend("force", config, opts)
end
setup_highlights()
setup_keymaps()
setup_autocmds()
end
--- Enable suggestions
function M.enable()
config.enabled = true
end
--- Disable suggestions
function M.disable()
config.enabled = false
M.dismiss()
end
--- Toggle suggestions
function M.toggle()
if config.enabled then
M.disable()
else
M.enable()
end
end
return M

View File

@@ -0,0 +1,585 @@
---@mod codetyper.indexer.analyzer Code analyzer using Tree-sitter
---@brief [[
--- Analyzes source files to extract functions, classes, exports, and imports.
--- Uses Tree-sitter when available, falls back to pattern matching.
---@brief ]]
local M = {}
local utils = require("codetyper.support.utils")
local scanner = require("codetyper.features.indexer.scanner")
--- Language-specific query patterns for Tree-sitter
local TS_QUERIES = {
lua = {
functions = [[
(function_declaration name: (identifier) @name) @func
(function_definition) @func
(local_function name: (identifier) @name) @func
(assignment_statement
(variable_list name: (identifier) @name)
(expression_list value: (function_definition) @func))
]],
exports = [[
(return_statement (expression_list (table_constructor))) @export
]],
},
typescript = {
functions = [[
(function_declaration name: (identifier) @name) @func
(method_definition name: (property_identifier) @name) @func
(arrow_function) @func
(lexical_declaration
(variable_declarator name: (identifier) @name value: (arrow_function) @func))
]],
exports = [[
(export_statement) @export
]],
imports = [[
(import_statement) @import
]],
},
javascript = {
functions = [[
(function_declaration name: (identifier) @name) @func
(method_definition name: (property_identifier) @name) @func
(arrow_function) @func
]],
exports = [[
(export_statement) @export
]],
imports = [[
(import_statement) @import
]],
},
python = {
functions = [[
(function_definition name: (identifier) @name) @func
]],
classes = [[
(class_definition name: (identifier) @name) @class
]],
imports = [[
(import_statement) @import
(import_from_statement) @import
]],
},
go = {
functions = [[
(function_declaration name: (identifier) @name) @func
(method_declaration name: (field_identifier) @name) @func
]],
imports = [[
(import_declaration) @import
]],
},
rust = {
functions = [[
(function_item name: (identifier) @name) @func
]],
imports = [[
(use_declaration) @import
]],
},
}
-- Forward declaration for analyze_tree_generic (defined below)
local analyze_tree_generic
--- Hash file content for change detection
---@param content string
---@return string
local function hash_content(content)
local hash = 0
for i = 1, math.min(#content, 10000) do
hash = (hash * 31 + string.byte(content, i)) % 2147483647
end
return string.format("%08x", hash)
end
--- Try to get Tree-sitter parser for a language
---@param lang string
---@return boolean
local function has_ts_parser(lang)
local ok = pcall(vim.treesitter.language.inspect, lang)
return ok
end
--- Analyze file using Tree-sitter
---@param filepath string
---@param lang string
---@param content string
---@return table|nil
local function analyze_with_treesitter(filepath, lang, content)
if not has_ts_parser(lang) then
return nil
end
local result = {
functions = {},
classes = {},
exports = {},
imports = {},
}
-- Create a temporary buffer for parsing
local bufnr = vim.api.nvim_create_buf(false, true)
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, vim.split(content, "\n"))
local ok, parser = pcall(vim.treesitter.get_parser, bufnr, lang)
if not ok or not parser then
vim.api.nvim_buf_delete(bufnr, { force = true })
return nil
end
local tree = parser:parse()[1]
if not tree then
vim.api.nvim_buf_delete(bufnr, { force = true })
return nil
end
local root = tree:root()
local queries = TS_QUERIES[lang]
if not queries then
-- Fallback: walk tree manually for common patterns
result = analyze_tree_generic(root, bufnr)
else
-- Use language-specific queries
if queries.functions then
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.functions)
if query_ok then
for id, node in query:iter_captures(root, bufnr, 0, -1) do
local capture_name = query.captures[id]
if capture_name == "func" or capture_name == "name" then
local start_row, _, end_row, _ = node:range()
local name = nil
-- Try to get name from sibling capture or child
if capture_name == "func" then
local name_node = node:field("name")[1]
if name_node then
name = vim.treesitter.get_node_text(name_node, bufnr)
end
else
name = vim.treesitter.get_node_text(node, bufnr)
end
if name and not vim.tbl_contains(vim.tbl_map(function(f)
return f.name
end, result.functions), name) then
table.insert(result.functions, {
name = name,
line = start_row + 1,
end_line = end_row + 1,
params = {},
})
end
end
end
end
end
if queries.classes then
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.classes)
if query_ok then
for id, node in query:iter_captures(root, bufnr, 0, -1) do
local capture_name = query.captures[id]
if capture_name == "class" then
local start_row, _, end_row, _ = node:range()
local name_node = node:field("name")[1]
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
table.insert(result.classes, {
name = name,
line = start_row + 1,
end_line = end_row + 1,
methods = {},
})
end
end
end
end
if queries.exports then
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.exports)
if query_ok then
for _, node in query:iter_captures(root, bufnr, 0, -1) do
local text = vim.treesitter.get_node_text(node, bufnr)
local start_row, _, _, _ = node:range()
-- Extract export names (simplified)
local names = {}
for name in text:gmatch("export%s+[%w_]+%s+([%w_]+)") do
table.insert(names, name)
end
for name in text:gmatch("export%s*{([^}]+)}") do
for n in name:gmatch("([%w_]+)") do
table.insert(names, n)
end
end
for _, name in ipairs(names) do
table.insert(result.exports, {
name = name,
type = "unknown",
line = start_row + 1,
})
end
end
end
end
if queries.imports then
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.imports)
if query_ok then
for _, node in query:iter_captures(root, bufnr, 0, -1) do
local text = vim.treesitter.get_node_text(node, bufnr)
local start_row, _, _, _ = node:range()
-- Extract import source
local source = text:match('["\']([^"\']+)["\']')
if source then
table.insert(result.imports, {
source = source,
names = {},
line = start_row + 1,
})
end
end
end
end
end
vim.api.nvim_buf_delete(bufnr, { force = true })
return result
end
--- Generic tree analysis for unsupported languages
---@param root TSNode
---@param bufnr number
---@return table
analyze_tree_generic = function(root, bufnr)
local result = {
functions = {},
classes = {},
exports = {},
imports = {},
}
local function visit(node)
local node_type = node:type()
-- Common function patterns
if
node_type:match("function")
or node_type:match("method")
or node_type == "arrow_function"
or node_type == "func_literal"
then
local start_row, _, end_row, _ = node:range()
local name_node = node:field("name")[1]
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
table.insert(result.functions, {
name = name,
line = start_row + 1,
end_line = end_row + 1,
params = {},
})
end
-- Common class patterns
if node_type:match("class") or node_type == "struct_item" or node_type == "impl_item" then
local start_row, _, end_row, _ = node:range()
local name_node = node:field("name")[1]
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
table.insert(result.classes, {
name = name,
line = start_row + 1,
end_line = end_row + 1,
methods = {},
})
end
-- Recurse into children
for child in node:iter_children() do
visit(child)
end
end
visit(root)
return result
end
--- Analyze file using pattern matching (fallback)
---@param content string
---@param lang string
---@return table
local function analyze_with_patterns(content, lang)
local result = {
functions = {},
classes = {},
exports = {},
imports = {},
}
local lines = vim.split(content, "\n")
-- Language-specific patterns
local patterns = {
lua = {
func_start = "^%s*local?%s*function%s+([%w_%.]+)",
func_assign = "^%s*([%w_%.]+)%s*=%s*function",
module_return = "^return%s+M",
},
javascript = {
func_start = "^%s*function%s+([%w_]+)",
func_arrow = "^%s*const%s+([%w_]+)%s*=%s*",
class_start = "^%s*class%s+([%w_]+)",
export_line = "^%s*export%s+",
import_line = "^%s*import%s+",
},
typescript = {
func_start = "^%s*function%s+([%w_]+)",
func_arrow = "^%s*const%s+([%w_]+)%s*=%s*",
class_start = "^%s*class%s+([%w_]+)",
export_line = "^%s*export%s+",
import_line = "^%s*import%s+",
},
python = {
func_start = "^%s*def%s+([%w_]+)",
class_start = "^%s*class%s+([%w_]+)",
import_line = "^%s*import%s+",
from_import = "^%s*from%s+",
},
go = {
func_start = "^func%s+([%w_]+)",
method_start = "^func%s+%([^%)]+%)%s+([%w_]+)",
import_line = "^import%s+",
},
rust = {
func_start = "^%s*pub?%s*fn%s+([%w_]+)",
struct_start = "^%s*pub?%s*struct%s+([%w_]+)",
impl_start = "^%s*impl%s+([%w_<>]+)",
use_line = "^%s*use%s+",
},
}
local lang_patterns = patterns[lang] or patterns.javascript
for i, line in ipairs(lines) do
-- Functions
if lang_patterns.func_start then
local name = line:match(lang_patterns.func_start)
if name then
table.insert(result.functions, {
name = name,
line = i,
end_line = i,
params = {},
})
end
end
if lang_patterns.func_arrow then
local name = line:match(lang_patterns.func_arrow)
if name and line:match("=>") then
table.insert(result.functions, {
name = name,
line = i,
end_line = i,
params = {},
})
end
end
if lang_patterns.func_assign then
local name = line:match(lang_patterns.func_assign)
if name then
table.insert(result.functions, {
name = name,
line = i,
end_line = i,
params = {},
})
end
end
if lang_patterns.method_start then
local name = line:match(lang_patterns.method_start)
if name then
table.insert(result.functions, {
name = name,
line = i,
end_line = i,
params = {},
})
end
end
-- Classes
if lang_patterns.class_start then
local name = line:match(lang_patterns.class_start)
if name then
table.insert(result.classes, {
name = name,
line = i,
end_line = i,
methods = {},
})
end
end
if lang_patterns.struct_start then
local name = line:match(lang_patterns.struct_start)
if name then
table.insert(result.classes, {
name = name,
line = i,
end_line = i,
methods = {},
})
end
end
-- Exports
if lang_patterns.export_line and line:match(lang_patterns.export_line) then
local name = line:match("export%s+[%w_]+%s+([%w_]+)")
or line:match("export%s+default%s+([%w_]+)")
or line:match("export%s+{%s*([%w_]+)")
if name then
table.insert(result.exports, {
name = name,
type = "unknown",
line = i,
})
end
end
-- Imports
if lang_patterns.import_line and line:match(lang_patterns.import_line) then
local source = line:match('["\']([^"\']+)["\']')
if source then
table.insert(result.imports, {
source = source,
names = {},
line = i,
})
end
end
if lang_patterns.from_import and line:match(lang_patterns.from_import) then
local source = line:match("from%s+([%w_%.]+)")
if source then
table.insert(result.imports, {
source = source,
names = {},
line = i,
})
end
end
if lang_patterns.use_line and line:match(lang_patterns.use_line) then
local source = line:match("use%s+([%w_:]+)")
if source then
table.insert(result.imports, {
source = source,
names = {},
line = i,
})
end
end
end
-- For Lua, infer exports from module table
if lang == "lua" then
for _, func in ipairs(result.functions) do
if func.name:match("^M%.") then
local name = func.name:gsub("^M%.", "")
table.insert(result.exports, {
name = name,
type = "function",
line = func.line,
})
end
end
end
return result
end
--- Analyze a single file
---@param filepath string Full path to file
---@return FileIndex|nil
function M.analyze_file(filepath)
local content = utils.read_file(filepath)
if not content then
return nil
end
local lang = scanner.get_language(filepath)
-- Map to Tree-sitter language names
local ts_lang_map = {
typescript = "typescript",
typescriptreact = "tsx",
javascript = "javascript",
javascriptreact = "javascript",
python = "python",
go = "go",
rust = "rust",
lua = "lua",
}
local ts_lang = ts_lang_map[lang] or lang
-- Try Tree-sitter first
local analysis = analyze_with_treesitter(filepath, ts_lang, content)
-- Fallback to pattern matching
if not analysis then
analysis = analyze_with_patterns(content, lang)
end
return {
path = filepath,
language = lang,
hash = hash_content(content),
exports = analysis.exports,
imports = analysis.imports,
functions = analysis.functions,
classes = analysis.classes,
last_indexed = os.time(),
}
end
--- Extract exports from a buffer
---@param bufnr number
---@return Export[]
function M.extract_exports(bufnr)
local filepath = vim.api.nvim_buf_get_name(bufnr)
local analysis = M.analyze_file(filepath)
return analysis and analysis.exports or {}
end
--- Extract functions from a buffer
---@param bufnr number
---@return FunctionInfo[]
function M.extract_functions(bufnr)
local filepath = vim.api.nvim_buf_get_name(bufnr)
local analysis = M.analyze_file(filepath)
return analysis and analysis.functions or {}
end
--- Extract imports from a buffer
---@param bufnr number
---@return Import[]
function M.extract_imports(bufnr)
local filepath = vim.api.nvim_buf_get_name(bufnr)
local analysis = M.analyze_file(filepath)
return analysis and analysis.imports or {}
end
return M

View File

@@ -0,0 +1,604 @@
---@mod codetyper.indexer Project indexer for Codetyper.nvim
---@brief [[
--- Indexes project structure, dependencies, and code symbols.
--- Stores knowledge in .coder/ directory for enriching LLM context.
---@brief ]]
local M = {}
local utils = require("codetyper.support.utils")
--- Index schema version for migrations
local INDEX_VERSION = 1
--- Index file name
local INDEX_FILE = "index.json"
--- Debounce timer for file indexing
local index_timer = nil
local INDEX_DEBOUNCE_MS = 500
--- Default indexer configuration
local default_config = {
enabled = true,
auto_index = true,
index_on_open = false,
max_file_size = 100000,
excluded_dirs = { "node_modules", "dist", "build", ".git", ".coder", "__pycache__", "vendor", "target" },
index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" },
memory = {
enabled = true,
max_memories = 1000,
prune_threshold = 0.1,
},
}
--- Current configuration
---@type table
local config = vim.deepcopy(default_config)
--- Cached project index
---@type table<string, ProjectIndex>
local index_cache = {}
---@class ProjectIndex
---@field version number Index schema version
---@field project_root string Absolute path to project
---@field project_name string Project name
---@field project_type string "node"|"rust"|"go"|"python"|"lua"|"unknown"
---@field dependencies table<string, string> name -> version
---@field dev_dependencies table<string, string> name -> version
---@field files table<string, FileIndex> path -> FileIndex
---@field symbols table<string, string[]> symbol -> [file paths]
---@field last_indexed number Timestamp
---@field stats {files: number, functions: number, classes: number, exports: number}
---@class FileIndex
---@field path string Relative path from project root
---@field language string Detected language
---@field hash string Content hash for change detection
---@field exports Export[] Exported symbols
---@field imports Import[] Dependencies
---@field functions FunctionInfo[]
---@field classes ClassInfo[]
---@field last_indexed number Timestamp
---@class Export
---@field name string Symbol name
---@field type string "function"|"class"|"constant"|"type"|"variable"
---@field line number Line number
---@class Import
---@field source string Import source/module
---@field names string[] Imported names
---@field line number Line number
---@class FunctionInfo
---@field name string Function name
---@field params string[] Parameter names
---@field line number Start line
---@field end_line number End line
---@field docstring string|nil Documentation
---@class ClassInfo
---@field name string Class name
---@field methods string[] Method names
---@field line number Start line
---@field end_line number End line
---@field docstring string|nil Documentation
--- Get the index file path
---@return string|nil
local function get_index_path()
local root = utils.get_project_root()
if not root then
return nil
end
return root .. "/.coder/" .. INDEX_FILE
end
--- Create empty index structure
---@return ProjectIndex
local function create_empty_index()
local root = utils.get_project_root()
return {
version = INDEX_VERSION,
project_root = root or "",
project_name = root and vim.fn.fnamemodify(root, ":t") or "",
project_type = "unknown",
dependencies = {},
dev_dependencies = {},
files = {},
symbols = {},
last_indexed = os.time(),
stats = {
files = 0,
functions = 0,
classes = 0,
exports = 0,
},
}
end
--- Load index from disk
---@return ProjectIndex|nil
function M.load_index()
local root = utils.get_project_root()
if not root then
return nil
end
-- Check cache first
if index_cache[root] then
return index_cache[root]
end
local path = get_index_path()
if not path then
return nil
end
local content = utils.read_file(path)
if not content then
return nil
end
local ok, index = pcall(vim.json.decode, content)
if not ok or not index then
return nil
end
-- Validate version
if index.version ~= INDEX_VERSION then
-- Index needs migration or rebuild
return nil
end
-- Cache it
index_cache[root] = index
return index
end
--- Save index to disk
---@param index ProjectIndex
---@return boolean
function M.save_index(index)
local root = utils.get_project_root()
if not root then
return false
end
-- Ensure .coder directory exists
local coder_dir = root .. "/.coder"
utils.ensure_dir(coder_dir)
local path = get_index_path()
if not path then
return false
end
local ok, encoded = pcall(vim.json.encode, index)
if not ok then
return false
end
local success = utils.write_file(path, encoded)
if success then
-- Update cache
index_cache[root] = index
end
return success
end
--- Index the entire project
---@param callback? fun(index: ProjectIndex)
---@return ProjectIndex|nil
function M.index_project(callback)
local scanner = require("codetyper.features.indexer.scanner")
local analyzer = require("codetyper.features.indexer.analyzer")
local index = create_empty_index()
local root = utils.get_project_root()
if not root then
if callback then
callback(index)
end
return index
end
-- Detect project type and parse dependencies
index.project_type = scanner.detect_project_type(root)
local deps = scanner.parse_dependencies(root, index.project_type)
index.dependencies = deps.dependencies or {}
index.dev_dependencies = deps.dev_dependencies or {}
-- Get all indexable files
local files = scanner.get_indexable_files(root, config)
-- Index each file
local total_functions = 0
local total_classes = 0
local total_exports = 0
for _, filepath in ipairs(files) do
local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "")
local file_index = analyzer.analyze_file(filepath)
if file_index then
file_index.path = relative_path
index.files[relative_path] = file_index
-- Update symbol index
for _, exp in ipairs(file_index.exports or {}) do
if not index.symbols[exp.name] then
index.symbols[exp.name] = {}
end
table.insert(index.symbols[exp.name], relative_path)
total_exports = total_exports + 1
end
total_functions = total_functions + #(file_index.functions or {})
total_classes = total_classes + #(file_index.classes or {})
end
end
-- Update stats
index.stats = {
files = #files,
functions = total_functions,
classes = total_classes,
exports = total_exports,
}
index.last_indexed = os.time()
-- Save to disk
M.save_index(index)
-- Store memories
local memory = require("codetyper.features.indexer.memory")
memory.store_index_summary(index)
-- Sync project summary to brain
M.sync_project_to_brain(index, files, root)
if callback then
callback(index)
end
return index
end
--- Sync project index to brain
---@param index ProjectIndex
---@param files string[] List of file paths
---@param root string Project root
function M.sync_project_to_brain(index, files, root)
local ok_brain, brain = pcall(require, "codetyper.brain")
if not ok_brain or not brain.is_initialized or not brain.is_initialized() then
return
end
-- Store project-level pattern
brain.learn({
type = "pattern",
file = root,
content = {
summary = "Project: "
.. index.project_name
.. " ("
.. index.project_type
.. ") - "
.. index.stats.files
.. " files",
detail = string.format(
"%d functions, %d classes, %d exports",
index.stats.functions,
index.stats.classes,
index.stats.exports
),
},
context = {
file = root,
project_type = index.project_type,
dependencies = index.dependencies,
},
})
-- Store key file patterns (files with most functions/classes)
local key_files = {}
for path, file_index in pairs(index.files) do
local score = #(file_index.functions or {}) + (#(file_index.classes or {}) * 2)
if score >= 3 then
table.insert(key_files, { path = path, index = file_index, score = score })
end
end
table.sort(key_files, function(a, b)
return a.score > b.score
end)
-- Store top 20 key files in brain
for i, kf in ipairs(key_files) do
if i > 20 then
break
end
M.sync_to_brain(root .. "/" .. kf.path, kf.index)
end
end
--- Index a single file (incremental update)
---@param filepath string
---@return FileIndex|nil
function M.index_file(filepath)
local analyzer = require("codetyper.features.indexer.analyzer")
local memory = require("codetyper.features.indexer.memory")
local root = utils.get_project_root()
if not root then
return nil
end
-- Load existing index
local index = M.load_index() or create_empty_index()
-- Analyze file
local file_index = analyzer.analyze_file(filepath)
if not file_index then
return nil
end
local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "")
file_index.path = relative_path
-- Remove old symbol references for this file
for symbol, paths in pairs(index.symbols) do
for i = #paths, 1, -1 do
if paths[i] == relative_path then
table.remove(paths, i)
end
end
if #paths == 0 then
index.symbols[symbol] = nil
end
end
-- Add new file index
index.files[relative_path] = file_index
-- Update symbol index
for _, exp in ipairs(file_index.exports or {}) do
if not index.symbols[exp.name] then
index.symbols[exp.name] = {}
end
table.insert(index.symbols[exp.name], relative_path)
end
-- Recalculate stats
local total_functions = 0
local total_classes = 0
local total_exports = 0
local file_count = 0
for _, f in pairs(index.files) do
file_count = file_count + 1
total_functions = total_functions + #(f.functions or {})
total_classes = total_classes + #(f.classes or {})
total_exports = total_exports + #(f.exports or {})
end
index.stats = {
files = file_count,
functions = total_functions,
classes = total_classes,
exports = total_exports,
}
index.last_indexed = os.time()
-- Save to disk
M.save_index(index)
-- Store file memory
memory.store_file_memory(relative_path, file_index)
-- Sync to brain if available
M.sync_to_brain(filepath, file_index)
return file_index
end
--- Sync file analysis to brain system
---@param filepath string Full file path
---@param file_index FileIndex File analysis
function M.sync_to_brain(filepath, file_index)
local ok_brain, brain = pcall(require, "codetyper.brain")
if not ok_brain or not brain.is_initialized or not brain.is_initialized() then
return
end
-- Only store if file has meaningful content
local funcs = file_index.functions or {}
local classes = file_index.classes or {}
if #funcs == 0 and #classes == 0 then
return
end
-- Build summary
local parts = {}
if #funcs > 0 then
local func_names = {}
for i, f in ipairs(funcs) do
if i <= 5 then
table.insert(func_names, f.name)
end
end
table.insert(parts, "functions: " .. table.concat(func_names, ", "))
if #funcs > 5 then
table.insert(parts, "(+" .. (#funcs - 5) .. " more)")
end
end
if #classes > 0 then
local class_names = {}
for _, c in ipairs(classes) do
table.insert(class_names, c.name)
end
table.insert(parts, "classes: " .. table.concat(class_names, ", "))
end
local filename = vim.fn.fnamemodify(filepath, ":t")
local summary = filename .. " - " .. table.concat(parts, "; ")
-- Learn this pattern in brain
brain.learn({
type = "pattern",
file = filepath,
content = {
summary = summary,
detail = #funcs .. " functions, " .. #classes .. " classes",
},
context = {
file = file_index.path or filepath,
language = file_index.language,
functions = funcs,
classes = classes,
exports = file_index.exports,
imports = file_index.imports,
},
})
end
--- Schedule file indexing with debounce
---@param filepath string
function M.schedule_index_file(filepath)
if not config.enabled or not config.auto_index then
return
end
-- Check if file should be indexed
local scanner = require("codetyper.features.indexer.scanner")
if not scanner.should_index(filepath, config) then
return
end
-- Cancel existing timer
if index_timer then
index_timer:stop()
end
-- Schedule new index
index_timer = vim.defer_fn(function()
M.index_file(filepath)
index_timer = nil
end, INDEX_DEBOUNCE_MS)
end
--- Get relevant context for a prompt
---@param opts {file: string, intent: table|nil, prompt: string, scope: string|nil}
---@return table Context information
function M.get_context_for(opts)
local memory = require("codetyper.features.indexer.memory")
local index = M.load_index()
local context = {
project_type = "unknown",
dependencies = {},
relevant_files = {},
relevant_symbols = {},
patterns = {},
}
if not index then
return context
end
context.project_type = index.project_type
context.dependencies = index.dependencies
-- Find relevant symbols from prompt
local words = {}
for word in opts.prompt:gmatch("%w+") do
if #word > 2 then
words[word:lower()] = true
end
end
-- Match symbols
for symbol, files in pairs(index.symbols) do
if words[symbol:lower()] then
context.relevant_symbols[symbol] = files
end
end
-- Get file context if available
if opts.file then
local root = utils.get_project_root()
if root then
local relative_path = opts.file:gsub("^" .. vim.pesc(root) .. "/", "")
local file_index = index.files[relative_path]
if file_index then
context.current_file = file_index
end
end
end
-- Get relevant memories
context.patterns = memory.get_relevant(opts.prompt, 5)
return context
end
--- Get index status
---@return table Status information
function M.get_status()
local index = M.load_index()
if not index then
return {
indexed = false,
stats = nil,
last_indexed = nil,
}
end
return {
indexed = true,
stats = index.stats,
last_indexed = index.last_indexed,
project_type = index.project_type,
}
end
--- Clear the project index
function M.clear()
local root = utils.get_project_root()
if root then
index_cache[root] = nil
end
local path = get_index_path()
if path and utils.file_exists(path) then
os.remove(path)
end
end
--- Setup the indexer with configuration
---@param opts? table Configuration options
function M.setup(opts)
if opts then
config = vim.tbl_deep_extend("force", config, opts)
end
-- Index on startup if configured
if config.index_on_open then
vim.defer_fn(function()
M.index_project()
end, 1000)
end
end
--- Get current configuration
---@return table
function M.get_config()
return vim.deepcopy(config)
end
return M

View File

@@ -0,0 +1,539 @@
---@mod codetyper.indexer.memory Memory persistence manager
---@brief [[
--- Stores and retrieves learned patterns and memories in .coder/memories/.
--- Supports session history for learning from interactions.
---@brief ]]
local M = {}
local utils = require("codetyper.support.utils")
--- Memory directories
local MEMORIES_DIR = "memories"
local SESSIONS_DIR = "sessions"
local FILES_DIR = "files"
--- Memory files
local PATTERNS_FILE = "patterns.json"
local CONVENTIONS_FILE = "conventions.json"
local SYMBOLS_FILE = "symbols.json"
--- In-memory cache
local cache = {
patterns = nil,
conventions = nil,
symbols = nil,
}
---@class Memory
---@field id string Unique identifier
---@field type "pattern"|"convention"|"session"|"interaction"
---@field content string The learned information
---@field context table Where/when learned
---@field weight number Importance score (0.0-1.0)
---@field created_at number Timestamp
---@field updated_at number Last update timestamp
---@field used_count number Times referenced
--- Get the memories base directory
---@return string|nil
local function get_memories_dir()
local root = utils.get_project_root()
if not root then
return nil
end
return root .. "/.coder/" .. MEMORIES_DIR
end
--- Get the sessions directory
---@return string|nil
local function get_sessions_dir()
local root = utils.get_project_root()
if not root then
return nil
end
return root .. "/.coder/" .. SESSIONS_DIR
end
--- Ensure memories directory exists
---@return boolean
local function ensure_memories_dir()
local dir = get_memories_dir()
if not dir then
return false
end
utils.ensure_dir(dir)
utils.ensure_dir(dir .. "/" .. FILES_DIR)
return true
end
--- Ensure sessions directory exists
---@return boolean
local function ensure_sessions_dir()
local dir = get_sessions_dir()
if not dir then
return false
end
return utils.ensure_dir(dir)
end
--- Generate a unique ID
---@return string
local function generate_id()
return string.format("mem_%d_%s", os.time(), string.sub(tostring(math.random()), 3, 8))
end
--- Load a memory file
---@param filename string
---@return table
local function load_memory_file(filename)
local dir = get_memories_dir()
if not dir then
return {}
end
local path = dir .. "/" .. filename
local content = utils.read_file(path)
if not content then
return {}
end
local ok, data = pcall(vim.json.decode, content)
if not ok or not data then
return {}
end
return data
end
--- Save a memory file
---@param filename string
---@param data table
---@return boolean
local function save_memory_file(filename, data)
if not ensure_memories_dir() then
return false
end
local dir = get_memories_dir()
if not dir then
return false
end
local path = dir .. "/" .. filename
local ok, encoded = pcall(vim.json.encode, data)
if not ok then
return false
end
return utils.write_file(path, encoded)
end
--- Hash a file path for storage
---@param filepath string
---@return string
local function hash_path(filepath)
local hash = 0
for i = 1, #filepath do
hash = (hash * 31 + string.byte(filepath, i)) % 2147483647
end
return string.format("%08x", hash)
end
--- Load patterns from cache or disk
---@return table
function M.load_patterns()
if cache.patterns then
return cache.patterns
end
cache.patterns = load_memory_file(PATTERNS_FILE)
return cache.patterns
end
--- Load conventions from cache or disk
---@return table
function M.load_conventions()
if cache.conventions then
return cache.conventions
end
cache.conventions = load_memory_file(CONVENTIONS_FILE)
return cache.conventions
end
--- Load symbols from cache or disk
---@return table
function M.load_symbols()
if cache.symbols then
return cache.symbols
end
cache.symbols = load_memory_file(SYMBOLS_FILE)
return cache.symbols
end
--- Store a new memory
---@param memory Memory
---@return boolean
function M.store_memory(memory)
memory.id = memory.id or generate_id()
memory.created_at = memory.created_at or os.time()
memory.updated_at = os.time()
memory.used_count = memory.used_count or 0
memory.weight = memory.weight or 0.5
local filename
if memory.type == "pattern" then
filename = PATTERNS_FILE
cache.patterns = nil
elseif memory.type == "convention" then
filename = CONVENTIONS_FILE
cache.conventions = nil
else
filename = PATTERNS_FILE
cache.patterns = nil
end
local data = load_memory_file(filename)
data[memory.id] = memory
return save_memory_file(filename, data)
end
--- Store file-specific memory
---@param relative_path string Relative file path
---@param file_index table FileIndex data
---@return boolean
function M.store_file_memory(relative_path, file_index)
if not ensure_memories_dir() then
return false
end
local dir = get_memories_dir()
if not dir then
return false
end
local hash = hash_path(relative_path)
local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json"
local data = {
path = relative_path,
indexed_at = os.time(),
functions = file_index.functions or {},
classes = file_index.classes or {},
exports = file_index.exports or {},
imports = file_index.imports or {},
}
local ok, encoded = pcall(vim.json.encode, data)
if not ok then
return false
end
return utils.write_file(path, encoded)
end
--- Load file-specific memory
---@param relative_path string
---@return table|nil
function M.load_file_memory(relative_path)
local dir = get_memories_dir()
if not dir then
return nil
end
local hash = hash_path(relative_path)
local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json"
local content = utils.read_file(path)
if not content then
return nil
end
local ok, data = pcall(vim.json.decode, content)
if not ok then
return nil
end
return data
end
--- Store index summary as memories
---@param index ProjectIndex
function M.store_index_summary(index)
-- Store project type convention
if index.project_type and index.project_type ~= "unknown" then
M.store_memory({
type = "convention",
content = "Project uses " .. index.project_type .. " ecosystem",
context = {
project_root = index.project_root,
detected_at = os.time(),
},
weight = 0.9,
})
end
-- Store dependency patterns
local dep_count = 0
for _ in pairs(index.dependencies or {}) do
dep_count = dep_count + 1
end
if dep_count > 0 then
local deps_list = {}
for name, _ in pairs(index.dependencies) do
table.insert(deps_list, name)
end
M.store_memory({
type = "pattern",
content = "Project dependencies: " .. table.concat(deps_list, ", "),
context = {
dependency_count = dep_count,
},
weight = 0.7,
})
end
-- Update symbol cache
cache.symbols = nil
save_memory_file(SYMBOLS_FILE, index.symbols or {})
end
--- Store session interaction
---@param interaction {prompt: string, response: string, file: string|nil, success: boolean}
function M.store_session(interaction)
if not ensure_sessions_dir() then
return
end
local dir = get_sessions_dir()
if not dir then
return
end
-- Use date-based session files
local date = os.date("%Y-%m-%d")
local path = dir .. "/" .. date .. ".json"
local sessions = {}
local content = utils.read_file(path)
if content then
local ok, data = pcall(vim.json.decode, content)
if ok and data then
sessions = data
end
end
table.insert(sessions, {
timestamp = os.time(),
prompt = interaction.prompt,
response = string.sub(interaction.response or "", 1, 500), -- Truncate
file = interaction.file,
success = interaction.success,
})
-- Limit session size
if #sessions > 100 then
sessions = { unpack(sessions, #sessions - 99) }
end
local ok, encoded = pcall(vim.json.encode, sessions)
if ok then
utils.write_file(path, encoded)
end
end
--- Get relevant memories for a query
---@param query string Search query
---@param limit number Maximum results
---@return Memory[]
function M.get_relevant(query, limit)
limit = limit or 10
local results = {}
-- Tokenize query
local query_words = {}
for word in query:lower():gmatch("%w+") do
if #word > 2 then
query_words[word] = true
end
end
-- Search patterns
local patterns = M.load_patterns()
for _, memory in pairs(patterns) do
local score = 0
local content_lower = (memory.content or ""):lower()
for word in pairs(query_words) do
if content_lower:find(word, 1, true) then
score = score + 1
end
end
if score > 0 then
memory.relevance_score = score * (memory.weight or 0.5)
table.insert(results, memory)
end
end
-- Search conventions
local conventions = M.load_conventions()
for _, memory in pairs(conventions) do
local score = 0
local content_lower = (memory.content or ""):lower()
for word in pairs(query_words) do
if content_lower:find(word, 1, true) then
score = score + 1
end
end
if score > 0 then
memory.relevance_score = score * (memory.weight or 0.5)
table.insert(results, memory)
end
end
-- Sort by relevance
table.sort(results, function(a, b)
return (a.relevance_score or 0) > (b.relevance_score or 0)
end)
-- Limit results
local limited = {}
for i = 1, math.min(limit, #results) do
limited[i] = results[i]
end
return limited
end
--- Update memory usage count
---@param memory_id string
function M.update_usage(memory_id)
local patterns = M.load_patterns()
if patterns[memory_id] then
patterns[memory_id].used_count = (patterns[memory_id].used_count or 0) + 1
patterns[memory_id].updated_at = os.time()
save_memory_file(PATTERNS_FILE, patterns)
cache.patterns = nil
return
end
local conventions = M.load_conventions()
if conventions[memory_id] then
conventions[memory_id].used_count = (conventions[memory_id].used_count or 0) + 1
conventions[memory_id].updated_at = os.time()
save_memory_file(CONVENTIONS_FILE, conventions)
cache.conventions = nil
end
end
--- Get all memories
---@return {patterns: table, conventions: table, symbols: table}
function M.get_all()
return {
patterns = M.load_patterns(),
conventions = M.load_conventions(),
symbols = M.load_symbols(),
}
end
--- Clear all memories
---@param pattern? string Optional pattern to match memory IDs
function M.clear(pattern)
if not pattern then
-- Clear all
cache = { patterns = nil, conventions = nil, symbols = nil }
save_memory_file(PATTERNS_FILE, {})
save_memory_file(CONVENTIONS_FILE, {})
save_memory_file(SYMBOLS_FILE, {})
return
end
-- Clear matching pattern
local patterns = M.load_patterns()
for id in pairs(patterns) do
if id:match(pattern) then
patterns[id] = nil
end
end
save_memory_file(PATTERNS_FILE, patterns)
cache.patterns = nil
local conventions = M.load_conventions()
for id in pairs(conventions) do
if id:match(pattern) then
conventions[id] = nil
end
end
save_memory_file(CONVENTIONS_FILE, conventions)
cache.conventions = nil
end
--- Prune low-weight memories
---@param threshold number Weight threshold (default: 0.1)
function M.prune(threshold)
threshold = threshold or 0.1
local patterns = M.load_patterns()
local pruned = 0
for id, memory in pairs(patterns) do
if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then
patterns[id] = nil
pruned = pruned + 1
end
end
if pruned > 0 then
save_memory_file(PATTERNS_FILE, patterns)
cache.patterns = nil
end
local conventions = M.load_conventions()
for id, memory in pairs(conventions) do
if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then
conventions[id] = nil
pruned = pruned + 1
end
end
if pruned > 0 then
save_memory_file(CONVENTIONS_FILE, conventions)
cache.conventions = nil
end
return pruned
end
--- Get memory statistics
---@return table
function M.get_stats()
local patterns = M.load_patterns()
local conventions = M.load_conventions()
local symbols = M.load_symbols()
local pattern_count = 0
for _ in pairs(patterns) do
pattern_count = pattern_count + 1
end
local convention_count = 0
for _ in pairs(conventions) do
convention_count = convention_count + 1
end
local symbol_count = 0
for _ in pairs(symbols) do
symbol_count = symbol_count + 1
end
return {
patterns = pattern_count,
conventions = convention_count,
symbols = symbol_count,
total = pattern_count + convention_count,
}
end
return M

View File

@@ -0,0 +1,409 @@
---@mod codetyper.indexer.scanner File scanner for project indexing
---@brief [[
--- Discovers indexable files, detects project type, and parses dependencies.
---@brief ]]
local M = {}
local utils = require("codetyper.support.utils")
--- Project type markers
local PROJECT_MARKERS = {
node = { "package.json" },
rust = { "Cargo.toml" },
go = { "go.mod" },
python = { "pyproject.toml", "setup.py", "requirements.txt" },
lua = { "init.lua", ".luarc.json" },
ruby = { "Gemfile" },
java = { "pom.xml", "build.gradle" },
csharp = { "*.csproj", "*.sln" },
}
--- File extension to language mapping
local EXTENSION_LANGUAGE = {
lua = "lua",
ts = "typescript",
tsx = "typescriptreact",
js = "javascript",
jsx = "javascriptreact",
py = "python",
go = "go",
rs = "rust",
rb = "ruby",
java = "java",
c = "c",
cpp = "cpp",
h = "c",
hpp = "cpp",
cs = "csharp",
}
--- Default ignore patterns
local DEFAULT_IGNORES = {
"^%.", -- Hidden files/folders
"^node_modules$",
"^__pycache__$",
"^%.git$",
"^%.coder$",
"^dist$",
"^build$",
"^target$",
"^vendor$",
"^%.next$",
"^%.nuxt$",
"^coverage$",
"%.min%.js$",
"%.min%.css$",
"%.map$",
"%.lock$",
"%-lock%.json$",
}
--- Detect project type from root markers
---@param root string Project root path
---@return string Project type
function M.detect_project_type(root)
for project_type, markers in pairs(PROJECT_MARKERS) do
for _, marker in ipairs(markers) do
local path = root .. "/" .. marker
if marker:match("^%*") then
-- Glob pattern
local pattern = marker:gsub("^%*", "")
local entries = vim.fn.glob(root .. "/*" .. pattern, false, true)
if #entries > 0 then
return project_type
end
else
if utils.file_exists(path) then
return project_type
end
end
end
end
return "unknown"
end
--- Parse project dependencies
---@param root string Project root path
---@param project_type string Project type
---@return {dependencies: table<string, string>, dev_dependencies: table<string, string>}
function M.parse_dependencies(root, project_type)
local deps = {
dependencies = {},
dev_dependencies = {},
}
if project_type == "node" then
deps = M.parse_package_json(root)
elseif project_type == "rust" then
deps = M.parse_cargo_toml(root)
elseif project_type == "go" then
deps = M.parse_go_mod(root)
elseif project_type == "python" then
deps = M.parse_python_deps(root)
end
return deps
end
--- Parse package.json for Node.js projects
---@param root string Project root path
---@return {dependencies: table, dev_dependencies: table}
function M.parse_package_json(root)
local path = root .. "/package.json"
local content = utils.read_file(path)
if not content then
return { dependencies = {}, dev_dependencies = {} }
end
local ok, pkg = pcall(vim.json.decode, content)
if not ok or not pkg then
return { dependencies = {}, dev_dependencies = {} }
end
return {
dependencies = pkg.dependencies or {},
dev_dependencies = pkg.devDependencies or {},
}
end
--- Parse Cargo.toml for Rust projects
---@param root string Project root path
---@return {dependencies: table, dev_dependencies: table}
function M.parse_cargo_toml(root)
local path = root .. "/Cargo.toml"
local content = utils.read_file(path)
if not content then
return { dependencies = {}, dev_dependencies = {} }
end
local deps = {}
local dev_deps = {}
local in_deps = false
local in_dev_deps = false
for line in content:gmatch("[^\n]+") do
if line:match("^%[dependencies%]") then
in_deps = true
in_dev_deps = false
elseif line:match("^%[dev%-dependencies%]") then
in_deps = false
in_dev_deps = true
elseif line:match("^%[") then
in_deps = false
in_dev_deps = false
elseif in_deps or in_dev_deps then
local name, version = line:match('^([%w_%-]+)%s*=%s*"([^"]+)"')
if not name then
name = line:match("^([%w_%-]+)%s*=")
version = "workspace"
end
if name then
if in_deps then
deps[name] = version or "unknown"
else
dev_deps[name] = version or "unknown"
end
end
end
end
return { dependencies = deps, dev_dependencies = dev_deps }
end
--- Parse go.mod for Go projects
---@param root string Project root path
---@return {dependencies: table, dev_dependencies: table}
function M.parse_go_mod(root)
local path = root .. "/go.mod"
local content = utils.read_file(path)
if not content then
return { dependencies = {}, dev_dependencies = {} }
end
local deps = {}
local in_require = false
for line in content:gmatch("[^\n]+") do
if line:match("^require%s*%(") then
in_require = true
elseif line:match("^%)") then
in_require = false
elseif in_require then
local module, version = line:match("^%s*([%w%.%-%_/]+)%s+([%w%.%-]+)")
if module then
deps[module] = version
end
else
local module, version = line:match("^require%s+([%w%.%-%_/]+)%s+([%w%.%-]+)")
if module then
deps[module] = version
end
end
end
return { dependencies = deps, dev_dependencies = {} }
end
--- Parse Python dependencies (pyproject.toml or requirements.txt)
---@param root string Project root path
---@return {dependencies: table, dev_dependencies: table}
function M.parse_python_deps(root)
local deps = {}
local dev_deps = {}
-- Try pyproject.toml first
local pyproject = root .. "/pyproject.toml"
local content = utils.read_file(pyproject)
if content then
-- Simple parsing for dependencies
local in_deps = false
local in_dev = false
for line in content:gmatch("[^\n]+") do
if line:match("^%[project%.dependencies%]") or line:match("^dependencies%s*=") then
in_deps = true
in_dev = false
elseif line:match("dev") and line:match("dependencies") then
in_deps = false
in_dev = true
elseif line:match("^%[") then
in_deps = false
in_dev = false
elseif in_deps or in_dev then
local name = line:match('"([%w_%-]+)')
if name then
if in_deps then
deps[name] = "latest"
else
dev_deps[name] = "latest"
end
end
end
end
end
-- Fallback to requirements.txt
local req_file = root .. "/requirements.txt"
content = utils.read_file(req_file)
if content then
for line in content:gmatch("[^\n]+") do
if not line:match("^#") and not line:match("^%s*$") then
local name, version = line:match("^([%w_%-]+)==([%d%.]+)")
if not name then
name = line:match("^([%w_%-]+)")
version = "latest"
end
if name then
deps[name] = version or "latest"
end
end
end
end
return { dependencies = deps, dev_dependencies = dev_deps }
end
--- Check if a file/directory should be ignored
---@param name string File or directory name
---@param config table Indexer configuration
---@return boolean
function M.should_ignore(name, config)
-- Check default patterns
for _, pattern in ipairs(DEFAULT_IGNORES) do
if name:match(pattern) then
return true
end
end
-- Check config excluded dirs
if config and config.excluded_dirs then
for _, dir in ipairs(config.excluded_dirs) do
if name == dir then
return true
end
end
end
return false
end
--- Check if a file should be indexed
---@param filepath string Full file path
---@param config table Indexer configuration
---@return boolean
function M.should_index(filepath, config)
local name = vim.fn.fnamemodify(filepath, ":t")
local ext = vim.fn.fnamemodify(filepath, ":e")
-- Check if it's a coder file
if utils.is_coder_file(filepath) then
return false
end
-- Check file size
if config and config.max_file_size then
local stat = vim.loop.fs_stat(filepath)
if stat and stat.size > config.max_file_size then
return false
end
end
-- Check extension
if config and config.index_extensions then
local valid_ext = false
for _, allowed_ext in ipairs(config.index_extensions) do
if ext == allowed_ext then
valid_ext = true
break
end
end
if not valid_ext then
return false
end
end
-- Check ignore patterns
if M.should_ignore(name, config) then
return false
end
return true
end
--- Get all indexable files in the project
---@param root string Project root path
---@param config table Indexer configuration
---@return string[] List of file paths
function M.get_indexable_files(root, config)
local files = {}
local function scan_dir(path)
local handle = vim.loop.fs_scandir(path)
if not handle then
return
end
while true do
local name, type = vim.loop.fs_scandir_next(handle)
if not name then
break
end
local full_path = path .. "/" .. name
if M.should_ignore(name, config) then
goto continue
end
if type == "directory" then
scan_dir(full_path)
elseif type == "file" then
if M.should_index(full_path, config) then
table.insert(files, full_path)
end
end
::continue::
end
end
scan_dir(root)
return files
end
--- Get language from file extension
---@param filepath string File path
---@return string Language name
function M.get_language(filepath)
local ext = vim.fn.fnamemodify(filepath, ":e")
return EXTENSION_LANGUAGE[ext] or ext
end
--- Read .gitignore patterns
---@param root string Project root
---@return string[] Patterns
function M.read_gitignore(root)
local patterns = {}
local path = root .. "/.gitignore"
local content = utils.read_file(path)
if not content then
return patterns
end
for line in content:gmatch("[^\n]+") do
-- Skip comments and empty lines
if not line:match("^#") and not line:match("^%s*$") then
-- Convert gitignore pattern to Lua pattern (simplified)
local pattern = line:gsub("^/", "^"):gsub("%*%*", ".*"):gsub("%*", "[^/]*"):gsub("%?", ".")
table.insert(patterns, pattern)
end
end
return patterns
end
return M