Adding more features
This commit is contained in:
854
lua/codetyper/agent/agentic.lua
Normal file
854
lua/codetyper/agent/agentic.lua
Normal file
@@ -0,0 +1,854 @@
|
||||
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
|
||||
---@brief [[
|
||||
--- Full agentic system that handles multi-file changes via tool calling.
|
||||
--- Inspired by avante.nvim and opencode patterns.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgenticMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_calls? table[] For assistant messages with tool calls
|
||||
---@field tool_call_id? string For tool result messages
|
||||
---@field name? string Tool name for tool results
|
||||
|
||||
---@class AgenticToolCall
|
||||
---@field id string Unique tool call ID
|
||||
---@field type "function"
|
||||
---@field function {name: string, arguments: string|table}
|
||||
|
||||
---@class AgenticOpts
|
||||
---@field task string The task to accomplish
|
||||
---@field files? string[] Initial files to include as context
|
||||
---@field agent? string Agent name to use (default: "coder")
|
||||
---@field model? string Model override
|
||||
---@field max_iterations? number Max tool call rounds (default: 20)
|
||||
---@field on_message? fun(msg: AgenticMessage) Called for each message
|
||||
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
|
||||
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_file_change? fun(path: string, action: string) Called when file is modified
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
|
||||
---@field on_status? fun(status: string) Status updates
|
||||
|
||||
--- Generate unique tool call ID
|
||||
local function generate_tool_call_id()
|
||||
return "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF))
|
||||
end
|
||||
|
||||
--- Load agent definition
|
||||
---@param name string Agent name
|
||||
---@return table|nil agent definition
|
||||
local function load_agent(name)
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
local agent_file = agents_dir .. "/" .. name .. ".md"
|
||||
|
||||
-- Check if custom agent exists
|
||||
if vim.fn.filereadable(agent_file) == 1 then
|
||||
local content = table.concat(vim.fn.readfile(agent_file), "\n")
|
||||
-- Parse frontmatter and content
|
||||
local frontmatter = {}
|
||||
local body = content
|
||||
|
||||
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
|
||||
if fm_match then
|
||||
-- Parse YAML-like frontmatter
|
||||
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
|
||||
local key, value = line:match("^(%w+):%s*(.+)$")
|
||||
if key and value then
|
||||
frontmatter[key] = value
|
||||
end
|
||||
end
|
||||
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
|
||||
end
|
||||
|
||||
return {
|
||||
name = name,
|
||||
description = frontmatter.description or "Custom agent: " .. name,
|
||||
system_prompt = body,
|
||||
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
|
||||
model = frontmatter.model,
|
||||
}
|
||||
end
|
||||
|
||||
-- Built-in agents
|
||||
local builtin_agents = {
|
||||
coder = {
|
||||
name = "coder",
|
||||
description = "Full-featured coding agent with file modification capabilities",
|
||||
system_prompt = [[You are an expert software engineer. You have access to tools to read, write, and modify files.
|
||||
|
||||
## Your Capabilities
|
||||
- Read files to understand the codebase
|
||||
- Search for patterns with grep and glob
|
||||
- Create new files with write tool
|
||||
- Edit existing files with precise replacements
|
||||
- Execute shell commands for builds and tests
|
||||
|
||||
## Guidelines
|
||||
1. Always read relevant files before making changes
|
||||
2. Make minimal, focused changes
|
||||
3. Follow existing code style and patterns
|
||||
4. Create tests when adding new functionality
|
||||
5. Verify changes work by running tests or builds
|
||||
|
||||
## Important Rules
|
||||
- NEVER guess file contents - always read first
|
||||
- Make precise edits using exact string matching
|
||||
- Explain your reasoning before making changes
|
||||
- If unsure, ask for clarification]],
|
||||
tools = { "view", "edit", "write", "grep", "glob", "bash" },
|
||||
},
|
||||
planner = {
|
||||
name = "planner",
|
||||
description = "Planning agent - read-only, helps design implementations",
|
||||
system_prompt = [[You are a software architect. Analyze codebases and create implementation plans.
|
||||
|
||||
You can read files and search the codebase, but cannot modify files.
|
||||
Your role is to:
|
||||
1. Understand the existing architecture
|
||||
2. Identify relevant files and patterns
|
||||
3. Create step-by-step implementation plans
|
||||
4. Suggest which files to modify and how
|
||||
|
||||
Be thorough in your analysis before making recommendations.]],
|
||||
tools = { "view", "grep", "glob" },
|
||||
},
|
||||
explorer = {
|
||||
name = "explorer",
|
||||
description = "Exploration agent - quickly find information in codebase",
|
||||
system_prompt = [[You are a codebase exploration assistant. Find information quickly and report back.
|
||||
|
||||
Your goal is to efficiently search and summarize findings.
|
||||
Use glob to find files, grep to search content, and view to read specific files.
|
||||
Be concise and focused in your responses.]],
|
||||
tools = { "view", "grep", "glob" },
|
||||
},
|
||||
}
|
||||
|
||||
return builtin_agents[name]
|
||||
end
|
||||
|
||||
--- Load rules from .coder/rules/
|
||||
---@return string Combined rules content
|
||||
local function load_rules()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
local rules = {}
|
||||
|
||||
if vim.fn.isdirectory(rules_dir) == 1 then
|
||||
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local content = table.concat(vim.fn.readfile(file), "\n")
|
||||
local filename = vim.fn.fnamemodify(file, ":t:r")
|
||||
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
|
||||
end
|
||||
end
|
||||
|
||||
if #rules > 0 then
|
||||
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
--- Build messages array for API request
|
||||
---@param history AgenticMessage[]
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted messages
|
||||
local function build_messages(history, provider)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
if provider == "claude" then
|
||||
-- Claude uses system parameter, not message
|
||||
-- Skip system messages in array
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
if provider == "claude" then
|
||||
-- Claude format: content is array of blocks
|
||||
message.content = {}
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(message.content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
for _, tc in ipairs(msg.tool_calls) do
|
||||
table.insert(message.content, {
|
||||
type = "tool_use",
|
||||
id = tc.id,
|
||||
name = tc["function"].name,
|
||||
input = type(tc["function"].arguments) == "string"
|
||||
and vim.json.decode(tc["function"].arguments)
|
||||
or tc["function"].arguments,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
if provider == "claude" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = msg.tool_call_id,
|
||||
content = msg.content,
|
||||
},
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Build tools array for API request
|
||||
---@param tool_names string[] Tool names to include
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted tools
|
||||
local function build_tools(tool_names, provider)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local tools = {}
|
||||
|
||||
for _, name in ipairs(tool_names) do
|
||||
local tool = tools_mod.get(name)
|
||||
if tool then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
if provider == "claude" then
|
||||
table.insert(tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Execute a tool call
|
||||
---@param tool_call AgenticToolCall
|
||||
---@param opts AgenticOpts
|
||||
---@return string result
|
||||
---@return string|nil error
|
||||
local function execute_tool(tool_call, opts)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local name = tool_call["function"].name
|
||||
local args = tool_call["function"].arguments
|
||||
|
||||
-- Parse arguments if string
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
else
|
||||
return "", "Failed to parse tool arguments: " .. args
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify tool start
|
||||
if opts.on_tool_start then
|
||||
opts.on_tool_start(name, args)
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status("Executing: " .. name)
|
||||
end
|
||||
|
||||
-- Execute the tool
|
||||
local tool = tools_mod.get(name)
|
||||
if not tool then
|
||||
local err = "Unknown tool: " .. name
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, nil, err)
|
||||
end
|
||||
return "", err
|
||||
end
|
||||
|
||||
local result, err = tool.func(args, {
|
||||
on_log = function(msg)
|
||||
if opts.on_status then
|
||||
opts.on_status(msg)
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
-- Notify tool end
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, result, err)
|
||||
end
|
||||
|
||||
-- Track file changes
|
||||
if opts.on_file_change and (name == "write" or name == "edit") and not err then
|
||||
opts.on_file_change(args.path, name == "write" and "created" or "modified")
|
||||
end
|
||||
|
||||
if err then
|
||||
return "", err
|
||||
end
|
||||
|
||||
return type(result) == "string" and result or vim.json.encode(result), nil
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return AgenticToolCall[]
|
||||
local function parse_tool_calls(response, provider)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Unified format: content array with tool_use blocks
|
||||
local content = response.content or {}
|
||||
for _, block in ipairs(content) do
|
||||
if block.type == "tool_use" then
|
||||
-- OpenAI expects arguments as JSON string, not table
|
||||
local args = block.input
|
||||
if type(args) == "table" then
|
||||
args = vim.json.encode(args)
|
||||
end
|
||||
|
||||
table.insert(tool_calls, {
|
||||
id = block.id or generate_tool_call_id(),
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = block.name,
|
||||
arguments = args,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Extract text content from response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return string
|
||||
local function extract_content(response, provider)
|
||||
local parts = {}
|
||||
for _, block in ipairs(response.content or {}) do
|
||||
if block.type == "text" then
|
||||
table.insert(parts, block.text)
|
||||
end
|
||||
end
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
--- Check if response indicates completion (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return boolean
|
||||
local function is_complete(response, provider)
|
||||
return response.stop_reason == "end_turn"
|
||||
end
|
||||
|
||||
--- Make API request to LLM with native tool calling support
|
||||
---@param messages table[] Formatted messages
|
||||
---@param tools table[] Formatted tools
|
||||
---@param system_prompt string System prompt
|
||||
---@param provider string "openai"|"claude"|"copilot"
|
||||
---@param model string Model name
|
||||
---@param callback fun(response: table|nil, error: string|nil)
|
||||
local function call_llm(messages, tools, system_prompt, provider, model, callback)
|
||||
local context = {
|
||||
language = "lua",
|
||||
file_content = "",
|
||||
prompt_type = "agent",
|
||||
project_root = vim.fn.getcwd(),
|
||||
cwd = vim.fn.getcwd(),
|
||||
}
|
||||
|
||||
-- Use native tool calling APIs
|
||||
if provider == "copilot" then
|
||||
local client = require("codetyper.llm.copilot")
|
||||
|
||||
-- Copilot's generate_with_tools expects messages in a specific format
|
||||
-- Convert to the format it expects
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
-- Convert to our internal format
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or generate_tool_call_id(),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "openai" then
|
||||
local client = require("codetyper.llm.openai")
|
||||
|
||||
-- OpenAI's generate_with_tools
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or generate_tool_call_id(),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "ollama" then
|
||||
local client = require("codetyper.llm.ollama")
|
||||
|
||||
-- Ollama's generate_with_tools (text-based tool calling)
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
callback(response, nil)
|
||||
end)
|
||||
else
|
||||
-- Fallback for other providers (ollama, etc.) - use text-based parsing
|
||||
local client = require("codetyper.llm." .. provider)
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, "User: " .. content)
|
||||
elseif msg.role == "assistant" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, "Assistant: " .. content)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add tool descriptions to prompt for text-based providers
|
||||
local tool_desc = "\n\n## Available Tools\n"
|
||||
tool_desc = tool_desc .. "Call tools by outputting JSON in this format:\n"
|
||||
tool_desc = tool_desc .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
|
||||
for _, tool in ipairs(tools) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
context.file_content = system_prompt .. tool_desc
|
||||
|
||||
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Parse response for tool calls (text-based fallback)
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
-- Extract text content
|
||||
local text_content = response
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = generate_tool_call_id(),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
if text_content and text_content ~= "" then
|
||||
table.insert(result.content, 1, { type = "text", text = text_content })
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Run the agentic loop
|
||||
---@param opts AgenticOpts
|
||||
function M.run(opts)
|
||||
-- Load agent
|
||||
local agent = load_agent(opts.agent or "coder")
|
||||
if not agent then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Load rules
|
||||
local rules = load_rules()
|
||||
|
||||
-- Build system prompt
|
||||
local system_prompt = agent.system_prompt .. rules
|
||||
|
||||
-- Initialize message history
|
||||
---@type AgenticMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = system_prompt },
|
||||
}
|
||||
|
||||
-- Add initial file context if provided
|
||||
if opts.files and #opts.files > 0 then
|
||||
local file_context = "# Initial Files\n"
|
||||
for _, file_path in ipairs(opts.files) do
|
||||
local content = table.concat(vim.fn.readfile(file_path) or {}, "\n")
|
||||
file_context = file_context .. string.format("\n## %s\n```\n%s\n```\n", file_path, content)
|
||||
end
|
||||
table.insert(history, { role = "user", content = file_context })
|
||||
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
|
||||
end
|
||||
|
||||
-- Add the task
|
||||
table.insert(history, { role = "user", content = opts.task })
|
||||
|
||||
-- Determine provider
|
||||
local config = require("codetyper").get_config()
|
||||
local provider = config.llm.provider or "copilot"
|
||||
-- Note: Ollama has its own handler in call_llm, don't change it
|
||||
|
||||
-- Get tools for this agent
|
||||
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
|
||||
|
||||
-- Ensure tools are loaded
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
tools_mod.setup()
|
||||
|
||||
-- Build tools for API
|
||||
local tools = build_tools(tool_names, provider)
|
||||
|
||||
-- Iteration tracking
|
||||
local iteration = 0
|
||||
local max_iterations = opts.max_iterations or 20
|
||||
|
||||
--- Process one iteration
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
|
||||
end
|
||||
|
||||
-- Build messages for API
|
||||
local messages = build_messages(history, provider)
|
||||
|
||||
-- Call LLM
|
||||
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
|
||||
if err then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, err)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract content and tool calls
|
||||
local content = extract_content(response, provider)
|
||||
local tool_calls = parse_tool_calls(response, provider)
|
||||
|
||||
-- Add assistant message to history
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(assistant_msg)
|
||||
end
|
||||
|
||||
-- Process tool calls if any
|
||||
if #tool_calls > 0 then
|
||||
for _, tc in ipairs(tool_calls) do
|
||||
local result, tool_err = execute_tool(tc, opts)
|
||||
|
||||
-- Add tool result to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = tc.id,
|
||||
name = tc["function"].name,
|
||||
content = tool_err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(tool_msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Continue the loop
|
||||
vim.schedule(process_iteration)
|
||||
else
|
||||
-- No tool calls - check if complete
|
||||
if is_complete(response, provider) or content ~= "" then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(content, nil)
|
||||
end
|
||||
else
|
||||
-- Continue if not explicitly complete
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create default agent files in .coder/agents/
|
||||
function M.init_agents_dir()
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
vim.fn.mkdir(agents_dir, "p")
|
||||
|
||||
-- Create example agent
|
||||
local example_agent = [[---
|
||||
description: Example custom agent
|
||||
tools: view,grep,glob,edit,write
|
||||
model:
|
||||
---
|
||||
|
||||
# Custom Agent
|
||||
|
||||
You are a custom coding agent. Describe your specialized behavior here.
|
||||
|
||||
## Your Role
|
||||
- Define what this agent specializes in
|
||||
- List specific capabilities
|
||||
|
||||
## Guidelines
|
||||
- Add agent-specific rules
|
||||
- Define coding standards to follow
|
||||
|
||||
## Examples
|
||||
Provide examples of how to handle common tasks.
|
||||
]]
|
||||
|
||||
local example_path = agents_dir .. "/example.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
|
||||
end
|
||||
|
||||
return agents_dir
|
||||
end
|
||||
|
||||
--- Create default rules in .coder/rules/
|
||||
function M.init_rules_dir()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
vim.fn.mkdir(rules_dir, "p")
|
||||
|
||||
-- Create example rule
|
||||
local example_rule = [[# Code Style
|
||||
|
||||
Follow these coding standards:
|
||||
|
||||
## General
|
||||
- Use consistent indentation (tabs or spaces based on project)
|
||||
- Keep lines under 100 characters
|
||||
- Add comments for complex logic
|
||||
|
||||
## Naming Conventions
|
||||
- Use descriptive variable names
|
||||
- Functions should be verbs (e.g., getUserData, calculateTotal)
|
||||
- Constants in UPPER_SNAKE_CASE
|
||||
|
||||
## Testing
|
||||
- Write tests for new functionality
|
||||
- Aim for >80% code coverage
|
||||
- Test edge cases
|
||||
|
||||
## Documentation
|
||||
- Document public APIs
|
||||
- Include usage examples
|
||||
- Keep docs up to date with code
|
||||
]]
|
||||
|
||||
local example_path = rules_dir .. "/code-style.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
|
||||
end
|
||||
|
||||
return rules_dir
|
||||
end
|
||||
|
||||
--- Initialize both agents and rules directories
|
||||
function M.init()
|
||||
M.init_agents_dir()
|
||||
M.init_rules_dir()
|
||||
end
|
||||
|
||||
--- List available agents
|
||||
---@return table[] List of {name, description, builtin}
|
||||
function M.list_agents()
|
||||
local agents = {}
|
||||
|
||||
-- Built-in agents
|
||||
local builtins = { "coder", "planner", "explorer" }
|
||||
for _, name in ipairs(builtins) do
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = true,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Custom agents from .coder/agents/
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
if vim.fn.isdirectory(agents_dir) == 1 then
|
||||
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local name = vim.fn.fnamemodify(file, ":t:r")
|
||||
if not vim.tbl_contains(builtins, name) then
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return agents
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -8,11 +8,11 @@ local M = {}
|
||||
|
||||
--- Heuristic weights (must sum to 1.0)
|
||||
M.weights = {
|
||||
length = 0.15, -- Response length relative to prompt
|
||||
length = 0.15, -- Response length relative to prompt
|
||||
uncertainty = 0.30, -- Uncertainty phrases
|
||||
syntax = 0.25, -- Syntax completeness
|
||||
repetition = 0.15, -- Duplicate lines
|
||||
truncation = 0.15, -- Incomplete ending
|
||||
syntax = 0.25, -- Syntax completeness
|
||||
repetition = 0.15, -- Duplicate lines
|
||||
truncation = 0.15, -- Incomplete ending
|
||||
}
|
||||
|
||||
--- Uncertainty phrases that indicate low confidence
|
||||
@@ -255,14 +255,15 @@ function M.score(response, prompt, context)
|
||||
_ = context -- Reserved for future use
|
||||
|
||||
if not response or #response == 0 then
|
||||
return 0, {
|
||||
length = 0,
|
||||
uncertainty = 0,
|
||||
syntax = 0,
|
||||
repetition = 0,
|
||||
truncation = 0,
|
||||
weighted_total = 0,
|
||||
}
|
||||
return 0,
|
||||
{
|
||||
length = 0,
|
||||
uncertainty = 0,
|
||||
syntax = 0,
|
||||
repetition = 0,
|
||||
truncation = 0,
|
||||
weighted_total = 0,
|
||||
}
|
||||
end
|
||||
|
||||
local scores = {
|
||||
|
||||
@@ -111,7 +111,11 @@ function M.agent_loop(context, callbacks)
|
||||
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
|
||||
|
||||
-- Generate with tools enabled
|
||||
client.generate_with_tools(state.conversation, context, tools.definitions, function(response, err)
|
||||
-- Ensure tools are loaded and get definitions
|
||||
tools.setup()
|
||||
local tool_defs = tools.to_openai_format()
|
||||
|
||||
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
|
||||
if err then
|
||||
state.is_running = false
|
||||
callbacks.on_error(err)
|
||||
|
||||
614
lua/codetyper/agent/inject.lua
Normal file
614
lua/codetyper/agent/inject.lua
Normal file
@@ -0,0 +1,614 @@
|
||||
---@mod codetyper.agent.inject Smart code injection with import handling
|
||||
---@brief [[
|
||||
--- Intelligent code injection that properly handles imports, merging them
|
||||
--- into existing import sections instead of blindly appending.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class ImportConfig
|
||||
---@field pattern string Lua pattern to match import statements
|
||||
---@field multi_line boolean Whether imports can span multiple lines
|
||||
---@field sort_key function|nil Function to extract sort key from import
|
||||
---@field group_by function|nil Function to group imports
|
||||
|
||||
---@class ParsedCode
|
||||
---@field imports string[] Import statements
|
||||
---@field body string[] Non-import code lines
|
||||
---@field import_lines table<number, boolean> Map of line numbers that are imports
|
||||
|
||||
--- Language-specific import patterns
|
||||
local import_patterns = {
|
||||
-- JavaScript/TypeScript
|
||||
javascript = {
|
||||
{ pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true },
|
||||
{ pattern = "^%s*import%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*import%s*{", multi_line = true },
|
||||
{ pattern = "^%s*import%s*%*", multi_line = true },
|
||||
{ pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true },
|
||||
{ pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
{ pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
{ pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
},
|
||||
-- Python
|
||||
python = {
|
||||
{ pattern = "^%s*import%s+%w", multi_line = false },
|
||||
{ pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true },
|
||||
},
|
||||
-- Lua
|
||||
lua = {
|
||||
{ pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require%s*%(?['\"]", multi_line = false },
|
||||
},
|
||||
-- Go
|
||||
go = {
|
||||
{ pattern = "^%s*import%s+%(?", multi_line = true },
|
||||
},
|
||||
-- Rust
|
||||
rust = {
|
||||
{ pattern = "^%s*use%s+", multi_line = true },
|
||||
{ pattern = "^%s*extern%s+crate%s+", multi_line = false },
|
||||
},
|
||||
-- C/C++
|
||||
c = {
|
||||
{ pattern = "^%s*#include%s*[<\"]", multi_line = false },
|
||||
},
|
||||
-- Java/Kotlin
|
||||
java = {
|
||||
{ pattern = "^%s*import%s+", multi_line = false },
|
||||
},
|
||||
-- Ruby
|
||||
ruby = {
|
||||
{ pattern = "^%s*require%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require_relative%s+['\"]", multi_line = false },
|
||||
},
|
||||
-- PHP
|
||||
php = {
|
||||
{ pattern = "^%s*use%s+", multi_line = false },
|
||||
{ pattern = "^%s*require%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require_once%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*include%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*include_once%s+['\"]", multi_line = false },
|
||||
},
|
||||
}
|
||||
|
||||
-- Alias common extensions to language configs
|
||||
import_patterns.ts = import_patterns.javascript
|
||||
import_patterns.tsx = import_patterns.javascript
|
||||
import_patterns.jsx = import_patterns.javascript
|
||||
import_patterns.mjs = import_patterns.javascript
|
||||
import_patterns.cjs = import_patterns.javascript
|
||||
import_patterns.py = import_patterns.python
|
||||
import_patterns.cpp = import_patterns.c
|
||||
import_patterns.hpp = import_patterns.c
|
||||
import_patterns.h = import_patterns.c
|
||||
import_patterns.kt = import_patterns.java
|
||||
import_patterns.rs = import_patterns.rust
|
||||
import_patterns.rb = import_patterns.ruby
|
||||
|
||||
--- Check if a line is an import statement for the given language
|
||||
---@param line string
|
||||
---@param patterns table[] Import patterns for the language
|
||||
---@return boolean is_import
|
||||
---@return boolean is_multi_line
|
||||
local function is_import_line(line, patterns)
|
||||
for _, p in ipairs(patterns) do
|
||||
if line:match(p.pattern) then
|
||||
return true, p.multi_line or false
|
||||
end
|
||||
end
|
||||
return false, false
|
||||
end
|
||||
|
||||
--- Check if a line is empty or a comment
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function is_empty_or_comment(line, filetype)
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
if trimmed == "" then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Language-specific comment patterns
|
||||
local comment_patterns = {
|
||||
lua = { "^%-%-" },
|
||||
python = { "^#" },
|
||||
javascript = { "^//", "^/%*", "^%*" },
|
||||
typescript = { "^//", "^/%*", "^%*" },
|
||||
go = { "^//", "^/%*", "^%*" },
|
||||
rust = { "^//", "^/%*", "^%*" },
|
||||
c = { "^//", "^/%*", "^%*", "^#" },
|
||||
java = { "^//", "^/%*", "^%*" },
|
||||
ruby = { "^#" },
|
||||
php = { "^//", "^/%*", "^%*", "^#" },
|
||||
}
|
||||
|
||||
local patterns = comment_patterns[filetype] or comment_patterns.javascript
|
||||
for _, pattern in ipairs(patterns) do
|
||||
if trimmed:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if a line ends a multi-line import
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function ends_multiline_import(line, filetype)
|
||||
-- Check for closing patterns
|
||||
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
|
||||
-- ES6 imports end with 'from "..." ;' or just ';' or a line with just '}'
|
||||
if line:match("from%s+['\"][^'\"]+['\"]%s*;?%s*$") then
|
||||
return true
|
||||
end
|
||||
if line:match("}%s*from%s+['\"]") then
|
||||
return true
|
||||
end
|
||||
if line:match("^%s*}%s*;?%s*$") then
|
||||
return true
|
||||
end
|
||||
if line:match(";%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "python" or filetype == "py" then
|
||||
-- Python single-line import: doesn't end with \, (, or ,
|
||||
-- Examples: "from typing import List, Dict" or "import os"
|
||||
if not line:match("\\%s*$") and not line:match("%(%s*$") and not line:match(",%s*$") then
|
||||
return true
|
||||
end
|
||||
-- Python multiline imports end with closing paren
|
||||
if line:match("%)%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "go" then
|
||||
-- Go multi-line imports end with ')'
|
||||
if line:match("%)%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "rust" or filetype == "rs" then
|
||||
-- Rust use statements end with ';'
|
||||
if line:match(";%s*$") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Parse code into imports and body
|
||||
---@param code string|string[] Code to parse
|
||||
---@param filetype string File type/extension
|
||||
---@return ParsedCode
|
||||
function M.parse_code(code, filetype)
|
||||
local lines
|
||||
if type(code) == "string" then
|
||||
lines = vim.split(code, "\n", { plain = true })
|
||||
else
|
||||
lines = code
|
||||
end
|
||||
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local result = {
|
||||
imports = {},
|
||||
body = {},
|
||||
import_lines = {},
|
||||
}
|
||||
|
||||
local in_multiline_import = false
|
||||
local current_import_lines = {}
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline_import then
|
||||
-- Continue collecting multi-line import
|
||||
table.insert(current_import_lines, line)
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
-- Complete the multi-line import
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
for j = i - #current_import_lines + 1, i do
|
||||
result.import_lines[j] = true
|
||||
end
|
||||
current_import_lines = {}
|
||||
in_multiline_import = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
result.import_lines[i] = true
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
-- Start of multi-line import
|
||||
in_multiline_import = true
|
||||
current_import_lines = { line }
|
||||
else
|
||||
-- Single-line import
|
||||
table.insert(result.imports, line)
|
||||
end
|
||||
else
|
||||
-- Non-import line
|
||||
table.insert(result.body, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
|
||||
if #current_import_lines > 0 then
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Find the import section range in a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return number|nil start_line First import line (1-indexed)
|
||||
---@return number|nil end_line Last import line (1-indexed)
|
||||
function M.find_import_section(bufnr, filetype)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local first_import = nil
|
||||
local last_import = nil
|
||||
local in_multiline = false
|
||||
local consecutive_non_import = 0
|
||||
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline then
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
in_multiline = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
if not first_import then
|
||||
first_import = i
|
||||
end
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
in_multiline = true
|
||||
end
|
||||
elseif is_empty_or_comment(line, filetype) then
|
||||
-- Allow gaps in import section
|
||||
if first_import then
|
||||
consecutive_non_import = consecutive_non_import + 1
|
||||
if consecutive_non_import > max_gap then
|
||||
-- Too many non-import lines, import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
else
|
||||
-- Non-import, non-empty line
|
||||
if first_import then
|
||||
-- Import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return first_import, last_import
|
||||
end
|
||||
|
||||
--- Get existing imports from a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return string[] Existing import statements
|
||||
function M.get_existing_imports(bufnr, filetype)
|
||||
local start_line, end_line = M.find_import_section(bufnr, filetype)
|
||||
if not start_line then
|
||||
return {}
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local parsed = M.parse_code(lines, filetype)
|
||||
return parsed.imports
|
||||
end
|
||||
|
||||
--- Normalize an import for comparison (remove whitespace variations)
|
||||
---@param import_str string
|
||||
---@return string
|
||||
local function normalize_import(import_str)
|
||||
-- Remove trailing semicolon for comparison
|
||||
local normalized = import_str:gsub(";%s*$", "")
|
||||
-- Remove all whitespace around braces, commas, colons
|
||||
normalized = normalized:gsub("%s*{%s*", "{")
|
||||
normalized = normalized:gsub("%s*}%s*", "}")
|
||||
normalized = normalized:gsub("%s*,%s*", ",")
|
||||
normalized = normalized:gsub("%s*:%s*", ":")
|
||||
-- Collapse multiple whitespace to single space
|
||||
normalized = normalized:gsub("%s+", " ")
|
||||
-- Trim leading/trailing whitespace
|
||||
normalized = normalized:match("^%s*(.-)%s*$")
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if two imports are duplicates
|
||||
---@param import1 string
|
||||
---@param import2 string
|
||||
---@return boolean
|
||||
local function are_duplicate_imports(import1, import2)
|
||||
return normalize_import(import1) == normalize_import(import2)
|
||||
end
|
||||
|
||||
--- Merge new imports with existing ones, avoiding duplicates
|
||||
---@param existing string[] Existing imports
|
||||
---@param new_imports string[] New imports to merge
|
||||
---@return string[] Merged imports
|
||||
function M.merge_imports(existing, new_imports)
|
||||
local merged = {}
|
||||
local seen = {}
|
||||
|
||||
-- Add existing imports
|
||||
for _, imp in ipairs(existing) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new imports that aren't duplicates
|
||||
for _, imp in ipairs(new_imports) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
return merged
|
||||
end
|
||||
|
||||
--- Sort imports by their source/module
|
||||
---@param imports string[]
|
||||
---@param filetype string
|
||||
---@return string[]
|
||||
function M.sort_imports(imports, filetype)
|
||||
-- Group imports: stdlib/builtin first, then third-party, then local
|
||||
local builtin = {}
|
||||
local third_party = {}
|
||||
local local_imports = {}
|
||||
|
||||
for _, imp in ipairs(imports) do
|
||||
-- Detect import type based on patterns
|
||||
local is_local = false
|
||||
local is_builtin = false
|
||||
|
||||
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
|
||||
-- Local: starts with . or ..
|
||||
is_local = imp:match("from%s+['\"]%.") or imp:match("require%(['\"]%.")
|
||||
-- Node builtin modules
|
||||
is_builtin = imp:match("from%s+['\"]node:") or imp:match("from%s+['\"]fs['\"]")
|
||||
or imp:match("from%s+['\"]path['\"]") or imp:match("from%s+['\"]http['\"]")
|
||||
elseif filetype == "python" or filetype == "py" then
|
||||
-- Local: relative imports
|
||||
is_local = imp:match("^from%s+%.") or imp:match("^import%s+%.")
|
||||
-- Python stdlib (simplified check)
|
||||
is_builtin = imp:match("^import%s+os") or imp:match("^import%s+sys")
|
||||
or imp:match("^from%s+os%s+") or imp:match("^from%s+sys%s+")
|
||||
or imp:match("^import%s+re") or imp:match("^import%s+json")
|
||||
elseif filetype == "lua" then
|
||||
-- Local: relative requires
|
||||
is_local = imp:match("require%(['\"]%.") or imp:match("require%s+['\"]%.")
|
||||
elseif filetype == "go" then
|
||||
-- Local: project imports (contain /)
|
||||
is_local = imp:match("['\"][^'\"]+/[^'\"]+['\"]") and not imp:match("github%.com")
|
||||
end
|
||||
|
||||
if is_builtin then
|
||||
table.insert(builtin, imp)
|
||||
elseif is_local then
|
||||
table.insert(local_imports, imp)
|
||||
else
|
||||
table.insert(third_party, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort each group alphabetically
|
||||
table.sort(builtin)
|
||||
table.sort(third_party)
|
||||
table.sort(local_imports)
|
||||
|
||||
-- Combine with proper spacing
|
||||
local result = {}
|
||||
|
||||
for _, imp in ipairs(builtin) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(third_party) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #third_party > 0 and #local_imports > 0 then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(local_imports) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
---@class InjectResult
|
||||
---@field success boolean
|
||||
---@field imports_added number Number of new imports added
|
||||
---@field imports_merged boolean Whether imports were merged into existing section
|
||||
---@field body_lines number Number of body lines injected
|
||||
|
||||
--- Smart inject code into a buffer, properly handling imports
|
||||
---@param bufnr number Target buffer
|
||||
---@param code string|string[] Code to inject
|
||||
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
|
||||
---@return InjectResult
|
||||
function M.inject(bufnr, code, opts)
|
||||
opts = opts or {}
|
||||
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
|
||||
end
|
||||
|
||||
-- Get filetype
|
||||
local filetype = opts.filetype
|
||||
if not filetype then
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
filetype = vim.fn.fnamemodify(bufname, ":e")
|
||||
end
|
||||
|
||||
-- Parse the code to separate imports from body
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
|
||||
local result = {
|
||||
success = true,
|
||||
imports_added = 0,
|
||||
imports_merged = false,
|
||||
body_lines = #parsed.body,
|
||||
}
|
||||
|
||||
-- Handle imports first if there are any
|
||||
if #parsed.imports > 0 then
|
||||
local import_start, import_end = M.find_import_section(bufnr, filetype)
|
||||
|
||||
if import_start then
|
||||
-- Merge with existing import section
|
||||
local existing_imports = M.get_existing_imports(bufnr, filetype)
|
||||
local merged = M.merge_imports(existing_imports, parsed.imports)
|
||||
|
||||
-- Count how many new imports were actually added
|
||||
result.imports_added = #merged - #existing_imports
|
||||
result.imports_merged = true
|
||||
|
||||
-- Optionally sort imports
|
||||
if opts.sort_imports ~= false then
|
||||
merged = M.sort_imports(merged, filetype)
|
||||
end
|
||||
|
||||
-- Convert back to lines (handling multi-line imports)
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(merged) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Replace the import section
|
||||
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
|
||||
|
||||
-- Adjust line numbers for body injection
|
||||
local lines_diff = #import_lines - (import_end - import_start + 1)
|
||||
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
|
||||
opts.range.start_line = opts.range.start_line + lines_diff
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + lines_diff
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No existing import section, add imports at the top
|
||||
-- Find the first non-comment, non-empty line
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local insert_at = 0
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
-- Skip shebang, docstrings, and initial comments
|
||||
if trimmed ~= "" and not trimmed:match("^#!")
|
||||
and not trimmed:match("^['\"]") and not is_empty_or_comment(line, filetype) then
|
||||
insert_at = i - 1
|
||||
break
|
||||
end
|
||||
insert_at = i
|
||||
end
|
||||
|
||||
-- Add imports with a trailing blank line
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(parsed.imports) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(import_lines, "") -- Blank line after imports
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
|
||||
result.imports_added = #parsed.imports
|
||||
result.imports_merged = false
|
||||
|
||||
-- Adjust body injection range
|
||||
if opts.range and opts.range.start_line then
|
||||
opts.range.start_line = opts.range.start_line + #import_lines
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + #import_lines
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle body (non-import) code
|
||||
if #parsed.body > 0 then
|
||||
-- Filter out empty leading/trailing lines from body
|
||||
local body_lines = parsed.body
|
||||
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
|
||||
table.remove(body_lines, 1)
|
||||
end
|
||||
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
|
||||
table.remove(body_lines)
|
||||
end
|
||||
|
||||
if #body_lines > 0 then
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
local strategy = opts.strategy or "append"
|
||||
|
||||
if strategy == "replace" and opts.range then
|
||||
local start_line = math.max(1, opts.range.start_line)
|
||||
local end_line = math.min(line_count, opts.range.end_line)
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
|
||||
elseif strategy == "insert" and opts.range then
|
||||
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
|
||||
else
|
||||
-- Default: append
|
||||
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Add blank line for spacing
|
||||
table.insert(body_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
|
||||
end
|
||||
|
||||
result.body_lines = #body_lines
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Check if code contains imports
|
||||
---@param code string|string[]
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
function M.has_imports(code, filetype)
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
return #parsed.imports > 0
|
||||
end
|
||||
|
||||
return M
|
||||
398
lua/codetyper/agent/loop.lua
Normal file
398
lua/codetyper/agent/loop.lua
Normal file
@@ -0,0 +1,398 @@
|
||||
---@mod codetyper.agent.loop Agent loop with tool orchestration
|
||||
---@brief [[
|
||||
--- Main agent loop that handles multi-turn conversations with tool use.
|
||||
--- Inspired by avante.nvim's agent_loop pattern.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgentMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_call_id? string For tool responses
|
||||
---@field tool_calls? table[] For assistant tool calls
|
||||
---@field name? string Tool name for tool responses
|
||||
|
||||
---@class AgentLoopOpts
|
||||
---@field system_prompt string System prompt
|
||||
---@field user_input string Initial user message
|
||||
---@field tools? CoderTool[] Available tools (default: all registered)
|
||||
---@field max_iterations? number Max tool call iterations (default: 10)
|
||||
---@field provider? string LLM provider to use
|
||||
---@field on_start? fun() Called when loop starts
|
||||
---@field on_chunk? fun(chunk: string) Called for each response chunk
|
||||
---@field on_tool_call? fun(name: string, input: table) Called before tool execution
|
||||
---@field on_tool_result? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_message? fun(message: AgentMessage) Called for each message added
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when loop completes
|
||||
---@field session_ctx? table Session context shared across tools
|
||||
|
||||
--- Format tool definitions for OpenAI-compatible API
|
||||
---@param tools CoderTool[]
|
||||
---@return table[]
|
||||
local function format_tools_for_api(tools)
|
||||
local formatted = {}
|
||||
for _, tool in ipairs(tools) do
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(formatted, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = type(tool.description) == "function" and tool.description() or tool.description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
return formatted
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response
|
||||
---@param response table LLM response
|
||||
---@return table[] tool_calls
|
||||
local function parse_tool_calls(response)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Handle different response formats
|
||||
if response.tool_calls then
|
||||
-- OpenAI format
|
||||
for _, call in ipairs(response.tool_calls) do
|
||||
local args = call["function"].arguments
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
end
|
||||
end
|
||||
table.insert(tool_calls, {
|
||||
id = call.id,
|
||||
name = call["function"].name,
|
||||
input = args,
|
||||
})
|
||||
end
|
||||
elseif response.content and type(response.content) == "table" then
|
||||
-- Claude format (content blocks)
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "tool_use" then
|
||||
table.insert(tool_calls, {
|
||||
id = block.id,
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Build messages for LLM request
|
||||
---@param history AgentMessage[]
|
||||
---@return table[]
|
||||
local function build_messages(history)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Execute the agent loop
|
||||
---@param opts AgentLoopOpts
|
||||
function M.run(opts)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local llm = require("codetyper.llm")
|
||||
|
||||
-- Get tools
|
||||
local tools = opts.tools or tools_mod.list()
|
||||
local tool_map = {}
|
||||
for _, tool in ipairs(tools) do
|
||||
tool_map[tool.name] = tool
|
||||
end
|
||||
|
||||
-- Initialize conversation history
|
||||
---@type AgentMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = opts.system_prompt },
|
||||
{ role = "user", content = opts.user_input },
|
||||
}
|
||||
|
||||
local session_ctx = opts.session_ctx or {}
|
||||
local max_iterations = opts.max_iterations or 10
|
||||
local iteration = 0
|
||||
|
||||
-- Callback wrappers
|
||||
local function on_message(msg)
|
||||
if opts.on_message then
|
||||
opts.on_message(msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify of initial messages
|
||||
for _, msg in ipairs(history) do
|
||||
on_message(msg)
|
||||
end
|
||||
|
||||
-- Start notification
|
||||
if opts.on_start then
|
||||
opts.on_start()
|
||||
end
|
||||
|
||||
--- Process one iteration of the loop
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build request
|
||||
local messages = build_messages(history)
|
||||
local formatted_tools = format_tools_for_api(tools)
|
||||
|
||||
-- Build context for LLM
|
||||
local context = {
|
||||
file_content = "",
|
||||
language = "lua",
|
||||
extension = "lua",
|
||||
prompt_type = "agent",
|
||||
tools = formatted_tools,
|
||||
}
|
||||
|
||||
-- Get LLM response
|
||||
local client = llm.get_client()
|
||||
if not client then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "No LLM client available")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or ""))
|
||||
end
|
||||
end
|
||||
local prompt = table.concat(prompt_parts, "\n\n")
|
||||
|
||||
client.generate(prompt, context, function(response, error)
|
||||
if error then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, error)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Chunk callback
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(response)
|
||||
end
|
||||
|
||||
-- Parse response for tool calls
|
||||
-- For now, we'll use a simple heuristic to detect tool calls in the response
|
||||
-- In a full implementation, the LLM would return structured tool calls
|
||||
local tool_calls = {}
|
||||
|
||||
-- Try to parse JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool_calls then
|
||||
tool_calls = parsed.tool_calls
|
||||
end
|
||||
end
|
||||
|
||||
-- Add assistant message
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = response,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
on_message(assistant_msg)
|
||||
|
||||
-- Process tool calls
|
||||
if #tool_calls > 0 then
|
||||
local pending = #tool_calls
|
||||
local results = {}
|
||||
|
||||
for i, call in ipairs(tool_calls) do
|
||||
local tool = tool_map[call.name]
|
||||
if not tool then
|
||||
results[i] = { error = "Unknown tool: " .. call.name }
|
||||
pending = pending - 1
|
||||
else
|
||||
-- Notify of tool call
|
||||
if opts.on_tool_call then
|
||||
opts.on_tool_call(call.name, call.input)
|
||||
end
|
||||
|
||||
-- Execute tool
|
||||
local tool_opts = {
|
||||
on_log = function(msg)
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({ type = "tool", message = msg })
|
||||
end)
|
||||
end,
|
||||
on_complete = function(result, err)
|
||||
results[i] = { result = result, error = err }
|
||||
pending = pending - 1
|
||||
|
||||
-- Notify of tool result
|
||||
if opts.on_tool_result then
|
||||
opts.on_tool_result(call.name, result, err)
|
||||
end
|
||||
|
||||
-- Add tool response to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = call.id or tostring(i),
|
||||
name = call.name,
|
||||
content = err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
on_message(tool_msg)
|
||||
|
||||
-- Continue loop when all tools complete
|
||||
if pending == 0 then
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end,
|
||||
session_ctx = session_ctx,
|
||||
}
|
||||
|
||||
-- Validate and execute
|
||||
local valid, validation_err = true, nil
|
||||
if tool.validate_input then
|
||||
valid, validation_err = tool:validate_input(call.input)
|
||||
end
|
||||
|
||||
if not valid then
|
||||
tool_opts.on_complete(nil, validation_err)
|
||||
else
|
||||
local result, err = tool.func(call.input, tool_opts)
|
||||
-- If sync result, call on_complete
|
||||
if result ~= nil or err ~= nil then
|
||||
tool_opts.on_complete(result, err)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No tool calls - loop complete
|
||||
if opts.on_complete then
|
||||
opts.on_complete(response, nil)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create an agent with default settings
|
||||
---@param task string Task description
|
||||
---@param opts? AgentLoopOpts Additional options
|
||||
function M.create(task, opts)
|
||||
opts = opts or {}
|
||||
|
||||
local system_prompt = opts.system_prompt or [[You are a helpful coding assistant with access to tools.
|
||||
|
||||
Available tools:
|
||||
- view: Read file contents
|
||||
- grep: Search for patterns in files
|
||||
- glob: Find files by pattern
|
||||
- edit: Make targeted edits to files
|
||||
- write: Create or overwrite files
|
||||
- bash: Execute shell commands
|
||||
|
||||
When you need to perform a task:
|
||||
1. Use tools to gather information
|
||||
2. Plan your approach
|
||||
3. Execute changes using appropriate tools
|
||||
4. Verify the results
|
||||
|
||||
Always explain your reasoning before using tools.
|
||||
When you're done, provide a clear summary of what was accomplished.]]
|
||||
|
||||
M.run(vim.tbl_extend("force", opts, {
|
||||
system_prompt = system_prompt,
|
||||
user_input = task,
|
||||
}))
|
||||
end
|
||||
|
||||
--- Simple dispatch agent for sub-tasks
|
||||
---@param prompt string Task for the sub-agent
|
||||
---@param on_complete fun(result: string|nil, error: string|nil) Completion callback
|
||||
---@param opts? table Additional options
|
||||
function M.dispatch(prompt, on_complete, opts)
|
||||
opts = opts or {}
|
||||
|
||||
-- Sub-agents get limited tools by default
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local safe_tools = tools_mod.list(function(tool)
|
||||
return tool.name == "view" or tool.name == "grep" or tool.name == "glob"
|
||||
end)
|
||||
|
||||
M.run({
|
||||
system_prompt = [[You are a research assistant. Your task is to find information and report back.
|
||||
You have access to: view (read files), grep (search content), glob (find files).
|
||||
Be thorough and report your findings clearly.]],
|
||||
user_input = prompt,
|
||||
tools = opts.tools or safe_tools,
|
||||
max_iterations = opts.max_iterations or 5,
|
||||
on_complete = on_complete,
|
||||
session_ctx = opts.session_ctx,
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -2,10 +2,16 @@
|
||||
---@brief [[
|
||||
--- Manages code patches with buffer snapshots for staleness detection.
|
||||
--- Patches are queued for safe injection when completion popup is not visible.
|
||||
--- Uses smart injection for intelligent import merging.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Lazy load inject module to avoid circular requires
|
||||
local function get_inject_module()
|
||||
return require("codetyper.agent.inject")
|
||||
end
|
||||
|
||||
---@class BufferSnapshot
|
||||
---@field bufnr number Buffer number
|
||||
---@field changedtick number vim.b.changedtick at snapshot time
|
||||
@@ -15,7 +21,8 @@ local M = {}
|
||||
---@class PatchCandidate
|
||||
---@field id string Unique patch ID
|
||||
---@field event_id string Related PromptEvent ID
|
||||
---@field target_bufnr number Target buffer for injection
|
||||
---@field source_bufnr number Source buffer where prompt tags are (coder file)
|
||||
---@field target_bufnr number Target buffer for injection (real file)
|
||||
---@field target_path string Target file path
|
||||
---@field original_snapshot BufferSnapshot Snapshot at event creation
|
||||
---@field generated_code string Code to inject
|
||||
@@ -171,7 +178,10 @@ end
|
||||
---@param strategy string|nil Injection strategy (overrides intent-based)
|
||||
---@return PatchCandidate
|
||||
function M.create_from_event(event, generated_code, confidence, strategy)
|
||||
-- Get target buffer
|
||||
-- Source buffer is where the prompt tags are (could be coder file)
|
||||
local source_bufnr = event.bufnr
|
||||
|
||||
-- Get target buffer (where code should be injected - the real file)
|
||||
local target_bufnr = vim.fn.bufnr(event.target_path)
|
||||
if target_bufnr == -1 then
|
||||
-- Try to find by filename
|
||||
@@ -220,7 +230,8 @@ function M.create_from_event(event, generated_code, confidence, strategy)
|
||||
return {
|
||||
id = M.generate_id(),
|
||||
event_id = event.id,
|
||||
target_bufnr = target_bufnr,
|
||||
source_bufnr = source_bufnr, -- Where prompt tags are (coder file)
|
||||
target_bufnr = target_bufnr, -- Where code goes (real file)
|
||||
target_path = event.target_path,
|
||||
original_snapshot = snapshot,
|
||||
generated_code = generated_code,
|
||||
@@ -453,39 +464,56 @@ function M.apply(patch)
|
||||
-- Prepare code lines
|
||||
local code_lines = vim.split(patch.generated_code, "\n", { plain = true })
|
||||
|
||||
-- FIRST: Remove the prompt tags from the buffer before applying code
|
||||
-- This prevents the infinite loop where tags stay and get re-detected
|
||||
local tags_removed = remove_prompt_tags(target_bufnr)
|
||||
-- FIRST: Remove the prompt tags from the SOURCE buffer (coder file), not target
|
||||
-- The tags are in the coder file where the user wrote the prompt
|
||||
-- Code goes to target file, tags get removed from source file
|
||||
local source_bufnr = patch.source_bufnr
|
||||
local tags_removed = 0
|
||||
|
||||
pcall(function()
|
||||
if tags_removed > 0 then
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Removed %d prompt tag(s) from buffer", tags_removed),
|
||||
})
|
||||
end
|
||||
end)
|
||||
if source_bufnr and vim.api.nvim_buf_is_valid(source_bufnr) then
|
||||
tags_removed = remove_prompt_tags(source_bufnr)
|
||||
|
||||
-- Recalculate line count after tag removal
|
||||
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
|
||||
pcall(function()
|
||||
if tags_removed > 0 then
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local source_name = vim.api.nvim_buf_get_name(source_bufnr)
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Removed %d prompt tag(s) from %s",
|
||||
tags_removed,
|
||||
vim.fn.fnamemodify(source_name, ":t")),
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Apply based on strategy
|
||||
-- Get filetype for smart injection
|
||||
local filetype = vim.fn.fnamemodify(patch.target_path or "", ":e")
|
||||
|
||||
-- Use smart injection module for intelligent import handling
|
||||
local inject = get_inject_module()
|
||||
local inject_result = nil
|
||||
|
||||
-- Apply based on strategy using smart injection
|
||||
local ok, err = pcall(function()
|
||||
-- Prepare injection options
|
||||
local inject_opts = {
|
||||
strategy = patch.injection_strategy,
|
||||
filetype = filetype,
|
||||
sort_imports = true,
|
||||
}
|
||||
|
||||
if patch.injection_strategy == "replace" and patch.injection_range then
|
||||
-- Replace the scope range with the new code
|
||||
-- The injection_range points to the function/method we're completing
|
||||
local start_line = patch.injection_range.start_line
|
||||
local end_line = patch.injection_range.end_line
|
||||
|
||||
-- Adjust for tag removal - find the new range by searching for the scope
|
||||
-- After removing tags, line numbers may have shifted
|
||||
-- Use the scope information to find the correct range
|
||||
if patch.scope and patch.scope.type then
|
||||
-- Try to find the scope using treesitter if available
|
||||
local found_range = nil
|
||||
pcall(function()
|
||||
local ts_utils = require("nvim-treesitter.ts_utils")
|
||||
local parsers = require("nvim-treesitter.parsers")
|
||||
local parser = parsers.get_parser(target_bufnr)
|
||||
if parser then
|
||||
@@ -528,34 +556,38 @@ function M.apply(patch)
|
||||
end
|
||||
|
||||
-- Clamp to valid range
|
||||
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
|
||||
start_line = math.max(1, start_line)
|
||||
end_line = math.min(line_count, end_line)
|
||||
|
||||
-- Replace the range (0-indexed for nvim_buf_set_lines)
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, start_line - 1, end_line, false, code_lines)
|
||||
inject_opts.range = { start_line = start_line, end_line = end_line }
|
||||
elseif patch.injection_strategy == "insert" and patch.injection_range then
|
||||
inject_opts.range = { start_line = patch.injection_range.start_line }
|
||||
end
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
-- Use smart injection - handles imports automatically
|
||||
inject_result = inject.inject(target_bufnr, patch.generated_code, inject_opts)
|
||||
|
||||
-- Log injection details
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
if inject_result.imports_added > 0 then
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Replacing lines %d-%d with %d lines of code", start_line, end_line, #code_lines),
|
||||
message = string.format(
|
||||
"%s %d import(s), injected %d body line(s)",
|
||||
inject_result.imports_merged and "Merged" or "Added",
|
||||
inject_result.imports_added,
|
||||
inject_result.body_lines
|
||||
),
|
||||
})
|
||||
else
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Injected %d line(s) of code", inject_result.body_lines),
|
||||
})
|
||||
end)
|
||||
elseif patch.injection_strategy == "insert" and patch.injection_range then
|
||||
-- Insert at the specified location
|
||||
local insert_line = patch.injection_range.start_line
|
||||
insert_line = math.max(1, math.min(line_count + 1, insert_line))
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, insert_line - 1, insert_line - 1, false, code_lines)
|
||||
else
|
||||
-- Default: append to end
|
||||
-- Check if last line is empty, if not add a blank line for spacing
|
||||
local last_line = vim.api.nvim_buf_get_lines(target_bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Last line has content, add blank line for spacing
|
||||
table.insert(code_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, line_count, line_count, false, code_lines)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
@@ -577,6 +609,41 @@ function M.apply(patch)
|
||||
})
|
||||
end)
|
||||
|
||||
-- Learn from successful code generation - this builds neural pathways
|
||||
-- The more code is successfully applied, the better the brain becomes
|
||||
pcall(function()
|
||||
local brain = require("codetyper.brain")
|
||||
if brain.is_initialized() then
|
||||
-- Learn the successful pattern
|
||||
local intent_type = patch.intent and patch.intent.type or "unknown"
|
||||
local scope_type = patch.scope and patch.scope.type or "file"
|
||||
local scope_name = patch.scope and patch.scope.name or ""
|
||||
|
||||
-- Create a meaningful summary for this learning
|
||||
local summary = string.format(
|
||||
"Generated %s: %s %s in %s",
|
||||
intent_type,
|
||||
scope_type,
|
||||
scope_name ~= "" and scope_name or "",
|
||||
vim.fn.fnamemodify(patch.target_path or "", ":t")
|
||||
)
|
||||
|
||||
brain.learn({
|
||||
type = "code_completion",
|
||||
file = patch.target_path,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
intent = intent_type,
|
||||
code = patch.generated_code:sub(1, 500), -- Store first 500 chars
|
||||
language = vim.fn.fnamemodify(patch.target_path or "", ":e"),
|
||||
function_name = scope_name,
|
||||
prompt = patch.prompt_content,
|
||||
confidence = patch.confidence or 0.5,
|
||||
},
|
||||
})
|
||||
end
|
||||
end)
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
|
||||
128
lua/codetyper/agent/tools/base.lua
Normal file
128
lua/codetyper/agent/tools/base.lua
Normal file
@@ -0,0 +1,128 @@
|
||||
---@mod codetyper.agent.tools.base Base tool definition
|
||||
---@brief [[
|
||||
--- Base metatable for all LLM tools.
|
||||
--- Tools extend this base to provide structured AI capabilities.
|
||||
---@brief ]]
|
||||
|
||||
---@class CoderToolParam
|
||||
---@field name string Parameter name
|
||||
---@field description string Parameter description
|
||||
---@field type string Parameter type ("string", "number", "boolean", "table")
|
||||
---@field optional? boolean Whether the parameter is optional
|
||||
---@field default? any Default value for optional parameters
|
||||
|
||||
---@class CoderToolReturn
|
||||
---@field name string Return value name
|
||||
---@field description string Return value description
|
||||
---@field type string Return type
|
||||
---@field optional? boolean Whether the return is optional
|
||||
|
||||
---@class CoderToolOpts
|
||||
---@field on_log? fun(message: string) Log callback
|
||||
---@field on_complete? fun(result: any, error: string|nil) Completion callback
|
||||
---@field session_ctx? table Session context
|
||||
---@field streaming? boolean Whether response is still streaming
|
||||
---@field confirm? fun(message: string, callback: fun(ok: boolean)) Confirmation callback
|
||||
|
||||
---@class CoderTool
|
||||
---@field name string Tool identifier
|
||||
---@field description string|fun(): string Tool description
|
||||
---@field params CoderToolParam[] Input parameters
|
||||
---@field returns CoderToolReturn[] Return values
|
||||
---@field requires_confirmation? boolean Whether tool needs user confirmation
|
||||
---@field func fun(input: table, opts: CoderToolOpts): any, string|nil Tool implementation
|
||||
|
||||
local M = {}
|
||||
M.__index = M
|
||||
|
||||
--- Call the tool function
|
||||
---@param opts CoderToolOpts Options for the tool call
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M:__call(opts, on_log, on_complete)
|
||||
return self.func(opts, on_log, on_complete)
|
||||
end
|
||||
|
||||
--- Get the tool description
|
||||
---@return string
|
||||
function M:get_description()
|
||||
if type(self.description) == "function" then
|
||||
return self.description()
|
||||
end
|
||||
return self.description
|
||||
end
|
||||
|
||||
--- Validate input against parameter schema
|
||||
---@param input table Input to validate
|
||||
---@return boolean valid
|
||||
---@return string|nil error
|
||||
function M:validate_input(input)
|
||||
if not self.params then
|
||||
return true
|
||||
end
|
||||
|
||||
for _, param in ipairs(self.params) do
|
||||
local value = input[param.name]
|
||||
|
||||
-- Check required parameters
|
||||
if not param.optional and value == nil then
|
||||
return false, string.format("Missing required parameter: %s", param.name)
|
||||
end
|
||||
|
||||
-- Type checking
|
||||
if value ~= nil then
|
||||
local actual_type = type(value)
|
||||
local expected_type = param.type
|
||||
|
||||
-- Handle special types
|
||||
if expected_type == "integer" and actual_type == "number" then
|
||||
if math.floor(value) ~= value then
|
||||
return false, string.format("Parameter %s must be an integer", param.name)
|
||||
end
|
||||
elseif expected_type ~= actual_type and expected_type ~= "any" then
|
||||
return false, string.format("Parameter %s must be %s, got %s", param.name, expected_type, actual_type)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Generate JSON schema for the tool (for LLM function calling)
|
||||
---@return table schema
|
||||
function M:to_schema()
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(self.params or {}) do
|
||||
local prop = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
|
||||
if param.default ~= nil then
|
||||
prop.default = param.default
|
||||
end
|
||||
|
||||
properties[param.name] = prop
|
||||
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
type = "function",
|
||||
function_def = {
|
||||
name = self.name,
|
||||
description = self:get_description(),
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
198
lua/codetyper/agent/tools/bash.lua
Normal file
198
lua/codetyper/agent/tools/bash.lua
Normal file
@@ -0,0 +1,198 @@
|
||||
---@mod codetyper.agent.tools.bash Shell command execution tool
|
||||
---@brief [[
|
||||
--- Tool for executing shell commands with safety checks.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "bash"
|
||||
|
||||
M.description = [[Executes a bash command in a shell.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- Do NOT use bash to read files (use 'view' tool instead)
|
||||
- Do NOT use bash to modify files (use 'write' or 'edit' tools instead)
|
||||
- Do NOT use interactive commands (vim, nano, less, etc.)
|
||||
- Commands timeout after 2 minutes by default
|
||||
|
||||
Allowed uses:
|
||||
- Running builds (make, npm run build, cargo build)
|
||||
- Running tests (npm test, pytest, cargo test)
|
||||
- Git operations (git status, git diff, git commit)
|
||||
- Package management (npm install, pip install)
|
||||
- System info commands (ls, pwd, which)]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "command",
|
||||
description = "The shell command to execute",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "cwd",
|
||||
description = "Working directory for the command (optional)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "timeout",
|
||||
description = "Timeout in milliseconds (default: 120000)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "stdout",
|
||||
description = "Command output",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if command failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = true
|
||||
|
||||
--- Banned commands for safety
|
||||
local BANNED_COMMANDS = {
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
":(){ :|:& };:",
|
||||
"> /dev/sda",
|
||||
}
|
||||
|
||||
--- Banned patterns
|
||||
local BANNED_PATTERNS = {
|
||||
"curl.*|.*sh",
|
||||
"wget.*|.*sh",
|
||||
"rm%s+%-rf%s+/",
|
||||
}
|
||||
|
||||
--- Check if command is safe
|
||||
---@param command string
|
||||
---@return boolean safe
|
||||
---@return string|nil reason
|
||||
local function is_safe_command(command)
|
||||
-- Check exact matches
|
||||
for _, banned in ipairs(BANNED_COMMANDS) do
|
||||
if command == banned then
|
||||
return false, "Command is banned for safety"
|
||||
end
|
||||
end
|
||||
|
||||
-- Check patterns
|
||||
for _, pattern in ipairs(BANNED_PATTERNS) do
|
||||
if command:match(pattern) then
|
||||
return false, "Command matches banned pattern"
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
---@param input {command: string, cwd?: string, timeout?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.command then
|
||||
return nil, "command is required"
|
||||
end
|
||||
|
||||
-- Safety check
|
||||
local safe, reason = is_safe_command(input.command)
|
||||
if not safe then
|
||||
return nil, reason
|
||||
end
|
||||
|
||||
-- Confirmation required
|
||||
if M.requires_confirmation and opts.confirm then
|
||||
local confirmed = false
|
||||
local confirm_error = nil
|
||||
|
||||
opts.confirm("Execute command: " .. input.command, function(ok)
|
||||
if not ok then
|
||||
confirm_error = "User declined command execution"
|
||||
end
|
||||
confirmed = ok
|
||||
end)
|
||||
|
||||
-- Wait for confirmation (in async context, this would be handled differently)
|
||||
if confirm_error then
|
||||
return nil, confirm_error
|
||||
end
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Executing: " .. input.command)
|
||||
end
|
||||
|
||||
-- Prepare command
|
||||
local cwd = input.cwd or vim.fn.getcwd()
|
||||
local timeout = input.timeout or 120000
|
||||
|
||||
-- Execute command
|
||||
local output = ""
|
||||
local exit_code = 0
|
||||
|
||||
local job_opts = {
|
||||
command = "bash",
|
||||
args = { "-c", input.command },
|
||||
cwd = cwd,
|
||||
on_stdout = function(_, data)
|
||||
if data then
|
||||
output = output .. table.concat(data, "\n")
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data then
|
||||
output = output .. table.concat(data, "\n")
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
exit_code = code
|
||||
end,
|
||||
}
|
||||
|
||||
-- Run synchronously with timeout
|
||||
local Job = require("plenary.job")
|
||||
local job = Job:new(job_opts)
|
||||
|
||||
job:sync(timeout)
|
||||
exit_code = job.code or 0
|
||||
output = table.concat(job:result() or {}, "\n")
|
||||
|
||||
-- Also get stderr
|
||||
local stderr = table.concat(job:stderr_result() or {}, "\n")
|
||||
if stderr and stderr ~= "" then
|
||||
output = output .. "\n" .. stderr
|
||||
end
|
||||
|
||||
-- Check result
|
||||
if exit_code ~= 0 then
|
||||
local error_msg = string.format("Command failed with exit code %d: %s", exit_code, output)
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, error_msg)
|
||||
end
|
||||
return nil, error_msg
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(output, nil)
|
||||
end
|
||||
|
||||
return output, nil
|
||||
end
|
||||
|
||||
return M
|
||||
429
lua/codetyper/agent/tools/edit.lua
Normal file
429
lua/codetyper/agent/tools/edit.lua
Normal file
@@ -0,0 +1,429 @@
|
||||
---@mod codetyper.agent.tools.edit File editing tool with fallback matching
|
||||
---@brief [[
|
||||
--- Tool for making targeted edits to files using search/replace.
|
||||
--- Implements multiple fallback strategies for robust matching.
|
||||
--- Inspired by opencode's 9-strategy approach.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "edit"
|
||||
|
||||
M.description = [[Makes a targeted edit to a file by replacing text.
|
||||
|
||||
The old_string should match the content you want to replace. The tool uses multiple
|
||||
matching strategies with fallbacks:
|
||||
1. Exact match
|
||||
2. Whitespace-normalized match
|
||||
3. Indentation-flexible match
|
||||
4. Line-trimmed match
|
||||
5. Fuzzy anchor-based match
|
||||
|
||||
For creating new files, use old_string="" and provide the full content in new_string.
|
||||
For large changes, consider using 'write' tool instead.]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file to edit",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "old_string",
|
||||
description = "Text to find and replace (empty string to create new file or append)",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "new_string",
|
||||
description = "Text to replace with",
|
||||
type = "string",
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "Whether the edit was applied",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if edit failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
--- Normalize line endings to LF
|
||||
---@param str string
|
||||
---@return string
|
||||
local function normalize_line_endings(str)
|
||||
return str:gsub("\r\n", "\n"):gsub("\r", "\n")
|
||||
end
|
||||
|
||||
--- Strategy 1: Exact match
|
||||
---@param content string File content
|
||||
---@param old_str string String to find
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function exact_match(content, old_str)
|
||||
local pos = content:find(old_str, 1, true)
|
||||
if pos then
|
||||
return pos, pos + #old_str - 1
|
||||
end
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 2: Whitespace-normalized match
|
||||
--- Collapses all whitespace to single spaces
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function whitespace_normalized_match(content, old_str)
|
||||
local function normalize_ws(s)
|
||||
return s:gsub("%s+", " "):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
end
|
||||
|
||||
local norm_old = normalize_ws(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
-- Try to find matching block
|
||||
for i = 1, #lines do
|
||||
local block = {}
|
||||
local block_start = nil
|
||||
|
||||
for j = i, #lines do
|
||||
table.insert(block, lines[j])
|
||||
local block_text = table.concat(block, "\n")
|
||||
local norm_block = normalize_ws(block_text)
|
||||
|
||||
if norm_block == norm_old then
|
||||
-- Found match
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
|
||||
-- If block is already longer than target, stop
|
||||
if #norm_block > #norm_old then
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 3: Indentation-flexible match
|
||||
--- Ignores leading whitespace differences
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function indentation_flexible_match(content, old_str)
|
||||
local function strip_indent(s)
|
||||
local lines = vim.split(s, "\n")
|
||||
local result = {}
|
||||
for _, line in ipairs(lines) do
|
||||
table.insert(result, line:gsub("^%s+", ""))
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
local stripped_old = strip_indent(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local num_old_lines = #old_lines
|
||||
|
||||
for i = 1, #lines - num_old_lines + 1 do
|
||||
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
|
||||
local block_text = table.concat(block, "\n")
|
||||
|
||||
if strip_indent(block_text) == stripped_old then
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 4: Line-trimmed match
|
||||
--- Trims each line before comparing
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function line_trimmed_match(content, old_str)
|
||||
local function trim_lines(s)
|
||||
local lines = vim.split(s, "\n")
|
||||
local result = {}
|
||||
for _, line in ipairs(lines) do
|
||||
table.insert(result, line:match("^%s*(.-)%s*$"))
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
local trimmed_old = trim_lines(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local num_old_lines = #old_lines
|
||||
|
||||
for i = 1, #lines - num_old_lines + 1 do
|
||||
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
|
||||
local block_text = table.concat(block, "\n")
|
||||
|
||||
if trim_lines(block_text) == trimmed_old then
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Calculate Levenshtein distance between two strings
|
||||
---@param s1 string
|
||||
---@param s2 string
|
||||
---@return number
|
||||
local function levenshtein(s1, s2)
|
||||
local len1, len2 = #s1, #s2
|
||||
local matrix = {}
|
||||
|
||||
for i = 0, len1 do
|
||||
matrix[i] = { [0] = i }
|
||||
end
|
||||
for j = 0, len2 do
|
||||
matrix[0][j] = j
|
||||
end
|
||||
|
||||
for i = 1, len1 do
|
||||
for j = 1, len2 do
|
||||
local cost = s1:sub(i, i) == s2:sub(j, j) and 0 or 1
|
||||
matrix[i][j] = math.min(
|
||||
matrix[i - 1][j] + 1,
|
||||
matrix[i][j - 1] + 1,
|
||||
matrix[i - 1][j - 1] + cost
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
return matrix[len1][len2]
|
||||
end
|
||||
|
||||
--- Strategy 5: Fuzzy anchor-based match
|
||||
--- Uses first and last lines as anchors, allows fuzzy matching in between
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@param threshold? number Similarity threshold (0-1), default 0.8
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function fuzzy_anchor_match(content, old_str, threshold)
|
||||
threshold = threshold or 0.8
|
||||
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
if #old_lines < 2 then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local first_line = old_lines[1]:match("^%s*(.-)%s*$")
|
||||
local last_line = old_lines[#old_lines]:match("^%s*(.-)%s*$")
|
||||
local content_lines = vim.split(content, "\n")
|
||||
|
||||
-- Find potential start positions
|
||||
local candidates = {}
|
||||
for i, line in ipairs(content_lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
if trimmed == first_line or (
|
||||
#first_line > 0 and
|
||||
1 - (levenshtein(trimmed, first_line) / math.max(#trimmed, #first_line)) >= threshold
|
||||
) then
|
||||
table.insert(candidates, i)
|
||||
end
|
||||
end
|
||||
|
||||
-- For each candidate, look for matching end
|
||||
for _, start_idx in ipairs(candidates) do
|
||||
local expected_end = start_idx + #old_lines - 1
|
||||
if expected_end <= #content_lines then
|
||||
local end_line = content_lines[expected_end]:match("^%s*(.-)%s*$")
|
||||
if end_line == last_line or (
|
||||
#last_line > 0 and
|
||||
1 - (levenshtein(end_line, last_line) / math.max(#end_line, #last_line)) >= threshold
|
||||
) then
|
||||
-- Calculate positions
|
||||
local before = table.concat(vim.list_slice(content_lines, 1, start_idx - 1), "\n")
|
||||
local block = table.concat(vim.list_slice(content_lines, start_idx, expected_end), "\n")
|
||||
local start_pos = #before + (start_idx > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Try all matching strategies in order
|
||||
---@param content string File content
|
||||
---@param old_str string String to find
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
---@return string strategy_used
|
||||
local function find_match(content, old_str)
|
||||
-- Strategy 1: Exact match
|
||||
local start_pos, end_pos = exact_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "exact"
|
||||
end
|
||||
|
||||
-- Strategy 2: Whitespace-normalized
|
||||
start_pos, end_pos = whitespace_normalized_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "whitespace_normalized"
|
||||
end
|
||||
|
||||
-- Strategy 3: Indentation-flexible
|
||||
start_pos, end_pos = indentation_flexible_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "indentation_flexible"
|
||||
end
|
||||
|
||||
-- Strategy 4: Line-trimmed
|
||||
start_pos, end_pos = line_trimmed_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "line_trimmed"
|
||||
end
|
||||
|
||||
-- Strategy 5: Fuzzy anchor
|
||||
start_pos, end_pos = fuzzy_anchor_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "fuzzy_anchor"
|
||||
end
|
||||
|
||||
return nil, nil, "none"
|
||||
end
|
||||
|
||||
---@param input {path: string, old_string: string, new_string: string}
|
||||
---@param opts CoderToolOpts
|
||||
---@return boolean|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
if input.old_string == nil then
|
||||
return nil, "old_string is required"
|
||||
end
|
||||
if input.new_string == nil then
|
||||
return nil, "new_string is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Editing file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Normalize inputs
|
||||
local old_str = normalize_line_endings(input.old_string)
|
||||
local new_str = normalize_line_endings(input.new_string)
|
||||
|
||||
-- Handle new file creation (empty old_string)
|
||||
if old_str == "" then
|
||||
-- Create parent directories
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if vim.fn.isdirectory(dir) == 0 then
|
||||
vim.fn.mkdir(dir, "p")
|
||||
end
|
||||
|
||||
-- Write new file
|
||||
local lines = vim.split(new_str, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to create file: " .. input.path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
if vim.fn.filereadable(path) ~= 1 then
|
||||
return nil, "File not found: " .. input.path
|
||||
end
|
||||
|
||||
-- Read current content
|
||||
local lines = vim.fn.readfile(path)
|
||||
if not lines then
|
||||
return nil, "Failed to read file: " .. input.path
|
||||
end
|
||||
|
||||
local content = normalize_line_endings(table.concat(lines, "\n"))
|
||||
|
||||
-- Find match using fallback strategies
|
||||
local start_pos, end_pos, strategy = find_match(content, old_str)
|
||||
|
||||
if not start_pos then
|
||||
return nil, "old_string not found in file (tried 5 matching strategies)"
|
||||
end
|
||||
|
||||
if opts.on_log then
|
||||
opts.on_log("Match found using strategy: " .. strategy)
|
||||
end
|
||||
|
||||
-- Perform replacement
|
||||
local new_content = content:sub(1, start_pos - 1) .. new_str .. content:sub(end_pos + 1)
|
||||
|
||||
-- Write back
|
||||
local new_lines = vim.split(new_content, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, new_lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to write file: " .. input.path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
return M
|
||||
146
lua/codetyper/agent/tools/glob.lua
Normal file
146
lua/codetyper/agent/tools/glob.lua
Normal file
@@ -0,0 +1,146 @@
|
||||
---@mod codetyper.agent.tools.glob File pattern matching tool
|
||||
---@brief [[
|
||||
--- Tool for finding files by glob pattern.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "glob"
|
||||
|
||||
M.description = [[Finds files matching a glob pattern.
|
||||
|
||||
Example patterns:
|
||||
- "**/*.lua" - All Lua files
|
||||
- "src/**/*.ts" - TypeScript files in src
|
||||
- "**/test_*.py" - Test files in Python]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "pattern",
|
||||
description = "Glob pattern to match files",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "path",
|
||||
description = "Base directory to search in (default: project root)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "max_results",
|
||||
description = "Maximum number of results (default: 100)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "matches",
|
||||
description = "JSON array of matching file paths",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if glob failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
---@param input {pattern: string, path?: string, max_results?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.pattern then
|
||||
return nil, "pattern is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Finding files: " .. input.pattern)
|
||||
end
|
||||
|
||||
-- Resolve base path
|
||||
local base_path = input.path or vim.fn.getcwd()
|
||||
if not vim.startswith(base_path, "/") then
|
||||
base_path = vim.fn.getcwd() .. "/" .. base_path
|
||||
end
|
||||
|
||||
local max_results = input.max_results or 100
|
||||
|
||||
-- Use vim.fn.glob or fd if available
|
||||
local matches = {}
|
||||
|
||||
if vim.fn.executable("fd") == 1 then
|
||||
-- Use fd for better performance
|
||||
local Job = require("plenary.job")
|
||||
|
||||
-- Convert glob to fd pattern
|
||||
local fd_pattern = input.pattern:gsub("%*%*/", ""):gsub("%*", ".*")
|
||||
|
||||
local job = Job:new({
|
||||
command = "fd",
|
||||
args = {
|
||||
"--type",
|
||||
"f",
|
||||
"--max-results",
|
||||
tostring(max_results),
|
||||
"--glob",
|
||||
input.pattern,
|
||||
base_path,
|
||||
},
|
||||
cwd = base_path,
|
||||
})
|
||||
|
||||
job:sync(30000)
|
||||
matches = job:result() or {}
|
||||
else
|
||||
-- Fallback to vim.fn.globpath
|
||||
local pattern = base_path .. "/" .. input.pattern
|
||||
local files = vim.fn.glob(pattern, false, true)
|
||||
|
||||
for i, file in ipairs(files) do
|
||||
if i > max_results then
|
||||
break
|
||||
end
|
||||
-- Make paths relative to base_path
|
||||
local relative = file:gsub("^" .. vim.pesc(base_path) .. "/", "")
|
||||
table.insert(matches, relative)
|
||||
end
|
||||
end
|
||||
|
||||
-- Clean up matches
|
||||
local cleaned = {}
|
||||
for _, match in ipairs(matches) do
|
||||
if match and match ~= "" then
|
||||
-- Make relative if absolute
|
||||
local relative = match
|
||||
if vim.startswith(match, base_path) then
|
||||
relative = match:sub(#base_path + 2)
|
||||
end
|
||||
table.insert(cleaned, relative)
|
||||
end
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
matches = cleaned,
|
||||
total = #cleaned,
|
||||
truncated = #cleaned >= max_results,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
150
lua/codetyper/agent/tools/grep.lua
Normal file
150
lua/codetyper/agent/tools/grep.lua
Normal file
@@ -0,0 +1,150 @@
|
||||
---@mod codetyper.agent.tools.grep Search tool
|
||||
---@brief [[
|
||||
--- Tool for searching file contents using ripgrep.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "grep"
|
||||
|
||||
M.description = [[Searches for a pattern in files using ripgrep.
|
||||
|
||||
Returns file paths and matching lines. Use this to find code by content.
|
||||
|
||||
Example patterns:
|
||||
- "function foo" - Find function definitions
|
||||
- "import.*react" - Find React imports
|
||||
- "TODO|FIXME" - Find todo comments]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "pattern",
|
||||
description = "Regular expression pattern to search for",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "path",
|
||||
description = "Directory or file to search in (default: project root)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "include",
|
||||
description = "File glob pattern to include (e.g., '*.lua')",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "max_results",
|
||||
description = "Maximum number of results (default: 50)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "matches",
|
||||
description = "JSON array of matches with file, line_number, and content",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if search failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
---@param input {pattern: string, path?: string, include?: string, max_results?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.pattern then
|
||||
return nil, "pattern is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Searching for: " .. input.pattern)
|
||||
end
|
||||
|
||||
-- Build ripgrep command
|
||||
local path = input.path or vim.fn.getcwd()
|
||||
local max_results = input.max_results or 50
|
||||
|
||||
-- Resolve path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Check if ripgrep is available
|
||||
if vim.fn.executable("rg") ~= 1 then
|
||||
return nil, "ripgrep (rg) is not installed"
|
||||
end
|
||||
|
||||
-- Build command args
|
||||
local args = {
|
||||
"--json",
|
||||
"--max-count",
|
||||
tostring(max_results),
|
||||
"--no-heading",
|
||||
}
|
||||
|
||||
if input.include then
|
||||
table.insert(args, "--glob")
|
||||
table.insert(args, input.include)
|
||||
end
|
||||
|
||||
table.insert(args, input.pattern)
|
||||
table.insert(args, path)
|
||||
|
||||
-- Execute ripgrep
|
||||
local Job = require("plenary.job")
|
||||
local job = Job:new({
|
||||
command = "rg",
|
||||
args = args,
|
||||
cwd = vim.fn.getcwd(),
|
||||
})
|
||||
|
||||
job:sync(30000) -- 30 second timeout
|
||||
|
||||
local results = job:result() or {}
|
||||
local matches = {}
|
||||
|
||||
-- Parse JSON output
|
||||
for _, line in ipairs(results) do
|
||||
if line and line ~= "" then
|
||||
local ok, parsed = pcall(vim.json.decode, line)
|
||||
if ok and parsed.type == "match" then
|
||||
local data = parsed.data
|
||||
table.insert(matches, {
|
||||
file = data.path.text,
|
||||
line_number = data.line_number,
|
||||
content = data.lines.text:gsub("\n$", ""),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
matches = matches,
|
||||
total = #matches,
|
||||
truncated = #matches >= max_results,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
308
lua/codetyper/agent/tools/init.lua
Normal file
308
lua/codetyper/agent/tools/init.lua
Normal file
@@ -0,0 +1,308 @@
|
||||
---@mod codetyper.agent.tools Tool registry and orchestration
|
||||
---@brief [[
|
||||
--- Registry for LLM tools with execution and schema generation.
|
||||
--- Inspired by avante.nvim's tool system.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Registered tools
|
||||
---@type table<string, CoderTool>
|
||||
local tools = {}
|
||||
|
||||
--- Tool execution history for current session
|
||||
---@type table[]
|
||||
local execution_history = {}
|
||||
|
||||
--- Register a tool
|
||||
---@param tool CoderTool Tool to register
|
||||
function M.register(tool)
|
||||
if not tool.name then
|
||||
error("Tool must have a name")
|
||||
end
|
||||
tools[tool.name] = tool
|
||||
end
|
||||
|
||||
--- Unregister a tool
|
||||
---@param name string Tool name
|
||||
function M.unregister(name)
|
||||
tools[name] = nil
|
||||
end
|
||||
|
||||
--- Get a tool by name
|
||||
---@param name string Tool name
|
||||
---@return CoderTool|nil
|
||||
function M.get(name)
|
||||
return tools[name]
|
||||
end
|
||||
|
||||
--- Get all registered tools
|
||||
---@return table<string, CoderTool>
|
||||
function M.get_all()
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Get tools as a list
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return CoderTool[]
|
||||
function M.list(filter)
|
||||
local result = {}
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
table.insert(result, tool)
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Generate schemas for all tools (for LLM function calling)
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] schemas
|
||||
function M.get_schemas(filter)
|
||||
local schemas = {}
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
if tool.to_schema then
|
||||
table.insert(schemas, tool:to_schema())
|
||||
end
|
||||
end
|
||||
end
|
||||
return schemas
|
||||
end
|
||||
|
||||
--- Execute a tool by name
|
||||
---@param name string Tool name
|
||||
---@param input table Input parameters
|
||||
---@param opts CoderToolOpts Execution options
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M.execute(name, input, opts)
|
||||
local tool = tools[name]
|
||||
if not tool then
|
||||
return nil, "Unknown tool: " .. name
|
||||
end
|
||||
|
||||
-- Validate input
|
||||
if tool.validate_input then
|
||||
local valid, err = tool:validate_input(input)
|
||||
if not valid then
|
||||
return nil, err
|
||||
end
|
||||
end
|
||||
|
||||
-- Log execution
|
||||
if opts.on_log then
|
||||
opts.on_log(string.format("Executing tool: %s", name))
|
||||
end
|
||||
|
||||
-- Track execution
|
||||
local execution = {
|
||||
tool = name,
|
||||
input = input,
|
||||
start_time = os.time(),
|
||||
status = "running",
|
||||
}
|
||||
table.insert(execution_history, execution)
|
||||
|
||||
-- Execute the tool
|
||||
local result, err = tool.func(input, opts)
|
||||
|
||||
-- Update execution record
|
||||
execution.end_time = os.time()
|
||||
execution.status = err and "error" or "completed"
|
||||
execution.result = result
|
||||
execution.error = err
|
||||
|
||||
return result, err
|
||||
end
|
||||
|
||||
--- Process a tool call from LLM response
|
||||
---@param tool_call table Tool call from LLM (name + input)
|
||||
---@param opts CoderToolOpts Execution options
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M.process_tool_call(tool_call, opts)
|
||||
local name = tool_call.name or tool_call.function_name
|
||||
local input = tool_call.input or tool_call.arguments or {}
|
||||
|
||||
-- Parse JSON arguments if string
|
||||
if type(input) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, input)
|
||||
if ok then
|
||||
input = parsed
|
||||
else
|
||||
return nil, "Failed to parse tool arguments: " .. input
|
||||
end
|
||||
end
|
||||
|
||||
return M.execute(name, input, opts)
|
||||
end
|
||||
|
||||
--- Get execution history
|
||||
---@param limit? number Max entries to return
|
||||
---@return table[]
|
||||
function M.get_history(limit)
|
||||
if not limit then
|
||||
return execution_history
|
||||
end
|
||||
|
||||
local result = {}
|
||||
local start = math.max(1, #execution_history - limit + 1)
|
||||
for i = start, #execution_history do
|
||||
table.insert(result, execution_history[i])
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Clear execution history
|
||||
function M.clear_history()
|
||||
execution_history = {}
|
||||
end
|
||||
|
||||
--- Load built-in tools
|
||||
function M.load_builtins()
|
||||
-- View file tool
|
||||
local view = require("codetyper.agent.tools.view")
|
||||
M.register(view)
|
||||
|
||||
-- Bash tool
|
||||
local bash = require("codetyper.agent.tools.bash")
|
||||
M.register(bash)
|
||||
|
||||
-- Grep tool
|
||||
local grep = require("codetyper.agent.tools.grep")
|
||||
M.register(grep)
|
||||
|
||||
-- Glob tool
|
||||
local glob = require("codetyper.agent.tools.glob")
|
||||
M.register(glob)
|
||||
|
||||
-- Write file tool
|
||||
local write = require("codetyper.agent.tools.write")
|
||||
M.register(write)
|
||||
|
||||
-- Edit tool
|
||||
local edit = require("codetyper.agent.tools.edit")
|
||||
M.register(edit)
|
||||
end
|
||||
|
||||
--- Initialize tools system
|
||||
function M.setup()
|
||||
M.load_builtins()
|
||||
end
|
||||
|
||||
--- Get tool definitions for LLM (lazy-loaded, OpenAI format)
|
||||
--- This is accessed as M.definitions property
|
||||
M.definitions = setmetatable({}, {
|
||||
__call = function()
|
||||
-- Ensure tools are loaded
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end,
|
||||
__index = function(_, key)
|
||||
-- Make it work as both function and table
|
||||
if key == "get" then
|
||||
return function()
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end,
|
||||
})
|
||||
|
||||
--- Get definitions as a function (for backwards compatibility)
|
||||
function M.get_definitions()
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end
|
||||
|
||||
--- Convert all tools to OpenAI function calling format
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] OpenAI-compatible tool definitions
|
||||
function M.to_openai_format(filter)
|
||||
local openai_tools = {}
|
||||
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if param.default ~= nil then
|
||||
properties[param.name].default = param.default
|
||||
end
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
table.insert(openai_tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return openai_tools
|
||||
end
|
||||
|
||||
--- Convert all tools to Claude tool use format
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] Claude-compatible tool definitions
|
||||
function M.to_claude_format(filter)
|
||||
local claude_tools = {}
|
||||
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
table.insert(claude_tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return claude_tools
|
||||
end
|
||||
|
||||
return M
|
||||
149
lua/codetyper/agent/tools/view.lua
Normal file
149
lua/codetyper/agent/tools/view.lua
Normal file
@@ -0,0 +1,149 @@
|
||||
---@mod codetyper.agent.tools.view File viewing tool
|
||||
---@brief [[
|
||||
--- Tool for reading file contents with line range support.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "view"
|
||||
|
||||
M.description = [[Reads the content of a file.
|
||||
|
||||
Usage notes:
|
||||
- Provide the file path relative to the project root
|
||||
- Use start_line and end_line to read specific sections
|
||||
- If content is truncated, use line ranges to read in chunks
|
||||
- Returns JSON with content, total_line_count, and is_truncated]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file (relative to project root or absolute)",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "start_line",
|
||||
description = "Line number to start reading (1-indexed)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "end_line",
|
||||
description = "Line number to end reading (1-indexed, inclusive)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "content",
|
||||
description = "File contents as JSON with content, total_line_count, is_truncated",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if file could not be read",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
--- Maximum content size before truncation
|
||||
local MAX_CONTENT_SIZE = 200 * 1024 -- 200KB
|
||||
|
||||
---@param input {path: string, start_line?: integer, end_line?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Reading file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
-- Relative path - resolve from project root
|
||||
local root = vim.fn.getcwd()
|
||||
path = root .. "/" .. path
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
local stat = vim.uv.fs_stat(path)
|
||||
if not stat then
|
||||
return nil, "File not found: " .. input.path
|
||||
end
|
||||
|
||||
if stat.type == "directory" then
|
||||
return nil, "Path is a directory: " .. input.path
|
||||
end
|
||||
|
||||
-- Read file
|
||||
local lines = vim.fn.readfile(path)
|
||||
if not lines then
|
||||
return nil, "Failed to read file: " .. input.path
|
||||
end
|
||||
|
||||
-- Apply line range
|
||||
local start_line = input.start_line or 1
|
||||
local end_line = input.end_line or #lines
|
||||
|
||||
start_line = math.max(1, start_line)
|
||||
end_line = math.min(#lines, end_line)
|
||||
|
||||
local total_lines = #lines
|
||||
local selected_lines = {}
|
||||
|
||||
for i = start_line, end_line do
|
||||
table.insert(selected_lines, lines[i])
|
||||
end
|
||||
|
||||
-- Check for truncation
|
||||
local content = table.concat(selected_lines, "\n")
|
||||
local is_truncated = false
|
||||
|
||||
if #content > MAX_CONTENT_SIZE then
|
||||
-- Truncate content
|
||||
local truncated_lines = {}
|
||||
local size = 0
|
||||
|
||||
for _, line in ipairs(selected_lines) do
|
||||
size = size + #line + 1
|
||||
if size > MAX_CONTENT_SIZE then
|
||||
is_truncated = true
|
||||
break
|
||||
end
|
||||
table.insert(truncated_lines, line)
|
||||
end
|
||||
|
||||
content = table.concat(truncated_lines, "\n")
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
content = content,
|
||||
total_line_count = total_lines,
|
||||
is_truncated = is_truncated,
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
101
lua/codetyper/agent/tools/write.lua
Normal file
101
lua/codetyper/agent/tools/write.lua
Normal file
@@ -0,0 +1,101 @@
|
||||
---@mod codetyper.agent.tools.write File writing tool
|
||||
---@brief [[
|
||||
--- Tool for creating or overwriting files.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "write"
|
||||
|
||||
M.description = [[Creates or overwrites a file with new content.
|
||||
|
||||
IMPORTANT:
|
||||
- This will completely replace the file contents
|
||||
- Use 'edit' tool for partial modifications
|
||||
- Parent directories will be created if needed]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file to write",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "content",
|
||||
description = "Content to write to the file",
|
||||
type = "string",
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "Whether the file was written successfully",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if write failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = true
|
||||
|
||||
---@param input {path: string, content: string}
|
||||
---@param opts CoderToolOpts
|
||||
---@return boolean|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
if not input.content then
|
||||
return nil, "content is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Writing file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Create parent directories
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if vim.fn.isdirectory(dir) == 0 then
|
||||
vim.fn.mkdir(dir, "p")
|
||||
end
|
||||
|
||||
-- Write the file
|
||||
local lines = vim.split(input.content, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to write file: " .. path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -224,6 +224,86 @@ local function format_attached_files(attached_files)
|
||||
return table.concat(parts, "")
|
||||
end
|
||||
|
||||
--- Get coder companion file path for a target file
|
||||
---@param target_path string Target file path
|
||||
---@return string|nil Coder file path if exists
|
||||
local function get_coder_companion_path(target_path)
|
||||
if not target_path or target_path == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Skip if target is already a coder file
|
||||
if target_path:match("%.coder%.") then
|
||||
return nil
|
||||
end
|
||||
|
||||
local dir = vim.fn.fnamemodify(target_path, ":h")
|
||||
local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
local coder_path = dir .. "/" .. name .. ".coder." .. ext
|
||||
if vim.fn.filereadable(coder_path) == 1 then
|
||||
return coder_path
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Read and format coder companion context (business logic, pseudo-code)
|
||||
---@param target_path string Target file path
|
||||
---@return string Formatted coder context
|
||||
local function get_coder_context(target_path)
|
||||
local coder_path = get_coder_companion_path(target_path)
|
||||
if not coder_path then
|
||||
return ""
|
||||
end
|
||||
|
||||
local ok, lines = pcall(function()
|
||||
return vim.fn.readfile(coder_path)
|
||||
end)
|
||||
|
||||
if not ok or not lines or #lines == 0 then
|
||||
return ""
|
||||
end
|
||||
|
||||
local content = table.concat(lines, "\n")
|
||||
|
||||
-- Skip if only template comments (no actual content)
|
||||
local stripped = content:gsub("^%s*", ""):gsub("%s*$", "")
|
||||
if stripped == "" then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Check if there's meaningful content (not just template)
|
||||
local has_content = false
|
||||
for _, line in ipairs(lines) do
|
||||
-- Skip comment lines that are part of the template
|
||||
local trimmed = line:gsub("^%s*", "")
|
||||
if not trimmed:match("^[%-#/]+%s*Coder companion")
|
||||
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
|
||||
and not trimmed:match("^[%-#/]+%s*Example:")
|
||||
and not trimmed:match("^<!%-%-")
|
||||
and trimmed ~= ""
|
||||
and not trimmed:match("^[%-#/]+%s*$") then
|
||||
has_content = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if not has_content then
|
||||
return ""
|
||||
end
|
||||
|
||||
local ext = vim.fn.fnamemodify(coder_path, ":e")
|
||||
return string.format(
|
||||
"\n\n--- Business Context / Pseudo-code ---\n" ..
|
||||
"The following describes the intended behavior and design for this file:\n" ..
|
||||
"```%s\n%s\n```",
|
||||
ext,
|
||||
content:sub(1, 4000) -- Limit to 4000 chars
|
||||
)
|
||||
end
|
||||
|
||||
--- Format indexed project context for inclusion in prompt
|
||||
---@param indexed_context table|nil
|
||||
---@return string
|
||||
@@ -309,8 +389,53 @@ local function build_prompt(event)
|
||||
-- Format attached files
|
||||
local attached_content = format_attached_files(event.attached_files)
|
||||
|
||||
-- Combine attached files and indexed context
|
||||
local extra_context = attached_content .. indexed_content
|
||||
-- Get coder companion context (business logic, pseudo-code)
|
||||
local coder_context = get_coder_context(event.target_path)
|
||||
|
||||
-- Get brain memories - contextual recall based on current task
|
||||
local brain_context = ""
|
||||
pcall(function()
|
||||
local brain = require("codetyper.brain")
|
||||
if brain.is_initialized() then
|
||||
-- Query brain for relevant memories based on:
|
||||
-- 1. Current file (file-specific patterns)
|
||||
-- 2. Prompt content (semantic similarity)
|
||||
-- 3. Intent type (relevant past generations)
|
||||
local query_text = event.prompt_content or ""
|
||||
if event.scope and event.scope.name then
|
||||
query_text = event.scope.name .. " " .. query_text
|
||||
end
|
||||
|
||||
local result = brain.query({
|
||||
query = query_text,
|
||||
file = event.target_path,
|
||||
max_results = 5,
|
||||
types = { "pattern", "correction", "convention" },
|
||||
})
|
||||
|
||||
if result and result.nodes and #result.nodes > 0 then
|
||||
local memories = { "\n\n--- Learned Patterns & Conventions ---" }
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c then
|
||||
local summary = node.c.s or ""
|
||||
local detail = node.c.d or ""
|
||||
if summary ~= "" then
|
||||
table.insert(memories, "• " .. summary)
|
||||
if detail ~= "" and #detail < 200 then
|
||||
table.insert(memories, " " .. detail)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if #memories > 1 then
|
||||
brain_context = table.concat(memories, "\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
-- Combine all context sources: brain memories first, then coder context, attached files, indexed
|
||||
local extra_context = brain_context .. coder_context .. attached_content .. indexed_content
|
||||
|
||||
-- Build context with scope information
|
||||
local context = {
|
||||
@@ -502,21 +627,21 @@ function M.start(worker)
|
||||
end
|
||||
end, worker.timeout_ms)
|
||||
|
||||
-- Get client and execute
|
||||
local client, client_err = get_client(worker.worker_type)
|
||||
if not client then
|
||||
M.complete(worker, nil, client_err)
|
||||
return
|
||||
end
|
||||
|
||||
local prompt, context = build_prompt(worker.event)
|
||||
|
||||
-- Call the LLM
|
||||
client.generate(prompt, context, function(response, err, usage)
|
||||
-- Check if smart selection is enabled (memory-based provider selection)
|
||||
local use_smart_selection = false
|
||||
pcall(function()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
use_smart_selection = config.llm.smart_selection ~= false -- Default to true
|
||||
end)
|
||||
|
||||
-- Define the response handler
|
||||
local function handle_response(response, err, usage_or_metadata)
|
||||
-- Cancel timeout timer
|
||||
if worker.timer then
|
||||
pcall(function()
|
||||
-- Timer might have already fired
|
||||
if type(worker.timer) == "userdata" and worker.timer.stop then
|
||||
worker.timer:stop()
|
||||
end
|
||||
@@ -527,8 +652,45 @@ function M.start(worker)
|
||||
return -- Already timed out or cancelled
|
||||
end
|
||||
|
||||
-- Extract usage from metadata if smart_generate was used
|
||||
local usage = usage_or_metadata
|
||||
if type(usage_or_metadata) == "table" and usage_or_metadata.provider then
|
||||
-- This is metadata from smart_generate
|
||||
usage = nil
|
||||
-- Update worker type to reflect actual provider used
|
||||
worker.worker_type = usage_or_metadata.provider
|
||||
-- Log if pondering occurred
|
||||
if usage_or_metadata.pondered then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
"Pondering: %s (agreement: %.0f%%)",
|
||||
usage_or_metadata.corrected and "corrected" or "validated",
|
||||
(usage_or_metadata.agreement or 1) * 100
|
||||
),
|
||||
})
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
M.complete(worker, response, err, usage)
|
||||
end)
|
||||
end
|
||||
|
||||
-- Use smart selection or direct client
|
||||
if use_smart_selection then
|
||||
local llm = require("codetyper.llm")
|
||||
llm.smart_generate(prompt, context, handle_response)
|
||||
else
|
||||
-- Get client and execute directly
|
||||
local client, client_err = get_client(worker.worker_type)
|
||||
if not client then
|
||||
M.complete(worker, nil, client_err)
|
||||
return
|
||||
end
|
||||
client.generate(prompt, context, handle_response)
|
||||
end
|
||||
end
|
||||
|
||||
--- Complete worker execution
|
||||
|
||||
Reference in New Issue
Block a user