Adding more features

This commit is contained in:
2026-01-15 20:58:56 -05:00
parent 84c8bcf92c
commit f5df1a9ac0
40 changed files with 9145 additions and 458 deletions

View File

@@ -0,0 +1,854 @@
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
---@brief [[
--- Full agentic system that handles multi-file changes via tool calling.
--- Inspired by avante.nvim and opencode patterns.
---@brief ]]
local M = {}
---@class AgenticMessage
---@field role "system"|"user"|"assistant"|"tool"
---@field content string|table
---@field tool_calls? table[] For assistant messages with tool calls
---@field tool_call_id? string For tool result messages
---@field name? string Tool name for tool results
---@class AgenticToolCall
---@field id string Unique tool call ID
---@field type "function"
---@field function {name: string, arguments: string|table}
---@class AgenticOpts
---@field task string The task to accomplish
---@field files? string[] Initial files to include as context
---@field agent? string Agent name to use (default: "coder")
---@field model? string Model override
---@field max_iterations? number Max tool call rounds (default: 20)
---@field on_message? fun(msg: AgenticMessage) Called for each message
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
---@field on_file_change? fun(path: string, action: string) Called when file is modified
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
---@field on_status? fun(status: string) Status updates
--- Generate unique tool call ID
local function generate_tool_call_id()
return "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF))
end
--- Load agent definition
---@param name string Agent name
---@return table|nil agent definition
local function load_agent(name)
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
local agent_file = agents_dir .. "/" .. name .. ".md"
-- Check if custom agent exists
if vim.fn.filereadable(agent_file) == 1 then
local content = table.concat(vim.fn.readfile(agent_file), "\n")
-- Parse frontmatter and content
local frontmatter = {}
local body = content
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
if fm_match then
-- Parse YAML-like frontmatter
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
local key, value = line:match("^(%w+):%s*(.+)$")
if key and value then
frontmatter[key] = value
end
end
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
end
return {
name = name,
description = frontmatter.description or "Custom agent: " .. name,
system_prompt = body,
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
model = frontmatter.model,
}
end
-- Built-in agents
local builtin_agents = {
coder = {
name = "coder",
description = "Full-featured coding agent with file modification capabilities",
system_prompt = [[You are an expert software engineer. You have access to tools to read, write, and modify files.
## Your Capabilities
- Read files to understand the codebase
- Search for patterns with grep and glob
- Create new files with write tool
- Edit existing files with precise replacements
- Execute shell commands for builds and tests
## Guidelines
1. Always read relevant files before making changes
2. Make minimal, focused changes
3. Follow existing code style and patterns
4. Create tests when adding new functionality
5. Verify changes work by running tests or builds
## Important Rules
- NEVER guess file contents - always read first
- Make precise edits using exact string matching
- Explain your reasoning before making changes
- If unsure, ask for clarification]],
tools = { "view", "edit", "write", "grep", "glob", "bash" },
},
planner = {
name = "planner",
description = "Planning agent - read-only, helps design implementations",
system_prompt = [[You are a software architect. Analyze codebases and create implementation plans.
You can read files and search the codebase, but cannot modify files.
Your role is to:
1. Understand the existing architecture
2. Identify relevant files and patterns
3. Create step-by-step implementation plans
4. Suggest which files to modify and how
Be thorough in your analysis before making recommendations.]],
tools = { "view", "grep", "glob" },
},
explorer = {
name = "explorer",
description = "Exploration agent - quickly find information in codebase",
system_prompt = [[You are a codebase exploration assistant. Find information quickly and report back.
Your goal is to efficiently search and summarize findings.
Use glob to find files, grep to search content, and view to read specific files.
Be concise and focused in your responses.]],
tools = { "view", "grep", "glob" },
},
}
return builtin_agents[name]
end
--- Load rules from .coder/rules/
---@return string Combined rules content
local function load_rules()
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
local rules = {}
if vim.fn.isdirectory(rules_dir) == 1 then
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
for _, file in ipairs(files) do
local content = table.concat(vim.fn.readfile(file), "\n")
local filename = vim.fn.fnamemodify(file, ":t:r")
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
end
end
if #rules > 0 then
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
end
return ""
end
--- Build messages array for API request
---@param history AgenticMessage[]
---@param provider string "openai"|"claude"
---@return table[] Formatted messages
local function build_messages(history, provider)
local messages = {}
for _, msg in ipairs(history) do
if msg.role == "system" then
if provider == "claude" then
-- Claude uses system parameter, not message
-- Skip system messages in array
else
table.insert(messages, {
role = "system",
content = msg.content,
})
end
elseif msg.role == "user" then
table.insert(messages, {
role = "user",
content = msg.content,
})
elseif msg.role == "assistant" then
local message = {
role = "assistant",
content = msg.content,
}
if msg.tool_calls then
message.tool_calls = msg.tool_calls
if provider == "claude" then
-- Claude format: content is array of blocks
message.content = {}
if msg.content and msg.content ~= "" then
table.insert(message.content, {
type = "text",
text = msg.content,
})
end
for _, tc in ipairs(msg.tool_calls) do
table.insert(message.content, {
type = "tool_use",
id = tc.id,
name = tc["function"].name,
input = type(tc["function"].arguments) == "string"
and vim.json.decode(tc["function"].arguments)
or tc["function"].arguments,
})
end
end
end
table.insert(messages, message)
elseif msg.role == "tool" then
if provider == "claude" then
table.insert(messages, {
role = "user",
content = {
{
type = "tool_result",
tool_use_id = msg.tool_call_id,
content = msg.content,
},
},
})
else
table.insert(messages, {
role = "tool",
tool_call_id = msg.tool_call_id,
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
})
end
end
end
return messages
end
--- Build tools array for API request
---@param tool_names string[] Tool names to include
---@param provider string "openai"|"claude"
---@return table[] Formatted tools
local function build_tools(tool_names, provider)
local tools_mod = require("codetyper.agent.tools")
local tools = {}
for _, name in ipairs(tool_names) do
local tool = tools_mod.get(name)
if tool then
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if not param.optional then
table.insert(required, param.name)
end
end
local description = type(tool.description) == "function" and tool.description() or tool.description
if provider == "claude" then
table.insert(tools, {
name = tool.name,
description = description,
input_schema = {
type = "object",
properties = properties,
required = required,
},
})
else
table.insert(tools, {
type = "function",
["function"] = {
name = tool.name,
description = description,
parameters = {
type = "object",
properties = properties,
required = required,
},
},
})
end
end
end
return tools
end
--- Execute a tool call
---@param tool_call AgenticToolCall
---@param opts AgenticOpts
---@return string result
---@return string|nil error
local function execute_tool(tool_call, opts)
local tools_mod = require("codetyper.agent.tools")
local name = tool_call["function"].name
local args = tool_call["function"].arguments
-- Parse arguments if string
if type(args) == "string" then
local ok, parsed = pcall(vim.json.decode, args)
if ok then
args = parsed
else
return "", "Failed to parse tool arguments: " .. args
end
end
-- Notify tool start
if opts.on_tool_start then
opts.on_tool_start(name, args)
end
if opts.on_status then
opts.on_status("Executing: " .. name)
end
-- Execute the tool
local tool = tools_mod.get(name)
if not tool then
local err = "Unknown tool: " .. name
if opts.on_tool_end then
opts.on_tool_end(name, nil, err)
end
return "", err
end
local result, err = tool.func(args, {
on_log = function(msg)
if opts.on_status then
opts.on_status(msg)
end
end,
})
-- Notify tool end
if opts.on_tool_end then
opts.on_tool_end(name, result, err)
end
-- Track file changes
if opts.on_file_change and (name == "write" or name == "edit") and not err then
opts.on_file_change(args.path, name == "write" and "created" or "modified")
end
if err then
return "", err
end
return type(result) == "string" and result or vim.json.encode(result), nil
end
--- Parse tool calls from LLM response (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return AgenticToolCall[]
local function parse_tool_calls(response, provider)
local tool_calls = {}
-- Unified format: content array with tool_use blocks
local content = response.content or {}
for _, block in ipairs(content) do
if block.type == "tool_use" then
-- OpenAI expects arguments as JSON string, not table
local args = block.input
if type(args) == "table" then
args = vim.json.encode(args)
end
table.insert(tool_calls, {
id = block.id or generate_tool_call_id(),
type = "function",
["function"] = {
name = block.name,
arguments = args,
},
})
end
end
return tool_calls
end
--- Extract text content from response (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return string
local function extract_content(response, provider)
local parts = {}
for _, block in ipairs(response.content or {}) do
if block.type == "text" then
table.insert(parts, block.text)
end
end
return table.concat(parts, "\n")
end
--- Check if response indicates completion (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return boolean
local function is_complete(response, provider)
return response.stop_reason == "end_turn"
end
--- Make API request to LLM with native tool calling support
---@param messages table[] Formatted messages
---@param tools table[] Formatted tools
---@param system_prompt string System prompt
---@param provider string "openai"|"claude"|"copilot"
---@param model string Model name
---@param callback fun(response: table|nil, error: string|nil)
local function call_llm(messages, tools, system_prompt, provider, model, callback)
local context = {
language = "lua",
file_content = "",
prompt_type = "agent",
project_root = vim.fn.getcwd(),
cwd = vim.fn.getcwd(),
}
-- Use native tool calling APIs
if provider == "copilot" then
local client = require("codetyper.llm.copilot")
-- Copilot's generate_with_tools expects messages in a specific format
-- Convert to the format it expects
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
-- Convert to our internal format
local result = {
content = {},
stop_reason = "end_turn",
}
if response and response.content then
for _, block in ipairs(response.content) do
if block.type == "text" then
table.insert(result.content, { type = "text", text = block.text })
elseif block.type == "tool_use" then
table.insert(result.content, {
type = "tool_use",
id = block.id or generate_tool_call_id(),
name = block.name,
input = block.input,
})
result.stop_reason = "tool_use"
end
end
end
callback(result, nil)
end)
elseif provider == "openai" then
local client = require("codetyper.llm.openai")
-- OpenAI's generate_with_tools
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
local result = {
content = {},
stop_reason = "end_turn",
}
if response and response.content then
for _, block in ipairs(response.content) do
if block.type == "text" then
table.insert(result.content, { type = "text", text = block.text })
elseif block.type == "tool_use" then
table.insert(result.content, {
type = "tool_use",
id = block.id or generate_tool_call_id(),
name = block.name,
input = block.input,
})
result.stop_reason = "tool_use"
end
end
end
callback(result, nil)
end)
elseif provider == "ollama" then
local client = require("codetyper.llm.ollama")
-- Ollama's generate_with_tools (text-based tool calling)
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
callback(response, nil)
end)
else
-- Fallback for other providers (ollama, etc.) - use text-based parsing
local client = require("codetyper.llm." .. provider)
-- Build prompt from messages
local prompt_parts = {}
for _, msg in ipairs(messages) do
if msg.role == "user" then
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
table.insert(prompt_parts, "User: " .. content)
elseif msg.role == "assistant" then
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
table.insert(prompt_parts, "Assistant: " .. content)
end
end
-- Add tool descriptions to prompt for text-based providers
local tool_desc = "\n\n## Available Tools\n"
tool_desc = tool_desc .. "Call tools by outputting JSON in this format:\n"
tool_desc = tool_desc .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
for _, tool in ipairs(tools) do
local name = tool.name or (tool["function"] and tool["function"].name)
local desc = tool.description or (tool["function"] and tool["function"].description)
if name then
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
end
end
context.file_content = system_prompt .. tool_desc
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
if err then
callback(nil, err)
return
end
-- Parse response for tool calls (text-based fallback)
local result = {
content = {},
stop_reason = "end_turn",
}
-- Extract text content
local text_content = response
-- Try to extract JSON tool calls from response
local json_match = response:match("```json%s*(%b{})%s*```")
if json_match then
local ok, parsed = pcall(vim.json.decode, json_match)
if ok and parsed.tool then
table.insert(result.content, {
type = "tool_use",
id = generate_tool_call_id(),
name = parsed.tool,
input = parsed.arguments or {},
})
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
result.stop_reason = "tool_use"
end
end
if text_content and text_content ~= "" then
table.insert(result.content, 1, { type = "text", text = text_content })
end
callback(result, nil)
end)
end
end
--- Run the agentic loop
---@param opts AgenticOpts
function M.run(opts)
-- Load agent
local agent = load_agent(opts.agent or "coder")
if not agent then
if opts.on_complete then
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
end
return
end
-- Load rules
local rules = load_rules()
-- Build system prompt
local system_prompt = agent.system_prompt .. rules
-- Initialize message history
---@type AgenticMessage[]
local history = {
{ role = "system", content = system_prompt },
}
-- Add initial file context if provided
if opts.files and #opts.files > 0 then
local file_context = "# Initial Files\n"
for _, file_path in ipairs(opts.files) do
local content = table.concat(vim.fn.readfile(file_path) or {}, "\n")
file_context = file_context .. string.format("\n## %s\n```\n%s\n```\n", file_path, content)
end
table.insert(history, { role = "user", content = file_context })
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
end
-- Add the task
table.insert(history, { role = "user", content = opts.task })
-- Determine provider
local config = require("codetyper").get_config()
local provider = config.llm.provider or "copilot"
-- Note: Ollama has its own handler in call_llm, don't change it
-- Get tools for this agent
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
-- Ensure tools are loaded
local tools_mod = require("codetyper.agent.tools")
tools_mod.setup()
-- Build tools for API
local tools = build_tools(tool_names, provider)
-- Iteration tracking
local iteration = 0
local max_iterations = opts.max_iterations or 20
--- Process one iteration
local function process_iteration()
iteration = iteration + 1
if iteration > max_iterations then
if opts.on_complete then
opts.on_complete(nil, "Max iterations reached")
end
return
end
if opts.on_status then
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
end
-- Build messages for API
local messages = build_messages(history, provider)
-- Call LLM
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
if err then
if opts.on_complete then
opts.on_complete(nil, err)
end
return
end
-- Extract content and tool calls
local content = extract_content(response, provider)
local tool_calls = parse_tool_calls(response, provider)
-- Add assistant message to history
local assistant_msg = {
role = "assistant",
content = content,
tool_calls = #tool_calls > 0 and tool_calls or nil,
}
table.insert(history, assistant_msg)
if opts.on_message then
opts.on_message(assistant_msg)
end
-- Process tool calls if any
if #tool_calls > 0 then
for _, tc in ipairs(tool_calls) do
local result, tool_err = execute_tool(tc, opts)
-- Add tool result to history
local tool_msg = {
role = "tool",
tool_call_id = tc.id,
name = tc["function"].name,
content = tool_err or result,
}
table.insert(history, tool_msg)
if opts.on_message then
opts.on_message(tool_msg)
end
end
-- Continue the loop
vim.schedule(process_iteration)
else
-- No tool calls - check if complete
if is_complete(response, provider) or content ~= "" then
if opts.on_complete then
opts.on_complete(content, nil)
end
else
-- Continue if not explicitly complete
vim.schedule(process_iteration)
end
end
end)
end
-- Start the loop
process_iteration()
end
--- Create default agent files in .coder/agents/
function M.init_agents_dir()
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
vim.fn.mkdir(agents_dir, "p")
-- Create example agent
local example_agent = [[---
description: Example custom agent
tools: view,grep,glob,edit,write
model:
---
# Custom Agent
You are a custom coding agent. Describe your specialized behavior here.
## Your Role
- Define what this agent specializes in
- List specific capabilities
## Guidelines
- Add agent-specific rules
- Define coding standards to follow
## Examples
Provide examples of how to handle common tasks.
]]
local example_path = agents_dir .. "/example.md"
if vim.fn.filereadable(example_path) ~= 1 then
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
end
return agents_dir
end
--- Create default rules in .coder/rules/
function M.init_rules_dir()
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
vim.fn.mkdir(rules_dir, "p")
-- Create example rule
local example_rule = [[# Code Style
Follow these coding standards:
## General
- Use consistent indentation (tabs or spaces based on project)
- Keep lines under 100 characters
- Add comments for complex logic
## Naming Conventions
- Use descriptive variable names
- Functions should be verbs (e.g., getUserData, calculateTotal)
- Constants in UPPER_SNAKE_CASE
## Testing
- Write tests for new functionality
- Aim for >80% code coverage
- Test edge cases
## Documentation
- Document public APIs
- Include usage examples
- Keep docs up to date with code
]]
local example_path = rules_dir .. "/code-style.md"
if vim.fn.filereadable(example_path) ~= 1 then
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
end
return rules_dir
end
--- Initialize both agents and rules directories
function M.init()
M.init_agents_dir()
M.init_rules_dir()
end
--- List available agents
---@return table[] List of {name, description, builtin}
function M.list_agents()
local agents = {}
-- Built-in agents
local builtins = { "coder", "planner", "explorer" }
for _, name in ipairs(builtins) do
local agent = load_agent(name)
if agent then
table.insert(agents, {
name = agent.name,
description = agent.description,
builtin = true,
})
end
end
-- Custom agents from .coder/agents/
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
if vim.fn.isdirectory(agents_dir) == 1 then
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
for _, file in ipairs(files) do
local name = vim.fn.fnamemodify(file, ":t:r")
if not vim.tbl_contains(builtins, name) then
local agent = load_agent(name)
if agent then
table.insert(agents, {
name = agent.name,
description = agent.description,
builtin = false,
})
end
end
end
end
return agents
end
return M

View File

@@ -8,11 +8,11 @@ local M = {}
--- Heuristic weights (must sum to 1.0)
M.weights = {
length = 0.15, -- Response length relative to prompt
length = 0.15, -- Response length relative to prompt
uncertainty = 0.30, -- Uncertainty phrases
syntax = 0.25, -- Syntax completeness
repetition = 0.15, -- Duplicate lines
truncation = 0.15, -- Incomplete ending
syntax = 0.25, -- Syntax completeness
repetition = 0.15, -- Duplicate lines
truncation = 0.15, -- Incomplete ending
}
--- Uncertainty phrases that indicate low confidence
@@ -255,14 +255,15 @@ function M.score(response, prompt, context)
_ = context -- Reserved for future use
if not response or #response == 0 then
return 0, {
length = 0,
uncertainty = 0,
syntax = 0,
repetition = 0,
truncation = 0,
weighted_total = 0,
}
return 0,
{
length = 0,
uncertainty = 0,
syntax = 0,
repetition = 0,
truncation = 0,
weighted_total = 0,
}
end
local scores = {

View File

@@ -111,7 +111,11 @@ function M.agent_loop(context, callbacks)
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
-- Generate with tools enabled
client.generate_with_tools(state.conversation, context, tools.definitions, function(response, err)
-- Ensure tools are loaded and get definitions
tools.setup()
local tool_defs = tools.to_openai_format()
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
if err then
state.is_running = false
callbacks.on_error(err)

View File

@@ -0,0 +1,614 @@
---@mod codetyper.agent.inject Smart code injection with import handling
---@brief [[
--- Intelligent code injection that properly handles imports, merging them
--- into existing import sections instead of blindly appending.
---@brief ]]
local M = {}
---@class ImportConfig
---@field pattern string Lua pattern to match import statements
---@field multi_line boolean Whether imports can span multiple lines
---@field sort_key function|nil Function to extract sort key from import
---@field group_by function|nil Function to group imports
---@class ParsedCode
---@field imports string[] Import statements
---@field body string[] Non-import code lines
---@field import_lines table<number, boolean> Map of line numbers that are imports
--- Language-specific import patterns
local import_patterns = {
-- JavaScript/TypeScript
javascript = {
{ pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true },
{ pattern = "^%s*import%s+['\"]", multi_line = false },
{ pattern = "^%s*import%s*{", multi_line = true },
{ pattern = "^%s*import%s*%*", multi_line = true },
{ pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true },
{ pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false },
{ pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false },
{ pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false },
},
-- Python
python = {
{ pattern = "^%s*import%s+%w", multi_line = false },
{ pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true },
},
-- Lua
lua = {
{ pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false },
{ pattern = "^%s*require%s*%(?['\"]", multi_line = false },
},
-- Go
go = {
{ pattern = "^%s*import%s+%(?", multi_line = true },
},
-- Rust
rust = {
{ pattern = "^%s*use%s+", multi_line = true },
{ pattern = "^%s*extern%s+crate%s+", multi_line = false },
},
-- C/C++
c = {
{ pattern = "^%s*#include%s*[<\"]", multi_line = false },
},
-- Java/Kotlin
java = {
{ pattern = "^%s*import%s+", multi_line = false },
},
-- Ruby
ruby = {
{ pattern = "^%s*require%s+['\"]", multi_line = false },
{ pattern = "^%s*require_relative%s+['\"]", multi_line = false },
},
-- PHP
php = {
{ pattern = "^%s*use%s+", multi_line = false },
{ pattern = "^%s*require%s+['\"]", multi_line = false },
{ pattern = "^%s*require_once%s+['\"]", multi_line = false },
{ pattern = "^%s*include%s+['\"]", multi_line = false },
{ pattern = "^%s*include_once%s+['\"]", multi_line = false },
},
}
-- Alias common extensions to language configs
import_patterns.ts = import_patterns.javascript
import_patterns.tsx = import_patterns.javascript
import_patterns.jsx = import_patterns.javascript
import_patterns.mjs = import_patterns.javascript
import_patterns.cjs = import_patterns.javascript
import_patterns.py = import_patterns.python
import_patterns.cpp = import_patterns.c
import_patterns.hpp = import_patterns.c
import_patterns.h = import_patterns.c
import_patterns.kt = import_patterns.java
import_patterns.rs = import_patterns.rust
import_patterns.rb = import_patterns.ruby
--- Check if a line is an import statement for the given language
---@param line string
---@param patterns table[] Import patterns for the language
---@return boolean is_import
---@return boolean is_multi_line
local function is_import_line(line, patterns)
for _, p in ipairs(patterns) do
if line:match(p.pattern) then
return true, p.multi_line or false
end
end
return false, false
end
--- Check if a line is empty or a comment
---@param line string
---@param filetype string
---@return boolean
local function is_empty_or_comment(line, filetype)
local trimmed = line:match("^%s*(.-)%s*$")
if trimmed == "" then
return true
end
-- Language-specific comment patterns
local comment_patterns = {
lua = { "^%-%-" },
python = { "^#" },
javascript = { "^//", "^/%*", "^%*" },
typescript = { "^//", "^/%*", "^%*" },
go = { "^//", "^/%*", "^%*" },
rust = { "^//", "^/%*", "^%*" },
c = { "^//", "^/%*", "^%*", "^#" },
java = { "^//", "^/%*", "^%*" },
ruby = { "^#" },
php = { "^//", "^/%*", "^%*", "^#" },
}
local patterns = comment_patterns[filetype] or comment_patterns.javascript
for _, pattern in ipairs(patterns) do
if trimmed:match(pattern) then
return true
end
end
return false
end
--- Check if a line ends a multi-line import
---@param line string
---@param filetype string
---@return boolean
local function ends_multiline_import(line, filetype)
-- Check for closing patterns
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
-- ES6 imports end with 'from "..." ;' or just ';' or a line with just '}'
if line:match("from%s+['\"][^'\"]+['\"]%s*;?%s*$") then
return true
end
if line:match("}%s*from%s+['\"]") then
return true
end
if line:match("^%s*}%s*;?%s*$") then
return true
end
if line:match(";%s*$") then
return true
end
elseif filetype == "python" or filetype == "py" then
-- Python single-line import: doesn't end with \, (, or ,
-- Examples: "from typing import List, Dict" or "import os"
if not line:match("\\%s*$") and not line:match("%(%s*$") and not line:match(",%s*$") then
return true
end
-- Python multiline imports end with closing paren
if line:match("%)%s*$") then
return true
end
elseif filetype == "go" then
-- Go multi-line imports end with ')'
if line:match("%)%s*$") then
return true
end
elseif filetype == "rust" or filetype == "rs" then
-- Rust use statements end with ';'
if line:match(";%s*$") then
return true
end
end
return false
end
--- Parse code into imports and body
---@param code string|string[] Code to parse
---@param filetype string File type/extension
---@return ParsedCode
function M.parse_code(code, filetype)
local lines
if type(code) == "string" then
lines = vim.split(code, "\n", { plain = true })
else
lines = code
end
local patterns = import_patterns[filetype] or import_patterns.javascript
local result = {
imports = {},
body = {},
import_lines = {},
}
local in_multiline_import = false
local current_import_lines = {}
for i, line in ipairs(lines) do
if in_multiline_import then
-- Continue collecting multi-line import
table.insert(current_import_lines, line)
if ends_multiline_import(line, filetype) then
-- Complete the multi-line import
table.insert(result.imports, table.concat(current_import_lines, "\n"))
for j = i - #current_import_lines + 1, i do
result.import_lines[j] = true
end
current_import_lines = {}
in_multiline_import = false
end
else
local is_import, is_multi = is_import_line(line, patterns)
if is_import then
result.import_lines[i] = true
if is_multi and not ends_multiline_import(line, filetype) then
-- Start of multi-line import
in_multiline_import = true
current_import_lines = { line }
else
-- Single-line import
table.insert(result.imports, line)
end
else
-- Non-import line
table.insert(result.body, line)
end
end
end
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
if #current_import_lines > 0 then
table.insert(result.imports, table.concat(current_import_lines, "\n"))
end
return result
end
--- Find the import section range in a buffer
---@param bufnr number Buffer number
---@param filetype string
---@return number|nil start_line First import line (1-indexed)
---@return number|nil end_line Last import line (1-indexed)
function M.find_import_section(bufnr, filetype)
if not vim.api.nvim_buf_is_valid(bufnr) then
return nil, nil
end
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local patterns = import_patterns[filetype] or import_patterns.javascript
local first_import = nil
local last_import = nil
local in_multiline = false
local consecutive_non_import = 0
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
for i, line in ipairs(lines) do
if in_multiline then
last_import = i
consecutive_non_import = 0
if ends_multiline_import(line, filetype) then
in_multiline = false
end
else
local is_import, is_multi = is_import_line(line, patterns)
if is_import then
if not first_import then
first_import = i
end
last_import = i
consecutive_non_import = 0
if is_multi and not ends_multiline_import(line, filetype) then
in_multiline = true
end
elseif is_empty_or_comment(line, filetype) then
-- Allow gaps in import section
if first_import then
consecutive_non_import = consecutive_non_import + 1
if consecutive_non_import > max_gap then
-- Too many non-import lines, import section has ended
break
end
end
else
-- Non-import, non-empty line
if first_import then
-- Import section has ended
break
end
end
end
end
return first_import, last_import
end
--- Get existing imports from a buffer
---@param bufnr number Buffer number
---@param filetype string
---@return string[] Existing import statements
function M.get_existing_imports(bufnr, filetype)
local start_line, end_line = M.find_import_section(bufnr, filetype)
if not start_line then
return {}
end
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
local parsed = M.parse_code(lines, filetype)
return parsed.imports
end
--- Normalize an import for comparison (remove whitespace variations)
---@param import_str string
---@return string
local function normalize_import(import_str)
-- Remove trailing semicolon for comparison
local normalized = import_str:gsub(";%s*$", "")
-- Remove all whitespace around braces, commas, colons
normalized = normalized:gsub("%s*{%s*", "{")
normalized = normalized:gsub("%s*}%s*", "}")
normalized = normalized:gsub("%s*,%s*", ",")
normalized = normalized:gsub("%s*:%s*", ":")
-- Collapse multiple whitespace to single space
normalized = normalized:gsub("%s+", " ")
-- Trim leading/trailing whitespace
normalized = normalized:match("^%s*(.-)%s*$")
return normalized
end
--- Check if two imports are duplicates
---@param import1 string
---@param import2 string
---@return boolean
local function are_duplicate_imports(import1, import2)
return normalize_import(import1) == normalize_import(import2)
end
--- Merge new imports with existing ones, avoiding duplicates
---@param existing string[] Existing imports
---@param new_imports string[] New imports to merge
---@return string[] Merged imports
function M.merge_imports(existing, new_imports)
local merged = {}
local seen = {}
-- Add existing imports
for _, imp in ipairs(existing) do
local normalized = normalize_import(imp)
if not seen[normalized] then
seen[normalized] = true
table.insert(merged, imp)
end
end
-- Add new imports that aren't duplicates
for _, imp in ipairs(new_imports) do
local normalized = normalize_import(imp)
if not seen[normalized] then
seen[normalized] = true
table.insert(merged, imp)
end
end
return merged
end
--- Sort imports by their source/module
---@param imports string[]
---@param filetype string
---@return string[]
function M.sort_imports(imports, filetype)
-- Group imports: stdlib/builtin first, then third-party, then local
local builtin = {}
local third_party = {}
local local_imports = {}
for _, imp in ipairs(imports) do
-- Detect import type based on patterns
local is_local = false
local is_builtin = false
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
-- Local: starts with . or ..
is_local = imp:match("from%s+['\"]%.") or imp:match("require%(['\"]%.")
-- Node builtin modules
is_builtin = imp:match("from%s+['\"]node:") or imp:match("from%s+['\"]fs['\"]")
or imp:match("from%s+['\"]path['\"]") or imp:match("from%s+['\"]http['\"]")
elseif filetype == "python" or filetype == "py" then
-- Local: relative imports
is_local = imp:match("^from%s+%.") or imp:match("^import%s+%.")
-- Python stdlib (simplified check)
is_builtin = imp:match("^import%s+os") or imp:match("^import%s+sys")
or imp:match("^from%s+os%s+") or imp:match("^from%s+sys%s+")
or imp:match("^import%s+re") or imp:match("^import%s+json")
elseif filetype == "lua" then
-- Local: relative requires
is_local = imp:match("require%(['\"]%.") or imp:match("require%s+['\"]%.")
elseif filetype == "go" then
-- Local: project imports (contain /)
is_local = imp:match("['\"][^'\"]+/[^'\"]+['\"]") and not imp:match("github%.com")
end
if is_builtin then
table.insert(builtin, imp)
elseif is_local then
table.insert(local_imports, imp)
else
table.insert(third_party, imp)
end
end
-- Sort each group alphabetically
table.sort(builtin)
table.sort(third_party)
table.sort(local_imports)
-- Combine with proper spacing
local result = {}
for _, imp in ipairs(builtin) do
table.insert(result, imp)
end
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
table.insert(result, "") -- Blank line between groups
end
for _, imp in ipairs(third_party) do
table.insert(result, imp)
end
if #third_party > 0 and #local_imports > 0 then
table.insert(result, "") -- Blank line between groups
end
for _, imp in ipairs(local_imports) do
table.insert(result, imp)
end
return result
end
---@class InjectResult
---@field success boolean
---@field imports_added number Number of new imports added
---@field imports_merged boolean Whether imports were merged into existing section
---@field body_lines number Number of body lines injected
--- Smart inject code into a buffer, properly handling imports
---@param bufnr number Target buffer
---@param code string|string[] Code to inject
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
---@return InjectResult
function M.inject(bufnr, code, opts)
opts = opts or {}
if not vim.api.nvim_buf_is_valid(bufnr) then
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
end
-- Get filetype
local filetype = opts.filetype
if not filetype then
local bufname = vim.api.nvim_buf_get_name(bufnr)
filetype = vim.fn.fnamemodify(bufname, ":e")
end
-- Parse the code to separate imports from body
local parsed = M.parse_code(code, filetype)
local result = {
success = true,
imports_added = 0,
imports_merged = false,
body_lines = #parsed.body,
}
-- Handle imports first if there are any
if #parsed.imports > 0 then
local import_start, import_end = M.find_import_section(bufnr, filetype)
if import_start then
-- Merge with existing import section
local existing_imports = M.get_existing_imports(bufnr, filetype)
local merged = M.merge_imports(existing_imports, parsed.imports)
-- Count how many new imports were actually added
result.imports_added = #merged - #existing_imports
result.imports_merged = true
-- Optionally sort imports
if opts.sort_imports ~= false then
merged = M.sort_imports(merged, filetype)
end
-- Convert back to lines (handling multi-line imports)
local import_lines = {}
for _, imp in ipairs(merged) do
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
table.insert(import_lines, line)
end
end
-- Replace the import section
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
-- Adjust line numbers for body injection
local lines_diff = #import_lines - (import_end - import_start + 1)
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
opts.range.start_line = opts.range.start_line + lines_diff
if opts.range.end_line then
opts.range.end_line = opts.range.end_line + lines_diff
end
end
else
-- No existing import section, add imports at the top
-- Find the first non-comment, non-empty line
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local insert_at = 0
for i, line in ipairs(lines) do
local trimmed = line:match("^%s*(.-)%s*$")
-- Skip shebang, docstrings, and initial comments
if trimmed ~= "" and not trimmed:match("^#!")
and not trimmed:match("^['\"]") and not is_empty_or_comment(line, filetype) then
insert_at = i - 1
break
end
insert_at = i
end
-- Add imports with a trailing blank line
local import_lines = {}
for _, imp in ipairs(parsed.imports) do
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
table.insert(import_lines, line)
end
end
table.insert(import_lines, "") -- Blank line after imports
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
result.imports_added = #parsed.imports
result.imports_merged = false
-- Adjust body injection range
if opts.range and opts.range.start_line then
opts.range.start_line = opts.range.start_line + #import_lines
if opts.range.end_line then
opts.range.end_line = opts.range.end_line + #import_lines
end
end
end
end
-- Handle body (non-import) code
if #parsed.body > 0 then
-- Filter out empty leading/trailing lines from body
local body_lines = parsed.body
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
table.remove(body_lines, 1)
end
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
table.remove(body_lines)
end
if #body_lines > 0 then
local line_count = vim.api.nvim_buf_line_count(bufnr)
local strategy = opts.strategy or "append"
if strategy == "replace" and opts.range then
local start_line = math.max(1, opts.range.start_line)
local end_line = math.min(line_count, opts.range.end_line)
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
elseif strategy == "insert" and opts.range then
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
else
-- Default: append
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
if last_line:match("%S") then
-- Add blank line for spacing
table.insert(body_lines, 1, "")
end
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
end
result.body_lines = #body_lines
end
end
return result
end
--- Check if code contains imports
---@param code string|string[]
---@param filetype string
---@return boolean
function M.has_imports(code, filetype)
local parsed = M.parse_code(code, filetype)
return #parsed.imports > 0
end
return M

View File

@@ -0,0 +1,398 @@
---@mod codetyper.agent.loop Agent loop with tool orchestration
---@brief [[
--- Main agent loop that handles multi-turn conversations with tool use.
--- Inspired by avante.nvim's agent_loop pattern.
---@brief ]]
local M = {}
---@class AgentMessage
---@field role "system"|"user"|"assistant"|"tool"
---@field content string|table
---@field tool_call_id? string For tool responses
---@field tool_calls? table[] For assistant tool calls
---@field name? string Tool name for tool responses
---@class AgentLoopOpts
---@field system_prompt string System prompt
---@field user_input string Initial user message
---@field tools? CoderTool[] Available tools (default: all registered)
---@field max_iterations? number Max tool call iterations (default: 10)
---@field provider? string LLM provider to use
---@field on_start? fun() Called when loop starts
---@field on_chunk? fun(chunk: string) Called for each response chunk
---@field on_tool_call? fun(name: string, input: table) Called before tool execution
---@field on_tool_result? fun(name: string, result: any, error: string|nil) Called after tool execution
---@field on_message? fun(message: AgentMessage) Called for each message added
---@field on_complete? fun(result: string|nil, error: string|nil) Called when loop completes
---@field session_ctx? table Session context shared across tools
--- Format tool definitions for OpenAI-compatible API
---@param tools CoderTool[]
---@return table[]
local function format_tools_for_api(tools)
local formatted = {}
for _, tool in ipairs(tools) do
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if not param.optional then
table.insert(required, param.name)
end
end
table.insert(formatted, {
type = "function",
["function"] = {
name = tool.name,
description = type(tool.description) == "function" and tool.description() or tool.description,
parameters = {
type = "object",
properties = properties,
required = required,
},
},
})
end
return formatted
end
--- Parse tool calls from LLM response
---@param response table LLM response
---@return table[] tool_calls
local function parse_tool_calls(response)
local tool_calls = {}
-- Handle different response formats
if response.tool_calls then
-- OpenAI format
for _, call in ipairs(response.tool_calls) do
local args = call["function"].arguments
if type(args) == "string" then
local ok, parsed = pcall(vim.json.decode, args)
if ok then
args = parsed
end
end
table.insert(tool_calls, {
id = call.id,
name = call["function"].name,
input = args,
})
end
elseif response.content and type(response.content) == "table" then
-- Claude format (content blocks)
for _, block in ipairs(response.content) do
if block.type == "tool_use" then
table.insert(tool_calls, {
id = block.id,
name = block.name,
input = block.input,
})
end
end
end
return tool_calls
end
--- Build messages for LLM request
---@param history AgentMessage[]
---@return table[]
local function build_messages(history)
local messages = {}
for _, msg in ipairs(history) do
if msg.role == "system" then
table.insert(messages, {
role = "system",
content = msg.content,
})
elseif msg.role == "user" then
table.insert(messages, {
role = "user",
content = msg.content,
})
elseif msg.role == "assistant" then
local message = {
role = "assistant",
content = msg.content,
}
if msg.tool_calls then
message.tool_calls = msg.tool_calls
end
table.insert(messages, message)
elseif msg.role == "tool" then
table.insert(messages, {
role = "tool",
tool_call_id = msg.tool_call_id,
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
})
end
end
return messages
end
--- Execute the agent loop
---@param opts AgentLoopOpts
function M.run(opts)
local tools_mod = require("codetyper.agent.tools")
local llm = require("codetyper.llm")
-- Get tools
local tools = opts.tools or tools_mod.list()
local tool_map = {}
for _, tool in ipairs(tools) do
tool_map[tool.name] = tool
end
-- Initialize conversation history
---@type AgentMessage[]
local history = {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = opts.user_input },
}
local session_ctx = opts.session_ctx or {}
local max_iterations = opts.max_iterations or 10
local iteration = 0
-- Callback wrappers
local function on_message(msg)
if opts.on_message then
opts.on_message(msg)
end
end
-- Notify of initial messages
for _, msg in ipairs(history) do
on_message(msg)
end
-- Start notification
if opts.on_start then
opts.on_start()
end
--- Process one iteration of the loop
local function process_iteration()
iteration = iteration + 1
if iteration > max_iterations then
if opts.on_complete then
opts.on_complete(nil, "Max iterations reached")
end
return
end
-- Build request
local messages = build_messages(history)
local formatted_tools = format_tools_for_api(tools)
-- Build context for LLM
local context = {
file_content = "",
language = "lua",
extension = "lua",
prompt_type = "agent",
tools = formatted_tools,
}
-- Get LLM response
local client = llm.get_client()
if not client then
if opts.on_complete then
opts.on_complete(nil, "No LLM client available")
end
return
end
-- Build prompt from messages
local prompt_parts = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or ""))
end
end
local prompt = table.concat(prompt_parts, "\n\n")
client.generate(prompt, context, function(response, error)
if error then
if opts.on_complete then
opts.on_complete(nil, error)
end
return
end
-- Chunk callback
if opts.on_chunk then
opts.on_chunk(response)
end
-- Parse response for tool calls
-- For now, we'll use a simple heuristic to detect tool calls in the response
-- In a full implementation, the LLM would return structured tool calls
local tool_calls = {}
-- Try to parse JSON tool calls from response
local json_match = response:match("```json%s*(%b{})%s*```")
if json_match then
local ok, parsed = pcall(vim.json.decode, json_match)
if ok and parsed.tool_calls then
tool_calls = parsed.tool_calls
end
end
-- Add assistant message
local assistant_msg = {
role = "assistant",
content = response,
tool_calls = #tool_calls > 0 and tool_calls or nil,
}
table.insert(history, assistant_msg)
on_message(assistant_msg)
-- Process tool calls
if #tool_calls > 0 then
local pending = #tool_calls
local results = {}
for i, call in ipairs(tool_calls) do
local tool = tool_map[call.name]
if not tool then
results[i] = { error = "Unknown tool: " .. call.name }
pending = pending - 1
else
-- Notify of tool call
if opts.on_tool_call then
opts.on_tool_call(call.name, call.input)
end
-- Execute tool
local tool_opts = {
on_log = function(msg)
pcall(function()
local logs = require("codetyper.agent.logs")
logs.add({ type = "tool", message = msg })
end)
end,
on_complete = function(result, err)
results[i] = { result = result, error = err }
pending = pending - 1
-- Notify of tool result
if opts.on_tool_result then
opts.on_tool_result(call.name, result, err)
end
-- Add tool response to history
local tool_msg = {
role = "tool",
tool_call_id = call.id or tostring(i),
name = call.name,
content = err or result,
}
table.insert(history, tool_msg)
on_message(tool_msg)
-- Continue loop when all tools complete
if pending == 0 then
vim.schedule(process_iteration)
end
end,
session_ctx = session_ctx,
}
-- Validate and execute
local valid, validation_err = true, nil
if tool.validate_input then
valid, validation_err = tool:validate_input(call.input)
end
if not valid then
tool_opts.on_complete(nil, validation_err)
else
local result, err = tool.func(call.input, tool_opts)
-- If sync result, call on_complete
if result ~= nil or err ~= nil then
tool_opts.on_complete(result, err)
end
end
end
end
else
-- No tool calls - loop complete
if opts.on_complete then
opts.on_complete(response, nil)
end
end
end)
end
-- Start the loop
process_iteration()
end
--- Create an agent with default settings
---@param task string Task description
---@param opts? AgentLoopOpts Additional options
function M.create(task, opts)
opts = opts or {}
local system_prompt = opts.system_prompt or [[You are a helpful coding assistant with access to tools.
Available tools:
- view: Read file contents
- grep: Search for patterns in files
- glob: Find files by pattern
- edit: Make targeted edits to files
- write: Create or overwrite files
- bash: Execute shell commands
When you need to perform a task:
1. Use tools to gather information
2. Plan your approach
3. Execute changes using appropriate tools
4. Verify the results
Always explain your reasoning before using tools.
When you're done, provide a clear summary of what was accomplished.]]
M.run(vim.tbl_extend("force", opts, {
system_prompt = system_prompt,
user_input = task,
}))
end
--- Simple dispatch agent for sub-tasks
---@param prompt string Task for the sub-agent
---@param on_complete fun(result: string|nil, error: string|nil) Completion callback
---@param opts? table Additional options
function M.dispatch(prompt, on_complete, opts)
opts = opts or {}
-- Sub-agents get limited tools by default
local tools_mod = require("codetyper.agent.tools")
local safe_tools = tools_mod.list(function(tool)
return tool.name == "view" or tool.name == "grep" or tool.name == "glob"
end)
M.run({
system_prompt = [[You are a research assistant. Your task is to find information and report back.
You have access to: view (read files), grep (search content), glob (find files).
Be thorough and report your findings clearly.]],
user_input = prompt,
tools = opts.tools or safe_tools,
max_iterations = opts.max_iterations or 5,
on_complete = on_complete,
session_ctx = opts.session_ctx,
})
end
return M

View File

@@ -2,10 +2,16 @@
---@brief [[
--- Manages code patches with buffer snapshots for staleness detection.
--- Patches are queued for safe injection when completion popup is not visible.
--- Uses smart injection for intelligent import merging.
---@brief ]]
local M = {}
--- Lazy load inject module to avoid circular requires
local function get_inject_module()
return require("codetyper.agent.inject")
end
---@class BufferSnapshot
---@field bufnr number Buffer number
---@field changedtick number vim.b.changedtick at snapshot time
@@ -15,7 +21,8 @@ local M = {}
---@class PatchCandidate
---@field id string Unique patch ID
---@field event_id string Related PromptEvent ID
---@field target_bufnr number Target buffer for injection
---@field source_bufnr number Source buffer where prompt tags are (coder file)
---@field target_bufnr number Target buffer for injection (real file)
---@field target_path string Target file path
---@field original_snapshot BufferSnapshot Snapshot at event creation
---@field generated_code string Code to inject
@@ -171,7 +178,10 @@ end
---@param strategy string|nil Injection strategy (overrides intent-based)
---@return PatchCandidate
function M.create_from_event(event, generated_code, confidence, strategy)
-- Get target buffer
-- Source buffer is where the prompt tags are (could be coder file)
local source_bufnr = event.bufnr
-- Get target buffer (where code should be injected - the real file)
local target_bufnr = vim.fn.bufnr(event.target_path)
if target_bufnr == -1 then
-- Try to find by filename
@@ -220,7 +230,8 @@ function M.create_from_event(event, generated_code, confidence, strategy)
return {
id = M.generate_id(),
event_id = event.id,
target_bufnr = target_bufnr,
source_bufnr = source_bufnr, -- Where prompt tags are (coder file)
target_bufnr = target_bufnr, -- Where code goes (real file)
target_path = event.target_path,
original_snapshot = snapshot,
generated_code = generated_code,
@@ -453,39 +464,56 @@ function M.apply(patch)
-- Prepare code lines
local code_lines = vim.split(patch.generated_code, "\n", { plain = true })
-- FIRST: Remove the prompt tags from the buffer before applying code
-- This prevents the infinite loop where tags stay and get re-detected
local tags_removed = remove_prompt_tags(target_bufnr)
-- FIRST: Remove the prompt tags from the SOURCE buffer (coder file), not target
-- The tags are in the coder file where the user wrote the prompt
-- Code goes to target file, tags get removed from source file
local source_bufnr = patch.source_bufnr
local tags_removed = 0
pcall(function()
if tags_removed > 0 then
local logs = require("codetyper.agent.logs")
logs.add({
type = "info",
message = string.format("Removed %d prompt tag(s) from buffer", tags_removed),
})
end
end)
if source_bufnr and vim.api.nvim_buf_is_valid(source_bufnr) then
tags_removed = remove_prompt_tags(source_bufnr)
-- Recalculate line count after tag removal
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
pcall(function()
if tags_removed > 0 then
local logs = require("codetyper.agent.logs")
local source_name = vim.api.nvim_buf_get_name(source_bufnr)
logs.add({
type = "info",
message = string.format("Removed %d prompt tag(s) from %s",
tags_removed,
vim.fn.fnamemodify(source_name, ":t")),
})
end
end)
end
-- Apply based on strategy
-- Get filetype for smart injection
local filetype = vim.fn.fnamemodify(patch.target_path or "", ":e")
-- Use smart injection module for intelligent import handling
local inject = get_inject_module()
local inject_result = nil
-- Apply based on strategy using smart injection
local ok, err = pcall(function()
-- Prepare injection options
local inject_opts = {
strategy = patch.injection_strategy,
filetype = filetype,
sort_imports = true,
}
if patch.injection_strategy == "replace" and patch.injection_range then
-- Replace the scope range with the new code
-- The injection_range points to the function/method we're completing
local start_line = patch.injection_range.start_line
local end_line = patch.injection_range.end_line
-- Adjust for tag removal - find the new range by searching for the scope
-- After removing tags, line numbers may have shifted
-- Use the scope information to find the correct range
if patch.scope and patch.scope.type then
-- Try to find the scope using treesitter if available
local found_range = nil
pcall(function()
local ts_utils = require("nvim-treesitter.ts_utils")
local parsers = require("nvim-treesitter.parsers")
local parser = parsers.get_parser(target_bufnr)
if parser then
@@ -528,34 +556,38 @@ function M.apply(patch)
end
-- Clamp to valid range
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
start_line = math.max(1, start_line)
end_line = math.min(line_count, end_line)
-- Replace the range (0-indexed for nvim_buf_set_lines)
vim.api.nvim_buf_set_lines(target_bufnr, start_line - 1, end_line, false, code_lines)
inject_opts.range = { start_line = start_line, end_line = end_line }
elseif patch.injection_strategy == "insert" and patch.injection_range then
inject_opts.range = { start_line = patch.injection_range.start_line }
end
pcall(function()
local logs = require("codetyper.agent.logs")
-- Use smart injection - handles imports automatically
inject_result = inject.inject(target_bufnr, patch.generated_code, inject_opts)
-- Log injection details
pcall(function()
local logs = require("codetyper.agent.logs")
if inject_result.imports_added > 0 then
logs.add({
type = "info",
message = string.format("Replacing lines %d-%d with %d lines of code", start_line, end_line, #code_lines),
message = string.format(
"%s %d import(s), injected %d body line(s)",
inject_result.imports_merged and "Merged" or "Added",
inject_result.imports_added,
inject_result.body_lines
),
})
else
logs.add({
type = "info",
message = string.format("Injected %d line(s) of code", inject_result.body_lines),
})
end)
elseif patch.injection_strategy == "insert" and patch.injection_range then
-- Insert at the specified location
local insert_line = patch.injection_range.start_line
insert_line = math.max(1, math.min(line_count + 1, insert_line))
vim.api.nvim_buf_set_lines(target_bufnr, insert_line - 1, insert_line - 1, false, code_lines)
else
-- Default: append to end
-- Check if last line is empty, if not add a blank line for spacing
local last_line = vim.api.nvim_buf_get_lines(target_bufnr, line_count - 1, line_count, false)[1] or ""
if last_line:match("%S") then
-- Last line has content, add blank line for spacing
table.insert(code_lines, 1, "")
end
vim.api.nvim_buf_set_lines(target_bufnr, line_count, line_count, false, code_lines)
end
end)
end)
if not ok then
@@ -577,6 +609,41 @@ function M.apply(patch)
})
end)
-- Learn from successful code generation - this builds neural pathways
-- The more code is successfully applied, the better the brain becomes
pcall(function()
local brain = require("codetyper.brain")
if brain.is_initialized() then
-- Learn the successful pattern
local intent_type = patch.intent and patch.intent.type or "unknown"
local scope_type = patch.scope and patch.scope.type or "file"
local scope_name = patch.scope and patch.scope.name or ""
-- Create a meaningful summary for this learning
local summary = string.format(
"Generated %s: %s %s in %s",
intent_type,
scope_type,
scope_name ~= "" and scope_name or "",
vim.fn.fnamemodify(patch.target_path or "", ":t")
)
brain.learn({
type = "code_completion",
file = patch.target_path,
timestamp = os.time(),
data = {
intent = intent_type,
code = patch.generated_code:sub(1, 500), -- Store first 500 chars
language = vim.fn.fnamemodify(patch.target_path or "", ":e"),
function_name = scope_name,
prompt = patch.prompt_content,
confidence = patch.confidence or 0.5,
},
})
end
end)
return true, nil
end

View File

@@ -0,0 +1,128 @@
---@mod codetyper.agent.tools.base Base tool definition
---@brief [[
--- Base metatable for all LLM tools.
--- Tools extend this base to provide structured AI capabilities.
---@brief ]]
---@class CoderToolParam
---@field name string Parameter name
---@field description string Parameter description
---@field type string Parameter type ("string", "number", "boolean", "table")
---@field optional? boolean Whether the parameter is optional
---@field default? any Default value for optional parameters
---@class CoderToolReturn
---@field name string Return value name
---@field description string Return value description
---@field type string Return type
---@field optional? boolean Whether the return is optional
---@class CoderToolOpts
---@field on_log? fun(message: string) Log callback
---@field on_complete? fun(result: any, error: string|nil) Completion callback
---@field session_ctx? table Session context
---@field streaming? boolean Whether response is still streaming
---@field confirm? fun(message: string, callback: fun(ok: boolean)) Confirmation callback
---@class CoderTool
---@field name string Tool identifier
---@field description string|fun(): string Tool description
---@field params CoderToolParam[] Input parameters
---@field returns CoderToolReturn[] Return values
---@field requires_confirmation? boolean Whether tool needs user confirmation
---@field func fun(input: table, opts: CoderToolOpts): any, string|nil Tool implementation
local M = {}
M.__index = M
--- Call the tool function
---@param opts CoderToolOpts Options for the tool call
---@return any result
---@return string|nil error
function M:__call(opts, on_log, on_complete)
return self.func(opts, on_log, on_complete)
end
--- Get the tool description
---@return string
function M:get_description()
if type(self.description) == "function" then
return self.description()
end
return self.description
end
--- Validate input against parameter schema
---@param input table Input to validate
---@return boolean valid
---@return string|nil error
function M:validate_input(input)
if not self.params then
return true
end
for _, param in ipairs(self.params) do
local value = input[param.name]
-- Check required parameters
if not param.optional and value == nil then
return false, string.format("Missing required parameter: %s", param.name)
end
-- Type checking
if value ~= nil then
local actual_type = type(value)
local expected_type = param.type
-- Handle special types
if expected_type == "integer" and actual_type == "number" then
if math.floor(value) ~= value then
return false, string.format("Parameter %s must be an integer", param.name)
end
elseif expected_type ~= actual_type and expected_type ~= "any" then
return false, string.format("Parameter %s must be %s, got %s", param.name, expected_type, actual_type)
end
end
end
return true
end
--- Generate JSON schema for the tool (for LLM function calling)
---@return table schema
function M:to_schema()
local properties = {}
local required = {}
for _, param in ipairs(self.params or {}) do
local prop = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if param.default ~= nil then
prop.default = param.default
end
properties[param.name] = prop
if not param.optional then
table.insert(required, param.name)
end
end
return {
type = "function",
function_def = {
name = self.name,
description = self:get_description(),
parameters = {
type = "object",
properties = properties,
required = required,
},
},
}
end
return M

View File

@@ -0,0 +1,198 @@
---@mod codetyper.agent.tools.bash Shell command execution tool
---@brief [[
--- Tool for executing shell commands with safety checks.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "bash"
M.description = [[Executes a bash command in a shell.
IMPORTANT RULES:
- Do NOT use bash to read files (use 'view' tool instead)
- Do NOT use bash to modify files (use 'write' or 'edit' tools instead)
- Do NOT use interactive commands (vim, nano, less, etc.)
- Commands timeout after 2 minutes by default
Allowed uses:
- Running builds (make, npm run build, cargo build)
- Running tests (npm test, pytest, cargo test)
- Git operations (git status, git diff, git commit)
- Package management (npm install, pip install)
- System info commands (ls, pwd, which)]]
M.params = {
{
name = "command",
description = "The shell command to execute",
type = "string",
},
{
name = "cwd",
description = "Working directory for the command (optional)",
type = "string",
optional = true,
},
{
name = "timeout",
description = "Timeout in milliseconds (default: 120000)",
type = "integer",
optional = true,
},
}
M.returns = {
{
name = "stdout",
description = "Command output",
type = "string",
},
{
name = "error",
description = "Error message if command failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = true
--- Banned commands for safety
local BANNED_COMMANDS = {
"rm -rf /",
"rm -rf /*",
"dd if=/dev/zero",
"mkfs",
":(){ :|:& };:",
"> /dev/sda",
}
--- Banned patterns
local BANNED_PATTERNS = {
"curl.*|.*sh",
"wget.*|.*sh",
"rm%s+%-rf%s+/",
}
--- Check if command is safe
---@param command string
---@return boolean safe
---@return string|nil reason
local function is_safe_command(command)
-- Check exact matches
for _, banned in ipairs(BANNED_COMMANDS) do
if command == banned then
return false, "Command is banned for safety"
end
end
-- Check patterns
for _, pattern in ipairs(BANNED_PATTERNS) do
if command:match(pattern) then
return false, "Command matches banned pattern"
end
end
return true
end
---@param input {command: string, cwd?: string, timeout?: integer}
---@param opts CoderToolOpts
---@return string|nil result
---@return string|nil error
function M.func(input, opts)
if not input.command then
return nil, "command is required"
end
-- Safety check
local safe, reason = is_safe_command(input.command)
if not safe then
return nil, reason
end
-- Confirmation required
if M.requires_confirmation and opts.confirm then
local confirmed = false
local confirm_error = nil
opts.confirm("Execute command: " .. input.command, function(ok)
if not ok then
confirm_error = "User declined command execution"
end
confirmed = ok
end)
-- Wait for confirmation (in async context, this would be handled differently)
if confirm_error then
return nil, confirm_error
end
end
-- Log the operation
if opts.on_log then
opts.on_log("Executing: " .. input.command)
end
-- Prepare command
local cwd = input.cwd or vim.fn.getcwd()
local timeout = input.timeout or 120000
-- Execute command
local output = ""
local exit_code = 0
local job_opts = {
command = "bash",
args = { "-c", input.command },
cwd = cwd,
on_stdout = function(_, data)
if data then
output = output .. table.concat(data, "\n")
end
end,
on_stderr = function(_, data)
if data then
output = output .. table.concat(data, "\n")
end
end,
on_exit = function(_, code)
exit_code = code
end,
}
-- Run synchronously with timeout
local Job = require("plenary.job")
local job = Job:new(job_opts)
job:sync(timeout)
exit_code = job.code or 0
output = table.concat(job:result() or {}, "\n")
-- Also get stderr
local stderr = table.concat(job:stderr_result() or {}, "\n")
if stderr and stderr ~= "" then
output = output .. "\n" .. stderr
end
-- Check result
if exit_code ~= 0 then
local error_msg = string.format("Command failed with exit code %d: %s", exit_code, output)
if opts.on_complete then
opts.on_complete(nil, error_msg)
end
return nil, error_msg
end
if opts.on_complete then
opts.on_complete(output, nil)
end
return output, nil
end
return M

View File

@@ -0,0 +1,429 @@
---@mod codetyper.agent.tools.edit File editing tool with fallback matching
---@brief [[
--- Tool for making targeted edits to files using search/replace.
--- Implements multiple fallback strategies for robust matching.
--- Inspired by opencode's 9-strategy approach.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "edit"
M.description = [[Makes a targeted edit to a file by replacing text.
The old_string should match the content you want to replace. The tool uses multiple
matching strategies with fallbacks:
1. Exact match
2. Whitespace-normalized match
3. Indentation-flexible match
4. Line-trimmed match
5. Fuzzy anchor-based match
For creating new files, use old_string="" and provide the full content in new_string.
For large changes, consider using 'write' tool instead.]]
M.params = {
{
name = "path",
description = "Path to the file to edit",
type = "string",
},
{
name = "old_string",
description = "Text to find and replace (empty string to create new file or append)",
type = "string",
},
{
name = "new_string",
description = "Text to replace with",
type = "string",
},
}
M.returns = {
{
name = "success",
description = "Whether the edit was applied",
type = "boolean",
},
{
name = "error",
description = "Error message if edit failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = false
--- Normalize line endings to LF
---@param str string
---@return string
local function normalize_line_endings(str)
return str:gsub("\r\n", "\n"):gsub("\r", "\n")
end
--- Strategy 1: Exact match
---@param content string File content
---@param old_str string String to find
---@return number|nil start_pos
---@return number|nil end_pos
local function exact_match(content, old_str)
local pos = content:find(old_str, 1, true)
if pos then
return pos, pos + #old_str - 1
end
return nil, nil
end
--- Strategy 2: Whitespace-normalized match
--- Collapses all whitespace to single spaces
---@param content string
---@param old_str string
---@return number|nil start_pos
---@return number|nil end_pos
local function whitespace_normalized_match(content, old_str)
local function normalize_ws(s)
return s:gsub("%s+", " "):gsub("^%s+", ""):gsub("%s+$", "")
end
local norm_old = normalize_ws(old_str)
local lines = vim.split(content, "\n")
-- Try to find matching block
for i = 1, #lines do
local block = {}
local block_start = nil
for j = i, #lines do
table.insert(block, lines[j])
local block_text = table.concat(block, "\n")
local norm_block = normalize_ws(block_text)
if norm_block == norm_old then
-- Found match
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
local start_pos = #before + (i > 1 and 2 or 1)
local end_pos = start_pos + #block_text - 1
return start_pos, end_pos
end
-- If block is already longer than target, stop
if #norm_block > #norm_old then
break
end
end
end
return nil, nil
end
--- Strategy 3: Indentation-flexible match
--- Ignores leading whitespace differences
---@param content string
---@param old_str string
---@return number|nil start_pos
---@return number|nil end_pos
local function indentation_flexible_match(content, old_str)
local function strip_indent(s)
local lines = vim.split(s, "\n")
local result = {}
for _, line in ipairs(lines) do
table.insert(result, line:gsub("^%s+", ""))
end
return table.concat(result, "\n")
end
local stripped_old = strip_indent(old_str)
local lines = vim.split(content, "\n")
local old_lines = vim.split(old_str, "\n")
local num_old_lines = #old_lines
for i = 1, #lines - num_old_lines + 1 do
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
local block_text = table.concat(block, "\n")
if strip_indent(block_text) == stripped_old then
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
local start_pos = #before + (i > 1 and 2 or 1)
local end_pos = start_pos + #block_text - 1
return start_pos, end_pos
end
end
return nil, nil
end
--- Strategy 4: Line-trimmed match
--- Trims each line before comparing
---@param content string
---@param old_str string
---@return number|nil start_pos
---@return number|nil end_pos
local function line_trimmed_match(content, old_str)
local function trim_lines(s)
local lines = vim.split(s, "\n")
local result = {}
for _, line in ipairs(lines) do
table.insert(result, line:match("^%s*(.-)%s*$"))
end
return table.concat(result, "\n")
end
local trimmed_old = trim_lines(old_str)
local lines = vim.split(content, "\n")
local old_lines = vim.split(old_str, "\n")
local num_old_lines = #old_lines
for i = 1, #lines - num_old_lines + 1 do
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
local block_text = table.concat(block, "\n")
if trim_lines(block_text) == trimmed_old then
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
local start_pos = #before + (i > 1 and 2 or 1)
local end_pos = start_pos + #block_text - 1
return start_pos, end_pos
end
end
return nil, nil
end
--- Calculate Levenshtein distance between two strings
---@param s1 string
---@param s2 string
---@return number
local function levenshtein(s1, s2)
local len1, len2 = #s1, #s2
local matrix = {}
for i = 0, len1 do
matrix[i] = { [0] = i }
end
for j = 0, len2 do
matrix[0][j] = j
end
for i = 1, len1 do
for j = 1, len2 do
local cost = s1:sub(i, i) == s2:sub(j, j) and 0 or 1
matrix[i][j] = math.min(
matrix[i - 1][j] + 1,
matrix[i][j - 1] + 1,
matrix[i - 1][j - 1] + cost
)
end
end
return matrix[len1][len2]
end
--- Strategy 5: Fuzzy anchor-based match
--- Uses first and last lines as anchors, allows fuzzy matching in between
---@param content string
---@param old_str string
---@param threshold? number Similarity threshold (0-1), default 0.8
---@return number|nil start_pos
---@return number|nil end_pos
local function fuzzy_anchor_match(content, old_str, threshold)
threshold = threshold or 0.8
local old_lines = vim.split(old_str, "\n")
if #old_lines < 2 then
return nil, nil
end
local first_line = old_lines[1]:match("^%s*(.-)%s*$")
local last_line = old_lines[#old_lines]:match("^%s*(.-)%s*$")
local content_lines = vim.split(content, "\n")
-- Find potential start positions
local candidates = {}
for i, line in ipairs(content_lines) do
local trimmed = line:match("^%s*(.-)%s*$")
if trimmed == first_line or (
#first_line > 0 and
1 - (levenshtein(trimmed, first_line) / math.max(#trimmed, #first_line)) >= threshold
) then
table.insert(candidates, i)
end
end
-- For each candidate, look for matching end
for _, start_idx in ipairs(candidates) do
local expected_end = start_idx + #old_lines - 1
if expected_end <= #content_lines then
local end_line = content_lines[expected_end]:match("^%s*(.-)%s*$")
if end_line == last_line or (
#last_line > 0 and
1 - (levenshtein(end_line, last_line) / math.max(#end_line, #last_line)) >= threshold
) then
-- Calculate positions
local before = table.concat(vim.list_slice(content_lines, 1, start_idx - 1), "\n")
local block = table.concat(vim.list_slice(content_lines, start_idx, expected_end), "\n")
local start_pos = #before + (start_idx > 1 and 2 or 1)
local end_pos = start_pos + #block - 1
return start_pos, end_pos
end
end
end
return nil, nil
end
--- Try all matching strategies in order
---@param content string File content
---@param old_str string String to find
---@return number|nil start_pos
---@return number|nil end_pos
---@return string strategy_used
local function find_match(content, old_str)
-- Strategy 1: Exact match
local start_pos, end_pos = exact_match(content, old_str)
if start_pos then
return start_pos, end_pos, "exact"
end
-- Strategy 2: Whitespace-normalized
start_pos, end_pos = whitespace_normalized_match(content, old_str)
if start_pos then
return start_pos, end_pos, "whitespace_normalized"
end
-- Strategy 3: Indentation-flexible
start_pos, end_pos = indentation_flexible_match(content, old_str)
if start_pos then
return start_pos, end_pos, "indentation_flexible"
end
-- Strategy 4: Line-trimmed
start_pos, end_pos = line_trimmed_match(content, old_str)
if start_pos then
return start_pos, end_pos, "line_trimmed"
end
-- Strategy 5: Fuzzy anchor
start_pos, end_pos = fuzzy_anchor_match(content, old_str)
if start_pos then
return start_pos, end_pos, "fuzzy_anchor"
end
return nil, nil, "none"
end
---@param input {path: string, old_string: string, new_string: string}
---@param opts CoderToolOpts
---@return boolean|nil result
---@return string|nil error
function M.func(input, opts)
if not input.path then
return nil, "path is required"
end
if input.old_string == nil then
return nil, "old_string is required"
end
if input.new_string == nil then
return nil, "new_string is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Editing file: " .. input.path)
end
-- Resolve path
local path = input.path
if not vim.startswith(path, "/") then
path = vim.fn.getcwd() .. "/" .. path
end
-- Normalize inputs
local old_str = normalize_line_endings(input.old_string)
local new_str = normalize_line_endings(input.new_string)
-- Handle new file creation (empty old_string)
if old_str == "" then
-- Create parent directories
local dir = vim.fn.fnamemodify(path, ":h")
if vim.fn.isdirectory(dir) == 0 then
vim.fn.mkdir(dir, "p")
end
-- Write new file
local lines = vim.split(new_str, "\n", { plain = true })
local ok = pcall(vim.fn.writefile, lines, path)
if not ok then
return nil, "Failed to create file: " .. input.path
end
-- Reload buffer if open
local bufnr = vim.fn.bufnr(path)
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
vim.api.nvim_buf_call(bufnr, function()
vim.cmd("edit!")
end)
end
if opts.on_complete then
opts.on_complete(true, nil)
end
return true, nil
end
-- Check if file exists
if vim.fn.filereadable(path) ~= 1 then
return nil, "File not found: " .. input.path
end
-- Read current content
local lines = vim.fn.readfile(path)
if not lines then
return nil, "Failed to read file: " .. input.path
end
local content = normalize_line_endings(table.concat(lines, "\n"))
-- Find match using fallback strategies
local start_pos, end_pos, strategy = find_match(content, old_str)
if not start_pos then
return nil, "old_string not found in file (tried 5 matching strategies)"
end
if opts.on_log then
opts.on_log("Match found using strategy: " .. strategy)
end
-- Perform replacement
local new_content = content:sub(1, start_pos - 1) .. new_str .. content:sub(end_pos + 1)
-- Write back
local new_lines = vim.split(new_content, "\n", { plain = true })
local ok = pcall(vim.fn.writefile, new_lines, path)
if not ok then
return nil, "Failed to write file: " .. input.path
end
-- Reload buffer if open
local bufnr = vim.fn.bufnr(path)
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
vim.api.nvim_buf_call(bufnr, function()
vim.cmd("edit!")
end)
end
if opts.on_complete then
opts.on_complete(true, nil)
end
return true, nil
end
return M

View File

@@ -0,0 +1,146 @@
---@mod codetyper.agent.tools.glob File pattern matching tool
---@brief [[
--- Tool for finding files by glob pattern.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "glob"
M.description = [[Finds files matching a glob pattern.
Example patterns:
- "**/*.lua" - All Lua files
- "src/**/*.ts" - TypeScript files in src
- "**/test_*.py" - Test files in Python]]
M.params = {
{
name = "pattern",
description = "Glob pattern to match files",
type = "string",
},
{
name = "path",
description = "Base directory to search in (default: project root)",
type = "string",
optional = true,
},
{
name = "max_results",
description = "Maximum number of results (default: 100)",
type = "integer",
optional = true,
},
}
M.returns = {
{
name = "matches",
description = "JSON array of matching file paths",
type = "string",
},
{
name = "error",
description = "Error message if glob failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = false
---@param input {pattern: string, path?: string, max_results?: integer}
---@param opts CoderToolOpts
---@return string|nil result
---@return string|nil error
function M.func(input, opts)
if not input.pattern then
return nil, "pattern is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Finding files: " .. input.pattern)
end
-- Resolve base path
local base_path = input.path or vim.fn.getcwd()
if not vim.startswith(base_path, "/") then
base_path = vim.fn.getcwd() .. "/" .. base_path
end
local max_results = input.max_results or 100
-- Use vim.fn.glob or fd if available
local matches = {}
if vim.fn.executable("fd") == 1 then
-- Use fd for better performance
local Job = require("plenary.job")
-- Convert glob to fd pattern
local fd_pattern = input.pattern:gsub("%*%*/", ""):gsub("%*", ".*")
local job = Job:new({
command = "fd",
args = {
"--type",
"f",
"--max-results",
tostring(max_results),
"--glob",
input.pattern,
base_path,
},
cwd = base_path,
})
job:sync(30000)
matches = job:result() or {}
else
-- Fallback to vim.fn.globpath
local pattern = base_path .. "/" .. input.pattern
local files = vim.fn.glob(pattern, false, true)
for i, file in ipairs(files) do
if i > max_results then
break
end
-- Make paths relative to base_path
local relative = file:gsub("^" .. vim.pesc(base_path) .. "/", "")
table.insert(matches, relative)
end
end
-- Clean up matches
local cleaned = {}
for _, match in ipairs(matches) do
if match and match ~= "" then
-- Make relative if absolute
local relative = match
if vim.startswith(match, base_path) then
relative = match:sub(#base_path + 2)
end
table.insert(cleaned, relative)
end
end
-- Return as JSON
local result = vim.json.encode({
matches = cleaned,
total = #cleaned,
truncated = #cleaned >= max_results,
})
if opts.on_complete then
opts.on_complete(result, nil)
end
return result, nil
end
return M

View File

@@ -0,0 +1,150 @@
---@mod codetyper.agent.tools.grep Search tool
---@brief [[
--- Tool for searching file contents using ripgrep.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "grep"
M.description = [[Searches for a pattern in files using ripgrep.
Returns file paths and matching lines. Use this to find code by content.
Example patterns:
- "function foo" - Find function definitions
- "import.*react" - Find React imports
- "TODO|FIXME" - Find todo comments]]
M.params = {
{
name = "pattern",
description = "Regular expression pattern to search for",
type = "string",
},
{
name = "path",
description = "Directory or file to search in (default: project root)",
type = "string",
optional = true,
},
{
name = "include",
description = "File glob pattern to include (e.g., '*.lua')",
type = "string",
optional = true,
},
{
name = "max_results",
description = "Maximum number of results (default: 50)",
type = "integer",
optional = true,
},
}
M.returns = {
{
name = "matches",
description = "JSON array of matches with file, line_number, and content",
type = "string",
},
{
name = "error",
description = "Error message if search failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = false
---@param input {pattern: string, path?: string, include?: string, max_results?: integer}
---@param opts CoderToolOpts
---@return string|nil result
---@return string|nil error
function M.func(input, opts)
if not input.pattern then
return nil, "pattern is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Searching for: " .. input.pattern)
end
-- Build ripgrep command
local path = input.path or vim.fn.getcwd()
local max_results = input.max_results or 50
-- Resolve path
if not vim.startswith(path, "/") then
path = vim.fn.getcwd() .. "/" .. path
end
-- Check if ripgrep is available
if vim.fn.executable("rg") ~= 1 then
return nil, "ripgrep (rg) is not installed"
end
-- Build command args
local args = {
"--json",
"--max-count",
tostring(max_results),
"--no-heading",
}
if input.include then
table.insert(args, "--glob")
table.insert(args, input.include)
end
table.insert(args, input.pattern)
table.insert(args, path)
-- Execute ripgrep
local Job = require("plenary.job")
local job = Job:new({
command = "rg",
args = args,
cwd = vim.fn.getcwd(),
})
job:sync(30000) -- 30 second timeout
local results = job:result() or {}
local matches = {}
-- Parse JSON output
for _, line in ipairs(results) do
if line and line ~= "" then
local ok, parsed = pcall(vim.json.decode, line)
if ok and parsed.type == "match" then
local data = parsed.data
table.insert(matches, {
file = data.path.text,
line_number = data.line_number,
content = data.lines.text:gsub("\n$", ""),
})
end
end
end
-- Return as JSON
local result = vim.json.encode({
matches = matches,
total = #matches,
truncated = #matches >= max_results,
})
if opts.on_complete then
opts.on_complete(result, nil)
end
return result, nil
end
return M

View File

@@ -0,0 +1,308 @@
---@mod codetyper.agent.tools Tool registry and orchestration
---@brief [[
--- Registry for LLM tools with execution and schema generation.
--- Inspired by avante.nvim's tool system.
---@brief ]]
local M = {}
--- Registered tools
---@type table<string, CoderTool>
local tools = {}
--- Tool execution history for current session
---@type table[]
local execution_history = {}
--- Register a tool
---@param tool CoderTool Tool to register
function M.register(tool)
if not tool.name then
error("Tool must have a name")
end
tools[tool.name] = tool
end
--- Unregister a tool
---@param name string Tool name
function M.unregister(name)
tools[name] = nil
end
--- Get a tool by name
---@param name string Tool name
---@return CoderTool|nil
function M.get(name)
return tools[name]
end
--- Get all registered tools
---@return table<string, CoderTool>
function M.get_all()
return tools
end
--- Get tools as a list
---@param filter? fun(tool: CoderTool): boolean Optional filter function
---@return CoderTool[]
function M.list(filter)
local result = {}
for _, tool in pairs(tools) do
if not filter or filter(tool) then
table.insert(result, tool)
end
end
return result
end
--- Generate schemas for all tools (for LLM function calling)
---@param filter? fun(tool: CoderTool): boolean Optional filter function
---@return table[] schemas
function M.get_schemas(filter)
local schemas = {}
for _, tool in pairs(tools) do
if not filter or filter(tool) then
if tool.to_schema then
table.insert(schemas, tool:to_schema())
end
end
end
return schemas
end
--- Execute a tool by name
---@param name string Tool name
---@param input table Input parameters
---@param opts CoderToolOpts Execution options
---@return any result
---@return string|nil error
function M.execute(name, input, opts)
local tool = tools[name]
if not tool then
return nil, "Unknown tool: " .. name
end
-- Validate input
if tool.validate_input then
local valid, err = tool:validate_input(input)
if not valid then
return nil, err
end
end
-- Log execution
if opts.on_log then
opts.on_log(string.format("Executing tool: %s", name))
end
-- Track execution
local execution = {
tool = name,
input = input,
start_time = os.time(),
status = "running",
}
table.insert(execution_history, execution)
-- Execute the tool
local result, err = tool.func(input, opts)
-- Update execution record
execution.end_time = os.time()
execution.status = err and "error" or "completed"
execution.result = result
execution.error = err
return result, err
end
--- Process a tool call from LLM response
---@param tool_call table Tool call from LLM (name + input)
---@param opts CoderToolOpts Execution options
---@return any result
---@return string|nil error
function M.process_tool_call(tool_call, opts)
local name = tool_call.name or tool_call.function_name
local input = tool_call.input or tool_call.arguments or {}
-- Parse JSON arguments if string
if type(input) == "string" then
local ok, parsed = pcall(vim.json.decode, input)
if ok then
input = parsed
else
return nil, "Failed to parse tool arguments: " .. input
end
end
return M.execute(name, input, opts)
end
--- Get execution history
---@param limit? number Max entries to return
---@return table[]
function M.get_history(limit)
if not limit then
return execution_history
end
local result = {}
local start = math.max(1, #execution_history - limit + 1)
for i = start, #execution_history do
table.insert(result, execution_history[i])
end
return result
end
--- Clear execution history
function M.clear_history()
execution_history = {}
end
--- Load built-in tools
function M.load_builtins()
-- View file tool
local view = require("codetyper.agent.tools.view")
M.register(view)
-- Bash tool
local bash = require("codetyper.agent.tools.bash")
M.register(bash)
-- Grep tool
local grep = require("codetyper.agent.tools.grep")
M.register(grep)
-- Glob tool
local glob = require("codetyper.agent.tools.glob")
M.register(glob)
-- Write file tool
local write = require("codetyper.agent.tools.write")
M.register(write)
-- Edit tool
local edit = require("codetyper.agent.tools.edit")
M.register(edit)
end
--- Initialize tools system
function M.setup()
M.load_builtins()
end
--- Get tool definitions for LLM (lazy-loaded, OpenAI format)
--- This is accessed as M.definitions property
M.definitions = setmetatable({}, {
__call = function()
-- Ensure tools are loaded
if vim.tbl_count(tools) == 0 then
M.load_builtins()
end
return M.to_openai_format()
end,
__index = function(_, key)
-- Make it work as both function and table
if key == "get" then
return function()
if vim.tbl_count(tools) == 0 then
M.load_builtins()
end
return M.to_openai_format()
end
end
return nil
end,
})
--- Get definitions as a function (for backwards compatibility)
function M.get_definitions()
if vim.tbl_count(tools) == 0 then
M.load_builtins()
end
return M.to_openai_format()
end
--- Convert all tools to OpenAI function calling format
---@param filter? fun(tool: CoderTool): boolean Optional filter function
---@return table[] OpenAI-compatible tool definitions
function M.to_openai_format(filter)
local openai_tools = {}
for _, tool in pairs(tools) do
if not filter or filter(tool) then
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if param.default ~= nil then
properties[param.name].default = param.default
end
if not param.optional then
table.insert(required, param.name)
end
end
local description = type(tool.description) == "function" and tool.description() or tool.description
table.insert(openai_tools, {
type = "function",
["function"] = {
name = tool.name,
description = description,
parameters = {
type = "object",
properties = properties,
required = required,
},
},
})
end
end
return openai_tools
end
--- Convert all tools to Claude tool use format
---@param filter? fun(tool: CoderTool): boolean Optional filter function
---@return table[] Claude-compatible tool definitions
function M.to_claude_format(filter)
local claude_tools = {}
for _, tool in pairs(tools) do
if not filter or filter(tool) then
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if not param.optional then
table.insert(required, param.name)
end
end
local description = type(tool.description) == "function" and tool.description() or tool.description
table.insert(claude_tools, {
name = tool.name,
description = description,
input_schema = {
type = "object",
properties = properties,
required = required,
},
})
end
end
return claude_tools
end
return M

View File

@@ -0,0 +1,149 @@
---@mod codetyper.agent.tools.view File viewing tool
---@brief [[
--- Tool for reading file contents with line range support.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "view"
M.description = [[Reads the content of a file.
Usage notes:
- Provide the file path relative to the project root
- Use start_line and end_line to read specific sections
- If content is truncated, use line ranges to read in chunks
- Returns JSON with content, total_line_count, and is_truncated]]
M.params = {
{
name = "path",
description = "Path to the file (relative to project root or absolute)",
type = "string",
},
{
name = "start_line",
description = "Line number to start reading (1-indexed)",
type = "integer",
optional = true,
},
{
name = "end_line",
description = "Line number to end reading (1-indexed, inclusive)",
type = "integer",
optional = true,
},
}
M.returns = {
{
name = "content",
description = "File contents as JSON with content, total_line_count, is_truncated",
type = "string",
},
{
name = "error",
description = "Error message if file could not be read",
type = "string",
optional = true,
},
}
M.requires_confirmation = false
--- Maximum content size before truncation
local MAX_CONTENT_SIZE = 200 * 1024 -- 200KB
---@param input {path: string, start_line?: integer, end_line?: integer}
---@param opts CoderToolOpts
---@return string|nil result
---@return string|nil error
function M.func(input, opts)
if not input.path then
return nil, "path is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Reading file: " .. input.path)
end
-- Resolve path
local path = input.path
if not vim.startswith(path, "/") then
-- Relative path - resolve from project root
local root = vim.fn.getcwd()
path = root .. "/" .. path
end
-- Check if file exists
local stat = vim.uv.fs_stat(path)
if not stat then
return nil, "File not found: " .. input.path
end
if stat.type == "directory" then
return nil, "Path is a directory: " .. input.path
end
-- Read file
local lines = vim.fn.readfile(path)
if not lines then
return nil, "Failed to read file: " .. input.path
end
-- Apply line range
local start_line = input.start_line or 1
local end_line = input.end_line or #lines
start_line = math.max(1, start_line)
end_line = math.min(#lines, end_line)
local total_lines = #lines
local selected_lines = {}
for i = start_line, end_line do
table.insert(selected_lines, lines[i])
end
-- Check for truncation
local content = table.concat(selected_lines, "\n")
local is_truncated = false
if #content > MAX_CONTENT_SIZE then
-- Truncate content
local truncated_lines = {}
local size = 0
for _, line in ipairs(selected_lines) do
size = size + #line + 1
if size > MAX_CONTENT_SIZE then
is_truncated = true
break
end
table.insert(truncated_lines, line)
end
content = table.concat(truncated_lines, "\n")
end
-- Return as JSON
local result = vim.json.encode({
content = content,
total_line_count = total_lines,
is_truncated = is_truncated,
start_line = start_line,
end_line = end_line,
})
if opts.on_complete then
opts.on_complete(result, nil)
end
return result, nil
end
return M

View File

@@ -0,0 +1,101 @@
---@mod codetyper.agent.tools.write File writing tool
---@brief [[
--- Tool for creating or overwriting files.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "write"
M.description = [[Creates or overwrites a file with new content.
IMPORTANT:
- This will completely replace the file contents
- Use 'edit' tool for partial modifications
- Parent directories will be created if needed]]
M.params = {
{
name = "path",
description = "Path to the file to write",
type = "string",
},
{
name = "content",
description = "Content to write to the file",
type = "string",
},
}
M.returns = {
{
name = "success",
description = "Whether the file was written successfully",
type = "boolean",
},
{
name = "error",
description = "Error message if write failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = true
---@param input {path: string, content: string}
---@param opts CoderToolOpts
---@return boolean|nil result
---@return string|nil error
function M.func(input, opts)
if not input.path then
return nil, "path is required"
end
if not input.content then
return nil, "content is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Writing file: " .. input.path)
end
-- Resolve path
local path = input.path
if not vim.startswith(path, "/") then
path = vim.fn.getcwd() .. "/" .. path
end
-- Create parent directories
local dir = vim.fn.fnamemodify(path, ":h")
if vim.fn.isdirectory(dir) == 0 then
vim.fn.mkdir(dir, "p")
end
-- Write the file
local lines = vim.split(input.content, "\n", { plain = true })
local ok = pcall(vim.fn.writefile, lines, path)
if not ok then
return nil, "Failed to write file: " .. path
end
-- Reload buffer if open
local bufnr = vim.fn.bufnr(path)
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
vim.api.nvim_buf_call(bufnr, function()
vim.cmd("edit!")
end)
end
if opts.on_complete then
opts.on_complete(true, nil)
end
return true, nil
end
return M

View File

@@ -224,6 +224,86 @@ local function format_attached_files(attached_files)
return table.concat(parts, "")
end
--- Get coder companion file path for a target file
---@param target_path string Target file path
---@return string|nil Coder file path if exists
local function get_coder_companion_path(target_path)
if not target_path or target_path == "" then
return nil
end
-- Skip if target is already a coder file
if target_path:match("%.coder%.") then
return nil
end
local dir = vim.fn.fnamemodify(target_path, ":h")
local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension
local ext = vim.fn.fnamemodify(target_path, ":e")
local coder_path = dir .. "/" .. name .. ".coder." .. ext
if vim.fn.filereadable(coder_path) == 1 then
return coder_path
end
return nil
end
--- Read and format coder companion context (business logic, pseudo-code)
---@param target_path string Target file path
---@return string Formatted coder context
local function get_coder_context(target_path)
local coder_path = get_coder_companion_path(target_path)
if not coder_path then
return ""
end
local ok, lines = pcall(function()
return vim.fn.readfile(coder_path)
end)
if not ok or not lines or #lines == 0 then
return ""
end
local content = table.concat(lines, "\n")
-- Skip if only template comments (no actual content)
local stripped = content:gsub("^%s*", ""):gsub("%s*$", "")
if stripped == "" then
return ""
end
-- Check if there's meaningful content (not just template)
local has_content = false
for _, line in ipairs(lines) do
-- Skip comment lines that are part of the template
local trimmed = line:gsub("^%s*", "")
if not trimmed:match("^[%-#/]+%s*Coder companion")
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
and not trimmed:match("^[%-#/]+%s*Example:")
and not trimmed:match("^<!%-%-")
and trimmed ~= ""
and not trimmed:match("^[%-#/]+%s*$") then
has_content = true
break
end
end
if not has_content then
return ""
end
local ext = vim.fn.fnamemodify(coder_path, ":e")
return string.format(
"\n\n--- Business Context / Pseudo-code ---\n" ..
"The following describes the intended behavior and design for this file:\n" ..
"```%s\n%s\n```",
ext,
content:sub(1, 4000) -- Limit to 4000 chars
)
end
--- Format indexed project context for inclusion in prompt
---@param indexed_context table|nil
---@return string
@@ -309,8 +389,53 @@ local function build_prompt(event)
-- Format attached files
local attached_content = format_attached_files(event.attached_files)
-- Combine attached files and indexed context
local extra_context = attached_content .. indexed_content
-- Get coder companion context (business logic, pseudo-code)
local coder_context = get_coder_context(event.target_path)
-- Get brain memories - contextual recall based on current task
local brain_context = ""
pcall(function()
local brain = require("codetyper.brain")
if brain.is_initialized() then
-- Query brain for relevant memories based on:
-- 1. Current file (file-specific patterns)
-- 2. Prompt content (semantic similarity)
-- 3. Intent type (relevant past generations)
local query_text = event.prompt_content or ""
if event.scope and event.scope.name then
query_text = event.scope.name .. " " .. query_text
end
local result = brain.query({
query = query_text,
file = event.target_path,
max_results = 5,
types = { "pattern", "correction", "convention" },
})
if result and result.nodes and #result.nodes > 0 then
local memories = { "\n\n--- Learned Patterns & Conventions ---" }
for _, node in ipairs(result.nodes) do
if node.c then
local summary = node.c.s or ""
local detail = node.c.d or ""
if summary ~= "" then
table.insert(memories, "" .. summary)
if detail ~= "" and #detail < 200 then
table.insert(memories, " " .. detail)
end
end
end
end
if #memories > 1 then
brain_context = table.concat(memories, "\n")
end
end
end
end)
-- Combine all context sources: brain memories first, then coder context, attached files, indexed
local extra_context = brain_context .. coder_context .. attached_content .. indexed_content
-- Build context with scope information
local context = {
@@ -502,21 +627,21 @@ function M.start(worker)
end
end, worker.timeout_ms)
-- Get client and execute
local client, client_err = get_client(worker.worker_type)
if not client then
M.complete(worker, nil, client_err)
return
end
local prompt, context = build_prompt(worker.event)
-- Call the LLM
client.generate(prompt, context, function(response, err, usage)
-- Check if smart selection is enabled (memory-based provider selection)
local use_smart_selection = false
pcall(function()
local codetyper = require("codetyper")
local config = codetyper.get_config()
use_smart_selection = config.llm.smart_selection ~= false -- Default to true
end)
-- Define the response handler
local function handle_response(response, err, usage_or_metadata)
-- Cancel timeout timer
if worker.timer then
pcall(function()
-- Timer might have already fired
if type(worker.timer) == "userdata" and worker.timer.stop then
worker.timer:stop()
end
@@ -527,8 +652,45 @@ function M.start(worker)
return -- Already timed out or cancelled
end
-- Extract usage from metadata if smart_generate was used
local usage = usage_or_metadata
if type(usage_or_metadata) == "table" and usage_or_metadata.provider then
-- This is metadata from smart_generate
usage = nil
-- Update worker type to reflect actual provider used
worker.worker_type = usage_or_metadata.provider
-- Log if pondering occurred
if usage_or_metadata.pondered then
pcall(function()
local logs = require("codetyper.agent.logs")
logs.add({
type = "info",
message = string.format(
"Pondering: %s (agreement: %.0f%%)",
usage_or_metadata.corrected and "corrected" or "validated",
(usage_or_metadata.agreement or 1) * 100
),
})
end)
end
end
M.complete(worker, response, err, usage)
end)
end
-- Use smart selection or direct client
if use_smart_selection then
local llm = require("codetyper.llm")
llm.smart_generate(prompt, context, handle_response)
else
-- Get client and execute directly
local client, client_err = get_client(worker.worker_type)
if not client then
M.complete(worker, nil, client_err)
return
end
client.generate(prompt, context, handle_response)
end
end
--- Complete worker execution