Refactor: Restructure project into core, features, adapters, and config modules
This commit is contained in:
268
lua/codetyper/features/agents/context_builder.lua
Normal file
268
lua/codetyper/features/agents/context_builder.lua
Normal file
@@ -0,0 +1,268 @@
|
||||
---@mod codetyper.agent.context_builder Context builder for agent prompts
|
||||
---
|
||||
--- Builds rich context including project structure, memories, and conventions
|
||||
--- to help the LLM understand the codebase.
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local params = require("codetyper.params.agent.context")
|
||||
|
||||
--- Get project structure as a tree string
|
||||
---@param max_depth? number Maximum depth to traverse (default: 3)
|
||||
---@param max_files? number Maximum files to show (default: 50)
|
||||
---@return string Project tree
|
||||
function M.get_project_structure(max_depth, max_files)
|
||||
max_depth = max_depth or 3
|
||||
max_files = max_files or 50
|
||||
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local lines = { "PROJECT STRUCTURE:", root, "" }
|
||||
local file_count = 0
|
||||
|
||||
-- Common ignore patterns
|
||||
local ignore_patterns = params.ignore_patterns
|
||||
|
||||
local function should_ignore(name)
|
||||
for _, pattern in ipairs(ignore_patterns) do
|
||||
if name:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function traverse(path, depth, prefix)
|
||||
if depth > max_depth or file_count >= max_files then
|
||||
return
|
||||
end
|
||||
|
||||
local entries = {}
|
||||
local handle = vim.loop.fs_scandir(path)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
if not should_ignore(name) then
|
||||
table.insert(entries, { name = name, type = type })
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort: directories first, then alphabetically
|
||||
table.sort(entries, function(a, b)
|
||||
if a.type == "directory" and b.type ~= "directory" then
|
||||
return true
|
||||
elseif a.type ~= "directory" and b.type == "directory" then
|
||||
return false
|
||||
else
|
||||
return a.name < b.name
|
||||
end
|
||||
end)
|
||||
|
||||
for i, entry in ipairs(entries) do
|
||||
if file_count >= max_files then
|
||||
table.insert(lines, prefix .. "... (truncated)")
|
||||
return
|
||||
end
|
||||
|
||||
local is_last = (i == #entries)
|
||||
local branch = is_last and "└── " or "├── "
|
||||
local new_prefix = prefix .. (is_last and " " or "│ ")
|
||||
|
||||
local icon = entry.type == "directory" and "/" or ""
|
||||
table.insert(lines, prefix .. branch .. entry.name .. icon)
|
||||
file_count = file_count + 1
|
||||
|
||||
if entry.type == "directory" then
|
||||
traverse(path .. "/" .. entry.name, depth + 1, new_prefix)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
traverse(root, 1, "")
|
||||
|
||||
if file_count >= max_files then
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "(Structure truncated at " .. max_files .. " entries)")
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Get key files that are important for understanding the project
|
||||
---@return table<string, string> Map of filename to description
|
||||
function M.get_key_files()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local key_files = {}
|
||||
|
||||
local important_files = {
|
||||
["package.json"] = "Node.js project config",
|
||||
["Cargo.toml"] = "Rust project config",
|
||||
["go.mod"] = "Go module config",
|
||||
["pyproject.toml"] = "Python project config",
|
||||
["setup.py"] = "Python setup config",
|
||||
["Makefile"] = "Build configuration",
|
||||
["CMakeLists.txt"] = "CMake config",
|
||||
[".gitignore"] = "Git ignore patterns",
|
||||
["README.md"] = "Project documentation",
|
||||
["init.lua"] = "Neovim plugin entry",
|
||||
["plugin.lua"] = "Neovim plugin config",
|
||||
}
|
||||
|
||||
for filename, desc in paparams.important_filesnd
|
||||
|
||||
return key_files
|
||||
end
|
||||
|
||||
--- Detect project type and language
|
||||
---@return table { type: string, language: string, framework?: string }
|
||||
function M.detect_project_type()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
|
||||
local indicators = {
|
||||
["package.json"] = { type = "node", language = "javascript/typescript" },
|
||||
["Cargo.toml"] = { type = "rust", language = "rust" },
|
||||
["go.mod"] = { type = "go", language = "go" },
|
||||
["pyproject.toml"] = { type = "python", language = "python" },
|
||||
["setup.py"] = { type = "python", language = "python" },
|
||||
["Gemfile"] = { type = "ruby", language = "ruby" },
|
||||
["pom.xml"] = { type = "maven", language = "java" },
|
||||
["build.gradle"] = { type = "gradle", language = "java/kotlin" },
|
||||
}
|
||||
|
||||
-- Check for Neovim plugin specifically
|
||||
if vim.fn.isdirectoparams.indicators return info
|
||||
end
|
||||
end
|
||||
|
||||
return { type = "unknown", language = "unknown" }
|
||||
end
|
||||
|
||||
--- Get memories/patterns from the brain system
|
||||
---@return string Formatted memories context
|
||||
function M.get_memories_context()
|
||||
local ok_memory, memory = pcall(require, "codetyper.indexer.memory")
|
||||
if not ok_memory then
|
||||
return ""
|
||||
end
|
||||
|
||||
local all = memory.get_all()
|
||||
if not all then
|
||||
return ""
|
||||
end
|
||||
|
||||
local lines = {}
|
||||
|
||||
-- Add patterns
|
||||
if all.patterns and next(all.patterns) then
|
||||
table.insert(lines, "LEARNED PATTERNS:")
|
||||
local count = 0
|
||||
for _, mem in pairs(all.patterns) do
|
||||
if count >= 5 then
|
||||
break
|
||||
end
|
||||
if mem.content then
|
||||
table.insert(lines, " - " .. mem.content:sub(1, 100))
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
-- Add conventions
|
||||
if all.conventions and next(all.conventions) then
|
||||
table.insert(lines, "CODING CONVENTIONS:")
|
||||
local count = 0
|
||||
for _, mem in pairs(all.conventions) do
|
||||
if count >= 5 then
|
||||
break
|
||||
end
|
||||
if mem.content then
|
||||
table.insert(lines, " - " .. mem.content:sub(1, 100))
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Build the full context for agent prompts
|
||||
---@return string Full context string
|
||||
function M.build_full_context()
|
||||
local sections = {}
|
||||
|
||||
-- Project info
|
||||
local project_type = M.detect_project_type()
|
||||
table.insert(sections, string.format(
|
||||
"PROJECT INFO:\n Type: %s\n Language: %s%s\n",
|
||||
project_type.type,
|
||||
project_type.language,
|
||||
project_type.framework and ("\n Framework: " .. project_type.framework) or ""
|
||||
))
|
||||
|
||||
-- Project structure
|
||||
local structure = M.get_project_structure(3, 40)
|
||||
table.insert(sections, structure)
|
||||
|
||||
-- Key files
|
||||
local key_files = M.get_key_files()
|
||||
if next(key_files) then
|
||||
local key_lines = { "", "KEY FILES:" }
|
||||
for name, info in pairs(key_files) do
|
||||
table.insert(key_lines, string.format(" %s - %s", name, info.description))
|
||||
end
|
||||
table.insert(sections, table.concat(key_lines, "\n"))
|
||||
end
|
||||
|
||||
-- Memories
|
||||
local memories = M.get_memories_context()
|
||||
if memories ~= "" then
|
||||
table.insert(sections, "\n" .. memories)
|
||||
end
|
||||
|
||||
return table.concat(sections, "\n")
|
||||
end
|
||||
|
||||
--- Get a compact context summary for token efficiency
|
||||
---@return string Compact context
|
||||
function M.build_compact_context()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local project_type = M.detect_project_type()
|
||||
|
||||
local lines = {
|
||||
"CONTEXT:",
|
||||
" Root: " .. root,
|
||||
" Type: " .. project_type.type .. " (" .. project_type.language .. ")",
|
||||
}
|
||||
|
||||
-- Add main directories
|
||||
local main_dirs = {}
|
||||
local handle = vim.loop.fs_scandir(root)
|
||||
if handle then
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
if type == "directory" and not name:match("^%.") and not name:match("node_modules") then
|
||||
table.insert(main_dirs, name .. "/")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if #main_dirs > 0 then
|
||||
table.sort(main_dirs)
|
||||
table.insert(lines, " Main dirs: " .. table.concat(main_dirs, ", "))
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
return M
|
||||
754
lua/codetyper/features/agents/engine.lua
Normal file
754
lua/codetyper/features/agents/engine.lua
Normal file
@@ -0,0 +1,754 @@
|
||||
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
|
||||
---@brief [[
|
||||
--- Full agentic system that handles multi-file changes via tool calling.
|
||||
--- Multi-file agent system with tool orchestration.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgenticMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_calls? table[] For assistant messages with tool calls
|
||||
---@field tool_call_id? string For tool result messages
|
||||
---@field name? string Tool name for tool results
|
||||
|
||||
---@class AgenticToolCall
|
||||
---@field id string Unique tool call ID
|
||||
---@field type "function"
|
||||
---@field function {name: string, arguments: string|table}
|
||||
|
||||
---@class AgenticOpts
|
||||
---@field task string The task to accomplish
|
||||
---@field files? string[] Initial files to include as context
|
||||
---@field agent? string Agent name to use (default: "coder")
|
||||
---@field model? string Model override
|
||||
---@field max_iterations? number Max tool call rounds (default: 20)
|
||||
---@field on_message? fun(msg: AgenticMessage) Called for each message
|
||||
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
|
||||
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_file_change? fun(path: string, action: string) Called when file is modified
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
|
||||
---@field on_status? fun(status: string) Status updates
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Load agent definition
|
||||
---@param name string Agent name
|
||||
---@return table|nil agent definition
|
||||
local function load_agent(name)
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
local agent_file = agents_dir .. "/" .. name .. ".md"
|
||||
|
||||
-- Check if custom agent exists
|
||||
if vim.fn.filereadable(agent_file) == 1 then
|
||||
local content = table.concat(vim.fn.readfile(agent_file), "\n")
|
||||
-- Parse frontmatter and content
|
||||
local frontmatter = {}
|
||||
local body = content
|
||||
|
||||
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
|
||||
if fm_match then
|
||||
-- Parse YAML-like frontmatter
|
||||
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
|
||||
local key, value = line:match("^(%w+):%s*(.+)$")
|
||||
if key and value then
|
||||
frontmatter[key] = value
|
||||
end
|
||||
end
|
||||
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
|
||||
end
|
||||
|
||||
return {
|
||||
name = name,
|
||||
description = frontmatter.description or "Custom agent: " .. name,
|
||||
system_prompt = body,
|
||||
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
|
||||
model = frontmatter.model,
|
||||
}
|
||||
end
|
||||
|
||||
-- Built-in agents
|
||||
local builtin_agents = require("codetyper.prompts.agents.personas").builtin
|
||||
|
||||
return builtin_agents[name]
|
||||
end
|
||||
|
||||
--- Load rules from .coder/rules/
|
||||
---@return string Combined rules content
|
||||
local function load_rules()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
local rules = {}
|
||||
|
||||
if vim.fn.isdirectory(rules_dir) == 1 then
|
||||
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local content = table.concat(vim.fn.readfile(file), "\n")
|
||||
local filename = vim.fn.fnamemodify(file, ":t:r")
|
||||
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
|
||||
end
|
||||
end
|
||||
|
||||
if #rules > 0 then
|
||||
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
--- Build messages array for API request
|
||||
---@param history AgenticMessage[]
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted messages
|
||||
local function build_messages(history, provider)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
if provider == "claude" then
|
||||
-- Claude uses system parameter, not message
|
||||
-- Skip system messages in array
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
if provider == "claude" then
|
||||
-- Claude format: content is array of blocks
|
||||
message.content = {}
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(message.content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
for _, tc in ipairs(msg.tool_calls) do
|
||||
table.insert(message.content, {
|
||||
type = "tool_use",
|
||||
id = tc.id,
|
||||
name = tc["function"].name,
|
||||
input = type(tc["function"].arguments) == "string"
|
||||
and vim.json.decode(tc["function"].arguments)
|
||||
or tc["function"].arguments,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
if provider == "claude" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = msg.tool_call_id,
|
||||
content = msg.content,
|
||||
},
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Build tools array for API request
|
||||
---@param tool_names string[] Tool names to include
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted tools
|
||||
local function build_tools(tool_names, provider)
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
local tools = {}
|
||||
|
||||
for _, name in ipairs(tool_names) do
|
||||
local tool = tools_mod.get(name)
|
||||
if tool then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
if provider == "claude" then
|
||||
table.insert(tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Execute a tool call
|
||||
---@param tool_call AgenticToolCall
|
||||
---@param opts AgenticOpts
|
||||
---@return string result
|
||||
---@return string|nil error
|
||||
local function execute_tool(tool_call, opts)
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
local name = tool_call["function"].name
|
||||
local args = tool_call["function"].arguments
|
||||
|
||||
-- Parse arguments if string
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
else
|
||||
return "", "Failed to parse tool arguments: " .. args
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify tool start
|
||||
if opts.on_tool_start then
|
||||
opts.on_tool_start(name, args)
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status("Executing: " .. name)
|
||||
end
|
||||
|
||||
-- Execute the tool
|
||||
local tool = tools_mod.get(name)
|
||||
if not tool then
|
||||
local err = "Unknown tool: " .. name
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, nil, err)
|
||||
end
|
||||
return "", err
|
||||
end
|
||||
|
||||
local result, err = tool.func(args, {
|
||||
on_log = function(msg)
|
||||
if opts.on_status then
|
||||
opts.on_status(msg)
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
-- Notify tool end
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, result, err)
|
||||
end
|
||||
|
||||
-- Track file changes
|
||||
if opts.on_file_change and (name == "write" or name == "edit") and not err then
|
||||
opts.on_file_change(args.path, name == "write" and "created" or "modified")
|
||||
end
|
||||
|
||||
if err then
|
||||
return "", err
|
||||
end
|
||||
|
||||
return type(result) == "string" and result or vim.json.encode(result), nil
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return AgenticToolCall[]
|
||||
local function parse_tool_calls(response, provider)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Unified format: content array with tool_use blocks
|
||||
local content = response.content or {}
|
||||
for _, block in ipairs(content) do
|
||||
if block.type == "tool_use" then
|
||||
-- OpenAI expects arguments as JSON string, not table
|
||||
local args = block.input
|
||||
if type(args) == "table" then
|
||||
args = vim.json.encode(args)
|
||||
end
|
||||
|
||||
table.insert(tool_calls, {
|
||||
id = block.id or utils.generate_id("call"),
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = block.name,
|
||||
arguments = args,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Extract text content from response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return string
|
||||
local function extract_content(response, provider)
|
||||
local parts = {}
|
||||
for _, block in ipairs(response.content or {}) do
|
||||
if block.type == "text" then
|
||||
table.insert(parts, block.text)
|
||||
end
|
||||
end
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
--- Check if response indicates completion (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return boolean
|
||||
local function is_complete(response, provider)
|
||||
return response.stop_reason == "end_turn"
|
||||
end
|
||||
|
||||
--- Make API request to LLM with native tool calling support
|
||||
---@param messages table[] Formatted messages
|
||||
---@param tools table[] Formatted tools
|
||||
---@param system_prompt string System prompt
|
||||
---@param provider string "openai"|"claude"|"copilot"
|
||||
---@param model string Model name
|
||||
---@param callback fun(response: table|nil, error: string|nil)
|
||||
local function call_llm(messages, tools, system_prompt, provider, model, callback)
|
||||
local context = {
|
||||
language = "lua",
|
||||
file_content = "",
|
||||
prompt_type = "agent",
|
||||
project_root = vim.fn.getcwd(),
|
||||
cwd = vim.fn.getcwd(),
|
||||
}
|
||||
|
||||
-- Use native tool calling APIs
|
||||
if provider == "copilot" then
|
||||
local client = require("codetyper.core.llm.copilot")
|
||||
|
||||
-- Copilot's generate_with_tools expects messages in a specific format
|
||||
-- Convert to the format it expects
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
-- Convert to our internal format
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or utils.generate_id("call"),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "openai" then
|
||||
local client = require("codetyper.core.llm.openai")
|
||||
|
||||
-- OpenAI's generate_with_tools
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or utils.generate_id("call"),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "ollama" then
|
||||
local client = require("codetyper.core.llm.ollama")
|
||||
|
||||
-- Ollama's generate_with_tools (text-based tool calling)
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
callback(response, nil)
|
||||
end)
|
||||
else
|
||||
-- Fallback for other providers (ollama, etc.) - use text-based parsing
|
||||
local client = require("codetyper.core.llm." .. provider)
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompts = require("codetyper.prompts.agent")
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, prompts.text_user_prefix .. content)
|
||||
elseif msg.role == "assistant" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, prompts.text_assistant_prefix .. content)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add tool descriptions to prompt for text-based providers
|
||||
local tool_desc = require("codetyper.prompts.agent").tool_instructions_text
|
||||
for _, tool in ipairs(tools) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
context.file_content = system_prompt .. tool_desc
|
||||
|
||||
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Parse response for tool calls (text-based fallback)
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
-- Extract text content
|
||||
local text_content = response
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = utils.generate_id("call"),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
if text_content and text_content ~= "" then
|
||||
table.insert(result.content, 1, { type = "text", text = text_content })
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Run the agentic loop
|
||||
---@param opts AgenticOpts
|
||||
function M.run(opts)
|
||||
-- Load agent
|
||||
local agent = load_agent(opts.agent or "coder")
|
||||
if not agent then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Load rules
|
||||
local rules = load_rules()
|
||||
|
||||
-- Build system prompt
|
||||
local system_prompt = agent.system_prompt .. rules
|
||||
|
||||
-- Initialize message history
|
||||
---@type AgenticMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = system_prompt },
|
||||
}
|
||||
|
||||
-- Add initial file context if provided
|
||||
if opts.files and #opts.files > 0 then
|
||||
local file_context = require("codetyper.prompts.agent").format_file_context(opts.files)
|
||||
table.insert(history, { role = "user", content = file_context })
|
||||
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
|
||||
end
|
||||
|
||||
-- Add the task
|
||||
table.insert(history, { role = "user", content = opts.task })
|
||||
|
||||
-- Determine provider
|
||||
local config = require("codetyper").get_config()
|
||||
local provider = config.llm.provider or "copilot"
|
||||
-- Note: Ollama has its own handler in call_llm, don't change it
|
||||
|
||||
-- Get tools for this agent
|
||||
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
|
||||
|
||||
-- Ensure tools are loaded
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
tools_mod.setup()
|
||||
|
||||
-- Build tools for API
|
||||
local tools = build_tools(tool_names, provider)
|
||||
|
||||
-- Iteration tracking
|
||||
local iteration = 0
|
||||
local max_iterations = opts.max_iterations or 20
|
||||
|
||||
--- Process one iteration
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
|
||||
end
|
||||
|
||||
-- Build messages for API
|
||||
local messages = build_messages(history, provider)
|
||||
|
||||
-- Call LLM
|
||||
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
|
||||
if err then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, err)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract content and tool calls
|
||||
local content = extract_content(response, provider)
|
||||
local tool_calls = parse_tool_calls(response, provider)
|
||||
|
||||
-- Add assistant message to history
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(assistant_msg)
|
||||
end
|
||||
|
||||
-- Process tool calls if any
|
||||
if #tool_calls > 0 then
|
||||
for _, tc in ipairs(tool_calls) do
|
||||
local result, tool_err = execute_tool(tc, opts)
|
||||
|
||||
-- Add tool result to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = tc.id,
|
||||
name = tc["function"].name,
|
||||
content = tool_err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(tool_msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Continue the loop
|
||||
vim.schedule(process_iteration)
|
||||
else
|
||||
-- No tool calls - check if complete
|
||||
if is_complete(response, provider) or content ~= "" then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(content, nil)
|
||||
end
|
||||
else
|
||||
-- Continue if not explicitly complete
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create default agent files in .coder/agents/
|
||||
function M.init_agents_dir()
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
vim.fn.mkdir(agents_dir, "p")
|
||||
|
||||
-- Create example agent
|
||||
local example_agent = require("codetyper.prompts.agents.templates").agent
|
||||
|
||||
local example_path = agents_dir .. "/example.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
|
||||
end
|
||||
|
||||
return agents_dir
|
||||
end
|
||||
|
||||
--- Create default rules in .coder/rules/
|
||||
function M.init_rules_dir()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
vim.fn.mkdir(rules_dir, "p")
|
||||
|
||||
-- Create example rule
|
||||
local example_rule = require("codetyper.prompts.agents.templates").rule
|
||||
|
||||
local example_path = rules_dir .. "/code-style.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
|
||||
end
|
||||
|
||||
return rules_dir
|
||||
end
|
||||
|
||||
--- Initialize both agents and rules directories
|
||||
function M.init()
|
||||
M.init_agents_dir()
|
||||
M.init_rules_dir()
|
||||
end
|
||||
|
||||
--- List available agents
|
||||
---@return table[] List of {name, description, builtin}
|
||||
function M.list_agents()
|
||||
local agents = {}
|
||||
|
||||
-- Built-in agents
|
||||
local personas = require("codetyper.prompts.agents.personas").builtin
|
||||
local builtins = vim.tbl_keys(personas)
|
||||
table.sort(builtins)
|
||||
|
||||
for _, name in ipairs(builtins) do
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = true,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Custom agents from .coder/agents/
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
if vim.fn.isdirectory(agents_dir) == 1 then
|
||||
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local name = vim.fn.fnamemodify(file, ":t:r")
|
||||
if not vim.tbl_contains(builtins, name) then
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return agents
|
||||
end
|
||||
|
||||
return M
|
||||
455
lua/codetyper/features/agents/init.lua
Normal file
455
lua/codetyper/features/agents/init.lua
Normal file
@@ -0,0 +1,455 @@
|
||||
---@mod codetyper.agent Agent orchestration for Codetyper.nvim
|
||||
---
|
||||
--- Manages the agentic conversation loop with tool execution.
|
||||
|
||||
local M = {}
|
||||
|
||||
local tools = require("codetyper.core.tools")
|
||||
local executor = require("codetyper.core.scheduler.executor")
|
||||
local parser = require("codetyper.core.llm.parser")
|
||||
local diff = require("codetyper.core.diff.diff")
|
||||
local diff_review = require("codetyper.adapters.nvim.ui.diff_review")
|
||||
local resume = require("codetyper.core.scheduler.resume")
|
||||
local utils = require("codetyper.support.utils")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
|
||||
---@class AgentState
|
||||
---@field conversation table[] Message history for multi-turn
|
||||
---@field pending_tool_results table[] Results waiting to be sent back
|
||||
---@field is_running boolean Whether agent loop is active
|
||||
---@field max_iterations number Maximum tool call iterations
|
||||
|
||||
local state = {
|
||||
conversation = {},
|
||||
pending_tool_results = {},
|
||||
is_running = false,
|
||||
max_iterations = 25, -- Increased for complex tasks (env setup, tests, fixes)
|
||||
current_iteration = 0,
|
||||
original_prompt = "", -- Store for resume functionality
|
||||
current_context = nil, -- Store context for resume
|
||||
current_callbacks = nil, -- Store callbacks for continue
|
||||
}
|
||||
|
||||
---@class AgentCallbacks
|
||||
---@field on_text fun(text: string) Called when text content is received
|
||||
---@field on_tool_start fun(name: string) Called when a tool starts
|
||||
---@field on_tool_result fun(name: string, result: string) Called when a tool completes
|
||||
---@field on_complete fun() Called when agent finishes
|
||||
---@field on_error fun(err: string) Called on error
|
||||
|
||||
--- Reset agent state for new conversation
|
||||
function M.reset()
|
||||
state.conversation = {}
|
||||
state.pending_tool_results = {}
|
||||
state.is_running = false
|
||||
state.current_iteration = 0
|
||||
-- Clear collected diffs
|
||||
diff_review.clear()
|
||||
end
|
||||
|
||||
--- Check if agent is currently running
|
||||
---@return boolean
|
||||
function M.is_running()
|
||||
return state.is_running
|
||||
end
|
||||
|
||||
--- Stop the agent
|
||||
function M.stop()
|
||||
state.is_running = false
|
||||
utils.notify("Agent stopped")
|
||||
end
|
||||
|
||||
--- Main agent entry point
|
||||
---@param prompt string User's request
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks Callback functions
|
||||
function M.run(prompt, context, callbacks)
|
||||
if state.is_running then
|
||||
callbacks.on_error("Agent is already running")
|
||||
return
|
||||
end
|
||||
|
||||
logs.info("Starting agent run")
|
||||
logs.debug("Prompt length: " .. #prompt .. " chars")
|
||||
|
||||
state.is_running = true
|
||||
state.current_iteration = 0
|
||||
state.original_prompt = prompt
|
||||
state.current_context = context
|
||||
state.current_callbacks = callbacks
|
||||
|
||||
-- Add user message to conversation
|
||||
table.insert(state.conversation, {
|
||||
role = "user",
|
||||
content = prompt,
|
||||
})
|
||||
|
||||
-- Start the agent loop
|
||||
M.agent_loop(context, callbacks)
|
||||
end
|
||||
|
||||
--- The core agent loop
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.agent_loop(context, callbacks)
|
||||
if not state.is_running then
|
||||
callbacks.on_complete()
|
||||
return
|
||||
end
|
||||
|
||||
state.current_iteration = state.current_iteration + 1
|
||||
logs.info(string.format("Agent loop iteration %d/%d", state.current_iteration, state.max_iterations))
|
||||
|
||||
if state.current_iteration > state.max_iterations then
|
||||
logs.info("Max iterations reached, asking user to continue or stop")
|
||||
-- Ask user if they want to continue
|
||||
M.prompt_continue(context, callbacks)
|
||||
return
|
||||
end
|
||||
|
||||
local llm = require("codetyper.core.llm")
|
||||
local client = llm.get_client()
|
||||
|
||||
-- Check if client supports tools
|
||||
if not client.generate_with_tools then
|
||||
logs.error("Provider does not support agent mode")
|
||||
callbacks.on_error("Current LLM provider does not support agent mode")
|
||||
state.is_running = false
|
||||
return
|
||||
end
|
||||
|
||||
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
|
||||
|
||||
-- Generate with tools enabled
|
||||
-- Ensure tools are loaded and get definitions
|
||||
tools.setup()
|
||||
local tool_defs = tools.to_openai_format()
|
||||
|
||||
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
|
||||
if err then
|
||||
state.is_running = false
|
||||
callbacks.on_error(err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Parse response based on provider
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
local parsed
|
||||
|
||||
-- Copilot uses Claude-like response format
|
||||
if config.llm.provider == "copilot" then
|
||||
parsed = parser.parse_claude_response(response)
|
||||
table.insert(state.conversation, {
|
||||
role = "assistant",
|
||||
content = parsed.text or "",
|
||||
tool_calls = parsed.tool_calls,
|
||||
_raw_content = response.content,
|
||||
})
|
||||
else
|
||||
-- For Ollama, response is the text directly
|
||||
if type(response) == "string" then
|
||||
parsed = parser.parse_ollama_response(response)
|
||||
else
|
||||
parsed = parser.parse_ollama_response(response.response or "")
|
||||
end
|
||||
-- Add assistant response to conversation
|
||||
table.insert(state.conversation, {
|
||||
role = "assistant",
|
||||
content = parsed.text,
|
||||
tool_calls = parsed.tool_calls,
|
||||
})
|
||||
end
|
||||
|
||||
-- Display any text content
|
||||
if parsed.text and parsed.text ~= "" then
|
||||
local clean_text = parser.clean_text(parsed.text)
|
||||
if clean_text ~= "" then
|
||||
callbacks.on_text(clean_text)
|
||||
end
|
||||
end
|
||||
|
||||
-- Check for tool calls
|
||||
if #parsed.tool_calls > 0 then
|
||||
logs.info(string.format("Processing %d tool call(s)", #parsed.tool_calls))
|
||||
-- Process tool calls sequentially
|
||||
M.process_tool_calls(parsed.tool_calls, 1, context, callbacks)
|
||||
else
|
||||
-- No more tool calls, agent is done
|
||||
logs.info("No tool calls, finishing agent loop")
|
||||
state.is_running = false
|
||||
callbacks.on_complete()
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Process tool calls one at a time
|
||||
---@param tool_calls table[] List of tool calls
|
||||
---@param index number Current index
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.process_tool_calls(tool_calls, index, context, callbacks)
|
||||
if not state.is_running then
|
||||
callbacks.on_complete()
|
||||
return
|
||||
end
|
||||
|
||||
if index > #tool_calls then
|
||||
-- All tools processed, continue agent loop with results
|
||||
M.continue_with_results(context, callbacks)
|
||||
return
|
||||
end
|
||||
|
||||
local tool_call = tool_calls[index]
|
||||
callbacks.on_tool_start(tool_call.name)
|
||||
|
||||
executor.execute(tool_call.name, tool_call.parameters, function(result)
|
||||
if result.requires_approval then
|
||||
logs.tool(tool_call.name, "approval", "Waiting for user approval")
|
||||
-- Show diff preview and wait for user decision
|
||||
local show_fn
|
||||
if result.diff_data.operation == "bash" then
|
||||
show_fn = function(_, cb)
|
||||
diff.show_bash_approval(result.diff_data.modified:gsub("^%$ ", ""), cb)
|
||||
end
|
||||
else
|
||||
show_fn = diff.show_diff
|
||||
end
|
||||
|
||||
show_fn(result.diff_data, function(approval_result)
|
||||
-- Handle both old (boolean) and new (table) approval result formats
|
||||
local approved = type(approval_result) == "table" and approval_result.approved or approval_result
|
||||
local permission_level = type(approval_result) == "table" and approval_result.permission_level or nil
|
||||
|
||||
if approved then
|
||||
local log_msg = "User approved"
|
||||
if permission_level == "allow_session" then
|
||||
log_msg = "Allowed for session"
|
||||
elseif permission_level == "allow_list" then
|
||||
log_msg = "Added to allow list"
|
||||
elseif permission_level == "auto" then
|
||||
log_msg = "Auto-approved"
|
||||
end
|
||||
logs.tool(tool_call.name, "approved", log_msg)
|
||||
|
||||
-- Apply the change and collect for review
|
||||
executor.apply_change(result.diff_data, function(apply_result)
|
||||
-- Collect the diff for end-of-session review
|
||||
if result.diff_data.operation ~= "bash" then
|
||||
diff_review.add({
|
||||
path = result.diff_data.path,
|
||||
operation = result.diff_data.operation,
|
||||
original = result.diff_data.original,
|
||||
modified = result.diff_data.modified,
|
||||
approved = true,
|
||||
applied = true,
|
||||
})
|
||||
end
|
||||
|
||||
-- Store result for sending back to LLM
|
||||
table.insert(state.pending_tool_results, {
|
||||
tool_use_id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
result = apply_result.result,
|
||||
})
|
||||
callbacks.on_tool_result(tool_call.name, apply_result.result)
|
||||
-- Process next tool call
|
||||
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
|
||||
end)
|
||||
else
|
||||
logs.tool(tool_call.name, "rejected", "User rejected")
|
||||
-- User rejected
|
||||
table.insert(state.pending_tool_results, {
|
||||
tool_use_id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
result = "User rejected this change",
|
||||
})
|
||||
callbacks.on_tool_result(tool_call.name, "Rejected by user")
|
||||
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
|
||||
end
|
||||
end)
|
||||
else
|
||||
-- No approval needed (read_file), store result immediately
|
||||
table.insert(state.pending_tool_results, {
|
||||
tool_use_id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
result = result.result,
|
||||
})
|
||||
|
||||
-- For read_file, just show a brief confirmation
|
||||
local display_result = result.result
|
||||
if tool_call.name == "read_file" and result.success then
|
||||
display_result = "[Read " .. #result.result .. " bytes]"
|
||||
end
|
||||
callbacks.on_tool_result(tool_call.name, display_result)
|
||||
|
||||
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Continue the loop after tool execution
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.continue_with_results(context, callbacks)
|
||||
if #state.pending_tool_results == 0 then
|
||||
state.is_running = false
|
||||
callbacks.on_complete()
|
||||
return
|
||||
end
|
||||
|
||||
-- Build tool results message
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
-- Copilot uses OpenAI format for tool results (role: "tool")
|
||||
if config.llm.provider == "copilot" then
|
||||
-- OpenAI-style tool messages - each result is a separate message
|
||||
for _, result in ipairs(state.pending_tool_results) do
|
||||
table.insert(state.conversation, {
|
||||
role = "tool",
|
||||
tool_call_id = result.tool_use_id,
|
||||
content = result.result,
|
||||
})
|
||||
end
|
||||
else
|
||||
-- Ollama format: plain text describing results
|
||||
local result_text = "Tool results:\n"
|
||||
for _, result in ipairs(state.pending_tool_results) do
|
||||
result_text = result_text .. "\n[" .. result.name .. "]: " .. result.result .. "\n"
|
||||
end
|
||||
table.insert(state.conversation, {
|
||||
role = "user",
|
||||
content = result_text,
|
||||
})
|
||||
end
|
||||
|
||||
state.pending_tool_results = {}
|
||||
|
||||
-- Continue the loop
|
||||
M.agent_loop(context, callbacks)
|
||||
end
|
||||
|
||||
--- Get conversation history
|
||||
---@return table[]
|
||||
function M.get_conversation()
|
||||
return state.conversation
|
||||
end
|
||||
|
||||
--- Set max iterations
|
||||
---@param max number Maximum iterations
|
||||
function M.set_max_iterations(max)
|
||||
state.max_iterations = max
|
||||
end
|
||||
|
||||
--- Get the count of collected changes
|
||||
---@return number
|
||||
function M.get_changes_count()
|
||||
return diff_review.count()
|
||||
end
|
||||
|
||||
--- Show the diff review UI for all collected changes
|
||||
function M.show_diff_review()
|
||||
diff_review.open()
|
||||
end
|
||||
|
||||
--- Check if diff review is open
|
||||
---@return boolean
|
||||
function M.is_review_open()
|
||||
return diff_review.is_open()
|
||||
end
|
||||
|
||||
--- Prompt user to continue or stop at max iterations
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.prompt_continue(context, callbacks)
|
||||
vim.schedule(function()
|
||||
vim.ui.select({ "Continue (25 more iterations)", "Stop and save for later" }, {
|
||||
prompt = string.format("Agent reached %d iterations. Continue?", state.max_iterations),
|
||||
}, function(choice)
|
||||
if choice and choice:match("^Continue") then
|
||||
-- Reset iteration counter and continue
|
||||
state.current_iteration = 0
|
||||
logs.info("User chose to continue, resetting iteration counter")
|
||||
M.agent_loop(context, callbacks)
|
||||
else
|
||||
-- Save state for later resume
|
||||
logs.info("User chose to stop, saving state for resume")
|
||||
resume.save(
|
||||
state.conversation,
|
||||
state.pending_tool_results,
|
||||
state.current_iteration,
|
||||
state.original_prompt
|
||||
)
|
||||
state.is_running = false
|
||||
callbacks.on_text("Agent paused. Use /continue to resume later.")
|
||||
callbacks.on_complete()
|
||||
end
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Continue a previously stopped agent session
|
||||
---@param callbacks AgentCallbacks
|
||||
---@return boolean Success
|
||||
function M.continue_session(callbacks)
|
||||
if state.is_running then
|
||||
utils.notify("Agent is already running", vim.log.levels.WARN)
|
||||
return false
|
||||
end
|
||||
|
||||
local saved = resume.load()
|
||||
if not saved then
|
||||
utils.notify("No saved agent session to continue", vim.log.levels.WARN)
|
||||
return false
|
||||
end
|
||||
|
||||
logs.info("Resuming agent session")
|
||||
logs.info(string.format("Loaded %d messages, iteration %d", #saved.conversation, saved.iteration))
|
||||
|
||||
-- Restore state
|
||||
state.conversation = saved.conversation
|
||||
state.pending_tool_results = saved.pending_tool_results or {}
|
||||
state.current_iteration = 0 -- Reset for fresh iterations
|
||||
state.original_prompt = saved.original_prompt
|
||||
state.is_running = true
|
||||
state.current_callbacks = callbacks
|
||||
|
||||
-- Build context from current state
|
||||
local llm = require("codetyper.core.llm")
|
||||
local context = {}
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
if current_file ~= "" and vim.fn.filereadable(current_file) == 1 then
|
||||
context = llm.build_context(current_file, "agent")
|
||||
end
|
||||
state.current_context = context
|
||||
|
||||
-- Clear saved state
|
||||
resume.clear()
|
||||
|
||||
-- Add continuation message
|
||||
table.insert(state.conversation, {
|
||||
role = "user",
|
||||
content = "Continue where you left off. Complete the remaining tasks.",
|
||||
})
|
||||
|
||||
-- Continue the loop
|
||||
callbacks.on_text("Resuming agent session...")
|
||||
M.agent_loop(context, callbacks)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Check if there's a saved session to continue
|
||||
---@return boolean
|
||||
function M.has_saved_session()
|
||||
return resume.has_saved_state()
|
||||
end
|
||||
|
||||
--- Get info about saved session
|
||||
---@return table|nil
|
||||
function M.get_saved_session_info()
|
||||
return resume.get_info()
|
||||
end
|
||||
|
||||
return M
|
||||
425
lua/codetyper/features/agents/linter.lua
Normal file
425
lua/codetyper/features/agents/linter.lua
Normal file
@@ -0,0 +1,425 @@
|
||||
---@mod codetyper.agent.linter Linter validation for generated code
|
||||
---@brief [[
|
||||
--- Validates generated code by checking LSP diagnostics after injection.
|
||||
--- Automatically saves the file and waits for LSP to update before checking.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local config_params = require("codetyper.params.agent.linter")
|
||||
local prompts = require("codetyper.prompts.agent.linter")
|
||||
|
||||
--- Configuration
|
||||
local config = config_params.config
|
||||
|
||||
--- Diagnostic results for tracking
|
||||
---@type table<number, table>
|
||||
local validation_results = {}
|
||||
|
||||
--- Configure linter behavior
|
||||
---@param opts table Configuration options
|
||||
function M.configure(opts)
|
||||
for k, v in pairs(opts) do
|
||||
if config[k] ~= nil then
|
||||
config[k] = v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Get current configuration
|
||||
---@return table
|
||||
function M.get_config()
|
||||
return vim.deepcopy(config)
|
||||
end
|
||||
|
||||
--- Save buffer if modified
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean success
|
||||
local function save_buffer(bufnr)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip if buffer is not modified
|
||||
if not vim.bo[bufnr].modified then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Skip if buffer has no name (unsaved file)
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
if bufname == "" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Save the buffer
|
||||
local ok, err = pcall(function()
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("silent! write")
|
||||
end)
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = "Failed to save buffer: " .. tostring(err),
|
||||
})
|
||||
end)
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get LSP diagnostics for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line? number Start line (1-indexed)
|
||||
---@param end_line? number End line (1-indexed)
|
||||
---@return table[] diagnostics List of diagnostics
|
||||
function M.get_diagnostics(bufnr, start_line, end_line)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return {}
|
||||
end
|
||||
|
||||
local all_diagnostics = vim.diagnostic.get(bufnr)
|
||||
local filtered = {}
|
||||
|
||||
for _, diag in ipairs(all_diagnostics) do
|
||||
-- Filter by severity
|
||||
if diag.severity <= config.min_severity then
|
||||
-- Filter by line range if specified
|
||||
if start_line and end_line then
|
||||
local diag_line = diag.lnum + 1 -- Convert to 1-indexed
|
||||
if diag_line >= start_line and diag_line <= end_line then
|
||||
table.insert(filtered, diag)
|
||||
end
|
||||
else
|
||||
table.insert(filtered, diag)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return filtered
|
||||
end
|
||||
|
||||
--- Format a diagnostic for display
|
||||
---@param diag table Diagnostic object
|
||||
---@return string
|
||||
local function format_diagnostic(diag)
|
||||
local severity_names = {
|
||||
[vim.diagnostic.severity.ERROR] = "ERROR",
|
||||
[vim.diagnostic.severity.WARN] = "WARN",
|
||||
[vim.diagnostic.severity.INFO] = "INFO",
|
||||
[vim.diagnostic.severity.HINT] = "HINT",
|
||||
}
|
||||
local severity = severity_names[diag.severity] or "UNKNOWN"
|
||||
local line = diag.lnum + 1
|
||||
local source = diag.source or "lsp"
|
||||
return string.format("[%s] Line %d (%s): %s", severity, line, source, diag.message)
|
||||
end
|
||||
|
||||
--- Check if there are errors in generated code region
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line number Start line (1-indexed)
|
||||
---@param end_line number End line (1-indexed)
|
||||
---@return table result {has_errors, has_warnings, diagnostics, summary}
|
||||
function M.check_region(bufnr, start_line, end_line)
|
||||
local diagnostics = M.get_diagnostics(bufnr, start_line, end_line)
|
||||
|
||||
local errors = 0
|
||||
local warnings = 0
|
||||
|
||||
for _, diag in ipairs(diagnostics) do
|
||||
if diag.severity == vim.diagnostic.severity.ERROR then
|
||||
errors = errors + 1
|
||||
elseif diag.severity == vim.diagnostic.severity.WARN then
|
||||
warnings = warnings + 1
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
has_errors = errors > 0,
|
||||
has_warnings = warnings > 0,
|
||||
error_count = errors,
|
||||
warning_count = warnings,
|
||||
diagnostics = diagnostics,
|
||||
summary = string.format("%d error(s), %d warning(s)", errors, warnings),
|
||||
}
|
||||
end
|
||||
|
||||
--- Validate code after injection and report issues
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line? number Start line of injected code (1-indexed)
|
||||
---@param end_line? number End line of injected code (1-indexed)
|
||||
---@param callback? function Callback with (result) when validation completes
|
||||
function M.validate_after_injection(bufnr, start_line, end_line, callback)
|
||||
-- Save the file first
|
||||
if config.auto_save then
|
||||
save_buffer(bufnr)
|
||||
end
|
||||
|
||||
-- Wait for LSP to process changes
|
||||
vim.defer_fn(function()
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
if callback then callback(nil) end
|
||||
return
|
||||
end
|
||||
|
||||
local result
|
||||
if start_line and end_line then
|
||||
result = M.check_region(bufnr, start_line, end_line)
|
||||
else
|
||||
-- Check entire buffer
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
result = M.check_region(bufnr, 1, line_count)
|
||||
end
|
||||
|
||||
-- Store result for this buffer
|
||||
validation_results[bufnr] = {
|
||||
timestamp = os.time(),
|
||||
result = result,
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
}
|
||||
|
||||
-- Log results
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
if result.has_errors then
|
||||
logs.add({
|
||||
type = "error",
|
||||
message = string.format("Linter found issues: %s", result.summary),
|
||||
})
|
||||
-- Log individual errors
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
if diag.severity == vim.diagnostic.severity.ERROR then
|
||||
logs.add({
|
||||
type = "error",
|
||||
message = format_diagnostic(diag),
|
||||
})
|
||||
end
|
||||
end
|
||||
elseif result.has_warnings then
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = string.format("Linter warnings: %s", result.summary),
|
||||
})
|
||||
else
|
||||
logs.add({
|
||||
type = "success",
|
||||
message = "Linter check passed - no errors or warnings",
|
||||
})
|
||||
end
|
||||
end)
|
||||
|
||||
-- Notify user
|
||||
if result.has_errors then
|
||||
vim.notify(
|
||||
string.format("Generated code has lint errors: %s", result.summary),
|
||||
vim.log.levels.ERROR
|
||||
)
|
||||
|
||||
-- Offer to fix if configured
|
||||
if config.auto_offer_fix and #result.diagnostics > 0 then
|
||||
M.offer_fix(bufnr, result)
|
||||
end
|
||||
elseif result.has_warnings then
|
||||
vim.notify(
|
||||
string.format("Generated code has warnings: %s", result.summary),
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
end
|
||||
|
||||
if callback then
|
||||
callback(result)
|
||||
end
|
||||
end, config.diagnostic_delay_ms)
|
||||
end
|
||||
|
||||
--- Offer to fix lint errors using AI
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.offer_fix(bufnr, result)
|
||||
if not result.has_errors and not result.has_warnings then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build error summary for prompt
|
||||
local error_messages = {}
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(error_messages, format_diagnostic(diag))
|
||||
end
|
||||
|
||||
vim.ui.select(
|
||||
{ "Yes - Auto-fix with AI", "No - I'll fix manually", "Show errors in quickfix" },
|
||||
{
|
||||
prompt = string.format("Found %d issue(s). Would you like AI to fix them?", #result.diagnostics),
|
||||
},
|
||||
function(choice)
|
||||
if not choice then return end
|
||||
|
||||
if choice:match("^Yes") then
|
||||
M.request_ai_fix(bufnr, result)
|
||||
elseif choice:match("quickfix") then
|
||||
M.show_in_quickfix(bufnr, result)
|
||||
end
|
||||
end
|
||||
)
|
||||
end
|
||||
|
||||
--- Show lint errors in quickfix list
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.show_in_quickfix(bufnr, result)
|
||||
local qf_items = {}
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(qf_items, {
|
||||
bufnr = bufnr,
|
||||
filename = bufname,
|
||||
lnum = diag.lnum + 1,
|
||||
col = diag.col + 1,
|
||||
text = diag.message,
|
||||
type = diag.severity == vim.diagnostic.severity.ERROR and "E" or "W",
|
||||
})
|
||||
end
|
||||
|
||||
vim.fn.setqflist(qf_items, "r")
|
||||
vim.cmd("copen")
|
||||
end
|
||||
|
||||
--- Request AI to fix lint errors
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.request_ai_fix(bufnr, result)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return
|
||||
end
|
||||
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
|
||||
-- Build fix prompt
|
||||
local error_list = {}
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(error_list, format_diagnostic(diag))
|
||||
end
|
||||
|
||||
-- Get the affected code region
|
||||
local start_line = result.diagnostics[1] and (result.diagnostics[1].lnum + 1) or 1
|
||||
local end_line = start_line
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
local line = diag.lnum + 1
|
||||
if line < start_line then start_line = line end
|
||||
if line > end_line then end_line = line end
|
||||
end
|
||||
|
||||
-- Expand range by a few lines for context
|
||||
start_line = math.max(1, start_line - 5)
|
||||
end_line = math.min(vim.api.nvim_buf_line_count(bufnr), end_line + 5)
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local code_context = table.concat(lines, "\n")
|
||||
|
||||
-- Create fix prompt using inline tag
|
||||
local fix_prompt = string.format(
|
||||
prompts.fix_request,
|
||||
table.concat(error_list, "\n"),
|
||||
start_line,
|
||||
end_line,
|
||||
code_context
|
||||
)
|
||||
|
||||
-- Queue the fix through the scheduler
|
||||
local scheduler = require("codetyper.core.scheduler.scheduler")
|
||||
local queue = require("codetyper.core.events.queue")
|
||||
local patch_mod = require("codetyper.core.diff.patch")
|
||||
|
||||
-- Ensure scheduler is running
|
||||
if not scheduler.status().running then
|
||||
scheduler.start()
|
||||
end
|
||||
|
||||
-- Take snapshot
|
||||
local snapshot = patch_mod.snapshot_buffer(bufnr, {
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
})
|
||||
|
||||
-- Enqueue fix request
|
||||
queue.enqueue({
|
||||
id = queue.generate_id(),
|
||||
bufnr = bufnr,
|
||||
range = { start_line = start_line, end_line = end_line },
|
||||
timestamp = os.clock(),
|
||||
changedtick = snapshot.changedtick,
|
||||
content_hash = snapshot.content_hash,
|
||||
prompt_content = fix_prompt,
|
||||
target_path = filepath,
|
||||
priority = 1, -- High priority for fixes
|
||||
status = "pending",
|
||||
attempt_count = 0,
|
||||
intent = {
|
||||
type = "fix",
|
||||
action = "replace",
|
||||
confidence = 0.9,
|
||||
},
|
||||
scope_range = { start_line = start_line, end_line = end_line },
|
||||
source = "linter_fix",
|
||||
})
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = "Queued AI fix request for lint errors",
|
||||
})
|
||||
end)
|
||||
|
||||
vim.notify("Queued AI fix request for lint errors", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Get last validation result for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@return table|nil result
|
||||
function M.get_last_result(bufnr)
|
||||
return validation_results[bufnr]
|
||||
end
|
||||
|
||||
--- Clear validation results for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
function M.clear_result(bufnr)
|
||||
validation_results[bufnr] = nil
|
||||
end
|
||||
|
||||
--- Check if buffer has any lint errors currently
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean has_errors
|
||||
function M.has_errors(bufnr)
|
||||
local diagnostics = vim.diagnostic.get(bufnr, {
|
||||
severity = vim.diagnostic.severity.ERROR,
|
||||
})
|
||||
return #diagnostics > 0
|
||||
end
|
||||
|
||||
--- Check if buffer has any lint warnings currently
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean has_warnings
|
||||
function M.has_warnings(bufnr)
|
||||
local diagnostics = vim.diagnostic.get(bufnr, {
|
||||
severity = { min = vim.diagnostic.severity.WARN },
|
||||
})
|
||||
return #diagnostics > 0
|
||||
end
|
||||
|
||||
--- Validate all buffers with recent changes
|
||||
function M.validate_all_changed()
|
||||
for bufnr, data in pairs(validation_results) do
|
||||
if vim.api.nvim_buf_is_valid(bufnr) then
|
||||
M.validate_after_injection(bufnr, data.start_line, data.end_line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
182
lua/codetyper/features/agents/permissions.lua
Normal file
182
lua/codetyper/features/agents/permissions.lua
Normal file
@@ -0,0 +1,182 @@
|
||||
---@mod codetyper.agent.permissions Permission manager for agent actions
|
||||
---
|
||||
--- Manages permissions for bash commands and file operations with
|
||||
--- allow, allow-session, allow-list, and reject options.
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class PermissionState
|
||||
---@field session_allowed table<string, boolean> Commands allowed for this session
|
||||
---@field allow_list table<string, boolean> Patterns always allowed
|
||||
---@field deny_list table<string, boolean> Patterns always denied
|
||||
|
||||
local params = require("codetyper.params.agent.permissions")
|
||||
|
||||
local state = {
|
||||
session_allowed = {},
|
||||
allow_list = {},
|
||||
deny_list = {},
|
||||
}
|
||||
|
||||
--- Dangerous command patterns that should never be auto-allowed
|
||||
local DANGEROUS_PATTERNS = params.dangerous_patterns
|
||||
|
||||
--- Safe command patterns that can be auto-allowed
|
||||
local SAFE_PATTERNS = params.safe_patterns
|
||||
|
||||
---@alias PermissionLevel "allow"|"allow_session"|"allow_list"|"reject"
|
||||
|
||||
---@class PermissionResult
|
||||
---@field allowed boolean Whether action is allowed
|
||||
---@field reason string Reason for the decision
|
||||
---@field auto boolean Whether this was an automatic decision
|
||||
|
||||
--- Check if a command matches a pattern
|
||||
---@param command string The command to check
|
||||
---@param pattern string The pattern to match
|
||||
---@return boolean
|
||||
local function matches_pattern(command, pattern)
|
||||
return command:match(pattern) ~= nil
|
||||
end
|
||||
|
||||
--- Check if command is dangerous
|
||||
---@param command string The command to check
|
||||
---@return boolean, string|nil dangerous, reason
|
||||
local function is_dangerous(command)
|
||||
for _, pattern in ipairs(DANGEROUS_PATTERNS) do
|
||||
if matches_pattern(command, pattern) then
|
||||
return true, "Matches dangerous pattern: " .. pattern
|
||||
end
|
||||
end
|
||||
return false, nil
|
||||
end
|
||||
|
||||
--- Check if command is safe
|
||||
---@param command string The command to check
|
||||
---@return boolean
|
||||
local function is_safe(command)
|
||||
for _, pattern in ipairs(SAFE_PATTERNS) do
|
||||
if matches_pattern(command, pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Normalize command for comparison (trim, lowercase first word)
|
||||
---@param command string
|
||||
---@return string
|
||||
local function normalize_command(command)
|
||||
return vim.trim(command)
|
||||
end
|
||||
|
||||
--- Check permission for a bash command
|
||||
---@param command string The command to check
|
||||
---@return PermissionResult
|
||||
function M.check_bash_permission(command)
|
||||
local normalized = normalize_command(command)
|
||||
|
||||
-- Check deny list first
|
||||
for pattern, _ in pairs(state.deny_list) do
|
||||
if matches_pattern(normalized, pattern) then
|
||||
return {
|
||||
allowed = false,
|
||||
reason = "Command in deny list",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if command is dangerous
|
||||
local dangerous, reason = is_dangerous(normalized)
|
||||
if dangerous then
|
||||
return {
|
||||
allowed = false,
|
||||
reason = reason,
|
||||
auto = false, -- Require explicit approval for dangerous commands
|
||||
}
|
||||
end
|
||||
|
||||
-- Check session allowed
|
||||
if state.session_allowed[normalized] then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Allowed for this session",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
|
||||
-- Check allow list patterns
|
||||
for pattern, _ in pairs(state.allow_list) do
|
||||
if matches_pattern(normalized, pattern) then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Matches allow list pattern",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if command is inherently safe
|
||||
if is_safe(normalized) then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Safe read-only command",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
|
||||
-- Otherwise, require explicit permission
|
||||
return {
|
||||
allowed = false,
|
||||
reason = "Requires approval",
|
||||
auto = false,
|
||||
}
|
||||
end
|
||||
|
||||
--- Grant permission for a command
|
||||
---@param command string The command
|
||||
---@param level PermissionLevel The permission level
|
||||
function M.grant_permission(command, level)
|
||||
local normalized = normalize_command(command)
|
||||
|
||||
if level == "allow_session" then
|
||||
state.session_allowed[normalized] = true
|
||||
elseif level == "allow_list" then
|
||||
-- Add as pattern (escape special chars for exact match)
|
||||
local pattern = "^" .. vim.pesc(normalized) .. "$"
|
||||
state.allow_list[pattern] = true
|
||||
end
|
||||
end
|
||||
|
||||
--- Add a pattern to the allow list
|
||||
---@param pattern string Lua pattern to allow
|
||||
function M.add_to_allow_list(pattern)
|
||||
state.allow_list[pattern] = true
|
||||
end
|
||||
|
||||
--- Add a pattern to the deny list
|
||||
---@param pattern string Lua pattern to deny
|
||||
function M.add_to_deny_list(pattern)
|
||||
state.deny_list[pattern] = true
|
||||
end
|
||||
|
||||
--- Clear session permissions
|
||||
function M.clear_session()
|
||||
state.session_allowed = {}
|
||||
end
|
||||
|
||||
--- Reset all permissions
|
||||
function M.reset()
|
||||
state.session_allowed = {}
|
||||
state.allow_list = {}
|
||||
state.deny_list = {}
|
||||
end
|
||||
|
||||
--- Get current permission state (for debugging)
|
||||
---@return PermissionState
|
||||
function M.get_state()
|
||||
return vim.deepcopy(state)
|
||||
end
|
||||
|
||||
return M
|
||||
1282
lua/codetyper/features/ask/engine.lua
Normal file
1282
lua/codetyper/features/ask/engine.lua
Normal file
File diff suppressed because it is too large
Load Diff
676
lua/codetyper/features/ask/explorer.lua
Normal file
676
lua/codetyper/features/ask/explorer.lua
Normal file
@@ -0,0 +1,676 @@
|
||||
---@mod codetyper.ask.explorer Project exploration for Ask mode
|
||||
---@brief [[
|
||||
--- Performs comprehensive project exploration when explaining a project.
|
||||
--- Shows progress, indexes files, and builds brain context.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
---@class ExplorationState
|
||||
---@field is_exploring boolean
|
||||
---@field files_scanned number
|
||||
---@field total_files number
|
||||
---@field current_file string|nil
|
||||
---@field findings table
|
||||
---@field on_log fun(msg: string, level: string)|nil
|
||||
|
||||
local state = {
|
||||
is_exploring = false,
|
||||
files_scanned = 0,
|
||||
total_files = 0,
|
||||
current_file = nil,
|
||||
findings = {},
|
||||
on_log = nil,
|
||||
}
|
||||
|
||||
--- File extensions to analyze
|
||||
local ANALYZABLE_EXTENSIONS = {
|
||||
lua = true,
|
||||
ts = true,
|
||||
tsx = true,
|
||||
js = true,
|
||||
jsx = true,
|
||||
py = true,
|
||||
go = true,
|
||||
rs = true,
|
||||
rb = true,
|
||||
java = true,
|
||||
c = true,
|
||||
cpp = true,
|
||||
h = true,
|
||||
hpp = true,
|
||||
json = true,
|
||||
yaml = true,
|
||||
yml = true,
|
||||
toml = true,
|
||||
md = true,
|
||||
xml = true,
|
||||
}
|
||||
|
||||
--- Directories to skip
|
||||
local SKIP_DIRS = {
|
||||
-- Version control
|
||||
[".git"] = true,
|
||||
[".svn"] = true,
|
||||
[".hg"] = true,
|
||||
|
||||
-- IDE/Editor
|
||||
[".idea"] = true,
|
||||
[".vscode"] = true,
|
||||
[".cursor"] = true,
|
||||
[".cursorignore"] = true,
|
||||
[".claude"] = true,
|
||||
[".zed"] = true,
|
||||
|
||||
-- Project tooling
|
||||
[".coder"] = true,
|
||||
[".github"] = true,
|
||||
[".gitlab"] = true,
|
||||
[".husky"] = true,
|
||||
|
||||
-- Build outputs
|
||||
dist = true,
|
||||
build = true,
|
||||
out = true,
|
||||
target = true,
|
||||
bin = true,
|
||||
obj = true,
|
||||
[".build"] = true,
|
||||
[".output"] = true,
|
||||
|
||||
-- Dependencies
|
||||
node_modules = true,
|
||||
vendor = true,
|
||||
[".vendor"] = true,
|
||||
packages = true,
|
||||
bower_components = true,
|
||||
jspm_packages = true,
|
||||
|
||||
-- Cache/temp
|
||||
[".cache"] = true,
|
||||
[".tmp"] = true,
|
||||
[".temp"] = true,
|
||||
__pycache__ = true,
|
||||
[".pytest_cache"] = true,
|
||||
[".mypy_cache"] = true,
|
||||
[".ruff_cache"] = true,
|
||||
[".tox"] = true,
|
||||
[".nox"] = true,
|
||||
[".eggs"] = true,
|
||||
["*.egg-info"] = true,
|
||||
|
||||
-- Framework specific
|
||||
[".next"] = true,
|
||||
[".nuxt"] = true,
|
||||
[".svelte-kit"] = true,
|
||||
[".vercel"] = true,
|
||||
[".netlify"] = true,
|
||||
[".serverless"] = true,
|
||||
[".turbo"] = true,
|
||||
|
||||
-- Testing/coverage
|
||||
coverage = true,
|
||||
[".nyc_output"] = true,
|
||||
htmlcov = true,
|
||||
|
||||
-- Logs
|
||||
logs = true,
|
||||
log = true,
|
||||
|
||||
-- OS files
|
||||
[".DS_Store"] = true,
|
||||
Thumbs_db = true,
|
||||
}
|
||||
|
||||
--- Files to skip (patterns)
|
||||
local SKIP_FILES = {
|
||||
-- Lock files
|
||||
"package%-lock%.json",
|
||||
"yarn%.lock",
|
||||
"pnpm%-lock%.yaml",
|
||||
"Gemfile%.lock",
|
||||
"Cargo%.lock",
|
||||
"poetry%.lock",
|
||||
"Pipfile%.lock",
|
||||
"composer%.lock",
|
||||
"go%.sum",
|
||||
"flake%.lock",
|
||||
"%.lock$",
|
||||
"%-lock%.json$",
|
||||
"%-lock%.yaml$",
|
||||
|
||||
-- Generated files
|
||||
"%.min%.js$",
|
||||
"%.min%.css$",
|
||||
"%.bundle%.js$",
|
||||
"%.chunk%.js$",
|
||||
"%.map$",
|
||||
"%.d%.ts$",
|
||||
|
||||
-- Binary/media (shouldn't match anyway but be safe)
|
||||
"%.png$",
|
||||
"%.jpg$",
|
||||
"%.jpeg$",
|
||||
"%.gif$",
|
||||
"%.ico$",
|
||||
"%.svg$",
|
||||
"%.woff",
|
||||
"%.ttf$",
|
||||
"%.eot$",
|
||||
"%.pdf$",
|
||||
"%.zip$",
|
||||
"%.tar",
|
||||
"%.gz$",
|
||||
|
||||
-- Config that's not useful
|
||||
"%.env",
|
||||
"%.env%.",
|
||||
}
|
||||
|
||||
--- Log a message during exploration
|
||||
---@param msg string
|
||||
---@param level? string "info"|"debug"|"file"|"progress"
|
||||
local function log(msg, level)
|
||||
level = level or "info"
|
||||
if state.on_log then
|
||||
state.on_log(msg, level)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if file should be skipped
|
||||
---@param filename string
|
||||
---@return boolean
|
||||
local function should_skip_file(filename)
|
||||
for _, pattern in ipairs(SKIP_FILES) do
|
||||
if filename:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if directory should be skipped
|
||||
---@param dirname string
|
||||
---@return boolean
|
||||
local function should_skip_dir(dirname)
|
||||
-- Direct match
|
||||
if SKIP_DIRS[dirname] then
|
||||
return true
|
||||
end
|
||||
-- Pattern match for .cursor* etc
|
||||
if dirname:match("^%.cursor") then
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Get all files in project
|
||||
---@param root string Project root
|
||||
---@return string[] files
|
||||
local function get_project_files(root)
|
||||
local files = {}
|
||||
|
||||
local function scan_dir(dir)
|
||||
local handle = vim.loop.fs_scandir(dir)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
|
||||
local full_path = dir .. "/" .. name
|
||||
|
||||
if type == "directory" then
|
||||
if not should_skip_dir(name) then
|
||||
scan_dir(full_path)
|
||||
end
|
||||
elseif type == "file" then
|
||||
if not should_skip_file(name) then
|
||||
local ext = name:match("%.([^%.]+)$")
|
||||
if ext and ANALYZABLE_EXTENSIONS[ext:lower()] then
|
||||
table.insert(files, full_path)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
scan_dir(root)
|
||||
return files
|
||||
end
|
||||
|
||||
--- Analyze a single file
|
||||
---@param filepath string
|
||||
---@return table|nil analysis
|
||||
local function analyze_file(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content or content == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ext = filepath:match("%.([^%.]+)$") or ""
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
local analysis = {
|
||||
path = filepath,
|
||||
extension = ext,
|
||||
lines = #lines,
|
||||
size = #content,
|
||||
imports = {},
|
||||
exports = {},
|
||||
functions = {},
|
||||
classes = {},
|
||||
summary = "",
|
||||
}
|
||||
|
||||
-- Extract key patterns based on file type
|
||||
for i, line in ipairs(lines) do
|
||||
-- Imports/requires
|
||||
local import = line:match('import%s+.*%s+from%s+["\']([^"\']+)["\']')
|
||||
or line:match('require%(["\']([^"\']+)["\']%)')
|
||||
or line:match("from%s+([%w_.]+)%s+import")
|
||||
if import then
|
||||
table.insert(analysis.imports, { source = import, line = i })
|
||||
end
|
||||
|
||||
-- Function definitions
|
||||
local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(")
|
||||
or line:match("^%s*local%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*def%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*func%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*async%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(")
|
||||
if func then
|
||||
table.insert(analysis.functions, { name = func, line = i })
|
||||
end
|
||||
|
||||
-- Class definitions
|
||||
local class = line:match("^%s*class%s+([%w_]+)")
|
||||
or line:match("^%s*public%s+class%s+([%w_]+)")
|
||||
or line:match("^%s*interface%s+([%w_]+)")
|
||||
if class then
|
||||
table.insert(analysis.classes, { name = class, line = i })
|
||||
end
|
||||
|
||||
-- Exports
|
||||
local exp = line:match("^%s*export%s+.*%s+([%w_]+)")
|
||||
or line:match("^%s*module%.exports%s*=")
|
||||
or line:match("^return%s+M")
|
||||
if exp then
|
||||
table.insert(analysis.exports, { name = exp, line = i })
|
||||
end
|
||||
end
|
||||
|
||||
-- Create summary
|
||||
local parts = {}
|
||||
if #analysis.functions > 0 then
|
||||
table.insert(parts, #analysis.functions .. " functions")
|
||||
end
|
||||
if #analysis.classes > 0 then
|
||||
table.insert(parts, #analysis.classes .. " classes")
|
||||
end
|
||||
if #analysis.imports > 0 then
|
||||
table.insert(parts, #analysis.imports .. " imports")
|
||||
end
|
||||
analysis.summary = table.concat(parts, ", ")
|
||||
|
||||
return analysis
|
||||
end
|
||||
|
||||
--- Detect project type from files
|
||||
---@param root string
|
||||
---@return string type, table info
|
||||
local function detect_project_type(root)
|
||||
local info = {
|
||||
name = vim.fn.fnamemodify(root, ":t"),
|
||||
type = "unknown",
|
||||
framework = nil,
|
||||
language = nil,
|
||||
}
|
||||
|
||||
-- Check for common project files
|
||||
if utils.file_exists(root .. "/package.json") then
|
||||
info.type = "node"
|
||||
info.language = "JavaScript/TypeScript"
|
||||
local content = utils.read_file(root .. "/package.json")
|
||||
if content then
|
||||
local ok, pkg = pcall(vim.json.decode, content)
|
||||
if ok then
|
||||
info.name = pkg.name or info.name
|
||||
if pkg.dependencies then
|
||||
if pkg.dependencies.react then
|
||||
info.framework = "React"
|
||||
elseif pkg.dependencies.vue then
|
||||
info.framework = "Vue"
|
||||
elseif pkg.dependencies.next then
|
||||
info.framework = "Next.js"
|
||||
elseif pkg.dependencies.express then
|
||||
info.framework = "Express"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif utils.file_exists(root .. "/pom.xml") then
|
||||
info.type = "maven"
|
||||
info.language = "Java"
|
||||
local content = utils.read_file(root .. "/pom.xml")
|
||||
if content and content:match("spring%-boot") then
|
||||
info.framework = "Spring Boot"
|
||||
end
|
||||
elseif utils.file_exists(root .. "/Cargo.toml") then
|
||||
info.type = "rust"
|
||||
info.language = "Rust"
|
||||
elseif utils.file_exists(root .. "/go.mod") then
|
||||
info.type = "go"
|
||||
info.language = "Go"
|
||||
elseif utils.file_exists(root .. "/requirements.txt") or utils.file_exists(root .. "/pyproject.toml") then
|
||||
info.type = "python"
|
||||
info.language = "Python"
|
||||
elseif utils.file_exists(root .. "/init.lua") or utils.file_exists(root .. "/plugin/") then
|
||||
info.type = "neovim-plugin"
|
||||
info.language = "Lua"
|
||||
end
|
||||
|
||||
return info.type, info
|
||||
end
|
||||
|
||||
--- Build project structure summary
|
||||
---@param files string[]
|
||||
---@param root string
|
||||
---@return table structure
|
||||
local function build_structure(files, root)
|
||||
local structure = {
|
||||
directories = {},
|
||||
by_extension = {},
|
||||
total_files = #files,
|
||||
}
|
||||
|
||||
for _, file in ipairs(files) do
|
||||
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local dir = vim.fn.fnamemodify(relative, ":h")
|
||||
local ext = file:match("%.([^%.]+)$") or "unknown"
|
||||
|
||||
structure.directories[dir] = (structure.directories[dir] or 0) + 1
|
||||
structure.by_extension[ext] = (structure.by_extension[ext] or 0) + 1
|
||||
end
|
||||
|
||||
return structure
|
||||
end
|
||||
|
||||
--- Explore project and build context
|
||||
---@param root string Project root
|
||||
---@param on_log fun(msg: string, level: string) Log callback
|
||||
---@param on_complete fun(result: table) Completion callback
|
||||
function M.explore(root, on_log, on_complete)
|
||||
if state.is_exploring then
|
||||
on_log("⚠️ Already exploring...", "warning")
|
||||
return
|
||||
end
|
||||
|
||||
state.is_exploring = true
|
||||
state.on_log = on_log
|
||||
state.findings = {}
|
||||
|
||||
-- Start exploration
|
||||
log("⏺ Exploring project structure...", "info")
|
||||
log("", "info")
|
||||
|
||||
-- Detect project type
|
||||
log(" Detect(Project type)", "progress")
|
||||
local project_type, project_info = detect_project_type(root)
|
||||
log(" ⎿ " .. project_info.language .. " (" .. (project_info.framework or project_type) .. ")", "debug")
|
||||
|
||||
state.findings.project = project_info
|
||||
|
||||
-- Get all files
|
||||
log("", "info")
|
||||
log(" Scan(Project files)", "progress")
|
||||
local files = get_project_files(root)
|
||||
state.total_files = #files
|
||||
log(" ⎿ Found " .. #files .. " analyzable files", "debug")
|
||||
|
||||
-- Build structure
|
||||
local structure = build_structure(files, root)
|
||||
state.findings.structure = structure
|
||||
|
||||
-- Show directory breakdown
|
||||
log("", "info")
|
||||
log(" Structure(Directories)", "progress")
|
||||
local sorted_dirs = {}
|
||||
for dir, count in pairs(structure.directories) do
|
||||
table.insert(sorted_dirs, { dir = dir, count = count })
|
||||
end
|
||||
table.sort(sorted_dirs, function(a, b)
|
||||
return a.count > b.count
|
||||
end)
|
||||
for i, entry in ipairs(sorted_dirs) do
|
||||
if i <= 5 then
|
||||
log(" ⎿ " .. entry.dir .. " (" .. entry.count .. " files)", "debug")
|
||||
end
|
||||
end
|
||||
if #sorted_dirs > 5 then
|
||||
log(" ⎿ +" .. (#sorted_dirs - 5) .. " more directories", "debug")
|
||||
end
|
||||
|
||||
-- Analyze files asynchronously
|
||||
log("", "info")
|
||||
log(" Analyze(Source files)", "progress")
|
||||
|
||||
state.files_scanned = 0
|
||||
local analyses = {}
|
||||
local key_files = {}
|
||||
|
||||
-- Process files in batches to avoid blocking
|
||||
local batch_size = 10
|
||||
local current_batch = 0
|
||||
|
||||
local function process_batch()
|
||||
local start_idx = current_batch * batch_size + 1
|
||||
local end_idx = math.min(start_idx + batch_size - 1, #files)
|
||||
|
||||
for i = start_idx, end_idx do
|
||||
local file = files[i]
|
||||
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
|
||||
state.files_scanned = state.files_scanned + 1
|
||||
state.current_file = relative
|
||||
|
||||
local analysis = analyze_file(file)
|
||||
if analysis then
|
||||
analysis.relative_path = relative
|
||||
table.insert(analyses, analysis)
|
||||
|
||||
-- Track key files (many functions/classes)
|
||||
if #analysis.functions >= 3 or #analysis.classes >= 1 then
|
||||
table.insert(key_files, {
|
||||
path = relative,
|
||||
functions = #analysis.functions,
|
||||
classes = #analysis.classes,
|
||||
summary = analysis.summary,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Log some files
|
||||
if i <= 3 or (i % 20 == 0) then
|
||||
log(" ⎿ " .. relative .. ": " .. (analysis and analysis.summary or "(empty)"), "file")
|
||||
end
|
||||
end
|
||||
|
||||
-- Progress update
|
||||
local progress = math.floor((state.files_scanned / state.total_files) * 100)
|
||||
if progress % 25 == 0 and progress > 0 then
|
||||
log(" ⎿ " .. progress .. "% complete (" .. state.files_scanned .. "/" .. state.total_files .. ")", "debug")
|
||||
end
|
||||
|
||||
current_batch = current_batch + 1
|
||||
|
||||
if end_idx < #files then
|
||||
-- Schedule next batch
|
||||
vim.defer_fn(process_batch, 10)
|
||||
else
|
||||
-- Complete
|
||||
finish_exploration(root, analyses, key_files, on_complete)
|
||||
end
|
||||
end
|
||||
|
||||
-- Start processing
|
||||
vim.defer_fn(process_batch, 10)
|
||||
end
|
||||
|
||||
--- Finish exploration and store results
|
||||
---@param root string
|
||||
---@param analyses table
|
||||
---@param key_files table
|
||||
---@param on_complete fun(result: table)
|
||||
function finish_exploration(root, analyses, key_files, on_complete)
|
||||
log(" ⎿ +" .. (#analyses - 3) .. " more files analyzed", "debug")
|
||||
|
||||
-- Show key files
|
||||
if #key_files > 0 then
|
||||
log("", "info")
|
||||
log(" KeyFiles(Important components)", "progress")
|
||||
table.sort(key_files, function(a, b)
|
||||
return (a.functions + a.classes * 2) > (b.functions + b.classes * 2)
|
||||
end)
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i <= 5 then
|
||||
log(" ⎿ " .. kf.path .. ": " .. kf.summary, "file")
|
||||
end
|
||||
end
|
||||
if #key_files > 5 then
|
||||
log(" ⎿ +" .. (#key_files - 5) .. " more key files", "debug")
|
||||
end
|
||||
end
|
||||
|
||||
state.findings.analyses = analyses
|
||||
state.findings.key_files = key_files
|
||||
|
||||
-- Store in brain if available
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized() then
|
||||
log("", "info")
|
||||
log(" Store(Brain context)", "progress")
|
||||
|
||||
-- Store project pattern
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root,
|
||||
content = {
|
||||
summary = "Project: " .. state.findings.project.name,
|
||||
detail = state.findings.project.language
|
||||
.. " "
|
||||
.. (state.findings.project.framework or state.findings.project.type),
|
||||
code = nil,
|
||||
},
|
||||
context = {
|
||||
file = root,
|
||||
language = state.findings.project.language,
|
||||
},
|
||||
})
|
||||
|
||||
-- Store key file patterns
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i <= 10 then
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root .. "/" .. kf.path,
|
||||
content = {
|
||||
summary = kf.path .. " - " .. kf.summary,
|
||||
detail = kf.summary,
|
||||
},
|
||||
context = {
|
||||
file = kf.path,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
log(" ⎿ Stored " .. math.min(#key_files, 10) + 1 .. " patterns in brain", "debug")
|
||||
end
|
||||
|
||||
-- Store in indexer if available
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
log(" Index(Project index)", "progress")
|
||||
indexer.index_project(function(index)
|
||||
log(" ⎿ Indexed " .. (index.stats.files or 0) .. " files", "debug")
|
||||
end)
|
||||
end
|
||||
|
||||
log("", "info")
|
||||
log("✓ Exploration complete!", "info")
|
||||
log("", "info")
|
||||
|
||||
-- Build result
|
||||
local result = {
|
||||
project = state.findings.project,
|
||||
structure = state.findings.structure,
|
||||
key_files = key_files,
|
||||
total_files = state.total_files,
|
||||
analyses = analyses,
|
||||
}
|
||||
|
||||
state.is_exploring = false
|
||||
state.on_log = nil
|
||||
|
||||
on_complete(result)
|
||||
end
|
||||
|
||||
--- Check if exploration is in progress
|
||||
---@return boolean
|
||||
function M.is_exploring()
|
||||
return state.is_exploring
|
||||
end
|
||||
|
||||
--- Get exploration progress
|
||||
---@return number scanned, number total
|
||||
function M.get_progress()
|
||||
return state.files_scanned, state.total_files
|
||||
end
|
||||
|
||||
--- Build context string from exploration result
|
||||
---@param result table Exploration result
|
||||
---@return string context
|
||||
function M.build_context(result)
|
||||
local parts = {}
|
||||
|
||||
-- Project info
|
||||
table.insert(parts, "## Project: " .. result.project.name)
|
||||
table.insert(parts, "- Type: " .. result.project.type)
|
||||
table.insert(parts, "- Language: " .. (result.project.language or "Unknown"))
|
||||
if result.project.framework then
|
||||
table.insert(parts, "- Framework: " .. result.project.framework)
|
||||
end
|
||||
table.insert(parts, "- Files: " .. result.total_files)
|
||||
table.insert(parts, "")
|
||||
|
||||
-- Structure
|
||||
table.insert(parts, "## Structure")
|
||||
if result.structure and result.structure.by_extension then
|
||||
for ext, count in pairs(result.structure.by_extension) do
|
||||
table.insert(parts, "- ." .. ext .. ": " .. count .. " files")
|
||||
end
|
||||
end
|
||||
table.insert(parts, "")
|
||||
|
||||
-- Key components
|
||||
if result.key_files and #result.key_files > 0 then
|
||||
table.insert(parts, "## Key Components")
|
||||
for i, kf in ipairs(result.key_files) do
|
||||
if i <= 10 then
|
||||
table.insert(parts, "- " .. kf.path .. ": " .. kf.summary)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
return M
|
||||
302
lua/codetyper/features/ask/intent.lua
Normal file
302
lua/codetyper/features/ask/intent.lua
Normal file
@@ -0,0 +1,302 @@
|
||||
---@mod codetyper.ask.intent Intent detection for Ask mode
|
||||
---@brief [[
|
||||
--- Analyzes user prompts to detect intent (ask/explain vs code generation).
|
||||
--- Routes to appropriate prompt type and context sources.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@alias IntentType "ask"|"explain"|"generate"|"refactor"|"document"|"test"
|
||||
|
||||
---@class Intent
|
||||
---@field type IntentType Detected intent type
|
||||
---@field confidence number 0-1 confidence score
|
||||
---@field needs_project_context boolean Whether project-wide context is needed
|
||||
---@field needs_brain_context boolean Whether brain/learned context is helpful
|
||||
---@field needs_exploration boolean Whether full project exploration is needed
|
||||
---@field keywords string[] Keywords that influenced detection
|
||||
|
||||
--- Patterns for detecting ask/explain intent (questions about code)
|
||||
local ASK_PATTERNS = {
|
||||
-- Question words
|
||||
{ pattern = "^what%s", weight = 0.9 },
|
||||
{ pattern = "^why%s", weight = 0.95 },
|
||||
{ pattern = "^how%s+does", weight = 0.9 },
|
||||
{ pattern = "^how%s+do%s+i", weight = 0.7 }, -- Could be asking for code
|
||||
{ pattern = "^where%s", weight = 0.85 },
|
||||
{ pattern = "^when%s", weight = 0.85 },
|
||||
{ pattern = "^which%s", weight = 0.8 },
|
||||
{ pattern = "^who%s", weight = 0.85 },
|
||||
{ pattern = "^can%s+you%s+explain", weight = 0.95 },
|
||||
{ pattern = "^could%s+you%s+explain", weight = 0.95 },
|
||||
{ pattern = "^please%s+explain", weight = 0.95 },
|
||||
|
||||
-- Explanation requests
|
||||
{ pattern = "explain%s", weight = 0.9 },
|
||||
{ pattern = "describe%s", weight = 0.85 },
|
||||
{ pattern = "tell%s+me%s+about", weight = 0.85 },
|
||||
{ pattern = "walk%s+me%s+through", weight = 0.9 },
|
||||
{ pattern = "help%s+me%s+understand", weight = 0.95 },
|
||||
{ pattern = "what%s+is%s+the%s+purpose", weight = 0.95 },
|
||||
{ pattern = "what%s+does%s+this", weight = 0.9 },
|
||||
{ pattern = "what%s+does%s+it", weight = 0.9 },
|
||||
{ pattern = "how%s+does%s+this%s+work", weight = 0.95 },
|
||||
{ pattern = "how%s+does%s+it%s+work", weight = 0.95 },
|
||||
|
||||
-- Understanding queries
|
||||
{ pattern = "understand", weight = 0.7 },
|
||||
{ pattern = "meaning%s+of", weight = 0.85 },
|
||||
{ pattern = "difference%s+between", weight = 0.9 },
|
||||
{ pattern = "compared%s+to", weight = 0.8 },
|
||||
{ pattern = "vs%s", weight = 0.7 },
|
||||
{ pattern = "versus", weight = 0.7 },
|
||||
{ pattern = "pros%s+and%s+cons", weight = 0.9 },
|
||||
{ pattern = "advantages", weight = 0.8 },
|
||||
{ pattern = "disadvantages", weight = 0.8 },
|
||||
{ pattern = "trade%-?offs?", weight = 0.85 },
|
||||
|
||||
-- Analysis requests
|
||||
{ pattern = "analyze", weight = 0.85 },
|
||||
{ pattern = "review", weight = 0.7 }, -- Could also be refactor
|
||||
{ pattern = "overview", weight = 0.9 },
|
||||
{ pattern = "summary", weight = 0.9 },
|
||||
{ pattern = "summarize", weight = 0.9 },
|
||||
|
||||
-- Question marks (weaker signal)
|
||||
{ pattern = "%?$", weight = 0.3 },
|
||||
{ pattern = "%?%s*$", weight = 0.3 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting code generation intent
|
||||
local GENERATE_PATTERNS = {
|
||||
-- Direct commands
|
||||
{ pattern = "^create%s", weight = 0.9 },
|
||||
{ pattern = "^make%s", weight = 0.85 },
|
||||
{ pattern = "^build%s", weight = 0.85 },
|
||||
{ pattern = "^write%s", weight = 0.9 },
|
||||
{ pattern = "^add%s", weight = 0.85 },
|
||||
{ pattern = "^implement%s", weight = 0.95 },
|
||||
{ pattern = "^generate%s", weight = 0.95 },
|
||||
{ pattern = "^code%s", weight = 0.8 },
|
||||
|
||||
-- Modification commands
|
||||
{ pattern = "^fix%s", weight = 0.9 },
|
||||
{ pattern = "^change%s", weight = 0.8 },
|
||||
{ pattern = "^update%s", weight = 0.75 },
|
||||
{ pattern = "^modify%s", weight = 0.8 },
|
||||
{ pattern = "^replace%s", weight = 0.85 },
|
||||
{ pattern = "^remove%s", weight = 0.85 },
|
||||
{ pattern = "^delete%s", weight = 0.85 },
|
||||
|
||||
-- Feature requests
|
||||
{ pattern = "i%s+need%s+a", weight = 0.8 },
|
||||
{ pattern = "i%s+want%s+a", weight = 0.8 },
|
||||
{ pattern = "give%s+me", weight = 0.7 },
|
||||
{ pattern = "show%s+me%s+how%s+to%s+code", weight = 0.9 },
|
||||
{ pattern = "how%s+do%s+i%s+implement", weight = 0.85 },
|
||||
{ pattern = "can%s+you%s+write", weight = 0.9 },
|
||||
{ pattern = "can%s+you%s+create", weight = 0.9 },
|
||||
{ pattern = "can%s+you%s+add", weight = 0.85 },
|
||||
{ pattern = "can%s+you%s+make", weight = 0.85 },
|
||||
|
||||
-- Code-specific terms
|
||||
{ pattern = "function%s+that", weight = 0.85 },
|
||||
{ pattern = "class%s+that", weight = 0.85 },
|
||||
{ pattern = "method%s+that", weight = 0.85 },
|
||||
{ pattern = "component%s+that", weight = 0.85 },
|
||||
{ pattern = "module%s+that", weight = 0.85 },
|
||||
{ pattern = "api%s+for", weight = 0.8 },
|
||||
{ pattern = "endpoint%s+for", weight = 0.8 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting refactor intent
|
||||
local REFACTOR_PATTERNS = {
|
||||
{ pattern = "^refactor%s", weight = 0.95 },
|
||||
{ pattern = "refactor%s+this", weight = 0.95 },
|
||||
{ pattern = "clean%s+up", weight = 0.85 },
|
||||
{ pattern = "improve%s+this%s+code", weight = 0.85 },
|
||||
{ pattern = "make%s+this%s+cleaner", weight = 0.85 },
|
||||
{ pattern = "simplify", weight = 0.8 },
|
||||
{ pattern = "optimize", weight = 0.75 }, -- Could be explain
|
||||
{ pattern = "reorganize", weight = 0.9 },
|
||||
{ pattern = "restructure", weight = 0.9 },
|
||||
{ pattern = "extract%s+to", weight = 0.9 },
|
||||
{ pattern = "split%s+into", weight = 0.85 },
|
||||
{ pattern = "dry%s+this", weight = 0.9 }, -- Don't repeat yourself
|
||||
{ pattern = "reduce%s+duplication", weight = 0.9 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting documentation intent
|
||||
local DOCUMENT_PATTERNS = {
|
||||
{ pattern = "^document%s", weight = 0.95 },
|
||||
{ pattern = "add%s+documentation", weight = 0.95 },
|
||||
{ pattern = "add%s+docs", weight = 0.95 },
|
||||
{ pattern = "add%s+comments", weight = 0.9 },
|
||||
{ pattern = "add%s+docstring", weight = 0.95 },
|
||||
{ pattern = "add%s+jsdoc", weight = 0.95 },
|
||||
{ pattern = "write%s+documentation", weight = 0.95 },
|
||||
{ pattern = "document%s+this", weight = 0.95 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting test generation intent
|
||||
local TEST_PATTERNS = {
|
||||
{ pattern = "^test%s", weight = 0.9 },
|
||||
{ pattern = "write%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "add%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "create%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "generate%s+tests?", weight = 0.95 },
|
||||
{ pattern = "unit%s+tests?", weight = 0.9 },
|
||||
{ pattern = "test%s+cases?%s+for", weight = 0.95 },
|
||||
{ pattern = "spec%s+for", weight = 0.85 },
|
||||
}
|
||||
|
||||
--- Patterns indicating project-wide context is needed
|
||||
local PROJECT_CONTEXT_PATTERNS = {
|
||||
{ pattern = "project", weight = 0.9 },
|
||||
{ pattern = "codebase", weight = 0.95 },
|
||||
{ pattern = "entire", weight = 0.7 },
|
||||
{ pattern = "whole", weight = 0.7 },
|
||||
{ pattern = "all%s+files", weight = 0.9 },
|
||||
{ pattern = "architecture", weight = 0.95 },
|
||||
{ pattern = "structure", weight = 0.85 },
|
||||
{ pattern = "how%s+is%s+.*%s+organized", weight = 0.95 },
|
||||
{ pattern = "where%s+is%s+.*%s+defined", weight = 0.9 },
|
||||
{ pattern = "dependencies", weight = 0.85 },
|
||||
{ pattern = "imports?%s+from", weight = 0.7 },
|
||||
{ pattern = "modules?", weight = 0.6 },
|
||||
{ pattern = "packages?", weight = 0.6 },
|
||||
}
|
||||
|
||||
--- Patterns indicating project exploration is needed (full indexing)
|
||||
local EXPLORE_PATTERNS = {
|
||||
{ pattern = "explain%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "explain%s+.*%s*codebase", weight = 1.0 },
|
||||
{ pattern = "explain%s+me%s+the%s+project", weight = 1.0 },
|
||||
{ pattern = "tell%s+me%s+about%s+.*%s*project", weight = 0.95 },
|
||||
{ pattern = "what%s+is%s+this%s+project", weight = 0.95 },
|
||||
{ pattern = "overview%s+of%s+.*%s*project", weight = 0.95 },
|
||||
{ pattern = "understand%s+.*%s*project", weight = 0.9 },
|
||||
{ pattern = "analyze%s+.*%s*project", weight = 0.9 },
|
||||
{ pattern = "explore%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "explore%s+.*%s*codebase", weight = 1.0 },
|
||||
{ pattern = "index%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "scan%s+.*%s*project", weight = 0.95 },
|
||||
}
|
||||
|
||||
--- Match patterns against text
|
||||
---@param text string Lowercased text to match
|
||||
---@param patterns table Pattern list with weights
|
||||
---@return number Score, string[] Matched keywords
|
||||
local function match_patterns(text, patterns)
|
||||
local score = 0
|
||||
local matched = {}
|
||||
|
||||
for _, p in ipairs(patterns) do
|
||||
if text:match(p.pattern) then
|
||||
score = score + p.weight
|
||||
table.insert(matched, p.pattern)
|
||||
end
|
||||
end
|
||||
|
||||
return score, matched
|
||||
end
|
||||
|
||||
--- Detect intent from user prompt
|
||||
---@param prompt string User's question/request
|
||||
---@return Intent Detected intent
|
||||
function M.detect(prompt)
|
||||
local text = prompt:lower()
|
||||
|
||||
-- Calculate raw scores for each intent type (sum of matched weights)
|
||||
local ask_score, ask_kw = match_patterns(text, ASK_PATTERNS)
|
||||
local gen_score, gen_kw = match_patterns(text, GENERATE_PATTERNS)
|
||||
local ref_score, ref_kw = match_patterns(text, REFACTOR_PATTERNS)
|
||||
local doc_score, doc_kw = match_patterns(text, DOCUMENT_PATTERNS)
|
||||
local test_score, test_kw = match_patterns(text, TEST_PATTERNS)
|
||||
local proj_score, _ = match_patterns(text, PROJECT_CONTEXT_PATTERNS)
|
||||
local explore_score, _ = match_patterns(text, EXPLORE_PATTERNS)
|
||||
|
||||
-- Find the winner by raw score (highest accumulated weight)
|
||||
local scores = {
|
||||
{ type = "ask", score = ask_score, keywords = ask_kw },
|
||||
{ type = "generate", score = gen_score, keywords = gen_kw },
|
||||
{ type = "refactor", score = ref_score, keywords = ref_kw },
|
||||
{ type = "document", score = doc_score, keywords = doc_kw },
|
||||
{ type = "test", score = test_score, keywords = test_kw },
|
||||
}
|
||||
|
||||
table.sort(scores, function(a, b)
|
||||
return a.score > b.score
|
||||
end)
|
||||
|
||||
local winner = scores[1]
|
||||
|
||||
-- If top score is very low, default to ask (safer for Q&A)
|
||||
if winner.score < 0.3 then
|
||||
winner = { type = "ask", score = 0.5, keywords = {} }
|
||||
end
|
||||
|
||||
-- If ask and generate are close AND there's a question mark, prefer ask
|
||||
if winner.type == "generate" and ask_score > 0 then
|
||||
if text:match("%?%s*$") and ask_score >= gen_score * 0.5 then
|
||||
winner = { type = "ask", score = ask_score, keywords = ask_kw }
|
||||
end
|
||||
end
|
||||
|
||||
-- Determine if "explain" vs "ask" (explain needs more context)
|
||||
local intent_type = winner.type
|
||||
if intent_type == "ask" then
|
||||
-- "explain" if asking about how something works, otherwise "ask"
|
||||
if text:match("explain") or text:match("how%s+does") or text:match("walk%s+me%s+through") then
|
||||
intent_type = "explain"
|
||||
end
|
||||
end
|
||||
|
||||
-- Normalize confidence to 0-1 range (cap at reasonable max)
|
||||
local confidence = math.min(winner.score / 2, 1.0)
|
||||
|
||||
-- Check if exploration is needed (full project indexing)
|
||||
local needs_exploration = explore_score >= 0.9
|
||||
|
||||
---@type Intent
|
||||
local intent = {
|
||||
type = intent_type,
|
||||
confidence = confidence,
|
||||
needs_project_context = proj_score > 0.5 or needs_exploration,
|
||||
needs_brain_context = intent_type == "ask" or intent_type == "explain",
|
||||
needs_exploration = needs_exploration,
|
||||
keywords = winner.keywords,
|
||||
}
|
||||
|
||||
return intent
|
||||
end
|
||||
|
||||
--- Get prompt type for system prompt selection
|
||||
---@param intent Intent Detected intent
|
||||
---@return string Prompt type for prompts.system
|
||||
function M.get_prompt_type(intent)
|
||||
local mapping = {
|
||||
ask = "ask",
|
||||
explain = "ask", -- Uses same prompt as ask
|
||||
generate = "code_generation",
|
||||
refactor = "refactor",
|
||||
document = "document",
|
||||
test = "test",
|
||||
}
|
||||
return mapping[intent.type] or "ask"
|
||||
end
|
||||
|
||||
--- Check if intent requires code output
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.produces_code(intent)
|
||||
local code_intents = {
|
||||
generate = true,
|
||||
refactor = true,
|
||||
document = true, -- Documentation is code (comments)
|
||||
test = true,
|
||||
}
|
||||
return code_intents[intent.type] or false
|
||||
end
|
||||
|
||||
return M
|
||||
456
lua/codetyper/features/completion/inject.lua
Normal file
456
lua/codetyper/features/completion/inject.lua
Normal file
@@ -0,0 +1,456 @@
|
||||
---@mod codetyper.agent.inject Smart code injection with import handling
|
||||
---@brief [[
|
||||
--- Intelligent code injection that properly handles imports, merging them
|
||||
--- into existing import sections instead of blindly appending.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class ImportConfig
|
||||
---@field pattern string Lua pattern to match import statements
|
||||
---@field multi_line boolean Whether imports can span multiple lines
|
||||
---@field sort_key function|nil Function to extract sort key from import
|
||||
---@field group_by function|nil Function to group imports
|
||||
|
||||
---@class ParsedCode
|
||||
---@field imports string[] Import statements
|
||||
---@field body string[] Non-import code lines
|
||||
---@field import_lines table<number, boolean> Map of line numbers that are imports
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local languages = require("codetyper.params.agent.languages")
|
||||
local import_patterns = languages.import_patterns
|
||||
|
||||
--- Check if a line is an import statement for the given language
|
||||
---@param line string
|
||||
---@param patterns table[] Import patterns for the language
|
||||
---@return boolean is_import
|
||||
---@return boolean is_multi_line
|
||||
local function is_import_line(line, patterns)
|
||||
for _, p in ipairs(patterns) do
|
||||
if line:match(p.pattern) then
|
||||
return true, p.multi_line or false
|
||||
end
|
||||
end
|
||||
return false, false
|
||||
end
|
||||
|
||||
|
||||
--- Check if a line ends a multi-line import
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function ends_multiline_import(line, filetype)
|
||||
return utils.ends_multiline_import(line, filetype)
|
||||
end
|
||||
|
||||
--- Parse code into imports and body
|
||||
---@param code string|string[] Code to parse
|
||||
---@param filetype string File type/extension
|
||||
---@return ParsedCode
|
||||
function M.parse_code(code, filetype)
|
||||
local lines
|
||||
if type(code) == "string" then
|
||||
lines = vim.split(code, "\n", { plain = true })
|
||||
else
|
||||
lines = code
|
||||
end
|
||||
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local result = {
|
||||
imports = {},
|
||||
body = {},
|
||||
import_lines = {},
|
||||
}
|
||||
|
||||
local in_multiline_import = false
|
||||
local current_import_lines = {}
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline_import then
|
||||
-- Continue collecting multi-line import
|
||||
table.insert(current_import_lines, line)
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
-- Complete the multi-line import
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
for j = i - #current_import_lines + 1, i do
|
||||
result.import_lines[j] = true
|
||||
end
|
||||
current_import_lines = {}
|
||||
in_multiline_import = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
result.import_lines[i] = true
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
-- Start of multi-line import
|
||||
in_multiline_import = true
|
||||
current_import_lines = { line }
|
||||
else
|
||||
-- Single-line import
|
||||
table.insert(result.imports, line)
|
||||
end
|
||||
else
|
||||
-- Non-import line
|
||||
table.insert(result.body, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
|
||||
if #current_import_lines > 0 then
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Find the import section range in a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return number|nil start_line First import line (1-indexed)
|
||||
---@return number|nil end_line Last import line (1-indexed)
|
||||
function M.find_import_section(bufnr, filetype)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local first_import = nil
|
||||
local last_import = nil
|
||||
local in_multiline = false
|
||||
local consecutive_non_import = 0
|
||||
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline then
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
in_multiline = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
if not first_import then
|
||||
first_import = i
|
||||
end
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
in_multiline = true
|
||||
end
|
||||
elseif utils.is_empty_or_comment(line, filetype) then
|
||||
-- Allow gaps in import section
|
||||
if first_import then
|
||||
consecutive_non_import = consecutive_non_import + 1
|
||||
if consecutive_non_import > max_gap then
|
||||
-- Too many non-import lines, import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
else
|
||||
-- Non-import, non-empty line
|
||||
if first_import then
|
||||
-- Import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return first_import, last_import
|
||||
end
|
||||
|
||||
--- Get existing imports from a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return string[] Existing import statements
|
||||
function M.get_existing_imports(bufnr, filetype)
|
||||
local start_line, end_line = M.find_import_section(bufnr, filetype)
|
||||
if not start_line then
|
||||
return {}
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local parsed = M.parse_code(lines, filetype)
|
||||
return parsed.imports
|
||||
end
|
||||
|
||||
--- Normalize an import for comparison (remove whitespace variations)
|
||||
---@param import_str string
|
||||
---@return string
|
||||
local function normalize_import(import_str)
|
||||
-- Remove trailing semicolon for comparison
|
||||
local normalized = import_str:gsub(";%s*$", "")
|
||||
-- Remove all whitespace around braces, commas, colons
|
||||
normalized = normalized:gsub("%s*{%s*", "{")
|
||||
normalized = normalized:gsub("%s*}%s*", "}")
|
||||
normalized = normalized:gsub("%s*,%s*", ",")
|
||||
normalized = normalized:gsub("%s*:%s*", ":")
|
||||
-- Collapse multiple whitespace to single space
|
||||
normalized = normalized:gsub("%s+", " ")
|
||||
-- Trim leading/trailing whitespace
|
||||
normalized = normalized:match("^%s*(.-)%s*$")
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if two imports are duplicates
|
||||
---@param import1 string
|
||||
---@param import2 string
|
||||
---@return boolean
|
||||
local function are_duplicate_imports(import1, import2)
|
||||
return normalize_import(import1) == normalize_import(import2)
|
||||
end
|
||||
|
||||
--- Merge new imports with existing ones, avoiding duplicates
|
||||
---@param existing string[] Existing imports
|
||||
---@param new_imports string[] New imports to merge
|
||||
---@return string[] Merged imports
|
||||
function M.merge_imports(existing, new_imports)
|
||||
local merged = {}
|
||||
local seen = {}
|
||||
|
||||
-- Add existing imports
|
||||
for _, imp in ipairs(existing) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new imports that aren't duplicates
|
||||
for _, imp in ipairs(new_imports) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
return merged
|
||||
end
|
||||
|
||||
--- Sort imports by their source/module
|
||||
---@param imports string[]
|
||||
---@param filetype string
|
||||
---@return string[]
|
||||
function M.sort_imports(imports, filetype)
|
||||
-- Group imports: stdlib/builtin first, then third-party, then local
|
||||
local builtin = {}
|
||||
local third_party = {}
|
||||
local local_imports = {}
|
||||
|
||||
for _, imp in ipairs(imports) do
|
||||
local category = utils.classify_import(imp, filetype)
|
||||
|
||||
if category == "builtin" then
|
||||
table.insert(builtin, imp)
|
||||
elseif category == "local" then
|
||||
table.insert(local_imports, imp)
|
||||
else
|
||||
table.insert(third_party, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort each group alphabetically
|
||||
table.sort(builtin)
|
||||
table.sort(third_party)
|
||||
table.sort(local_imports)
|
||||
|
||||
-- Combine with proper spacing
|
||||
local result = {}
|
||||
|
||||
for _, imp in ipairs(builtin) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(third_party) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #third_party > 0 and #local_imports > 0 then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(local_imports) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
---@class InjectResult
|
||||
---@field success boolean
|
||||
---@field imports_added number Number of new imports added
|
||||
---@field imports_merged boolean Whether imports were merged into existing section
|
||||
---@field body_lines number Number of body lines injected
|
||||
|
||||
--- Smart inject code into a buffer, properly handling imports
|
||||
---@param bufnr number Target buffer
|
||||
---@param code string|string[] Code to inject
|
||||
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
|
||||
---@return InjectResult
|
||||
function M.inject(bufnr, code, opts)
|
||||
opts = opts or {}
|
||||
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
|
||||
end
|
||||
|
||||
-- Get filetype
|
||||
local filetype = opts.filetype
|
||||
if not filetype then
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
filetype = vim.fn.fnamemodify(bufname, ":e")
|
||||
end
|
||||
|
||||
-- Parse the code to separate imports from body
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
|
||||
local result = {
|
||||
success = true,
|
||||
imports_added = 0,
|
||||
imports_merged = false,
|
||||
body_lines = #parsed.body,
|
||||
}
|
||||
|
||||
-- Handle imports first if there are any
|
||||
if #parsed.imports > 0 then
|
||||
local import_start, import_end = M.find_import_section(bufnr, filetype)
|
||||
|
||||
if import_start then
|
||||
-- Merge with existing import section
|
||||
local existing_imports = M.get_existing_imports(bufnr, filetype)
|
||||
local merged = M.merge_imports(existing_imports, parsed.imports)
|
||||
|
||||
-- Count how many new imports were actually added
|
||||
result.imports_added = #merged - #existing_imports
|
||||
result.imports_merged = true
|
||||
|
||||
-- Optionally sort imports
|
||||
if opts.sort_imports ~= false then
|
||||
merged = M.sort_imports(merged, filetype)
|
||||
end
|
||||
|
||||
-- Convert back to lines (handling multi-line imports)
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(merged) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Replace the import section
|
||||
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
|
||||
|
||||
-- Adjust line numbers for body injection
|
||||
local lines_diff = #import_lines - (import_end - import_start + 1)
|
||||
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
|
||||
opts.range.start_line = opts.range.start_line + lines_diff
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + lines_diff
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No existing import section, add imports at the top
|
||||
-- Find the first non-comment, non-empty line
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local insert_at = 0
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
-- Skip shebang, docstrings, and initial comments
|
||||
if trimmed ~= "" and not trimmed:match("^#!")
|
||||
and not trimmed:match("^['\"]") and not utils.is_empty_or_comment(line, filetype) then
|
||||
insert_at = i - 1
|
||||
break
|
||||
end
|
||||
insert_at = i
|
||||
end
|
||||
|
||||
-- Add imports with a trailing blank line
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(parsed.imports) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(import_lines, "") -- Blank line after imports
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
|
||||
result.imports_added = #parsed.imports
|
||||
result.imports_merged = false
|
||||
|
||||
-- Adjust body injection range
|
||||
if opts.range and opts.range.start_line then
|
||||
opts.range.start_line = opts.range.start_line + #import_lines
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + #import_lines
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle body (non-import) code
|
||||
if #parsed.body > 0 then
|
||||
-- Filter out empty leading/trailing lines from body
|
||||
local body_lines = parsed.body
|
||||
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
|
||||
table.remove(body_lines, 1)
|
||||
end
|
||||
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
|
||||
table.remove(body_lines)
|
||||
end
|
||||
|
||||
if #body_lines > 0 then
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
local strategy = opts.strategy or "append"
|
||||
|
||||
if strategy == "replace" and opts.range then
|
||||
local start_line = math.max(1, opts.range.start_line)
|
||||
local end_line = math.min(line_count, opts.range.end_line)
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
|
||||
elseif strategy == "insert" and opts.range then
|
||||
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
|
||||
else
|
||||
-- Default: append
|
||||
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Add blank line for spacing
|
||||
table.insert(body_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
|
||||
end
|
||||
|
||||
result.body_lines = #body_lines
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Check if code contains imports
|
||||
---@param code string|string[]
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
function M.has_imports(code, filetype)
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
return #parsed.imports > 0
|
||||
end
|
||||
|
||||
return M
|
||||
192
lua/codetyper/features/completion/inline.lua
Normal file
192
lua/codetyper/features/completion/inline.lua
Normal file
@@ -0,0 +1,192 @@
|
||||
---@mod codetyper.completion Insert mode completion for file references
|
||||
---
|
||||
--- Provides completion for @filename inside /@ @/ tags.
|
||||
|
||||
local M = {}
|
||||
|
||||
local parser = require("codetyper.parser")
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Get list of files for completion
|
||||
---@param prefix string Prefix to filter files
|
||||
---@return table[] List of completion items
|
||||
local function get_file_completions(prefix)
|
||||
local cwd = vim.fn.getcwd()
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
local current_dir = vim.fn.fnamemodify(current_file, ":h")
|
||||
local files = {}
|
||||
|
||||
-- Use vim.fn.glob to find files matching the prefix
|
||||
local pattern = prefix .. "*"
|
||||
|
||||
-- Determine base directory - use current file's directory if outside cwd
|
||||
local base_dir = cwd
|
||||
if current_dir ~= "" and not current_dir:find(cwd, 1, true) then
|
||||
-- File is outside project, use its directory as base
|
||||
base_dir = current_dir
|
||||
end
|
||||
|
||||
-- Search in base directory
|
||||
local matches = vim.fn.glob(base_dir .. "/" .. pattern, false, true)
|
||||
|
||||
-- Search with ** for all subdirectories
|
||||
local deep_matches = vim.fn.glob(base_dir .. "/**/" .. pattern, false, true)
|
||||
for _, m in ipairs(deep_matches) do
|
||||
table.insert(matches, m)
|
||||
end
|
||||
|
||||
-- Also search in cwd if different from base_dir
|
||||
if base_dir ~= cwd then
|
||||
local cwd_matches = vim.fn.glob(cwd .. "/" .. pattern, false, true)
|
||||
for _, m in ipairs(cwd_matches) do
|
||||
table.insert(matches, m)
|
||||
end
|
||||
local cwd_deep = vim.fn.glob(cwd .. "/**/" .. pattern, false, true)
|
||||
for _, m in ipairs(cwd_deep) do
|
||||
table.insert(matches, m)
|
||||
end
|
||||
end
|
||||
|
||||
-- Also search specific directories if prefix doesn't have path
|
||||
if not prefix:find("/") then
|
||||
local search_dirs = { "src", "lib", "lua", "app", "components", "utils", "tests" }
|
||||
for _, dir in ipairs(search_dirs) do
|
||||
local dir_path = base_dir .. "/" .. dir
|
||||
if vim.fn.isdirectory(dir_path) == 1 then
|
||||
local dir_matches = vim.fn.glob(dir_path .. "/**/" .. pattern, false, true)
|
||||
for _, m in ipairs(dir_matches) do
|
||||
table.insert(matches, m)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Convert to relative paths and deduplicate
|
||||
local seen = {}
|
||||
for _, match in ipairs(matches) do
|
||||
-- Convert to relative path based on which base it came from
|
||||
local rel_path
|
||||
if match:find(base_dir, 1, true) == 1 then
|
||||
rel_path = match:sub(#base_dir + 2)
|
||||
elseif match:find(cwd, 1, true) == 1 then
|
||||
rel_path = match:sub(#cwd + 2)
|
||||
else
|
||||
rel_path = vim.fn.fnamemodify(match, ":t") -- Just filename if can't make relative
|
||||
end
|
||||
|
||||
-- Skip directories, coder files, and hidden/generated files
|
||||
if vim.fn.isdirectory(match) == 0
|
||||
and not utils.is_coder_file(match)
|
||||
and not rel_path:match("^%.")
|
||||
and not rel_path:match("node_modules")
|
||||
and not rel_path:match("%.git/")
|
||||
and not rel_path:match("dist/")
|
||||
and not rel_path:match("build/")
|
||||
and not seen[rel_path]
|
||||
then
|
||||
seen[rel_path] = true
|
||||
table.insert(files, {
|
||||
word = rel_path,
|
||||
abbr = rel_path,
|
||||
kind = "File",
|
||||
menu = "[ref]",
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort by length (shorter paths first)
|
||||
table.sort(files, function(a, b)
|
||||
return #a.word < #b.word
|
||||
end)
|
||||
|
||||
-- Limit results
|
||||
local result = {}
|
||||
for i = 1, math.min(#files, 15) do
|
||||
result[i] = files[i]
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Show file completion popup
|
||||
function M.show_file_completion()
|
||||
-- Check if we're in an open prompt tag
|
||||
local is_inside = parser.is_cursor_in_open_tag()
|
||||
if not is_inside then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Get the prefix being typed
|
||||
local prefix = parser.get_file_ref_prefix()
|
||||
if prefix == nil then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Get completions
|
||||
local items = get_file_completions(prefix)
|
||||
|
||||
if #items == 0 then
|
||||
-- Try with empty prefix to show all files
|
||||
items = get_file_completions("")
|
||||
end
|
||||
|
||||
if #items > 0 then
|
||||
-- Calculate start column (position right after @)
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local col = cursor[2] - #prefix + 1 -- 1-indexed for complete()
|
||||
|
||||
-- Show completion popup
|
||||
vim.fn.complete(col, items)
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Setup completion for file references (works on ALL files)
|
||||
function M.setup()
|
||||
local group = vim.api.nvim_create_augroup("CoderCompletion", { clear = true })
|
||||
|
||||
-- Trigger completion on @ in insert mode (works on ALL files)
|
||||
vim.api.nvim_create_autocmd("InsertCharPre", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function()
|
||||
-- Skip special buffers
|
||||
if vim.bo.buftype ~= "" then
|
||||
return
|
||||
end
|
||||
|
||||
if vim.v.char == "@" then
|
||||
-- Schedule completion popup after the @ is inserted
|
||||
vim.schedule(function()
|
||||
-- Check we're in an open tag
|
||||
local is_inside = parser.is_cursor_in_open_tag()
|
||||
if not is_inside then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check we're not typing @/ (closing tag)
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local line = vim.api.nvim_get_current_line()
|
||||
local next_char = line:sub(cursor[2] + 2, cursor[2] + 2)
|
||||
|
||||
if next_char == "/" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Show file completion
|
||||
M.show_file_completion()
|
||||
end)
|
||||
end
|
||||
end,
|
||||
desc = "Trigger file completion on @ inside prompt tags",
|
||||
})
|
||||
|
||||
-- Also allow manual trigger with <C-x><C-f> style keybinding in insert mode
|
||||
vim.keymap.set("i", "<C-x>@", function()
|
||||
M.show_file_completion()
|
||||
end, { silent = true, desc = "Coder: Complete file reference" })
|
||||
end
|
||||
|
||||
return M
|
||||
491
lua/codetyper/features/completion/suggestion.lua
Normal file
491
lua/codetyper/features/completion/suggestion.lua
Normal file
@@ -0,0 +1,491 @@
|
||||
---@mod codetyper.suggestion Inline ghost text suggestions
|
||||
---@brief [[
|
||||
--- Provides Copilot-style inline suggestions with ghost text.
|
||||
--- Uses Copilot when available, falls back to codetyper's own suggestions.
|
||||
--- Shows suggestions as grayed-out text that can be accepted with Tab.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class SuggestionState
|
||||
---@field current_suggestion string|nil Current suggestion text
|
||||
---@field suggestions string[] List of available suggestions
|
||||
---@field current_index number Current suggestion index
|
||||
---@field extmark_id number|nil Virtual text extmark ID
|
||||
---@field bufnr number|nil Buffer where suggestion is shown
|
||||
---@field line number|nil Line where suggestion is shown
|
||||
---@field col number|nil Column where suggestion starts
|
||||
---@field timer any|nil Debounce timer
|
||||
---@field using_copilot boolean Whether currently using copilot
|
||||
|
||||
local state = {
|
||||
current_suggestion = nil,
|
||||
suggestions = {},
|
||||
current_index = 0,
|
||||
extmark_id = nil,
|
||||
bufnr = nil,
|
||||
line = nil,
|
||||
col = nil,
|
||||
timer = nil,
|
||||
using_copilot = false,
|
||||
}
|
||||
|
||||
--- Namespace for virtual text
|
||||
local ns = vim.api.nvim_create_namespace("codetyper_suggestion")
|
||||
|
||||
--- Highlight group for ghost text
|
||||
local hl_group = "CmpGhostText"
|
||||
|
||||
--- Configuration
|
||||
local config = {
|
||||
enabled = true,
|
||||
auto_trigger = true,
|
||||
debounce = 150,
|
||||
use_copilot = true, -- Use copilot when available
|
||||
keymap = {
|
||||
accept = "<Tab>",
|
||||
next = "<M-]>",
|
||||
prev = "<M-[>",
|
||||
dismiss = "<C-]>",
|
||||
},
|
||||
}
|
||||
|
||||
--- Check if copilot is available and enabled
|
||||
---@return boolean, table|nil available, copilot_suggestion module
|
||||
local function get_copilot()
|
||||
if not config.use_copilot then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local ok, copilot_suggestion = pcall(require, "copilot.suggestion")
|
||||
if not ok then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
-- Check if copilot suggestion is enabled
|
||||
local ok_client, copilot_client = pcall(require, "copilot.client")
|
||||
if ok_client and copilot_client.is_disabled and copilot_client.is_disabled() then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
return true, copilot_suggestion
|
||||
end
|
||||
|
||||
--- Check if suggestion is visible (copilot or codetyper)
|
||||
---@return boolean
|
||||
function M.is_visible()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
state.using_copilot = true
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check codetyper's own suggestion
|
||||
state.using_copilot = false
|
||||
return state.extmark_id ~= nil and state.current_suggestion ~= nil
|
||||
end
|
||||
|
||||
--- Clear the current suggestion
|
||||
function M.dismiss()
|
||||
-- Dismiss copilot if active
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.dismiss()
|
||||
end
|
||||
|
||||
-- Clear codetyper's suggestion
|
||||
if state.extmark_id and state.bufnr then
|
||||
pcall(vim.api.nvim_buf_del_extmark, state.bufnr, ns, state.extmark_id)
|
||||
end
|
||||
|
||||
state.current_suggestion = nil
|
||||
state.suggestions = {}
|
||||
state.current_index = 0
|
||||
state.extmark_id = nil
|
||||
state.bufnr = nil
|
||||
state.line = nil
|
||||
state.col = nil
|
||||
state.using_copilot = false
|
||||
end
|
||||
|
||||
--- Display suggestion as ghost text
|
||||
---@param suggestion string The suggestion to display
|
||||
local function display_suggestion(suggestion)
|
||||
if not suggestion or suggestion == "" then
|
||||
return
|
||||
end
|
||||
|
||||
M.dismiss()
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local line = cursor[1] - 1
|
||||
local col = cursor[2]
|
||||
|
||||
-- Split suggestion into lines
|
||||
local lines = vim.split(suggestion, "\n", { plain = true })
|
||||
|
||||
-- Build virtual text
|
||||
local virt_text = {}
|
||||
local virt_lines = {}
|
||||
|
||||
-- First line goes inline
|
||||
if #lines > 0 then
|
||||
virt_text = { { lines[1], hl_group } }
|
||||
end
|
||||
|
||||
-- Remaining lines go below
|
||||
for i = 2, #lines do
|
||||
table.insert(virt_lines, { { lines[i], hl_group } })
|
||||
end
|
||||
|
||||
-- Create extmark with virtual text
|
||||
local opts = {
|
||||
virt_text = virt_text,
|
||||
virt_text_pos = "overlay",
|
||||
hl_mode = "combine",
|
||||
}
|
||||
|
||||
if #virt_lines > 0 then
|
||||
opts.virt_lines = virt_lines
|
||||
end
|
||||
|
||||
state.extmark_id = vim.api.nvim_buf_set_extmark(bufnr, ns, line, col, opts)
|
||||
state.bufnr = bufnr
|
||||
state.line = line
|
||||
state.col = col
|
||||
state.current_suggestion = suggestion
|
||||
end
|
||||
|
||||
--- Accept the current suggestion
|
||||
---@return boolean Whether a suggestion was accepted
|
||||
function M.accept()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.accept()
|
||||
state.using_copilot = false
|
||||
return true
|
||||
end
|
||||
|
||||
-- Accept codetyper's suggestion
|
||||
if not M.is_visible() then
|
||||
return false
|
||||
end
|
||||
|
||||
local suggestion = state.current_suggestion
|
||||
local bufnr = state.bufnr
|
||||
local line = state.line
|
||||
local col = state.col
|
||||
|
||||
M.dismiss()
|
||||
|
||||
if suggestion and bufnr and line ~= nil and col ~= nil then
|
||||
-- Get current line content
|
||||
local current_line = vim.api.nvim_buf_get_lines(bufnr, line, line + 1, false)[1] or ""
|
||||
|
||||
-- Split suggestion into lines
|
||||
local suggestion_lines = vim.split(suggestion, "\n", { plain = true })
|
||||
|
||||
if #suggestion_lines == 1 then
|
||||
-- Single line - insert at cursor
|
||||
local new_line = current_line:sub(1, col) .. suggestion .. current_line:sub(col + 1)
|
||||
vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, { new_line })
|
||||
-- Move cursor to end of inserted text
|
||||
vim.api.nvim_win_set_cursor(0, { line + 1, col + #suggestion })
|
||||
else
|
||||
-- Multi-line - insert at cursor
|
||||
local first_line = current_line:sub(1, col) .. suggestion_lines[1]
|
||||
local last_line = suggestion_lines[#suggestion_lines] .. current_line:sub(col + 1)
|
||||
|
||||
local new_lines = { first_line }
|
||||
for i = 2, #suggestion_lines - 1 do
|
||||
table.insert(new_lines, suggestion_lines[i])
|
||||
end
|
||||
table.insert(new_lines, last_line)
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, new_lines)
|
||||
-- Move cursor to end of last line
|
||||
vim.api.nvim_win_set_cursor(0, { line + #new_lines, #suggestion_lines[#suggestion_lines] })
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Show next suggestion
|
||||
function M.next()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.next()
|
||||
return
|
||||
end
|
||||
|
||||
-- Codetyper's suggestions
|
||||
if #state.suggestions <= 1 then
|
||||
return
|
||||
end
|
||||
|
||||
state.current_index = (state.current_index % #state.suggestions) + 1
|
||||
display_suggestion(state.suggestions[state.current_index])
|
||||
end
|
||||
|
||||
--- Show previous suggestion
|
||||
function M.prev()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.prev()
|
||||
return
|
||||
end
|
||||
|
||||
-- Codetyper's suggestions
|
||||
if #state.suggestions <= 1 then
|
||||
return
|
||||
end
|
||||
|
||||
state.current_index = state.current_index - 1
|
||||
if state.current_index < 1 then
|
||||
state.current_index = #state.suggestions
|
||||
end
|
||||
display_suggestion(state.suggestions[state.current_index])
|
||||
end
|
||||
|
||||
--- Get suggestions from brain/indexer
|
||||
---@param prefix string Current word prefix
|
||||
---@param context table Context info
|
||||
---@return string[] suggestions
|
||||
local function get_suggestions(prefix, context)
|
||||
local suggestions = {}
|
||||
|
||||
-- Get completions from brain
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized and brain.is_initialized() then
|
||||
local result = brain.query({
|
||||
query = prefix,
|
||||
max_results = 5,
|
||||
types = { "pattern" },
|
||||
})
|
||||
|
||||
if result and result.nodes then
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c and node.c.code then
|
||||
table.insert(suggestions, node.c.code)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get completions from indexer
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
local index = indexer.load_index()
|
||||
if index and index.symbols then
|
||||
for symbol, _ in pairs(index.symbols) do
|
||||
if symbol:lower():find(prefix:lower(), 1, true) and symbol ~= prefix then
|
||||
-- Just complete the symbol name
|
||||
local completion = symbol:sub(#prefix + 1)
|
||||
if completion ~= "" then
|
||||
table.insert(suggestions, completion)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Buffer-based completions
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local seen = {}
|
||||
|
||||
for _, line in ipairs(lines) do
|
||||
for word in line:gmatch("[%a_][%w_]*") do
|
||||
if
|
||||
#word > #prefix
|
||||
and word:lower():find(prefix:lower(), 1, true) == 1
|
||||
and not seen[word]
|
||||
and word ~= prefix
|
||||
then
|
||||
seen[word] = true
|
||||
local completion = word:sub(#prefix + 1)
|
||||
if completion ~= "" then
|
||||
table.insert(suggestions, completion)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return suggestions
|
||||
end
|
||||
|
||||
--- Trigger suggestion generation
|
||||
function M.trigger()
|
||||
if not config.enabled then
|
||||
return
|
||||
end
|
||||
|
||||
-- If copilot is available and has a suggestion, don't show codetyper's
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
-- Copilot is handling suggestions
|
||||
state.using_copilot = true
|
||||
return
|
||||
end
|
||||
|
||||
-- Cancel existing timer
|
||||
if state.timer then
|
||||
state.timer:stop()
|
||||
state.timer = nil
|
||||
end
|
||||
|
||||
-- Get current context
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local line = vim.api.nvim_get_current_line()
|
||||
local col = cursor[2]
|
||||
local before_cursor = line:sub(1, col)
|
||||
|
||||
-- Extract prefix (word being typed)
|
||||
local prefix = before_cursor:match("[%a_][%w_]*$") or ""
|
||||
|
||||
if #prefix < 2 then
|
||||
M.dismiss()
|
||||
return
|
||||
end
|
||||
|
||||
-- Debounce - wait a bit longer to let copilot try first
|
||||
local debounce_time = copilot_ok and (config.debounce + 200) or config.debounce
|
||||
|
||||
state.timer = vim.defer_fn(function()
|
||||
-- Check again if copilot has shown something
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
state.using_copilot = true
|
||||
state.timer = nil
|
||||
return
|
||||
end
|
||||
|
||||
local suggestions = get_suggestions(prefix, {
|
||||
line = line,
|
||||
col = col,
|
||||
bufnr = vim.api.nvim_get_current_buf(),
|
||||
})
|
||||
|
||||
if #suggestions > 0 then
|
||||
state.suggestions = suggestions
|
||||
state.current_index = 1
|
||||
display_suggestion(suggestions[1])
|
||||
else
|
||||
M.dismiss()
|
||||
end
|
||||
|
||||
state.timer = nil
|
||||
end, debounce_time)
|
||||
end
|
||||
|
||||
--- Setup keymaps
|
||||
local function setup_keymaps()
|
||||
-- Accept with Tab (only when suggestion visible)
|
||||
vim.keymap.set("i", config.keymap.accept, function()
|
||||
if M.is_visible() then
|
||||
M.accept()
|
||||
return ""
|
||||
end
|
||||
-- Fallback to normal Tab behavior
|
||||
return vim.api.nvim_replace_termcodes("<Tab>", true, false, true)
|
||||
end, { expr = true, silent = true, desc = "Accept codetyper suggestion" })
|
||||
|
||||
-- Next suggestion
|
||||
vim.keymap.set("i", config.keymap.next, function()
|
||||
M.next()
|
||||
end, { silent = true, desc = "Next codetyper suggestion" })
|
||||
|
||||
-- Previous suggestion
|
||||
vim.keymap.set("i", config.keymap.prev, function()
|
||||
M.prev()
|
||||
end, { silent = true, desc = "Previous codetyper suggestion" })
|
||||
|
||||
-- Dismiss
|
||||
vim.keymap.set("i", config.keymap.dismiss, function()
|
||||
M.dismiss()
|
||||
end, { silent = true, desc = "Dismiss codetyper suggestion" })
|
||||
end
|
||||
|
||||
--- Setup autocmds for auto-trigger
|
||||
local function setup_autocmds()
|
||||
local group = vim.api.nvim_create_augroup("CodetypeSuggestion", { clear = true })
|
||||
|
||||
-- Trigger on text change in insert mode
|
||||
if config.auto_trigger then
|
||||
vim.api.nvim_create_autocmd("TextChangedI", {
|
||||
group = group,
|
||||
callback = function()
|
||||
M.trigger()
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
-- Dismiss on leaving insert mode
|
||||
vim.api.nvim_create_autocmd("InsertLeave", {
|
||||
group = group,
|
||||
callback = function()
|
||||
M.dismiss()
|
||||
end,
|
||||
})
|
||||
|
||||
-- Dismiss on cursor move (not from typing)
|
||||
vim.api.nvim_create_autocmd("CursorMovedI", {
|
||||
group = group,
|
||||
callback = function()
|
||||
-- Only dismiss if cursor moved significantly
|
||||
if state.line ~= nil then
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
if cursor[1] - 1 ~= state.line then
|
||||
M.dismiss()
|
||||
end
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Setup highlight group
|
||||
local function setup_highlights()
|
||||
-- Use Comment highlight or define custom ghost text style
|
||||
vim.api.nvim_set_hl(0, hl_group, { link = "Comment" })
|
||||
end
|
||||
|
||||
--- Setup the suggestion system
|
||||
---@param opts? table Configuration options
|
||||
function M.setup(opts)
|
||||
if opts then
|
||||
config = vim.tbl_deep_extend("force", config, opts)
|
||||
end
|
||||
|
||||
setup_highlights()
|
||||
setup_keymaps()
|
||||
setup_autocmds()
|
||||
end
|
||||
|
||||
--- Enable suggestions
|
||||
function M.enable()
|
||||
config.enabled = true
|
||||
end
|
||||
|
||||
--- Disable suggestions
|
||||
function M.disable()
|
||||
config.enabled = false
|
||||
M.dismiss()
|
||||
end
|
||||
|
||||
--- Toggle suggestions
|
||||
function M.toggle()
|
||||
if config.enabled then
|
||||
M.disable()
|
||||
else
|
||||
M.enable()
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
585
lua/codetyper/features/indexer/analyzer.lua
Normal file
585
lua/codetyper/features/indexer/analyzer.lua
Normal file
@@ -0,0 +1,585 @@
|
||||
---@mod codetyper.indexer.analyzer Code analyzer using Tree-sitter
|
||||
---@brief [[
|
||||
--- Analyzes source files to extract functions, classes, exports, and imports.
|
||||
--- Uses Tree-sitter when available, falls back to pattern matching.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local scanner = require("codetyper.features.indexer.scanner")
|
||||
|
||||
--- Language-specific query patterns for Tree-sitter
|
||||
local TS_QUERIES = {
|
||||
lua = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(function_definition) @func
|
||||
(local_function name: (identifier) @name) @func
|
||||
(assignment_statement
|
||||
(variable_list name: (identifier) @name)
|
||||
(expression_list value: (function_definition) @func))
|
||||
]],
|
||||
exports = [[
|
||||
(return_statement (expression_list (table_constructor))) @export
|
||||
]],
|
||||
},
|
||||
typescript = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_definition name: (property_identifier) @name) @func
|
||||
(arrow_function) @func
|
||||
(lexical_declaration
|
||||
(variable_declarator name: (identifier) @name value: (arrow_function) @func))
|
||||
]],
|
||||
exports = [[
|
||||
(export_statement) @export
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
]],
|
||||
},
|
||||
javascript = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_definition name: (property_identifier) @name) @func
|
||||
(arrow_function) @func
|
||||
]],
|
||||
exports = [[
|
||||
(export_statement) @export
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
]],
|
||||
},
|
||||
python = {
|
||||
functions = [[
|
||||
(function_definition name: (identifier) @name) @func
|
||||
]],
|
||||
classes = [[
|
||||
(class_definition name: (identifier) @name) @class
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
(import_from_statement) @import
|
||||
]],
|
||||
},
|
||||
go = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_declaration name: (field_identifier) @name) @func
|
||||
]],
|
||||
imports = [[
|
||||
(import_declaration) @import
|
||||
]],
|
||||
},
|
||||
rust = {
|
||||
functions = [[
|
||||
(function_item name: (identifier) @name) @func
|
||||
]],
|
||||
imports = [[
|
||||
(use_declaration) @import
|
||||
]],
|
||||
},
|
||||
}
|
||||
|
||||
-- Forward declaration for analyze_tree_generic (defined below)
|
||||
local analyze_tree_generic
|
||||
|
||||
--- Hash file content for change detection
|
||||
---@param content string
|
||||
---@return string
|
||||
local function hash_content(content)
|
||||
local hash = 0
|
||||
for i = 1, math.min(#content, 10000) do
|
||||
hash = (hash * 31 + string.byte(content, i)) % 2147483647
|
||||
end
|
||||
return string.format("%08x", hash)
|
||||
end
|
||||
|
||||
--- Try to get Tree-sitter parser for a language
|
||||
---@param lang string
|
||||
---@return boolean
|
||||
local function has_ts_parser(lang)
|
||||
local ok = pcall(vim.treesitter.language.inspect, lang)
|
||||
return ok
|
||||
end
|
||||
|
||||
--- Analyze file using Tree-sitter
|
||||
---@param filepath string
|
||||
---@param lang string
|
||||
---@param content string
|
||||
---@return table|nil
|
||||
local function analyze_with_treesitter(filepath, lang, content)
|
||||
if not has_ts_parser(lang) then
|
||||
return nil
|
||||
end
|
||||
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
-- Create a temporary buffer for parsing
|
||||
local bufnr = vim.api.nvim_create_buf(false, true)
|
||||
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, vim.split(content, "\n"))
|
||||
|
||||
local ok, parser = pcall(vim.treesitter.get_parser, bufnr, lang)
|
||||
if not ok or not parser then
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return nil
|
||||
end
|
||||
|
||||
local tree = parser:parse()[1]
|
||||
if not tree then
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return nil
|
||||
end
|
||||
|
||||
local root = tree:root()
|
||||
local queries = TS_QUERIES[lang]
|
||||
|
||||
if not queries then
|
||||
-- Fallback: walk tree manually for common patterns
|
||||
result = analyze_tree_generic(root, bufnr)
|
||||
else
|
||||
-- Use language-specific queries
|
||||
if queries.functions then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.functions)
|
||||
if query_ok then
|
||||
for id, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local capture_name = query.captures[id]
|
||||
if capture_name == "func" or capture_name == "name" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name = nil
|
||||
|
||||
-- Try to get name from sibling capture or child
|
||||
if capture_name == "func" then
|
||||
local name_node = node:field("name")[1]
|
||||
if name_node then
|
||||
name = vim.treesitter.get_node_text(name_node, bufnr)
|
||||
end
|
||||
else
|
||||
name = vim.treesitter.get_node_text(node, bufnr)
|
||||
end
|
||||
|
||||
if name and not vim.tbl_contains(vim.tbl_map(function(f)
|
||||
return f.name
|
||||
end, result.functions), name) then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.classes then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.classes)
|
||||
if query_ok then
|
||||
for id, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local capture_name = query.captures[id]
|
||||
if capture_name == "class" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.exports then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.exports)
|
||||
if query_ok then
|
||||
for _, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local text = vim.treesitter.get_node_text(node, bufnr)
|
||||
local start_row, _, _, _ = node:range()
|
||||
|
||||
-- Extract export names (simplified)
|
||||
local names = {}
|
||||
for name in text:gmatch("export%s+[%w_]+%s+([%w_]+)") do
|
||||
table.insert(names, name)
|
||||
end
|
||||
for name in text:gmatch("export%s*{([^}]+)}") do
|
||||
for n in name:gmatch("([%w_]+)") do
|
||||
table.insert(names, n)
|
||||
end
|
||||
end
|
||||
|
||||
for _, name in ipairs(names) do
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "unknown",
|
||||
line = start_row + 1,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.imports then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.imports)
|
||||
if query_ok then
|
||||
for _, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local text = vim.treesitter.get_node_text(node, bufnr)
|
||||
local start_row, _, _, _ = node:range()
|
||||
|
||||
-- Extract import source
|
||||
local source = text:match('["\']([^"\']+)["\']')
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = start_row + 1,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return result
|
||||
end
|
||||
|
||||
--- Generic tree analysis for unsupported languages
|
||||
---@param root TSNode
|
||||
---@param bufnr number
|
||||
---@return table
|
||||
analyze_tree_generic = function(root, bufnr)
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
local function visit(node)
|
||||
local node_type = node:type()
|
||||
|
||||
-- Common function patterns
|
||||
if
|
||||
node_type:match("function")
|
||||
or node_type:match("method")
|
||||
or node_type == "arrow_function"
|
||||
or node_type == "func_literal"
|
||||
then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
|
||||
-- Common class patterns
|
||||
if node_type:match("class") or node_type == "struct_item" or node_type == "impl_item" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
|
||||
-- Recurse into children
|
||||
for child in node:iter_children() do
|
||||
visit(child)
|
||||
end
|
||||
end
|
||||
|
||||
visit(root)
|
||||
return result
|
||||
end
|
||||
|
||||
--- Analyze file using pattern matching (fallback)
|
||||
---@param content string
|
||||
---@param lang string
|
||||
---@return table
|
||||
local function analyze_with_patterns(content, lang)
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
-- Language-specific patterns
|
||||
local patterns = {
|
||||
lua = {
|
||||
func_start = "^%s*local?%s*function%s+([%w_%.]+)",
|
||||
func_assign = "^%s*([%w_%.]+)%s*=%s*function",
|
||||
module_return = "^return%s+M",
|
||||
},
|
||||
javascript = {
|
||||
func_start = "^%s*function%s+([%w_]+)",
|
||||
func_arrow = "^%s*const%s+([%w_]+)%s*=%s*",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
export_line = "^%s*export%s+",
|
||||
import_line = "^%s*import%s+",
|
||||
},
|
||||
typescript = {
|
||||
func_start = "^%s*function%s+([%w_]+)",
|
||||
func_arrow = "^%s*const%s+([%w_]+)%s*=%s*",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
export_line = "^%s*export%s+",
|
||||
import_line = "^%s*import%s+",
|
||||
},
|
||||
python = {
|
||||
func_start = "^%s*def%s+([%w_]+)",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
import_line = "^%s*import%s+",
|
||||
from_import = "^%s*from%s+",
|
||||
},
|
||||
go = {
|
||||
func_start = "^func%s+([%w_]+)",
|
||||
method_start = "^func%s+%([^%)]+%)%s+([%w_]+)",
|
||||
import_line = "^import%s+",
|
||||
},
|
||||
rust = {
|
||||
func_start = "^%s*pub?%s*fn%s+([%w_]+)",
|
||||
struct_start = "^%s*pub?%s*struct%s+([%w_]+)",
|
||||
impl_start = "^%s*impl%s+([%w_<>]+)",
|
||||
use_line = "^%s*use%s+",
|
||||
},
|
||||
}
|
||||
|
||||
local lang_patterns = patterns[lang] or patterns.javascript
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
-- Functions
|
||||
if lang_patterns.func_start then
|
||||
local name = line:match(lang_patterns.func_start)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.func_arrow then
|
||||
local name = line:match(lang_patterns.func_arrow)
|
||||
if name and line:match("=>") then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.func_assign then
|
||||
local name = line:match(lang_patterns.func_assign)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.method_start then
|
||||
local name = line:match(lang_patterns.method_start)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Classes
|
||||
if lang_patterns.class_start then
|
||||
local name = line:match(lang_patterns.class_start)
|
||||
if name then
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.struct_start then
|
||||
local name = line:match(lang_patterns.struct_start)
|
||||
if name then
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Exports
|
||||
if lang_patterns.export_line and line:match(lang_patterns.export_line) then
|
||||
local name = line:match("export%s+[%w_]+%s+([%w_]+)")
|
||||
or line:match("export%s+default%s+([%w_]+)")
|
||||
or line:match("export%s+{%s*([%w_]+)")
|
||||
if name then
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "unknown",
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Imports
|
||||
if lang_patterns.import_line and line:match(lang_patterns.import_line) then
|
||||
local source = line:match('["\']([^"\']+)["\']')
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.from_import and line:match(lang_patterns.from_import) then
|
||||
local source = line:match("from%s+([%w_%.]+)")
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.use_line and line:match(lang_patterns.use_line) then
|
||||
local source = line:match("use%s+([%w_:]+)")
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- For Lua, infer exports from module table
|
||||
if lang == "lua" then
|
||||
for _, func in ipairs(result.functions) do
|
||||
if func.name:match("^M%.") then
|
||||
local name = func.name:gsub("^M%.", "")
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "function",
|
||||
line = func.line,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Analyze a single file
|
||||
---@param filepath string Full path to file
|
||||
---@return FileIndex|nil
|
||||
function M.analyze_file(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local lang = scanner.get_language(filepath)
|
||||
|
||||
-- Map to Tree-sitter language names
|
||||
local ts_lang_map = {
|
||||
typescript = "typescript",
|
||||
typescriptreact = "tsx",
|
||||
javascript = "javascript",
|
||||
javascriptreact = "javascript",
|
||||
python = "python",
|
||||
go = "go",
|
||||
rust = "rust",
|
||||
lua = "lua",
|
||||
}
|
||||
|
||||
local ts_lang = ts_lang_map[lang] or lang
|
||||
|
||||
-- Try Tree-sitter first
|
||||
local analysis = analyze_with_treesitter(filepath, ts_lang, content)
|
||||
|
||||
-- Fallback to pattern matching
|
||||
if not analysis then
|
||||
analysis = analyze_with_patterns(content, lang)
|
||||
end
|
||||
|
||||
return {
|
||||
path = filepath,
|
||||
language = lang,
|
||||
hash = hash_content(content),
|
||||
exports = analysis.exports,
|
||||
imports = analysis.imports,
|
||||
functions = analysis.functions,
|
||||
classes = analysis.classes,
|
||||
last_indexed = os.time(),
|
||||
}
|
||||
end
|
||||
|
||||
--- Extract exports from a buffer
|
||||
---@param bufnr number
|
||||
---@return Export[]
|
||||
function M.extract_exports(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.exports or {}
|
||||
end
|
||||
|
||||
--- Extract functions from a buffer
|
||||
---@param bufnr number
|
||||
---@return FunctionInfo[]
|
||||
function M.extract_functions(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.functions or {}
|
||||
end
|
||||
|
||||
--- Extract imports from a buffer
|
||||
---@param bufnr number
|
||||
---@return Import[]
|
||||
function M.extract_imports(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.imports or {}
|
||||
end
|
||||
|
||||
return M
|
||||
604
lua/codetyper/features/indexer/init.lua
Normal file
604
lua/codetyper/features/indexer/init.lua
Normal file
@@ -0,0 +1,604 @@
|
||||
---@mod codetyper.indexer Project indexer for Codetyper.nvim
|
||||
---@brief [[
|
||||
--- Indexes project structure, dependencies, and code symbols.
|
||||
--- Stores knowledge in .coder/ directory for enriching LLM context.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Index schema version for migrations
|
||||
local INDEX_VERSION = 1
|
||||
|
||||
--- Index file name
|
||||
local INDEX_FILE = "index.json"
|
||||
|
||||
--- Debounce timer for file indexing
|
||||
local index_timer = nil
|
||||
local INDEX_DEBOUNCE_MS = 500
|
||||
|
||||
--- Default indexer configuration
|
||||
local default_config = {
|
||||
enabled = true,
|
||||
auto_index = true,
|
||||
index_on_open = false,
|
||||
max_file_size = 100000,
|
||||
excluded_dirs = { "node_modules", "dist", "build", ".git", ".coder", "__pycache__", "vendor", "target" },
|
||||
index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" },
|
||||
memory = {
|
||||
enabled = true,
|
||||
max_memories = 1000,
|
||||
prune_threshold = 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
--- Current configuration
|
||||
---@type table
|
||||
local config = vim.deepcopy(default_config)
|
||||
|
||||
--- Cached project index
|
||||
---@type table<string, ProjectIndex>
|
||||
local index_cache = {}
|
||||
|
||||
---@class ProjectIndex
|
||||
---@field version number Index schema version
|
||||
---@field project_root string Absolute path to project
|
||||
---@field project_name string Project name
|
||||
---@field project_type string "node"|"rust"|"go"|"python"|"lua"|"unknown"
|
||||
---@field dependencies table<string, string> name -> version
|
||||
---@field dev_dependencies table<string, string> name -> version
|
||||
---@field files table<string, FileIndex> path -> FileIndex
|
||||
---@field symbols table<string, string[]> symbol -> [file paths]
|
||||
---@field last_indexed number Timestamp
|
||||
---@field stats {files: number, functions: number, classes: number, exports: number}
|
||||
|
||||
---@class FileIndex
|
||||
---@field path string Relative path from project root
|
||||
---@field language string Detected language
|
||||
---@field hash string Content hash for change detection
|
||||
---@field exports Export[] Exported symbols
|
||||
---@field imports Import[] Dependencies
|
||||
---@field functions FunctionInfo[]
|
||||
---@field classes ClassInfo[]
|
||||
---@field last_indexed number Timestamp
|
||||
|
||||
---@class Export
|
||||
---@field name string Symbol name
|
||||
---@field type string "function"|"class"|"constant"|"type"|"variable"
|
||||
---@field line number Line number
|
||||
|
||||
---@class Import
|
||||
---@field source string Import source/module
|
||||
---@field names string[] Imported names
|
||||
---@field line number Line number
|
||||
|
||||
---@class FunctionInfo
|
||||
---@field name string Function name
|
||||
---@field params string[] Parameter names
|
||||
---@field line number Start line
|
||||
---@field end_line number End line
|
||||
---@field docstring string|nil Documentation
|
||||
|
||||
---@class ClassInfo
|
||||
---@field name string Class name
|
||||
---@field methods string[] Method names
|
||||
---@field line number Start line
|
||||
---@field end_line number End line
|
||||
---@field docstring string|nil Documentation
|
||||
|
||||
--- Get the index file path
|
||||
---@return string|nil
|
||||
local function get_index_path()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. INDEX_FILE
|
||||
end
|
||||
|
||||
--- Create empty index structure
|
||||
---@return ProjectIndex
|
||||
local function create_empty_index()
|
||||
local root = utils.get_project_root()
|
||||
return {
|
||||
version = INDEX_VERSION,
|
||||
project_root = root or "",
|
||||
project_name = root and vim.fn.fnamemodify(root, ":t") or "",
|
||||
project_type = "unknown",
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
files = {},
|
||||
symbols = {},
|
||||
last_indexed = os.time(),
|
||||
stats = {
|
||||
files = 0,
|
||||
functions = 0,
|
||||
classes = 0,
|
||||
exports = 0,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Load index from disk
|
||||
---@return ProjectIndex|nil
|
||||
function M.load_index()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Check cache first
|
||||
if index_cache[root] then
|
||||
return index_cache[root]
|
||||
end
|
||||
|
||||
local path = get_index_path()
|
||||
if not path then
|
||||
return nil
|
||||
end
|
||||
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ok, index = pcall(vim.json.decode, content)
|
||||
if not ok or not index then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Validate version
|
||||
if index.version ~= INDEX_VERSION then
|
||||
-- Index needs migration or rebuild
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Cache it
|
||||
index_cache[root] = index
|
||||
return index
|
||||
end
|
||||
|
||||
--- Save index to disk
|
||||
---@param index ProjectIndex
|
||||
---@return boolean
|
||||
function M.save_index(index)
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Ensure .coder directory exists
|
||||
local coder_dir = root .. "/.coder"
|
||||
utils.ensure_dir(coder_dir)
|
||||
|
||||
local path = get_index_path()
|
||||
if not path then
|
||||
return false
|
||||
end
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, index)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
local success = utils.write_file(path, encoded)
|
||||
if success then
|
||||
-- Update cache
|
||||
index_cache[root] = index
|
||||
end
|
||||
return success
|
||||
end
|
||||
|
||||
--- Index the entire project
|
||||
---@param callback? fun(index: ProjectIndex)
|
||||
---@return ProjectIndex|nil
|
||||
function M.index_project(callback)
|
||||
local scanner = require("codetyper.features.indexer.scanner")
|
||||
local analyzer = require("codetyper.features.indexer.analyzer")
|
||||
|
||||
local index = create_empty_index()
|
||||
local root = utils.get_project_root()
|
||||
|
||||
if not root then
|
||||
if callback then
|
||||
callback(index)
|
||||
end
|
||||
return index
|
||||
end
|
||||
|
||||
-- Detect project type and parse dependencies
|
||||
index.project_type = scanner.detect_project_type(root)
|
||||
local deps = scanner.parse_dependencies(root, index.project_type)
|
||||
index.dependencies = deps.dependencies or {}
|
||||
index.dev_dependencies = deps.dev_dependencies or {}
|
||||
|
||||
-- Get all indexable files
|
||||
local files = scanner.get_indexable_files(root, config)
|
||||
|
||||
-- Index each file
|
||||
local total_functions = 0
|
||||
local total_classes = 0
|
||||
local total_exports = 0
|
||||
|
||||
for _, filepath in ipairs(files) do
|
||||
local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local file_index = analyzer.analyze_file(filepath)
|
||||
|
||||
if file_index then
|
||||
file_index.path = relative_path
|
||||
index.files[relative_path] = file_index
|
||||
|
||||
-- Update symbol index
|
||||
for _, exp in ipairs(file_index.exports or {}) do
|
||||
if not index.symbols[exp.name] then
|
||||
index.symbols[exp.name] = {}
|
||||
end
|
||||
table.insert(index.symbols[exp.name], relative_path)
|
||||
total_exports = total_exports + 1
|
||||
end
|
||||
|
||||
total_functions = total_functions + #(file_index.functions or {})
|
||||
total_classes = total_classes + #(file_index.classes or {})
|
||||
end
|
||||
end
|
||||
|
||||
-- Update stats
|
||||
index.stats = {
|
||||
files = #files,
|
||||
functions = total_functions,
|
||||
classes = total_classes,
|
||||
exports = total_exports,
|
||||
}
|
||||
index.last_indexed = os.time()
|
||||
|
||||
-- Save to disk
|
||||
M.save_index(index)
|
||||
|
||||
-- Store memories
|
||||
local memory = require("codetyper.features.indexer.memory")
|
||||
memory.store_index_summary(index)
|
||||
|
||||
-- Sync project summary to brain
|
||||
M.sync_project_to_brain(index, files, root)
|
||||
|
||||
if callback then
|
||||
callback(index)
|
||||
end
|
||||
|
||||
return index
|
||||
end
|
||||
|
||||
--- Sync project index to brain
|
||||
---@param index ProjectIndex
|
||||
---@param files string[] List of file paths
|
||||
---@param root string Project root
|
||||
function M.sync_project_to_brain(index, files, root)
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain or not brain.is_initialized or not brain.is_initialized() then
|
||||
return
|
||||
end
|
||||
|
||||
-- Store project-level pattern
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root,
|
||||
content = {
|
||||
summary = "Project: "
|
||||
.. index.project_name
|
||||
.. " ("
|
||||
.. index.project_type
|
||||
.. ") - "
|
||||
.. index.stats.files
|
||||
.. " files",
|
||||
detail = string.format(
|
||||
"%d functions, %d classes, %d exports",
|
||||
index.stats.functions,
|
||||
index.stats.classes,
|
||||
index.stats.exports
|
||||
),
|
||||
},
|
||||
context = {
|
||||
file = root,
|
||||
project_type = index.project_type,
|
||||
dependencies = index.dependencies,
|
||||
},
|
||||
})
|
||||
|
||||
-- Store key file patterns (files with most functions/classes)
|
||||
local key_files = {}
|
||||
for path, file_index in pairs(index.files) do
|
||||
local score = #(file_index.functions or {}) + (#(file_index.classes or {}) * 2)
|
||||
if score >= 3 then
|
||||
table.insert(key_files, { path = path, index = file_index, score = score })
|
||||
end
|
||||
end
|
||||
|
||||
table.sort(key_files, function(a, b)
|
||||
return a.score > b.score
|
||||
end)
|
||||
|
||||
-- Store top 20 key files in brain
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i > 20 then
|
||||
break
|
||||
end
|
||||
M.sync_to_brain(root .. "/" .. kf.path, kf.index)
|
||||
end
|
||||
end
|
||||
|
||||
--- Index a single file (incremental update)
|
||||
---@param filepath string
|
||||
---@return FileIndex|nil
|
||||
function M.index_file(filepath)
|
||||
local analyzer = require("codetyper.features.indexer.analyzer")
|
||||
local memory = require("codetyper.features.indexer.memory")
|
||||
local root = utils.get_project_root()
|
||||
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Load existing index
|
||||
local index = M.load_index() or create_empty_index()
|
||||
|
||||
-- Analyze file
|
||||
local file_index = analyzer.analyze_file(filepath)
|
||||
if not file_index then
|
||||
return nil
|
||||
end
|
||||
|
||||
local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
file_index.path = relative_path
|
||||
|
||||
-- Remove old symbol references for this file
|
||||
for symbol, paths in pairs(index.symbols) do
|
||||
for i = #paths, 1, -1 do
|
||||
if paths[i] == relative_path then
|
||||
table.remove(paths, i)
|
||||
end
|
||||
end
|
||||
if #paths == 0 then
|
||||
index.symbols[symbol] = nil
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new file index
|
||||
index.files[relative_path] = file_index
|
||||
|
||||
-- Update symbol index
|
||||
for _, exp in ipairs(file_index.exports or {}) do
|
||||
if not index.symbols[exp.name] then
|
||||
index.symbols[exp.name] = {}
|
||||
end
|
||||
table.insert(index.symbols[exp.name], relative_path)
|
||||
end
|
||||
|
||||
-- Recalculate stats
|
||||
local total_functions = 0
|
||||
local total_classes = 0
|
||||
local total_exports = 0
|
||||
local file_count = 0
|
||||
|
||||
for _, f in pairs(index.files) do
|
||||
file_count = file_count + 1
|
||||
total_functions = total_functions + #(f.functions or {})
|
||||
total_classes = total_classes + #(f.classes or {})
|
||||
total_exports = total_exports + #(f.exports or {})
|
||||
end
|
||||
|
||||
index.stats = {
|
||||
files = file_count,
|
||||
functions = total_functions,
|
||||
classes = total_classes,
|
||||
exports = total_exports,
|
||||
}
|
||||
index.last_indexed = os.time()
|
||||
|
||||
-- Save to disk
|
||||
M.save_index(index)
|
||||
|
||||
-- Store file memory
|
||||
memory.store_file_memory(relative_path, file_index)
|
||||
|
||||
-- Sync to brain if available
|
||||
M.sync_to_brain(filepath, file_index)
|
||||
|
||||
return file_index
|
||||
end
|
||||
|
||||
--- Sync file analysis to brain system
|
||||
---@param filepath string Full file path
|
||||
---@param file_index FileIndex File analysis
|
||||
function M.sync_to_brain(filepath, file_index)
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain or not brain.is_initialized or not brain.is_initialized() then
|
||||
return
|
||||
end
|
||||
|
||||
-- Only store if file has meaningful content
|
||||
local funcs = file_index.functions or {}
|
||||
local classes = file_index.classes or {}
|
||||
if #funcs == 0 and #classes == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build summary
|
||||
local parts = {}
|
||||
if #funcs > 0 then
|
||||
local func_names = {}
|
||||
for i, f in ipairs(funcs) do
|
||||
if i <= 5 then
|
||||
table.insert(func_names, f.name)
|
||||
end
|
||||
end
|
||||
table.insert(parts, "functions: " .. table.concat(func_names, ", "))
|
||||
if #funcs > 5 then
|
||||
table.insert(parts, "(+" .. (#funcs - 5) .. " more)")
|
||||
end
|
||||
end
|
||||
if #classes > 0 then
|
||||
local class_names = {}
|
||||
for _, c in ipairs(classes) do
|
||||
table.insert(class_names, c.name)
|
||||
end
|
||||
table.insert(parts, "classes: " .. table.concat(class_names, ", "))
|
||||
end
|
||||
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
local summary = filename .. " - " .. table.concat(parts, "; ")
|
||||
|
||||
-- Learn this pattern in brain
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = filepath,
|
||||
content = {
|
||||
summary = summary,
|
||||
detail = #funcs .. " functions, " .. #classes .. " classes",
|
||||
},
|
||||
context = {
|
||||
file = file_index.path or filepath,
|
||||
language = file_index.language,
|
||||
functions = funcs,
|
||||
classes = classes,
|
||||
exports = file_index.exports,
|
||||
imports = file_index.imports,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Schedule file indexing with debounce
|
||||
---@param filepath string
|
||||
function M.schedule_index_file(filepath)
|
||||
if not config.enabled or not config.auto_index then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if file should be indexed
|
||||
local scanner = require("codetyper.features.indexer.scanner")
|
||||
if not scanner.should_index(filepath, config) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Cancel existing timer
|
||||
if index_timer then
|
||||
index_timer:stop()
|
||||
end
|
||||
|
||||
-- Schedule new index
|
||||
index_timer = vim.defer_fn(function()
|
||||
M.index_file(filepath)
|
||||
index_timer = nil
|
||||
end, INDEX_DEBOUNCE_MS)
|
||||
end
|
||||
|
||||
--- Get relevant context for a prompt
|
||||
---@param opts {file: string, intent: table|nil, prompt: string, scope: string|nil}
|
||||
---@return table Context information
|
||||
function M.get_context_for(opts)
|
||||
local memory = require("codetyper.features.indexer.memory")
|
||||
local index = M.load_index()
|
||||
|
||||
local context = {
|
||||
project_type = "unknown",
|
||||
dependencies = {},
|
||||
relevant_files = {},
|
||||
relevant_symbols = {},
|
||||
patterns = {},
|
||||
}
|
||||
|
||||
if not index then
|
||||
return context
|
||||
end
|
||||
|
||||
context.project_type = index.project_type
|
||||
context.dependencies = index.dependencies
|
||||
|
||||
-- Find relevant symbols from prompt
|
||||
local words = {}
|
||||
for word in opts.prompt:gmatch("%w+") do
|
||||
if #word > 2 then
|
||||
words[word:lower()] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- Match symbols
|
||||
for symbol, files in pairs(index.symbols) do
|
||||
if words[symbol:lower()] then
|
||||
context.relevant_symbols[symbol] = files
|
||||
end
|
||||
end
|
||||
|
||||
-- Get file context if available
|
||||
if opts.file then
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
local relative_path = opts.file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local file_index = index.files[relative_path]
|
||||
if file_index then
|
||||
context.current_file = file_index
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get relevant memories
|
||||
context.patterns = memory.get_relevant(opts.prompt, 5)
|
||||
|
||||
return context
|
||||
end
|
||||
|
||||
--- Get index status
|
||||
---@return table Status information
|
||||
function M.get_status()
|
||||
local index = M.load_index()
|
||||
if not index then
|
||||
return {
|
||||
indexed = false,
|
||||
stats = nil,
|
||||
last_indexed = nil,
|
||||
}
|
||||
end
|
||||
|
||||
return {
|
||||
indexed = true,
|
||||
stats = index.stats,
|
||||
last_indexed = index.last_indexed,
|
||||
project_type = index.project_type,
|
||||
}
|
||||
end
|
||||
|
||||
--- Clear the project index
|
||||
function M.clear()
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
index_cache[root] = nil
|
||||
end
|
||||
|
||||
local path = get_index_path()
|
||||
if path and utils.file_exists(path) then
|
||||
os.remove(path)
|
||||
end
|
||||
end
|
||||
|
||||
--- Setup the indexer with configuration
|
||||
---@param opts? table Configuration options
|
||||
function M.setup(opts)
|
||||
if opts then
|
||||
config = vim.tbl_deep_extend("force", config, opts)
|
||||
end
|
||||
|
||||
-- Index on startup if configured
|
||||
if config.index_on_open then
|
||||
vim.defer_fn(function()
|
||||
M.index_project()
|
||||
end, 1000)
|
||||
end
|
||||
end
|
||||
|
||||
--- Get current configuration
|
||||
---@return table
|
||||
function M.get_config()
|
||||
return vim.deepcopy(config)
|
||||
end
|
||||
|
||||
return M
|
||||
539
lua/codetyper/features/indexer/memory.lua
Normal file
539
lua/codetyper/features/indexer/memory.lua
Normal file
@@ -0,0 +1,539 @@
|
||||
---@mod codetyper.indexer.memory Memory persistence manager
|
||||
---@brief [[
|
||||
--- Stores and retrieves learned patterns and memories in .coder/memories/.
|
||||
--- Supports session history for learning from interactions.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Memory directories
|
||||
local MEMORIES_DIR = "memories"
|
||||
local SESSIONS_DIR = "sessions"
|
||||
local FILES_DIR = "files"
|
||||
|
||||
--- Memory files
|
||||
local PATTERNS_FILE = "patterns.json"
|
||||
local CONVENTIONS_FILE = "conventions.json"
|
||||
local SYMBOLS_FILE = "symbols.json"
|
||||
|
||||
--- In-memory cache
|
||||
local cache = {
|
||||
patterns = nil,
|
||||
conventions = nil,
|
||||
symbols = nil,
|
||||
}
|
||||
|
||||
---@class Memory
|
||||
---@field id string Unique identifier
|
||||
---@field type "pattern"|"convention"|"session"|"interaction"
|
||||
---@field content string The learned information
|
||||
---@field context table Where/when learned
|
||||
---@field weight number Importance score (0.0-1.0)
|
||||
---@field created_at number Timestamp
|
||||
---@field updated_at number Last update timestamp
|
||||
---@field used_count number Times referenced
|
||||
|
||||
--- Get the memories base directory
|
||||
---@return string|nil
|
||||
local function get_memories_dir()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. MEMORIES_DIR
|
||||
end
|
||||
|
||||
--- Get the sessions directory
|
||||
---@return string|nil
|
||||
local function get_sessions_dir()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. SESSIONS_DIR
|
||||
end
|
||||
|
||||
--- Ensure memories directory exists
|
||||
---@return boolean
|
||||
local function ensure_memories_dir()
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
utils.ensure_dir(dir)
|
||||
utils.ensure_dir(dir .. "/" .. FILES_DIR)
|
||||
return true
|
||||
end
|
||||
|
||||
--- Ensure sessions directory exists
|
||||
---@return boolean
|
||||
local function ensure_sessions_dir()
|
||||
local dir = get_sessions_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
return utils.ensure_dir(dir)
|
||||
end
|
||||
|
||||
--- Generate a unique ID
|
||||
---@return string
|
||||
local function generate_id()
|
||||
return string.format("mem_%d_%s", os.time(), string.sub(tostring(math.random()), 3, 8))
|
||||
end
|
||||
|
||||
--- Load a memory file
|
||||
---@param filename string
|
||||
---@return table
|
||||
local function load_memory_file(filename)
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return {}
|
||||
end
|
||||
|
||||
local path = dir .. "/" .. filename
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return {}
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok or not data then
|
||||
return {}
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Save a memory file
|
||||
---@param filename string
|
||||
---@param data table
|
||||
---@return boolean
|
||||
local function save_memory_file(filename, data)
|
||||
if not ensure_memories_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
|
||||
local path = dir .. "/" .. filename
|
||||
local ok, encoded = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, encoded)
|
||||
end
|
||||
|
||||
--- Hash a file path for storage
|
||||
---@param filepath string
|
||||
---@return string
|
||||
local function hash_path(filepath)
|
||||
local hash = 0
|
||||
for i = 1, #filepath do
|
||||
hash = (hash * 31 + string.byte(filepath, i)) % 2147483647
|
||||
end
|
||||
return string.format("%08x", hash)
|
||||
end
|
||||
|
||||
--- Load patterns from cache or disk
|
||||
---@return table
|
||||
function M.load_patterns()
|
||||
if cache.patterns then
|
||||
return cache.patterns
|
||||
end
|
||||
cache.patterns = load_memory_file(PATTERNS_FILE)
|
||||
return cache.patterns
|
||||
end
|
||||
|
||||
--- Load conventions from cache or disk
|
||||
---@return table
|
||||
function M.load_conventions()
|
||||
if cache.conventions then
|
||||
return cache.conventions
|
||||
end
|
||||
cache.conventions = load_memory_file(CONVENTIONS_FILE)
|
||||
return cache.conventions
|
||||
end
|
||||
|
||||
--- Load symbols from cache or disk
|
||||
---@return table
|
||||
function M.load_symbols()
|
||||
if cache.symbols then
|
||||
return cache.symbols
|
||||
end
|
||||
cache.symbols = load_memory_file(SYMBOLS_FILE)
|
||||
return cache.symbols
|
||||
end
|
||||
|
||||
--- Store a new memory
|
||||
---@param memory Memory
|
||||
---@return boolean
|
||||
function M.store_memory(memory)
|
||||
memory.id = memory.id or generate_id()
|
||||
memory.created_at = memory.created_at or os.time()
|
||||
memory.updated_at = os.time()
|
||||
memory.used_count = memory.used_count or 0
|
||||
memory.weight = memory.weight or 0.5
|
||||
|
||||
local filename
|
||||
if memory.type == "pattern" then
|
||||
filename = PATTERNS_FILE
|
||||
cache.patterns = nil
|
||||
elseif memory.type == "convention" then
|
||||
filename = CONVENTIONS_FILE
|
||||
cache.conventions = nil
|
||||
else
|
||||
filename = PATTERNS_FILE
|
||||
cache.patterns = nil
|
||||
end
|
||||
|
||||
local data = load_memory_file(filename)
|
||||
data[memory.id] = memory
|
||||
|
||||
return save_memory_file(filename, data)
|
||||
end
|
||||
|
||||
--- Store file-specific memory
|
||||
---@param relative_path string Relative file path
|
||||
---@param file_index table FileIndex data
|
||||
---@return boolean
|
||||
function M.store_file_memory(relative_path, file_index)
|
||||
if not ensure_memories_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
|
||||
local hash = hash_path(relative_path)
|
||||
local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json"
|
||||
|
||||
local data = {
|
||||
path = relative_path,
|
||||
indexed_at = os.time(),
|
||||
functions = file_index.functions or {},
|
||||
classes = file_index.classes or {},
|
||||
exports = file_index.exports or {},
|
||||
imports = file_index.imports or {},
|
||||
}
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, encoded)
|
||||
end
|
||||
|
||||
--- Load file-specific memory
|
||||
---@param relative_path string
|
||||
---@return table|nil
|
||||
function M.load_file_memory(relative_path)
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return nil
|
||||
end
|
||||
|
||||
local hash = hash_path(relative_path)
|
||||
local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json"
|
||||
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok then
|
||||
return nil
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Store index summary as memories
|
||||
---@param index ProjectIndex
|
||||
function M.store_index_summary(index)
|
||||
-- Store project type convention
|
||||
if index.project_type and index.project_type ~= "unknown" then
|
||||
M.store_memory({
|
||||
type = "convention",
|
||||
content = "Project uses " .. index.project_type .. " ecosystem",
|
||||
context = {
|
||||
project_root = index.project_root,
|
||||
detected_at = os.time(),
|
||||
},
|
||||
weight = 0.9,
|
||||
})
|
||||
end
|
||||
|
||||
-- Store dependency patterns
|
||||
local dep_count = 0
|
||||
for _ in pairs(index.dependencies or {}) do
|
||||
dep_count = dep_count + 1
|
||||
end
|
||||
|
||||
if dep_count > 0 then
|
||||
local deps_list = {}
|
||||
for name, _ in pairs(index.dependencies) do
|
||||
table.insert(deps_list, name)
|
||||
end
|
||||
|
||||
M.store_memory({
|
||||
type = "pattern",
|
||||
content = "Project dependencies: " .. table.concat(deps_list, ", "),
|
||||
context = {
|
||||
dependency_count = dep_count,
|
||||
},
|
||||
weight = 0.7,
|
||||
})
|
||||
end
|
||||
|
||||
-- Update symbol cache
|
||||
cache.symbols = nil
|
||||
save_memory_file(SYMBOLS_FILE, index.symbols or {})
|
||||
end
|
||||
|
||||
--- Store session interaction
|
||||
---@param interaction {prompt: string, response: string, file: string|nil, success: boolean}
|
||||
function M.store_session(interaction)
|
||||
if not ensure_sessions_dir() then
|
||||
return
|
||||
end
|
||||
|
||||
local dir = get_sessions_dir()
|
||||
if not dir then
|
||||
return
|
||||
end
|
||||
|
||||
-- Use date-based session files
|
||||
local date = os.date("%Y-%m-%d")
|
||||
local path = dir .. "/" .. date .. ".json"
|
||||
|
||||
local sessions = {}
|
||||
local content = utils.read_file(path)
|
||||
if content then
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if ok and data then
|
||||
sessions = data
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(sessions, {
|
||||
timestamp = os.time(),
|
||||
prompt = interaction.prompt,
|
||||
response = string.sub(interaction.response or "", 1, 500), -- Truncate
|
||||
file = interaction.file,
|
||||
success = interaction.success,
|
||||
})
|
||||
|
||||
-- Limit session size
|
||||
if #sessions > 100 then
|
||||
sessions = { unpack(sessions, #sessions - 99) }
|
||||
end
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, sessions)
|
||||
if ok then
|
||||
utils.write_file(path, encoded)
|
||||
end
|
||||
end
|
||||
|
||||
--- Get relevant memories for a query
|
||||
---@param query string Search query
|
||||
---@param limit number Maximum results
|
||||
---@return Memory[]
|
||||
function M.get_relevant(query, limit)
|
||||
limit = limit or 10
|
||||
local results = {}
|
||||
|
||||
-- Tokenize query
|
||||
local query_words = {}
|
||||
for word in query:lower():gmatch("%w+") do
|
||||
if #word > 2 then
|
||||
query_words[word] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- Search patterns
|
||||
local patterns = M.load_patterns()
|
||||
for _, memory in pairs(patterns) do
|
||||
local score = 0
|
||||
local content_lower = (memory.content or ""):lower()
|
||||
|
||||
for word in pairs(query_words) do
|
||||
if content_lower:find(word, 1, true) then
|
||||
score = score + 1
|
||||
end
|
||||
end
|
||||
|
||||
if score > 0 then
|
||||
memory.relevance_score = score * (memory.weight or 0.5)
|
||||
table.insert(results, memory)
|
||||
end
|
||||
end
|
||||
|
||||
-- Search conventions
|
||||
local conventions = M.load_conventions()
|
||||
for _, memory in pairs(conventions) do
|
||||
local score = 0
|
||||
local content_lower = (memory.content or ""):lower()
|
||||
|
||||
for word in pairs(query_words) do
|
||||
if content_lower:find(word, 1, true) then
|
||||
score = score + 1
|
||||
end
|
||||
end
|
||||
|
||||
if score > 0 then
|
||||
memory.relevance_score = score * (memory.weight or 0.5)
|
||||
table.insert(results, memory)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort by relevance
|
||||
table.sort(results, function(a, b)
|
||||
return (a.relevance_score or 0) > (b.relevance_score or 0)
|
||||
end)
|
||||
|
||||
-- Limit results
|
||||
local limited = {}
|
||||
for i = 1, math.min(limit, #results) do
|
||||
limited[i] = results[i]
|
||||
end
|
||||
|
||||
return limited
|
||||
end
|
||||
|
||||
--- Update memory usage count
|
||||
---@param memory_id string
|
||||
function M.update_usage(memory_id)
|
||||
local patterns = M.load_patterns()
|
||||
if patterns[memory_id] then
|
||||
patterns[memory_id].used_count = (patterns[memory_id].used_count or 0) + 1
|
||||
patterns[memory_id].updated_at = os.time()
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
return
|
||||
end
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
if conventions[memory_id] then
|
||||
conventions[memory_id].used_count = (conventions[memory_id].used_count or 0) + 1
|
||||
conventions[memory_id].updated_at = os.time()
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
end
|
||||
|
||||
--- Get all memories
|
||||
---@return {patterns: table, conventions: table, symbols: table}
|
||||
function M.get_all()
|
||||
return {
|
||||
patterns = M.load_patterns(),
|
||||
conventions = M.load_conventions(),
|
||||
symbols = M.load_symbols(),
|
||||
}
|
||||
end
|
||||
|
||||
--- Clear all memories
|
||||
---@param pattern? string Optional pattern to match memory IDs
|
||||
function M.clear(pattern)
|
||||
if not pattern then
|
||||
-- Clear all
|
||||
cache = { patterns = nil, conventions = nil, symbols = nil }
|
||||
save_memory_file(PATTERNS_FILE, {})
|
||||
save_memory_file(CONVENTIONS_FILE, {})
|
||||
save_memory_file(SYMBOLS_FILE, {})
|
||||
return
|
||||
end
|
||||
|
||||
-- Clear matching pattern
|
||||
local patterns = M.load_patterns()
|
||||
for id in pairs(patterns) do
|
||||
if id:match(pattern) then
|
||||
patterns[id] = nil
|
||||
end
|
||||
end
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
for id in pairs(conventions) do
|
||||
if id:match(pattern) then
|
||||
conventions[id] = nil
|
||||
end
|
||||
end
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
|
||||
--- Prune low-weight memories
|
||||
---@param threshold number Weight threshold (default: 0.1)
|
||||
function M.prune(threshold)
|
||||
threshold = threshold or 0.1
|
||||
|
||||
local patterns = M.load_patterns()
|
||||
local pruned = 0
|
||||
for id, memory in pairs(patterns) do
|
||||
if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then
|
||||
patterns[id] = nil
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
if pruned > 0 then
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
end
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
for id, memory in pairs(conventions) do
|
||||
if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then
|
||||
conventions[id] = nil
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
if pruned > 0 then
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
|
||||
return pruned
|
||||
end
|
||||
|
||||
--- Get memory statistics
|
||||
---@return table
|
||||
function M.get_stats()
|
||||
local patterns = M.load_patterns()
|
||||
local conventions = M.load_conventions()
|
||||
local symbols = M.load_symbols()
|
||||
|
||||
local pattern_count = 0
|
||||
for _ in pairs(patterns) do
|
||||
pattern_count = pattern_count + 1
|
||||
end
|
||||
|
||||
local convention_count = 0
|
||||
for _ in pairs(conventions) do
|
||||
convention_count = convention_count + 1
|
||||
end
|
||||
|
||||
local symbol_count = 0
|
||||
for _ in pairs(symbols) do
|
||||
symbol_count = symbol_count + 1
|
||||
end
|
||||
|
||||
return {
|
||||
patterns = pattern_count,
|
||||
conventions = convention_count,
|
||||
symbols = symbol_count,
|
||||
total = pattern_count + convention_count,
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
409
lua/codetyper/features/indexer/scanner.lua
Normal file
409
lua/codetyper/features/indexer/scanner.lua
Normal file
@@ -0,0 +1,409 @@
|
||||
---@mod codetyper.indexer.scanner File scanner for project indexing
|
||||
---@brief [[
|
||||
--- Discovers indexable files, detects project type, and parses dependencies.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Project type markers
|
||||
local PROJECT_MARKERS = {
|
||||
node = { "package.json" },
|
||||
rust = { "Cargo.toml" },
|
||||
go = { "go.mod" },
|
||||
python = { "pyproject.toml", "setup.py", "requirements.txt" },
|
||||
lua = { "init.lua", ".luarc.json" },
|
||||
ruby = { "Gemfile" },
|
||||
java = { "pom.xml", "build.gradle" },
|
||||
csharp = { "*.csproj", "*.sln" },
|
||||
}
|
||||
|
||||
--- File extension to language mapping
|
||||
local EXTENSION_LANGUAGE = {
|
||||
lua = "lua",
|
||||
ts = "typescript",
|
||||
tsx = "typescriptreact",
|
||||
js = "javascript",
|
||||
jsx = "javascriptreact",
|
||||
py = "python",
|
||||
go = "go",
|
||||
rs = "rust",
|
||||
rb = "ruby",
|
||||
java = "java",
|
||||
c = "c",
|
||||
cpp = "cpp",
|
||||
h = "c",
|
||||
hpp = "cpp",
|
||||
cs = "csharp",
|
||||
}
|
||||
|
||||
--- Default ignore patterns
|
||||
local DEFAULT_IGNORES = {
|
||||
"^%.", -- Hidden files/folders
|
||||
"^node_modules$",
|
||||
"^__pycache__$",
|
||||
"^%.git$",
|
||||
"^%.coder$",
|
||||
"^dist$",
|
||||
"^build$",
|
||||
"^target$",
|
||||
"^vendor$",
|
||||
"^%.next$",
|
||||
"^%.nuxt$",
|
||||
"^coverage$",
|
||||
"%.min%.js$",
|
||||
"%.min%.css$",
|
||||
"%.map$",
|
||||
"%.lock$",
|
||||
"%-lock%.json$",
|
||||
}
|
||||
|
||||
--- Detect project type from root markers
|
||||
---@param root string Project root path
|
||||
---@return string Project type
|
||||
function M.detect_project_type(root)
|
||||
for project_type, markers in pairs(PROJECT_MARKERS) do
|
||||
for _, marker in ipairs(markers) do
|
||||
local path = root .. "/" .. marker
|
||||
if marker:match("^%*") then
|
||||
-- Glob pattern
|
||||
local pattern = marker:gsub("^%*", "")
|
||||
local entries = vim.fn.glob(root .. "/*" .. pattern, false, true)
|
||||
if #entries > 0 then
|
||||
return project_type
|
||||
end
|
||||
else
|
||||
if utils.file_exists(path) then
|
||||
return project_type
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return "unknown"
|
||||
end
|
||||
|
||||
--- Parse project dependencies
|
||||
---@param root string Project root path
|
||||
---@param project_type string Project type
|
||||
---@return {dependencies: table<string, string>, dev_dependencies: table<string, string>}
|
||||
function M.parse_dependencies(root, project_type)
|
||||
local deps = {
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
}
|
||||
|
||||
if project_type == "node" then
|
||||
deps = M.parse_package_json(root)
|
||||
elseif project_type == "rust" then
|
||||
deps = M.parse_cargo_toml(root)
|
||||
elseif project_type == "go" then
|
||||
deps = M.parse_go_mod(root)
|
||||
elseif project_type == "python" then
|
||||
deps = M.parse_python_deps(root)
|
||||
end
|
||||
|
||||
return deps
|
||||
end
|
||||
|
||||
--- Parse package.json for Node.js projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_package_json(root)
|
||||
local path = root .. "/package.json"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local ok, pkg = pcall(vim.json.decode, content)
|
||||
if not ok or not pkg then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
return {
|
||||
dependencies = pkg.dependencies or {},
|
||||
dev_dependencies = pkg.devDependencies or {},
|
||||
}
|
||||
end
|
||||
|
||||
--- Parse Cargo.toml for Rust projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_cargo_toml(root)
|
||||
local path = root .. "/Cargo.toml"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local deps = {}
|
||||
local dev_deps = {}
|
||||
local in_deps = false
|
||||
local in_dev_deps = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^%[dependencies%]") then
|
||||
in_deps = true
|
||||
in_dev_deps = false
|
||||
elseif line:match("^%[dev%-dependencies%]") then
|
||||
in_deps = false
|
||||
in_dev_deps = true
|
||||
elseif line:match("^%[") then
|
||||
in_deps = false
|
||||
in_dev_deps = false
|
||||
elseif in_deps or in_dev_deps then
|
||||
local name, version = line:match('^([%w_%-]+)%s*=%s*"([^"]+)"')
|
||||
if not name then
|
||||
name = line:match("^([%w_%-]+)%s*=")
|
||||
version = "workspace"
|
||||
end
|
||||
if name then
|
||||
if in_deps then
|
||||
deps[name] = version or "unknown"
|
||||
else
|
||||
dev_deps[name] = version or "unknown"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = dev_deps }
|
||||
end
|
||||
|
||||
--- Parse go.mod for Go projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_go_mod(root)
|
||||
local path = root .. "/go.mod"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local deps = {}
|
||||
local in_require = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^require%s*%(") then
|
||||
in_require = true
|
||||
elseif line:match("^%)") then
|
||||
in_require = false
|
||||
elseif in_require then
|
||||
local module, version = line:match("^%s*([%w%.%-%_/]+)%s+([%w%.%-]+)")
|
||||
if module then
|
||||
deps[module] = version
|
||||
end
|
||||
else
|
||||
local module, version = line:match("^require%s+([%w%.%-%_/]+)%s+([%w%.%-]+)")
|
||||
if module then
|
||||
deps[module] = version
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
--- Parse Python dependencies (pyproject.toml or requirements.txt)
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_python_deps(root)
|
||||
local deps = {}
|
||||
local dev_deps = {}
|
||||
|
||||
-- Try pyproject.toml first
|
||||
local pyproject = root .. "/pyproject.toml"
|
||||
local content = utils.read_file(pyproject)
|
||||
|
||||
if content then
|
||||
-- Simple parsing for dependencies
|
||||
local in_deps = false
|
||||
local in_dev = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^%[project%.dependencies%]") or line:match("^dependencies%s*=") then
|
||||
in_deps = true
|
||||
in_dev = false
|
||||
elseif line:match("dev") and line:match("dependencies") then
|
||||
in_deps = false
|
||||
in_dev = true
|
||||
elseif line:match("^%[") then
|
||||
in_deps = false
|
||||
in_dev = false
|
||||
elseif in_deps or in_dev then
|
||||
local name = line:match('"([%w_%-]+)')
|
||||
if name then
|
||||
if in_deps then
|
||||
deps[name] = "latest"
|
||||
else
|
||||
dev_deps[name] = "latest"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Fallback to requirements.txt
|
||||
local req_file = root .. "/requirements.txt"
|
||||
content = utils.read_file(req_file)
|
||||
|
||||
if content then
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if not line:match("^#") and not line:match("^%s*$") then
|
||||
local name, version = line:match("^([%w_%-]+)==([%d%.]+)")
|
||||
if not name then
|
||||
name = line:match("^([%w_%-]+)")
|
||||
version = "latest"
|
||||
end
|
||||
if name then
|
||||
deps[name] = version or "latest"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = dev_deps }
|
||||
end
|
||||
|
||||
--- Check if a file/directory should be ignored
|
||||
---@param name string File or directory name
|
||||
---@param config table Indexer configuration
|
||||
---@return boolean
|
||||
function M.should_ignore(name, config)
|
||||
-- Check default patterns
|
||||
for _, pattern in ipairs(DEFAULT_IGNORES) do
|
||||
if name:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
-- Check config excluded dirs
|
||||
if config and config.excluded_dirs then
|
||||
for _, dir in ipairs(config.excluded_dirs) do
|
||||
if name == dir then
|
||||
return true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if a file should be indexed
|
||||
---@param filepath string Full file path
|
||||
---@param config table Indexer configuration
|
||||
---@return boolean
|
||||
function M.should_index(filepath, config)
|
||||
local name = vim.fn.fnamemodify(filepath, ":t")
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
|
||||
-- Check if it's a coder file
|
||||
if utils.is_coder_file(filepath) then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Check file size
|
||||
if config and config.max_file_size then
|
||||
local stat = vim.loop.fs_stat(filepath)
|
||||
if stat and stat.size > config.max_file_size then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
-- Check extension
|
||||
if config and config.index_extensions then
|
||||
local valid_ext = false
|
||||
for _, allowed_ext in ipairs(config.index_extensions) do
|
||||
if ext == allowed_ext then
|
||||
valid_ext = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if not valid_ext then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
-- Check ignore patterns
|
||||
if M.should_ignore(name, config) then
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get all indexable files in the project
|
||||
---@param root string Project root path
|
||||
---@param config table Indexer configuration
|
||||
---@return string[] List of file paths
|
||||
function M.get_indexable_files(root, config)
|
||||
local files = {}
|
||||
|
||||
local function scan_dir(path)
|
||||
local handle = vim.loop.fs_scandir(path)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
|
||||
local full_path = path .. "/" .. name
|
||||
|
||||
if M.should_ignore(name, config) then
|
||||
goto continue
|
||||
end
|
||||
|
||||
if type == "directory" then
|
||||
scan_dir(full_path)
|
||||
elseif type == "file" then
|
||||
if M.should_index(full_path, config) then
|
||||
table.insert(files, full_path)
|
||||
end
|
||||
end
|
||||
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
|
||||
scan_dir(root)
|
||||
return files
|
||||
end
|
||||
|
||||
--- Get language from file extension
|
||||
---@param filepath string File path
|
||||
---@return string Language name
|
||||
function M.get_language(filepath)
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
return EXTENSION_LANGUAGE[ext] or ext
|
||||
end
|
||||
|
||||
--- Read .gitignore patterns
|
||||
---@param root string Project root
|
||||
---@return string[] Patterns
|
||||
function M.read_gitignore(root)
|
||||
local patterns = {}
|
||||
local path = root .. "/.gitignore"
|
||||
local content = utils.read_file(path)
|
||||
|
||||
if not content then
|
||||
return patterns
|
||||
end
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
-- Skip comments and empty lines
|
||||
if not line:match("^#") and not line:match("^%s*$") then
|
||||
-- Convert gitignore pattern to Lua pattern (simplified)
|
||||
local pattern = line:gsub("^/", "^"):gsub("%*%*", ".*"):gsub("%*", "[^/]*"):gsub("%?", ".")
|
||||
table.insert(patterns, pattern)
|
||||
end
|
||||
end
|
||||
|
||||
return patterns
|
||||
end
|
||||
|
||||
return M
|
||||
Reference in New Issue
Block a user