Adding more features

This commit is contained in:
2026-01-15 20:58:56 -05:00
parent 84c8bcf92c
commit f5df1a9ac0
40 changed files with 9145 additions and 458 deletions

View File

@@ -7,6 +7,57 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.5.0] - 2026-01-15
### Added
- **Cost Tracking System** - Track LLM API costs across sessions
- New `:CoderCost` command opens cost estimation floating window
- Session costs tracked in real-time
- All-time costs persisted in `.coder/cost_history.json`
- Per-model breakdown with token counts
- Pricing database for 50+ models (GPT-4/5, Claude, O-series, Gemini)
- Window keymaps: `q` close, `r` refresh, `c` clear session, `C` clear all
- **Automatic Ollama Fallback** - Graceful degradation when API limits hit
- Automatically switches to Ollama when Copilot rate limits exceeded
- Detects local Ollama availability before fallback
- Notifies user of provider switch
- **Enhanced Error Handling** - Better error messages for API failures
- Shows actual API response on parse errors (not generic "failed to parse")
- Improved rate limit detection and messaging
- Sanitized newlines in error notifications to prevent UI crashes
- **Agent Tools System Improvements**
- New `to_openai_format()` and `to_claude_format()` functions
- `get_definitions()` for generic tool access
- Fixed tool call argument serialization (JSON strings vs tables)
- **Credentials Management System** - Store API keys outside of config files
- New `:CoderAddApiKey` command for interactive credential setup
- `:CoderRemoveApiKey` to remove stored credentials
- `:CoderCredentials` to view credential status
- `:CoderSwitchProvider` to switch active LLM provider
- Credentials stored in `~/.local/share/nvim/codetyper/configuration.json`
- Priority: stored credentials > config > environment variables
- Supports all providers: Claude, OpenAI, Gemini, Copilot, Ollama
### Changed
- Cost window now shows both session and all-time statistics
- Improved agent prompt templates with correct tool names
- Better error context in LLM provider responses
### Fixed
- Fixed "Failed to parse Copilot response" error showing instead of actual error
- Fixed `nvim_buf_set_lines` crash from newlines in error messages
- Fixed `tools.definitions` nil error in agent initialization
- Fixed tool name mismatch in agent prompts (write_file vs write)
---
## [0.4.0] - 2026-01-13
### Added
@@ -194,7 +245,8 @@ scheduler = {
- **Fixed** - Bug fixes
- **Security** - Vulnerability fixes
[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...HEAD
[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.5.0...HEAD
[0.5.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...v0.5.0
[0.4.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.3.0...v0.4.0
[0.3.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.2.0...v0.3.0
[0.2.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.1.0...v0.2.0

214
README.md
View File

@@ -20,8 +20,10 @@
- 🛡️ **Completion-Aware**: Safe injection that doesn't fight with autocomplete
- 📁 **Auto-Index**: Automatically create coder companion files on file open
- 📜 **Logs Panel**: Real-time visibility into LLM requests and token usage
- 💰 **Cost Tracking**: Persistent LLM cost estimation with session and all-time stats
- 🔒 **Git Integration**: Automatically adds `.coder.*` files to `.gitignore`
- 🌳 **Project Tree Logging**: Maintains a `tree.log` tracking your project structure
- 🧠 **Brain System**: Knowledge graph that learns from your coding patterns
---
@@ -34,6 +36,8 @@
- [LLM Providers](#-llm-providers)
- [Commands Reference](#-commands-reference)
- [Usage Guide](#-usage-guide)
- [Logs Panel](#-logs-panel)
- [Cost Tracking](#-cost-tracking)
- [Agent Mode](#-agent-mode)
- [Keymaps](#-keymaps)
- [Health Check](#-health-check)
@@ -196,6 +200,32 @@ require("codetyper").setup({
| `OPENAI_API_KEY` | OpenAI API key |
| `GEMINI_API_KEY` | Google Gemini API key |
### Credentials Management
Instead of storing API keys in your config (which may be committed to git), you can use the credentials system:
```vim
:CoderAddApiKey
```
This command interactively prompts for:
1. Provider selection (Claude, OpenAI, Gemini, Copilot, Ollama)
2. API key (for cloud providers)
3. Model name
4. Custom endpoint (for OpenAI-compatible APIs)
Credentials are stored securely in `~/.local/share/nvim/codetyper/configuration.json` (not in your config files).
**Priority order for credentials:**
1. Stored credentials (via `:CoderAddApiKey`)
2. Config file settings
3. Environment variables
**Other credential commands:**
- `:CoderCredentials` - View configured providers
- `:CoderSwitchProvider` - Switch between configured providers
- `:CoderRemoveApiKey` - Remove stored credentials
---
## 🔌 LLM Providers
@@ -255,49 +285,129 @@ llm = {
## 📝 Commands Reference
### Main Commands
All commands can be invoked via `:Coder {subcommand}` or their dedicated command aliases.
### Core Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder open` | `:CoderOpen` | Open the coder split view |
| `:Coder close` | `:CoderClose` | Close the coder split view |
| `:Coder toggle` | `:CoderToggle` | Toggle the coder split view |
| `:Coder process` | `:CoderProcess` | Process the last prompt in coder file |
| `:Coder status` | - | Show plugin status and configuration |
| `:Coder focus` | - | Switch focus between coder and target windows |
| `:Coder reset` | - | Reset processed prompts to allow re-processing |
| `:Coder gitignore` | - | Force update .gitignore with coder patterns |
### Ask Panel (Chat Interface)
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder ask` | `:CoderAsk` | Open the Ask panel |
| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle the Ask panel |
| `:Coder ask-close` | - | Close the Ask panel |
| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history |
### Agent Mode (Autonomous Coding)
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder agent` | `:CoderAgent` | Open the Agent panel |
| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle the Agent panel |
| `:Coder agent-close` | - | Close the Agent panel |
| `:Coder agent-stop` | `:CoderAgentStop` | Stop the running agent |
### Agentic Mode (IDE-like Multi-file Agent)
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder agentic-run <task>` | `:CoderAgenticRun <task>` | Run an agentic task (multi-file changes) |
| `:Coder agentic-list` | `:CoderAgenticList` | List available agents |
| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize `.coder/agents/` and `.coder/rules/` |
### Transform Commands (Inline Tag Processing)
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder transform` | `:CoderTransform` | Transform all `/@ @/` tags in file |
| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor position |
| - | `:CoderTransformVisual` | Transform selected tags (visual mode) |
### Project & Index Commands
| Command | Alias | Description |
|---------|-------|-------------|
| - | `:CoderIndex` | Open coder companion for current file |
| `:Coder index-project` | `:CoderIndexProject` | Index the entire project |
| `:Coder index-status` | `:CoderIndexStatus` | Show project index status |
### Tree & Structure Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder tree` | `:CoderTree` | Refresh `.coder/tree.log` |
| `:Coder tree-view` | `:CoderTreeView` | View `.coder/tree.log` in split |
### Queue & Scheduler Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler and queue status |
| `:Coder queue-process` | `:CoderQueueProcess` | Manually trigger queue processing |
### Processing Mode Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle automatic/manual prompt processing |
| `:Coder auto-set <mode>` | `:CoderAutoSet <mode>` | Set processing mode (`auto`/`manual`) |
### Memory & Learning Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder memories` | `:CoderMemories` | Show learned memories |
| `:Coder forget [pattern]` | `:CoderForget [pattern]` | Clear memories (optionally matching pattern) |
### Brain Commands (Knowledge Graph)
| Command | Alias | Description |
|---------|-------|-------------|
| - | `:CoderBrain [action]` | Brain management (`stats`/`commit`/`flush`/`prune`) |
| - | `:CoderFeedback <type>` | Give feedback to brain (`good`/`bad`/`stats`) |
### LLM Statistics & Feedback
| Command | Description |
|---------|-------------|
| `:Coder {subcommand}` | Main command with subcommands |
| `:CoderOpen` | Open the coder split view |
| `:CoderClose` | Close the coder split view |
| `:CoderToggle` | Toggle the coder split view |
| `:CoderProcess` | Process the last prompt |
| `:Coder llm-stats` | Show LLM provider accuracy statistics |
| `:Coder llm-feedback-good` | Report positive feedback on last response |
| `:Coder llm-feedback-bad` | Report negative feedback on last response |
| `:Coder llm-reset-stats` | Reset LLM accuracy statistics |
### Ask Panel
### Cost Tracking
| Command | Description |
|---------|-------------|
| `:CoderAsk` | Open the Ask panel |
| `:CoderAskToggle` | Toggle the Ask panel |
| `:CoderAskClear` | Clear chat history |
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder cost` | `:CoderCost` | Show LLM cost estimation window |
| `:Coder cost-clear` | - | Clear session cost tracking |
### Agent Mode
### Credentials Management
| Command | Description |
|---------|-------------|
| `:CoderAgent` | Open the Agent panel |
| `:CoderAgentToggle` | Toggle the Agent panel |
| `:CoderAgentStop` | Stop the running agent |
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder add-api-key` | `:CoderAddApiKey` | Add or update LLM provider API key |
| `:Coder remove-api-key` | `:CoderRemoveApiKey` | Remove LLM provider credentials |
| `:Coder credentials` | `:CoderCredentials` | Show credentials status |
| `:Coder switch-provider` | `:CoderSwitchProvider` | Switch active LLM provider |
### Transform Commands
### UI Commands
| Command | Description |
|---------|-------------|
| `:CoderTransform` | Transform all /@ @/ tags in file |
| `:CoderTransformCursor` | Transform tag at cursor position |
| `:CoderTransformVisual` | Transform selected tags (visual mode) |
### Utility Commands
| Command | Description |
|---------|-------------|
| `:CoderIndex` | Open coder companion for current file |
| `:CoderLogs` | Toggle logs panel |
| `:CoderType` | Switch between Ask/Agent modes |
| `:CoderTree` | Refresh tree.log |
| `:CoderTreeView` | View tree.log |
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder type-toggle` | `:CoderType` | Show Ask/Agent mode switcher |
| `:Coder logs-toggle` | `:CoderLogs` | Toggle logs panel |
---
@@ -384,6 +494,42 @@ The logs panel opens automatically when processing prompts with the scheduler en
---
## 💰 Cost Tracking
Track your LLM API costs across sessions with the Cost Estimation window.
### Features
- **Session Tracking**: Monitor current session token usage and costs
- **All-Time Tracking**: Persistent cost history stored per-project in `.coder/cost_history.json`
- **Model Breakdown**: See costs by individual model
- **Pricing Database**: Built-in pricing for 50+ models (GPT, Claude, Gemini, O-series, etc.)
### Opening the Cost Window
```vim
:CoderCost
```
### Cost Window Keymaps
| Key | Description |
|-----|-------------|
| `q` / `<Esc>` | Close window |
| `r` | Refresh display |
| `c` | Clear session costs |
| `C` | Clear all history |
### Supported Models
The cost tracker includes pricing for:
- **OpenAI**: GPT-4, GPT-4o, GPT-4o-mini, O1, O3, O4-mini, and more
- **Anthropic**: Claude 3 Opus, Sonnet, Haiku, Claude 3.5 Sonnet/Haiku
- **Local**: Ollama models (free, but usage tracked)
- **Copilot**: Usage tracked (included in subscription)
---
## 🤖 Agent Mode
The Agent mode provides an autonomous coding assistant with tool access:

188
llms.txt
View File

@@ -33,6 +33,8 @@ lua/codetyper/
├── health.lua # Health check for :checkhealth
├── tree.lua # Project tree logging (.coder/tree.log)
├── logs_panel.lua # Standalone logs panel UI
├── cost.lua # LLM cost tracking with persistent history
├── credentials.lua # Secure credential storage (API keys, models)
├── llm/
│ ├── init.lua # LLM interface, provider selection
│ ├── claude.lua # Claude API client (Anthropic)
@@ -68,7 +70,14 @@ The plugin automatically creates and maintains a `.coder/` folder in your projec
```
.coder/
── tree.log # Project structure, auto-updated on file changes
── tree.log # Project structure, auto-updated on file changes
├── cost_history.json # LLM cost tracking history (persistent)
├── brain/ # Knowledge graph storage
│ ├── nodes/ # Learning nodes by type
│ ├── indices/ # Search indices
│ └── deltas/ # Version history
├── agents/ # Custom agent definitions
└── rules/ # Project-specific rules
```
## Key Features
@@ -115,7 +124,49 @@ auto_index = true -- disabled by default
Real-time visibility into LLM operations with token usage tracking.
### 6. Event-Driven Scheduler
### 6. Cost Tracking
Track LLM API costs across sessions:
- **Session tracking**: Monitor current session costs in real-time
- **All-time tracking**: Persistent history in `.coder/cost_history.json`
- **Per-model breakdown**: See costs by individual model
- **50+ models**: Built-in pricing for GPT, Claude, O-series, Gemini
Cost window keymaps:
- `q`/`<Esc>` - Close window
- `r` - Refresh display
- `c` - Clear session costs
- `C` - Clear all history
### 7. Automatic Ollama Fallback
When API rate limits are hit (e.g., Copilot free tier), the plugin:
1. Detects the rate limit error
2. Checks if local Ollama is available
3. Automatically switches provider to Ollama
4. Notifies user of the provider change
### 8. Credentials Management
Store API keys securely outside of config files:
```vim
:CoderAddApiKey
```
**Features:**
- Interactive prompts for provider, API key, model, endpoint
- Stored in `~/.local/share/nvim/codetyper/configuration.json`
- Supports all providers: Claude, OpenAI, Gemini, Copilot, Ollama
- Switch providers at runtime with `:CoderSwitchProvider`
**Credential priority:**
1. Stored credentials (via `:CoderAddApiKey`)
2. Config file settings (`require("codetyper").setup({...})`)
3. Environment variables (`OPENAI_API_KEY`, etc.)
### 9. Event-Driven Scheduler
Prompts are treated as events, not commands:
@@ -143,7 +194,7 @@ scheduler = {
}
```
### 7. Tree-sitter Scope Resolution
### 10. Tree-sitter Scope Resolution
Prompts automatically resolve to their enclosing function/method/class:
@@ -158,7 +209,7 @@ end
For replacement intents (complete, refactor, fix), the entire scope is extracted
and sent to the LLM, then replaced with the transformed version.
### 8. Intent Detection
### 11. Intent Detection
The system parses prompts to detect user intent:
@@ -173,7 +224,7 @@ The system parses prompts to detect user intent:
| optimize | optimize, performance, faster | replace |
| explain | explain, what, how, why | none |
### 9. Tag Precedence
### 12. Tag Precedence
Multiple tags in the same scope follow "first tag wins" rule:
- Earlier (by line number) unresolved tag processes first
@@ -182,33 +233,114 @@ Multiple tags in the same scope follow "first tag wins" rule:
## Commands
### Main Commands
- `:Coder open` - Opens split view with coder file
- `:Coder close` - Closes the split
- `:Coder toggle` - Toggles the view
- `:Coder process` - Manually triggers code generation
All commands can be invoked via `:Coder {subcommand}` or dedicated aliases.
### Ask Panel
- `:CoderAsk` - Open Ask panel
- `:CoderAskToggle` - Toggle Ask panel
- `:CoderAskClear` - Clear chat history
### Core Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder open` | `:CoderOpen` | Open coder split view |
| `:Coder close` | `:CoderClose` | Close coder split view |
| `:Coder toggle` | `:CoderToggle` | Toggle coder split view |
| `:Coder process` | `:CoderProcess` | Process last prompt in coder file |
| `:Coder status` | - | Show plugin status and configuration |
| `:Coder focus` | - | Switch focus between coder/target windows |
| `:Coder reset` | - | Reset processed prompts |
| `:Coder gitignore` | - | Force update .gitignore |
### Agent Mode
- `:CoderAgent` - Open Agent panel
- `:CoderAgentToggle` - Toggle Agent panel
- `:CoderAgentStop` - Stop running agent
### Ask Panel (Chat Interface)
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder ask` | `:CoderAsk` | Open Ask panel |
| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle Ask panel |
| `:Coder ask-close` | - | Close Ask panel |
| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history |
### Transform
- `:CoderTransform` - Transform all tags
- `:CoderTransformCursor` - Transform at cursor
- `:CoderTransformVisual` - Transform selection
### Agent Mode (Autonomous Coding)
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder agent` | `:CoderAgent` | Open Agent panel |
| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle Agent panel |
| `:Coder agent-close` | - | Close Agent panel |
| `:Coder agent-stop` | `:CoderAgentStop` | Stop running agent |
### Utility
- `:CoderIndex` - Open coder companion
- `:CoderLogs` - Toggle logs panel
- `:CoderType` - Switch Ask/Agent mode
- `:CoderTree` - Refresh tree.log
- `:CoderTreeView` - View tree.log
### Agentic Mode (IDE-like Multi-file Agent)
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder agentic-run <task>` | `:CoderAgenticRun <task>` | Run agentic task |
| `:Coder agentic-list` | `:CoderAgenticList` | List available agents |
| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize .coder/agents/ and .coder/rules/ |
### Transform Commands (Inline Tag Processing)
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder transform` | `:CoderTransform` | Transform all /@ @/ tags in file |
| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor |
| - | `:CoderTransformVisual` | Transform selected tags (visual mode) |
### Project & Index Commands
| Command | Alias | Description |
|---------|-------|-------------|
| - | `:CoderIndex` | Open coder companion for current file |
| `:Coder index-project` | `:CoderIndexProject` | Index entire project |
| `:Coder index-status` | `:CoderIndexStatus` | Show project index status |
### Tree & Structure Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder tree` | `:CoderTree` | Refresh .coder/tree.log |
| `:Coder tree-view` | `:CoderTreeView` | View .coder/tree.log |
### Queue & Scheduler Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler/queue status |
| `:Coder queue-process` | `:CoderQueueProcess` | Manually trigger queue processing |
### Processing Mode Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle automatic/manual processing |
| `:Coder auto-set <mode>` | `:CoderAutoSet <mode>` | Set mode (auto/manual) |
### Memory & Learning Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder memories` | `:CoderMemories` | Show learned memories |
| `:Coder forget [pattern]` | `:CoderForget [pattern]` | Clear memories |
### Brain Commands (Knowledge Graph)
| Command | Alias | Description |
|---------|-------|-------------|
| - | `:CoderBrain [action]` | Brain management (stats/commit/flush/prune) |
| - | `:CoderFeedback <type>` | Give feedback (good/bad/stats) |
### LLM Statistics & Feedback
| Command | Description |
|---------|-------------|
| `:Coder llm-stats` | Show LLM provider accuracy stats |
| `:Coder llm-feedback-good` | Report positive feedback |
| `:Coder llm-feedback-bad` | Report negative feedback |
| `:Coder llm-reset-stats` | Reset LLM accuracy stats |
### Cost Tracking
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder cost` | `:CoderCost` | Show LLM cost estimation window |
| `:Coder cost-clear` | - | Clear session cost tracking |
### Credentials Management
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder add-api-key` | `:CoderAddApiKey` | Add/update LLM provider credentials |
| `:Coder remove-api-key` | `:CoderRemoveApiKey` | Remove provider credentials |
| `:Coder credentials` | `:CoderCredentials` | Show credentials status |
| `:Coder switch-provider` | `:CoderSwitchProvider` | Switch active provider |
### UI Commands
| Command | Alias | Description |
|---------|-------|-------------|
| `:Coder type-toggle` | `:CoderType` | Show Ask/Agent mode switcher |
| `:Coder logs-toggle` | `:CoderLogs` | Toggle logs panel |
## Configuration Schema

View File

@@ -0,0 +1,854 @@
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
---@brief [[
--- Full agentic system that handles multi-file changes via tool calling.
--- Inspired by avante.nvim and opencode patterns.
---@brief ]]
local M = {}
---@class AgenticMessage
---@field role "system"|"user"|"assistant"|"tool"
---@field content string|table
---@field tool_calls? table[] For assistant messages with tool calls
---@field tool_call_id? string For tool result messages
---@field name? string Tool name for tool results
---@class AgenticToolCall
---@field id string Unique tool call ID
---@field type "function"
---@field function {name: string, arguments: string|table}
---@class AgenticOpts
---@field task string The task to accomplish
---@field files? string[] Initial files to include as context
---@field agent? string Agent name to use (default: "coder")
---@field model? string Model override
---@field max_iterations? number Max tool call rounds (default: 20)
---@field on_message? fun(msg: AgenticMessage) Called for each message
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
---@field on_file_change? fun(path: string, action: string) Called when file is modified
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
---@field on_status? fun(status: string) Status updates
--- Generate unique tool call ID
local function generate_tool_call_id()
return "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF))
end
--- Load agent definition
---@param name string Agent name
---@return table|nil agent definition
local function load_agent(name)
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
local agent_file = agents_dir .. "/" .. name .. ".md"
-- Check if custom agent exists
if vim.fn.filereadable(agent_file) == 1 then
local content = table.concat(vim.fn.readfile(agent_file), "\n")
-- Parse frontmatter and content
local frontmatter = {}
local body = content
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
if fm_match then
-- Parse YAML-like frontmatter
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
local key, value = line:match("^(%w+):%s*(.+)$")
if key and value then
frontmatter[key] = value
end
end
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
end
return {
name = name,
description = frontmatter.description or "Custom agent: " .. name,
system_prompt = body,
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
model = frontmatter.model,
}
end
-- Built-in agents
local builtin_agents = {
coder = {
name = "coder",
description = "Full-featured coding agent with file modification capabilities",
system_prompt = [[You are an expert software engineer. You have access to tools to read, write, and modify files.
## Your Capabilities
- Read files to understand the codebase
- Search for patterns with grep and glob
- Create new files with write tool
- Edit existing files with precise replacements
- Execute shell commands for builds and tests
## Guidelines
1. Always read relevant files before making changes
2. Make minimal, focused changes
3. Follow existing code style and patterns
4. Create tests when adding new functionality
5. Verify changes work by running tests or builds
## Important Rules
- NEVER guess file contents - always read first
- Make precise edits using exact string matching
- Explain your reasoning before making changes
- If unsure, ask for clarification]],
tools = { "view", "edit", "write", "grep", "glob", "bash" },
},
planner = {
name = "planner",
description = "Planning agent - read-only, helps design implementations",
system_prompt = [[You are a software architect. Analyze codebases and create implementation plans.
You can read files and search the codebase, but cannot modify files.
Your role is to:
1. Understand the existing architecture
2. Identify relevant files and patterns
3. Create step-by-step implementation plans
4. Suggest which files to modify and how
Be thorough in your analysis before making recommendations.]],
tools = { "view", "grep", "glob" },
},
explorer = {
name = "explorer",
description = "Exploration agent - quickly find information in codebase",
system_prompt = [[You are a codebase exploration assistant. Find information quickly and report back.
Your goal is to efficiently search and summarize findings.
Use glob to find files, grep to search content, and view to read specific files.
Be concise and focused in your responses.]],
tools = { "view", "grep", "glob" },
},
}
return builtin_agents[name]
end
--- Load rules from .coder/rules/
---@return string Combined rules content
local function load_rules()
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
local rules = {}
if vim.fn.isdirectory(rules_dir) == 1 then
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
for _, file in ipairs(files) do
local content = table.concat(vim.fn.readfile(file), "\n")
local filename = vim.fn.fnamemodify(file, ":t:r")
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
end
end
if #rules > 0 then
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
end
return ""
end
--- Build messages array for API request
---@param history AgenticMessage[]
---@param provider string "openai"|"claude"
---@return table[] Formatted messages
local function build_messages(history, provider)
local messages = {}
for _, msg in ipairs(history) do
if msg.role == "system" then
if provider == "claude" then
-- Claude uses system parameter, not message
-- Skip system messages in array
else
table.insert(messages, {
role = "system",
content = msg.content,
})
end
elseif msg.role == "user" then
table.insert(messages, {
role = "user",
content = msg.content,
})
elseif msg.role == "assistant" then
local message = {
role = "assistant",
content = msg.content,
}
if msg.tool_calls then
message.tool_calls = msg.tool_calls
if provider == "claude" then
-- Claude format: content is array of blocks
message.content = {}
if msg.content and msg.content ~= "" then
table.insert(message.content, {
type = "text",
text = msg.content,
})
end
for _, tc in ipairs(msg.tool_calls) do
table.insert(message.content, {
type = "tool_use",
id = tc.id,
name = tc["function"].name,
input = type(tc["function"].arguments) == "string"
and vim.json.decode(tc["function"].arguments)
or tc["function"].arguments,
})
end
end
end
table.insert(messages, message)
elseif msg.role == "tool" then
if provider == "claude" then
table.insert(messages, {
role = "user",
content = {
{
type = "tool_result",
tool_use_id = msg.tool_call_id,
content = msg.content,
},
},
})
else
table.insert(messages, {
role = "tool",
tool_call_id = msg.tool_call_id,
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
})
end
end
end
return messages
end
--- Build tools array for API request
---@param tool_names string[] Tool names to include
---@param provider string "openai"|"claude"
---@return table[] Formatted tools
local function build_tools(tool_names, provider)
local tools_mod = require("codetyper.agent.tools")
local tools = {}
for _, name in ipairs(tool_names) do
local tool = tools_mod.get(name)
if tool then
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if not param.optional then
table.insert(required, param.name)
end
end
local description = type(tool.description) == "function" and tool.description() or tool.description
if provider == "claude" then
table.insert(tools, {
name = tool.name,
description = description,
input_schema = {
type = "object",
properties = properties,
required = required,
},
})
else
table.insert(tools, {
type = "function",
["function"] = {
name = tool.name,
description = description,
parameters = {
type = "object",
properties = properties,
required = required,
},
},
})
end
end
end
return tools
end
--- Execute a tool call
---@param tool_call AgenticToolCall
---@param opts AgenticOpts
---@return string result
---@return string|nil error
local function execute_tool(tool_call, opts)
local tools_mod = require("codetyper.agent.tools")
local name = tool_call["function"].name
local args = tool_call["function"].arguments
-- Parse arguments if string
if type(args) == "string" then
local ok, parsed = pcall(vim.json.decode, args)
if ok then
args = parsed
else
return "", "Failed to parse tool arguments: " .. args
end
end
-- Notify tool start
if opts.on_tool_start then
opts.on_tool_start(name, args)
end
if opts.on_status then
opts.on_status("Executing: " .. name)
end
-- Execute the tool
local tool = tools_mod.get(name)
if not tool then
local err = "Unknown tool: " .. name
if opts.on_tool_end then
opts.on_tool_end(name, nil, err)
end
return "", err
end
local result, err = tool.func(args, {
on_log = function(msg)
if opts.on_status then
opts.on_status(msg)
end
end,
})
-- Notify tool end
if opts.on_tool_end then
opts.on_tool_end(name, result, err)
end
-- Track file changes
if opts.on_file_change and (name == "write" or name == "edit") and not err then
opts.on_file_change(args.path, name == "write" and "created" or "modified")
end
if err then
return "", err
end
return type(result) == "string" and result or vim.json.encode(result), nil
end
--- Parse tool calls from LLM response (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return AgenticToolCall[]
local function parse_tool_calls(response, provider)
local tool_calls = {}
-- Unified format: content array with tool_use blocks
local content = response.content or {}
for _, block in ipairs(content) do
if block.type == "tool_use" then
-- OpenAI expects arguments as JSON string, not table
local args = block.input
if type(args) == "table" then
args = vim.json.encode(args)
end
table.insert(tool_calls, {
id = block.id or generate_tool_call_id(),
type = "function",
["function"] = {
name = block.name,
arguments = args,
},
})
end
end
return tool_calls
end
--- Extract text content from response (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return string
local function extract_content(response, provider)
local parts = {}
for _, block in ipairs(response.content or {}) do
if block.type == "text" then
table.insert(parts, block.text)
end
end
return table.concat(parts, "\n")
end
--- Check if response indicates completion (unified Claude-like format)
---@param response table Raw API response in unified format
---@param provider string Provider name (unused, kept for signature compatibility)
---@return boolean
local function is_complete(response, provider)
return response.stop_reason == "end_turn"
end
--- Make API request to LLM with native tool calling support
---@param messages table[] Formatted messages
---@param tools table[] Formatted tools
---@param system_prompt string System prompt
---@param provider string "openai"|"claude"|"copilot"
---@param model string Model name
---@param callback fun(response: table|nil, error: string|nil)
local function call_llm(messages, tools, system_prompt, provider, model, callback)
local context = {
language = "lua",
file_content = "",
prompt_type = "agent",
project_root = vim.fn.getcwd(),
cwd = vim.fn.getcwd(),
}
-- Use native tool calling APIs
if provider == "copilot" then
local client = require("codetyper.llm.copilot")
-- Copilot's generate_with_tools expects messages in a specific format
-- Convert to the format it expects
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
-- Convert to our internal format
local result = {
content = {},
stop_reason = "end_turn",
}
if response and response.content then
for _, block in ipairs(response.content) do
if block.type == "text" then
table.insert(result.content, { type = "text", text = block.text })
elseif block.type == "tool_use" then
table.insert(result.content, {
type = "tool_use",
id = block.id or generate_tool_call_id(),
name = block.name,
input = block.input,
})
result.stop_reason = "tool_use"
end
end
end
callback(result, nil)
end)
elseif provider == "openai" then
local client = require("codetyper.llm.openai")
-- OpenAI's generate_with_tools
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
local result = {
content = {},
stop_reason = "end_turn",
}
if response and response.content then
for _, block in ipairs(response.content) do
if block.type == "text" then
table.insert(result.content, { type = "text", text = block.text })
elseif block.type == "tool_use" then
table.insert(result.content, {
type = "tool_use",
id = block.id or generate_tool_call_id(),
name = block.name,
input = block.input,
})
result.stop_reason = "tool_use"
end
end
end
callback(result, nil)
end)
elseif provider == "ollama" then
local client = require("codetyper.llm.ollama")
-- Ollama's generate_with_tools (text-based tool calling)
local converted_messages = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(converted_messages, msg)
end
end
client.generate_with_tools(converted_messages, context, tools, function(response, err)
if err then
callback(nil, err)
return
end
-- Response is already in Claude-like format from the provider
callback(response, nil)
end)
else
-- Fallback for other providers (ollama, etc.) - use text-based parsing
local client = require("codetyper.llm." .. provider)
-- Build prompt from messages
local prompt_parts = {}
for _, msg in ipairs(messages) do
if msg.role == "user" then
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
table.insert(prompt_parts, "User: " .. content)
elseif msg.role == "assistant" then
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
table.insert(prompt_parts, "Assistant: " .. content)
end
end
-- Add tool descriptions to prompt for text-based providers
local tool_desc = "\n\n## Available Tools\n"
tool_desc = tool_desc .. "Call tools by outputting JSON in this format:\n"
tool_desc = tool_desc .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
for _, tool in ipairs(tools) do
local name = tool.name or (tool["function"] and tool["function"].name)
local desc = tool.description or (tool["function"] and tool["function"].description)
if name then
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
end
end
context.file_content = system_prompt .. tool_desc
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
if err then
callback(nil, err)
return
end
-- Parse response for tool calls (text-based fallback)
local result = {
content = {},
stop_reason = "end_turn",
}
-- Extract text content
local text_content = response
-- Try to extract JSON tool calls from response
local json_match = response:match("```json%s*(%b{})%s*```")
if json_match then
local ok, parsed = pcall(vim.json.decode, json_match)
if ok and parsed.tool then
table.insert(result.content, {
type = "tool_use",
id = generate_tool_call_id(),
name = parsed.tool,
input = parsed.arguments or {},
})
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
result.stop_reason = "tool_use"
end
end
if text_content and text_content ~= "" then
table.insert(result.content, 1, { type = "text", text = text_content })
end
callback(result, nil)
end)
end
end
--- Run the agentic loop
---@param opts AgenticOpts
function M.run(opts)
-- Load agent
local agent = load_agent(opts.agent or "coder")
if not agent then
if opts.on_complete then
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
end
return
end
-- Load rules
local rules = load_rules()
-- Build system prompt
local system_prompt = agent.system_prompt .. rules
-- Initialize message history
---@type AgenticMessage[]
local history = {
{ role = "system", content = system_prompt },
}
-- Add initial file context if provided
if opts.files and #opts.files > 0 then
local file_context = "# Initial Files\n"
for _, file_path in ipairs(opts.files) do
local content = table.concat(vim.fn.readfile(file_path) or {}, "\n")
file_context = file_context .. string.format("\n## %s\n```\n%s\n```\n", file_path, content)
end
table.insert(history, { role = "user", content = file_context })
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
end
-- Add the task
table.insert(history, { role = "user", content = opts.task })
-- Determine provider
local config = require("codetyper").get_config()
local provider = config.llm.provider or "copilot"
-- Note: Ollama has its own handler in call_llm, don't change it
-- Get tools for this agent
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
-- Ensure tools are loaded
local tools_mod = require("codetyper.agent.tools")
tools_mod.setup()
-- Build tools for API
local tools = build_tools(tool_names, provider)
-- Iteration tracking
local iteration = 0
local max_iterations = opts.max_iterations or 20
--- Process one iteration
local function process_iteration()
iteration = iteration + 1
if iteration > max_iterations then
if opts.on_complete then
opts.on_complete(nil, "Max iterations reached")
end
return
end
if opts.on_status then
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
end
-- Build messages for API
local messages = build_messages(history, provider)
-- Call LLM
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
if err then
if opts.on_complete then
opts.on_complete(nil, err)
end
return
end
-- Extract content and tool calls
local content = extract_content(response, provider)
local tool_calls = parse_tool_calls(response, provider)
-- Add assistant message to history
local assistant_msg = {
role = "assistant",
content = content,
tool_calls = #tool_calls > 0 and tool_calls or nil,
}
table.insert(history, assistant_msg)
if opts.on_message then
opts.on_message(assistant_msg)
end
-- Process tool calls if any
if #tool_calls > 0 then
for _, tc in ipairs(tool_calls) do
local result, tool_err = execute_tool(tc, opts)
-- Add tool result to history
local tool_msg = {
role = "tool",
tool_call_id = tc.id,
name = tc["function"].name,
content = tool_err or result,
}
table.insert(history, tool_msg)
if opts.on_message then
opts.on_message(tool_msg)
end
end
-- Continue the loop
vim.schedule(process_iteration)
else
-- No tool calls - check if complete
if is_complete(response, provider) or content ~= "" then
if opts.on_complete then
opts.on_complete(content, nil)
end
else
-- Continue if not explicitly complete
vim.schedule(process_iteration)
end
end
end)
end
-- Start the loop
process_iteration()
end
--- Create default agent files in .coder/agents/
function M.init_agents_dir()
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
vim.fn.mkdir(agents_dir, "p")
-- Create example agent
local example_agent = [[---
description: Example custom agent
tools: view,grep,glob,edit,write
model:
---
# Custom Agent
You are a custom coding agent. Describe your specialized behavior here.
## Your Role
- Define what this agent specializes in
- List specific capabilities
## Guidelines
- Add agent-specific rules
- Define coding standards to follow
## Examples
Provide examples of how to handle common tasks.
]]
local example_path = agents_dir .. "/example.md"
if vim.fn.filereadable(example_path) ~= 1 then
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
end
return agents_dir
end
--- Create default rules in .coder/rules/
function M.init_rules_dir()
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
vim.fn.mkdir(rules_dir, "p")
-- Create example rule
local example_rule = [[# Code Style
Follow these coding standards:
## General
- Use consistent indentation (tabs or spaces based on project)
- Keep lines under 100 characters
- Add comments for complex logic
## Naming Conventions
- Use descriptive variable names
- Functions should be verbs (e.g., getUserData, calculateTotal)
- Constants in UPPER_SNAKE_CASE
## Testing
- Write tests for new functionality
- Aim for >80% code coverage
- Test edge cases
## Documentation
- Document public APIs
- Include usage examples
- Keep docs up to date with code
]]
local example_path = rules_dir .. "/code-style.md"
if vim.fn.filereadable(example_path) ~= 1 then
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
end
return rules_dir
end
--- Initialize both agents and rules directories
function M.init()
M.init_agents_dir()
M.init_rules_dir()
end
--- List available agents
---@return table[] List of {name, description, builtin}
function M.list_agents()
local agents = {}
-- Built-in agents
local builtins = { "coder", "planner", "explorer" }
for _, name in ipairs(builtins) do
local agent = load_agent(name)
if agent then
table.insert(agents, {
name = agent.name,
description = agent.description,
builtin = true,
})
end
end
-- Custom agents from .coder/agents/
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
if vim.fn.isdirectory(agents_dir) == 1 then
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
for _, file in ipairs(files) do
local name = vim.fn.fnamemodify(file, ":t:r")
if not vim.tbl_contains(builtins, name) then
local agent = load_agent(name)
if agent then
table.insert(agents, {
name = agent.name,
description = agent.description,
builtin = false,
})
end
end
end
end
return agents
end
return M

View File

@@ -8,11 +8,11 @@ local M = {}
--- Heuristic weights (must sum to 1.0)
M.weights = {
length = 0.15, -- Response length relative to prompt
length = 0.15, -- Response length relative to prompt
uncertainty = 0.30, -- Uncertainty phrases
syntax = 0.25, -- Syntax completeness
repetition = 0.15, -- Duplicate lines
truncation = 0.15, -- Incomplete ending
syntax = 0.25, -- Syntax completeness
repetition = 0.15, -- Duplicate lines
truncation = 0.15, -- Incomplete ending
}
--- Uncertainty phrases that indicate low confidence
@@ -255,14 +255,15 @@ function M.score(response, prompt, context)
_ = context -- Reserved for future use
if not response or #response == 0 then
return 0, {
length = 0,
uncertainty = 0,
syntax = 0,
repetition = 0,
truncation = 0,
weighted_total = 0,
}
return 0,
{
length = 0,
uncertainty = 0,
syntax = 0,
repetition = 0,
truncation = 0,
weighted_total = 0,
}
end
local scores = {

View File

@@ -111,7 +111,11 @@ function M.agent_loop(context, callbacks)
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
-- Generate with tools enabled
client.generate_with_tools(state.conversation, context, tools.definitions, function(response, err)
-- Ensure tools are loaded and get definitions
tools.setup()
local tool_defs = tools.to_openai_format()
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
if err then
state.is_running = false
callbacks.on_error(err)

View File

@@ -0,0 +1,614 @@
---@mod codetyper.agent.inject Smart code injection with import handling
---@brief [[
--- Intelligent code injection that properly handles imports, merging them
--- into existing import sections instead of blindly appending.
---@brief ]]
local M = {}
---@class ImportConfig
---@field pattern string Lua pattern to match import statements
---@field multi_line boolean Whether imports can span multiple lines
---@field sort_key function|nil Function to extract sort key from import
---@field group_by function|nil Function to group imports
---@class ParsedCode
---@field imports string[] Import statements
---@field body string[] Non-import code lines
---@field import_lines table<number, boolean> Map of line numbers that are imports
--- Language-specific import patterns
local import_patterns = {
-- JavaScript/TypeScript
javascript = {
{ pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true },
{ pattern = "^%s*import%s+['\"]", multi_line = false },
{ pattern = "^%s*import%s*{", multi_line = true },
{ pattern = "^%s*import%s*%*", multi_line = true },
{ pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true },
{ pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false },
{ pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false },
{ pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false },
},
-- Python
python = {
{ pattern = "^%s*import%s+%w", multi_line = false },
{ pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true },
},
-- Lua
lua = {
{ pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false },
{ pattern = "^%s*require%s*%(?['\"]", multi_line = false },
},
-- Go
go = {
{ pattern = "^%s*import%s+%(?", multi_line = true },
},
-- Rust
rust = {
{ pattern = "^%s*use%s+", multi_line = true },
{ pattern = "^%s*extern%s+crate%s+", multi_line = false },
},
-- C/C++
c = {
{ pattern = "^%s*#include%s*[<\"]", multi_line = false },
},
-- Java/Kotlin
java = {
{ pattern = "^%s*import%s+", multi_line = false },
},
-- Ruby
ruby = {
{ pattern = "^%s*require%s+['\"]", multi_line = false },
{ pattern = "^%s*require_relative%s+['\"]", multi_line = false },
},
-- PHP
php = {
{ pattern = "^%s*use%s+", multi_line = false },
{ pattern = "^%s*require%s+['\"]", multi_line = false },
{ pattern = "^%s*require_once%s+['\"]", multi_line = false },
{ pattern = "^%s*include%s+['\"]", multi_line = false },
{ pattern = "^%s*include_once%s+['\"]", multi_line = false },
},
}
-- Alias common extensions to language configs
import_patterns.ts = import_patterns.javascript
import_patterns.tsx = import_patterns.javascript
import_patterns.jsx = import_patterns.javascript
import_patterns.mjs = import_patterns.javascript
import_patterns.cjs = import_patterns.javascript
import_patterns.py = import_patterns.python
import_patterns.cpp = import_patterns.c
import_patterns.hpp = import_patterns.c
import_patterns.h = import_patterns.c
import_patterns.kt = import_patterns.java
import_patterns.rs = import_patterns.rust
import_patterns.rb = import_patterns.ruby
--- Check if a line is an import statement for the given language
---@param line string
---@param patterns table[] Import patterns for the language
---@return boolean is_import
---@return boolean is_multi_line
local function is_import_line(line, patterns)
for _, p in ipairs(patterns) do
if line:match(p.pattern) then
return true, p.multi_line or false
end
end
return false, false
end
--- Check if a line is empty or a comment
---@param line string
---@param filetype string
---@return boolean
local function is_empty_or_comment(line, filetype)
local trimmed = line:match("^%s*(.-)%s*$")
if trimmed == "" then
return true
end
-- Language-specific comment patterns
local comment_patterns = {
lua = { "^%-%-" },
python = { "^#" },
javascript = { "^//", "^/%*", "^%*" },
typescript = { "^//", "^/%*", "^%*" },
go = { "^//", "^/%*", "^%*" },
rust = { "^//", "^/%*", "^%*" },
c = { "^//", "^/%*", "^%*", "^#" },
java = { "^//", "^/%*", "^%*" },
ruby = { "^#" },
php = { "^//", "^/%*", "^%*", "^#" },
}
local patterns = comment_patterns[filetype] or comment_patterns.javascript
for _, pattern in ipairs(patterns) do
if trimmed:match(pattern) then
return true
end
end
return false
end
--- Check if a line ends a multi-line import
---@param line string
---@param filetype string
---@return boolean
local function ends_multiline_import(line, filetype)
-- Check for closing patterns
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
-- ES6 imports end with 'from "..." ;' or just ';' or a line with just '}'
if line:match("from%s+['\"][^'\"]+['\"]%s*;?%s*$") then
return true
end
if line:match("}%s*from%s+['\"]") then
return true
end
if line:match("^%s*}%s*;?%s*$") then
return true
end
if line:match(";%s*$") then
return true
end
elseif filetype == "python" or filetype == "py" then
-- Python single-line import: doesn't end with \, (, or ,
-- Examples: "from typing import List, Dict" or "import os"
if not line:match("\\%s*$") and not line:match("%(%s*$") and not line:match(",%s*$") then
return true
end
-- Python multiline imports end with closing paren
if line:match("%)%s*$") then
return true
end
elseif filetype == "go" then
-- Go multi-line imports end with ')'
if line:match("%)%s*$") then
return true
end
elseif filetype == "rust" or filetype == "rs" then
-- Rust use statements end with ';'
if line:match(";%s*$") then
return true
end
end
return false
end
--- Parse code into imports and body
---@param code string|string[] Code to parse
---@param filetype string File type/extension
---@return ParsedCode
function M.parse_code(code, filetype)
local lines
if type(code) == "string" then
lines = vim.split(code, "\n", { plain = true })
else
lines = code
end
local patterns = import_patterns[filetype] or import_patterns.javascript
local result = {
imports = {},
body = {},
import_lines = {},
}
local in_multiline_import = false
local current_import_lines = {}
for i, line in ipairs(lines) do
if in_multiline_import then
-- Continue collecting multi-line import
table.insert(current_import_lines, line)
if ends_multiline_import(line, filetype) then
-- Complete the multi-line import
table.insert(result.imports, table.concat(current_import_lines, "\n"))
for j = i - #current_import_lines + 1, i do
result.import_lines[j] = true
end
current_import_lines = {}
in_multiline_import = false
end
else
local is_import, is_multi = is_import_line(line, patterns)
if is_import then
result.import_lines[i] = true
if is_multi and not ends_multiline_import(line, filetype) then
-- Start of multi-line import
in_multiline_import = true
current_import_lines = { line }
else
-- Single-line import
table.insert(result.imports, line)
end
else
-- Non-import line
table.insert(result.body, line)
end
end
end
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
if #current_import_lines > 0 then
table.insert(result.imports, table.concat(current_import_lines, "\n"))
end
return result
end
--- Find the import section range in a buffer
---@param bufnr number Buffer number
---@param filetype string
---@return number|nil start_line First import line (1-indexed)
---@return number|nil end_line Last import line (1-indexed)
function M.find_import_section(bufnr, filetype)
if not vim.api.nvim_buf_is_valid(bufnr) then
return nil, nil
end
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local patterns = import_patterns[filetype] or import_patterns.javascript
local first_import = nil
local last_import = nil
local in_multiline = false
local consecutive_non_import = 0
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
for i, line in ipairs(lines) do
if in_multiline then
last_import = i
consecutive_non_import = 0
if ends_multiline_import(line, filetype) then
in_multiline = false
end
else
local is_import, is_multi = is_import_line(line, patterns)
if is_import then
if not first_import then
first_import = i
end
last_import = i
consecutive_non_import = 0
if is_multi and not ends_multiline_import(line, filetype) then
in_multiline = true
end
elseif is_empty_or_comment(line, filetype) then
-- Allow gaps in import section
if first_import then
consecutive_non_import = consecutive_non_import + 1
if consecutive_non_import > max_gap then
-- Too many non-import lines, import section has ended
break
end
end
else
-- Non-import, non-empty line
if first_import then
-- Import section has ended
break
end
end
end
end
return first_import, last_import
end
--- Get existing imports from a buffer
---@param bufnr number Buffer number
---@param filetype string
---@return string[] Existing import statements
function M.get_existing_imports(bufnr, filetype)
local start_line, end_line = M.find_import_section(bufnr, filetype)
if not start_line then
return {}
end
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
local parsed = M.parse_code(lines, filetype)
return parsed.imports
end
--- Normalize an import for comparison (remove whitespace variations)
---@param import_str string
---@return string
local function normalize_import(import_str)
-- Remove trailing semicolon for comparison
local normalized = import_str:gsub(";%s*$", "")
-- Remove all whitespace around braces, commas, colons
normalized = normalized:gsub("%s*{%s*", "{")
normalized = normalized:gsub("%s*}%s*", "}")
normalized = normalized:gsub("%s*,%s*", ",")
normalized = normalized:gsub("%s*:%s*", ":")
-- Collapse multiple whitespace to single space
normalized = normalized:gsub("%s+", " ")
-- Trim leading/trailing whitespace
normalized = normalized:match("^%s*(.-)%s*$")
return normalized
end
--- Check if two imports are duplicates
---@param import1 string
---@param import2 string
---@return boolean
local function are_duplicate_imports(import1, import2)
return normalize_import(import1) == normalize_import(import2)
end
--- Merge new imports with existing ones, avoiding duplicates
---@param existing string[] Existing imports
---@param new_imports string[] New imports to merge
---@return string[] Merged imports
function M.merge_imports(existing, new_imports)
local merged = {}
local seen = {}
-- Add existing imports
for _, imp in ipairs(existing) do
local normalized = normalize_import(imp)
if not seen[normalized] then
seen[normalized] = true
table.insert(merged, imp)
end
end
-- Add new imports that aren't duplicates
for _, imp in ipairs(new_imports) do
local normalized = normalize_import(imp)
if not seen[normalized] then
seen[normalized] = true
table.insert(merged, imp)
end
end
return merged
end
--- Sort imports by their source/module
---@param imports string[]
---@param filetype string
---@return string[]
function M.sort_imports(imports, filetype)
-- Group imports: stdlib/builtin first, then third-party, then local
local builtin = {}
local third_party = {}
local local_imports = {}
for _, imp in ipairs(imports) do
-- Detect import type based on patterns
local is_local = false
local is_builtin = false
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
-- Local: starts with . or ..
is_local = imp:match("from%s+['\"]%.") or imp:match("require%(['\"]%.")
-- Node builtin modules
is_builtin = imp:match("from%s+['\"]node:") or imp:match("from%s+['\"]fs['\"]")
or imp:match("from%s+['\"]path['\"]") or imp:match("from%s+['\"]http['\"]")
elseif filetype == "python" or filetype == "py" then
-- Local: relative imports
is_local = imp:match("^from%s+%.") or imp:match("^import%s+%.")
-- Python stdlib (simplified check)
is_builtin = imp:match("^import%s+os") or imp:match("^import%s+sys")
or imp:match("^from%s+os%s+") or imp:match("^from%s+sys%s+")
or imp:match("^import%s+re") or imp:match("^import%s+json")
elseif filetype == "lua" then
-- Local: relative requires
is_local = imp:match("require%(['\"]%.") or imp:match("require%s+['\"]%.")
elseif filetype == "go" then
-- Local: project imports (contain /)
is_local = imp:match("['\"][^'\"]+/[^'\"]+['\"]") and not imp:match("github%.com")
end
if is_builtin then
table.insert(builtin, imp)
elseif is_local then
table.insert(local_imports, imp)
else
table.insert(third_party, imp)
end
end
-- Sort each group alphabetically
table.sort(builtin)
table.sort(third_party)
table.sort(local_imports)
-- Combine with proper spacing
local result = {}
for _, imp in ipairs(builtin) do
table.insert(result, imp)
end
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
table.insert(result, "") -- Blank line between groups
end
for _, imp in ipairs(third_party) do
table.insert(result, imp)
end
if #third_party > 0 and #local_imports > 0 then
table.insert(result, "") -- Blank line between groups
end
for _, imp in ipairs(local_imports) do
table.insert(result, imp)
end
return result
end
---@class InjectResult
---@field success boolean
---@field imports_added number Number of new imports added
---@field imports_merged boolean Whether imports were merged into existing section
---@field body_lines number Number of body lines injected
--- Smart inject code into a buffer, properly handling imports
---@param bufnr number Target buffer
---@param code string|string[] Code to inject
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
---@return InjectResult
function M.inject(bufnr, code, opts)
opts = opts or {}
if not vim.api.nvim_buf_is_valid(bufnr) then
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
end
-- Get filetype
local filetype = opts.filetype
if not filetype then
local bufname = vim.api.nvim_buf_get_name(bufnr)
filetype = vim.fn.fnamemodify(bufname, ":e")
end
-- Parse the code to separate imports from body
local parsed = M.parse_code(code, filetype)
local result = {
success = true,
imports_added = 0,
imports_merged = false,
body_lines = #parsed.body,
}
-- Handle imports first if there are any
if #parsed.imports > 0 then
local import_start, import_end = M.find_import_section(bufnr, filetype)
if import_start then
-- Merge with existing import section
local existing_imports = M.get_existing_imports(bufnr, filetype)
local merged = M.merge_imports(existing_imports, parsed.imports)
-- Count how many new imports were actually added
result.imports_added = #merged - #existing_imports
result.imports_merged = true
-- Optionally sort imports
if opts.sort_imports ~= false then
merged = M.sort_imports(merged, filetype)
end
-- Convert back to lines (handling multi-line imports)
local import_lines = {}
for _, imp in ipairs(merged) do
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
table.insert(import_lines, line)
end
end
-- Replace the import section
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
-- Adjust line numbers for body injection
local lines_diff = #import_lines - (import_end - import_start + 1)
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
opts.range.start_line = opts.range.start_line + lines_diff
if opts.range.end_line then
opts.range.end_line = opts.range.end_line + lines_diff
end
end
else
-- No existing import section, add imports at the top
-- Find the first non-comment, non-empty line
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local insert_at = 0
for i, line in ipairs(lines) do
local trimmed = line:match("^%s*(.-)%s*$")
-- Skip shebang, docstrings, and initial comments
if trimmed ~= "" and not trimmed:match("^#!")
and not trimmed:match("^['\"]") and not is_empty_or_comment(line, filetype) then
insert_at = i - 1
break
end
insert_at = i
end
-- Add imports with a trailing blank line
local import_lines = {}
for _, imp in ipairs(parsed.imports) do
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
table.insert(import_lines, line)
end
end
table.insert(import_lines, "") -- Blank line after imports
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
result.imports_added = #parsed.imports
result.imports_merged = false
-- Adjust body injection range
if opts.range and opts.range.start_line then
opts.range.start_line = opts.range.start_line + #import_lines
if opts.range.end_line then
opts.range.end_line = opts.range.end_line + #import_lines
end
end
end
end
-- Handle body (non-import) code
if #parsed.body > 0 then
-- Filter out empty leading/trailing lines from body
local body_lines = parsed.body
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
table.remove(body_lines, 1)
end
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
table.remove(body_lines)
end
if #body_lines > 0 then
local line_count = vim.api.nvim_buf_line_count(bufnr)
local strategy = opts.strategy or "append"
if strategy == "replace" and opts.range then
local start_line = math.max(1, opts.range.start_line)
local end_line = math.min(line_count, opts.range.end_line)
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
elseif strategy == "insert" and opts.range then
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
else
-- Default: append
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
if last_line:match("%S") then
-- Add blank line for spacing
table.insert(body_lines, 1, "")
end
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
end
result.body_lines = #body_lines
end
end
return result
end
--- Check if code contains imports
---@param code string|string[]
---@param filetype string
---@return boolean
function M.has_imports(code, filetype)
local parsed = M.parse_code(code, filetype)
return #parsed.imports > 0
end
return M

View File

@@ -0,0 +1,398 @@
---@mod codetyper.agent.loop Agent loop with tool orchestration
---@brief [[
--- Main agent loop that handles multi-turn conversations with tool use.
--- Inspired by avante.nvim's agent_loop pattern.
---@brief ]]
local M = {}
---@class AgentMessage
---@field role "system"|"user"|"assistant"|"tool"
---@field content string|table
---@field tool_call_id? string For tool responses
---@field tool_calls? table[] For assistant tool calls
---@field name? string Tool name for tool responses
---@class AgentLoopOpts
---@field system_prompt string System prompt
---@field user_input string Initial user message
---@field tools? CoderTool[] Available tools (default: all registered)
---@field max_iterations? number Max tool call iterations (default: 10)
---@field provider? string LLM provider to use
---@field on_start? fun() Called when loop starts
---@field on_chunk? fun(chunk: string) Called for each response chunk
---@field on_tool_call? fun(name: string, input: table) Called before tool execution
---@field on_tool_result? fun(name: string, result: any, error: string|nil) Called after tool execution
---@field on_message? fun(message: AgentMessage) Called for each message added
---@field on_complete? fun(result: string|nil, error: string|nil) Called when loop completes
---@field session_ctx? table Session context shared across tools
--- Format tool definitions for OpenAI-compatible API
---@param tools CoderTool[]
---@return table[]
local function format_tools_for_api(tools)
local formatted = {}
for _, tool in ipairs(tools) do
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if not param.optional then
table.insert(required, param.name)
end
end
table.insert(formatted, {
type = "function",
["function"] = {
name = tool.name,
description = type(tool.description) == "function" and tool.description() or tool.description,
parameters = {
type = "object",
properties = properties,
required = required,
},
},
})
end
return formatted
end
--- Parse tool calls from LLM response
---@param response table LLM response
---@return table[] tool_calls
local function parse_tool_calls(response)
local tool_calls = {}
-- Handle different response formats
if response.tool_calls then
-- OpenAI format
for _, call in ipairs(response.tool_calls) do
local args = call["function"].arguments
if type(args) == "string" then
local ok, parsed = pcall(vim.json.decode, args)
if ok then
args = parsed
end
end
table.insert(tool_calls, {
id = call.id,
name = call["function"].name,
input = args,
})
end
elseif response.content and type(response.content) == "table" then
-- Claude format (content blocks)
for _, block in ipairs(response.content) do
if block.type == "tool_use" then
table.insert(tool_calls, {
id = block.id,
name = block.name,
input = block.input,
})
end
end
end
return tool_calls
end
--- Build messages for LLM request
---@param history AgentMessage[]
---@return table[]
local function build_messages(history)
local messages = {}
for _, msg in ipairs(history) do
if msg.role == "system" then
table.insert(messages, {
role = "system",
content = msg.content,
})
elseif msg.role == "user" then
table.insert(messages, {
role = "user",
content = msg.content,
})
elseif msg.role == "assistant" then
local message = {
role = "assistant",
content = msg.content,
}
if msg.tool_calls then
message.tool_calls = msg.tool_calls
end
table.insert(messages, message)
elseif msg.role == "tool" then
table.insert(messages, {
role = "tool",
tool_call_id = msg.tool_call_id,
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
})
end
end
return messages
end
--- Execute the agent loop
---@param opts AgentLoopOpts
function M.run(opts)
local tools_mod = require("codetyper.agent.tools")
local llm = require("codetyper.llm")
-- Get tools
local tools = opts.tools or tools_mod.list()
local tool_map = {}
for _, tool in ipairs(tools) do
tool_map[tool.name] = tool
end
-- Initialize conversation history
---@type AgentMessage[]
local history = {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = opts.user_input },
}
local session_ctx = opts.session_ctx or {}
local max_iterations = opts.max_iterations or 10
local iteration = 0
-- Callback wrappers
local function on_message(msg)
if opts.on_message then
opts.on_message(msg)
end
end
-- Notify of initial messages
for _, msg in ipairs(history) do
on_message(msg)
end
-- Start notification
if opts.on_start then
opts.on_start()
end
--- Process one iteration of the loop
local function process_iteration()
iteration = iteration + 1
if iteration > max_iterations then
if opts.on_complete then
opts.on_complete(nil, "Max iterations reached")
end
return
end
-- Build request
local messages = build_messages(history)
local formatted_tools = format_tools_for_api(tools)
-- Build context for LLM
local context = {
file_content = "",
language = "lua",
extension = "lua",
prompt_type = "agent",
tools = formatted_tools,
}
-- Get LLM response
local client = llm.get_client()
if not client then
if opts.on_complete then
opts.on_complete(nil, "No LLM client available")
end
return
end
-- Build prompt from messages
local prompt_parts = {}
for _, msg in ipairs(messages) do
if msg.role ~= "system" then
table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or ""))
end
end
local prompt = table.concat(prompt_parts, "\n\n")
client.generate(prompt, context, function(response, error)
if error then
if opts.on_complete then
opts.on_complete(nil, error)
end
return
end
-- Chunk callback
if opts.on_chunk then
opts.on_chunk(response)
end
-- Parse response for tool calls
-- For now, we'll use a simple heuristic to detect tool calls in the response
-- In a full implementation, the LLM would return structured tool calls
local tool_calls = {}
-- Try to parse JSON tool calls from response
local json_match = response:match("```json%s*(%b{})%s*```")
if json_match then
local ok, parsed = pcall(vim.json.decode, json_match)
if ok and parsed.tool_calls then
tool_calls = parsed.tool_calls
end
end
-- Add assistant message
local assistant_msg = {
role = "assistant",
content = response,
tool_calls = #tool_calls > 0 and tool_calls or nil,
}
table.insert(history, assistant_msg)
on_message(assistant_msg)
-- Process tool calls
if #tool_calls > 0 then
local pending = #tool_calls
local results = {}
for i, call in ipairs(tool_calls) do
local tool = tool_map[call.name]
if not tool then
results[i] = { error = "Unknown tool: " .. call.name }
pending = pending - 1
else
-- Notify of tool call
if opts.on_tool_call then
opts.on_tool_call(call.name, call.input)
end
-- Execute tool
local tool_opts = {
on_log = function(msg)
pcall(function()
local logs = require("codetyper.agent.logs")
logs.add({ type = "tool", message = msg })
end)
end,
on_complete = function(result, err)
results[i] = { result = result, error = err }
pending = pending - 1
-- Notify of tool result
if opts.on_tool_result then
opts.on_tool_result(call.name, result, err)
end
-- Add tool response to history
local tool_msg = {
role = "tool",
tool_call_id = call.id or tostring(i),
name = call.name,
content = err or result,
}
table.insert(history, tool_msg)
on_message(tool_msg)
-- Continue loop when all tools complete
if pending == 0 then
vim.schedule(process_iteration)
end
end,
session_ctx = session_ctx,
}
-- Validate and execute
local valid, validation_err = true, nil
if tool.validate_input then
valid, validation_err = tool:validate_input(call.input)
end
if not valid then
tool_opts.on_complete(nil, validation_err)
else
local result, err = tool.func(call.input, tool_opts)
-- If sync result, call on_complete
if result ~= nil or err ~= nil then
tool_opts.on_complete(result, err)
end
end
end
end
else
-- No tool calls - loop complete
if opts.on_complete then
opts.on_complete(response, nil)
end
end
end)
end
-- Start the loop
process_iteration()
end
--- Create an agent with default settings
---@param task string Task description
---@param opts? AgentLoopOpts Additional options
function M.create(task, opts)
opts = opts or {}
local system_prompt = opts.system_prompt or [[You are a helpful coding assistant with access to tools.
Available tools:
- view: Read file contents
- grep: Search for patterns in files
- glob: Find files by pattern
- edit: Make targeted edits to files
- write: Create or overwrite files
- bash: Execute shell commands
When you need to perform a task:
1. Use tools to gather information
2. Plan your approach
3. Execute changes using appropriate tools
4. Verify the results
Always explain your reasoning before using tools.
When you're done, provide a clear summary of what was accomplished.]]
M.run(vim.tbl_extend("force", opts, {
system_prompt = system_prompt,
user_input = task,
}))
end
--- Simple dispatch agent for sub-tasks
---@param prompt string Task for the sub-agent
---@param on_complete fun(result: string|nil, error: string|nil) Completion callback
---@param opts? table Additional options
function M.dispatch(prompt, on_complete, opts)
opts = opts or {}
-- Sub-agents get limited tools by default
local tools_mod = require("codetyper.agent.tools")
local safe_tools = tools_mod.list(function(tool)
return tool.name == "view" or tool.name == "grep" or tool.name == "glob"
end)
M.run({
system_prompt = [[You are a research assistant. Your task is to find information and report back.
You have access to: view (read files), grep (search content), glob (find files).
Be thorough and report your findings clearly.]],
user_input = prompt,
tools = opts.tools or safe_tools,
max_iterations = opts.max_iterations or 5,
on_complete = on_complete,
session_ctx = opts.session_ctx,
})
end
return M

View File

@@ -2,10 +2,16 @@
---@brief [[
--- Manages code patches with buffer snapshots for staleness detection.
--- Patches are queued for safe injection when completion popup is not visible.
--- Uses smart injection for intelligent import merging.
---@brief ]]
local M = {}
--- Lazy load inject module to avoid circular requires
local function get_inject_module()
return require("codetyper.agent.inject")
end
---@class BufferSnapshot
---@field bufnr number Buffer number
---@field changedtick number vim.b.changedtick at snapshot time
@@ -15,7 +21,8 @@ local M = {}
---@class PatchCandidate
---@field id string Unique patch ID
---@field event_id string Related PromptEvent ID
---@field target_bufnr number Target buffer for injection
---@field source_bufnr number Source buffer where prompt tags are (coder file)
---@field target_bufnr number Target buffer for injection (real file)
---@field target_path string Target file path
---@field original_snapshot BufferSnapshot Snapshot at event creation
---@field generated_code string Code to inject
@@ -171,7 +178,10 @@ end
---@param strategy string|nil Injection strategy (overrides intent-based)
---@return PatchCandidate
function M.create_from_event(event, generated_code, confidence, strategy)
-- Get target buffer
-- Source buffer is where the prompt tags are (could be coder file)
local source_bufnr = event.bufnr
-- Get target buffer (where code should be injected - the real file)
local target_bufnr = vim.fn.bufnr(event.target_path)
if target_bufnr == -1 then
-- Try to find by filename
@@ -220,7 +230,8 @@ function M.create_from_event(event, generated_code, confidence, strategy)
return {
id = M.generate_id(),
event_id = event.id,
target_bufnr = target_bufnr,
source_bufnr = source_bufnr, -- Where prompt tags are (coder file)
target_bufnr = target_bufnr, -- Where code goes (real file)
target_path = event.target_path,
original_snapshot = snapshot,
generated_code = generated_code,
@@ -453,39 +464,56 @@ function M.apply(patch)
-- Prepare code lines
local code_lines = vim.split(patch.generated_code, "\n", { plain = true })
-- FIRST: Remove the prompt tags from the buffer before applying code
-- This prevents the infinite loop where tags stay and get re-detected
local tags_removed = remove_prompt_tags(target_bufnr)
-- FIRST: Remove the prompt tags from the SOURCE buffer (coder file), not target
-- The tags are in the coder file where the user wrote the prompt
-- Code goes to target file, tags get removed from source file
local source_bufnr = patch.source_bufnr
local tags_removed = 0
pcall(function()
if tags_removed > 0 then
local logs = require("codetyper.agent.logs")
logs.add({
type = "info",
message = string.format("Removed %d prompt tag(s) from buffer", tags_removed),
})
end
end)
if source_bufnr and vim.api.nvim_buf_is_valid(source_bufnr) then
tags_removed = remove_prompt_tags(source_bufnr)
-- Recalculate line count after tag removal
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
pcall(function()
if tags_removed > 0 then
local logs = require("codetyper.agent.logs")
local source_name = vim.api.nvim_buf_get_name(source_bufnr)
logs.add({
type = "info",
message = string.format("Removed %d prompt tag(s) from %s",
tags_removed,
vim.fn.fnamemodify(source_name, ":t")),
})
end
end)
end
-- Apply based on strategy
-- Get filetype for smart injection
local filetype = vim.fn.fnamemodify(patch.target_path or "", ":e")
-- Use smart injection module for intelligent import handling
local inject = get_inject_module()
local inject_result = nil
-- Apply based on strategy using smart injection
local ok, err = pcall(function()
-- Prepare injection options
local inject_opts = {
strategy = patch.injection_strategy,
filetype = filetype,
sort_imports = true,
}
if patch.injection_strategy == "replace" and patch.injection_range then
-- Replace the scope range with the new code
-- The injection_range points to the function/method we're completing
local start_line = patch.injection_range.start_line
local end_line = patch.injection_range.end_line
-- Adjust for tag removal - find the new range by searching for the scope
-- After removing tags, line numbers may have shifted
-- Use the scope information to find the correct range
if patch.scope and patch.scope.type then
-- Try to find the scope using treesitter if available
local found_range = nil
pcall(function()
local ts_utils = require("nvim-treesitter.ts_utils")
local parsers = require("nvim-treesitter.parsers")
local parser = parsers.get_parser(target_bufnr)
if parser then
@@ -528,34 +556,38 @@ function M.apply(patch)
end
-- Clamp to valid range
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
start_line = math.max(1, start_line)
end_line = math.min(line_count, end_line)
-- Replace the range (0-indexed for nvim_buf_set_lines)
vim.api.nvim_buf_set_lines(target_bufnr, start_line - 1, end_line, false, code_lines)
inject_opts.range = { start_line = start_line, end_line = end_line }
elseif patch.injection_strategy == "insert" and patch.injection_range then
inject_opts.range = { start_line = patch.injection_range.start_line }
end
pcall(function()
local logs = require("codetyper.agent.logs")
-- Use smart injection - handles imports automatically
inject_result = inject.inject(target_bufnr, patch.generated_code, inject_opts)
-- Log injection details
pcall(function()
local logs = require("codetyper.agent.logs")
if inject_result.imports_added > 0 then
logs.add({
type = "info",
message = string.format("Replacing lines %d-%d with %d lines of code", start_line, end_line, #code_lines),
message = string.format(
"%s %d import(s), injected %d body line(s)",
inject_result.imports_merged and "Merged" or "Added",
inject_result.imports_added,
inject_result.body_lines
),
})
else
logs.add({
type = "info",
message = string.format("Injected %d line(s) of code", inject_result.body_lines),
})
end)
elseif patch.injection_strategy == "insert" and patch.injection_range then
-- Insert at the specified location
local insert_line = patch.injection_range.start_line
insert_line = math.max(1, math.min(line_count + 1, insert_line))
vim.api.nvim_buf_set_lines(target_bufnr, insert_line - 1, insert_line - 1, false, code_lines)
else
-- Default: append to end
-- Check if last line is empty, if not add a blank line for spacing
local last_line = vim.api.nvim_buf_get_lines(target_bufnr, line_count - 1, line_count, false)[1] or ""
if last_line:match("%S") then
-- Last line has content, add blank line for spacing
table.insert(code_lines, 1, "")
end
vim.api.nvim_buf_set_lines(target_bufnr, line_count, line_count, false, code_lines)
end
end)
end)
if not ok then
@@ -577,6 +609,41 @@ function M.apply(patch)
})
end)
-- Learn from successful code generation - this builds neural pathways
-- The more code is successfully applied, the better the brain becomes
pcall(function()
local brain = require("codetyper.brain")
if brain.is_initialized() then
-- Learn the successful pattern
local intent_type = patch.intent and patch.intent.type or "unknown"
local scope_type = patch.scope and patch.scope.type or "file"
local scope_name = patch.scope and patch.scope.name or ""
-- Create a meaningful summary for this learning
local summary = string.format(
"Generated %s: %s %s in %s",
intent_type,
scope_type,
scope_name ~= "" and scope_name or "",
vim.fn.fnamemodify(patch.target_path or "", ":t")
)
brain.learn({
type = "code_completion",
file = patch.target_path,
timestamp = os.time(),
data = {
intent = intent_type,
code = patch.generated_code:sub(1, 500), -- Store first 500 chars
language = vim.fn.fnamemodify(patch.target_path or "", ":e"),
function_name = scope_name,
prompt = patch.prompt_content,
confidence = patch.confidence or 0.5,
},
})
end
end)
return true, nil
end

View File

@@ -0,0 +1,128 @@
---@mod codetyper.agent.tools.base Base tool definition
---@brief [[
--- Base metatable for all LLM tools.
--- Tools extend this base to provide structured AI capabilities.
---@brief ]]
---@class CoderToolParam
---@field name string Parameter name
---@field description string Parameter description
---@field type string Parameter type ("string", "number", "boolean", "table")
---@field optional? boolean Whether the parameter is optional
---@field default? any Default value for optional parameters
---@class CoderToolReturn
---@field name string Return value name
---@field description string Return value description
---@field type string Return type
---@field optional? boolean Whether the return is optional
---@class CoderToolOpts
---@field on_log? fun(message: string) Log callback
---@field on_complete? fun(result: any, error: string|nil) Completion callback
---@field session_ctx? table Session context
---@field streaming? boolean Whether response is still streaming
---@field confirm? fun(message: string, callback: fun(ok: boolean)) Confirmation callback
---@class CoderTool
---@field name string Tool identifier
---@field description string|fun(): string Tool description
---@field params CoderToolParam[] Input parameters
---@field returns CoderToolReturn[] Return values
---@field requires_confirmation? boolean Whether tool needs user confirmation
---@field func fun(input: table, opts: CoderToolOpts): any, string|nil Tool implementation
local M = {}
M.__index = M
--- Call the tool function
---@param opts CoderToolOpts Options for the tool call
---@return any result
---@return string|nil error
function M:__call(opts, on_log, on_complete)
return self.func(opts, on_log, on_complete)
end
--- Get the tool description
---@return string
function M:get_description()
if type(self.description) == "function" then
return self.description()
end
return self.description
end
--- Validate input against parameter schema
---@param input table Input to validate
---@return boolean valid
---@return string|nil error
function M:validate_input(input)
if not self.params then
return true
end
for _, param in ipairs(self.params) do
local value = input[param.name]
-- Check required parameters
if not param.optional and value == nil then
return false, string.format("Missing required parameter: %s", param.name)
end
-- Type checking
if value ~= nil then
local actual_type = type(value)
local expected_type = param.type
-- Handle special types
if expected_type == "integer" and actual_type == "number" then
if math.floor(value) ~= value then
return false, string.format("Parameter %s must be an integer", param.name)
end
elseif expected_type ~= actual_type and expected_type ~= "any" then
return false, string.format("Parameter %s must be %s, got %s", param.name, expected_type, actual_type)
end
end
end
return true
end
--- Generate JSON schema for the tool (for LLM function calling)
---@return table schema
function M:to_schema()
local properties = {}
local required = {}
for _, param in ipairs(self.params or {}) do
local prop = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if param.default ~= nil then
prop.default = param.default
end
properties[param.name] = prop
if not param.optional then
table.insert(required, param.name)
end
end
return {
type = "function",
function_def = {
name = self.name,
description = self:get_description(),
parameters = {
type = "object",
properties = properties,
required = required,
},
},
}
end
return M

View File

@@ -0,0 +1,198 @@
---@mod codetyper.agent.tools.bash Shell command execution tool
---@brief [[
--- Tool for executing shell commands with safety checks.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "bash"
M.description = [[Executes a bash command in a shell.
IMPORTANT RULES:
- Do NOT use bash to read files (use 'view' tool instead)
- Do NOT use bash to modify files (use 'write' or 'edit' tools instead)
- Do NOT use interactive commands (vim, nano, less, etc.)
- Commands timeout after 2 minutes by default
Allowed uses:
- Running builds (make, npm run build, cargo build)
- Running tests (npm test, pytest, cargo test)
- Git operations (git status, git diff, git commit)
- Package management (npm install, pip install)
- System info commands (ls, pwd, which)]]
M.params = {
{
name = "command",
description = "The shell command to execute",
type = "string",
},
{
name = "cwd",
description = "Working directory for the command (optional)",
type = "string",
optional = true,
},
{
name = "timeout",
description = "Timeout in milliseconds (default: 120000)",
type = "integer",
optional = true,
},
}
M.returns = {
{
name = "stdout",
description = "Command output",
type = "string",
},
{
name = "error",
description = "Error message if command failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = true
--- Banned commands for safety
local BANNED_COMMANDS = {
"rm -rf /",
"rm -rf /*",
"dd if=/dev/zero",
"mkfs",
":(){ :|:& };:",
"> /dev/sda",
}
--- Banned patterns
local BANNED_PATTERNS = {
"curl.*|.*sh",
"wget.*|.*sh",
"rm%s+%-rf%s+/",
}
--- Check if command is safe
---@param command string
---@return boolean safe
---@return string|nil reason
local function is_safe_command(command)
-- Check exact matches
for _, banned in ipairs(BANNED_COMMANDS) do
if command == banned then
return false, "Command is banned for safety"
end
end
-- Check patterns
for _, pattern in ipairs(BANNED_PATTERNS) do
if command:match(pattern) then
return false, "Command matches banned pattern"
end
end
return true
end
---@param input {command: string, cwd?: string, timeout?: integer}
---@param opts CoderToolOpts
---@return string|nil result
---@return string|nil error
function M.func(input, opts)
if not input.command then
return nil, "command is required"
end
-- Safety check
local safe, reason = is_safe_command(input.command)
if not safe then
return nil, reason
end
-- Confirmation required
if M.requires_confirmation and opts.confirm then
local confirmed = false
local confirm_error = nil
opts.confirm("Execute command: " .. input.command, function(ok)
if not ok then
confirm_error = "User declined command execution"
end
confirmed = ok
end)
-- Wait for confirmation (in async context, this would be handled differently)
if confirm_error then
return nil, confirm_error
end
end
-- Log the operation
if opts.on_log then
opts.on_log("Executing: " .. input.command)
end
-- Prepare command
local cwd = input.cwd or vim.fn.getcwd()
local timeout = input.timeout or 120000
-- Execute command
local output = ""
local exit_code = 0
local job_opts = {
command = "bash",
args = { "-c", input.command },
cwd = cwd,
on_stdout = function(_, data)
if data then
output = output .. table.concat(data, "\n")
end
end,
on_stderr = function(_, data)
if data then
output = output .. table.concat(data, "\n")
end
end,
on_exit = function(_, code)
exit_code = code
end,
}
-- Run synchronously with timeout
local Job = require("plenary.job")
local job = Job:new(job_opts)
job:sync(timeout)
exit_code = job.code or 0
output = table.concat(job:result() or {}, "\n")
-- Also get stderr
local stderr = table.concat(job:stderr_result() or {}, "\n")
if stderr and stderr ~= "" then
output = output .. "\n" .. stderr
end
-- Check result
if exit_code ~= 0 then
local error_msg = string.format("Command failed with exit code %d: %s", exit_code, output)
if opts.on_complete then
opts.on_complete(nil, error_msg)
end
return nil, error_msg
end
if opts.on_complete then
opts.on_complete(output, nil)
end
return output, nil
end
return M

View File

@@ -0,0 +1,429 @@
---@mod codetyper.agent.tools.edit File editing tool with fallback matching
---@brief [[
--- Tool for making targeted edits to files using search/replace.
--- Implements multiple fallback strategies for robust matching.
--- Inspired by opencode's 9-strategy approach.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "edit"
M.description = [[Makes a targeted edit to a file by replacing text.
The old_string should match the content you want to replace. The tool uses multiple
matching strategies with fallbacks:
1. Exact match
2. Whitespace-normalized match
3. Indentation-flexible match
4. Line-trimmed match
5. Fuzzy anchor-based match
For creating new files, use old_string="" and provide the full content in new_string.
For large changes, consider using 'write' tool instead.]]
M.params = {
{
name = "path",
description = "Path to the file to edit",
type = "string",
},
{
name = "old_string",
description = "Text to find and replace (empty string to create new file or append)",
type = "string",
},
{
name = "new_string",
description = "Text to replace with",
type = "string",
},
}
M.returns = {
{
name = "success",
description = "Whether the edit was applied",
type = "boolean",
},
{
name = "error",
description = "Error message if edit failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = false
--- Normalize line endings to LF
---@param str string
---@return string
local function normalize_line_endings(str)
return str:gsub("\r\n", "\n"):gsub("\r", "\n")
end
--- Strategy 1: Exact match
---@param content string File content
---@param old_str string String to find
---@return number|nil start_pos
---@return number|nil end_pos
local function exact_match(content, old_str)
local pos = content:find(old_str, 1, true)
if pos then
return pos, pos + #old_str - 1
end
return nil, nil
end
--- Strategy 2: Whitespace-normalized match
--- Collapses all whitespace to single spaces
---@param content string
---@param old_str string
---@return number|nil start_pos
---@return number|nil end_pos
local function whitespace_normalized_match(content, old_str)
local function normalize_ws(s)
return s:gsub("%s+", " "):gsub("^%s+", ""):gsub("%s+$", "")
end
local norm_old = normalize_ws(old_str)
local lines = vim.split(content, "\n")
-- Try to find matching block
for i = 1, #lines do
local block = {}
local block_start = nil
for j = i, #lines do
table.insert(block, lines[j])
local block_text = table.concat(block, "\n")
local norm_block = normalize_ws(block_text)
if norm_block == norm_old then
-- Found match
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
local start_pos = #before + (i > 1 and 2 or 1)
local end_pos = start_pos + #block_text - 1
return start_pos, end_pos
end
-- If block is already longer than target, stop
if #norm_block > #norm_old then
break
end
end
end
return nil, nil
end
--- Strategy 3: Indentation-flexible match
--- Ignores leading whitespace differences
---@param content string
---@param old_str string
---@return number|nil start_pos
---@return number|nil end_pos
local function indentation_flexible_match(content, old_str)
local function strip_indent(s)
local lines = vim.split(s, "\n")
local result = {}
for _, line in ipairs(lines) do
table.insert(result, line:gsub("^%s+", ""))
end
return table.concat(result, "\n")
end
local stripped_old = strip_indent(old_str)
local lines = vim.split(content, "\n")
local old_lines = vim.split(old_str, "\n")
local num_old_lines = #old_lines
for i = 1, #lines - num_old_lines + 1 do
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
local block_text = table.concat(block, "\n")
if strip_indent(block_text) == stripped_old then
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
local start_pos = #before + (i > 1 and 2 or 1)
local end_pos = start_pos + #block_text - 1
return start_pos, end_pos
end
end
return nil, nil
end
--- Strategy 4: Line-trimmed match
--- Trims each line before comparing
---@param content string
---@param old_str string
---@return number|nil start_pos
---@return number|nil end_pos
local function line_trimmed_match(content, old_str)
local function trim_lines(s)
local lines = vim.split(s, "\n")
local result = {}
for _, line in ipairs(lines) do
table.insert(result, line:match("^%s*(.-)%s*$"))
end
return table.concat(result, "\n")
end
local trimmed_old = trim_lines(old_str)
local lines = vim.split(content, "\n")
local old_lines = vim.split(old_str, "\n")
local num_old_lines = #old_lines
for i = 1, #lines - num_old_lines + 1 do
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
local block_text = table.concat(block, "\n")
if trim_lines(block_text) == trimmed_old then
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
local start_pos = #before + (i > 1 and 2 or 1)
local end_pos = start_pos + #block_text - 1
return start_pos, end_pos
end
end
return nil, nil
end
--- Calculate Levenshtein distance between two strings
---@param s1 string
---@param s2 string
---@return number
local function levenshtein(s1, s2)
local len1, len2 = #s1, #s2
local matrix = {}
for i = 0, len1 do
matrix[i] = { [0] = i }
end
for j = 0, len2 do
matrix[0][j] = j
end
for i = 1, len1 do
for j = 1, len2 do
local cost = s1:sub(i, i) == s2:sub(j, j) and 0 or 1
matrix[i][j] = math.min(
matrix[i - 1][j] + 1,
matrix[i][j - 1] + 1,
matrix[i - 1][j - 1] + cost
)
end
end
return matrix[len1][len2]
end
--- Strategy 5: Fuzzy anchor-based match
--- Uses first and last lines as anchors, allows fuzzy matching in between
---@param content string
---@param old_str string
---@param threshold? number Similarity threshold (0-1), default 0.8
---@return number|nil start_pos
---@return number|nil end_pos
local function fuzzy_anchor_match(content, old_str, threshold)
threshold = threshold or 0.8
local old_lines = vim.split(old_str, "\n")
if #old_lines < 2 then
return nil, nil
end
local first_line = old_lines[1]:match("^%s*(.-)%s*$")
local last_line = old_lines[#old_lines]:match("^%s*(.-)%s*$")
local content_lines = vim.split(content, "\n")
-- Find potential start positions
local candidates = {}
for i, line in ipairs(content_lines) do
local trimmed = line:match("^%s*(.-)%s*$")
if trimmed == first_line or (
#first_line > 0 and
1 - (levenshtein(trimmed, first_line) / math.max(#trimmed, #first_line)) >= threshold
) then
table.insert(candidates, i)
end
end
-- For each candidate, look for matching end
for _, start_idx in ipairs(candidates) do
local expected_end = start_idx + #old_lines - 1
if expected_end <= #content_lines then
local end_line = content_lines[expected_end]:match("^%s*(.-)%s*$")
if end_line == last_line or (
#last_line > 0 and
1 - (levenshtein(end_line, last_line) / math.max(#end_line, #last_line)) >= threshold
) then
-- Calculate positions
local before = table.concat(vim.list_slice(content_lines, 1, start_idx - 1), "\n")
local block = table.concat(vim.list_slice(content_lines, start_idx, expected_end), "\n")
local start_pos = #before + (start_idx > 1 and 2 or 1)
local end_pos = start_pos + #block - 1
return start_pos, end_pos
end
end
end
return nil, nil
end
--- Try all matching strategies in order
---@param content string File content
---@param old_str string String to find
---@return number|nil start_pos
---@return number|nil end_pos
---@return string strategy_used
local function find_match(content, old_str)
-- Strategy 1: Exact match
local start_pos, end_pos = exact_match(content, old_str)
if start_pos then
return start_pos, end_pos, "exact"
end
-- Strategy 2: Whitespace-normalized
start_pos, end_pos = whitespace_normalized_match(content, old_str)
if start_pos then
return start_pos, end_pos, "whitespace_normalized"
end
-- Strategy 3: Indentation-flexible
start_pos, end_pos = indentation_flexible_match(content, old_str)
if start_pos then
return start_pos, end_pos, "indentation_flexible"
end
-- Strategy 4: Line-trimmed
start_pos, end_pos = line_trimmed_match(content, old_str)
if start_pos then
return start_pos, end_pos, "line_trimmed"
end
-- Strategy 5: Fuzzy anchor
start_pos, end_pos = fuzzy_anchor_match(content, old_str)
if start_pos then
return start_pos, end_pos, "fuzzy_anchor"
end
return nil, nil, "none"
end
---@param input {path: string, old_string: string, new_string: string}
---@param opts CoderToolOpts
---@return boolean|nil result
---@return string|nil error
function M.func(input, opts)
if not input.path then
return nil, "path is required"
end
if input.old_string == nil then
return nil, "old_string is required"
end
if input.new_string == nil then
return nil, "new_string is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Editing file: " .. input.path)
end
-- Resolve path
local path = input.path
if not vim.startswith(path, "/") then
path = vim.fn.getcwd() .. "/" .. path
end
-- Normalize inputs
local old_str = normalize_line_endings(input.old_string)
local new_str = normalize_line_endings(input.new_string)
-- Handle new file creation (empty old_string)
if old_str == "" then
-- Create parent directories
local dir = vim.fn.fnamemodify(path, ":h")
if vim.fn.isdirectory(dir) == 0 then
vim.fn.mkdir(dir, "p")
end
-- Write new file
local lines = vim.split(new_str, "\n", { plain = true })
local ok = pcall(vim.fn.writefile, lines, path)
if not ok then
return nil, "Failed to create file: " .. input.path
end
-- Reload buffer if open
local bufnr = vim.fn.bufnr(path)
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
vim.api.nvim_buf_call(bufnr, function()
vim.cmd("edit!")
end)
end
if opts.on_complete then
opts.on_complete(true, nil)
end
return true, nil
end
-- Check if file exists
if vim.fn.filereadable(path) ~= 1 then
return nil, "File not found: " .. input.path
end
-- Read current content
local lines = vim.fn.readfile(path)
if not lines then
return nil, "Failed to read file: " .. input.path
end
local content = normalize_line_endings(table.concat(lines, "\n"))
-- Find match using fallback strategies
local start_pos, end_pos, strategy = find_match(content, old_str)
if not start_pos then
return nil, "old_string not found in file (tried 5 matching strategies)"
end
if opts.on_log then
opts.on_log("Match found using strategy: " .. strategy)
end
-- Perform replacement
local new_content = content:sub(1, start_pos - 1) .. new_str .. content:sub(end_pos + 1)
-- Write back
local new_lines = vim.split(new_content, "\n", { plain = true })
local ok = pcall(vim.fn.writefile, new_lines, path)
if not ok then
return nil, "Failed to write file: " .. input.path
end
-- Reload buffer if open
local bufnr = vim.fn.bufnr(path)
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
vim.api.nvim_buf_call(bufnr, function()
vim.cmd("edit!")
end)
end
if opts.on_complete then
opts.on_complete(true, nil)
end
return true, nil
end
return M

View File

@@ -0,0 +1,146 @@
---@mod codetyper.agent.tools.glob File pattern matching tool
---@brief [[
--- Tool for finding files by glob pattern.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "glob"
M.description = [[Finds files matching a glob pattern.
Example patterns:
- "**/*.lua" - All Lua files
- "src/**/*.ts" - TypeScript files in src
- "**/test_*.py" - Test files in Python]]
M.params = {
{
name = "pattern",
description = "Glob pattern to match files",
type = "string",
},
{
name = "path",
description = "Base directory to search in (default: project root)",
type = "string",
optional = true,
},
{
name = "max_results",
description = "Maximum number of results (default: 100)",
type = "integer",
optional = true,
},
}
M.returns = {
{
name = "matches",
description = "JSON array of matching file paths",
type = "string",
},
{
name = "error",
description = "Error message if glob failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = false
---@param input {pattern: string, path?: string, max_results?: integer}
---@param opts CoderToolOpts
---@return string|nil result
---@return string|nil error
function M.func(input, opts)
if not input.pattern then
return nil, "pattern is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Finding files: " .. input.pattern)
end
-- Resolve base path
local base_path = input.path or vim.fn.getcwd()
if not vim.startswith(base_path, "/") then
base_path = vim.fn.getcwd() .. "/" .. base_path
end
local max_results = input.max_results or 100
-- Use vim.fn.glob or fd if available
local matches = {}
if vim.fn.executable("fd") == 1 then
-- Use fd for better performance
local Job = require("plenary.job")
-- Convert glob to fd pattern
local fd_pattern = input.pattern:gsub("%*%*/", ""):gsub("%*", ".*")
local job = Job:new({
command = "fd",
args = {
"--type",
"f",
"--max-results",
tostring(max_results),
"--glob",
input.pattern,
base_path,
},
cwd = base_path,
})
job:sync(30000)
matches = job:result() or {}
else
-- Fallback to vim.fn.globpath
local pattern = base_path .. "/" .. input.pattern
local files = vim.fn.glob(pattern, false, true)
for i, file in ipairs(files) do
if i > max_results then
break
end
-- Make paths relative to base_path
local relative = file:gsub("^" .. vim.pesc(base_path) .. "/", "")
table.insert(matches, relative)
end
end
-- Clean up matches
local cleaned = {}
for _, match in ipairs(matches) do
if match and match ~= "" then
-- Make relative if absolute
local relative = match
if vim.startswith(match, base_path) then
relative = match:sub(#base_path + 2)
end
table.insert(cleaned, relative)
end
end
-- Return as JSON
local result = vim.json.encode({
matches = cleaned,
total = #cleaned,
truncated = #cleaned >= max_results,
})
if opts.on_complete then
opts.on_complete(result, nil)
end
return result, nil
end
return M

View File

@@ -0,0 +1,150 @@
---@mod codetyper.agent.tools.grep Search tool
---@brief [[
--- Tool for searching file contents using ripgrep.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "grep"
M.description = [[Searches for a pattern in files using ripgrep.
Returns file paths and matching lines. Use this to find code by content.
Example patterns:
- "function foo" - Find function definitions
- "import.*react" - Find React imports
- "TODO|FIXME" - Find todo comments]]
M.params = {
{
name = "pattern",
description = "Regular expression pattern to search for",
type = "string",
},
{
name = "path",
description = "Directory or file to search in (default: project root)",
type = "string",
optional = true,
},
{
name = "include",
description = "File glob pattern to include (e.g., '*.lua')",
type = "string",
optional = true,
},
{
name = "max_results",
description = "Maximum number of results (default: 50)",
type = "integer",
optional = true,
},
}
M.returns = {
{
name = "matches",
description = "JSON array of matches with file, line_number, and content",
type = "string",
},
{
name = "error",
description = "Error message if search failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = false
---@param input {pattern: string, path?: string, include?: string, max_results?: integer}
---@param opts CoderToolOpts
---@return string|nil result
---@return string|nil error
function M.func(input, opts)
if not input.pattern then
return nil, "pattern is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Searching for: " .. input.pattern)
end
-- Build ripgrep command
local path = input.path or vim.fn.getcwd()
local max_results = input.max_results or 50
-- Resolve path
if not vim.startswith(path, "/") then
path = vim.fn.getcwd() .. "/" .. path
end
-- Check if ripgrep is available
if vim.fn.executable("rg") ~= 1 then
return nil, "ripgrep (rg) is not installed"
end
-- Build command args
local args = {
"--json",
"--max-count",
tostring(max_results),
"--no-heading",
}
if input.include then
table.insert(args, "--glob")
table.insert(args, input.include)
end
table.insert(args, input.pattern)
table.insert(args, path)
-- Execute ripgrep
local Job = require("plenary.job")
local job = Job:new({
command = "rg",
args = args,
cwd = vim.fn.getcwd(),
})
job:sync(30000) -- 30 second timeout
local results = job:result() or {}
local matches = {}
-- Parse JSON output
for _, line in ipairs(results) do
if line and line ~= "" then
local ok, parsed = pcall(vim.json.decode, line)
if ok and parsed.type == "match" then
local data = parsed.data
table.insert(matches, {
file = data.path.text,
line_number = data.line_number,
content = data.lines.text:gsub("\n$", ""),
})
end
end
end
-- Return as JSON
local result = vim.json.encode({
matches = matches,
total = #matches,
truncated = #matches >= max_results,
})
if opts.on_complete then
opts.on_complete(result, nil)
end
return result, nil
end
return M

View File

@@ -0,0 +1,308 @@
---@mod codetyper.agent.tools Tool registry and orchestration
---@brief [[
--- Registry for LLM tools with execution and schema generation.
--- Inspired by avante.nvim's tool system.
---@brief ]]
local M = {}
--- Registered tools
---@type table<string, CoderTool>
local tools = {}
--- Tool execution history for current session
---@type table[]
local execution_history = {}
--- Register a tool
---@param tool CoderTool Tool to register
function M.register(tool)
if not tool.name then
error("Tool must have a name")
end
tools[tool.name] = tool
end
--- Unregister a tool
---@param name string Tool name
function M.unregister(name)
tools[name] = nil
end
--- Get a tool by name
---@param name string Tool name
---@return CoderTool|nil
function M.get(name)
return tools[name]
end
--- Get all registered tools
---@return table<string, CoderTool>
function M.get_all()
return tools
end
--- Get tools as a list
---@param filter? fun(tool: CoderTool): boolean Optional filter function
---@return CoderTool[]
function M.list(filter)
local result = {}
for _, tool in pairs(tools) do
if not filter or filter(tool) then
table.insert(result, tool)
end
end
return result
end
--- Generate schemas for all tools (for LLM function calling)
---@param filter? fun(tool: CoderTool): boolean Optional filter function
---@return table[] schemas
function M.get_schemas(filter)
local schemas = {}
for _, tool in pairs(tools) do
if not filter or filter(tool) then
if tool.to_schema then
table.insert(schemas, tool:to_schema())
end
end
end
return schemas
end
--- Execute a tool by name
---@param name string Tool name
---@param input table Input parameters
---@param opts CoderToolOpts Execution options
---@return any result
---@return string|nil error
function M.execute(name, input, opts)
local tool = tools[name]
if not tool then
return nil, "Unknown tool: " .. name
end
-- Validate input
if tool.validate_input then
local valid, err = tool:validate_input(input)
if not valid then
return nil, err
end
end
-- Log execution
if opts.on_log then
opts.on_log(string.format("Executing tool: %s", name))
end
-- Track execution
local execution = {
tool = name,
input = input,
start_time = os.time(),
status = "running",
}
table.insert(execution_history, execution)
-- Execute the tool
local result, err = tool.func(input, opts)
-- Update execution record
execution.end_time = os.time()
execution.status = err and "error" or "completed"
execution.result = result
execution.error = err
return result, err
end
--- Process a tool call from LLM response
---@param tool_call table Tool call from LLM (name + input)
---@param opts CoderToolOpts Execution options
---@return any result
---@return string|nil error
function M.process_tool_call(tool_call, opts)
local name = tool_call.name or tool_call.function_name
local input = tool_call.input or tool_call.arguments or {}
-- Parse JSON arguments if string
if type(input) == "string" then
local ok, parsed = pcall(vim.json.decode, input)
if ok then
input = parsed
else
return nil, "Failed to parse tool arguments: " .. input
end
end
return M.execute(name, input, opts)
end
--- Get execution history
---@param limit? number Max entries to return
---@return table[]
function M.get_history(limit)
if not limit then
return execution_history
end
local result = {}
local start = math.max(1, #execution_history - limit + 1)
for i = start, #execution_history do
table.insert(result, execution_history[i])
end
return result
end
--- Clear execution history
function M.clear_history()
execution_history = {}
end
--- Load built-in tools
function M.load_builtins()
-- View file tool
local view = require("codetyper.agent.tools.view")
M.register(view)
-- Bash tool
local bash = require("codetyper.agent.tools.bash")
M.register(bash)
-- Grep tool
local grep = require("codetyper.agent.tools.grep")
M.register(grep)
-- Glob tool
local glob = require("codetyper.agent.tools.glob")
M.register(glob)
-- Write file tool
local write = require("codetyper.agent.tools.write")
M.register(write)
-- Edit tool
local edit = require("codetyper.agent.tools.edit")
M.register(edit)
end
--- Initialize tools system
function M.setup()
M.load_builtins()
end
--- Get tool definitions for LLM (lazy-loaded, OpenAI format)
--- This is accessed as M.definitions property
M.definitions = setmetatable({}, {
__call = function()
-- Ensure tools are loaded
if vim.tbl_count(tools) == 0 then
M.load_builtins()
end
return M.to_openai_format()
end,
__index = function(_, key)
-- Make it work as both function and table
if key == "get" then
return function()
if vim.tbl_count(tools) == 0 then
M.load_builtins()
end
return M.to_openai_format()
end
end
return nil
end,
})
--- Get definitions as a function (for backwards compatibility)
function M.get_definitions()
if vim.tbl_count(tools) == 0 then
M.load_builtins()
end
return M.to_openai_format()
end
--- Convert all tools to OpenAI function calling format
---@param filter? fun(tool: CoderTool): boolean Optional filter function
---@return table[] OpenAI-compatible tool definitions
function M.to_openai_format(filter)
local openai_tools = {}
for _, tool in pairs(tools) do
if not filter or filter(tool) then
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if param.default ~= nil then
properties[param.name].default = param.default
end
if not param.optional then
table.insert(required, param.name)
end
end
local description = type(tool.description) == "function" and tool.description() or tool.description
table.insert(openai_tools, {
type = "function",
["function"] = {
name = tool.name,
description = description,
parameters = {
type = "object",
properties = properties,
required = required,
},
},
})
end
end
return openai_tools
end
--- Convert all tools to Claude tool use format
---@param filter? fun(tool: CoderTool): boolean Optional filter function
---@return table[] Claude-compatible tool definitions
function M.to_claude_format(filter)
local claude_tools = {}
for _, tool in pairs(tools) do
if not filter or filter(tool) then
local properties = {}
local required = {}
for _, param in ipairs(tool.params or {}) do
properties[param.name] = {
type = param.type == "integer" and "number" or param.type,
description = param.description,
}
if not param.optional then
table.insert(required, param.name)
end
end
local description = type(tool.description) == "function" and tool.description() or tool.description
table.insert(claude_tools, {
name = tool.name,
description = description,
input_schema = {
type = "object",
properties = properties,
required = required,
},
})
end
end
return claude_tools
end
return M

View File

@@ -0,0 +1,149 @@
---@mod codetyper.agent.tools.view File viewing tool
---@brief [[
--- Tool for reading file contents with line range support.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "view"
M.description = [[Reads the content of a file.
Usage notes:
- Provide the file path relative to the project root
- Use start_line and end_line to read specific sections
- If content is truncated, use line ranges to read in chunks
- Returns JSON with content, total_line_count, and is_truncated]]
M.params = {
{
name = "path",
description = "Path to the file (relative to project root or absolute)",
type = "string",
},
{
name = "start_line",
description = "Line number to start reading (1-indexed)",
type = "integer",
optional = true,
},
{
name = "end_line",
description = "Line number to end reading (1-indexed, inclusive)",
type = "integer",
optional = true,
},
}
M.returns = {
{
name = "content",
description = "File contents as JSON with content, total_line_count, is_truncated",
type = "string",
},
{
name = "error",
description = "Error message if file could not be read",
type = "string",
optional = true,
},
}
M.requires_confirmation = false
--- Maximum content size before truncation
local MAX_CONTENT_SIZE = 200 * 1024 -- 200KB
---@param input {path: string, start_line?: integer, end_line?: integer}
---@param opts CoderToolOpts
---@return string|nil result
---@return string|nil error
function M.func(input, opts)
if not input.path then
return nil, "path is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Reading file: " .. input.path)
end
-- Resolve path
local path = input.path
if not vim.startswith(path, "/") then
-- Relative path - resolve from project root
local root = vim.fn.getcwd()
path = root .. "/" .. path
end
-- Check if file exists
local stat = vim.uv.fs_stat(path)
if not stat then
return nil, "File not found: " .. input.path
end
if stat.type == "directory" then
return nil, "Path is a directory: " .. input.path
end
-- Read file
local lines = vim.fn.readfile(path)
if not lines then
return nil, "Failed to read file: " .. input.path
end
-- Apply line range
local start_line = input.start_line or 1
local end_line = input.end_line or #lines
start_line = math.max(1, start_line)
end_line = math.min(#lines, end_line)
local total_lines = #lines
local selected_lines = {}
for i = start_line, end_line do
table.insert(selected_lines, lines[i])
end
-- Check for truncation
local content = table.concat(selected_lines, "\n")
local is_truncated = false
if #content > MAX_CONTENT_SIZE then
-- Truncate content
local truncated_lines = {}
local size = 0
for _, line in ipairs(selected_lines) do
size = size + #line + 1
if size > MAX_CONTENT_SIZE then
is_truncated = true
break
end
table.insert(truncated_lines, line)
end
content = table.concat(truncated_lines, "\n")
end
-- Return as JSON
local result = vim.json.encode({
content = content,
total_line_count = total_lines,
is_truncated = is_truncated,
start_line = start_line,
end_line = end_line,
})
if opts.on_complete then
opts.on_complete(result, nil)
end
return result, nil
end
return M

View File

@@ -0,0 +1,101 @@
---@mod codetyper.agent.tools.write File writing tool
---@brief [[
--- Tool for creating or overwriting files.
---@brief ]]
local Base = require("codetyper.agent.tools.base")
---@class CoderTool
local M = setmetatable({}, Base)
M.name = "write"
M.description = [[Creates or overwrites a file with new content.
IMPORTANT:
- This will completely replace the file contents
- Use 'edit' tool for partial modifications
- Parent directories will be created if needed]]
M.params = {
{
name = "path",
description = "Path to the file to write",
type = "string",
},
{
name = "content",
description = "Content to write to the file",
type = "string",
},
}
M.returns = {
{
name = "success",
description = "Whether the file was written successfully",
type = "boolean",
},
{
name = "error",
description = "Error message if write failed",
type = "string",
optional = true,
},
}
M.requires_confirmation = true
---@param input {path: string, content: string}
---@param opts CoderToolOpts
---@return boolean|nil result
---@return string|nil error
function M.func(input, opts)
if not input.path then
return nil, "path is required"
end
if not input.content then
return nil, "content is required"
end
-- Log the operation
if opts.on_log then
opts.on_log("Writing file: " .. input.path)
end
-- Resolve path
local path = input.path
if not vim.startswith(path, "/") then
path = vim.fn.getcwd() .. "/" .. path
end
-- Create parent directories
local dir = vim.fn.fnamemodify(path, ":h")
if vim.fn.isdirectory(dir) == 0 then
vim.fn.mkdir(dir, "p")
end
-- Write the file
local lines = vim.split(input.content, "\n", { plain = true })
local ok = pcall(vim.fn.writefile, lines, path)
if not ok then
return nil, "Failed to write file: " .. path
end
-- Reload buffer if open
local bufnr = vim.fn.bufnr(path)
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
vim.api.nvim_buf_call(bufnr, function()
vim.cmd("edit!")
end)
end
if opts.on_complete then
opts.on_complete(true, nil)
end
return true, nil
end
return M

View File

@@ -224,6 +224,86 @@ local function format_attached_files(attached_files)
return table.concat(parts, "")
end
--- Get coder companion file path for a target file
---@param target_path string Target file path
---@return string|nil Coder file path if exists
local function get_coder_companion_path(target_path)
if not target_path or target_path == "" then
return nil
end
-- Skip if target is already a coder file
if target_path:match("%.coder%.") then
return nil
end
local dir = vim.fn.fnamemodify(target_path, ":h")
local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension
local ext = vim.fn.fnamemodify(target_path, ":e")
local coder_path = dir .. "/" .. name .. ".coder." .. ext
if vim.fn.filereadable(coder_path) == 1 then
return coder_path
end
return nil
end
--- Read and format coder companion context (business logic, pseudo-code)
---@param target_path string Target file path
---@return string Formatted coder context
local function get_coder_context(target_path)
local coder_path = get_coder_companion_path(target_path)
if not coder_path then
return ""
end
local ok, lines = pcall(function()
return vim.fn.readfile(coder_path)
end)
if not ok or not lines or #lines == 0 then
return ""
end
local content = table.concat(lines, "\n")
-- Skip if only template comments (no actual content)
local stripped = content:gsub("^%s*", ""):gsub("%s*$", "")
if stripped == "" then
return ""
end
-- Check if there's meaningful content (not just template)
local has_content = false
for _, line in ipairs(lines) do
-- Skip comment lines that are part of the template
local trimmed = line:gsub("^%s*", "")
if not trimmed:match("^[%-#/]+%s*Coder companion")
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
and not trimmed:match("^[%-#/]+%s*Example:")
and not trimmed:match("^<!%-%-")
and trimmed ~= ""
and not trimmed:match("^[%-#/]+%s*$") then
has_content = true
break
end
end
if not has_content then
return ""
end
local ext = vim.fn.fnamemodify(coder_path, ":e")
return string.format(
"\n\n--- Business Context / Pseudo-code ---\n" ..
"The following describes the intended behavior and design for this file:\n" ..
"```%s\n%s\n```",
ext,
content:sub(1, 4000) -- Limit to 4000 chars
)
end
--- Format indexed project context for inclusion in prompt
---@param indexed_context table|nil
---@return string
@@ -309,8 +389,53 @@ local function build_prompt(event)
-- Format attached files
local attached_content = format_attached_files(event.attached_files)
-- Combine attached files and indexed context
local extra_context = attached_content .. indexed_content
-- Get coder companion context (business logic, pseudo-code)
local coder_context = get_coder_context(event.target_path)
-- Get brain memories - contextual recall based on current task
local brain_context = ""
pcall(function()
local brain = require("codetyper.brain")
if brain.is_initialized() then
-- Query brain for relevant memories based on:
-- 1. Current file (file-specific patterns)
-- 2. Prompt content (semantic similarity)
-- 3. Intent type (relevant past generations)
local query_text = event.prompt_content or ""
if event.scope and event.scope.name then
query_text = event.scope.name .. " " .. query_text
end
local result = brain.query({
query = query_text,
file = event.target_path,
max_results = 5,
types = { "pattern", "correction", "convention" },
})
if result and result.nodes and #result.nodes > 0 then
local memories = { "\n\n--- Learned Patterns & Conventions ---" }
for _, node in ipairs(result.nodes) do
if node.c then
local summary = node.c.s or ""
local detail = node.c.d or ""
if summary ~= "" then
table.insert(memories, "" .. summary)
if detail ~= "" and #detail < 200 then
table.insert(memories, " " .. detail)
end
end
end
end
if #memories > 1 then
brain_context = table.concat(memories, "\n")
end
end
end
end)
-- Combine all context sources: brain memories first, then coder context, attached files, indexed
local extra_context = brain_context .. coder_context .. attached_content .. indexed_content
-- Build context with scope information
local context = {
@@ -502,21 +627,21 @@ function M.start(worker)
end
end, worker.timeout_ms)
-- Get client and execute
local client, client_err = get_client(worker.worker_type)
if not client then
M.complete(worker, nil, client_err)
return
end
local prompt, context = build_prompt(worker.event)
-- Call the LLM
client.generate(prompt, context, function(response, err, usage)
-- Check if smart selection is enabled (memory-based provider selection)
local use_smart_selection = false
pcall(function()
local codetyper = require("codetyper")
local config = codetyper.get_config()
use_smart_selection = config.llm.smart_selection ~= false -- Default to true
end)
-- Define the response handler
local function handle_response(response, err, usage_or_metadata)
-- Cancel timeout timer
if worker.timer then
pcall(function()
-- Timer might have already fired
if type(worker.timer) == "userdata" and worker.timer.stop then
worker.timer:stop()
end
@@ -527,8 +652,45 @@ function M.start(worker)
return -- Already timed out or cancelled
end
-- Extract usage from metadata if smart_generate was used
local usage = usage_or_metadata
if type(usage_or_metadata) == "table" and usage_or_metadata.provider then
-- This is metadata from smart_generate
usage = nil
-- Update worker type to reflect actual provider used
worker.worker_type = usage_or_metadata.provider
-- Log if pondering occurred
if usage_or_metadata.pondered then
pcall(function()
local logs = require("codetyper.agent.logs")
logs.add({
type = "info",
message = string.format(
"Pondering: %s (agreement: %.0f%%)",
usage_or_metadata.corrected and "corrected" or "validated",
(usage_or_metadata.agreement or 1) * 100
),
})
end)
end
end
M.complete(worker, response, err, usage)
end)
end
-- Use smart selection or direct client
if use_smart_selection then
local llm = require("codetyper.llm")
llm.smart_generate(prompt, context, handle_response)
else
-- Get client and execute directly
local client, client_err = get_client(worker.worker_type)
if not client then
M.complete(worker, nil, client_err)
return
end
client.generate(prompt, context, handle_response)
end
end
--- Complete worker execution

View File

@@ -312,7 +312,9 @@ local function append_log_to_output(entry)
}
local icon = icons[entry.level] or ""
local formatted = string.format("[%s] %s %s", entry.timestamp, icon, entry.message)
-- Sanitize message - replace newlines with spaces to prevent nvim_buf_set_lines error
local sanitized_message = entry.message:gsub("\n", " "):gsub("\r", "")
local formatted = string.format("[%s] %s %s", entry.timestamp, icon, sanitized_message)
vim.schedule(function()
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then

View File

@@ -18,6 +18,16 @@ local processed_prompts = {}
--- Track if we're currently asking for preferences
local asking_preference = false
--- Track if we're currently processing prompts (busy flag)
local is_processing = false
--- Track the previous mode for visual mode detection
local previous_mode = "n"
--- Debounce timer for prompt processing
local prompt_process_timer = nil
local PROMPT_PROCESS_DEBOUNCE_MS = 200 -- Wait 200ms after mode change before processing
--- Generate a unique key for a prompt
---@param bufnr number Buffer number
---@param prompt table Prompt object
@@ -64,6 +74,20 @@ function M.setup()
desc = "Check for closed prompt tags on InsertLeave",
})
-- Track mode changes for visual mode detection
vim.api.nvim_create_autocmd("ModeChanged", {
group = group,
pattern = "*",
callback = function(ev)
-- Extract old mode from pattern (format: "old_mode:new_mode")
local old_mode = ev.match:match("^(.-):")
if old_mode then
previous_mode = old_mode
end
end,
desc = "Track previous mode for visual mode detection",
})
-- Auto-process prompts when entering normal mode (works on ALL files)
vim.api.nvim_create_autocmd("ModeChanged", {
group = group,
@@ -74,10 +98,33 @@ function M.setup()
if buftype ~= "" then
return
end
-- Slight delay to let buffer settle
vim.defer_fn(function()
-- Skip if currently processing (avoid concurrent processing)
if is_processing then
return
end
-- Skip if coming from visual mode (v, V, CTRL-V) - user is still editing
if previous_mode == "v" or previous_mode == "V" or previous_mode == "\22" then
return
end
-- Cancel any pending processing timer
if prompt_process_timer then
prompt_process_timer:stop()
prompt_process_timer = nil
end
-- Debounced processing - wait for user to truly be idle
prompt_process_timer = vim.defer_fn(function()
prompt_process_timer = nil
-- Double-check we're still in normal mode
local mode = vim.api.nvim_get_mode().mode
if mode ~= "n" then
return
end
M.check_all_prompts_with_preference()
end, 50)
end, PROMPT_PROCESS_DEBOUNCE_MS)
end,
desc = "Auto-process closed prompts when entering normal mode",
})
@@ -92,6 +139,10 @@ function M.setup()
if buftype ~= "" then
return
end
-- Skip if currently processing
if is_processing then
return
end
local mode = vim.api.nvim_get_mode().mode
if mode == "n" then
M.check_all_prompts_with_preference()
@@ -291,6 +342,12 @@ end
--- Check if the buffer has a newly closed prompt and auto-process (works on ANY file)
function M.check_for_closed_prompt()
-- Skip if already processing
if is_processing then
return
end
is_processing = true
local config = get_config_safe()
local parser = require("codetyper.parser")
@@ -299,6 +356,7 @@ function M.check_for_closed_prompt()
-- Skip if no file
if current_file == "" then
is_processing = false
return
end
@@ -308,6 +366,7 @@ function M.check_for_closed_prompt()
local lines = vim.api.nvim_buf_get_lines(bufnr, line - 1, line, false)
if #lines == 0 then
is_processing = false
return
end
@@ -323,6 +382,7 @@ function M.check_for_closed_prompt()
-- Check if already processed
if processed_prompts[prompt_key] then
is_processing = false
return
end
@@ -366,23 +426,38 @@ function M.check_for_closed_prompt()
-- Clean prompt content (strip file references)
local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content))
-- Resolve scope in target file FIRST (need it to adjust intent)
local target_bufnr = vim.fn.bufnr(target_path)
if target_bufnr == -1 then
target_bufnr = bufnr
end
-- Check if we're working from a coder file
local is_from_coder_file = utils.is_coder_file(current_file)
-- Resolve scope in target file FIRST (need it to adjust intent)
-- Only resolve scope if NOT from coder file (line numbers don't apply)
local target_bufnr = vim.fn.bufnr(target_path)
local scope = nil
local scope_text = nil
local scope_range = nil
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
if scope and scope.type ~= "file" then
scope_text = scope.text
scope_range = {
start_line = scope.range.start_row,
end_line = scope.range.end_row,
}
if not is_from_coder_file then
-- Prompt is in the actual source file, use line position for scope
if target_bufnr == -1 then
target_bufnr = bufnr
end
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
if scope and scope.type ~= "file" then
scope_text = scope.text
scope_range = {
start_line = scope.range.start_row,
end_line = scope.range.end_row,
}
end
else
-- Prompt is in coder file - load target if needed, but don't use scope
-- Code from coder files should append to target by default
if target_bufnr == -1 then
target_bufnr = vim.fn.bufadd(target_path)
if target_bufnr ~= 0 then
vim.fn.bufload(target_bufnr)
end
end
end
-- Detect intent from prompt
@@ -390,7 +465,8 @@ function M.check_for_closed_prompt()
-- IMPORTANT: If prompt is inside a function/method and intent is "add",
-- override to "complete" since we're completing the function body
if scope and (scope.type == "function" or scope.type == "method") then
-- But NOT for coder files - they should use "add/append" by default
if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then
if intent.type == "add" or intent.action == "insert" or intent.action == "append" then
-- Override to complete the function instead of adding new code
intent = {
@@ -403,6 +479,16 @@ function M.check_for_closed_prompt()
end
end
-- For coder files, default to "add" with "append" action
if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then
intent = {
type = intent.type == "complete" and "add" or intent.type,
confidence = intent.confidence,
action = "append",
keywords = intent.keywords,
}
end
-- Determine priority based on intent
local priority = 2 -- Normal
if intent.type == "fix" or intent.type == "complete" then
@@ -446,6 +532,7 @@ function M.check_for_closed_prompt()
end
end
end
is_processing = false
end
--- Check and process all closed prompts in the buffer (works on ANY file)
@@ -507,7 +594,8 @@ function M.check_all_prompts()
-- Get target path - for coder files, get the target; for regular files, use self
local target_path
if utils.is_coder_file(current_file) then
local is_from_coder_file = utils.is_coder_file(current_file)
if is_from_coder_file then
target_path = utils.get_target_path(current_file)
else
target_path = current_file
@@ -520,22 +608,33 @@ function M.check_all_prompts()
local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content))
-- Resolve scope in target file FIRST (need it to adjust intent)
-- Only resolve scope if NOT from coder file (line numbers don't apply)
local target_bufnr = vim.fn.bufnr(target_path)
if target_bufnr == -1 then
target_bufnr = bufnr -- Use current buffer if target not loaded
end
local scope = nil
local scope_text = nil
local scope_range = nil
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
if scope and scope.type ~= "file" then
scope_text = scope.text
scope_range = {
start_line = scope.range.start_row,
end_line = scope.range.end_row,
}
if not is_from_coder_file then
-- Prompt is in the actual source file, use line position for scope
if target_bufnr == -1 then
target_bufnr = bufnr
end
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
if scope and scope.type ~= "file" then
scope_text = scope.text
scope_range = {
start_line = scope.range.start_row,
end_line = scope.range.end_row,
}
end
else
-- Prompt is in coder file - load target if needed
if target_bufnr == -1 then
target_bufnr = vim.fn.bufadd(target_path)
if target_bufnr ~= 0 then
vim.fn.bufload(target_bufnr)
end
end
end
-- Detect intent from prompt
@@ -543,7 +642,8 @@ function M.check_all_prompts()
-- IMPORTANT: If prompt is inside a function/method and intent is "add",
-- override to "complete" since we're completing the function body
if scope and (scope.type == "function" or scope.type == "method") then
-- But NOT for coder files - they should use "add/append" by default
if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then
if intent.type == "add" or intent.action == "insert" or intent.action == "append" then
-- Override to complete the function instead of adding new code
intent = {
@@ -556,6 +656,16 @@ function M.check_all_prompts()
end
end
-- For coder files, default to "add" with "append" action
if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then
intent = {
type = intent.type == "complete" and "add" or intent.type,
confidence = intent.confidence,
action = "append",
keywords = intent.keywords,
}
end
-- Determine priority based on intent
local priority = 2
if intent.type == "fix" or intent.type == "complete" then
@@ -932,20 +1042,17 @@ function M.update_brain_from_file(filepath)
local summary = vim.fn.fnamemodify(filepath, ":t") .. " - " .. table.concat(parts, "; ")
-- Learn this pattern
-- Learn this pattern - use "pattern_detected" type to match the pattern learner
brain.learn({
type = "pattern",
type = "pattern_detected",
file = filepath,
content = {
summary = summary,
detail = #functions .. " functions, " .. #classes .. " classes",
code = nil,
},
context = {
file = filepath,
timestamp = os.time(),
data = {
name = summary,
description = #functions .. " functions, " .. #classes .. " classes",
language = ext,
functions = functions,
classes = classes,
symbols = vim.tbl_map(function(f) return f.name end, functions),
example = nil,
},
})
end
@@ -997,6 +1104,126 @@ end
--- Auto-index a file by creating/opening its coder companion
---@param bufnr number Buffer number
--- Directories to ignore for coder file creation
local ignored_directories = {
".git",
".coder",
".claude",
".vscode",
".idea",
"node_modules",
"vendor",
"dist",
"build",
"target",
"__pycache__",
".cache",
".npm",
".yarn",
"coverage",
".next",
".nuxt",
".svelte-kit",
"out",
"bin",
"obj",
}
--- Files to ignore for coder file creation (exact names or patterns)
local ignored_files = {
-- Git files
".gitignore",
".gitattributes",
".gitmodules",
-- Lock files
"package-lock.json",
"yarn.lock",
"pnpm-lock.yaml",
"Cargo.lock",
"Gemfile.lock",
"poetry.lock",
"composer.lock",
-- Config files that don't need coder companions
".env",
".env.local",
".env.development",
".env.production",
".eslintrc",
".eslintrc.json",
".prettierrc",
".prettierrc.json",
".editorconfig",
".dockerignore",
"Dockerfile",
"docker-compose.yml",
"docker-compose.yaml",
".npmrc",
".yarnrc",
".nvmrc",
"tsconfig.json",
"jsconfig.json",
"babel.config.js",
"webpack.config.js",
"vite.config.js",
"rollup.config.js",
"jest.config.js",
"vitest.config.js",
".stylelintrc",
"tailwind.config.js",
"postcss.config.js",
-- Other non-code files
"README.md",
"CHANGELOG.md",
"LICENSE",
"LICENSE.md",
"CONTRIBUTING.md",
"Makefile",
"CMakeLists.txt",
}
--- Check if a file path contains an ignored directory
---@param filepath string Full file path
---@return boolean
local function is_in_ignored_directory(filepath)
for _, dir in ipairs(ignored_directories) do
-- Check for /dirname/ or /dirname at end
if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then
return true
end
-- Also check for dirname/ at start (relative paths)
if filepath:match("^" .. dir .. "/") then
return true
end
end
return false
end
--- Check if a file should be ignored for coder companion creation
---@param filepath string Full file path
---@return boolean
local function should_ignore_for_coder(filepath)
local filename = vim.fn.fnamemodify(filepath, ":t")
-- Check exact filename matches
for _, ignored in ipairs(ignored_files) do
if filename == ignored then
return true
end
end
-- Check if file starts with dot (hidden/config files)
if filename:match("^%.") then
return true
end
-- Check if in ignored directory
if is_in_ignored_directory(filepath) then
return true
end
return false
end
function M.auto_index_file(bufnr)
-- Skip if buffer is invalid
if not vim.api.nvim_buf_is_valid(bufnr) then
@@ -1031,6 +1258,11 @@ function M.auto_index_file(bufnr)
return
end
-- Skip ignored directories and files (node_modules, .git, config files, etc.)
if should_ignore_for_coder(filepath) then
return
end
-- Skip if auto_index is disabled in config
local codetyper = require("codetyper")
local config = codetyper.get_config()
@@ -1047,23 +1279,137 @@ function M.auto_index_file(bufnr)
-- Check if coder file already exists
local coder_exists = utils.file_exists(coder_path)
-- Create coder file with template if it doesn't exist
-- Create coder file with pseudo-code context if it doesn't exist
if not coder_exists then
local filename = vim.fn.fnamemodify(filepath, ":t")
local template = string.format(
[[-- Coder companion for %s
-- Use /@ @/ tags to write pseudo-code prompts
-- Example:
-- /@
-- Add a function that validates user input
-- - Check for empty strings
-- - Validate email format
-- @/
local ext = vim.fn.fnamemodify(filepath, ":e")
]],
filename
)
utils.write_file(coder_path, template)
-- Determine comment style based on extension
local comment_prefix = "--"
local comment_block_start = "--[["
local comment_block_end = "]]"
if ext == "ts" or ext == "tsx" or ext == "js" or ext == "jsx" or ext == "java" or ext == "c" or ext == "cpp" or ext == "cs" or ext == "go" or ext == "rs" then
comment_prefix = "//"
comment_block_start = "/*"
comment_block_end = "*/"
elseif ext == "py" or ext == "rb" or ext == "yaml" or ext == "yml" then
comment_prefix = "#"
comment_block_start = '"""'
comment_block_end = '"""'
end
-- Read target file to analyze its structure
local content = ""
pcall(function()
local lines = vim.fn.readfile(filepath)
if lines then
content = table.concat(lines, "\n")
end
end)
-- Extract structure from the file
local functions = extract_functions(content, ext)
local classes = extract_classes(content, ext)
local imports = extract_imports(content, ext)
-- Build pseudo-code context
local pseudo_code = {}
-- Header
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
table.insert(pseudo_code, comment_prefix .. " CODER COMPANION: " .. filename)
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
table.insert(pseudo_code, comment_prefix .. " This file describes the business logic and behavior of " .. filename)
table.insert(pseudo_code, comment_prefix .. " Edit this pseudo-code to guide code generation.")
table.insert(pseudo_code, comment_prefix .. " Use /@ @/ tags for specific generation requests.")
table.insert(pseudo_code, comment_prefix .. "")
-- Module purpose
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " MODULE PURPOSE:")
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " TODO: Describe what this module/file is responsible for")
table.insert(pseudo_code, comment_prefix .. " Example: \"Handles user authentication and session management\"")
table.insert(pseudo_code, comment_prefix .. "")
-- Dependencies section
if #imports > 0 then
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " DEPENDENCIES:")
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
for _, imp in ipairs(imports) do
table.insert(pseudo_code, comment_prefix .. "" .. imp)
end
table.insert(pseudo_code, comment_prefix .. "")
end
-- Classes section
if #classes > 0 then
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " CLASSES:")
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
for _, class in ipairs(classes) do
table.insert(pseudo_code, comment_prefix .. "")
table.insert(pseudo_code, comment_prefix .. " class " .. class.name .. ":")
table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - describe what this class represents")
table.insert(pseudo_code, comment_prefix .. " RESPONSIBILITIES:")
table.insert(pseudo_code, comment_prefix .. " - TODO: list main responsibilities")
end
table.insert(pseudo_code, comment_prefix .. "")
end
-- Functions section
if #functions > 0 then
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " FUNCTIONS:")
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
for _, func in ipairs(functions) do
table.insert(pseudo_code, comment_prefix .. "")
table.insert(pseudo_code, comment_prefix .. " " .. func.name .. "():")
table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - what does this function do?")
table.insert(pseudo_code, comment_prefix .. " INPUTS: TODO - describe parameters")
table.insert(pseudo_code, comment_prefix .. " OUTPUTS: TODO - describe return value")
table.insert(pseudo_code, comment_prefix .. " BEHAVIOR:")
table.insert(pseudo_code, comment_prefix .. " - TODO: describe step-by-step logic")
end
table.insert(pseudo_code, comment_prefix .. "")
end
-- If empty file, provide starter template
if #functions == 0 and #classes == 0 then
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " PLANNED STRUCTURE:")
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " TODO: Describe what you want to build in this file")
table.insert(pseudo_code, comment_prefix .. "")
table.insert(pseudo_code, comment_prefix .. " Example pseudo-code:")
table.insert(pseudo_code, comment_prefix .. " /@")
table.insert(pseudo_code, comment_prefix .. " Create a module that:")
table.insert(pseudo_code, comment_prefix .. " 1. Exports a main function")
table.insert(pseudo_code, comment_prefix .. " 2. Handles errors gracefully")
table.insert(pseudo_code, comment_prefix .. " 3. Returns structured data")
table.insert(pseudo_code, comment_prefix .. " @/")
table.insert(pseudo_code, comment_prefix .. "")
end
-- Business rules section
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " BUSINESS RULES:")
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
table.insert(pseudo_code, comment_prefix .. " TODO: Document any business rules, constraints, or requirements")
table.insert(pseudo_code, comment_prefix .. " Example:")
table.insert(pseudo_code, comment_prefix .. " - Users must be authenticated before accessing this feature")
table.insert(pseudo_code, comment_prefix .. " - Data must be validated before saving")
table.insert(pseudo_code, comment_prefix .. " - Errors should be logged but not exposed to users")
table.insert(pseudo_code, comment_prefix .. "")
-- Footer with generation tags example
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
table.insert(pseudo_code, comment_prefix .. " Use /@ @/ tags below to request code generation:")
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
table.insert(pseudo_code, "")
utils.write_file(coder_path, table.concat(pseudo_code, "\n"))
end
-- Notify user about the coder companion

View File

@@ -111,7 +111,7 @@ function M.compute_relevance(node, opts)
return score
end
--- Traverse graph from seed nodes
--- Traverse graph from seed nodes (basic traversal)
---@param seed_ids string[] Starting node IDs
---@param depth number Traversal depth
---@param edge_types? EdgeType[] Edge types to follow
@@ -157,6 +157,73 @@ local function traverse(seed_ids, depth, edge_types)
return discovered
end
--- Spreading activation - mimics human associative memory
--- Activation spreads from seed nodes along edges, decaying by weight
--- Nodes accumulate activation from multiple paths (like neural pathways)
---@param seed_activations table<string, number> Initial activations {node_id: activation}
---@param max_iterations number Max spread iterations (default 3)
---@param decay number Activation decay per hop (default 0.5)
---@param threshold number Minimum activation to continue spreading (default 0.1)
---@return table<string, number> Final activations {node_id: accumulated_activation}
local function spreading_activation(seed_activations, max_iterations, decay, threshold)
local edge_mod = get_edge_module()
max_iterations = max_iterations or 3
decay = decay or 0.5
threshold = threshold or 0.1
-- Accumulated activation for each node
local activation = {}
for node_id, act in pairs(seed_activations) do
activation[node_id] = act
end
-- Current frontier with their activation levels
local frontier = {}
for node_id, act in pairs(seed_activations) do
frontier[node_id] = act
end
-- Spread activation iteratively
for _ = 1, max_iterations do
local next_frontier = {}
for source_id, source_activation in pairs(frontier) do
-- Get all outgoing edges
local edges = edge_mod.get_edges(source_id, nil, "both")
for _, edge in ipairs(edges) do
-- Determine target (could be source or target of edge)
local target_id = edge.s == source_id and edge.t or edge.s
-- Calculate spreading activation
-- Activation = source_activation * edge_weight * decay
local edge_weight = edge.p and edge.p.w or 0.5
local spread_amount = source_activation * edge_weight * decay
-- Only spread if above threshold
if spread_amount >= threshold then
-- Accumulate activation (multiple paths add up)
activation[target_id] = (activation[target_id] or 0) + spread_amount
-- Add to next frontier if not already processed with higher activation
if not next_frontier[target_id] or next_frontier[target_id] < spread_amount then
next_frontier[target_id] = spread_amount
end
end
end
end
-- Stop if no more spreading
if vim.tbl_count(next_frontier) == 0 then
break
end
frontier = next_frontier
end
return activation
end
--- Execute a query across all dimensions
---@param opts QueryOpts Query options
---@return QueryResult
@@ -236,28 +303,49 @@ function M.execute(opts)
end
end
-- 4. Combine and deduplicate
-- 4. Combine all found nodes and compute seed activations
local all_nodes = {}
local seed_activations = {}
for _, category in pairs(results) do
for id, node in pairs(category) do
if not all_nodes[id] then
all_nodes[id] = node
-- Compute initial activation based on relevance
local relevance = M.compute_relevance(node, opts)
seed_activations[id] = relevance
end
end
end
-- 5. Score and rank
-- 5. Apply spreading activation - like human associative memory
-- Activation spreads from seed nodes along edges, accumulating
-- Nodes connected to multiple relevant seeds get higher activation
local final_activations = spreading_activation(
seed_activations,
opts.spread_iterations or 3, -- How far activation spreads
opts.spread_decay or 0.5, -- How much activation decays per hop
opts.spread_threshold or 0.05 -- Minimum activation to continue spreading
)
-- 6. Score and rank by combined activation
local scored = {}
for id, node in pairs(all_nodes) do
local relevance = M.compute_relevance(node, opts)
table.insert(scored, { node = node, relevance = relevance })
for id, activation in pairs(final_activations) do
local node = all_nodes[id] or node_mod.get(id)
if node then
all_nodes[id] = node
-- Final score = spreading activation + base relevance
local base_relevance = M.compute_relevance(node, opts)
local final_score = (activation * 0.6) + (base_relevance * 0.4)
table.insert(scored, { node = node, relevance = final_score, activation = activation })
end
end
table.sort(scored, function(a, b)
return a.relevance > b.relevance
end)
-- 6. Apply limit
-- 7. Apply limit
local limit = opts.limit or 50
local result_nodes = {}
local truncated = #scored > limit
@@ -266,7 +354,7 @@ function M.execute(opts)
table.insert(result_nodes, scored[i].node)
end
-- 7. Get edges between result nodes
-- 8. Get edges between result nodes
local edge_mod = get_edge_module()
local result_edges = {}
local node_ids = {}
@@ -291,11 +379,17 @@ function M.execute(opts)
file_count = vim.tbl_count(results.file),
temporal_count = vim.tbl_count(results.temporal),
total_scored = #scored,
seed_nodes = vim.tbl_count(seed_activations),
activated_nodes = vim.tbl_count(final_activations),
},
truncated = truncated,
}
end
--- Expose spreading activation for direct use
--- Useful for custom activation patterns or debugging
M.spreading_activation = spreading_activation
--- Find nodes by file
---@param filepath string File path
---@param limit? number Max results

View File

@@ -9,6 +9,10 @@ local M = {}
---@param event LearnEvent Learning event
---@return boolean
function M.detect(event)
if not event or not event.type then
return false
end
local valid_types = {
"code_completion",
"file_indexed",

View File

@@ -287,6 +287,118 @@ local function cmd_agent_stop()
end
end
--- Run the agentic loop with a task
---@param task string The task to accomplish
---@param agent_name? string Optional agent name
local function cmd_agentic_run(task, agent_name)
local agentic = require("codetyper.agent.agentic")
local logs_panel = require("codetyper.logs_panel")
local logs = require("codetyper.agent.logs")
-- Open logs panel
logs_panel.open()
logs.info("Starting agentic task: " .. task:sub(1, 50) .. "...")
utils.notify("Running agentic task...", vim.log.levels.INFO)
-- Get current file for context
local current_file = vim.fn.expand("%:p")
local files = {}
if current_file ~= "" then
table.insert(files, current_file)
end
agentic.run({
task = task,
files = files,
agent = agent_name or "coder",
on_status = function(status)
logs.thinking(status)
end,
on_tool_start = function(name, args)
logs.info("Tool: " .. name)
end,
on_tool_end = function(name, result, err)
if err then
logs.error(name .. " failed: " .. err)
else
logs.debug(name .. " completed")
end
end,
on_file_change = function(path, action)
logs.info("File " .. action .. ": " .. path)
end,
on_message = function(msg)
if msg.role == "assistant" and type(msg.content) == "string" and msg.content ~= "" then
logs.thinking(msg.content:sub(1, 100) .. "...")
end
end,
on_complete = function(result, err)
if err then
logs.error("Task failed: " .. err)
utils.notify("Agentic task failed: " .. err, vim.log.levels.ERROR)
else
logs.info("Task completed successfully")
utils.notify("Agentic task completed!", vim.log.levels.INFO)
if result and result ~= "" then
-- Show summary in a float
vim.schedule(function()
vim.notify("Result:\n" .. result:sub(1, 500), vim.log.levels.INFO)
end)
end
end
end,
})
end
--- List available agents
local function cmd_agentic_list()
local agentic = require("codetyper.agent.agentic")
local agents = agentic.list_agents()
local lines = {
"Available Agents",
"================",
"",
}
for _, agent in ipairs(agents) do
local badge = agent.builtin and "[builtin]" or "[custom]"
table.insert(lines, string.format(" %s %s", agent.name, badge))
table.insert(lines, string.format(" %s", agent.description))
table.insert(lines, "")
end
table.insert(lines, "Use :CoderAgenticRun <task> [agent] to run a task")
table.insert(lines, "Use :CoderAgenticInit to create custom agents")
utils.notify(table.concat(lines, "\n"))
end
--- Initialize .coder/agents/ and .coder/rules/ directories
local function cmd_agentic_init()
local agentic = require("codetyper.agent.agentic")
agentic.init()
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
local lines = {
"Initialized Coder directories:",
"",
" " .. agents_dir,
" - example.md (template for custom agents)",
"",
" " .. rules_dir,
" - code-style.md (template for project rules)",
"",
"Edit these files to customize agent behavior.",
"Create new .md files to add more agents/rules.",
}
utils.notify(table.concat(lines, "\n"))
end
--- Show chat type switcher modal (Ask/Agent)
local function cmd_type_toggle()
local switcher = require("codetyper.chat_switcher")
@@ -844,6 +956,65 @@ end
--- Main command dispatcher
---@param args table Command arguments
--- Show LLM accuracy statistics
local function cmd_llm_stats()
local llm = require("codetyper.llm")
local stats = llm.get_accuracy_stats()
local lines = {
"LLM Provider Accuracy Statistics",
"================================",
"",
string.format("Ollama:"),
string.format(" Total requests: %d", stats.ollama.total),
string.format(" Correct: %d", stats.ollama.correct),
string.format(" Accuracy: %.1f%%", stats.ollama.accuracy * 100),
"",
string.format("Copilot:"),
string.format(" Total requests: %d", stats.copilot.total),
string.format(" Correct: %d", stats.copilot.correct),
string.format(" Accuracy: %.1f%%", stats.copilot.accuracy * 100),
"",
"Note: Smart selection prefers Ollama when brain memories",
"provide enough context. Accuracy improves over time via",
"pondering (verification with other LLMs).",
}
vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO)
end
--- Report feedback on last LLM response
---@param was_good boolean Whether the response was good
local function cmd_llm_feedback(was_good)
local llm = require("codetyper.llm")
-- Get the last used provider from logs or default
local provider = "ollama" -- Default assumption
-- Try to get actual last provider from logs
pcall(function()
local logs = require("codetyper.agent.logs")
local entries = logs.get(10)
for i = #entries, 1, -1 do
local entry = entries[i]
if entry.message and entry.message:match("^LLM:") then
provider = entry.message:match("LLM: (%w+)") or provider
break
end
end
end)
llm.report_feedback(provider, was_good)
local feedback_type = was_good and "positive" or "negative"
utils.notify(string.format("Reported %s feedback for %s", feedback_type, provider), vim.log.levels.INFO)
end
--- Reset LLM accuracy statistics
local function cmd_llm_reset_stats()
local selector = require("codetyper.llm.selector")
selector.reset_accuracy_stats()
utils.notify("LLM accuracy statistics reset", vim.log.levels.INFO)
end
local function coder_cmd(args)
local subcommand = args.fargs[1] or "toggle"
@@ -872,6 +1043,17 @@ local function coder_cmd(args)
["logs-toggle"] = cmd_logs_toggle,
["queue-status"] = cmd_queue_status,
["queue-process"] = cmd_queue_process,
-- Agentic commands
["agentic-run"] = function(args)
local task = table.concat(vim.list_slice(args.fargs, 2), " ")
if task == "" then
utils.notify("Usage: Coder agentic-run <task> [agent]", vim.log.levels.WARN)
return
end
cmd_agentic_run(task)
end,
["agentic-list"] = cmd_agentic_list,
["agentic-init"] = cmd_agentic_init,
["index-project"] = cmd_index_project,
["index-status"] = cmd_index_status,
memories = cmd_memories,
@@ -901,6 +1083,41 @@ local function coder_cmd(args)
end
end
end,
-- LLM smart selection commands
["llm-stats"] = cmd_llm_stats,
["llm-feedback-good"] = function()
cmd_llm_feedback(true)
end,
["llm-feedback-bad"] = function()
cmd_llm_feedback(false)
end,
["llm-reset-stats"] = cmd_llm_reset_stats,
-- Cost tracking commands
["cost"] = function()
local cost = require("codetyper.cost")
cost.toggle()
end,
["cost-clear"] = function()
local cost = require("codetyper.cost")
cost.clear()
end,
-- Credentials management commands
["add-api-key"] = function()
local credentials = require("codetyper.credentials")
credentials.interactive_add()
end,
["remove-api-key"] = function()
local credentials = require("codetyper.credentials")
credentials.interactive_remove()
end,
["credentials"] = function()
local credentials = require("codetyper.credentials")
credentials.show_status()
end,
["switch-provider"] = function()
local credentials = require("codetyper.credentials")
credentials.interactive_switch_provider()
end,
}
local cmd_fn = commands[subcommand]
@@ -922,10 +1139,14 @@ function M.setup()
"ask", "ask-close", "ask-toggle", "ask-clear",
"transform", "transform-cursor",
"agent", "agent-close", "agent-toggle", "agent-stop",
"agentic-run", "agentic-list", "agentic-init",
"type-toggle", "logs-toggle",
"queue-status", "queue-process",
"index-project", "index-status", "memories", "forget",
"auto-toggle", "auto-set",
"llm-stats", "llm-feedback-good", "llm-feedback-bad", "llm-reset-stats",
"cost", "cost-clear",
"add-api-key", "remove-api-key", "credentials", "switch-provider",
}
end,
desc = "Codetyper.nvim commands",
@@ -997,6 +1218,31 @@ function M.setup()
cmd_agent_stop()
end, { desc = "Stop running agent" })
-- Agentic commands (full IDE-like agent functionality)
vim.api.nvim_create_user_command("CoderAgenticRun", function(opts)
local task = opts.args
if task == "" then
vim.ui.input({ prompt = "Task: " }, function(input)
if input and input ~= "" then
cmd_agentic_run(input)
end
end)
else
cmd_agentic_run(task)
end
end, {
desc = "Run agentic task (IDE-like multi-file changes)",
nargs = "*",
})
vim.api.nvim_create_user_command("CoderAgenticList", function()
cmd_agentic_list()
end, { desc = "List available agents" })
vim.api.nvim_create_user_command("CoderAgenticInit", function()
cmd_agentic_init()
end, { desc = "Initialize .coder/agents/ and .coder/rules/ directories" })
-- Chat type switcher command
vim.api.nvim_create_user_command("CoderType", function()
cmd_type_toggle()
@@ -1075,6 +1321,147 @@ function M.setup()
end,
})
-- Brain feedback command - teach the brain from your experience
vim.api.nvim_create_user_command("CoderFeedback", function(opts)
local brain = require("codetyper.brain")
if not brain.is_initialized() then
vim.notify("Brain not initialized", vim.log.levels.WARN)
return
end
local feedback_type = opts.args:lower()
local current_file = vim.fn.expand("%:p")
if feedback_type == "good" or feedback_type == "accept" or feedback_type == "+" then
-- Learn positive feedback
brain.learn({
type = "user_feedback",
file = current_file,
timestamp = os.time(),
data = {
feedback = "accepted",
description = "User marked code as good/accepted",
},
})
vim.notify("Brain: Learned positive feedback ✓", vim.log.levels.INFO)
elseif feedback_type == "bad" or feedback_type == "reject" or feedback_type == "-" then
-- Learn negative feedback
brain.learn({
type = "user_feedback",
file = current_file,
timestamp = os.time(),
data = {
feedback = "rejected",
description = "User marked code as bad/rejected",
},
})
vim.notify("Brain: Learned negative feedback ✗", vim.log.levels.INFO)
elseif feedback_type == "stats" or feedback_type == "status" then
-- Show brain stats
local stats = brain.stats()
local msg = string.format(
"Brain Stats:\n• Nodes: %d\n• Edges: %d\n• Pending: %d\n• Deltas: %d",
stats.node_count or 0,
stats.edge_count or 0,
stats.pending_changes or 0,
stats.delta_count or 0
)
vim.notify(msg, vim.log.levels.INFO)
else
vim.notify("Usage: CoderFeedback <good|bad|stats>", vim.log.levels.INFO)
end
end, {
desc = "Give feedback to the brain (good/bad/stats)",
nargs = "?",
complete = function()
return { "good", "bad", "stats" }
end,
})
-- Brain stats command
vim.api.nvim_create_user_command("CoderBrain", function(opts)
local brain = require("codetyper.brain")
if not brain.is_initialized() then
vim.notify("Brain not initialized", vim.log.levels.WARN)
return
end
local action = opts.args:lower()
if action == "stats" or action == "" then
local stats = brain.stats()
local lines = {
"╭─────────────────────────────────╮",
"│ CODETYPER BRAIN │",
"╰─────────────────────────────────╯",
"",
string.format(" Nodes: %d", stats.node_count or 0),
string.format(" Edges: %d", stats.edge_count or 0),
string.format(" Deltas: %d", stats.delta_count or 0),
string.format(" Pending: %d", stats.pending_changes or 0),
"",
" The more you use Codetyper,",
" the smarter it becomes!",
}
vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO)
elseif action == "commit" then
local hash = brain.commit("Manual commit")
if hash then
vim.notify("Brain: Committed changes (hash: " .. hash:sub(1, 8) .. ")", vim.log.levels.INFO)
else
vim.notify("Brain: Nothing to commit", vim.log.levels.INFO)
end
elseif action == "flush" then
brain.flush()
vim.notify("Brain: Flushed to disk", vim.log.levels.INFO)
elseif action == "prune" then
local pruned = brain.prune()
vim.notify("Brain: Pruned " .. pruned .. " low-value nodes", vim.log.levels.INFO)
else
vim.notify("Usage: CoderBrain <stats|commit|flush|prune>", vim.log.levels.INFO)
end
end, {
desc = "Brain management commands",
nargs = "?",
complete = function()
return { "stats", "commit", "flush", "prune" }
end,
})
-- Cost estimation command
vim.api.nvim_create_user_command("CoderCost", function()
local cost = require("codetyper.cost")
cost.toggle()
end, { desc = "Show LLM cost estimation window" })
-- Credentials management commands
vim.api.nvim_create_user_command("CoderAddApiKey", function()
local credentials = require("codetyper.credentials")
credentials.interactive_add()
end, { desc = "Add or update LLM provider API key" })
vim.api.nvim_create_user_command("CoderRemoveApiKey", function()
local credentials = require("codetyper.credentials")
credentials.interactive_remove()
end, { desc = "Remove LLM provider credentials" })
vim.api.nvim_create_user_command("CoderCredentials", function()
local credentials = require("codetyper.credentials")
credentials.show_status()
end, { desc = "Show credentials status" })
vim.api.nvim_create_user_command("CoderSwitchProvider", function()
local credentials = require("codetyper.credentials")
credentials.interactive_switch_provider()
end, { desc = "Switch active LLM provider" })
-- Setup default keymaps
M.setup_keymaps()
end

750
lua/codetyper/cost.lua Normal file
View File

@@ -0,0 +1,750 @@
---@mod codetyper.cost Cost estimation for LLM usage
---@brief [[
--- Tracks token usage and estimates costs based on model pricing.
--- Prices are per 1M tokens. Persists usage data in the brain.
---@brief ]]
local M = {}
local utils = require("codetyper.utils")
--- Cost history file name
local COST_HISTORY_FILE = "cost_history.json"
--- Get path to cost history file
---@return string File path
local function get_history_path()
local root = utils.get_project_root()
return root .. "/.coder/" .. COST_HISTORY_FILE
end
--- Default model for savings comparison (what you'd pay if not using Ollama)
M.comparison_model = "gpt-4o"
--- Models considered "free" (Ollama, local, Copilot subscription)
M.free_models = {
["ollama"] = true,
["codellama"] = true,
["llama2"] = true,
["llama3"] = true,
["mistral"] = true,
["deepseek-coder"] = true,
["copilot"] = true,
}
--- Model pricing table (per 1M tokens in USD)
---@type table<string, {input: number, cached_input: number|nil, output: number|nil}>
M.pricing = {
-- GPT-5.x series
["gpt-5.2"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
["gpt-5.1"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
["gpt-5"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
["gpt-5-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 },
["gpt-5-nano"] = { input = 0.05, cached_input = 0.005, output = 0.40 },
["gpt-5.2-chat-latest"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
["gpt-5.1-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
["gpt-5-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
["gpt-5.2-codex"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
["gpt-5.1-codex-max"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
["gpt-5.1-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
["gpt-5-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
["gpt-5.2-pro"] = { input = 21.00, cached_input = nil, output = 168.00 },
["gpt-5-pro"] = { input = 15.00, cached_input = nil, output = 120.00 },
["gpt-5.1-codex-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 },
["gpt-5-search-api"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
-- GPT-4.x series
["gpt-4.1"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
["gpt-4.1-mini"] = { input = 0.40, cached_input = 0.10, output = 1.60 },
["gpt-4.1-nano"] = { input = 0.10, cached_input = 0.025, output = 0.40 },
["gpt-4o"] = { input = 2.50, cached_input = 1.25, output = 10.00 },
["gpt-4o-2024-05-13"] = { input = 5.00, cached_input = nil, output = 15.00 },
["gpt-4o-mini"] = { input = 0.15, cached_input = 0.075, output = 0.60 },
-- Realtime models
["gpt-realtime"] = { input = 4.00, cached_input = 0.40, output = 16.00 },
["gpt-realtime-mini"] = { input = 0.60, cached_input = 0.06, output = 2.40 },
["gpt-4o-realtime-preview"] = { input = 5.00, cached_input = 2.50, output = 20.00 },
["gpt-4o-mini-realtime-preview"] = { input = 0.60, cached_input = 0.30, output = 2.40 },
-- Audio models
["gpt-audio"] = { input = 2.50, cached_input = nil, output = 10.00 },
["gpt-audio-mini"] = { input = 0.60, cached_input = nil, output = 2.40 },
["gpt-4o-audio-preview"] = { input = 2.50, cached_input = nil, output = 10.00 },
["gpt-4o-mini-audio-preview"] = { input = 0.15, cached_input = nil, output = 0.60 },
-- O-series reasoning models
["o1"] = { input = 15.00, cached_input = 7.50, output = 60.00 },
["o1-pro"] = { input = 150.00, cached_input = nil, output = 600.00 },
["o3-pro"] = { input = 20.00, cached_input = nil, output = 80.00 },
["o3"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
["o3-deep-research"] = { input = 10.00, cached_input = 2.50, output = 40.00 },
["o4-mini"] = { input = 1.10, cached_input = 0.275, output = 4.40 },
["o4-mini-deep-research"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
["o3-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 },
["o1-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 },
-- Codex
["codex-mini-latest"] = { input = 1.50, cached_input = 0.375, output = 6.00 },
-- Search models
["gpt-4o-mini-search-preview"] = { input = 0.15, cached_input = nil, output = 0.60 },
["gpt-4o-search-preview"] = { input = 2.50, cached_input = nil, output = 10.00 },
-- Computer use
["computer-use-preview"] = { input = 3.00, cached_input = nil, output = 12.00 },
-- Image models
["gpt-image-1.5"] = { input = 5.00, cached_input = 1.25, output = 10.00 },
["chatgpt-image-latest"] = { input = 5.00, cached_input = 1.25, output = 10.00 },
["gpt-image-1"] = { input = 5.00, cached_input = 1.25, output = nil },
["gpt-image-1-mini"] = { input = 2.00, cached_input = 0.20, output = nil },
-- Claude models (Anthropic)
["claude-3-opus"] = { input = 15.00, cached_input = 7.50, output = 75.00 },
["claude-3-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 },
["claude-3-haiku"] = { input = 0.25, cached_input = 0.125, output = 1.25 },
["claude-3.5-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 },
["claude-3.5-haiku"] = { input = 0.80, cached_input = 0.40, output = 4.00 },
-- Ollama/Local models (free)
["ollama"] = { input = 0, cached_input = 0, output = 0 },
["codellama"] = { input = 0, cached_input = 0, output = 0 },
["llama2"] = { input = 0, cached_input = 0, output = 0 },
["llama3"] = { input = 0, cached_input = 0, output = 0 },
["mistral"] = { input = 0, cached_input = 0, output = 0 },
["deepseek-coder"] = { input = 0, cached_input = 0, output = 0 },
-- Copilot (included in subscription, but tracking usage)
["copilot"] = { input = 0, cached_input = 0, output = 0 },
}
---@class CostUsage
---@field model string Model name
---@field input_tokens number Input tokens used
---@field output_tokens number Output tokens used
---@field cached_tokens number Cached input tokens
---@field timestamp number Unix timestamp
---@field cost number Calculated cost in USD
---@class CostState
---@field usage CostUsage[] Current session usage
---@field all_usage CostUsage[] All historical usage from brain
---@field session_start number Session start timestamp
---@field win number|nil Window handle
---@field buf number|nil Buffer handle
---@field loaded boolean Whether historical data has been loaded
local state = {
usage = {},
all_usage = {},
session_start = os.time(),
win = nil,
buf = nil,
loaded = false,
}
--- Load historical usage from disk
function M.load_from_history()
if state.loaded then
return
end
local history_path = get_history_path()
local content = utils.read_file(history_path)
if content and content ~= "" then
local ok, data = pcall(vim.json.decode, content)
if ok and data and data.usage then
state.all_usage = data.usage
end
end
state.loaded = true
end
--- Save all usage to disk (debounced)
local save_timer = nil
local function save_to_disk()
-- Cancel existing timer
if save_timer then
save_timer:stop()
save_timer = nil
end
-- Debounce writes (500ms)
save_timer = vim.loop.new_timer()
save_timer:start(500, 0, vim.schedule_wrap(function()
local history_path = get_history_path()
-- Ensure directory exists
local dir = vim.fn.fnamemodify(history_path, ":h")
utils.ensure_dir(dir)
-- Merge session and historical usage
local all_data = vim.deepcopy(state.all_usage)
for _, usage in ipairs(state.usage) do
table.insert(all_data, usage)
end
-- Save to file
local data = {
version = 1,
updated = os.time(),
usage = all_data,
}
local ok, json = pcall(vim.json.encode, data)
if ok then
utils.write_file(history_path, json)
end
save_timer = nil
end))
end
--- Normalize model name for pricing lookup
---@param model string Model name from API
---@return string Normalized model name
local function normalize_model(model)
if not model then
return "unknown"
end
-- Convert to lowercase
local normalized = model:lower()
-- Handle Copilot models
if normalized:match("copilot") then
return "copilot"
end
-- Handle common prefixes
normalized = normalized:gsub("^openai/", "")
normalized = normalized:gsub("^anthropic/", "")
-- Try exact match first
if M.pricing[normalized] then
return normalized
end
-- Try partial matches
for price_model, _ in pairs(M.pricing) do
if normalized:match(price_model) or price_model:match(normalized) then
return price_model
end
end
return normalized
end
--- Check if a model is considered "free" (local/Ollama/Copilot subscription)
---@param model string Model name
---@return boolean True if free
function M.is_free_model(model)
local normalized = normalize_model(model)
-- Check direct match
if M.free_models[normalized] then
return true
end
-- Check if it's an Ollama model (any model with : in name like deepseek-coder:6.7b)
if model:match(":") then
return true
end
-- Check pricing - if cost is 0, it's free
local pricing = M.pricing[normalized]
if pricing and pricing.input == 0 and pricing.output == 0 then
return true
end
return false
end
--- Calculate cost for token usage
---@param model string Model name
---@param input_tokens number Input tokens
---@param output_tokens number Output tokens
---@param cached_tokens? number Cached input tokens
---@return number Cost in USD
function M.calculate_cost(model, input_tokens, output_tokens, cached_tokens)
local normalized = normalize_model(model)
local pricing = M.pricing[normalized]
if not pricing then
-- Unknown model, return 0
return 0
end
cached_tokens = cached_tokens or 0
local regular_input = input_tokens - cached_tokens
-- Calculate cost (prices are per 1M tokens)
local input_cost = (regular_input / 1000000) * (pricing.input or 0)
local cached_cost = (cached_tokens / 1000000) * (pricing.cached_input or pricing.input or 0)
local output_cost = (output_tokens / 1000000) * (pricing.output or 0)
return input_cost + cached_cost + output_cost
end
--- Calculate estimated savings (what would have been paid if using comparison model)
---@param input_tokens number Input tokens
---@param output_tokens number Output tokens
---@param cached_tokens? number Cached input tokens
---@return number Estimated savings in USD
function M.calculate_savings(input_tokens, output_tokens, cached_tokens)
-- Calculate what it would have cost with the comparison model
return M.calculate_cost(M.comparison_model, input_tokens, output_tokens, cached_tokens)
end
--- Record token usage
---@param model string Model name
---@param input_tokens number Input tokens
---@param output_tokens number Output tokens
---@param cached_tokens? number Cached input tokens
function M.record_usage(model, input_tokens, output_tokens, cached_tokens)
cached_tokens = cached_tokens or 0
local cost = M.calculate_cost(model, input_tokens, output_tokens, cached_tokens)
-- Calculate savings if using a free model
local savings = 0
if M.is_free_model(model) then
savings = M.calculate_savings(input_tokens, output_tokens, cached_tokens)
end
table.insert(state.usage, {
model = model,
input_tokens = input_tokens,
output_tokens = output_tokens,
cached_tokens = cached_tokens,
timestamp = os.time(),
cost = cost,
savings = savings,
is_free = M.is_free_model(model),
})
-- Save to disk (debounced)
save_to_disk()
-- Update window if open
if state.win and vim.api.nvim_win_is_valid(state.win) then
M.refresh_window()
end
end
--- Aggregate usage data into stats
---@param usage_list CostUsage[] List of usage records
---@return table Stats
local function aggregate_usage(usage_list)
local stats = {
total_input = 0,
total_output = 0,
total_cached = 0,
total_cost = 0,
total_savings = 0,
free_requests = 0,
paid_requests = 0,
by_model = {},
request_count = #usage_list,
}
for _, usage in ipairs(usage_list) do
stats.total_input = stats.total_input + (usage.input_tokens or 0)
stats.total_output = stats.total_output + (usage.output_tokens or 0)
stats.total_cached = stats.total_cached + (usage.cached_tokens or 0)
stats.total_cost = stats.total_cost + (usage.cost or 0)
-- Track savings
local usage_savings = usage.savings or 0
-- For historical data without savings field, calculate it
if usage_savings == 0 and usage.is_free == nil then
local model = usage.model or "unknown"
if M.is_free_model(model) then
usage_savings = M.calculate_savings(
usage.input_tokens or 0,
usage.output_tokens or 0,
usage.cached_tokens or 0
)
end
end
stats.total_savings = stats.total_savings + usage_savings
-- Track free vs paid
local is_free = usage.is_free
if is_free == nil then
is_free = M.is_free_model(usage.model or "unknown")
end
if is_free then
stats.free_requests = stats.free_requests + 1
else
stats.paid_requests = stats.paid_requests + 1
end
local model = usage.model or "unknown"
if not stats.by_model[model] then
stats.by_model[model] = {
input_tokens = 0,
output_tokens = 0,
cached_tokens = 0,
cost = 0,
savings = 0,
requests = 0,
is_free = is_free,
}
end
stats.by_model[model].input_tokens = stats.by_model[model].input_tokens + (usage.input_tokens or 0)
stats.by_model[model].output_tokens = stats.by_model[model].output_tokens + (usage.output_tokens or 0)
stats.by_model[model].cached_tokens = stats.by_model[model].cached_tokens + (usage.cached_tokens or 0)
stats.by_model[model].cost = stats.by_model[model].cost + (usage.cost or 0)
stats.by_model[model].savings = stats.by_model[model].savings + usage_savings
stats.by_model[model].requests = stats.by_model[model].requests + 1
end
return stats
end
--- Get session statistics
---@return table Statistics
function M.get_stats()
local stats = aggregate_usage(state.usage)
stats.session_duration = os.time() - state.session_start
return stats
end
--- Get all-time statistics (session + historical)
---@return table Statistics
function M.get_all_time_stats()
-- Load history if not loaded
M.load_from_history()
-- Combine session and historical usage
local all_usage = vim.deepcopy(state.all_usage)
for _, usage in ipairs(state.usage) do
table.insert(all_usage, usage)
end
local stats = aggregate_usage(all_usage)
-- Calculate time span
if #all_usage > 0 then
local oldest = all_usage[1].timestamp or os.time()
for _, usage in ipairs(all_usage) do
if usage.timestamp and usage.timestamp < oldest then
oldest = usage.timestamp
end
end
stats.time_span = os.time() - oldest
else
stats.time_span = 0
end
return stats
end
--- Format cost as string
---@param cost number Cost in USD
---@return string Formatted cost
local function format_cost(cost)
if cost < 0.01 then
return string.format("$%.4f", cost)
elseif cost < 1 then
return string.format("$%.3f", cost)
else
return string.format("$%.2f", cost)
end
end
--- Format token count
---@param tokens number Token count
---@return string Formatted count
local function format_tokens(tokens)
if tokens >= 1000000 then
return string.format("%.2fM", tokens / 1000000)
elseif tokens >= 1000 then
return string.format("%.1fK", tokens / 1000)
else
return tostring(tokens)
end
end
--- Format duration
---@param seconds number Duration in seconds
---@return string Formatted duration
local function format_duration(seconds)
if seconds < 60 then
return string.format("%ds", seconds)
elseif seconds < 3600 then
return string.format("%dm %ds", math.floor(seconds / 60), seconds % 60)
else
local hours = math.floor(seconds / 3600)
local mins = math.floor((seconds % 3600) / 60)
return string.format("%dh %dm", hours, mins)
end
end
--- Generate model breakdown section
---@param stats table Stats with by_model
---@return string[] Lines
local function generate_model_breakdown(stats)
local lines = {}
if next(stats.by_model) then
-- Sort models by cost (descending)
local models = {}
for model, data in pairs(stats.by_model) do
table.insert(models, { name = model, data = data })
end
table.sort(models, function(a, b)
return a.data.cost > b.data.cost
end)
for _, item in ipairs(models) do
local model = item.name
local data = item.data
local pricing = M.pricing[normalize_model(model)]
local is_free = data.is_free or M.is_free_model(model)
table.insert(lines, "")
local model_icon = is_free and "🆓" or "💳"
table.insert(lines, string.format(" %s %s", model_icon, model))
table.insert(lines, string.format(" Requests: %d", data.requests))
table.insert(lines, string.format(" Input: %s tokens", format_tokens(data.input_tokens)))
table.insert(lines, string.format(" Output: %s tokens", format_tokens(data.output_tokens)))
if is_free then
-- Show savings for free models
if data.savings and data.savings > 0 then
table.insert(lines, string.format(" Saved: %s", format_cost(data.savings)))
end
else
table.insert(lines, string.format(" Cost: %s", format_cost(data.cost)))
end
-- Show pricing info for paid models
if pricing and not is_free then
local price_info = string.format(
" Rate: $%.2f/1M in, $%.2f/1M out",
pricing.input or 0,
pricing.output or 0
)
table.insert(lines, price_info)
end
end
else
table.insert(lines, " No usage recorded.")
end
return lines
end
--- Generate window content
---@return string[] Lines for the buffer
local function generate_content()
local session_stats = M.get_stats()
local all_time_stats = M.get_all_time_stats()
local lines = {}
-- Header
table.insert(lines, "╔══════════════════════════════════════════════════════╗")
table.insert(lines, "║ 💰 LLM Cost Estimation ║")
table.insert(lines, "╠══════════════════════════════════════════════════════╣")
table.insert(lines, "")
-- All-time summary (prominent)
table.insert(lines, "🌐 All-Time Summary (Project)")
table.insert(lines, "───────────────────────────────────────────────────────")
if all_time_stats.time_span > 0 then
table.insert(lines, string.format(" Time span: %s", format_duration(all_time_stats.time_span)))
end
table.insert(lines, string.format(" Requests: %d total", all_time_stats.request_count))
table.insert(lines, string.format(" Local/Free: %d requests", all_time_stats.free_requests or 0))
table.insert(lines, string.format(" Paid API: %d requests", all_time_stats.paid_requests or 0))
table.insert(lines, string.format(" Input tokens: %s", format_tokens(all_time_stats.total_input)))
table.insert(lines, string.format(" Output tokens: %s", format_tokens(all_time_stats.total_output)))
if all_time_stats.total_cached > 0 then
table.insert(lines, string.format(" Cached tokens: %s", format_tokens(all_time_stats.total_cached)))
end
table.insert(lines, "")
table.insert(lines, string.format(" 💵 Total Cost: %s", format_cost(all_time_stats.total_cost)))
-- Show savings prominently if there are any
if all_time_stats.total_savings and all_time_stats.total_savings > 0 then
table.insert(lines, string.format(" 💚 Saved: %s (vs %s)", format_cost(all_time_stats.total_savings), M.comparison_model))
end
table.insert(lines, "")
-- Session summary
table.insert(lines, "📊 Current Session")
table.insert(lines, "───────────────────────────────────────────────────────")
table.insert(lines, string.format(" Duration: %s", format_duration(session_stats.session_duration)))
table.insert(lines, string.format(" Requests: %d (%d free, %d paid)",
session_stats.request_count,
session_stats.free_requests or 0,
session_stats.paid_requests or 0))
table.insert(lines, string.format(" Input tokens: %s", format_tokens(session_stats.total_input)))
table.insert(lines, string.format(" Output tokens: %s", format_tokens(session_stats.total_output)))
if session_stats.total_cached > 0 then
table.insert(lines, string.format(" Cached tokens: %s", format_tokens(session_stats.total_cached)))
end
table.insert(lines, string.format(" Session Cost: %s", format_cost(session_stats.total_cost)))
if session_stats.total_savings and session_stats.total_savings > 0 then
table.insert(lines, string.format(" Session Saved: %s", format_cost(session_stats.total_savings)))
end
table.insert(lines, "")
-- Per-model breakdown (all-time)
table.insert(lines, "📈 Cost by Model (All-Time)")
table.insert(lines, "───────────────────────────────────────────────────────")
local model_lines = generate_model_breakdown(all_time_stats)
for _, line in ipairs(model_lines) do
table.insert(lines, line)
end
table.insert(lines, "")
table.insert(lines, "───────────────────────────────────────────────────────")
table.insert(lines, " 'q' close | 'r' refresh | 'c' clear session | 'C' all")
table.insert(lines, "╚══════════════════════════════════════════════════════╝")
return lines
end
--- Refresh the cost window content
function M.refresh_window()
if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then
return
end
local lines = generate_content()
vim.bo[state.buf].modifiable = true
vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, lines)
vim.bo[state.buf].modifiable = false
end
--- Open the cost estimation window
function M.open()
-- Load historical data if not loaded
M.load_from_history()
-- Close existing window if open
if state.win and vim.api.nvim_win_is_valid(state.win) then
vim.api.nvim_win_close(state.win, true)
end
-- Create buffer
state.buf = vim.api.nvim_create_buf(false, true)
vim.bo[state.buf].buftype = "nofile"
vim.bo[state.buf].bufhidden = "wipe"
vim.bo[state.buf].swapfile = false
vim.bo[state.buf].filetype = "codetyper-cost"
-- Calculate window size
local width = 58
local height = 40
local row = math.floor((vim.o.lines - height) / 2)
local col = math.floor((vim.o.columns - width) / 2)
-- Create floating window
state.win = vim.api.nvim_open_win(state.buf, true, {
relative = "editor",
width = width,
height = height,
row = row,
col = col,
style = "minimal",
border = "rounded",
title = " Cost Estimation ",
title_pos = "center",
})
-- Set window options
vim.wo[state.win].wrap = false
vim.wo[state.win].cursorline = false
-- Populate content
M.refresh_window()
-- Set up keymaps
local opts = { buffer = state.buf, silent = true }
vim.keymap.set("n", "q", function()
M.close()
end, opts)
vim.keymap.set("n", "<Esc>", function()
M.close()
end, opts)
vim.keymap.set("n", "r", function()
M.refresh_window()
end, opts)
vim.keymap.set("n", "c", function()
M.clear_session()
M.refresh_window()
end, opts)
vim.keymap.set("n", "C", function()
M.clear_all()
M.refresh_window()
end, opts)
-- Set up highlights
vim.api.nvim_buf_call(state.buf, function()
vim.fn.matchadd("Title", "LLM Cost Estimation")
vim.fn.matchadd("Number", "\\$[0-9.]*")
vim.fn.matchadd("Keyword", "[0-9.]*[KM]\\? tokens")
vim.fn.matchadd("Special", "🤖\\|💰\\|📊\\|📈\\|💵")
end)
end
--- Close the cost window
function M.close()
if state.win and vim.api.nvim_win_is_valid(state.win) then
vim.api.nvim_win_close(state.win, true)
end
state.win = nil
state.buf = nil
end
--- Toggle the cost window
function M.toggle()
if state.win and vim.api.nvim_win_is_valid(state.win) then
M.close()
else
M.open()
end
end
--- Clear session usage (not history)
function M.clear_session()
state.usage = {}
state.session_start = os.time()
utils.notify("Session cost tracking cleared", vim.log.levels.INFO)
end
--- Clear all history (session + saved)
function M.clear_all()
state.usage = {}
state.all_usage = {}
state.session_start = os.time()
state.loaded = false
-- Delete history file
local history_path = get_history_path()
local ok, err = os.remove(history_path)
if not ok and err and not err:match("No such file") then
utils.notify("Failed to delete history: " .. err, vim.log.levels.WARN)
end
utils.notify("All cost history cleared", vim.log.levels.INFO)
end
--- Clear usage history (alias for clear_session)
function M.clear()
M.clear_session()
end
--- Reset session
function M.reset()
M.clear_session()
end
return M

View File

@@ -0,0 +1,602 @@
---@mod codetyper.credentials Secure credential storage for Codetyper.nvim
---@brief [[
--- Manages API keys and model preferences stored outside of config files.
--- Credentials are stored in ~/.local/share/nvim/codetyper/configuration.json
---@brief ]]
local M = {}
local utils = require("codetyper.utils")
--- Get the credentials file path
---@return string Path to credentials file
local function get_credentials_path()
local data_dir = vim.fn.stdpath("data")
return data_dir .. "/codetyper/configuration.json"
end
--- Ensure the credentials directory exists
---@return boolean Success
local function ensure_dir()
local data_dir = vim.fn.stdpath("data")
local codetyper_dir = data_dir .. "/codetyper"
return utils.ensure_dir(codetyper_dir)
end
--- Load credentials from file
---@return table Credentials data
function M.load()
local path = get_credentials_path()
local content = utils.read_file(path)
if not content or content == "" then
return {
version = 1,
providers = {},
}
end
local ok, data = pcall(vim.json.decode, content)
if not ok or not data then
return {
version = 1,
providers = {},
}
end
return data
end
--- Save credentials to file
---@param data table Credentials data
---@return boolean Success
function M.save(data)
if not ensure_dir() then
return false
end
local path = get_credentials_path()
local ok, json = pcall(vim.json.encode, data)
if not ok then
return false
end
return utils.write_file(path, json)
end
--- Get API key for a provider
---@param provider string Provider name (claude, openai, gemini, copilot, ollama)
---@return string|nil API key or nil if not found
function M.get_api_key(provider)
local data = M.load()
local provider_data = data.providers and data.providers[provider]
if provider_data and provider_data.api_key then
return provider_data.api_key
end
return nil
end
--- Get model for a provider
---@param provider string Provider name
---@return string|nil Model name or nil if not found
function M.get_model(provider)
local data = M.load()
local provider_data = data.providers and data.providers[provider]
if provider_data and provider_data.model then
return provider_data.model
end
return nil
end
--- Get endpoint for a provider (for custom OpenAI-compatible endpoints)
---@param provider string Provider name
---@return string|nil Endpoint URL or nil if not found
function M.get_endpoint(provider)
local data = M.load()
local provider_data = data.providers and data.providers[provider]
if provider_data and provider_data.endpoint then
return provider_data.endpoint
end
return nil
end
--- Get host for Ollama
---@return string|nil Host URL or nil if not found
function M.get_ollama_host()
local data = M.load()
local provider_data = data.providers and data.providers.ollama
if provider_data and provider_data.host then
return provider_data.host
end
return nil
end
--- Set credentials for a provider
---@param provider string Provider name
---@param credentials table Credentials (api_key, model, endpoint, host)
---@return boolean Success
function M.set_credentials(provider, credentials)
local data = M.load()
if not data.providers then
data.providers = {}
end
if not data.providers[provider] then
data.providers[provider] = {}
end
-- Merge credentials
for key, value in pairs(credentials) do
if value and value ~= "" then
data.providers[provider][key] = value
end
end
data.updated = os.time()
return M.save(data)
end
--- Remove credentials for a provider
---@param provider string Provider name
---@return boolean Success
function M.remove_credentials(provider)
local data = M.load()
if data.providers and data.providers[provider] then
data.providers[provider] = nil
data.updated = os.time()
return M.save(data)
end
return true
end
--- List all configured providers (checks both stored credentials AND config)
---@return table List of provider names with their config status
function M.list_providers()
local data = M.load()
local result = {}
local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" }
for _, provider in ipairs(all_providers) do
local provider_data = data.providers and data.providers[provider]
local has_stored_key = provider_data and provider_data.api_key and provider_data.api_key ~= ""
local has_model = provider_data and provider_data.model and provider_data.model ~= ""
-- Check if configured from config or environment
local configured_from_config = false
local config_model = nil
local ok, codetyper = pcall(require, "codetyper")
if ok then
local config = codetyper.get_config()
if config and config.llm and config.llm[provider] then
local pc = config.llm[provider]
config_model = pc.model
if provider == "claude" then
configured_from_config = pc.api_key ~= nil or vim.env.ANTHROPIC_API_KEY ~= nil
elseif provider == "openai" then
configured_from_config = pc.api_key ~= nil or vim.env.OPENAI_API_KEY ~= nil
elseif provider == "gemini" then
configured_from_config = pc.api_key ~= nil or vim.env.GEMINI_API_KEY ~= nil
elseif provider == "copilot" then
configured_from_config = true -- Just needs copilot.lua
elseif provider == "ollama" then
configured_from_config = pc.host ~= nil
end
end
end
local is_configured = has_stored_key
or (provider == "ollama" and provider_data ~= nil)
or (provider == "copilot" and (provider_data ~= nil or configured_from_config))
or configured_from_config
table.insert(result, {
name = provider,
configured = is_configured,
has_api_key = has_stored_key,
has_model = has_model or config_model ~= nil,
model = (provider_data and provider_data.model) or config_model,
source = has_stored_key and "stored" or (configured_from_config and "config" or nil),
})
end
return result
end
--- Default models for each provider
M.default_models = {
claude = "claude-sonnet-4-20250514",
openai = "gpt-4o",
gemini = "gemini-2.0-flash",
copilot = "gpt-4o",
ollama = "deepseek-coder:6.7b",
}
--- Interactive command to add/update API key
function M.interactive_add()
local providers = { "claude", "openai", "gemini", "copilot", "ollama" }
-- Step 1: Select provider
vim.ui.select(providers, {
prompt = "Select LLM provider:",
format_item = function(item)
local display = item:sub(1, 1):upper() .. item:sub(2)
local creds = M.load()
local configured = creds.providers and creds.providers[item]
if configured and (configured.api_key or item == "ollama") then
return display .. " [configured]"
end
return display
end,
}, function(provider)
if not provider then
return
end
-- Step 2: Get API key (skip for Ollama)
if provider == "ollama" then
M.interactive_ollama_config()
else
M.interactive_api_key(provider)
end
end)
end
--- Interactive API key input
---@param provider string Provider name
function M.interactive_api_key(provider)
-- Copilot uses OAuth from copilot.lua, no API key needed
if provider == "copilot" then
M.interactive_copilot_config()
return
end
local prompt = string.format("Enter %s API key (leave empty to skip): ", provider:upper())
vim.ui.input({ prompt = prompt }, function(api_key)
if api_key == nil then
return -- Cancelled
end
-- Step 3: Get model
M.interactive_model(provider, api_key)
end)
end
--- Interactive Copilot configuration (no API key, uses OAuth)
function M.interactive_copilot_config()
utils.notify("Copilot uses OAuth from copilot.lua/copilot.vim - no API key needed", vim.log.levels.INFO)
-- Just ask for model
local default_model = M.default_models.copilot
vim.ui.input({
prompt = string.format("Copilot model (default: %s): ", default_model),
default = default_model,
}, function(model)
if model == nil then
return -- Cancelled
end
if model == "" then
model = default_model
end
M.save_and_notify("copilot", {
model = model,
-- Mark as configured even without API key
configured = true,
})
end)
end
--- Interactive model selection
---@param provider string Provider name
---@param api_key string|nil API key
function M.interactive_model(provider, api_key)
local default_model = M.default_models[provider] or ""
local prompt = string.format("Enter model (default: %s): ", default_model)
vim.ui.input({ prompt = prompt, default = default_model }, function(model)
if model == nil then
return -- Cancelled
end
-- Use default if empty
if model == "" then
model = default_model
end
-- Save credentials
local credentials = {
model = model,
}
if api_key and api_key ~= "" then
credentials.api_key = api_key
end
-- For OpenAI, also ask for custom endpoint
if provider == "openai" then
M.interactive_endpoint(provider, credentials)
else
M.save_and_notify(provider, credentials)
end
end)
end
--- Interactive endpoint input for OpenAI-compatible providers
---@param provider string Provider name
---@param credentials table Current credentials
function M.interactive_endpoint(provider, credentials)
vim.ui.input({
prompt = "Custom endpoint (leave empty for default OpenAI): ",
}, function(endpoint)
if endpoint == nil then
return -- Cancelled
end
if endpoint ~= "" then
credentials.endpoint = endpoint
end
M.save_and_notify(provider, credentials)
end)
end
--- Interactive Ollama configuration
function M.interactive_ollama_config()
vim.ui.input({
prompt = "Ollama host (default: http://localhost:11434): ",
default = "http://localhost:11434",
}, function(host)
if host == nil then
return -- Cancelled
end
if host == "" then
host = "http://localhost:11434"
end
-- Get model
local default_model = M.default_models.ollama
vim.ui.input({
prompt = string.format("Ollama model (default: %s): ", default_model),
default = default_model,
}, function(model)
if model == nil then
return -- Cancelled
end
if model == "" then
model = default_model
end
M.save_and_notify("ollama", {
host = host,
model = model,
})
end)
end)
end
--- Save credentials and notify user
---@param provider string Provider name
---@param credentials table Credentials to save
function M.save_and_notify(provider, credentials)
if M.set_credentials(provider, credentials) then
local msg = string.format("Saved %s configuration", provider:upper())
if credentials.model then
msg = msg .. " (model: " .. credentials.model .. ")"
end
utils.notify(msg, vim.log.levels.INFO)
else
utils.notify("Failed to save credentials", vim.log.levels.ERROR)
end
end
--- Show current credentials status
function M.show_status()
local providers = M.list_providers()
-- Get current active provider
local codetyper = require("codetyper")
local current = codetyper.get_config().llm.provider
local lines = {
"Codetyper Credentials Status",
"============================",
"",
"Storage: " .. get_credentials_path(),
"Active: " .. current:upper(),
"",
}
for _, p in ipairs(providers) do
local status_icon = p.configured and "" or ""
local active_marker = p.name == current and " [ACTIVE]" or ""
local source_info = ""
if p.configured then
source_info = p.source == "stored" and " (stored)" or " (config)"
end
local model_info = p.model and (" - " .. p.model) or ""
table.insert(lines, string.format(" %s %s%s%s%s",
status_icon,
p.name:upper(),
active_marker,
source_info,
model_info))
end
table.insert(lines, "")
table.insert(lines, "Commands:")
table.insert(lines, " :CoderAddApiKey - Add/update credentials")
table.insert(lines, " :CoderSwitchProvider - Switch active provider")
table.insert(lines, " :CoderRemoveApiKey - Remove stored credentials")
utils.notify(table.concat(lines, "\n"))
end
--- Interactive remove credentials
function M.interactive_remove()
local data = M.load()
local configured = {}
for provider, _ in pairs(data.providers or {}) do
table.insert(configured, provider)
end
if #configured == 0 then
utils.notify("No credentials configured", vim.log.levels.INFO)
return
end
vim.ui.select(configured, {
prompt = "Select provider to remove:",
}, function(provider)
if not provider then
return
end
vim.ui.select({ "Yes", "No" }, {
prompt = "Remove " .. provider:upper() .. " credentials?",
}, function(choice)
if choice == "Yes" then
if M.remove_credentials(provider) then
utils.notify("Removed " .. provider:upper() .. " credentials", vim.log.levels.INFO)
else
utils.notify("Failed to remove credentials", vim.log.levels.ERROR)
end
end
end)
end)
end
--- Set the active provider
---@param provider string Provider name
function M.set_active_provider(provider)
local data = M.load()
data.active_provider = provider
data.updated = os.time()
M.save(data)
-- Also update the runtime config
local codetyper = require("codetyper")
local config = codetyper.get_config()
config.llm.provider = provider
utils.notify("Active provider set to: " .. provider:upper(), vim.log.levels.INFO)
end
--- Get the active provider from stored config
---@return string|nil Active provider
function M.get_active_provider()
local data = M.load()
return data.active_provider
end
--- Check if a provider is configured (from stored credentials OR config)
---@param provider string Provider name
---@return boolean configured, string|nil source
local function is_provider_configured(provider)
-- Check stored credentials first
local data = M.load()
local stored = data.providers and data.providers[provider]
if stored then
if stored.configured or stored.api_key or provider == "ollama" or provider == "copilot" then
return true, "stored"
end
end
-- Check codetyper config
local ok, codetyper = pcall(require, "codetyper")
if not ok then
return false, nil
end
local config = codetyper.get_config()
if not config or not config.llm then
return false, nil
end
local provider_config = config.llm[provider]
if not provider_config then
return false, nil
end
-- Check for API key in config or environment
if provider == "claude" then
if provider_config.api_key or vim.env.ANTHROPIC_API_KEY then
return true, "config"
end
elseif provider == "openai" then
if provider_config.api_key or vim.env.OPENAI_API_KEY then
return true, "config"
end
elseif provider == "gemini" then
if provider_config.api_key or vim.env.GEMINI_API_KEY then
return true, "config"
end
elseif provider == "copilot" then
-- Copilot just needs copilot.lua installed
return true, "config"
elseif provider == "ollama" then
-- Ollama just needs host configured
if provider_config.host then
return true, "config"
end
end
return false, nil
end
--- Interactive switch provider
function M.interactive_switch_provider()
local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" }
local available = {}
local sources = {}
for _, provider in ipairs(all_providers) do
local configured, source = is_provider_configured(provider)
if configured then
table.insert(available, provider)
sources[provider] = source
end
end
if #available == 0 then
utils.notify("No providers configured. Use :CoderAddApiKey or add to your config.", vim.log.levels.WARN)
return
end
local codetyper = require("codetyper")
local current = codetyper.get_config().llm.provider
vim.ui.select(available, {
prompt = "Select provider (current: " .. current .. "):",
format_item = function(item)
local marker = item == current and " [active]" or ""
local source_marker = sources[item] == "stored" and " (stored)" or " (config)"
return item:upper() .. marker .. source_marker
end,
}, function(provider)
if provider then
M.set_active_provider(provider)
end
end)
end
return M

View File

@@ -83,6 +83,9 @@ local TS_QUERIES = {
},
}
-- Forward declaration for analyze_tree_generic (defined below)
local analyze_tree_generic
--- Hash file content for change detection
---@param content string
---@return string
@@ -256,7 +259,7 @@ end
---@param root TSNode
---@param bufnr number
---@return table
local function analyze_tree_generic(root, bufnr)
analyze_tree_generic = function(root, bufnr)
local result = {
functions = {},
classes = {},

View File

@@ -14,6 +14,51 @@ local AUTH_URL = "https://api.github.com/copilot_internal/v2/token"
---@field github_token table|nil
M.state = nil
--- Track if we've already suggested Ollama fallback this session
local ollama_fallback_suggested = false
--- Suggest switching to Ollama when rate limits are hit
---@param error_msg string The error message that triggered this
function M.suggest_ollama_fallback(error_msg)
if ollama_fallback_suggested then
return
end
-- Check if Ollama is available
local ollama_available = false
vim.fn.jobstart({ "curl", "-s", "http://localhost:11434/api/tags" }, {
on_exit = function(_, code)
if code == 0 then
ollama_available = true
end
vim.schedule(function()
if ollama_available then
-- Switch to Ollama automatically
local codetyper = require("codetyper")
local config = codetyper.get_config()
config.llm.provider = "ollama"
ollama_fallback_suggested = true
utils.notify(
"⚠️ Copilot rate limit reached. Switched to Ollama automatically.\n"
.. "Original error: "
.. error_msg:sub(1, 100),
vim.log.levels.WARN
)
else
utils.notify(
"⚠️ Copilot rate limit reached. Ollama not available.\n"
.. "Start Ollama with: ollama serve\n"
.. "Or wait for Copilot limits to reset.",
vim.log.levels.WARN
)
end
end)
end,
})
end
--- Get OAuth token from copilot.lua or copilot.vim config
---@return string|nil OAuth token
local function get_oauth_token()
@@ -51,9 +96,16 @@ local function get_oauth_token()
return nil
end
--- Get model from config
--- Get model from stored credentials or config
---@return string Model name
local function get_model()
-- Priority: stored credentials > config
local credentials = require("codetyper.credentials")
local stored_model = credentials.get_model("copilot")
if stored_model then
return stored_model
end
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.copilot.model
@@ -204,15 +256,37 @@ local function make_request(token, body, callback)
local ok, response = pcall(vim.json.decode, response_text)
if not ok then
-- Show the actual response text as the error (truncated if too long)
local error_msg = response_text
if #error_msg > 200 then
error_msg = error_msg:sub(1, 200) .. "..."
end
-- Clean up common patterns
if response_text:match("<!DOCTYPE") or response_text:match("<html") then
error_msg = "Copilot API returned HTML error page. Service may be unavailable."
end
-- Check for rate limit and suggest Ollama fallback
if response_text:match("limit") or response_text:match("Upgrade") or response_text:match("quota") then
M.suggest_ollama_fallback(error_msg)
end
vim.schedule(function()
callback(nil, "Failed to parse Copilot response", nil)
callback(nil, error_msg, nil)
end)
return
end
if response.error then
local error_msg = response.error.message or "Copilot API error"
if response.error.code == "rate_limit_exceeded" or (error_msg:match("limit") and error_msg:match("plan")) then
error_msg = "Copilot rate limit: " .. error_msg
M.suggest_ollama_fallback(error_msg)
end
vim.schedule(function()
callback(nil, response.error.message or "Copilot API error", nil)
callback(nil, error_msg, nil)
end)
return
end
@@ -220,6 +294,17 @@ local function make_request(token, body, callback)
-- Extract usage info
local usage = response.usage or {}
-- Record usage for cost tracking
if usage.prompt_tokens or usage.completion_tokens then
local cost = require("codetyper.cost")
cost.record_usage(
get_model(),
usage.prompt_tokens or 0,
usage.completion_tokens or 0,
usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens or 0
)
end
if response.choices and response.choices[1] and response.choices[1].message then
local code = llm.extract_code(response.choices[1].message.content)
vim.schedule(function()
@@ -362,20 +447,46 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
-- Format messages for Copilot (OpenAI-compatible format)
local copilot_messages = { { role = "system", content = system_prompt } }
for _, msg in ipairs(messages) do
if type(msg.content) == "string" then
table.insert(copilot_messages, { role = msg.role, content = msg.content })
elseif type(msg.content) == "table" then
local text_parts = {}
for _, part in ipairs(msg.content) do
if part.type == "tool_result" then
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
elseif part.type == "text" then
table.insert(text_parts, part.text or "")
if msg.role == "user" then
-- User messages - handle string or table content
if type(msg.content) == "string" then
table.insert(copilot_messages, { role = "user", content = msg.content })
elseif type(msg.content) == "table" then
-- Handle complex content (like tool results from user perspective)
local text_parts = {}
for _, part in ipairs(msg.content) do
if part.type == "tool_result" then
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
elseif part.type == "text" then
table.insert(text_parts, part.text or "")
end
end
if #text_parts > 0 then
table.insert(copilot_messages, { role = "user", content = table.concat(text_parts, "\n") })
end
end
if #text_parts > 0 then
table.insert(copilot_messages, { role = msg.role, content = table.concat(text_parts, "\n") })
elseif msg.role == "assistant" then
-- Assistant messages - must preserve tool_calls if present
local assistant_msg = {
role = "assistant",
content = type(msg.content) == "string" and msg.content or nil,
}
-- Preserve tool_calls for the API
if msg.tool_calls then
assistant_msg.tool_calls = msg.tool_calls
-- Ensure content is not nil when tool_calls present
if assistant_msg.content == nil then
assistant_msg.content = ""
end
end
table.insert(copilot_messages, assistant_msg)
elseif msg.role == "tool" then
-- Tool result messages - must have tool_call_id
table.insert(copilot_messages, {
role = "tool",
tool_call_id = msg.tool_call_id,
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
})
end
end
@@ -396,6 +507,20 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
logs.thinking("Sending to Copilot API...")
-- Log request to debug file
local debug_log_path = vim.fn.expand("~/.local/codetyper-debug.log")
local debug_f = io.open(debug_log_path, "a")
if debug_f then
debug_f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. "COPILOT REQUEST\n")
debug_f:write("Messages count: " .. #copilot_messages .. "\n")
for i, m in ipairs(copilot_messages) do
debug_f:write(string.format(" [%d] role=%s, has_tool_calls=%s, has_tool_call_id=%s\n",
i, m.role, tostring(m.tool_calls ~= nil), tostring(m.tool_call_id ~= nil)))
end
debug_f:write("---\n")
debug_f:close()
end
local headers = build_headers(token)
local cmd = {
"curl",
@@ -413,35 +538,97 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
table.insert(cmd, "-d")
table.insert(cmd, json_body)
-- Debug logging helper
local function debug_log(msg, data)
local log_path = vim.fn.expand("~/.local/codetyper-debug.log")
local f = io.open(log_path, "a")
if f then
f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. msg .. "\n")
if data then
f:write("DATA: " .. tostring(data):sub(1, 2000) .. "\n")
end
f:write("---\n")
f:close()
end
end
-- Prevent double callback calls
local callback_called = false
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if callback_called then
debug_log("on_stdout: callback already called, skipping")
return
end
if not data or #data == 0 or (data[1] == "" and #data == 1) then
debug_log("on_stdout: empty data")
return
end
local response_text = table.concat(data, "\n")
debug_log("on_stdout: received response", response_text)
local ok, response = pcall(vim.json.decode, response_text)
if not ok then
debug_log("JSON parse failed", response_text)
callback_called = true
-- Show the actual response text as the error (truncated if too long)
local error_msg = response_text
if #error_msg > 200 then
error_msg = error_msg:sub(1, 200) .. "..."
end
-- Clean up common patterns
if response_text:match("<!DOCTYPE") or response_text:match("<html") then
error_msg = "Copilot API returned HTML error page. Service may be unavailable."
end
-- Check for rate limit and suggest Ollama fallback
if response_text:match("limit") or response_text:match("Upgrade") or response_text:match("quota") then
M.suggest_ollama_fallback(error_msg)
end
vim.schedule(function()
logs.error("Failed to parse Copilot response")
callback(nil, "Failed to parse Copilot response")
logs.error(error_msg)
callback(nil, error_msg)
end)
return
end
if response.error then
callback_called = true
local error_msg = response.error.message or "Copilot API error"
-- Check for rate limit in structured error
if response.error.code == "rate_limit_exceeded" or (error_msg:match("limit") and error_msg:match("plan")) then
error_msg = "Copilot rate limit: " .. error_msg
M.suggest_ollama_fallback(error_msg)
end
vim.schedule(function()
logs.error(response.error.message or "Copilot API error")
callback(nil, response.error.message or "Copilot API error")
logs.error(error_msg)
callback(nil, error_msg)
end)
return
end
-- Log token usage
-- Log token usage and record cost
if response.usage then
logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop")
-- Record usage for cost tracking
local cost_tracker = require("codetyper.cost")
cost_tracker.record_usage(
get_model(),
response.usage.prompt_tokens or 0,
response.usage.completion_tokens or 0,
response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0
)
end
-- Convert to Claude-like format for parser compatibility
@@ -474,12 +661,19 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
end
end
callback_called = true
debug_log("on_stdout: success, calling callback")
vim.schedule(function()
callback(converted, nil)
end)
end,
on_stderr = function(_, data)
if callback_called then
return
end
if data and #data > 0 and data[1] ~= "" then
debug_log("on_stderr", table.concat(data, "\n"))
callback_called = true
vim.schedule(function()
logs.error("Copilot API request failed: " .. table.concat(data, "\n"))
callback(nil, "Copilot API request failed: " .. table.concat(data, "\n"))
@@ -487,7 +681,12 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
end
end,
on_exit = function(_, code)
debug_log("on_exit: code=" .. code .. ", callback_called=" .. tostring(callback_called))
if callback_called then
return
end
if code ~= 0 then
callback_called = true
vim.schedule(function()
logs.error("Copilot API request failed with code: " .. code)
callback(nil, "Copilot API request failed with code: " .. code)

View File

@@ -8,17 +8,31 @@ local llm = require("codetyper.llm")
--- Gemini API endpoint
local API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
--- Get API key from config or environment
--- Get API key from stored credentials, config, or environment
---@return string|nil API key
local function get_api_key()
-- Priority: stored credentials > config > environment
local credentials = require("codetyper.credentials")
local stored_key = credentials.get_api_key("gemini")
if stored_key then
return stored_key
end
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.gemini.api_key or vim.env.GEMINI_API_KEY
end
--- Get model from config
--- Get model from stored credentials or config
---@return string Model name
local function get_model()
-- Priority: stored credentials > config
local credentials = require("codetyper.credentials")
local stored_model = credentials.get_model("gemini")
if stored_model then
return stored_model
end
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.gemini.model

View File

@@ -32,6 +32,32 @@ function M.generate(prompt, context, callback)
client.generate(prompt, context, callback)
end
--- Smart generate with automatic provider selection based on brain memories
--- Prefers Ollama when context is rich, falls back to Copilot otherwise.
--- Implements verification pondering to reinforce Ollama accuracy over time.
---@param prompt string The user's prompt
---@param context table Context information
---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback
function M.smart_generate(prompt, context, callback)
local selector = require("codetyper.llm.selector")
selector.smart_generate(prompt, context, callback)
end
--- Get accuracy statistics for providers
---@return table Statistics for each provider
function M.get_accuracy_stats()
local selector = require("codetyper.llm.selector")
return selector.get_accuracy_stats()
end
--- Report user feedback on response quality (for reinforcement learning)
---@param provider string Which provider generated the response
---@param was_correct boolean Whether the response was good
function M.report_feedback(provider, was_correct)
local selector = require("codetyper.llm.selector")
selector.report_feedback(provider, was_correct)
end
--- Build the system prompt for code generation
---@param context table Context information
---@return string System prompt

View File

@@ -5,21 +5,33 @@ local M = {}
local utils = require("codetyper.utils")
local llm = require("codetyper.llm")
--- Get Ollama host from config
--- Get Ollama host from stored credentials or config
---@return string Host URL
local function get_host()
-- Priority: stored credentials > config
local credentials = require("codetyper.credentials")
local stored_host = credentials.get_ollama_host()
if stored_host then
return stored_host
end
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.ollama.host
end
--- Get model from config
--- Get model from stored credentials or config
---@return string Model name
local function get_model()
-- Priority: stored credentials > config
local credentials = require("codetyper.credentials")
local stored_model = credentials.get_model("ollama")
if stored_model then
return stored_model
end
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.ollama.model
end
@@ -199,47 +211,41 @@ function M.validate()
return true
end
--- Build system prompt for agent mode with tool instructions
--- Generate with tool use support for agentic mode (text-based tool calling)
---@param messages table[] Conversation history
---@param context table Context information
---@return string System prompt
local function build_agent_system_prompt(context)
---@param tool_definitions table Tool definitions
---@param callback fun(response: table|nil, error: string|nil) Callback with Claude-like response format
function M.generate_with_tools(messages, context, tool_definitions, callback)
local logs = require("codetyper.agent.logs")
local agent_prompts = require("codetyper.prompts.agent")
local tools_module = require("codetyper.agent.tools")
local system_prompt = agent_prompts.system .. "\n\n"
system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n"
system_prompt = system_prompt .. agent_prompts.tool_instructions
logs.request("ollama", get_model())
logs.thinking("Preparing agent request...")
-- Add context about current file if available
if context.file_path then
system_prompt = system_prompt .. "\n\nCurrent working context:\n"
system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n"
if context.language then
system_prompt = system_prompt .. "- Language: " .. context.language .. "\n"
-- Build system prompt with tool instructions
local system_prompt = llm.build_system_prompt(context)
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
-- Add tool descriptions
system_prompt = system_prompt .. "\n\n## Available Tools\n"
system_prompt = system_prompt .. "Call tools by outputting JSON in this exact format:\n"
system_prompt = system_prompt .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
for _, tool in ipairs(tool_definitions) do
local name = tool.name or (tool["function"] and tool["function"].name)
local desc = tool.description or (tool["function"] and tool["function"].description)
if name then
system_prompt = system_prompt .. string.format("### %s\n%s\n\n", name, desc or "")
end
end
-- Add project root info
local root = utils.get_project_root()
if root then
system_prompt = system_prompt .. "- Project root: " .. root .. "\n"
end
return system_prompt
end
--- Build request body for Ollama API with tools (chat format)
---@param messages table[] Conversation messages
---@param context table Context information
---@return table Request body
local function build_tools_request_body(messages, context)
local system_prompt = build_agent_system_prompt(context)
-- Convert messages to Ollama chat format
local ollama_messages = {}
for _, msg in ipairs(messages) do
local content = msg.content
-- Handle complex content (like tool results)
if type(content) == "table" then
local text_parts = {}
for _, part in ipairs(content) do
@@ -251,14 +257,10 @@ local function build_tools_request_body(messages, context)
end
content = table.concat(text_parts, "\n")
end
table.insert(ollama_messages, {
role = msg.role,
content = content,
})
table.insert(ollama_messages, { role = msg.role, content = content })
end
return {
local body = {
model = get_model(),
messages = ollama_messages,
system = system_prompt,
@@ -268,16 +270,15 @@ local function build_tools_request_body(messages, context)
num_predict = 4096,
},
}
end
--- Make HTTP request to Ollama chat API
---@param body table Request body
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
local function make_chat_request(body, callback)
local host = get_host()
local url = host .. "/api/chat"
local json_body = vim.json.encode(body)
local prompt_estimate = logs.estimate_tokens(json_body)
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
logs.thinking("Sending to Ollama API...")
local cmd = {
"curl",
"-s",
@@ -302,196 +303,82 @@ local function make_chat_request(body, callback)
if not ok then
vim.schedule(function()
callback(nil, "Failed to parse Ollama response", nil)
logs.error("Failed to parse Ollama response")
callback(nil, "Failed to parse Ollama response")
end)
return
end
if response.error then
vim.schedule(function()
callback(nil, response.error or "Ollama API error", nil)
logs.error(response.error or "Ollama API error")
callback(nil, response.error or "Ollama API error")
end)
return
end
-- Extract usage info
local usage = {
prompt_tokens = response.prompt_eval_count or 0,
response_tokens = response.eval_count or 0,
}
-- Log token usage and record cost (Ollama is free but we track usage)
if response.prompt_eval_count or response.eval_count then
logs.response(response.prompt_eval_count or 0, response.eval_count or 0, "stop")
-- Return the message content for agent parsing
if response.message and response.message.content then
vim.schedule(function()
callback(response.message.content, nil, usage)
end)
else
vim.schedule(function()
callback(nil, "No response from Ollama", nil)
end)
-- Record usage for cost tracking (free for local models)
local cost = require("codetyper.cost")
cost.record_usage(
get_model(),
response.prompt_eval_count or 0,
response.eval_count or 0,
0 -- No cached tokens for Ollama
)
end
-- Parse the response text for tool calls
local content_text = response.message and response.message.content or ""
local converted = { content = {}, stop_reason = "end_turn" }
-- Try to extract JSON tool calls from response
local json_match = content_text:match("```json%s*(%b{})%s*```")
if json_match then
local ok_json, parsed = pcall(vim.json.decode, json_match)
if ok_json and parsed.tool then
table.insert(converted.content, {
type = "tool_use",
id = "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF)),
name = parsed.tool,
input = parsed.arguments or {},
})
logs.thinking("Tool call: " .. parsed.tool)
content_text = content_text:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
converted.stop_reason = "tool_use"
end
end
-- Add text content
if content_text and content_text ~= "" then
table.insert(converted.content, 1, { type = "text", text = content_text })
logs.thinking("Response contains text")
end
vim.schedule(function()
callback(converted, nil)
end)
end,
on_stderr = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
logs.error("Ollama API request failed: " .. table.concat(data, "\n"))
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"))
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
-- Don't double-report errors
vim.schedule(function()
logs.error("Ollama API request failed with code: " .. code)
callback(nil, "Ollama API request failed with code: " .. code)
end)
end
end,
})
end
--- Generate response with tools using Ollama API
---@param messages table[] Conversation history
---@param context table Context information
---@param tools table Tool definitions (embedded in prompt for Ollama)
---@param callback fun(response: string|nil, error: string|nil) Callback function
function M.generate_with_tools(messages, context, tools, callback)
local logs = require("codetyper.agent.logs")
-- Log the request
local model = get_model()
logs.request("ollama", model)
logs.thinking("Preparing API request...")
local body = build_tools_request_body(messages, context)
-- Estimate prompt tokens
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
make_chat_request(body, function(response, err, usage)
if err then
logs.error(err)
callback(nil, err)
else
-- Log token usage
if usage then
logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn")
end
-- Log if response contains tool calls
if response then
local parser = require("codetyper.agent.parser")
local parsed = parser.parse_ollama_response(response)
if #parsed.tool_calls > 0 then
for _, tc in ipairs(parsed.tool_calls) do
logs.thinking("Tool call: " .. tc.name)
end
end
if parsed.text and parsed.text ~= "" then
logs.thinking("Response contains text")
end
end
callback(response, nil)
end
end)
end
--- Generate with tool use support for agentic mode (simulated via prompts)
---@param messages table[] Conversation history
---@param context table Context information
---@param tool_definitions table Tool definitions
---@param callback fun(response: string|nil, error: string|nil) Callback with response text
function M.generate_with_tools(messages, context, tool_definitions, callback)
local tools_module = require("codetyper.agent.tools")
local agent_prompts = require("codetyper.prompts.agent")
-- Build system prompt with agent instructions and tool definitions
local system_prompt = llm.build_system_prompt(context)
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
system_prompt = system_prompt .. "\n\n" .. tools_module.to_prompt_format()
-- Flatten messages to single prompt (Ollama's generate API)
local prompt_parts = {}
for _, msg in ipairs(messages) do
if type(msg.content) == "string" then
local role_prefix = msg.role == "user" and "User" or "Assistant"
table.insert(prompt_parts, role_prefix .. ": " .. msg.content)
elseif type(msg.content) == "table" then
-- Handle tool results
for _, item in ipairs(msg.content) do
if item.type == "tool_result" then
table.insert(prompt_parts, "Tool result: " .. item.content)
end
end
end
end
local body = {
model = get_model(),
system = system_prompt,
prompt = table.concat(prompt_parts, "\n\n"),
stream = false,
options = {
temperature = 0.2,
num_predict = 4096,
},
}
local host = get_host()
local url = host .. "/api/generate"
local json_body = vim.json.encode(body)
local cmd = {
"curl",
"-s",
"-X", "POST",
url,
"-H", "Content-Type: application/json",
"-d", json_body,
}
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if not data or #data == 0 or (data[1] == "" and #data == 1) then
return
end
local response_text = table.concat(data, "\n")
local ok, response = pcall(vim.json.decode, response_text)
if not ok then
vim.schedule(function()
callback(nil, "Failed to parse Ollama response")
end)
return
end
if response.error then
vim.schedule(function()
callback(nil, response.error or "Ollama API error")
end)
return
end
-- Return raw response text for parser to handle
vim.schedule(function()
callback(response.response or "", nil)
end)
end,
on_stderr = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"))
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
vim.schedule(function()
callback(nil, "Ollama API request failed with code: " .. code)
end)
end
end,
})
end
return M

View File

@@ -8,25 +8,46 @@ local llm = require("codetyper.llm")
--- OpenAI API endpoint
local API_URL = "https://api.openai.com/v1/chat/completions"
--- Get API key from config or environment
--- Get API key from stored credentials, config, or environment
---@return string|nil API key
local function get_api_key()
-- Priority: stored credentials > config > environment
local credentials = require("codetyper.credentials")
local stored_key = credentials.get_api_key("openai")
if stored_key then
return stored_key
end
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.openai.api_key or vim.env.OPENAI_API_KEY
end
--- Get model from config
--- Get model from stored credentials or config
---@return string Model name
local function get_model()
-- Priority: stored credentials > config
local credentials = require("codetyper.credentials")
local stored_model = credentials.get_model("openai")
if stored_model then
return stored_model
end
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.openai.model
end
--- Get endpoint from config (allows custom endpoints like Azure, OpenRouter)
--- Get endpoint from stored credentials or config (allows custom endpoints like Azure, OpenRouter)
---@return string API endpoint
local function get_endpoint()
-- Priority: stored credentials > config > default
local credentials = require("codetyper.credentials")
local stored_endpoint = credentials.get_endpoint("openai")
if stored_endpoint then
return stored_endpoint
end
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.openai.endpoint or API_URL
@@ -284,9 +305,18 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
return
end
-- Log token usage
-- Log token usage and record cost
if response.usage then
logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop")
-- Record usage for cost tracking
local cost = require("codetyper.cost")
cost.record_usage(
model,
response.usage.prompt_tokens or 0,
response.usage.completion_tokens or 0,
response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0
)
end
-- Convert to Claude-like format for parser compatibility

View File

@@ -0,0 +1,514 @@
---@mod codetyper.llm.selector Smart LLM selection with memory-based confidence
---@brief [[
--- Intelligent LLM provider selection based on brain memories.
--- Prefers local Ollama when context is rich, falls back to Copilot otherwise.
--- Implements verification pondering to reinforce Ollama accuracy over time.
---@brief ]]
local M = {}
---@class SelectionResult
---@field provider string Selected provider name
---@field confidence number Confidence score (0-1)
---@field memory_count number Number of relevant memories found
---@field reason string Human-readable reason for selection
---@class PonderResult
---@field ollama_response string Ollama's response
---@field verifier_response string Verifier's response
---@field agreement_score number How much they agree (0-1)
---@field ollama_correct boolean Whether Ollama was deemed correct
---@field feedback string Feedback for learning
--- Minimum memories required for high confidence
local MIN_MEMORIES_FOR_LOCAL = 3
--- Minimum memory relevance score for local provider
local MIN_RELEVANCE_FOR_LOCAL = 0.6
--- Agreement threshold for Ollama verification
local AGREEMENT_THRESHOLD = 0.7
--- Pondering sample rate (0-1) - how often to verify Ollama
local PONDER_SAMPLE_RATE = 0.2
--- Provider accuracy tracking (persisted in brain)
local accuracy_cache = {
ollama = { correct = 0, total = 0 },
copilot = { correct = 0, total = 0 },
}
--- Get the brain module safely
---@return table|nil
local function get_brain()
local ok, brain = pcall(require, "codetyper.brain")
if ok and brain.is_initialized and brain.is_initialized() then
return brain
end
return nil
end
--- Load accuracy stats from brain
local function load_accuracy_stats()
local brain = get_brain()
if not brain then
return
end
-- Query for accuracy tracking nodes
pcall(function()
local result = brain.query({
query = "provider_accuracy_stats",
types = { "metric" },
limit = 1,
})
if result and result.nodes and #result.nodes > 0 then
local node = result.nodes[1]
if node.c and node.c.d then
local ok, stats = pcall(vim.json.decode, node.c.d)
if ok and stats then
accuracy_cache = stats
end
end
end
end)
end
--- Save accuracy stats to brain
local function save_accuracy_stats()
local brain = get_brain()
if not brain then
return
end
pcall(function()
brain.learn({
type = "metric",
summary = "provider_accuracy_stats",
detail = vim.json.encode(accuracy_cache),
weight = 1.0,
})
end)
end
--- Calculate Ollama confidence based on historical accuracy
---@return number confidence (0-1)
local function get_ollama_historical_confidence()
local stats = accuracy_cache.ollama
if stats.total < 5 then
-- Not enough data, return neutral confidence
return 0.5
end
local accuracy = stats.correct / stats.total
-- Boost confidence if accuracy is high
return math.min(1.0, accuracy * 1.2)
end
--- Query brain for relevant context
---@param prompt string User prompt
---@param file_path string|nil Current file path
---@return table result {memories: table[], relevance: number, count: number}
local function query_brain_context(prompt, file_path)
local result = {
memories = {},
relevance = 0,
count = 0,
}
local brain = get_brain()
if not brain then
return result
end
-- Query brain with multiple dimensions
local ok, query_result = pcall(function()
return brain.query({
query = prompt,
file = file_path,
limit = 10,
types = { "pattern", "correction", "convention", "fact" },
})
end)
if not ok or not query_result then
return result
end
result.memories = query_result.nodes or {}
result.count = #result.memories
-- Calculate average relevance
if result.count > 0 then
local total_relevance = 0
for _, node in ipairs(result.memories) do
-- Use node weight and success rate as relevance indicators
local node_relevance = (node.sc and node.sc.w or 0.5) * (node.sc and node.sc.sr or 0.5)
total_relevance = total_relevance + node_relevance
end
result.relevance = total_relevance / result.count
end
return result
end
--- Select the best LLM provider based on context
---@param prompt string User prompt
---@param context table LLM context
---@return SelectionResult
function M.select_provider(prompt, context)
-- Load accuracy stats on first call
if accuracy_cache.ollama.total == 0 then
load_accuracy_stats()
end
local file_path = context.file_path
-- Query brain for relevant memories
local brain_context = query_brain_context(prompt, file_path)
-- Calculate base confidence from memories
local memory_confidence = 0
if brain_context.count >= MIN_MEMORIES_FOR_LOCAL then
memory_confidence = math.min(1.0, brain_context.count / 10) * brain_context.relevance
end
-- Factor in historical Ollama accuracy
local historical_confidence = get_ollama_historical_confidence()
-- Combined confidence score
local combined_confidence = (memory_confidence * 0.6) + (historical_confidence * 0.4)
-- Decision logic
local provider = "copilot" -- Default to more capable
local reason = ""
if brain_context.count >= MIN_MEMORIES_FOR_LOCAL and combined_confidence >= MIN_RELEVANCE_FOR_LOCAL then
provider = "ollama"
reason = string.format(
"Rich context: %d memories (%.1f%% relevance), historical accuracy: %.1f%%",
brain_context.count,
brain_context.relevance * 100,
historical_confidence * 100
)
elseif brain_context.count > 0 and combined_confidence >= 0.4 then
-- Medium confidence - use Ollama but with pondering
provider = "ollama"
reason = string.format(
"Moderate context: %d memories, will verify with pondering",
brain_context.count
)
else
reason = string.format(
"Insufficient context: %d memories (need %d), using capable provider",
brain_context.count,
MIN_MEMORIES_FOR_LOCAL
)
end
return {
provider = provider,
confidence = combined_confidence,
memory_count = brain_context.count,
reason = reason,
memories = brain_context.memories,
}
end
--- Check if we should ponder (verify) this Ollama response
---@param confidence number Current confidence level
---@return boolean
function M.should_ponder(confidence)
-- Always ponder when confidence is medium
if confidence >= 0.4 and confidence < 0.7 then
return true
end
-- Random sampling for high confidence to keep learning
if confidence >= 0.7 then
return math.random() < PONDER_SAMPLE_RATE
end
-- Low confidence shouldn't reach Ollama anyway
return false
end
--- Calculate agreement score between two responses
---@param response1 string First response
---@param response2 string Second response
---@return number Agreement score (0-1)
local function calculate_agreement(response1, response2)
-- Normalize responses
local norm1 = response1:lower():gsub("%s+", " "):gsub("[^%w%s]", "")
local norm2 = response2:lower():gsub("%s+", " "):gsub("[^%w%s]", "")
-- Extract words
local words1 = {}
for word in norm1:gmatch("%w+") do
words1[word] = (words1[word] or 0) + 1
end
local words2 = {}
for word in norm2:gmatch("%w+") do
words2[word] = (words2[word] or 0) + 1
end
-- Calculate Jaccard similarity
local intersection = 0
local union = 0
for word, count1 in pairs(words1) do
local count2 = words2[word] or 0
intersection = intersection + math.min(count1, count2)
union = union + math.max(count1, count2)
end
for word, count2 in pairs(words2) do
if not words1[word] then
union = union + count2
end
end
if union == 0 then
return 1.0 -- Both empty
end
-- Also check structural similarity (code structure)
local struct_score = 0
local function_count1 = select(2, response1:gsub("function", ""))
local function_count2 = select(2, response2:gsub("function", ""))
if function_count1 > 0 or function_count2 > 0 then
struct_score = 1 - math.abs(function_count1 - function_count2) / math.max(function_count1, function_count2, 1)
else
struct_score = 1.0
end
-- Combined score
local jaccard = intersection / union
return (jaccard * 0.7) + (struct_score * 0.3)
end
--- Ponder (verify) Ollama's response with another LLM
---@param prompt string Original prompt
---@param context table LLM context
---@param ollama_response string Ollama's response
---@param callback fun(result: PonderResult) Callback with pondering result
function M.ponder(prompt, context, ollama_response, callback)
-- Use Copilot as verifier
local copilot = require("codetyper.llm.copilot")
-- Build verification prompt
local verify_prompt = prompt
copilot.generate(verify_prompt, context, function(verifier_response, error)
if error or not verifier_response then
-- Verification failed, assume Ollama is correct
callback({
ollama_response = ollama_response,
verifier_response = "",
agreement_score = 1.0,
ollama_correct = true,
feedback = "Verification unavailable, trusting Ollama",
})
return
end
-- Calculate agreement
local agreement = calculate_agreement(ollama_response, verifier_response)
-- Determine if Ollama was correct
local ollama_correct = agreement >= AGREEMENT_THRESHOLD
-- Generate feedback
local feedback
if ollama_correct then
feedback = string.format("Agreement: %.1f%% - Ollama response validated", agreement * 100)
else
feedback = string.format(
"Disagreement: %.1f%% - Ollama may need correction",
(1 - agreement) * 100
)
end
-- Update accuracy tracking
accuracy_cache.ollama.total = accuracy_cache.ollama.total + 1
if ollama_correct then
accuracy_cache.ollama.correct = accuracy_cache.ollama.correct + 1
end
save_accuracy_stats()
-- Learn from this verification
local brain = get_brain()
if brain then
pcall(function()
if ollama_correct then
-- Reinforce the pattern
brain.learn({
type = "correction",
summary = "Ollama verified correct",
detail = string.format(
"Prompt: %s\nAgreement: %.1f%%",
prompt:sub(1, 100),
agreement * 100
),
weight = 0.8,
file = context.file_path,
})
else
-- Learn the correction
brain.learn({
type = "correction",
summary = "Ollama needed correction",
detail = string.format(
"Prompt: %s\nOllama: %s\nCorrect: %s",
prompt:sub(1, 100),
ollama_response:sub(1, 200),
verifier_response:sub(1, 200)
),
weight = 0.9,
file = context.file_path,
})
end
end)
end
callback({
ollama_response = ollama_response,
verifier_response = verifier_response,
agreement_score = agreement,
ollama_correct = ollama_correct,
feedback = feedback,
})
end)
end
--- Smart generate with automatic provider selection and pondering
---@param prompt string User prompt
---@param context table LLM context
---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback
function M.smart_generate(prompt, context, callback)
-- Select provider
local selection = M.select_provider(prompt, context)
-- Log selection
pcall(function()
local logs = require("codetyper.agent.logs")
logs.add({
type = "info",
message = string.format(
"LLM: %s (confidence: %.1f%%, %s)",
selection.provider,
selection.confidence * 100,
selection.reason
),
})
end)
-- Get the selected client
local client
if selection.provider == "ollama" then
client = require("codetyper.llm.ollama")
else
client = require("codetyper.llm.copilot")
end
-- Generate response
client.generate(prompt, context, function(response, error)
if error then
-- Fallback on error
if selection.provider == "ollama" then
-- Try Copilot as fallback
local copilot = require("codetyper.llm.copilot")
copilot.generate(prompt, context, function(fallback_response, fallback_error)
callback(fallback_response, fallback_error, {
provider = "copilot",
fallback = true,
original_provider = "ollama",
original_error = error,
})
end)
return
end
callback(nil, error, { provider = selection.provider })
return
end
-- Check if we should ponder
if selection.provider == "ollama" and M.should_ponder(selection.confidence) then
M.ponder(prompt, context, response, function(ponder_result)
if ponder_result.ollama_correct then
-- Ollama was correct, use its response
callback(response, nil, {
provider = "ollama",
pondered = true,
agreement = ponder_result.agreement_score,
confidence = selection.confidence,
})
else
-- Use verifier's response instead
callback(ponder_result.verifier_response, nil, {
provider = "copilot",
pondered = true,
agreement = ponder_result.agreement_score,
original_provider = "ollama",
corrected = true,
})
end
end)
else
-- No pondering needed
callback(response, nil, {
provider = selection.provider,
pondered = false,
confidence = selection.confidence,
})
end
end)
end
--- Get current accuracy statistics
---@return table {ollama: {correct, total, accuracy}, copilot: {correct, total, accuracy}}
function M.get_accuracy_stats()
local stats = {
ollama = {
correct = accuracy_cache.ollama.correct,
total = accuracy_cache.ollama.total,
accuracy = accuracy_cache.ollama.total > 0
and (accuracy_cache.ollama.correct / accuracy_cache.ollama.total)
or 0,
},
copilot = {
correct = accuracy_cache.copilot.correct,
total = accuracy_cache.copilot.total,
accuracy = accuracy_cache.copilot.total > 0
and (accuracy_cache.copilot.correct / accuracy_cache.copilot.total)
or 0,
},
}
return stats
end
--- Reset accuracy statistics
function M.reset_accuracy_stats()
accuracy_cache = {
ollama = { correct = 0, total = 0 },
copilot = { correct = 0, total = 0 },
}
save_accuracy_stats()
end
--- Report user feedback on response quality
---@param provider string Which provider generated the response
---@param was_correct boolean Whether the response was good
function M.report_feedback(provider, was_correct)
if accuracy_cache[provider] then
accuracy_cache[provider].total = accuracy_cache[provider].total + 1
if was_correct then
accuracy_cache[provider].correct = accuracy_cache[provider].correct + 1
end
save_accuracy_stats()
end
end
return M

View File

@@ -13,24 +13,23 @@ M.system =
You have access to these tools - USE THEM to accomplish tasks:
### File Operations
- **read_file**: Read any file. ALWAYS read files before modifying them.
- **write_file**: Create new files or completely replace existing ones. Use for new files.
- **edit_file**: Make precise edits to existing files using find/replace. The "find" must match EXACTLY.
- **delete_file**: Delete files (requires user approval). Include a reason.
- **list_directory**: Explore project structure. See what files exist.
- **search_files**: Find files by pattern or content.
- **view**: Read any file. ALWAYS read files before modifying them. Parameters: path (string)
- **write**: Create new files or completely replace existing ones. Use for new files. Parameters: path (string), content (string)
- **edit**: Make precise edits to existing files using search/replace. Parameters: path (string), old_string (string), new_string (string)
- **glob**: Find files by pattern (e.g., "**/*.lua"). Parameters: pattern (string), path (optional)
- **grep**: Search file contents with regex. Parameters: pattern (string), path (optional)
### Shell Commands
- **bash**: Run shell commands (git, npm, make, etc.). User approves each command.
- **bash**: Run shell commands (git, npm, make, etc.). User approves each command. Parameters: command (string)
## HOW TO WORK
1. **UNDERSTAND FIRST**: Use read_file, list_directory, or search_files to understand the codebase before making changes.
1. **UNDERSTAND FIRST**: Use view, glob, or grep to understand the codebase before making changes.
2. **MAKE CHANGES**: Use write_file for new files, edit_file for modifications.
- For edit_file: The "find" parameter must match file content EXACTLY (including whitespace)
- Include enough context in "find" to be unique
- For write_file: Provide complete file content
2. **MAKE CHANGES**: Use write for new files, edit for modifications.
- For edit: The "old_string" parameter must match file content EXACTLY (including whitespace)
- Include enough context in "old_string" to be unique
- For write: Provide complete file content
3. **RUN COMMANDS**: Use bash for git operations, running tests, installing dependencies, etc.
@@ -41,16 +40,16 @@ You have access to these tools - USE THEM to accomplish tasks:
User: "Create a new React component for a login form"
Your approach:
1. Use list_directory to see project structure
2. Use read_file to check existing component patterns
3. Use write_file to create the new component file
4. Use write_file to create a test file if appropriate
1. Use glob to see project structure (glob pattern="**/*.tsx")
2. Use view to check existing component patterns
3. Use write to create the new component file
4. Use write to create a test file if appropriate
5. Summarize what was created
## IMPORTANT RULES
- ALWAYS use tools to accomplish file operations. Don't just describe what to do - DO IT.
- Read files before editing to ensure your "find" string matches exactly.
- Read files before editing to ensure your "old_string" matches exactly.
- When creating files, write complete, working code.
- When editing, preserve existing code style and conventions.
- If a file path is provided, use it. If not, infer from context.
@@ -68,13 +67,12 @@ M.tool_instructions = [[
## TOOL USAGE
When you need to perform an action, call the appropriate tool. You can call tools to:
- Read files to understand code
- Create new files with write_file
- Modify existing files with edit_file (read first!)
- Delete files with delete_file
- List directories to explore structure
- Search for files by name or content
- Run shell commands with bash
- Read files with view (parameters: path)
- Create new files with write (parameters: path, content)
- Modify existing files with edit (parameters: path, old_string, new_string) - read first!
- Find files by pattern with glob (parameters: pattern, path)
- Search file contents with grep (parameters: pattern, path)
- Run shell commands with bash (parameters: command)
After receiving a tool result, continue working:
- If more actions are needed, call another tool
@@ -82,11 +80,11 @@ After receiving a tool result, continue working:
## CRITICAL RULES
1. **Always read before editing**: Use read_file before edit_file to ensure exact matches
2. **Be precise with edits**: The "find" parameter must match the file content EXACTLY
3. **Create complete files**: When using write_file, provide fully working code
4. **User approval required**: File writes, edits, deletes, and bash commands need approval
5. **Don't guess**: If unsure about file structure, use list_directory or search_files
1. **Always read before editing**: Use view before edit to ensure exact matches
2. **Be precise with edits**: The "old_string" parameter must match the file content EXACTLY
3. **Create complete files**: When using write, provide fully working code
4. **User approval required**: File writes, edits, and bash commands need approval
5. **Don't guess**: If unsure about file structure, use glob or grep
]]
--- Prompt for when agent finishes

View File

@@ -0,0 +1,427 @@
--- Tests for agent tools system
describe("codetyper.agent.tools", function()
local tools
before_each(function()
tools = require("codetyper.agent.tools")
-- Clear any existing registrations
for name, _ in pairs(tools.get_all()) do
tools.unregister(name)
end
end)
describe("tool registration", function()
it("should register a tool", function()
local test_tool = {
name = "test_tool",
description = "A test tool",
params = {
{ name = "input", type = "string", description = "Test input" },
},
func = function(input, opts)
return "result", nil
end,
}
tools.register(test_tool)
local retrieved = tools.get("test_tool")
assert.is_not_nil(retrieved)
assert.equals("test_tool", retrieved.name)
end)
it("should unregister a tool", function()
local test_tool = {
name = "temp_tool",
description = "Temporary",
func = function() end,
}
tools.register(test_tool)
assert.is_not_nil(tools.get("temp_tool"))
tools.unregister("temp_tool")
assert.is_nil(tools.get("temp_tool"))
end)
it("should list all tools", function()
tools.register({ name = "tool1", func = function() end })
tools.register({ name = "tool2", func = function() end })
tools.register({ name = "tool3", func = function() end })
local list = tools.list()
assert.equals(3, #list)
end)
it("should filter tools with predicate", function()
tools.register({ name = "safe_tool", requires_confirmation = false, func = function() end })
tools.register({ name = "dangerous_tool", requires_confirmation = true, func = function() end })
local safe_list = tools.list(function(t)
return not t.requires_confirmation
end)
assert.equals(1, #safe_list)
assert.equals("safe_tool", safe_list[1].name)
end)
end)
describe("tool execution", function()
it("should execute a tool and return result", function()
tools.register({
name = "adder",
params = {
{ name = "a", type = "number" },
{ name = "b", type = "number" },
},
func = function(input, opts)
return input.a + input.b, nil
end,
})
local result, err = tools.execute("adder", { a = 5, b = 3 }, {})
assert.is_nil(err)
assert.equals(8, result)
end)
it("should return error for unknown tool", function()
local result, err = tools.execute("nonexistent", {}, {})
assert.is_nil(result)
assert.truthy(err:match("Unknown tool"))
end)
it("should track execution history", function()
tools.clear_history()
tools.register({
name = "tracked_tool",
func = function()
return "done", nil
end,
})
tools.execute("tracked_tool", {}, {})
tools.execute("tracked_tool", {}, {})
local history = tools.get_history()
assert.equals(2, #history)
assert.equals("tracked_tool", history[1].tool)
assert.equals("completed", history[1].status)
end)
end)
describe("tool schemas", function()
it("should generate JSON schema for tools", function()
tools.register({
name = "schema_test",
description = "Test schema generation",
params = {
{ name = "required_param", type = "string", description = "A required param" },
{ name = "optional_param", type = "number", description = "Optional", optional = true },
},
returns = {
{ name = "result", type = "string" },
},
to_schema = require("codetyper.agent.tools.base").to_schema,
func = function() end,
})
local schemas = tools.get_schemas()
assert.equals(1, #schemas)
local schema = schemas[1]
assert.equals("function", schema.type)
assert.equals("schema_test", schema.function_def.name)
assert.is_not_nil(schema.function_def.parameters.properties.required_param)
assert.is_not_nil(schema.function_def.parameters.properties.optional_param)
end)
end)
describe("process_tool_call", function()
it("should process tool call with name and input", function()
tools.register({
name = "processor_test",
func = function(input, opts)
return "processed: " .. input.value, nil
end,
})
local result, err = tools.process_tool_call({
name = "processor_test",
input = { value = "test" },
}, {})
assert.is_nil(err)
assert.equals("processed: test", result)
end)
it("should parse JSON string arguments", function()
tools.register({
name = "json_parser_test",
func = function(input, opts)
return input.key, nil
end,
})
local result, err = tools.process_tool_call({
name = "json_parser_test",
arguments = '{"key": "value"}',
}, {})
assert.is_nil(err)
assert.equals("value", result)
end)
end)
end)
describe("codetyper.agent.tools.base", function()
local base
before_each(function()
base = require("codetyper.agent.tools.base")
end)
describe("validate_input", function()
it("should validate required parameters", function()
local tool = setmetatable({
params = {
{ name = "required", type = "string" },
{ name = "optional", type = "string", optional = true },
},
}, base)
local valid, err = tool:validate_input({ required = "value" })
assert.is_true(valid)
assert.is_nil(err)
end)
it("should fail on missing required parameter", function()
local tool = setmetatable({
params = {
{ name = "required", type = "string" },
},
}, base)
local valid, err = tool:validate_input({})
assert.is_false(valid)
assert.truthy(err:match("Missing required parameter"))
end)
it("should validate parameter types", function()
local tool = setmetatable({
params = {
{ name = "num", type = "number" },
},
}, base)
local valid1, _ = tool:validate_input({ num = 42 })
assert.is_true(valid1)
local valid2, err2 = tool:validate_input({ num = "not a number" })
assert.is_false(valid2)
assert.truthy(err2:match("must be number"))
end)
it("should validate integer type", function()
local tool = setmetatable({
params = {
{ name = "int", type = "integer" },
},
}, base)
local valid1, _ = tool:validate_input({ int = 42 })
assert.is_true(valid1)
local valid2, err2 = tool:validate_input({ int = 42.5 })
assert.is_false(valid2)
assert.truthy(err2:match("must be an integer"))
end)
end)
describe("get_description", function()
it("should return string description", function()
local tool = setmetatable({
description = "Static description",
}, base)
assert.equals("Static description", tool:get_description())
end)
it("should call function description", function()
local tool = setmetatable({
description = function()
return "Dynamic description"
end,
}, base)
assert.equals("Dynamic description", tool:get_description())
end)
end)
describe("to_schema", function()
it("should generate valid schema", function()
local tool = setmetatable({
name = "test",
description = "Test tool",
params = {
{ name = "input", type = "string", description = "Input value" },
{ name = "count", type = "integer", description = "Count", optional = true },
},
}, base)
local schema = tool:to_schema()
assert.equals("function", schema.type)
assert.equals("test", schema.function_def.name)
assert.equals("Test tool", schema.function_def.description)
assert.equals("object", schema.function_def.parameters.type)
assert.is_not_nil(schema.function_def.parameters.properties.input)
assert.is_not_nil(schema.function_def.parameters.properties.count)
assert.same({ "input" }, schema.function_def.parameters.required)
end)
end)
end)
describe("built-in tools", function()
describe("view tool", function()
local view
before_each(function()
view = require("codetyper.agent.tools.view")
end)
it("should have required fields", function()
assert.equals("view", view.name)
assert.is_string(view.description)
assert.is_table(view.params)
assert.is_function(view.func)
end)
it("should require path parameter", function()
local result, err = view.func({}, {})
assert.is_nil(result)
assert.truthy(err:match("path is required"))
end)
end)
describe("grep tool", function()
local grep
before_each(function()
grep = require("codetyper.agent.tools.grep")
end)
it("should have required fields", function()
assert.equals("grep", grep.name)
assert.is_string(grep.description)
assert.is_table(grep.params)
assert.is_function(grep.func)
end)
it("should require pattern parameter", function()
local result, err = grep.func({}, {})
assert.is_nil(result)
assert.truthy(err:match("pattern is required"))
end)
end)
describe("glob tool", function()
local glob
before_each(function()
glob = require("codetyper.agent.tools.glob")
end)
it("should have required fields", function()
assert.equals("glob", glob.name)
assert.is_string(glob.description)
assert.is_table(glob.params)
assert.is_function(glob.func)
end)
it("should require pattern parameter", function()
local result, err = glob.func({}, {})
assert.is_nil(result)
assert.truthy(err:match("pattern is required"))
end)
end)
describe("edit tool", function()
local edit
before_each(function()
edit = require("codetyper.agent.tools.edit")
end)
it("should have required fields", function()
assert.equals("edit", edit.name)
assert.is_string(edit.description)
assert.is_table(edit.params)
assert.is_function(edit.func)
end)
it("should require path parameter", function()
local result, err = edit.func({}, {})
assert.is_nil(result)
assert.truthy(err:match("path is required"))
end)
it("should require old_string parameter", function()
local result, err = edit.func({ path = "/tmp/test" }, {})
assert.is_nil(result)
assert.truthy(err:match("old_string is required"))
end)
end)
describe("write tool", function()
local write
before_each(function()
write = require("codetyper.agent.tools.write")
end)
it("should have required fields", function()
assert.equals("write", write.name)
assert.is_string(write.description)
assert.is_table(write.params)
assert.is_function(write.func)
end)
it("should require path parameter", function()
local result, err = write.func({}, {})
assert.is_nil(result)
assert.truthy(err:match("path is required"))
end)
it("should require content parameter", function()
local result, err = write.func({ path = "/tmp/test" }, {})
assert.is_nil(result)
assert.truthy(err:match("content is required"))
end)
end)
describe("bash tool", function()
local bash
before_each(function()
bash = require("codetyper.agent.tools.bash")
end)
it("should have required fields", function()
assert.equals("bash", bash.name)
assert.is_function(bash.func)
end)
it("should require command parameter", function()
local result, err = bash.func({}, {})
assert.is_nil(result)
assert.truthy(err:match("command is required"))
end)
it("should require confirmation by default", function()
assert.is_true(bash.requires_confirmation)
end)
end)
end)

312
tests/spec/agentic_spec.lua Normal file
View File

@@ -0,0 +1,312 @@
---@diagnostic disable: undefined-global
-- Unit tests for the agentic system
describe("agentic module", function()
local agentic
before_each(function()
-- Reset and reload
package.loaded["codetyper.agent.agentic"] = nil
agentic = require("codetyper.agent.agentic")
end)
it("should list built-in agents", function()
local agents = agentic.list_agents()
assert.is_table(agents)
assert.is_true(#agents >= 3) -- coder, planner, explorer
local names = {}
for _, agent in ipairs(agents) do
names[agent.name] = true
end
assert.is_true(names["coder"])
assert.is_true(names["planner"])
assert.is_true(names["explorer"])
end)
it("should have description for each agent", function()
local agents = agentic.list_agents()
for _, agent in ipairs(agents) do
assert.is_string(agent.description)
assert.is_true(#agent.description > 0)
end
end)
it("should mark built-in agents as builtin", function()
local agents = agentic.list_agents()
local coder = nil
for _, agent in ipairs(agents) do
if agent.name == "coder" then
coder = agent
break
end
end
assert.is_not_nil(coder)
assert.is_true(coder.builtin)
end)
it("should have init function to create directories", function()
assert.is_function(agentic.init)
assert.is_function(agentic.init_agents_dir)
assert.is_function(agentic.init_rules_dir)
end)
it("should have run function for executing tasks", function()
assert.is_function(agentic.run)
end)
end)
describe("tools format conversion", function()
local tools_module
before_each(function()
package.loaded["codetyper.agent.tools"] = nil
tools_module = require("codetyper.agent.tools")
-- Load tools
if tools_module.load_builtins then
pcall(tools_module.load_builtins)
end
end)
it("should have to_openai_format function", function()
assert.is_function(tools_module.to_openai_format)
end)
it("should have to_claude_format function", function()
assert.is_function(tools_module.to_claude_format)
end)
it("should convert tools to OpenAI format", function()
local openai_tools = tools_module.to_openai_format()
assert.is_table(openai_tools)
-- If tools are loaded, check format
if #openai_tools > 0 then
local first_tool = openai_tools[1]
assert.equals("function", first_tool.type)
assert.is_table(first_tool["function"])
assert.is_string(first_tool["function"].name)
end
end)
it("should convert tools to Claude format", function()
local claude_tools = tools_module.to_claude_format()
assert.is_table(claude_tools)
-- If tools are loaded, check format
if #claude_tools > 0 then
local first_tool = claude_tools[1]
assert.is_string(first_tool.name)
assert.is_table(first_tool.input_schema)
end
end)
end)
describe("edit tool", function()
local edit_tool
before_each(function()
package.loaded["codetyper.agent.tools.edit"] = nil
edit_tool = require("codetyper.agent.tools.edit")
end)
it("should have name 'edit'", function()
assert.equals("edit", edit_tool.name)
end)
it("should have description mentioning matching strategies", function()
local desc = edit_tool:get_description()
assert.is_string(desc)
-- Should mention the matching capabilities
assert.is_true(desc:lower():match("match") ~= nil or desc:lower():match("replac") ~= nil)
end)
it("should have params defined", function()
assert.is_table(edit_tool.params)
assert.is_true(#edit_tool.params >= 3) -- path, old_string, new_string
end)
it("should require path parameter", function()
local valid, err = edit_tool:validate_input({
old_string = "test",
new_string = "test2",
})
assert.is_false(valid)
assert.is_string(err)
end)
it("should require old_string parameter", function()
local valid, err = edit_tool:validate_input({
path = "/test",
new_string = "test",
})
assert.is_false(valid)
end)
it("should require new_string parameter", function()
local valid, err = edit_tool:validate_input({
path = "/test",
old_string = "test",
})
assert.is_false(valid)
end)
it("should accept empty old_string for new file creation", function()
local valid, err = edit_tool:validate_input({
path = "/test/new_file.lua",
old_string = "",
new_string = "new content",
})
assert.is_true(valid)
assert.is_nil(err)
end)
it("should have func implementation", function()
assert.is_function(edit_tool.func)
end)
end)
describe("view tool", function()
local view_tool
before_each(function()
package.loaded["codetyper.agent.tools.view"] = nil
view_tool = require("codetyper.agent.tools.view")
end)
it("should have name 'view'", function()
assert.equals("view", view_tool.name)
end)
it("should require path parameter", function()
local valid, err = view_tool:validate_input({})
assert.is_false(valid)
end)
it("should accept valid path", function()
local valid, err = view_tool:validate_input({
path = "/test/file.lua",
})
assert.is_true(valid)
end)
end)
describe("write tool", function()
local write_tool
before_each(function()
package.loaded["codetyper.agent.tools.write"] = nil
write_tool = require("codetyper.agent.tools.write")
end)
it("should have name 'write'", function()
assert.equals("write", write_tool.name)
end)
it("should require path and content parameters", function()
local valid, err = write_tool:validate_input({})
assert.is_false(valid)
valid, err = write_tool:validate_input({ path = "/test" })
assert.is_false(valid)
end)
it("should accept valid input", function()
local valid, err = write_tool:validate_input({
path = "/test/file.lua",
content = "test content",
})
assert.is_true(valid)
end)
end)
describe("grep tool", function()
local grep_tool
before_each(function()
package.loaded["codetyper.agent.tools.grep"] = nil
grep_tool = require("codetyper.agent.tools.grep")
end)
it("should have name 'grep'", function()
assert.equals("grep", grep_tool.name)
end)
it("should require pattern parameter", function()
local valid, err = grep_tool:validate_input({})
assert.is_false(valid)
end)
it("should accept valid pattern", function()
local valid, err = grep_tool:validate_input({
pattern = "function.*test",
})
assert.is_true(valid)
end)
end)
describe("glob tool", function()
local glob_tool
before_each(function()
package.loaded["codetyper.agent.tools.glob"] = nil
glob_tool = require("codetyper.agent.tools.glob")
end)
it("should have name 'glob'", function()
assert.equals("glob", glob_tool.name)
end)
it("should require pattern parameter", function()
local valid, err = glob_tool:validate_input({})
assert.is_false(valid)
end)
it("should accept valid pattern", function()
local valid, err = glob_tool:validate_input({
pattern = "**/*.lua",
})
assert.is_true(valid)
end)
end)
describe("base tool", function()
local Base
before_each(function()
package.loaded["codetyper.agent.tools.base"] = nil
Base = require("codetyper.agent.tools.base")
end)
it("should have validate_input method", function()
assert.is_function(Base.validate_input)
end)
it("should have to_schema method", function()
assert.is_function(Base.to_schema)
end)
it("should have get_description method", function()
assert.is_function(Base.get_description)
end)
it("should generate valid schema", function()
local test_tool = setmetatable({
name = "test",
description = "A test tool",
params = {
{ name = "arg1", type = "string", description = "First arg" },
{ name = "arg2", type = "number", description = "Second arg", optional = true },
},
}, Base)
local schema = test_tool:to_schema()
assert.equals("function", schema.type)
assert.equals("test", schema.function_def.name)
assert.is_table(schema.function_def.parameters.properties)
assert.is_table(schema.function_def.parameters.required)
assert.is_true(vim.tbl_contains(schema.function_def.parameters.required, "arg1"))
assert.is_false(vim.tbl_contains(schema.function_def.parameters.required, "arg2"))
end)
end)

View File

@@ -0,0 +1,153 @@
--- Tests for brain/learners pattern detection and extraction
describe("brain.learners", function()
local pattern_learner
before_each(function()
-- Clear module cache
package.loaded["codetyper.brain.learners.pattern"] = nil
package.loaded["codetyper.brain.types"] = nil
pattern_learner = require("codetyper.brain.learners.pattern")
end)
describe("pattern learner detection", function()
it("should detect code_completion events", function()
local event = { type = "code_completion", data = {} }
assert.is_true(pattern_learner.detect(event))
end)
it("should detect file_indexed events", function()
local event = { type = "file_indexed", data = {} }
assert.is_true(pattern_learner.detect(event))
end)
it("should detect code_analyzed events", function()
local event = { type = "code_analyzed", data = {} }
assert.is_true(pattern_learner.detect(event))
end)
it("should detect pattern_detected events", function()
local event = { type = "pattern_detected", data = {} }
assert.is_true(pattern_learner.detect(event))
end)
it("should NOT detect plain 'pattern' type events", function()
-- This was the bug - 'pattern' type was not in the valid_types list
local event = { type = "pattern", data = {} }
assert.is_false(pattern_learner.detect(event))
end)
it("should NOT detect unknown event types", function()
local event = { type = "unknown_type", data = {} }
assert.is_false(pattern_learner.detect(event))
end)
it("should NOT detect nil events", function()
assert.is_false(pattern_learner.detect(nil))
end)
it("should NOT detect events without type", function()
local event = { data = {} }
assert.is_false(pattern_learner.detect(event))
end)
end)
describe("pattern learner extraction", function()
it("should extract from pattern_detected events", function()
local event = {
type = "pattern_detected",
file = "/path/to/file.lua",
data = {
name = "Test pattern",
description = "Pattern description",
language = "lua",
symbols = { "func1", "func2" },
},
}
local extracted = pattern_learner.extract(event)
assert.is_not_nil(extracted)
assert.equals("Test pattern", extracted.summary)
assert.equals("Pattern description", extracted.detail)
assert.equals("lua", extracted.lang)
assert.equals("/path/to/file.lua", extracted.file)
end)
it("should handle pattern_detected with minimal data", function()
local event = {
type = "pattern_detected",
file = "/path/to/file.lua",
data = {
name = "Minimal pattern",
},
}
local extracted = pattern_learner.extract(event)
assert.is_not_nil(extracted)
assert.equals("Minimal pattern", extracted.summary)
assert.equals("Minimal pattern", extracted.detail)
end)
it("should extract from code_completion events", function()
local event = {
type = "code_completion",
file = "/path/to/file.lua",
data = {
intent = "add function",
code = "function test() end",
language = "lua",
},
}
local extracted = pattern_learner.extract(event)
assert.is_not_nil(extracted)
assert.is_true(extracted.summary:find("Code pattern") ~= nil)
assert.equals("function test() end", extracted.detail)
end)
end)
describe("should_learn validation", function()
it("should accept valid patterns", function()
local data = {
summary = "Valid pattern summary",
detail = "This is a detailed description of the pattern",
}
assert.is_true(pattern_learner.should_learn(data))
end)
it("should reject patterns without summary", function()
local data = {
summary = "",
detail = "Some detail",
}
assert.is_false(pattern_learner.should_learn(data))
end)
it("should reject patterns with nil summary", function()
local data = {
summary = nil,
detail = "Some detail",
}
assert.is_false(pattern_learner.should_learn(data))
end)
it("should reject patterns with very short detail", function()
local data = {
summary = "Valid summary",
detail = "short", -- Less than 10 chars
}
assert.is_false(pattern_learner.should_learn(data))
end)
it("should reject whitespace-only summaries", function()
local data = {
summary = " ",
detail = "Some valid detail here",
}
assert.is_false(pattern_learner.should_learn(data))
end)
end)
end)

View File

@@ -0,0 +1,194 @@
--- Tests for coder file context injection
describe("coder context injection", function()
local test_dir
local original_filereadable
before_each(function()
test_dir = "/tmp/codetyper_coder_test_" .. os.time()
vim.fn.mkdir(test_dir, "p")
-- Store original function
original_filereadable = vim.fn.filereadable
end)
after_each(function()
vim.fn.delete(test_dir, "rf")
vim.fn.filereadable = original_filereadable
end)
describe("get_coder_companion_path logic", function()
-- Test the path generation logic (simulating the function behavior)
local function get_coder_companion_path(target_path, file_exists_check)
if not target_path or target_path == "" then
return nil
end
-- Skip if target is already a coder file
if target_path:match("%.coder%.") then
return nil
end
local dir = vim.fn.fnamemodify(target_path, ":h")
local name = vim.fn.fnamemodify(target_path, ":t:r")
local ext = vim.fn.fnamemodify(target_path, ":e")
local coder_path = dir .. "/" .. name .. ".coder." .. ext
if file_exists_check(coder_path) then
return coder_path
end
return nil
end
it("should generate correct coder path for source file", function()
local target = "/path/to/file.ts"
local expected = "/path/to/file.coder.ts"
local path = get_coder_companion_path(target, function() return true end)
assert.equals(expected, path)
end)
it("should return nil for empty path", function()
local path = get_coder_companion_path("", function() return true end)
assert.is_nil(path)
end)
it("should return nil for nil path", function()
local path = get_coder_companion_path(nil, function() return true end)
assert.is_nil(path)
end)
it("should return nil for coder files (avoid recursion)", function()
local target = "/path/to/file.coder.ts"
local path = get_coder_companion_path(target, function() return true end)
assert.is_nil(path)
end)
it("should return nil if coder file doesn't exist", function()
local target = "/path/to/file.ts"
local path = get_coder_companion_path(target, function() return false end)
assert.is_nil(path)
end)
it("should handle files with multiple dots", function()
local target = "/path/to/my.component.ts"
local expected = "/path/to/my.component.coder.ts"
local path = get_coder_companion_path(target, function() return true end)
assert.equals(expected, path)
end)
it("should handle different extensions", function()
local test_cases = {
{ target = "/path/file.lua", expected = "/path/file.coder.lua" },
{ target = "/path/file.py", expected = "/path/file.coder.py" },
{ target = "/path/file.js", expected = "/path/file.coder.js" },
{ target = "/path/file.go", expected = "/path/file.coder.go" },
}
for _, tc in ipairs(test_cases) do
local path = get_coder_companion_path(tc.target, function() return true end)
assert.equals(tc.expected, path, "Failed for: " .. tc.target)
end
end)
end)
describe("coder content filtering", function()
-- Test the filtering logic that skips template-only content
local function has_meaningful_content(lines)
for _, line in ipairs(lines) do
local trimmed = line:gsub("^%s*", "")
if not trimmed:match("^[%-#/]+%s*Coder companion")
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
and not trimmed:match("^[%-#/]+%s*Example:")
and not trimmed:match("^<!%-%-")
and trimmed ~= ""
and not trimmed:match("^[%-#/]+%s*$") then
return true
end
end
return false
end
it("should detect meaningful content", function()
local lines = {
"-- Coder companion for test.lua",
"-- This file handles authentication",
"/@",
"Add login function",
"@/",
}
assert.is_true(has_meaningful_content(lines))
end)
it("should reject template-only content", function()
-- Template lines are filtered by specific patterns
-- Only header comments that match the template format are filtered
local lines = {
"-- Coder companion for test.lua",
"-- Use /@ @/ tags to write pseudo-code prompts",
"-- Example:",
"--",
"",
}
assert.is_false(has_meaningful_content(lines))
end)
it("should detect pseudo-code content", function()
local lines = {
"-- Authentication module",
"",
"-- This module should:",
"-- 1. Validate user credentials",
"-- 2. Generate JWT tokens",
"-- 3. Handle session management",
}
-- "-- Authentication module" doesn't match template patterns
assert.is_true(has_meaningful_content(lines))
end)
it("should handle JavaScript style comments", function()
local lines = {
"// Coder companion for test.ts",
"// Business logic for user authentication",
"",
"// The auth flow should:",
"// 1. Check OAuth token",
"// 2. Validate permissions",
}
-- "// Business logic..." doesn't match template patterns
assert.is_true(has_meaningful_content(lines))
end)
it("should handle empty lines", function()
local lines = {
"",
"",
"",
}
assert.is_false(has_meaningful_content(lines))
end)
end)
describe("context format", function()
it("should format context with proper header", function()
local function format_coder_context(content, ext)
return string.format(
"\n\n--- Business Context / Pseudo-code ---\n" ..
"The following describes the intended behavior and design for this file:\n" ..
"```%s\n%s\n```",
ext,
content
)
end
local formatted = format_coder_context("-- Auth logic here", "lua")
assert.is_true(formatted:find("Business Context") ~= nil)
assert.is_true(formatted:find("```lua") ~= nil)
assert.is_true(formatted:find("Auth logic here") ~= nil)
end)
end)
end)

View File

@@ -0,0 +1,161 @@
--- Tests for coder file ignore logic
describe("coder file ignore logic", function()
-- Directories to ignore
local ignored_directories = {
".git",
".coder",
".claude",
".vscode",
".idea",
"node_modules",
"vendor",
"dist",
"build",
"target",
"__pycache__",
".cache",
".npm",
".yarn",
"coverage",
".next",
".nuxt",
".svelte-kit",
"out",
"bin",
"obj",
}
-- Files to ignore
local ignored_files = {
".gitignore",
".gitattributes",
"package-lock.json",
"yarn.lock",
".env",
".eslintrc",
"tsconfig.json",
"README.md",
"LICENSE",
"Makefile",
}
local function is_in_ignored_directory(filepath)
for _, dir in ipairs(ignored_directories) do
if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then
return true
end
if filepath:match("^" .. dir .. "/") then
return true
end
end
return false
end
local function should_ignore_for_coder(filepath)
local filename = vim.fn.fnamemodify(filepath, ":t")
for _, ignored in ipairs(ignored_files) do
if filename == ignored then
return true
end
end
if filename:match("^%.") then
return true
end
if is_in_ignored_directory(filepath) then
return true
end
return false
end
describe("ignored directories", function()
it("should ignore files in node_modules", function()
assert.is_true(should_ignore_for_coder("/project/node_modules/lodash/index.js"))
assert.is_true(should_ignore_for_coder("/project/node_modules/react/index.js"))
end)
it("should ignore files in .git", function()
assert.is_true(should_ignore_for_coder("/project/.git/config"))
assert.is_true(should_ignore_for_coder("/project/.git/hooks/pre-commit"))
end)
it("should ignore files in .coder", function()
assert.is_true(should_ignore_for_coder("/project/.coder/brain/meta.json"))
end)
it("should ignore files in vendor", function()
assert.is_true(should_ignore_for_coder("/project/vendor/autoload.php"))
end)
it("should ignore files in dist/build", function()
assert.is_true(should_ignore_for_coder("/project/dist/bundle.js"))
assert.is_true(should_ignore_for_coder("/project/build/output.js"))
end)
it("should ignore files in __pycache__", function()
assert.is_true(should_ignore_for_coder("/project/__pycache__/module.cpython-39.pyc"))
end)
it("should NOT ignore regular source files", function()
assert.is_false(should_ignore_for_coder("/project/src/index.ts"))
assert.is_false(should_ignore_for_coder("/project/lib/utils.lua"))
assert.is_false(should_ignore_for_coder("/project/app/main.py"))
end)
end)
describe("ignored files", function()
it("should ignore .gitignore", function()
assert.is_true(should_ignore_for_coder("/project/.gitignore"))
end)
it("should ignore lock files", function()
assert.is_true(should_ignore_for_coder("/project/package-lock.json"))
assert.is_true(should_ignore_for_coder("/project/yarn.lock"))
end)
it("should ignore config files", function()
assert.is_true(should_ignore_for_coder("/project/tsconfig.json"))
assert.is_true(should_ignore_for_coder("/project/.eslintrc"))
end)
it("should ignore .env files", function()
assert.is_true(should_ignore_for_coder("/project/.env"))
end)
it("should ignore README and LICENSE", function()
assert.is_true(should_ignore_for_coder("/project/README.md"))
assert.is_true(should_ignore_for_coder("/project/LICENSE"))
end)
it("should ignore hidden/dot files", function()
assert.is_true(should_ignore_for_coder("/project/.hidden"))
assert.is_true(should_ignore_for_coder("/project/.secret"))
end)
it("should NOT ignore regular source files", function()
assert.is_false(should_ignore_for_coder("/project/src/app.ts"))
assert.is_false(should_ignore_for_coder("/project/components/Button.tsx"))
assert.is_false(should_ignore_for_coder("/project/utils/helpers.js"))
end)
end)
describe("edge cases", function()
it("should handle nested node_modules", function()
assert.is_true(should_ignore_for_coder("/project/packages/core/node_modules/dep/index.js"))
end)
it("should handle files named like directories but not in them", function()
-- A file named "node_modules.md" in root should be ignored (starts with .)
-- But a file in a folder that contains "node" should NOT be ignored
assert.is_false(should_ignore_for_coder("/project/src/node_utils.ts"))
end)
it("should handle relative paths", function()
assert.is_true(should_ignore_for_coder("node_modules/lodash/index.js"))
assert.is_false(should_ignore_for_coder("src/index.ts"))
end)
end)
end)

371
tests/spec/inject_spec.lua Normal file
View File

@@ -0,0 +1,371 @@
--- Tests for smart code injection with import handling
describe("codetyper.agent.inject", function()
local inject
before_each(function()
inject = require("codetyper.agent.inject")
end)
describe("parse_code", function()
describe("JavaScript/TypeScript", function()
it("should detect ES6 named imports", function()
local code = [[import { useState, useEffect } from 'react';
import { Button } from './components';
function App() {
return <div>Hello</div>;
}]]
local result = inject.parse_code(code, "typescript")
assert.equals(2, #result.imports)
assert.truthy(result.imports[1]:match("useState"))
assert.truthy(result.imports[2]:match("Button"))
assert.truthy(#result.body > 0)
end)
it("should detect ES6 default imports", function()
local code = [[import React from 'react';
import axios from 'axios';
const api = axios.create();]]
local result = inject.parse_code(code, "javascript")
assert.equals(2, #result.imports)
assert.truthy(result.imports[1]:match("React"))
assert.truthy(result.imports[2]:match("axios"))
end)
it("should detect require imports", function()
local code = [[const fs = require('fs');
const path = require('path');
module.exports = { fs, path };]]
local result = inject.parse_code(code, "javascript")
assert.equals(2, #result.imports)
assert.truthy(result.imports[1]:match("fs"))
assert.truthy(result.imports[2]:match("path"))
end)
it("should detect multi-line imports", function()
local code = [[import {
useState,
useEffect,
useCallback
} from 'react';
function Component() {}]]
local result = inject.parse_code(code, "typescript")
assert.equals(1, #result.imports)
assert.truthy(result.imports[1]:match("useState"))
assert.truthy(result.imports[1]:match("useCallback"))
end)
it("should detect namespace imports", function()
local code = [[import * as React from 'react';
export default React;]]
local result = inject.parse_code(code, "tsx")
assert.equals(1, #result.imports)
assert.truthy(result.imports[1]:match("%* as React"))
end)
end)
describe("Python", function()
it("should detect simple imports", function()
local code = [[import os
import sys
import json
def main():
pass]]
local result = inject.parse_code(code, "python")
assert.equals(3, #result.imports)
assert.truthy(result.imports[1]:match("import os"))
assert.truthy(result.imports[2]:match("import sys"))
assert.truthy(result.imports[3]:match("import json"))
end)
it("should detect from imports", function()
local code = [[from typing import List, Dict
from pathlib import Path
def process(items: List[str]) -> None:
pass]]
local result = inject.parse_code(code, "py")
assert.equals(2, #result.imports)
assert.truthy(result.imports[1]:match("from typing"))
assert.truthy(result.imports[2]:match("from pathlib"))
end)
end)
describe("Lua", function()
it("should detect require statements", function()
local code = [[local M = {}
local utils = require("codetyper.utils")
local config = require('codetyper.config')
function M.setup()
end
return M]]
local result = inject.parse_code(code, "lua")
assert.equals(2, #result.imports)
assert.truthy(result.imports[1]:match("utils"))
assert.truthy(result.imports[2]:match("config"))
end)
end)
describe("Go", function()
it("should detect single imports", function()
local code = [[package main
import "fmt"
func main() {
fmt.Println("Hello")
}]]
local result = inject.parse_code(code, "go")
assert.equals(1, #result.imports)
assert.truthy(result.imports[1]:match('import "fmt"'))
end)
it("should detect grouped imports", function()
local code = [[package main
import (
"fmt"
"os"
"strings"
)
func main() {}]]
local result = inject.parse_code(code, "go")
assert.equals(1, #result.imports)
assert.truthy(result.imports[1]:match("fmt"))
assert.truthy(result.imports[1]:match("os"))
end)
end)
describe("Rust", function()
it("should detect use statements", function()
local code = [[use std::io;
use std::collections::HashMap;
fn main() {
let map = HashMap::new();
}]]
local result = inject.parse_code(code, "rs")
assert.equals(2, #result.imports)
assert.truthy(result.imports[1]:match("std::io"))
assert.truthy(result.imports[2]:match("HashMap"))
end)
end)
describe("C/C++", function()
it("should detect include statements", function()
local code = [[#include <stdio.h>
#include "myheader.h"
int main() {
return 0;
}]]
local result = inject.parse_code(code, "c")
assert.equals(2, #result.imports)
assert.truthy(result.imports[1]:match("stdio"))
assert.truthy(result.imports[2]:match("myheader"))
end)
end)
end)
describe("merge_imports", function()
it("should merge without duplicates", function()
local existing = {
"import { useState } from 'react';",
"import { Button } from './components';",
}
local new_imports = {
"import { useEffect } from 'react';",
"import { useState } from 'react';", -- duplicate
"import { Card } from './components';",
}
local merged = inject.merge_imports(existing, new_imports)
assert.equals(4, #merged) -- Should not have duplicate useState
end)
it("should handle empty existing imports", function()
local existing = {}
local new_imports = {
"import os",
"import sys",
}
local merged = inject.merge_imports(existing, new_imports)
assert.equals(2, #merged)
end)
it("should handle empty new imports", function()
local existing = {
"import os",
"import sys",
}
local new_imports = {}
local merged = inject.merge_imports(existing, new_imports)
assert.equals(2, #merged)
end)
it("should handle whitespace variations in duplicates", function()
local existing = {
"import { useState } from 'react';",
}
local new_imports = {
"import {useState} from 'react';", -- Same but different spacing
}
local merged = inject.merge_imports(existing, new_imports)
assert.equals(1, #merged) -- Should detect as duplicate
end)
end)
describe("sort_imports", function()
it("should group imports by type for JavaScript", function()
local imports = {
"import React from 'react';",
"import { Button } from './components';",
"import axios from 'axios';",
"import path from 'path';",
}
local sorted = inject.sort_imports(imports, "javascript")
-- Check ordering: builtin -> third-party -> local
local found_builtin = false
local found_local = false
local builtin_pos = 0
local local_pos = 0
for i, imp in ipairs(sorted) do
if imp:match("path") then
found_builtin = true
builtin_pos = i
end
if imp:match("%.%/") then
found_local = true
local_pos = i
end
end
-- Local imports should come after third-party
if found_local and found_builtin then
assert.truthy(local_pos > builtin_pos)
end
end)
end)
describe("has_imports", function()
it("should return true when code has imports", function()
local code = [[import { useState } from 'react';
function App() {}]]
assert.is_true(inject.has_imports(code, "typescript"))
end)
it("should return false when code has no imports", function()
local code = [[function App() {
return <div>Hello</div>;
}]]
assert.is_false(inject.has_imports(code, "typescript"))
end)
it("should detect Python imports", function()
local code = [[from typing import List
def process(items: List[str]):
pass]]
assert.is_true(inject.has_imports(code, "python"))
end)
it("should detect Lua requires", function()
local code = [[local utils = require("utils")
local M = {}
return M]]
assert.is_true(inject.has_imports(code, "lua"))
end)
end)
describe("edge cases", function()
it("should handle empty code", function()
local result = inject.parse_code("", "javascript")
assert.equals(0, #result.imports)
assert.equals(1, #result.body) -- Empty string becomes one empty line
end)
it("should handle code with only imports", function()
local code = [[import React from 'react';
import { useState } from 'react';]]
local result = inject.parse_code(code, "javascript")
assert.equals(2, #result.imports)
assert.equals(0, #result.body)
end)
it("should handle code with only body", function()
local code = [[function hello() {
console.log("Hello");
}]]
local result = inject.parse_code(code, "javascript")
assert.equals(0, #result.imports)
assert.truthy(#result.body > 0)
end)
it("should handle imports in string literals (not detect as imports)", function()
local code = [[const example = "import { fake } from 'not-real';";
const config = { import: true };
function test() {}]]
local result = inject.parse_code(code, "javascript")
-- The first line looks like an import but is in a string
-- This is a known limitation - we accept some false positives
-- The important thing is we don't break the code
assert.truthy(#result.body >= 0)
end)
it("should handle mixed import styles in same file", function()
local code = [[import React from 'react';
const axios = require('axios');
import { useState } from 'react';
function App() {}]]
local result = inject.parse_code(code, "javascript")
assert.equals(3, #result.imports)
end)
end)
end)

View File

@@ -0,0 +1,174 @@
--- Tests for smart LLM selection with memory-based confidence
describe("codetyper.llm.selector", function()
local selector
before_each(function()
selector = require("codetyper.llm.selector")
-- Reset stats for clean tests
selector.reset_accuracy_stats()
end)
describe("select_provider", function()
it("should return copilot when no brain memories exist", function()
local result = selector.select_provider("write a function", {
file_path = "/test/file.lua",
})
assert.equals("copilot", result.provider)
assert.equals(0, result.memory_count)
assert.truthy(result.reason:match("Insufficient context"))
end)
it("should return a valid selection result structure", function()
local result = selector.select_provider("test prompt", {})
assert.is_string(result.provider)
assert.is_number(result.confidence)
assert.is_number(result.memory_count)
assert.is_string(result.reason)
end)
it("should have confidence between 0 and 1", function()
local result = selector.select_provider("test", {})
assert.truthy(result.confidence >= 0)
assert.truthy(result.confidence <= 1)
end)
end)
describe("should_ponder", function()
it("should return true for medium confidence", function()
assert.is_true(selector.should_ponder(0.5))
assert.is_true(selector.should_ponder(0.6))
end)
it("should return false for low confidence", function()
assert.is_false(selector.should_ponder(0.2))
assert.is_false(selector.should_ponder(0.3))
end)
-- High confidence pondering is probabilistic, so we test the range
it("should sometimes ponder for high confidence (sampling)", function()
-- Run multiple times to test probabilistic behavior
local pondered_count = 0
for _ = 1, 100 do
if selector.should_ponder(0.9) then
pondered_count = pondered_count + 1
end
end
-- Should ponder roughly 20% of the time (PONDER_SAMPLE_RATE = 0.2)
-- Allow range of 5-40% due to randomness
assert.truthy(pondered_count >= 5, "Should ponder at least sometimes")
assert.truthy(pondered_count <= 40, "Should not ponder too often")
end)
end)
describe("get_accuracy_stats", function()
it("should return initial empty stats", function()
local stats = selector.get_accuracy_stats()
assert.equals(0, stats.ollama.total)
assert.equals(0, stats.ollama.correct)
assert.equals(0, stats.ollama.accuracy)
assert.equals(0, stats.copilot.total)
assert.equals(0, stats.copilot.correct)
assert.equals(0, stats.copilot.accuracy)
end)
end)
describe("report_feedback", function()
it("should track positive feedback", function()
selector.report_feedback("ollama", true)
selector.report_feedback("ollama", true)
selector.report_feedback("ollama", false)
local stats = selector.get_accuracy_stats()
assert.equals(3, stats.ollama.total)
assert.equals(2, stats.ollama.correct)
end)
it("should track copilot feedback separately", function()
selector.report_feedback("ollama", true)
selector.report_feedback("copilot", true)
selector.report_feedback("copilot", false)
local stats = selector.get_accuracy_stats()
assert.equals(1, stats.ollama.total)
assert.equals(2, stats.copilot.total)
end)
it("should calculate accuracy correctly", function()
selector.report_feedback("ollama", true)
selector.report_feedback("ollama", true)
selector.report_feedback("ollama", true)
selector.report_feedback("ollama", false)
local stats = selector.get_accuracy_stats()
assert.equals(0.75, stats.ollama.accuracy)
end)
end)
describe("reset_accuracy_stats", function()
it("should clear all stats", function()
selector.report_feedback("ollama", true)
selector.report_feedback("copilot", true)
selector.reset_accuracy_stats()
local stats = selector.get_accuracy_stats()
assert.equals(0, stats.ollama.total)
assert.equals(0, stats.copilot.total)
end)
end)
end)
describe("agreement calculation", function()
-- Test the internal agreement calculation through pondering behavior
-- Since calculate_agreement is local, we test its effects indirectly
it("should detect high agreement for similar responses", function()
-- This is tested through the pondering system
-- When responses are similar, agreement should be high
local selector = require("codetyper.llm.selector")
-- Verify that should_ponder returns predictable results
-- for medium confidence (where pondering always happens)
assert.is_true(selector.should_ponder(0.5))
end)
end)
describe("provider selection with accuracy history", function()
local selector
before_each(function()
selector = require("codetyper.llm.selector")
selector.reset_accuracy_stats()
end)
it("should factor in historical accuracy for selection", function()
-- Simulate high Ollama accuracy
for _ = 1, 10 do
selector.report_feedback("ollama", true)
end
-- Even with no brain context, historical accuracy should influence confidence
local result = selector.select_provider("test", {})
-- Confidence should be higher due to historical accuracy
-- but provider might still be copilot if no memories
assert.is_number(result.confidence)
end)
it("should have lower confidence for low historical accuracy", function()
-- Simulate low Ollama accuracy
for _ = 1, 10 do
selector.report_feedback("ollama", false)
end
local result = selector.select_provider("test", {})
-- With bad history and no memories, should definitely use copilot
assert.equals("copilot", result.provider)
end)
end)