Deleting unnecesary features
This commit is contained in:
@@ -1,268 +0,0 @@
|
||||
---@mod codetyper.agent.context_builder Context builder for agent prompts
|
||||
---
|
||||
--- Builds rich context including project structure, memories, and conventions
|
||||
--- to help the LLM understand the codebase.
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local params = require("codetyper.params.agents.context")
|
||||
|
||||
--- Get project structure as a tree string
|
||||
---@param max_depth? number Maximum depth to traverse (default: 3)
|
||||
---@param max_files? number Maximum files to show (default: 50)
|
||||
---@return string Project tree
|
||||
function M.get_project_structure(max_depth, max_files)
|
||||
max_depth = max_depth or 3
|
||||
max_files = max_files or 50
|
||||
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local lines = { "PROJECT STRUCTURE:", root, "" }
|
||||
local file_count = 0
|
||||
|
||||
-- Common ignore patterns
|
||||
local ignore_patterns = params.ignore_patterns
|
||||
|
||||
local function should_ignore(name)
|
||||
for _, pattern in ipairs(ignore_patterns) do
|
||||
if name:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function traverse(path, depth, prefix)
|
||||
if depth > max_depth or file_count >= max_files then
|
||||
return
|
||||
end
|
||||
|
||||
local entries = {}
|
||||
local handle = vim.loop.fs_scandir(path)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
if not should_ignore(name) then
|
||||
table.insert(entries, { name = name, type = type })
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort: directories first, then alphabetically
|
||||
table.sort(entries, function(a, b)
|
||||
if a.type == "directory" and b.type ~= "directory" then
|
||||
return true
|
||||
elseif a.type ~= "directory" and b.type == "directory" then
|
||||
return false
|
||||
else
|
||||
return a.name < b.name
|
||||
end
|
||||
end)
|
||||
|
||||
for i, entry in ipairs(entries) do
|
||||
if file_count >= max_files then
|
||||
table.insert(lines, prefix .. "... (truncated)")
|
||||
return
|
||||
end
|
||||
|
||||
local is_last = (i == #entries)
|
||||
local branch = is_last and "└── " or "├── "
|
||||
local new_prefix = prefix .. (is_last and " " or "│ ")
|
||||
|
||||
local icon = entry.type == "directory" and "/" or ""
|
||||
table.insert(lines, prefix .. branch .. entry.name .. icon)
|
||||
file_count = file_count + 1
|
||||
|
||||
if entry.type == "directory" then
|
||||
traverse(path .. "/" .. entry.name, depth + 1, new_prefix)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
traverse(root, 1, "")
|
||||
|
||||
if file_count >= max_files then
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "(Structure truncated at " .. max_files .. " entries)")
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Get key files that are important for understanding the project
|
||||
---@return table<string, string> Map of filename to description
|
||||
function M.get_key_files()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local key_files = {}
|
||||
|
||||
local important_files = {
|
||||
["package.json"] = "Node.js project config",
|
||||
["Cargo.toml"] = "Rust project config",
|
||||
["go.mod"] = "Go module config",
|
||||
["pyproject.toml"] = "Python project config",
|
||||
["setup.py"] = "Python setup config",
|
||||
["Makefile"] = "Build configuration",
|
||||
["CMakeLists.txt"] = "CMake config",
|
||||
[".gitignore"] = "Git ignore patterns",
|
||||
["README.md"] = "Project documentation",
|
||||
["init.lua"] = "Neovim plugin entry",
|
||||
["plugin.lua"] = "Neovim plugin config",
|
||||
}
|
||||
|
||||
for filename, desc in paparams.important_filesnd
|
||||
|
||||
return key_files
|
||||
end
|
||||
|
||||
--- Detect project type and language
|
||||
---@return table { type: string, language: string, framework?: string }
|
||||
function M.detect_project_type()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
|
||||
local indicators = {
|
||||
["package.json"] = { type = "node", language = "javascript/typescript" },
|
||||
["Cargo.toml"] = { type = "rust", language = "rust" },
|
||||
["go.mod"] = { type = "go", language = "go" },
|
||||
["pyproject.toml"] = { type = "python", language = "python" },
|
||||
["setup.py"] = { type = "python", language = "python" },
|
||||
["Gemfile"] = { type = "ruby", language = "ruby" },
|
||||
["pom.xml"] = { type = "maven", language = "java" },
|
||||
["build.gradle"] = { type = "gradle", language = "java/kotlin" },
|
||||
}
|
||||
|
||||
-- Check for Neovim plugin specifically
|
||||
if vim.fn.isdirectoparams.indicators return info
|
||||
end
|
||||
end
|
||||
|
||||
return { type = "unknown", language = "unknown" }
|
||||
end
|
||||
|
||||
--- Get memories/patterns from the brain system
|
||||
---@return string Formatted memories context
|
||||
function M.get_memories_context()
|
||||
local ok_memory, memory = pcall(require, "codetyper.indexer.memory")
|
||||
if not ok_memory then
|
||||
return ""
|
||||
end
|
||||
|
||||
local all = memory.get_all()
|
||||
if not all then
|
||||
return ""
|
||||
end
|
||||
|
||||
local lines = {}
|
||||
|
||||
-- Add patterns
|
||||
if all.patterns and next(all.patterns) then
|
||||
table.insert(lines, "LEARNED PATTERNS:")
|
||||
local count = 0
|
||||
for _, mem in pairs(all.patterns) do
|
||||
if count >= 5 then
|
||||
break
|
||||
end
|
||||
if mem.content then
|
||||
table.insert(lines, " - " .. mem.content:sub(1, 100))
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
-- Add conventions
|
||||
if all.conventions and next(all.conventions) then
|
||||
table.insert(lines, "CODING CONVENTIONS:")
|
||||
local count = 0
|
||||
for _, mem in pairs(all.conventions) do
|
||||
if count >= 5 then
|
||||
break
|
||||
end
|
||||
if mem.content then
|
||||
table.insert(lines, " - " .. mem.content:sub(1, 100))
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Build the full context for agent prompts
|
||||
---@return string Full context string
|
||||
function M.build_full_context()
|
||||
local sections = {}
|
||||
|
||||
-- Project info
|
||||
local project_type = M.detect_project_type()
|
||||
table.insert(sections, string.format(
|
||||
"PROJECT INFO:\n Type: %s\n Language: %s%s\n",
|
||||
project_type.type,
|
||||
project_type.language,
|
||||
project_type.framework and ("\n Framework: " .. project_type.framework) or ""
|
||||
))
|
||||
|
||||
-- Project structure
|
||||
local structure = M.get_project_structure(3, 40)
|
||||
table.insert(sections, structure)
|
||||
|
||||
-- Key files
|
||||
local key_files = M.get_key_files()
|
||||
if next(key_files) then
|
||||
local key_lines = { "", "KEY FILES:" }
|
||||
for name, info in pairs(key_files) do
|
||||
table.insert(key_lines, string.format(" %s - %s", name, info.description))
|
||||
end
|
||||
table.insert(sections, table.concat(key_lines, "\n"))
|
||||
end
|
||||
|
||||
-- Memories
|
||||
local memories = M.get_memories_context()
|
||||
if memories ~= "" then
|
||||
table.insert(sections, "\n" .. memories)
|
||||
end
|
||||
|
||||
return table.concat(sections, "\n")
|
||||
end
|
||||
|
||||
--- Get a compact context summary for token efficiency
|
||||
---@return string Compact context
|
||||
function M.build_compact_context()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local project_type = M.detect_project_type()
|
||||
|
||||
local lines = {
|
||||
"CONTEXT:",
|
||||
" Root: " .. root,
|
||||
" Type: " .. project_type.type .. " (" .. project_type.language .. ")",
|
||||
}
|
||||
|
||||
-- Add main directories
|
||||
local main_dirs = {}
|
||||
local handle = vim.loop.fs_scandir(root)
|
||||
if handle then
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
if type == "directory" and not name:match("^%.") and not name:match("node_modules") then
|
||||
table.insert(main_dirs, name .. "/")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if #main_dirs > 0 then
|
||||
table.sort(main_dirs)
|
||||
table.insert(lines, " Main dirs: " .. table.concat(main_dirs, ", "))
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,754 +0,0 @@
|
||||
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
|
||||
---@brief [[
|
||||
--- Full agentic system that handles multi-file changes via tool calling.
|
||||
--- Multi-file agent system with tool orchestration.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgenticMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_calls? table[] For assistant messages with tool calls
|
||||
---@field tool_call_id? string For tool result messages
|
||||
---@field name? string Tool name for tool results
|
||||
|
||||
---@class AgenticToolCall
|
||||
---@field id string Unique tool call ID
|
||||
---@field type "function"
|
||||
---@field function {name: string, arguments: string|table}
|
||||
|
||||
---@class AgenticOpts
|
||||
---@field task string The task to accomplish
|
||||
---@field files? string[] Initial files to include as context
|
||||
---@field agent? string Agent name to use (default: "coder")
|
||||
---@field model? string Model override
|
||||
---@field max_iterations? number Max tool call rounds (default: 20)
|
||||
---@field on_message? fun(msg: AgenticMessage) Called for each message
|
||||
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
|
||||
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_file_change? fun(path: string, action: string) Called when file is modified
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
|
||||
---@field on_status? fun(status: string) Status updates
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Load agent definition
|
||||
---@param name string Agent name
|
||||
---@return table|nil agent definition
|
||||
local function load_agent(name)
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
local agent_file = agents_dir .. "/" .. name .. ".md"
|
||||
|
||||
-- Check if custom agent exists
|
||||
if vim.fn.filereadable(agent_file) == 1 then
|
||||
local content = table.concat(vim.fn.readfile(agent_file), "\n")
|
||||
-- Parse frontmatter and content
|
||||
local frontmatter = {}
|
||||
local body = content
|
||||
|
||||
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
|
||||
if fm_match then
|
||||
-- Parse YAML-like frontmatter
|
||||
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
|
||||
local key, value = line:match("^(%w+):%s*(.+)$")
|
||||
if key and value then
|
||||
frontmatter[key] = value
|
||||
end
|
||||
end
|
||||
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
|
||||
end
|
||||
|
||||
return {
|
||||
name = name,
|
||||
description = frontmatter.description or "Custom agent: " .. name,
|
||||
system_prompt = body,
|
||||
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
|
||||
model = frontmatter.model,
|
||||
}
|
||||
end
|
||||
|
||||
-- Built-in agents
|
||||
local builtin_agents = require("codetyper.prompts.agents.personas").builtin
|
||||
|
||||
return builtin_agents[name]
|
||||
end
|
||||
|
||||
--- Load rules from .coder/rules/
|
||||
---@return string Combined rules content
|
||||
local function load_rules()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
local rules = {}
|
||||
|
||||
if vim.fn.isdirectory(rules_dir) == 1 then
|
||||
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local content = table.concat(vim.fn.readfile(file), "\n")
|
||||
local filename = vim.fn.fnamemodify(file, ":t:r")
|
||||
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
|
||||
end
|
||||
end
|
||||
|
||||
if #rules > 0 then
|
||||
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
--- Build messages array for API request
|
||||
---@param history AgenticMessage[]
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted messages
|
||||
local function build_messages(history, provider)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
if provider == "claude" then
|
||||
-- Claude uses system parameter, not message
|
||||
-- Skip system messages in array
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
if provider == "claude" then
|
||||
-- Claude format: content is array of blocks
|
||||
message.content = {}
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(message.content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
for _, tc in ipairs(msg.tool_calls) do
|
||||
table.insert(message.content, {
|
||||
type = "tool_use",
|
||||
id = tc.id,
|
||||
name = tc["function"].name,
|
||||
input = type(tc["function"].arguments) == "string"
|
||||
and vim.json.decode(tc["function"].arguments)
|
||||
or tc["function"].arguments,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
if provider == "claude" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = msg.tool_call_id,
|
||||
content = msg.content,
|
||||
},
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Build tools array for API request
|
||||
---@param tool_names string[] Tool names to include
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted tools
|
||||
local function build_tools(tool_names, provider)
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
local tools = {}
|
||||
|
||||
for _, name in ipairs(tool_names) do
|
||||
local tool = tools_mod.get(name)
|
||||
if tool then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
if provider == "claude" then
|
||||
table.insert(tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Execute a tool call
|
||||
---@param tool_call AgenticToolCall
|
||||
---@param opts AgenticOpts
|
||||
---@return string result
|
||||
---@return string|nil error
|
||||
local function execute_tool(tool_call, opts)
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
local name = tool_call["function"].name
|
||||
local args = tool_call["function"].arguments
|
||||
|
||||
-- Parse arguments if string
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
else
|
||||
return "", "Failed to parse tool arguments: " .. args
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify tool start
|
||||
if opts.on_tool_start then
|
||||
opts.on_tool_start(name, args)
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status("Executing: " .. name)
|
||||
end
|
||||
|
||||
-- Execute the tool
|
||||
local tool = tools_mod.get(name)
|
||||
if not tool then
|
||||
local err = "Unknown tool: " .. name
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, nil, err)
|
||||
end
|
||||
return "", err
|
||||
end
|
||||
|
||||
local result, err = tool.func(args, {
|
||||
on_log = function(msg)
|
||||
if opts.on_status then
|
||||
opts.on_status(msg)
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
-- Notify tool end
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, result, err)
|
||||
end
|
||||
|
||||
-- Track file changes
|
||||
if opts.on_file_change and (name == "write" or name == "edit") and not err then
|
||||
opts.on_file_change(args.path, name == "write" and "created" or "modified")
|
||||
end
|
||||
|
||||
if err then
|
||||
return "", err
|
||||
end
|
||||
|
||||
return type(result) == "string" and result or vim.json.encode(result), nil
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return AgenticToolCall[]
|
||||
local function parse_tool_calls(response, provider)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Unified format: content array with tool_use blocks
|
||||
local content = response.content or {}
|
||||
for _, block in ipairs(content) do
|
||||
if block.type == "tool_use" then
|
||||
-- OpenAI expects arguments as JSON string, not table
|
||||
local args = block.input
|
||||
if type(args) == "table" then
|
||||
args = vim.json.encode(args)
|
||||
end
|
||||
|
||||
table.insert(tool_calls, {
|
||||
id = block.id or utils.generate_id("call"),
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = block.name,
|
||||
arguments = args,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Extract text content from response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return string
|
||||
local function extract_content(response, provider)
|
||||
local parts = {}
|
||||
for _, block in ipairs(response.content or {}) do
|
||||
if block.type == "text" then
|
||||
table.insert(parts, block.text)
|
||||
end
|
||||
end
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
--- Check if response indicates completion (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return boolean
|
||||
local function is_complete(response, provider)
|
||||
return response.stop_reason == "end_turn"
|
||||
end
|
||||
|
||||
--- Make API request to LLM with native tool calling support
|
||||
---@param messages table[] Formatted messages
|
||||
---@param tools table[] Formatted tools
|
||||
---@param system_prompt string System prompt
|
||||
---@param provider string "openai"|"claude"|"copilot"
|
||||
---@param model string Model name
|
||||
---@param callback fun(response: table|nil, error: string|nil)
|
||||
local function call_llm(messages, tools, system_prompt, provider, model, callback)
|
||||
local context = {
|
||||
language = "lua",
|
||||
file_content = "",
|
||||
prompt_type = "agent",
|
||||
project_root = vim.fn.getcwd(),
|
||||
cwd = vim.fn.getcwd(),
|
||||
}
|
||||
|
||||
-- Use native tool calling APIs
|
||||
if provider == "copilot" then
|
||||
local client = require("codetyper.core.llm.copilot")
|
||||
|
||||
-- Copilot's generate_with_tools expects messages in a specific format
|
||||
-- Convert to the format it expects
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
-- Convert to our internal format
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or utils.generate_id("call"),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "openai" then
|
||||
local client = require("codetyper.core.llm.openai")
|
||||
|
||||
-- OpenAI's generate_with_tools
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or utils.generate_id("call"),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "ollama" then
|
||||
local client = require("codetyper.core.llm.ollama")
|
||||
|
||||
-- Ollama's generate_with_tools (text-based tool calling)
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
callback(response, nil)
|
||||
end)
|
||||
else
|
||||
-- Fallback for other providers (ollama, etc.) - use text-based parsing
|
||||
local client = require("codetyper.core.llm." .. provider)
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompts = require("codetyper.prompts.agents")
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, prompts.text_user_prefix .. content)
|
||||
elseif msg.role == "assistant" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, prompts.text_assistant_prefix .. content)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add tool descriptions to prompt for text-based providers
|
||||
local tool_desc = require("codetyper.prompts.agents").tool_instructions_text
|
||||
for _, tool in ipairs(tools) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
context.file_content = system_prompt .. tool_desc
|
||||
|
||||
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Parse response for tool calls (text-based fallback)
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
-- Extract text content
|
||||
local text_content = response
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = utils.generate_id("call"),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
if text_content and text_content ~= "" then
|
||||
table.insert(result.content, 1, { type = "text", text = text_content })
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Run the agentic loop
|
||||
---@param opts AgenticOpts
|
||||
function M.run(opts)
|
||||
-- Load agent
|
||||
local agent = load_agent(opts.agent or "coder")
|
||||
if not agent then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Load rules
|
||||
local rules = load_rules()
|
||||
|
||||
-- Build system prompt
|
||||
local system_prompt = agent.system_prompt .. rules
|
||||
|
||||
-- Initialize message history
|
||||
---@type AgenticMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = system_prompt },
|
||||
}
|
||||
|
||||
-- Add initial file context if provided
|
||||
if opts.files and #opts.files > 0 then
|
||||
local file_context = require("codetyper.prompts.agents").format_file_context(opts.files)
|
||||
table.insert(history, { role = "user", content = file_context })
|
||||
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
|
||||
end
|
||||
|
||||
-- Add the task
|
||||
table.insert(history, { role = "user", content = opts.task })
|
||||
|
||||
-- Determine provider
|
||||
local config = require("codetyper").get_config()
|
||||
local provider = config.llm.provider or "copilot"
|
||||
-- Note: Ollama has its own handler in call_llm, don't change it
|
||||
|
||||
-- Get tools for this agent
|
||||
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
|
||||
|
||||
-- Ensure tools are loaded
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
tools_mod.setup()
|
||||
|
||||
-- Build tools for API
|
||||
local tools = build_tools(tool_names, provider)
|
||||
|
||||
-- Iteration tracking
|
||||
local iteration = 0
|
||||
local max_iterations = opts.max_iterations or 20
|
||||
|
||||
--- Process one iteration
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
|
||||
end
|
||||
|
||||
-- Build messages for API
|
||||
local messages = build_messages(history, provider)
|
||||
|
||||
-- Call LLM
|
||||
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
|
||||
if err then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, err)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract content and tool calls
|
||||
local content = extract_content(response, provider)
|
||||
local tool_calls = parse_tool_calls(response, provider)
|
||||
|
||||
-- Add assistant message to history
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(assistant_msg)
|
||||
end
|
||||
|
||||
-- Process tool calls if any
|
||||
if #tool_calls > 0 then
|
||||
for _, tc in ipairs(tool_calls) do
|
||||
local result, tool_err = execute_tool(tc, opts)
|
||||
|
||||
-- Add tool result to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = tc.id,
|
||||
name = tc["function"].name,
|
||||
content = tool_err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(tool_msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Continue the loop
|
||||
vim.schedule(process_iteration)
|
||||
else
|
||||
-- No tool calls - check if complete
|
||||
if is_complete(response, provider) or content ~= "" then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(content, nil)
|
||||
end
|
||||
else
|
||||
-- Continue if not explicitly complete
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create default agent files in .coder/agents/
|
||||
function M.init_agents_dir()
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
vim.fn.mkdir(agents_dir, "p")
|
||||
|
||||
-- Create example agent
|
||||
local example_agent = require("codetyper.prompts.agents.templates").agent
|
||||
|
||||
local example_path = agents_dir .. "/example.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
|
||||
end
|
||||
|
||||
return agents_dir
|
||||
end
|
||||
|
||||
--- Create default rules in .coder/rules/
|
||||
function M.init_rules_dir()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
vim.fn.mkdir(rules_dir, "p")
|
||||
|
||||
-- Create example rule
|
||||
local example_rule = require("codetyper.prompts.agents.templates").rule
|
||||
|
||||
local example_path = rules_dir .. "/code-style.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
|
||||
end
|
||||
|
||||
return rules_dir
|
||||
end
|
||||
|
||||
--- Initialize both agents and rules directories
|
||||
function M.init()
|
||||
M.init_agents_dir()
|
||||
M.init_rules_dir()
|
||||
end
|
||||
|
||||
--- List available agents
|
||||
---@return table[] List of {name, description, builtin}
|
||||
function M.list_agents()
|
||||
local agents = {}
|
||||
|
||||
-- Built-in agents
|
||||
local personas = require("codetyper.prompts.agents.personas").builtin
|
||||
local builtins = vim.tbl_keys(personas)
|
||||
table.sort(builtins)
|
||||
|
||||
for _, name in ipairs(builtins) do
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = true,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Custom agents from .coder/agents/
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
if vim.fn.isdirectory(agents_dir) == 1 then
|
||||
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local name = vim.fn.fnamemodify(file, ":t:r")
|
||||
if not vim.tbl_contains(builtins, name) then
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return agents
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,455 +0,0 @@
|
||||
---@mod codetyper.agent Agent orchestration for Codetyper.nvim
|
||||
---
|
||||
--- Manages the agentic conversation loop with tool execution.
|
||||
|
||||
local M = {}
|
||||
|
||||
local tools = require("codetyper.core.tools")
|
||||
local executor = require("codetyper.core.scheduler.executor")
|
||||
local parser = require("codetyper.core.llm.parser")
|
||||
local diff = require("codetyper.core.diff.diff")
|
||||
local diff_review = require("codetyper.adapters.nvim.ui.diff_review")
|
||||
local resume = require("codetyper.core.scheduler.resume")
|
||||
local utils = require("codetyper.support.utils")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
|
||||
---@class AgentState
|
||||
---@field conversation table[] Message history for multi-turn
|
||||
---@field pending_tool_results table[] Results waiting to be sent back
|
||||
---@field is_running boolean Whether agent loop is active
|
||||
---@field max_iterations number Maximum tool call iterations
|
||||
|
||||
local state = {
|
||||
conversation = {},
|
||||
pending_tool_results = {},
|
||||
is_running = false,
|
||||
max_iterations = 25, -- Increased for complex tasks (env setup, tests, fixes)
|
||||
current_iteration = 0,
|
||||
original_prompt = "", -- Store for resume functionality
|
||||
current_context = nil, -- Store context for resume
|
||||
current_callbacks = nil, -- Store callbacks for continue
|
||||
}
|
||||
|
||||
---@class AgentCallbacks
|
||||
---@field on_text fun(text: string) Called when text content is received
|
||||
---@field on_tool_start fun(name: string) Called when a tool starts
|
||||
---@field on_tool_result fun(name: string, result: string) Called when a tool completes
|
||||
---@field on_complete fun() Called when agent finishes
|
||||
---@field on_error fun(err: string) Called on error
|
||||
|
||||
--- Reset agent state for new conversation
|
||||
function M.reset()
|
||||
state.conversation = {}
|
||||
state.pending_tool_results = {}
|
||||
state.is_running = false
|
||||
state.current_iteration = 0
|
||||
-- Clear collected diffs
|
||||
diff_review.clear()
|
||||
end
|
||||
|
||||
--- Check if agent is currently running
|
||||
---@return boolean
|
||||
function M.is_running()
|
||||
return state.is_running
|
||||
end
|
||||
|
||||
--- Stop the agent
|
||||
function M.stop()
|
||||
state.is_running = false
|
||||
utils.notify("Agent stopped")
|
||||
end
|
||||
|
||||
--- Main agent entry point
|
||||
---@param prompt string User's request
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks Callback functions
|
||||
function M.run(prompt, context, callbacks)
|
||||
if state.is_running then
|
||||
callbacks.on_error("Agent is already running")
|
||||
return
|
||||
end
|
||||
|
||||
logs.info("Starting agent run")
|
||||
logs.debug("Prompt length: " .. #prompt .. " chars")
|
||||
|
||||
state.is_running = true
|
||||
state.current_iteration = 0
|
||||
state.original_prompt = prompt
|
||||
state.current_context = context
|
||||
state.current_callbacks = callbacks
|
||||
|
||||
-- Add user message to conversation
|
||||
table.insert(state.conversation, {
|
||||
role = "user",
|
||||
content = prompt,
|
||||
})
|
||||
|
||||
-- Start the agent loop
|
||||
M.agent_loop(context, callbacks)
|
||||
end
|
||||
|
||||
--- The core agent loop
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.agent_loop(context, callbacks)
|
||||
if not state.is_running then
|
||||
callbacks.on_complete()
|
||||
return
|
||||
end
|
||||
|
||||
state.current_iteration = state.current_iteration + 1
|
||||
logs.info(string.format("Agent loop iteration %d/%d", state.current_iteration, state.max_iterations))
|
||||
|
||||
if state.current_iteration > state.max_iterations then
|
||||
logs.info("Max iterations reached, asking user to continue or stop")
|
||||
-- Ask user if they want to continue
|
||||
M.prompt_continue(context, callbacks)
|
||||
return
|
||||
end
|
||||
|
||||
local llm = require("codetyper.core.llm")
|
||||
local client = llm.get_client()
|
||||
|
||||
-- Check if client supports tools
|
||||
if not client.generate_with_tools then
|
||||
logs.error("Provider does not support agent mode")
|
||||
callbacks.on_error("Current LLM provider does not support agent mode")
|
||||
state.is_running = false
|
||||
return
|
||||
end
|
||||
|
||||
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
|
||||
|
||||
-- Generate with tools enabled
|
||||
-- Ensure tools are loaded and get definitions
|
||||
tools.setup()
|
||||
local tool_defs = tools.to_openai_format()
|
||||
|
||||
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
|
||||
if err then
|
||||
state.is_running = false
|
||||
callbacks.on_error(err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Parse response based on provider
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
local parsed
|
||||
|
||||
-- Copilot uses Claude-like response format
|
||||
if config.llm.provider == "copilot" then
|
||||
parsed = parser.parse_claude_response(response)
|
||||
table.insert(state.conversation, {
|
||||
role = "assistant",
|
||||
content = parsed.text or "",
|
||||
tool_calls = parsed.tool_calls,
|
||||
_raw_content = response.content,
|
||||
})
|
||||
else
|
||||
-- For Ollama, response is the text directly
|
||||
if type(response) == "string" then
|
||||
parsed = parser.parse_ollama_response(response)
|
||||
else
|
||||
parsed = parser.parse_ollama_response(response.response or "")
|
||||
end
|
||||
-- Add assistant response to conversation
|
||||
table.insert(state.conversation, {
|
||||
role = "assistant",
|
||||
content = parsed.text,
|
||||
tool_calls = parsed.tool_calls,
|
||||
})
|
||||
end
|
||||
|
||||
-- Display any text content
|
||||
if parsed.text and parsed.text ~= "" then
|
||||
local clean_text = parser.clean_text(parsed.text)
|
||||
if clean_text ~= "" then
|
||||
callbacks.on_text(clean_text)
|
||||
end
|
||||
end
|
||||
|
||||
-- Check for tool calls
|
||||
if #parsed.tool_calls > 0 then
|
||||
logs.info(string.format("Processing %d tool call(s)", #parsed.tool_calls))
|
||||
-- Process tool calls sequentially
|
||||
M.process_tool_calls(parsed.tool_calls, 1, context, callbacks)
|
||||
else
|
||||
-- No more tool calls, agent is done
|
||||
logs.info("No tool calls, finishing agent loop")
|
||||
state.is_running = false
|
||||
callbacks.on_complete()
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Process tool calls one at a time
|
||||
---@param tool_calls table[] List of tool calls
|
||||
---@param index number Current index
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.process_tool_calls(tool_calls, index, context, callbacks)
|
||||
if not state.is_running then
|
||||
callbacks.on_complete()
|
||||
return
|
||||
end
|
||||
|
||||
if index > #tool_calls then
|
||||
-- All tools processed, continue agent loop with results
|
||||
M.continue_with_results(context, callbacks)
|
||||
return
|
||||
end
|
||||
|
||||
local tool_call = tool_calls[index]
|
||||
callbacks.on_tool_start(tool_call.name)
|
||||
|
||||
executor.execute(tool_call.name, tool_call.parameters, function(result)
|
||||
if result.requires_approval then
|
||||
logs.tool(tool_call.name, "approval", "Waiting for user approval")
|
||||
-- Show diff preview and wait for user decision
|
||||
local show_fn
|
||||
if result.diff_data.operation == "bash" then
|
||||
show_fn = function(_, cb)
|
||||
diff.show_bash_approval(result.diff_data.modified:gsub("^%$ ", ""), cb)
|
||||
end
|
||||
else
|
||||
show_fn = diff.show_diff
|
||||
end
|
||||
|
||||
show_fn(result.diff_data, function(approval_result)
|
||||
-- Handle both old (boolean) and new (table) approval result formats
|
||||
local approved = type(approval_result) == "table" and approval_result.approved or approval_result
|
||||
local permission_level = type(approval_result) == "table" and approval_result.permission_level or nil
|
||||
|
||||
if approved then
|
||||
local log_msg = "User approved"
|
||||
if permission_level == "allow_session" then
|
||||
log_msg = "Allowed for session"
|
||||
elseif permission_level == "allow_list" then
|
||||
log_msg = "Added to allow list"
|
||||
elseif permission_level == "auto" then
|
||||
log_msg = "Auto-approved"
|
||||
end
|
||||
logs.tool(tool_call.name, "approved", log_msg)
|
||||
|
||||
-- Apply the change and collect for review
|
||||
executor.apply_change(result.diff_data, function(apply_result)
|
||||
-- Collect the diff for end-of-session review
|
||||
if result.diff_data.operation ~= "bash" then
|
||||
diff_review.add({
|
||||
path = result.diff_data.path,
|
||||
operation = result.diff_data.operation,
|
||||
original = result.diff_data.original,
|
||||
modified = result.diff_data.modified,
|
||||
approved = true,
|
||||
applied = true,
|
||||
})
|
||||
end
|
||||
|
||||
-- Store result for sending back to LLM
|
||||
table.insert(state.pending_tool_results, {
|
||||
tool_use_id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
result = apply_result.result,
|
||||
})
|
||||
callbacks.on_tool_result(tool_call.name, apply_result.result)
|
||||
-- Process next tool call
|
||||
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
|
||||
end)
|
||||
else
|
||||
logs.tool(tool_call.name, "rejected", "User rejected")
|
||||
-- User rejected
|
||||
table.insert(state.pending_tool_results, {
|
||||
tool_use_id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
result = "User rejected this change",
|
||||
})
|
||||
callbacks.on_tool_result(tool_call.name, "Rejected by user")
|
||||
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
|
||||
end
|
||||
end)
|
||||
else
|
||||
-- No approval needed (read_file), store result immediately
|
||||
table.insert(state.pending_tool_results, {
|
||||
tool_use_id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
result = result.result,
|
||||
})
|
||||
|
||||
-- For read_file, just show a brief confirmation
|
||||
local display_result = result.result
|
||||
if tool_call.name == "read_file" and result.success then
|
||||
display_result = "[Read " .. #result.result .. " bytes]"
|
||||
end
|
||||
callbacks.on_tool_result(tool_call.name, display_result)
|
||||
|
||||
M.process_tool_calls(tool_calls, index + 1, context, callbacks)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Continue the loop after tool execution
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.continue_with_results(context, callbacks)
|
||||
if #state.pending_tool_results == 0 then
|
||||
state.is_running = false
|
||||
callbacks.on_complete()
|
||||
return
|
||||
end
|
||||
|
||||
-- Build tool results message
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
-- Copilot uses OpenAI format for tool results (role: "tool")
|
||||
if config.llm.provider == "copilot" then
|
||||
-- OpenAI-style tool messages - each result is a separate message
|
||||
for _, result in ipairs(state.pending_tool_results) do
|
||||
table.insert(state.conversation, {
|
||||
role = "tool",
|
||||
tool_call_id = result.tool_use_id,
|
||||
content = result.result,
|
||||
})
|
||||
end
|
||||
else
|
||||
-- Ollama format: plain text describing results
|
||||
local result_text = "Tool results:\n"
|
||||
for _, result in ipairs(state.pending_tool_results) do
|
||||
result_text = result_text .. "\n[" .. result.name .. "]: " .. result.result .. "\n"
|
||||
end
|
||||
table.insert(state.conversation, {
|
||||
role = "user",
|
||||
content = result_text,
|
||||
})
|
||||
end
|
||||
|
||||
state.pending_tool_results = {}
|
||||
|
||||
-- Continue the loop
|
||||
M.agent_loop(context, callbacks)
|
||||
end
|
||||
|
||||
--- Get conversation history
|
||||
---@return table[]
|
||||
function M.get_conversation()
|
||||
return state.conversation
|
||||
end
|
||||
|
||||
--- Set max iterations
|
||||
---@param max number Maximum iterations
|
||||
function M.set_max_iterations(max)
|
||||
state.max_iterations = max
|
||||
end
|
||||
|
||||
--- Get the count of collected changes
|
||||
---@return number
|
||||
function M.get_changes_count()
|
||||
return diff_review.count()
|
||||
end
|
||||
|
||||
--- Show the diff review UI for all collected changes
|
||||
function M.show_diff_review()
|
||||
diff_review.open()
|
||||
end
|
||||
|
||||
--- Check if diff review is open
|
||||
---@return boolean
|
||||
function M.is_review_open()
|
||||
return diff_review.is_open()
|
||||
end
|
||||
|
||||
--- Prompt user to continue or stop at max iterations
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.prompt_continue(context, callbacks)
|
||||
vim.schedule(function()
|
||||
vim.ui.select({ "Continue (25 more iterations)", "Stop and save for later" }, {
|
||||
prompt = string.format("Agent reached %d iterations. Continue?", state.max_iterations),
|
||||
}, function(choice)
|
||||
if choice and choice:match("^Continue") then
|
||||
-- Reset iteration counter and continue
|
||||
state.current_iteration = 0
|
||||
logs.info("User chose to continue, resetting iteration counter")
|
||||
M.agent_loop(context, callbacks)
|
||||
else
|
||||
-- Save state for later resume
|
||||
logs.info("User chose to stop, saving state for resume")
|
||||
resume.save(
|
||||
state.conversation,
|
||||
state.pending_tool_results,
|
||||
state.current_iteration,
|
||||
state.original_prompt
|
||||
)
|
||||
state.is_running = false
|
||||
callbacks.on_text("Agent paused. Use /continue to resume later.")
|
||||
callbacks.on_complete()
|
||||
end
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Continue a previously stopped agent session
|
||||
---@param callbacks AgentCallbacks
|
||||
---@return boolean Success
|
||||
function M.continue_session(callbacks)
|
||||
if state.is_running then
|
||||
utils.notify("Agent is already running", vim.log.levels.WARN)
|
||||
return false
|
||||
end
|
||||
|
||||
local saved = resume.load()
|
||||
if not saved then
|
||||
utils.notify("No saved agent session to continue", vim.log.levels.WARN)
|
||||
return false
|
||||
end
|
||||
|
||||
logs.info("Resuming agent session")
|
||||
logs.info(string.format("Loaded %d messages, iteration %d", #saved.conversation, saved.iteration))
|
||||
|
||||
-- Restore state
|
||||
state.conversation = saved.conversation
|
||||
state.pending_tool_results = saved.pending_tool_results or {}
|
||||
state.current_iteration = 0 -- Reset for fresh iterations
|
||||
state.original_prompt = saved.original_prompt
|
||||
state.is_running = true
|
||||
state.current_callbacks = callbacks
|
||||
|
||||
-- Build context from current state
|
||||
local llm = require("codetyper.core.llm")
|
||||
local context = {}
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
if current_file ~= "" and vim.fn.filereadable(current_file) == 1 then
|
||||
context = llm.build_context(current_file, "agent")
|
||||
end
|
||||
state.current_context = context
|
||||
|
||||
-- Clear saved state
|
||||
resume.clear()
|
||||
|
||||
-- Add continuation message
|
||||
table.insert(state.conversation, {
|
||||
role = "user",
|
||||
content = "Continue where you left off. Complete the remaining tasks.",
|
||||
})
|
||||
|
||||
-- Continue the loop
|
||||
callbacks.on_text("Resuming agent session...")
|
||||
M.agent_loop(context, callbacks)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Check if there's a saved session to continue
|
||||
---@return boolean
|
||||
function M.has_saved_session()
|
||||
return resume.has_saved_state()
|
||||
end
|
||||
|
||||
--- Get info about saved session
|
||||
---@return table|nil
|
||||
function M.get_saved_session_info()
|
||||
return resume.get_info()
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,425 +0,0 @@
|
||||
---@mod codetyper.agent.linter Linter validation for generated code
|
||||
---@brief [[
|
||||
--- Validates generated code by checking LSP diagnostics after injection.
|
||||
--- Automatically saves the file and waits for LSP to update before checking.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local config_params = require("codetyper.params.agents.linter")
|
||||
local prompts = require("codetyper.prompts.agents.linter")
|
||||
|
||||
--- Configuration
|
||||
local config = config_params.config
|
||||
|
||||
--- Diagnostic results for tracking
|
||||
---@type table<number, table>
|
||||
local validation_results = {}
|
||||
|
||||
--- Configure linter behavior
|
||||
---@param opts table Configuration options
|
||||
function M.configure(opts)
|
||||
for k, v in pairs(opts) do
|
||||
if config[k] ~= nil then
|
||||
config[k] = v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Get current configuration
|
||||
---@return table
|
||||
function M.get_config()
|
||||
return vim.deepcopy(config)
|
||||
end
|
||||
|
||||
--- Save buffer if modified
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean success
|
||||
local function save_buffer(bufnr)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip if buffer is not modified
|
||||
if not vim.bo[bufnr].modified then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Skip if buffer has no name (unsaved file)
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
if bufname == "" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Save the buffer
|
||||
local ok, err = pcall(function()
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("silent! write")
|
||||
end)
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = "Failed to save buffer: " .. tostring(err),
|
||||
})
|
||||
end)
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get LSP diagnostics for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line? number Start line (1-indexed)
|
||||
---@param end_line? number End line (1-indexed)
|
||||
---@return table[] diagnostics List of diagnostics
|
||||
function M.get_diagnostics(bufnr, start_line, end_line)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return {}
|
||||
end
|
||||
|
||||
local all_diagnostics = vim.diagnostic.get(bufnr)
|
||||
local filtered = {}
|
||||
|
||||
for _, diag in ipairs(all_diagnostics) do
|
||||
-- Filter by severity
|
||||
if diag.severity <= config.min_severity then
|
||||
-- Filter by line range if specified
|
||||
if start_line and end_line then
|
||||
local diag_line = diag.lnum + 1 -- Convert to 1-indexed
|
||||
if diag_line >= start_line and diag_line <= end_line then
|
||||
table.insert(filtered, diag)
|
||||
end
|
||||
else
|
||||
table.insert(filtered, diag)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return filtered
|
||||
end
|
||||
|
||||
--- Format a diagnostic for display
|
||||
---@param diag table Diagnostic object
|
||||
---@return string
|
||||
local function format_diagnostic(diag)
|
||||
local severity_names = {
|
||||
[vim.diagnostic.severity.ERROR] = "ERROR",
|
||||
[vim.diagnostic.severity.WARN] = "WARN",
|
||||
[vim.diagnostic.severity.INFO] = "INFO",
|
||||
[vim.diagnostic.severity.HINT] = "HINT",
|
||||
}
|
||||
local severity = severity_names[diag.severity] or "UNKNOWN"
|
||||
local line = diag.lnum + 1
|
||||
local source = diag.source or "lsp"
|
||||
return string.format("[%s] Line %d (%s): %s", severity, line, source, diag.message)
|
||||
end
|
||||
|
||||
--- Check if there are errors in generated code region
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line number Start line (1-indexed)
|
||||
---@param end_line number End line (1-indexed)
|
||||
---@return table result {has_errors, has_warnings, diagnostics, summary}
|
||||
function M.check_region(bufnr, start_line, end_line)
|
||||
local diagnostics = M.get_diagnostics(bufnr, start_line, end_line)
|
||||
|
||||
local errors = 0
|
||||
local warnings = 0
|
||||
|
||||
for _, diag in ipairs(diagnostics) do
|
||||
if diag.severity == vim.diagnostic.severity.ERROR then
|
||||
errors = errors + 1
|
||||
elseif diag.severity == vim.diagnostic.severity.WARN then
|
||||
warnings = warnings + 1
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
has_errors = errors > 0,
|
||||
has_warnings = warnings > 0,
|
||||
error_count = errors,
|
||||
warning_count = warnings,
|
||||
diagnostics = diagnostics,
|
||||
summary = string.format("%d error(s), %d warning(s)", errors, warnings),
|
||||
}
|
||||
end
|
||||
|
||||
--- Validate code after injection and report issues
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line? number Start line of injected code (1-indexed)
|
||||
---@param end_line? number End line of injected code (1-indexed)
|
||||
---@param callback? function Callback with (result) when validation completes
|
||||
function M.validate_after_injection(bufnr, start_line, end_line, callback)
|
||||
-- Save the file first
|
||||
if config.auto_save then
|
||||
save_buffer(bufnr)
|
||||
end
|
||||
|
||||
-- Wait for LSP to process changes
|
||||
vim.defer_fn(function()
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
if callback then callback(nil) end
|
||||
return
|
||||
end
|
||||
|
||||
local result
|
||||
if start_line and end_line then
|
||||
result = M.check_region(bufnr, start_line, end_line)
|
||||
else
|
||||
-- Check entire buffer
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
result = M.check_region(bufnr, 1, line_count)
|
||||
end
|
||||
|
||||
-- Store result for this buffer
|
||||
validation_results[bufnr] = {
|
||||
timestamp = os.time(),
|
||||
result = result,
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
}
|
||||
|
||||
-- Log results
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
if result.has_errors then
|
||||
logs.add({
|
||||
type = "error",
|
||||
message = string.format("Linter found issues: %s", result.summary),
|
||||
})
|
||||
-- Log individual errors
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
if diag.severity == vim.diagnostic.severity.ERROR then
|
||||
logs.add({
|
||||
type = "error",
|
||||
message = format_diagnostic(diag),
|
||||
})
|
||||
end
|
||||
end
|
||||
elseif result.has_warnings then
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = string.format("Linter warnings: %s", result.summary),
|
||||
})
|
||||
else
|
||||
logs.add({
|
||||
type = "success",
|
||||
message = "Linter check passed - no errors or warnings",
|
||||
})
|
||||
end
|
||||
end)
|
||||
|
||||
-- Notify user
|
||||
if result.has_errors then
|
||||
vim.notify(
|
||||
string.format("Generated code has lint errors: %s", result.summary),
|
||||
vim.log.levels.ERROR
|
||||
)
|
||||
|
||||
-- Offer to fix if configured
|
||||
if config.auto_offer_fix and #result.diagnostics > 0 then
|
||||
M.offer_fix(bufnr, result)
|
||||
end
|
||||
elseif result.has_warnings then
|
||||
vim.notify(
|
||||
string.format("Generated code has warnings: %s", result.summary),
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
end
|
||||
|
||||
if callback then
|
||||
callback(result)
|
||||
end
|
||||
end, config.diagnostic_delay_ms)
|
||||
end
|
||||
|
||||
--- Offer to fix lint errors using AI
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.offer_fix(bufnr, result)
|
||||
if not result.has_errors and not result.has_warnings then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build error summary for prompt
|
||||
local error_messages = {}
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(error_messages, format_diagnostic(diag))
|
||||
end
|
||||
|
||||
vim.ui.select(
|
||||
{ "Yes - Auto-fix with AI", "No - I'll fix manually", "Show errors in quickfix" },
|
||||
{
|
||||
prompt = string.format("Found %d issue(s). Would you like AI to fix them?", #result.diagnostics),
|
||||
},
|
||||
function(choice)
|
||||
if not choice then return end
|
||||
|
||||
if choice:match("^Yes") then
|
||||
M.request_ai_fix(bufnr, result)
|
||||
elseif choice:match("quickfix") then
|
||||
M.show_in_quickfix(bufnr, result)
|
||||
end
|
||||
end
|
||||
)
|
||||
end
|
||||
|
||||
--- Show lint errors in quickfix list
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.show_in_quickfix(bufnr, result)
|
||||
local qf_items = {}
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(qf_items, {
|
||||
bufnr = bufnr,
|
||||
filename = bufname,
|
||||
lnum = diag.lnum + 1,
|
||||
col = diag.col + 1,
|
||||
text = diag.message,
|
||||
type = diag.severity == vim.diagnostic.severity.ERROR and "E" or "W",
|
||||
})
|
||||
end
|
||||
|
||||
vim.fn.setqflist(qf_items, "r")
|
||||
vim.cmd("copen")
|
||||
end
|
||||
|
||||
--- Request AI to fix lint errors
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.request_ai_fix(bufnr, result)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return
|
||||
end
|
||||
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
|
||||
-- Build fix prompt
|
||||
local error_list = {}
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(error_list, format_diagnostic(diag))
|
||||
end
|
||||
|
||||
-- Get the affected code region
|
||||
local start_line = result.diagnostics[1] and (result.diagnostics[1].lnum + 1) or 1
|
||||
local end_line = start_line
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
local line = diag.lnum + 1
|
||||
if line < start_line then start_line = line end
|
||||
if line > end_line then end_line = line end
|
||||
end
|
||||
|
||||
-- Expand range by a few lines for context
|
||||
start_line = math.max(1, start_line - 5)
|
||||
end_line = math.min(vim.api.nvim_buf_line_count(bufnr), end_line + 5)
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local code_context = table.concat(lines, "\n")
|
||||
|
||||
-- Create fix prompt using inline tag
|
||||
local fix_prompt = string.format(
|
||||
prompts.fix_request,
|
||||
table.concat(error_list, "\n"),
|
||||
start_line,
|
||||
end_line,
|
||||
code_context
|
||||
)
|
||||
|
||||
-- Queue the fix through the scheduler
|
||||
local scheduler = require("codetyper.core.scheduler.scheduler")
|
||||
local queue = require("codetyper.core.events.queue")
|
||||
local patch_mod = require("codetyper.core.diff.patch")
|
||||
|
||||
-- Ensure scheduler is running
|
||||
if not scheduler.status().running then
|
||||
scheduler.start()
|
||||
end
|
||||
|
||||
-- Take snapshot
|
||||
local snapshot = patch_mod.snapshot_buffer(bufnr, {
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
})
|
||||
|
||||
-- Enqueue fix request
|
||||
queue.enqueue({
|
||||
id = queue.generate_id(),
|
||||
bufnr = bufnr,
|
||||
range = { start_line = start_line, end_line = end_line },
|
||||
timestamp = os.clock(),
|
||||
changedtick = snapshot.changedtick,
|
||||
content_hash = snapshot.content_hash,
|
||||
prompt_content = fix_prompt,
|
||||
target_path = filepath,
|
||||
priority = 1, -- High priority for fixes
|
||||
status = "pending",
|
||||
attempt_count = 0,
|
||||
intent = {
|
||||
type = "fix",
|
||||
action = "replace",
|
||||
confidence = 0.9,
|
||||
},
|
||||
scope_range = { start_line = start_line, end_line = end_line },
|
||||
source = "linter_fix",
|
||||
})
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = "Queued AI fix request for lint errors",
|
||||
})
|
||||
end)
|
||||
|
||||
vim.notify("Queued AI fix request for lint errors", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Get last validation result for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@return table|nil result
|
||||
function M.get_last_result(bufnr)
|
||||
return validation_results[bufnr]
|
||||
end
|
||||
|
||||
--- Clear validation results for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
function M.clear_result(bufnr)
|
||||
validation_results[bufnr] = nil
|
||||
end
|
||||
|
||||
--- Check if buffer has any lint errors currently
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean has_errors
|
||||
function M.has_errors(bufnr)
|
||||
local diagnostics = vim.diagnostic.get(bufnr, {
|
||||
severity = vim.diagnostic.severity.ERROR,
|
||||
})
|
||||
return #diagnostics > 0
|
||||
end
|
||||
|
||||
--- Check if buffer has any lint warnings currently
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean has_warnings
|
||||
function M.has_warnings(bufnr)
|
||||
local diagnostics = vim.diagnostic.get(bufnr, {
|
||||
severity = { min = vim.diagnostic.severity.WARN },
|
||||
})
|
||||
return #diagnostics > 0
|
||||
end
|
||||
|
||||
--- Validate all buffers with recent changes
|
||||
function M.validate_all_changed()
|
||||
for bufnr, data in pairs(validation_results) do
|
||||
if vim.api.nvim_buf_is_valid(bufnr) then
|
||||
M.validate_after_injection(bufnr, data.start_line, data.end_line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,182 +0,0 @@
|
||||
---@mod codetyper.agent.permissions Permission manager for agent actions
|
||||
---
|
||||
--- Manages permissions for bash commands and file operations with
|
||||
--- allow, allow-session, allow-list, and reject options.
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class PermissionState
|
||||
---@field session_allowed table<string, boolean> Commands allowed for this session
|
||||
---@field allow_list table<string, boolean> Patterns always allowed
|
||||
---@field deny_list table<string, boolean> Patterns always denied
|
||||
|
||||
local params = require("codetyper.params.agents.permissions")
|
||||
|
||||
local state = {
|
||||
session_allowed = {},
|
||||
allow_list = {},
|
||||
deny_list = {},
|
||||
}
|
||||
|
||||
--- Dangerous command patterns that should never be auto-allowed
|
||||
local DANGEROUS_PATTERNS = params.dangerous_patterns
|
||||
|
||||
--- Safe command patterns that can be auto-allowed
|
||||
local SAFE_PATTERNS = params.safe_patterns
|
||||
|
||||
---@alias PermissionLevel "allow"|"allow_session"|"allow_list"|"reject"
|
||||
|
||||
---@class PermissionResult
|
||||
---@field allowed boolean Whether action is allowed
|
||||
---@field reason string Reason for the decision
|
||||
---@field auto boolean Whether this was an automatic decision
|
||||
|
||||
--- Check if a command matches a pattern
|
||||
---@param command string The command to check
|
||||
---@param pattern string The pattern to match
|
||||
---@return boolean
|
||||
local function matches_pattern(command, pattern)
|
||||
return command:match(pattern) ~= nil
|
||||
end
|
||||
|
||||
--- Check if command is dangerous
|
||||
---@param command string The command to check
|
||||
---@return boolean, string|nil dangerous, reason
|
||||
local function is_dangerous(command)
|
||||
for _, pattern in ipairs(DANGEROUS_PATTERNS) do
|
||||
if matches_pattern(command, pattern) then
|
||||
return true, "Matches dangerous pattern: " .. pattern
|
||||
end
|
||||
end
|
||||
return false, nil
|
||||
end
|
||||
|
||||
--- Check if command is safe
|
||||
---@param command string The command to check
|
||||
---@return boolean
|
||||
local function is_safe(command)
|
||||
for _, pattern in ipairs(SAFE_PATTERNS) do
|
||||
if matches_pattern(command, pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Normalize command for comparison (trim, lowercase first word)
|
||||
---@param command string
|
||||
---@return string
|
||||
local function normalize_command(command)
|
||||
return vim.trim(command)
|
||||
end
|
||||
|
||||
--- Check permission for a bash command
|
||||
---@param command string The command to check
|
||||
---@return PermissionResult
|
||||
function M.check_bash_permission(command)
|
||||
local normalized = normalize_command(command)
|
||||
|
||||
-- Check deny list first
|
||||
for pattern, _ in pairs(state.deny_list) do
|
||||
if matches_pattern(normalized, pattern) then
|
||||
return {
|
||||
allowed = false,
|
||||
reason = "Command in deny list",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if command is dangerous
|
||||
local dangerous, reason = is_dangerous(normalized)
|
||||
if dangerous then
|
||||
return {
|
||||
allowed = false,
|
||||
reason = reason,
|
||||
auto = false, -- Require explicit approval for dangerous commands
|
||||
}
|
||||
end
|
||||
|
||||
-- Check session allowed
|
||||
if state.session_allowed[normalized] then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Allowed for this session",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
|
||||
-- Check allow list patterns
|
||||
for pattern, _ in pairs(state.allow_list) do
|
||||
if matches_pattern(normalized, pattern) then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Matches allow list pattern",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if command is inherently safe
|
||||
if is_safe(normalized) then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Safe read-only command",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
|
||||
-- Otherwise, require explicit permission
|
||||
return {
|
||||
allowed = false,
|
||||
reason = "Requires approval",
|
||||
auto = false,
|
||||
}
|
||||
end
|
||||
|
||||
--- Grant permission for a command
|
||||
---@param command string The command
|
||||
---@param level PermissionLevel The permission level
|
||||
function M.grant_permission(command, level)
|
||||
local normalized = normalize_command(command)
|
||||
|
||||
if level == "allow_session" then
|
||||
state.session_allowed[normalized] = true
|
||||
elseif level == "allow_list" then
|
||||
-- Add as pattern (escape special chars for exact match)
|
||||
local pattern = "^" .. vim.pesc(normalized) .. "$"
|
||||
state.allow_list[pattern] = true
|
||||
end
|
||||
end
|
||||
|
||||
--- Add a pattern to the allow list
|
||||
---@param pattern string Lua pattern to allow
|
||||
function M.add_to_allow_list(pattern)
|
||||
state.allow_list[pattern] = true
|
||||
end
|
||||
|
||||
--- Add a pattern to the deny list
|
||||
---@param pattern string Lua pattern to deny
|
||||
function M.add_to_deny_list(pattern)
|
||||
state.deny_list[pattern] = true
|
||||
end
|
||||
|
||||
--- Clear session permissions
|
||||
function M.clear_session()
|
||||
state.session_allowed = {}
|
||||
end
|
||||
|
||||
--- Reset all permissions
|
||||
function M.reset()
|
||||
state.session_allowed = {}
|
||||
state.allow_list = {}
|
||||
state.deny_list = {}
|
||||
end
|
||||
|
||||
--- Get current permission state (for debugging)
|
||||
---@return PermissionState
|
||||
function M.get_state()
|
||||
return vim.deepcopy(state)
|
||||
end
|
||||
|
||||
return M
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,676 +0,0 @@
|
||||
---@mod codetyper.ask.explorer Project exploration for Ask mode
|
||||
---@brief [[
|
||||
--- Performs comprehensive project exploration when explaining a project.
|
||||
--- Shows progress, indexes files, and builds brain context.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
---@class ExplorationState
|
||||
---@field is_exploring boolean
|
||||
---@field files_scanned number
|
||||
---@field total_files number
|
||||
---@field current_file string|nil
|
||||
---@field findings table
|
||||
---@field on_log fun(msg: string, level: string)|nil
|
||||
|
||||
local state = {
|
||||
is_exploring = false,
|
||||
files_scanned = 0,
|
||||
total_files = 0,
|
||||
current_file = nil,
|
||||
findings = {},
|
||||
on_log = nil,
|
||||
}
|
||||
|
||||
--- File extensions to analyze
|
||||
local ANALYZABLE_EXTENSIONS = {
|
||||
lua = true,
|
||||
ts = true,
|
||||
tsx = true,
|
||||
js = true,
|
||||
jsx = true,
|
||||
py = true,
|
||||
go = true,
|
||||
rs = true,
|
||||
rb = true,
|
||||
java = true,
|
||||
c = true,
|
||||
cpp = true,
|
||||
h = true,
|
||||
hpp = true,
|
||||
json = true,
|
||||
yaml = true,
|
||||
yml = true,
|
||||
toml = true,
|
||||
md = true,
|
||||
xml = true,
|
||||
}
|
||||
|
||||
--- Directories to skip
|
||||
local SKIP_DIRS = {
|
||||
-- Version control
|
||||
[".git"] = true,
|
||||
[".svn"] = true,
|
||||
[".hg"] = true,
|
||||
|
||||
-- IDE/Editor
|
||||
[".idea"] = true,
|
||||
[".vscode"] = true,
|
||||
[".cursor"] = true,
|
||||
[".cursorignore"] = true,
|
||||
[".claude"] = true,
|
||||
[".zed"] = true,
|
||||
|
||||
-- Project tooling
|
||||
[".coder"] = true,
|
||||
[".github"] = true,
|
||||
[".gitlab"] = true,
|
||||
[".husky"] = true,
|
||||
|
||||
-- Build outputs
|
||||
dist = true,
|
||||
build = true,
|
||||
out = true,
|
||||
target = true,
|
||||
bin = true,
|
||||
obj = true,
|
||||
[".build"] = true,
|
||||
[".output"] = true,
|
||||
|
||||
-- Dependencies
|
||||
node_modules = true,
|
||||
vendor = true,
|
||||
[".vendor"] = true,
|
||||
packages = true,
|
||||
bower_components = true,
|
||||
jspm_packages = true,
|
||||
|
||||
-- Cache/temp
|
||||
[".cache"] = true,
|
||||
[".tmp"] = true,
|
||||
[".temp"] = true,
|
||||
__pycache__ = true,
|
||||
[".pytest_cache"] = true,
|
||||
[".mypy_cache"] = true,
|
||||
[".ruff_cache"] = true,
|
||||
[".tox"] = true,
|
||||
[".nox"] = true,
|
||||
[".eggs"] = true,
|
||||
["*.egg-info"] = true,
|
||||
|
||||
-- Framework specific
|
||||
[".next"] = true,
|
||||
[".nuxt"] = true,
|
||||
[".svelte-kit"] = true,
|
||||
[".vercel"] = true,
|
||||
[".netlify"] = true,
|
||||
[".serverless"] = true,
|
||||
[".turbo"] = true,
|
||||
|
||||
-- Testing/coverage
|
||||
coverage = true,
|
||||
[".nyc_output"] = true,
|
||||
htmlcov = true,
|
||||
|
||||
-- Logs
|
||||
logs = true,
|
||||
log = true,
|
||||
|
||||
-- OS files
|
||||
[".DS_Store"] = true,
|
||||
Thumbs_db = true,
|
||||
}
|
||||
|
||||
--- Files to skip (patterns)
|
||||
local SKIP_FILES = {
|
||||
-- Lock files
|
||||
"package%-lock%.json",
|
||||
"yarn%.lock",
|
||||
"pnpm%-lock%.yaml",
|
||||
"Gemfile%.lock",
|
||||
"Cargo%.lock",
|
||||
"poetry%.lock",
|
||||
"Pipfile%.lock",
|
||||
"composer%.lock",
|
||||
"go%.sum",
|
||||
"flake%.lock",
|
||||
"%.lock$",
|
||||
"%-lock%.json$",
|
||||
"%-lock%.yaml$",
|
||||
|
||||
-- Generated files
|
||||
"%.min%.js$",
|
||||
"%.min%.css$",
|
||||
"%.bundle%.js$",
|
||||
"%.chunk%.js$",
|
||||
"%.map$",
|
||||
"%.d%.ts$",
|
||||
|
||||
-- Binary/media (shouldn't match anyway but be safe)
|
||||
"%.png$",
|
||||
"%.jpg$",
|
||||
"%.jpeg$",
|
||||
"%.gif$",
|
||||
"%.ico$",
|
||||
"%.svg$",
|
||||
"%.woff",
|
||||
"%.ttf$",
|
||||
"%.eot$",
|
||||
"%.pdf$",
|
||||
"%.zip$",
|
||||
"%.tar",
|
||||
"%.gz$",
|
||||
|
||||
-- Config that's not useful
|
||||
"%.env",
|
||||
"%.env%.",
|
||||
}
|
||||
|
||||
--- Log a message during exploration
|
||||
---@param msg string
|
||||
---@param level? string "info"|"debug"|"file"|"progress"
|
||||
local function log(msg, level)
|
||||
level = level or "info"
|
||||
if state.on_log then
|
||||
state.on_log(msg, level)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if file should be skipped
|
||||
---@param filename string
|
||||
---@return boolean
|
||||
local function should_skip_file(filename)
|
||||
for _, pattern in ipairs(SKIP_FILES) do
|
||||
if filename:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if directory should be skipped
|
||||
---@param dirname string
|
||||
---@return boolean
|
||||
local function should_skip_dir(dirname)
|
||||
-- Direct match
|
||||
if SKIP_DIRS[dirname] then
|
||||
return true
|
||||
end
|
||||
-- Pattern match for .cursor* etc
|
||||
if dirname:match("^%.cursor") then
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Get all files in project
|
||||
---@param root string Project root
|
||||
---@return string[] files
|
||||
local function get_project_files(root)
|
||||
local files = {}
|
||||
|
||||
local function scan_dir(dir)
|
||||
local handle = vim.loop.fs_scandir(dir)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
|
||||
local full_path = dir .. "/" .. name
|
||||
|
||||
if type == "directory" then
|
||||
if not should_skip_dir(name) then
|
||||
scan_dir(full_path)
|
||||
end
|
||||
elseif type == "file" then
|
||||
if not should_skip_file(name) then
|
||||
local ext = name:match("%.([^%.]+)$")
|
||||
if ext and ANALYZABLE_EXTENSIONS[ext:lower()] then
|
||||
table.insert(files, full_path)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
scan_dir(root)
|
||||
return files
|
||||
end
|
||||
|
||||
--- Analyze a single file
|
||||
---@param filepath string
|
||||
---@return table|nil analysis
|
||||
local function analyze_file(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content or content == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ext = filepath:match("%.([^%.]+)$") or ""
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
local analysis = {
|
||||
path = filepath,
|
||||
extension = ext,
|
||||
lines = #lines,
|
||||
size = #content,
|
||||
imports = {},
|
||||
exports = {},
|
||||
functions = {},
|
||||
classes = {},
|
||||
summary = "",
|
||||
}
|
||||
|
||||
-- Extract key patterns based on file type
|
||||
for i, line in ipairs(lines) do
|
||||
-- Imports/requires
|
||||
local import = line:match('import%s+.*%s+from%s+["\']([^"\']+)["\']')
|
||||
or line:match('require%(["\']([^"\']+)["\']%)')
|
||||
or line:match("from%s+([%w_.]+)%s+import")
|
||||
if import then
|
||||
table.insert(analysis.imports, { source = import, line = i })
|
||||
end
|
||||
|
||||
-- Function definitions
|
||||
local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(")
|
||||
or line:match("^%s*local%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*def%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*func%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*async%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(")
|
||||
if func then
|
||||
table.insert(analysis.functions, { name = func, line = i })
|
||||
end
|
||||
|
||||
-- Class definitions
|
||||
local class = line:match("^%s*class%s+([%w_]+)")
|
||||
or line:match("^%s*public%s+class%s+([%w_]+)")
|
||||
or line:match("^%s*interface%s+([%w_]+)")
|
||||
if class then
|
||||
table.insert(analysis.classes, { name = class, line = i })
|
||||
end
|
||||
|
||||
-- Exports
|
||||
local exp = line:match("^%s*export%s+.*%s+([%w_]+)")
|
||||
or line:match("^%s*module%.exports%s*=")
|
||||
or line:match("^return%s+M")
|
||||
if exp then
|
||||
table.insert(analysis.exports, { name = exp, line = i })
|
||||
end
|
||||
end
|
||||
|
||||
-- Create summary
|
||||
local parts = {}
|
||||
if #analysis.functions > 0 then
|
||||
table.insert(parts, #analysis.functions .. " functions")
|
||||
end
|
||||
if #analysis.classes > 0 then
|
||||
table.insert(parts, #analysis.classes .. " classes")
|
||||
end
|
||||
if #analysis.imports > 0 then
|
||||
table.insert(parts, #analysis.imports .. " imports")
|
||||
end
|
||||
analysis.summary = table.concat(parts, ", ")
|
||||
|
||||
return analysis
|
||||
end
|
||||
|
||||
--- Detect project type from files
|
||||
---@param root string
|
||||
---@return string type, table info
|
||||
local function detect_project_type(root)
|
||||
local info = {
|
||||
name = vim.fn.fnamemodify(root, ":t"),
|
||||
type = "unknown",
|
||||
framework = nil,
|
||||
language = nil,
|
||||
}
|
||||
|
||||
-- Check for common project files
|
||||
if utils.file_exists(root .. "/package.json") then
|
||||
info.type = "node"
|
||||
info.language = "JavaScript/TypeScript"
|
||||
local content = utils.read_file(root .. "/package.json")
|
||||
if content then
|
||||
local ok, pkg = pcall(vim.json.decode, content)
|
||||
if ok then
|
||||
info.name = pkg.name or info.name
|
||||
if pkg.dependencies then
|
||||
if pkg.dependencies.react then
|
||||
info.framework = "React"
|
||||
elseif pkg.dependencies.vue then
|
||||
info.framework = "Vue"
|
||||
elseif pkg.dependencies.next then
|
||||
info.framework = "Next.js"
|
||||
elseif pkg.dependencies.express then
|
||||
info.framework = "Express"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif utils.file_exists(root .. "/pom.xml") then
|
||||
info.type = "maven"
|
||||
info.language = "Java"
|
||||
local content = utils.read_file(root .. "/pom.xml")
|
||||
if content and content:match("spring%-boot") then
|
||||
info.framework = "Spring Boot"
|
||||
end
|
||||
elseif utils.file_exists(root .. "/Cargo.toml") then
|
||||
info.type = "rust"
|
||||
info.language = "Rust"
|
||||
elseif utils.file_exists(root .. "/go.mod") then
|
||||
info.type = "go"
|
||||
info.language = "Go"
|
||||
elseif utils.file_exists(root .. "/requirements.txt") or utils.file_exists(root .. "/pyproject.toml") then
|
||||
info.type = "python"
|
||||
info.language = "Python"
|
||||
elseif utils.file_exists(root .. "/init.lua") or utils.file_exists(root .. "/plugin/") then
|
||||
info.type = "neovim-plugin"
|
||||
info.language = "Lua"
|
||||
end
|
||||
|
||||
return info.type, info
|
||||
end
|
||||
|
||||
--- Build project structure summary
|
||||
---@param files string[]
|
||||
---@param root string
|
||||
---@return table structure
|
||||
local function build_structure(files, root)
|
||||
local structure = {
|
||||
directories = {},
|
||||
by_extension = {},
|
||||
total_files = #files,
|
||||
}
|
||||
|
||||
for _, file in ipairs(files) do
|
||||
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local dir = vim.fn.fnamemodify(relative, ":h")
|
||||
local ext = file:match("%.([^%.]+)$") or "unknown"
|
||||
|
||||
structure.directories[dir] = (structure.directories[dir] or 0) + 1
|
||||
structure.by_extension[ext] = (structure.by_extension[ext] or 0) + 1
|
||||
end
|
||||
|
||||
return structure
|
||||
end
|
||||
|
||||
--- Explore project and build context
|
||||
---@param root string Project root
|
||||
---@param on_log fun(msg: string, level: string) Log callback
|
||||
---@param on_complete fun(result: table) Completion callback
|
||||
function M.explore(root, on_log, on_complete)
|
||||
if state.is_exploring then
|
||||
on_log("⚠️ Already exploring...", "warning")
|
||||
return
|
||||
end
|
||||
|
||||
state.is_exploring = true
|
||||
state.on_log = on_log
|
||||
state.findings = {}
|
||||
|
||||
-- Start exploration
|
||||
log("⏺ Exploring project structure...", "info")
|
||||
log("", "info")
|
||||
|
||||
-- Detect project type
|
||||
log(" Detect(Project type)", "progress")
|
||||
local project_type, project_info = detect_project_type(root)
|
||||
log(" ⎿ " .. project_info.language .. " (" .. (project_info.framework or project_type) .. ")", "debug")
|
||||
|
||||
state.findings.project = project_info
|
||||
|
||||
-- Get all files
|
||||
log("", "info")
|
||||
log(" Scan(Project files)", "progress")
|
||||
local files = get_project_files(root)
|
||||
state.total_files = #files
|
||||
log(" ⎿ Found " .. #files .. " analyzable files", "debug")
|
||||
|
||||
-- Build structure
|
||||
local structure = build_structure(files, root)
|
||||
state.findings.structure = structure
|
||||
|
||||
-- Show directory breakdown
|
||||
log("", "info")
|
||||
log(" Structure(Directories)", "progress")
|
||||
local sorted_dirs = {}
|
||||
for dir, count in pairs(structure.directories) do
|
||||
table.insert(sorted_dirs, { dir = dir, count = count })
|
||||
end
|
||||
table.sort(sorted_dirs, function(a, b)
|
||||
return a.count > b.count
|
||||
end)
|
||||
for i, entry in ipairs(sorted_dirs) do
|
||||
if i <= 5 then
|
||||
log(" ⎿ " .. entry.dir .. " (" .. entry.count .. " files)", "debug")
|
||||
end
|
||||
end
|
||||
if #sorted_dirs > 5 then
|
||||
log(" ⎿ +" .. (#sorted_dirs - 5) .. " more directories", "debug")
|
||||
end
|
||||
|
||||
-- Analyze files asynchronously
|
||||
log("", "info")
|
||||
log(" Analyze(Source files)", "progress")
|
||||
|
||||
state.files_scanned = 0
|
||||
local analyses = {}
|
||||
local key_files = {}
|
||||
|
||||
-- Process files in batches to avoid blocking
|
||||
local batch_size = 10
|
||||
local current_batch = 0
|
||||
|
||||
local function process_batch()
|
||||
local start_idx = current_batch * batch_size + 1
|
||||
local end_idx = math.min(start_idx + batch_size - 1, #files)
|
||||
|
||||
for i = start_idx, end_idx do
|
||||
local file = files[i]
|
||||
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
|
||||
state.files_scanned = state.files_scanned + 1
|
||||
state.current_file = relative
|
||||
|
||||
local analysis = analyze_file(file)
|
||||
if analysis then
|
||||
analysis.relative_path = relative
|
||||
table.insert(analyses, analysis)
|
||||
|
||||
-- Track key files (many functions/classes)
|
||||
if #analysis.functions >= 3 or #analysis.classes >= 1 then
|
||||
table.insert(key_files, {
|
||||
path = relative,
|
||||
functions = #analysis.functions,
|
||||
classes = #analysis.classes,
|
||||
summary = analysis.summary,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Log some files
|
||||
if i <= 3 or (i % 20 == 0) then
|
||||
log(" ⎿ " .. relative .. ": " .. (analysis and analysis.summary or "(empty)"), "file")
|
||||
end
|
||||
end
|
||||
|
||||
-- Progress update
|
||||
local progress = math.floor((state.files_scanned / state.total_files) * 100)
|
||||
if progress % 25 == 0 and progress > 0 then
|
||||
log(" ⎿ " .. progress .. "% complete (" .. state.files_scanned .. "/" .. state.total_files .. ")", "debug")
|
||||
end
|
||||
|
||||
current_batch = current_batch + 1
|
||||
|
||||
if end_idx < #files then
|
||||
-- Schedule next batch
|
||||
vim.defer_fn(process_batch, 10)
|
||||
else
|
||||
-- Complete
|
||||
finish_exploration(root, analyses, key_files, on_complete)
|
||||
end
|
||||
end
|
||||
|
||||
-- Start processing
|
||||
vim.defer_fn(process_batch, 10)
|
||||
end
|
||||
|
||||
--- Finish exploration and store results
|
||||
---@param root string
|
||||
---@param analyses table
|
||||
---@param key_files table
|
||||
---@param on_complete fun(result: table)
|
||||
function finish_exploration(root, analyses, key_files, on_complete)
|
||||
log(" ⎿ +" .. (#analyses - 3) .. " more files analyzed", "debug")
|
||||
|
||||
-- Show key files
|
||||
if #key_files > 0 then
|
||||
log("", "info")
|
||||
log(" KeyFiles(Important components)", "progress")
|
||||
table.sort(key_files, function(a, b)
|
||||
return (a.functions + a.classes * 2) > (b.functions + b.classes * 2)
|
||||
end)
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i <= 5 then
|
||||
log(" ⎿ " .. kf.path .. ": " .. kf.summary, "file")
|
||||
end
|
||||
end
|
||||
if #key_files > 5 then
|
||||
log(" ⎿ +" .. (#key_files - 5) .. " more key files", "debug")
|
||||
end
|
||||
end
|
||||
|
||||
state.findings.analyses = analyses
|
||||
state.findings.key_files = key_files
|
||||
|
||||
-- Store in brain if available
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized() then
|
||||
log("", "info")
|
||||
log(" Store(Brain context)", "progress")
|
||||
|
||||
-- Store project pattern
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root,
|
||||
content = {
|
||||
summary = "Project: " .. state.findings.project.name,
|
||||
detail = state.findings.project.language
|
||||
.. " "
|
||||
.. (state.findings.project.framework or state.findings.project.type),
|
||||
code = nil,
|
||||
},
|
||||
context = {
|
||||
file = root,
|
||||
language = state.findings.project.language,
|
||||
},
|
||||
})
|
||||
|
||||
-- Store key file patterns
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i <= 10 then
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root .. "/" .. kf.path,
|
||||
content = {
|
||||
summary = kf.path .. " - " .. kf.summary,
|
||||
detail = kf.summary,
|
||||
},
|
||||
context = {
|
||||
file = kf.path,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
log(" ⎿ Stored " .. math.min(#key_files, 10) + 1 .. " patterns in brain", "debug")
|
||||
end
|
||||
|
||||
-- Store in indexer if available
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
log(" Index(Project index)", "progress")
|
||||
indexer.index_project(function(index)
|
||||
log(" ⎿ Indexed " .. (index.stats.files or 0) .. " files", "debug")
|
||||
end)
|
||||
end
|
||||
|
||||
log("", "info")
|
||||
log("✓ Exploration complete!", "info")
|
||||
log("", "info")
|
||||
|
||||
-- Build result
|
||||
local result = {
|
||||
project = state.findings.project,
|
||||
structure = state.findings.structure,
|
||||
key_files = key_files,
|
||||
total_files = state.total_files,
|
||||
analyses = analyses,
|
||||
}
|
||||
|
||||
state.is_exploring = false
|
||||
state.on_log = nil
|
||||
|
||||
on_complete(result)
|
||||
end
|
||||
|
||||
--- Check if exploration is in progress
|
||||
---@return boolean
|
||||
function M.is_exploring()
|
||||
return state.is_exploring
|
||||
end
|
||||
|
||||
--- Get exploration progress
|
||||
---@return number scanned, number total
|
||||
function M.get_progress()
|
||||
return state.files_scanned, state.total_files
|
||||
end
|
||||
|
||||
--- Build context string from exploration result
|
||||
---@param result table Exploration result
|
||||
---@return string context
|
||||
function M.build_context(result)
|
||||
local parts = {}
|
||||
|
||||
-- Project info
|
||||
table.insert(parts, "## Project: " .. result.project.name)
|
||||
table.insert(parts, "- Type: " .. result.project.type)
|
||||
table.insert(parts, "- Language: " .. (result.project.language or "Unknown"))
|
||||
if result.project.framework then
|
||||
table.insert(parts, "- Framework: " .. result.project.framework)
|
||||
end
|
||||
table.insert(parts, "- Files: " .. result.total_files)
|
||||
table.insert(parts, "")
|
||||
|
||||
-- Structure
|
||||
table.insert(parts, "## Structure")
|
||||
if result.structure and result.structure.by_extension then
|
||||
for ext, count in pairs(result.structure.by_extension) do
|
||||
table.insert(parts, "- ." .. ext .. ": " .. count .. " files")
|
||||
end
|
||||
end
|
||||
table.insert(parts, "")
|
||||
|
||||
-- Key components
|
||||
if result.key_files and #result.key_files > 0 then
|
||||
table.insert(parts, "## Key Components")
|
||||
for i, kf in ipairs(result.key_files) do
|
||||
if i <= 10 then
|
||||
table.insert(parts, "- " .. kf.path .. ": " .. kf.summary)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,302 +0,0 @@
|
||||
---@mod codetyper.ask.intent Intent detection for Ask mode
|
||||
---@brief [[
|
||||
--- Analyzes user prompts to detect intent (ask/explain vs code generation).
|
||||
--- Routes to appropriate prompt type and context sources.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@alias IntentType "ask"|"explain"|"generate"|"refactor"|"document"|"test"
|
||||
|
||||
---@class Intent
|
||||
---@field type IntentType Detected intent type
|
||||
---@field confidence number 0-1 confidence score
|
||||
---@field needs_project_context boolean Whether project-wide context is needed
|
||||
---@field needs_brain_context boolean Whether brain/learned context is helpful
|
||||
---@field needs_exploration boolean Whether full project exploration is needed
|
||||
---@field keywords string[] Keywords that influenced detection
|
||||
|
||||
--- Patterns for detecting ask/explain intent (questions about code)
|
||||
local ASK_PATTERNS = {
|
||||
-- Question words
|
||||
{ pattern = "^what%s", weight = 0.9 },
|
||||
{ pattern = "^why%s", weight = 0.95 },
|
||||
{ pattern = "^how%s+does", weight = 0.9 },
|
||||
{ pattern = "^how%s+do%s+i", weight = 0.7 }, -- Could be asking for code
|
||||
{ pattern = "^where%s", weight = 0.85 },
|
||||
{ pattern = "^when%s", weight = 0.85 },
|
||||
{ pattern = "^which%s", weight = 0.8 },
|
||||
{ pattern = "^who%s", weight = 0.85 },
|
||||
{ pattern = "^can%s+you%s+explain", weight = 0.95 },
|
||||
{ pattern = "^could%s+you%s+explain", weight = 0.95 },
|
||||
{ pattern = "^please%s+explain", weight = 0.95 },
|
||||
|
||||
-- Explanation requests
|
||||
{ pattern = "explain%s", weight = 0.9 },
|
||||
{ pattern = "describe%s", weight = 0.85 },
|
||||
{ pattern = "tell%s+me%s+about", weight = 0.85 },
|
||||
{ pattern = "walk%s+me%s+through", weight = 0.9 },
|
||||
{ pattern = "help%s+me%s+understand", weight = 0.95 },
|
||||
{ pattern = "what%s+is%s+the%s+purpose", weight = 0.95 },
|
||||
{ pattern = "what%s+does%s+this", weight = 0.9 },
|
||||
{ pattern = "what%s+does%s+it", weight = 0.9 },
|
||||
{ pattern = "how%s+does%s+this%s+work", weight = 0.95 },
|
||||
{ pattern = "how%s+does%s+it%s+work", weight = 0.95 },
|
||||
|
||||
-- Understanding queries
|
||||
{ pattern = "understand", weight = 0.7 },
|
||||
{ pattern = "meaning%s+of", weight = 0.85 },
|
||||
{ pattern = "difference%s+between", weight = 0.9 },
|
||||
{ pattern = "compared%s+to", weight = 0.8 },
|
||||
{ pattern = "vs%s", weight = 0.7 },
|
||||
{ pattern = "versus", weight = 0.7 },
|
||||
{ pattern = "pros%s+and%s+cons", weight = 0.9 },
|
||||
{ pattern = "advantages", weight = 0.8 },
|
||||
{ pattern = "disadvantages", weight = 0.8 },
|
||||
{ pattern = "trade%-?offs?", weight = 0.85 },
|
||||
|
||||
-- Analysis requests
|
||||
{ pattern = "analyze", weight = 0.85 },
|
||||
{ pattern = "review", weight = 0.7 }, -- Could also be refactor
|
||||
{ pattern = "overview", weight = 0.9 },
|
||||
{ pattern = "summary", weight = 0.9 },
|
||||
{ pattern = "summarize", weight = 0.9 },
|
||||
|
||||
-- Question marks (weaker signal)
|
||||
{ pattern = "%?$", weight = 0.3 },
|
||||
{ pattern = "%?%s*$", weight = 0.3 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting code generation intent
|
||||
local GENERATE_PATTERNS = {
|
||||
-- Direct commands
|
||||
{ pattern = "^create%s", weight = 0.9 },
|
||||
{ pattern = "^make%s", weight = 0.85 },
|
||||
{ pattern = "^build%s", weight = 0.85 },
|
||||
{ pattern = "^write%s", weight = 0.9 },
|
||||
{ pattern = "^add%s", weight = 0.85 },
|
||||
{ pattern = "^implement%s", weight = 0.95 },
|
||||
{ pattern = "^generate%s", weight = 0.95 },
|
||||
{ pattern = "^code%s", weight = 0.8 },
|
||||
|
||||
-- Modification commands
|
||||
{ pattern = "^fix%s", weight = 0.9 },
|
||||
{ pattern = "^change%s", weight = 0.8 },
|
||||
{ pattern = "^update%s", weight = 0.75 },
|
||||
{ pattern = "^modify%s", weight = 0.8 },
|
||||
{ pattern = "^replace%s", weight = 0.85 },
|
||||
{ pattern = "^remove%s", weight = 0.85 },
|
||||
{ pattern = "^delete%s", weight = 0.85 },
|
||||
|
||||
-- Feature requests
|
||||
{ pattern = "i%s+need%s+a", weight = 0.8 },
|
||||
{ pattern = "i%s+want%s+a", weight = 0.8 },
|
||||
{ pattern = "give%s+me", weight = 0.7 },
|
||||
{ pattern = "show%s+me%s+how%s+to%s+code", weight = 0.9 },
|
||||
{ pattern = "how%s+do%s+i%s+implement", weight = 0.85 },
|
||||
{ pattern = "can%s+you%s+write", weight = 0.9 },
|
||||
{ pattern = "can%s+you%s+create", weight = 0.9 },
|
||||
{ pattern = "can%s+you%s+add", weight = 0.85 },
|
||||
{ pattern = "can%s+you%s+make", weight = 0.85 },
|
||||
|
||||
-- Code-specific terms
|
||||
{ pattern = "function%s+that", weight = 0.85 },
|
||||
{ pattern = "class%s+that", weight = 0.85 },
|
||||
{ pattern = "method%s+that", weight = 0.85 },
|
||||
{ pattern = "component%s+that", weight = 0.85 },
|
||||
{ pattern = "module%s+that", weight = 0.85 },
|
||||
{ pattern = "api%s+for", weight = 0.8 },
|
||||
{ pattern = "endpoint%s+for", weight = 0.8 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting refactor intent
|
||||
local REFACTOR_PATTERNS = {
|
||||
{ pattern = "^refactor%s", weight = 0.95 },
|
||||
{ pattern = "refactor%s+this", weight = 0.95 },
|
||||
{ pattern = "clean%s+up", weight = 0.85 },
|
||||
{ pattern = "improve%s+this%s+code", weight = 0.85 },
|
||||
{ pattern = "make%s+this%s+cleaner", weight = 0.85 },
|
||||
{ pattern = "simplify", weight = 0.8 },
|
||||
{ pattern = "optimize", weight = 0.75 }, -- Could be explain
|
||||
{ pattern = "reorganize", weight = 0.9 },
|
||||
{ pattern = "restructure", weight = 0.9 },
|
||||
{ pattern = "extract%s+to", weight = 0.9 },
|
||||
{ pattern = "split%s+into", weight = 0.85 },
|
||||
{ pattern = "dry%s+this", weight = 0.9 }, -- Don't repeat yourself
|
||||
{ pattern = "reduce%s+duplication", weight = 0.9 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting documentation intent
|
||||
local DOCUMENT_PATTERNS = {
|
||||
{ pattern = "^document%s", weight = 0.95 },
|
||||
{ pattern = "add%s+documentation", weight = 0.95 },
|
||||
{ pattern = "add%s+docs", weight = 0.95 },
|
||||
{ pattern = "add%s+comments", weight = 0.9 },
|
||||
{ pattern = "add%s+docstring", weight = 0.95 },
|
||||
{ pattern = "add%s+jsdoc", weight = 0.95 },
|
||||
{ pattern = "write%s+documentation", weight = 0.95 },
|
||||
{ pattern = "document%s+this", weight = 0.95 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting test generation intent
|
||||
local TEST_PATTERNS = {
|
||||
{ pattern = "^test%s", weight = 0.9 },
|
||||
{ pattern = "write%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "add%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "create%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "generate%s+tests?", weight = 0.95 },
|
||||
{ pattern = "unit%s+tests?", weight = 0.9 },
|
||||
{ pattern = "test%s+cases?%s+for", weight = 0.95 },
|
||||
{ pattern = "spec%s+for", weight = 0.85 },
|
||||
}
|
||||
|
||||
--- Patterns indicating project-wide context is needed
|
||||
local PROJECT_CONTEXT_PATTERNS = {
|
||||
{ pattern = "project", weight = 0.9 },
|
||||
{ pattern = "codebase", weight = 0.95 },
|
||||
{ pattern = "entire", weight = 0.7 },
|
||||
{ pattern = "whole", weight = 0.7 },
|
||||
{ pattern = "all%s+files", weight = 0.9 },
|
||||
{ pattern = "architecture", weight = 0.95 },
|
||||
{ pattern = "structure", weight = 0.85 },
|
||||
{ pattern = "how%s+is%s+.*%s+organized", weight = 0.95 },
|
||||
{ pattern = "where%s+is%s+.*%s+defined", weight = 0.9 },
|
||||
{ pattern = "dependencies", weight = 0.85 },
|
||||
{ pattern = "imports?%s+from", weight = 0.7 },
|
||||
{ pattern = "modules?", weight = 0.6 },
|
||||
{ pattern = "packages?", weight = 0.6 },
|
||||
}
|
||||
|
||||
--- Patterns indicating project exploration is needed (full indexing)
|
||||
local EXPLORE_PATTERNS = {
|
||||
{ pattern = "explain%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "explain%s+.*%s*codebase", weight = 1.0 },
|
||||
{ pattern = "explain%s+me%s+the%s+project", weight = 1.0 },
|
||||
{ pattern = "tell%s+me%s+about%s+.*%s*project", weight = 0.95 },
|
||||
{ pattern = "what%s+is%s+this%s+project", weight = 0.95 },
|
||||
{ pattern = "overview%s+of%s+.*%s*project", weight = 0.95 },
|
||||
{ pattern = "understand%s+.*%s*project", weight = 0.9 },
|
||||
{ pattern = "analyze%s+.*%s*project", weight = 0.9 },
|
||||
{ pattern = "explore%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "explore%s+.*%s*codebase", weight = 1.0 },
|
||||
{ pattern = "index%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "scan%s+.*%s*project", weight = 0.95 },
|
||||
}
|
||||
|
||||
--- Match patterns against text
|
||||
---@param text string Lowercased text to match
|
||||
---@param patterns table Pattern list with weights
|
||||
---@return number Score, string[] Matched keywords
|
||||
local function match_patterns(text, patterns)
|
||||
local score = 0
|
||||
local matched = {}
|
||||
|
||||
for _, p in ipairs(patterns) do
|
||||
if text:match(p.pattern) then
|
||||
score = score + p.weight
|
||||
table.insert(matched, p.pattern)
|
||||
end
|
||||
end
|
||||
|
||||
return score, matched
|
||||
end
|
||||
|
||||
--- Detect intent from user prompt
|
||||
---@param prompt string User's question/request
|
||||
---@return Intent Detected intent
|
||||
function M.detect(prompt)
|
||||
local text = prompt:lower()
|
||||
|
||||
-- Calculate raw scores for each intent type (sum of matched weights)
|
||||
local ask_score, ask_kw = match_patterns(text, ASK_PATTERNS)
|
||||
local gen_score, gen_kw = match_patterns(text, GENERATE_PATTERNS)
|
||||
local ref_score, ref_kw = match_patterns(text, REFACTOR_PATTERNS)
|
||||
local doc_score, doc_kw = match_patterns(text, DOCUMENT_PATTERNS)
|
||||
local test_score, test_kw = match_patterns(text, TEST_PATTERNS)
|
||||
local proj_score, _ = match_patterns(text, PROJECT_CONTEXT_PATTERNS)
|
||||
local explore_score, _ = match_patterns(text, EXPLORE_PATTERNS)
|
||||
|
||||
-- Find the winner by raw score (highest accumulated weight)
|
||||
local scores = {
|
||||
{ type = "ask", score = ask_score, keywords = ask_kw },
|
||||
{ type = "generate", score = gen_score, keywords = gen_kw },
|
||||
{ type = "refactor", score = ref_score, keywords = ref_kw },
|
||||
{ type = "document", score = doc_score, keywords = doc_kw },
|
||||
{ type = "test", score = test_score, keywords = test_kw },
|
||||
}
|
||||
|
||||
table.sort(scores, function(a, b)
|
||||
return a.score > b.score
|
||||
end)
|
||||
|
||||
local winner = scores[1]
|
||||
|
||||
-- If top score is very low, default to ask (safer for Q&A)
|
||||
if winner.score < 0.3 then
|
||||
winner = { type = "ask", score = 0.5, keywords = {} }
|
||||
end
|
||||
|
||||
-- If ask and generate are close AND there's a question mark, prefer ask
|
||||
if winner.type == "generate" and ask_score > 0 then
|
||||
if text:match("%?%s*$") and ask_score >= gen_score * 0.5 then
|
||||
winner = { type = "ask", score = ask_score, keywords = ask_kw }
|
||||
end
|
||||
end
|
||||
|
||||
-- Determine if "explain" vs "ask" (explain needs more context)
|
||||
local intent_type = winner.type
|
||||
if intent_type == "ask" then
|
||||
-- "explain" if asking about how something works, otherwise "ask"
|
||||
if text:match("explain") or text:match("how%s+does") or text:match("walk%s+me%s+through") then
|
||||
intent_type = "explain"
|
||||
end
|
||||
end
|
||||
|
||||
-- Normalize confidence to 0-1 range (cap at reasonable max)
|
||||
local confidence = math.min(winner.score / 2, 1.0)
|
||||
|
||||
-- Check if exploration is needed (full project indexing)
|
||||
local needs_exploration = explore_score >= 0.9
|
||||
|
||||
---@type Intent
|
||||
local intent = {
|
||||
type = intent_type,
|
||||
confidence = confidence,
|
||||
needs_project_context = proj_score > 0.5 or needs_exploration,
|
||||
needs_brain_context = intent_type == "ask" or intent_type == "explain",
|
||||
needs_exploration = needs_exploration,
|
||||
keywords = winner.keywords,
|
||||
}
|
||||
|
||||
return intent
|
||||
end
|
||||
|
||||
--- Get prompt type for system prompt selection
|
||||
---@param intent Intent Detected intent
|
||||
---@return string Prompt type for prompts.system
|
||||
function M.get_prompt_type(intent)
|
||||
local mapping = {
|
||||
ask = "ask",
|
||||
explain = "ask", -- Uses same prompt as ask
|
||||
generate = "code_generation",
|
||||
refactor = "refactor",
|
||||
document = "document",
|
||||
test = "test",
|
||||
}
|
||||
return mapping[intent.type] or "ask"
|
||||
end
|
||||
|
||||
--- Check if intent requires code output
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.produces_code(intent)
|
||||
local code_intents = {
|
||||
generate = true,
|
||||
refactor = true,
|
||||
document = true, -- Documentation is code (comments)
|
||||
test = true,
|
||||
}
|
||||
return code_intents[intent.type] or false
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,456 +0,0 @@
|
||||
---@mod codetyper.agent.inject Smart code injection with import handling
|
||||
---@brief [[
|
||||
--- Intelligent code injection that properly handles imports, merging them
|
||||
--- into existing import sections instead of blindly appending.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class ImportConfig
|
||||
---@field pattern string Lua pattern to match import statements
|
||||
---@field multi_line boolean Whether imports can span multiple lines
|
||||
---@field sort_key function|nil Function to extract sort key from import
|
||||
---@field group_by function|nil Function to group imports
|
||||
|
||||
---@class ParsedCode
|
||||
---@field imports string[] Import statements
|
||||
---@field body string[] Non-import code lines
|
||||
---@field import_lines table<number, boolean> Map of line numbers that are imports
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local languages = require("codetyper.params.agents.languages")
|
||||
local import_patterns = languages.import_patterns
|
||||
|
||||
--- Check if a line is an import statement for the given language
|
||||
---@param line string
|
||||
---@param patterns table[] Import patterns for the language
|
||||
---@return boolean is_import
|
||||
---@return boolean is_multi_line
|
||||
local function is_import_line(line, patterns)
|
||||
for _, p in ipairs(patterns) do
|
||||
if line:match(p.pattern) then
|
||||
return true, p.multi_line or false
|
||||
end
|
||||
end
|
||||
return false, false
|
||||
end
|
||||
|
||||
|
||||
--- Check if a line ends a multi-line import
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function ends_multiline_import(line, filetype)
|
||||
return utils.ends_multiline_import(line, filetype)
|
||||
end
|
||||
|
||||
--- Parse code into imports and body
|
||||
---@param code string|string[] Code to parse
|
||||
---@param filetype string File type/extension
|
||||
---@return ParsedCode
|
||||
function M.parse_code(code, filetype)
|
||||
local lines
|
||||
if type(code) == "string" then
|
||||
lines = vim.split(code, "\n", { plain = true })
|
||||
else
|
||||
lines = code
|
||||
end
|
||||
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local result = {
|
||||
imports = {},
|
||||
body = {},
|
||||
import_lines = {},
|
||||
}
|
||||
|
||||
local in_multiline_import = false
|
||||
local current_import_lines = {}
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline_import then
|
||||
-- Continue collecting multi-line import
|
||||
table.insert(current_import_lines, line)
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
-- Complete the multi-line import
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
for j = i - #current_import_lines + 1, i do
|
||||
result.import_lines[j] = true
|
||||
end
|
||||
current_import_lines = {}
|
||||
in_multiline_import = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
result.import_lines[i] = true
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
-- Start of multi-line import
|
||||
in_multiline_import = true
|
||||
current_import_lines = { line }
|
||||
else
|
||||
-- Single-line import
|
||||
table.insert(result.imports, line)
|
||||
end
|
||||
else
|
||||
-- Non-import line
|
||||
table.insert(result.body, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
|
||||
if #current_import_lines > 0 then
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Find the import section range in a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return number|nil start_line First import line (1-indexed)
|
||||
---@return number|nil end_line Last import line (1-indexed)
|
||||
function M.find_import_section(bufnr, filetype)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local first_import = nil
|
||||
local last_import = nil
|
||||
local in_multiline = false
|
||||
local consecutive_non_import = 0
|
||||
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline then
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
in_multiline = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
if not first_import then
|
||||
first_import = i
|
||||
end
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
in_multiline = true
|
||||
end
|
||||
elseif utils.is_empty_or_comment(line, filetype) then
|
||||
-- Allow gaps in import section
|
||||
if first_import then
|
||||
consecutive_non_import = consecutive_non_import + 1
|
||||
if consecutive_non_import > max_gap then
|
||||
-- Too many non-import lines, import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
else
|
||||
-- Non-import, non-empty line
|
||||
if first_import then
|
||||
-- Import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return first_import, last_import
|
||||
end
|
||||
|
||||
--- Get existing imports from a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return string[] Existing import statements
|
||||
function M.get_existing_imports(bufnr, filetype)
|
||||
local start_line, end_line = M.find_import_section(bufnr, filetype)
|
||||
if not start_line then
|
||||
return {}
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local parsed = M.parse_code(lines, filetype)
|
||||
return parsed.imports
|
||||
end
|
||||
|
||||
--- Normalize an import for comparison (remove whitespace variations)
|
||||
---@param import_str string
|
||||
---@return string
|
||||
local function normalize_import(import_str)
|
||||
-- Remove trailing semicolon for comparison
|
||||
local normalized = import_str:gsub(";%s*$", "")
|
||||
-- Remove all whitespace around braces, commas, colons
|
||||
normalized = normalized:gsub("%s*{%s*", "{")
|
||||
normalized = normalized:gsub("%s*}%s*", "}")
|
||||
normalized = normalized:gsub("%s*,%s*", ",")
|
||||
normalized = normalized:gsub("%s*:%s*", ":")
|
||||
-- Collapse multiple whitespace to single space
|
||||
normalized = normalized:gsub("%s+", " ")
|
||||
-- Trim leading/trailing whitespace
|
||||
normalized = normalized:match("^%s*(.-)%s*$")
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if two imports are duplicates
|
||||
---@param import1 string
|
||||
---@param import2 string
|
||||
---@return boolean
|
||||
local function are_duplicate_imports(import1, import2)
|
||||
return normalize_import(import1) == normalize_import(import2)
|
||||
end
|
||||
|
||||
--- Merge new imports with existing ones, avoiding duplicates
|
||||
---@param existing string[] Existing imports
|
||||
---@param new_imports string[] New imports to merge
|
||||
---@return string[] Merged imports
|
||||
function M.merge_imports(existing, new_imports)
|
||||
local merged = {}
|
||||
local seen = {}
|
||||
|
||||
-- Add existing imports
|
||||
for _, imp in ipairs(existing) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new imports that aren't duplicates
|
||||
for _, imp in ipairs(new_imports) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
return merged
|
||||
end
|
||||
|
||||
--- Sort imports by their source/module
|
||||
---@param imports string[]
|
||||
---@param filetype string
|
||||
---@return string[]
|
||||
function M.sort_imports(imports, filetype)
|
||||
-- Group imports: stdlib/builtin first, then third-party, then local
|
||||
local builtin = {}
|
||||
local third_party = {}
|
||||
local local_imports = {}
|
||||
|
||||
for _, imp in ipairs(imports) do
|
||||
local category = utils.classify_import(imp, filetype)
|
||||
|
||||
if category == "builtin" then
|
||||
table.insert(builtin, imp)
|
||||
elseif category == "local" then
|
||||
table.insert(local_imports, imp)
|
||||
else
|
||||
table.insert(third_party, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort each group alphabetically
|
||||
table.sort(builtin)
|
||||
table.sort(third_party)
|
||||
table.sort(local_imports)
|
||||
|
||||
-- Combine with proper spacing
|
||||
local result = {}
|
||||
|
||||
for _, imp in ipairs(builtin) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(third_party) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #third_party > 0 and #local_imports > 0 then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(local_imports) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
---@class InjectResult
|
||||
---@field success boolean
|
||||
---@field imports_added number Number of new imports added
|
||||
---@field imports_merged boolean Whether imports were merged into existing section
|
||||
---@field body_lines number Number of body lines injected
|
||||
|
||||
--- Smart inject code into a buffer, properly handling imports
|
||||
---@param bufnr number Target buffer
|
||||
---@param code string|string[] Code to inject
|
||||
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
|
||||
---@return InjectResult
|
||||
function M.inject(bufnr, code, opts)
|
||||
opts = opts or {}
|
||||
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
|
||||
end
|
||||
|
||||
-- Get filetype
|
||||
local filetype = opts.filetype
|
||||
if not filetype then
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
filetype = vim.fn.fnamemodify(bufname, ":e")
|
||||
end
|
||||
|
||||
-- Parse the code to separate imports from body
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
|
||||
local result = {
|
||||
success = true,
|
||||
imports_added = 0,
|
||||
imports_merged = false,
|
||||
body_lines = #parsed.body,
|
||||
}
|
||||
|
||||
-- Handle imports first if there are any
|
||||
if #parsed.imports > 0 then
|
||||
local import_start, import_end = M.find_import_section(bufnr, filetype)
|
||||
|
||||
if import_start then
|
||||
-- Merge with existing import section
|
||||
local existing_imports = M.get_existing_imports(bufnr, filetype)
|
||||
local merged = M.merge_imports(existing_imports, parsed.imports)
|
||||
|
||||
-- Count how many new imports were actually added
|
||||
result.imports_added = #merged - #existing_imports
|
||||
result.imports_merged = true
|
||||
|
||||
-- Optionally sort imports
|
||||
if opts.sort_imports ~= false then
|
||||
merged = M.sort_imports(merged, filetype)
|
||||
end
|
||||
|
||||
-- Convert back to lines (handling multi-line imports)
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(merged) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Replace the import section
|
||||
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
|
||||
|
||||
-- Adjust line numbers for body injection
|
||||
local lines_diff = #import_lines - (import_end - import_start + 1)
|
||||
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
|
||||
opts.range.start_line = opts.range.start_line + lines_diff
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + lines_diff
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No existing import section, add imports at the top
|
||||
-- Find the first non-comment, non-empty line
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local insert_at = 0
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
-- Skip shebang, docstrings, and initial comments
|
||||
if trimmed ~= "" and not trimmed:match("^#!")
|
||||
and not trimmed:match("^['\"]") and not utils.is_empty_or_comment(line, filetype) then
|
||||
insert_at = i - 1
|
||||
break
|
||||
end
|
||||
insert_at = i
|
||||
end
|
||||
|
||||
-- Add imports with a trailing blank line
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(parsed.imports) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(import_lines, "") -- Blank line after imports
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
|
||||
result.imports_added = #parsed.imports
|
||||
result.imports_merged = false
|
||||
|
||||
-- Adjust body injection range
|
||||
if opts.range and opts.range.start_line then
|
||||
opts.range.start_line = opts.range.start_line + #import_lines
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + #import_lines
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle body (non-import) code
|
||||
if #parsed.body > 0 then
|
||||
-- Filter out empty leading/trailing lines from body
|
||||
local body_lines = parsed.body
|
||||
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
|
||||
table.remove(body_lines, 1)
|
||||
end
|
||||
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
|
||||
table.remove(body_lines)
|
||||
end
|
||||
|
||||
if #body_lines > 0 then
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
local strategy = opts.strategy or "append"
|
||||
|
||||
if strategy == "replace" and opts.range then
|
||||
local start_line = math.max(1, opts.range.start_line)
|
||||
local end_line = math.min(line_count, opts.range.end_line)
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
|
||||
elseif strategy == "insert" and opts.range then
|
||||
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
|
||||
else
|
||||
-- Default: append
|
||||
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Add blank line for spacing
|
||||
table.insert(body_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
|
||||
end
|
||||
|
||||
result.body_lines = #body_lines
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Check if code contains imports
|
||||
---@param code string|string[]
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
function M.has_imports(code, filetype)
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
return #parsed.imports > 0
|
||||
end
|
||||
|
||||
return M
|
||||
Reference in New Issue
Block a user