Adding more features
This commit is contained in:
54
CHANGELOG.md
54
CHANGELOG.md
@@ -7,6 +7,57 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.5.0] - 2026-01-15
|
||||
|
||||
### Added
|
||||
|
||||
- **Cost Tracking System** - Track LLM API costs across sessions
|
||||
- New `:CoderCost` command opens cost estimation floating window
|
||||
- Session costs tracked in real-time
|
||||
- All-time costs persisted in `.coder/cost_history.json`
|
||||
- Per-model breakdown with token counts
|
||||
- Pricing database for 50+ models (GPT-4/5, Claude, O-series, Gemini)
|
||||
- Window keymaps: `q` close, `r` refresh, `c` clear session, `C` clear all
|
||||
|
||||
- **Automatic Ollama Fallback** - Graceful degradation when API limits hit
|
||||
- Automatically switches to Ollama when Copilot rate limits exceeded
|
||||
- Detects local Ollama availability before fallback
|
||||
- Notifies user of provider switch
|
||||
|
||||
- **Enhanced Error Handling** - Better error messages for API failures
|
||||
- Shows actual API response on parse errors (not generic "failed to parse")
|
||||
- Improved rate limit detection and messaging
|
||||
- Sanitized newlines in error notifications to prevent UI crashes
|
||||
|
||||
- **Agent Tools System Improvements**
|
||||
- New `to_openai_format()` and `to_claude_format()` functions
|
||||
- `get_definitions()` for generic tool access
|
||||
- Fixed tool call argument serialization (JSON strings vs tables)
|
||||
|
||||
- **Credentials Management System** - Store API keys outside of config files
|
||||
- New `:CoderAddApiKey` command for interactive credential setup
|
||||
- `:CoderRemoveApiKey` to remove stored credentials
|
||||
- `:CoderCredentials` to view credential status
|
||||
- `:CoderSwitchProvider` to switch active LLM provider
|
||||
- Credentials stored in `~/.local/share/nvim/codetyper/configuration.json`
|
||||
- Priority: stored credentials > config > environment variables
|
||||
- Supports all providers: Claude, OpenAI, Gemini, Copilot, Ollama
|
||||
|
||||
### Changed
|
||||
|
||||
- Cost window now shows both session and all-time statistics
|
||||
- Improved agent prompt templates with correct tool names
|
||||
- Better error context in LLM provider responses
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed "Failed to parse Copilot response" error showing instead of actual error
|
||||
- Fixed `nvim_buf_set_lines` crash from newlines in error messages
|
||||
- Fixed `tools.definitions` nil error in agent initialization
|
||||
- Fixed tool name mismatch in agent prompts (write_file vs write)
|
||||
|
||||
---
|
||||
|
||||
## [0.4.0] - 2026-01-13
|
||||
|
||||
### Added
|
||||
@@ -194,7 +245,8 @@ scheduler = {
|
||||
- **Fixed** - Bug fixes
|
||||
- **Security** - Vulnerability fixes
|
||||
|
||||
[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...HEAD
|
||||
[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.5.0...HEAD
|
||||
[0.5.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...v0.5.0
|
||||
[0.4.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.3.0...v0.4.0
|
||||
[0.3.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.2.0...v0.3.0
|
||||
[0.2.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.1.0...v0.2.0
|
||||
|
||||
214
README.md
214
README.md
@@ -20,8 +20,10 @@
|
||||
- 🛡️ **Completion-Aware**: Safe injection that doesn't fight with autocomplete
|
||||
- 📁 **Auto-Index**: Automatically create coder companion files on file open
|
||||
- 📜 **Logs Panel**: Real-time visibility into LLM requests and token usage
|
||||
- 💰 **Cost Tracking**: Persistent LLM cost estimation with session and all-time stats
|
||||
- 🔒 **Git Integration**: Automatically adds `.coder.*` files to `.gitignore`
|
||||
- 🌳 **Project Tree Logging**: Maintains a `tree.log` tracking your project structure
|
||||
- 🧠 **Brain System**: Knowledge graph that learns from your coding patterns
|
||||
|
||||
---
|
||||
|
||||
@@ -34,6 +36,8 @@
|
||||
- [LLM Providers](#-llm-providers)
|
||||
- [Commands Reference](#-commands-reference)
|
||||
- [Usage Guide](#-usage-guide)
|
||||
- [Logs Panel](#-logs-panel)
|
||||
- [Cost Tracking](#-cost-tracking)
|
||||
- [Agent Mode](#-agent-mode)
|
||||
- [Keymaps](#-keymaps)
|
||||
- [Health Check](#-health-check)
|
||||
@@ -196,6 +200,32 @@ require("codetyper").setup({
|
||||
| `OPENAI_API_KEY` | OpenAI API key |
|
||||
| `GEMINI_API_KEY` | Google Gemini API key |
|
||||
|
||||
### Credentials Management
|
||||
|
||||
Instead of storing API keys in your config (which may be committed to git), you can use the credentials system:
|
||||
|
||||
```vim
|
||||
:CoderAddApiKey
|
||||
```
|
||||
|
||||
This command interactively prompts for:
|
||||
1. Provider selection (Claude, OpenAI, Gemini, Copilot, Ollama)
|
||||
2. API key (for cloud providers)
|
||||
3. Model name
|
||||
4. Custom endpoint (for OpenAI-compatible APIs)
|
||||
|
||||
Credentials are stored securely in `~/.local/share/nvim/codetyper/configuration.json` (not in your config files).
|
||||
|
||||
**Priority order for credentials:**
|
||||
1. Stored credentials (via `:CoderAddApiKey`)
|
||||
2. Config file settings
|
||||
3. Environment variables
|
||||
|
||||
**Other credential commands:**
|
||||
- `:CoderCredentials` - View configured providers
|
||||
- `:CoderSwitchProvider` - Switch between configured providers
|
||||
- `:CoderRemoveApiKey` - Remove stored credentials
|
||||
|
||||
---
|
||||
|
||||
## 🔌 LLM Providers
|
||||
@@ -255,49 +285,129 @@ llm = {
|
||||
|
||||
## 📝 Commands Reference
|
||||
|
||||
### Main Commands
|
||||
All commands can be invoked via `:Coder {subcommand}` or their dedicated command aliases.
|
||||
|
||||
### Core Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder open` | `:CoderOpen` | Open the coder split view |
|
||||
| `:Coder close` | `:CoderClose` | Close the coder split view |
|
||||
| `:Coder toggle` | `:CoderToggle` | Toggle the coder split view |
|
||||
| `:Coder process` | `:CoderProcess` | Process the last prompt in coder file |
|
||||
| `:Coder status` | - | Show plugin status and configuration |
|
||||
| `:Coder focus` | - | Switch focus between coder and target windows |
|
||||
| `:Coder reset` | - | Reset processed prompts to allow re-processing |
|
||||
| `:Coder gitignore` | - | Force update .gitignore with coder patterns |
|
||||
|
||||
### Ask Panel (Chat Interface)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder ask` | `:CoderAsk` | Open the Ask panel |
|
||||
| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle the Ask panel |
|
||||
| `:Coder ask-close` | - | Close the Ask panel |
|
||||
| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history |
|
||||
|
||||
### Agent Mode (Autonomous Coding)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agent` | `:CoderAgent` | Open the Agent panel |
|
||||
| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle the Agent panel |
|
||||
| `:Coder agent-close` | - | Close the Agent panel |
|
||||
| `:Coder agent-stop` | `:CoderAgentStop` | Stop the running agent |
|
||||
|
||||
### Agentic Mode (IDE-like Multi-file Agent)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agentic-run <task>` | `:CoderAgenticRun <task>` | Run an agentic task (multi-file changes) |
|
||||
| `:Coder agentic-list` | `:CoderAgenticList` | List available agents |
|
||||
| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize `.coder/agents/` and `.coder/rules/` |
|
||||
|
||||
### Transform Commands (Inline Tag Processing)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder transform` | `:CoderTransform` | Transform all `/@ @/` tags in file |
|
||||
| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor position |
|
||||
| - | `:CoderTransformVisual` | Transform selected tags (visual mode) |
|
||||
|
||||
### Project & Index Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| - | `:CoderIndex` | Open coder companion for current file |
|
||||
| `:Coder index-project` | `:CoderIndexProject` | Index the entire project |
|
||||
| `:Coder index-status` | `:CoderIndexStatus` | Show project index status |
|
||||
|
||||
### Tree & Structure Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder tree` | `:CoderTree` | Refresh `.coder/tree.log` |
|
||||
| `:Coder tree-view` | `:CoderTreeView` | View `.coder/tree.log` in split |
|
||||
|
||||
### Queue & Scheduler Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler and queue status |
|
||||
| `:Coder queue-process` | `:CoderQueueProcess` | Manually trigger queue processing |
|
||||
|
||||
### Processing Mode Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle automatic/manual prompt processing |
|
||||
| `:Coder auto-set <mode>` | `:CoderAutoSet <mode>` | Set processing mode (`auto`/`manual`) |
|
||||
|
||||
### Memory & Learning Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder memories` | `:CoderMemories` | Show learned memories |
|
||||
| `:Coder forget [pattern]` | `:CoderForget [pattern]` | Clear memories (optionally matching pattern) |
|
||||
|
||||
### Brain Commands (Knowledge Graph)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| - | `:CoderBrain [action]` | Brain management (`stats`/`commit`/`flush`/`prune`) |
|
||||
| - | `:CoderFeedback <type>` | Give feedback to brain (`good`/`bad`/`stats`) |
|
||||
|
||||
### LLM Statistics & Feedback
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:Coder {subcommand}` | Main command with subcommands |
|
||||
| `:CoderOpen` | Open the coder split view |
|
||||
| `:CoderClose` | Close the coder split view |
|
||||
| `:CoderToggle` | Toggle the coder split view |
|
||||
| `:CoderProcess` | Process the last prompt |
|
||||
| `:Coder llm-stats` | Show LLM provider accuracy statistics |
|
||||
| `:Coder llm-feedback-good` | Report positive feedback on last response |
|
||||
| `:Coder llm-feedback-bad` | Report negative feedback on last response |
|
||||
| `:Coder llm-reset-stats` | Reset LLM accuracy statistics |
|
||||
|
||||
### Ask Panel
|
||||
### Cost Tracking
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderAsk` | Open the Ask panel |
|
||||
| `:CoderAskToggle` | Toggle the Ask panel |
|
||||
| `:CoderAskClear` | Clear chat history |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder cost` | `:CoderCost` | Show LLM cost estimation window |
|
||||
| `:Coder cost-clear` | - | Clear session cost tracking |
|
||||
|
||||
### Agent Mode
|
||||
### Credentials Management
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderAgent` | Open the Agent panel |
|
||||
| `:CoderAgentToggle` | Toggle the Agent panel |
|
||||
| `:CoderAgentStop` | Stop the running agent |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder add-api-key` | `:CoderAddApiKey` | Add or update LLM provider API key |
|
||||
| `:Coder remove-api-key` | `:CoderRemoveApiKey` | Remove LLM provider credentials |
|
||||
| `:Coder credentials` | `:CoderCredentials` | Show credentials status |
|
||||
| `:Coder switch-provider` | `:CoderSwitchProvider` | Switch active LLM provider |
|
||||
|
||||
### Transform Commands
|
||||
### UI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderTransform` | Transform all /@ @/ tags in file |
|
||||
| `:CoderTransformCursor` | Transform tag at cursor position |
|
||||
| `:CoderTransformVisual` | Transform selected tags (visual mode) |
|
||||
|
||||
### Utility Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderIndex` | Open coder companion for current file |
|
||||
| `:CoderLogs` | Toggle logs panel |
|
||||
| `:CoderType` | Switch between Ask/Agent modes |
|
||||
| `:CoderTree` | Refresh tree.log |
|
||||
| `:CoderTreeView` | View tree.log |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder type-toggle` | `:CoderType` | Show Ask/Agent mode switcher |
|
||||
| `:Coder logs-toggle` | `:CoderLogs` | Toggle logs panel |
|
||||
|
||||
---
|
||||
|
||||
@@ -384,6 +494,42 @@ The logs panel opens automatically when processing prompts with the scheduler en
|
||||
|
||||
---
|
||||
|
||||
## 💰 Cost Tracking
|
||||
|
||||
Track your LLM API costs across sessions with the Cost Estimation window.
|
||||
|
||||
### Features
|
||||
|
||||
- **Session Tracking**: Monitor current session token usage and costs
|
||||
- **All-Time Tracking**: Persistent cost history stored per-project in `.coder/cost_history.json`
|
||||
- **Model Breakdown**: See costs by individual model
|
||||
- **Pricing Database**: Built-in pricing for 50+ models (GPT, Claude, Gemini, O-series, etc.)
|
||||
|
||||
### Opening the Cost Window
|
||||
|
||||
```vim
|
||||
:CoderCost
|
||||
```
|
||||
|
||||
### Cost Window Keymaps
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `q` / `<Esc>` | Close window |
|
||||
| `r` | Refresh display |
|
||||
| `c` | Clear session costs |
|
||||
| `C` | Clear all history |
|
||||
|
||||
### Supported Models
|
||||
|
||||
The cost tracker includes pricing for:
|
||||
- **OpenAI**: GPT-4, GPT-4o, GPT-4o-mini, O1, O3, O4-mini, and more
|
||||
- **Anthropic**: Claude 3 Opus, Sonnet, Haiku, Claude 3.5 Sonnet/Haiku
|
||||
- **Local**: Ollama models (free, but usage tracked)
|
||||
- **Copilot**: Usage tracked (included in subscription)
|
||||
|
||||
---
|
||||
|
||||
## 🤖 Agent Mode
|
||||
|
||||
The Agent mode provides an autonomous coding assistant with tool access:
|
||||
|
||||
188
llms.txt
188
llms.txt
@@ -33,6 +33,8 @@ lua/codetyper/
|
||||
├── health.lua # Health check for :checkhealth
|
||||
├── tree.lua # Project tree logging (.coder/tree.log)
|
||||
├── logs_panel.lua # Standalone logs panel UI
|
||||
├── cost.lua # LLM cost tracking with persistent history
|
||||
├── credentials.lua # Secure credential storage (API keys, models)
|
||||
├── llm/
|
||||
│ ├── init.lua # LLM interface, provider selection
|
||||
│ ├── claude.lua # Claude API client (Anthropic)
|
||||
@@ -68,7 +70,14 @@ The plugin automatically creates and maintains a `.coder/` folder in your projec
|
||||
|
||||
```
|
||||
.coder/
|
||||
└── tree.log # Project structure, auto-updated on file changes
|
||||
├── tree.log # Project structure, auto-updated on file changes
|
||||
├── cost_history.json # LLM cost tracking history (persistent)
|
||||
├── brain/ # Knowledge graph storage
|
||||
│ ├── nodes/ # Learning nodes by type
|
||||
│ ├── indices/ # Search indices
|
||||
│ └── deltas/ # Version history
|
||||
├── agents/ # Custom agent definitions
|
||||
└── rules/ # Project-specific rules
|
||||
```
|
||||
|
||||
## Key Features
|
||||
@@ -115,7 +124,49 @@ auto_index = true -- disabled by default
|
||||
|
||||
Real-time visibility into LLM operations with token usage tracking.
|
||||
|
||||
### 6. Event-Driven Scheduler
|
||||
### 6. Cost Tracking
|
||||
|
||||
Track LLM API costs across sessions:
|
||||
|
||||
- **Session tracking**: Monitor current session costs in real-time
|
||||
- **All-time tracking**: Persistent history in `.coder/cost_history.json`
|
||||
- **Per-model breakdown**: See costs by individual model
|
||||
- **50+ models**: Built-in pricing for GPT, Claude, O-series, Gemini
|
||||
|
||||
Cost window keymaps:
|
||||
- `q`/`<Esc>` - Close window
|
||||
- `r` - Refresh display
|
||||
- `c` - Clear session costs
|
||||
- `C` - Clear all history
|
||||
|
||||
### 7. Automatic Ollama Fallback
|
||||
|
||||
When API rate limits are hit (e.g., Copilot free tier), the plugin:
|
||||
1. Detects the rate limit error
|
||||
2. Checks if local Ollama is available
|
||||
3. Automatically switches provider to Ollama
|
||||
4. Notifies user of the provider change
|
||||
|
||||
### 8. Credentials Management
|
||||
|
||||
Store API keys securely outside of config files:
|
||||
|
||||
```vim
|
||||
:CoderAddApiKey
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Interactive prompts for provider, API key, model, endpoint
|
||||
- Stored in `~/.local/share/nvim/codetyper/configuration.json`
|
||||
- Supports all providers: Claude, OpenAI, Gemini, Copilot, Ollama
|
||||
- Switch providers at runtime with `:CoderSwitchProvider`
|
||||
|
||||
**Credential priority:**
|
||||
1. Stored credentials (via `:CoderAddApiKey`)
|
||||
2. Config file settings (`require("codetyper").setup({...})`)
|
||||
3. Environment variables (`OPENAI_API_KEY`, etc.)
|
||||
|
||||
### 9. Event-Driven Scheduler
|
||||
|
||||
Prompts are treated as events, not commands:
|
||||
|
||||
@@ -143,7 +194,7 @@ scheduler = {
|
||||
}
|
||||
```
|
||||
|
||||
### 7. Tree-sitter Scope Resolution
|
||||
### 10. Tree-sitter Scope Resolution
|
||||
|
||||
Prompts automatically resolve to their enclosing function/method/class:
|
||||
|
||||
@@ -158,7 +209,7 @@ end
|
||||
For replacement intents (complete, refactor, fix), the entire scope is extracted
|
||||
and sent to the LLM, then replaced with the transformed version.
|
||||
|
||||
### 8. Intent Detection
|
||||
### 11. Intent Detection
|
||||
|
||||
The system parses prompts to detect user intent:
|
||||
|
||||
@@ -173,7 +224,7 @@ The system parses prompts to detect user intent:
|
||||
| optimize | optimize, performance, faster | replace |
|
||||
| explain | explain, what, how, why | none |
|
||||
|
||||
### 9. Tag Precedence
|
||||
### 12. Tag Precedence
|
||||
|
||||
Multiple tags in the same scope follow "first tag wins" rule:
|
||||
- Earlier (by line number) unresolved tag processes first
|
||||
@@ -182,33 +233,114 @@ Multiple tags in the same scope follow "first tag wins" rule:
|
||||
|
||||
## Commands
|
||||
|
||||
### Main Commands
|
||||
- `:Coder open` - Opens split view with coder file
|
||||
- `:Coder close` - Closes the split
|
||||
- `:Coder toggle` - Toggles the view
|
||||
- `:Coder process` - Manually triggers code generation
|
||||
All commands can be invoked via `:Coder {subcommand}` or dedicated aliases.
|
||||
|
||||
### Ask Panel
|
||||
- `:CoderAsk` - Open Ask panel
|
||||
- `:CoderAskToggle` - Toggle Ask panel
|
||||
- `:CoderAskClear` - Clear chat history
|
||||
### Core Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder open` | `:CoderOpen` | Open coder split view |
|
||||
| `:Coder close` | `:CoderClose` | Close coder split view |
|
||||
| `:Coder toggle` | `:CoderToggle` | Toggle coder split view |
|
||||
| `:Coder process` | `:CoderProcess` | Process last prompt in coder file |
|
||||
| `:Coder status` | - | Show plugin status and configuration |
|
||||
| `:Coder focus` | - | Switch focus between coder/target windows |
|
||||
| `:Coder reset` | - | Reset processed prompts |
|
||||
| `:Coder gitignore` | - | Force update .gitignore |
|
||||
|
||||
### Agent Mode
|
||||
- `:CoderAgent` - Open Agent panel
|
||||
- `:CoderAgentToggle` - Toggle Agent panel
|
||||
- `:CoderAgentStop` - Stop running agent
|
||||
### Ask Panel (Chat Interface)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder ask` | `:CoderAsk` | Open Ask panel |
|
||||
| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle Ask panel |
|
||||
| `:Coder ask-close` | - | Close Ask panel |
|
||||
| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history |
|
||||
|
||||
### Transform
|
||||
- `:CoderTransform` - Transform all tags
|
||||
- `:CoderTransformCursor` - Transform at cursor
|
||||
- `:CoderTransformVisual` - Transform selection
|
||||
### Agent Mode (Autonomous Coding)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agent` | `:CoderAgent` | Open Agent panel |
|
||||
| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle Agent panel |
|
||||
| `:Coder agent-close` | - | Close Agent panel |
|
||||
| `:Coder agent-stop` | `:CoderAgentStop` | Stop running agent |
|
||||
|
||||
### Utility
|
||||
- `:CoderIndex` - Open coder companion
|
||||
- `:CoderLogs` - Toggle logs panel
|
||||
- `:CoderType` - Switch Ask/Agent mode
|
||||
- `:CoderTree` - Refresh tree.log
|
||||
- `:CoderTreeView` - View tree.log
|
||||
### Agentic Mode (IDE-like Multi-file Agent)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agentic-run <task>` | `:CoderAgenticRun <task>` | Run agentic task |
|
||||
| `:Coder agentic-list` | `:CoderAgenticList` | List available agents |
|
||||
| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize .coder/agents/ and .coder/rules/ |
|
||||
|
||||
### Transform Commands (Inline Tag Processing)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder transform` | `:CoderTransform` | Transform all /@ @/ tags in file |
|
||||
| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor |
|
||||
| - | `:CoderTransformVisual` | Transform selected tags (visual mode) |
|
||||
|
||||
### Project & Index Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| - | `:CoderIndex` | Open coder companion for current file |
|
||||
| `:Coder index-project` | `:CoderIndexProject` | Index entire project |
|
||||
| `:Coder index-status` | `:CoderIndexStatus` | Show project index status |
|
||||
|
||||
### Tree & Structure Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder tree` | `:CoderTree` | Refresh .coder/tree.log |
|
||||
| `:Coder tree-view` | `:CoderTreeView` | View .coder/tree.log |
|
||||
|
||||
### Queue & Scheduler Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler/queue status |
|
||||
| `:Coder queue-process` | `:CoderQueueProcess` | Manually trigger queue processing |
|
||||
|
||||
### Processing Mode Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle automatic/manual processing |
|
||||
| `:Coder auto-set <mode>` | `:CoderAutoSet <mode>` | Set mode (auto/manual) |
|
||||
|
||||
### Memory & Learning Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder memories` | `:CoderMemories` | Show learned memories |
|
||||
| `:Coder forget [pattern]` | `:CoderForget [pattern]` | Clear memories |
|
||||
|
||||
### Brain Commands (Knowledge Graph)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| - | `:CoderBrain [action]` | Brain management (stats/commit/flush/prune) |
|
||||
| - | `:CoderFeedback <type>` | Give feedback (good/bad/stats) |
|
||||
|
||||
### LLM Statistics & Feedback
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:Coder llm-stats` | Show LLM provider accuracy stats |
|
||||
| `:Coder llm-feedback-good` | Report positive feedback |
|
||||
| `:Coder llm-feedback-bad` | Report negative feedback |
|
||||
| `:Coder llm-reset-stats` | Reset LLM accuracy stats |
|
||||
|
||||
### Cost Tracking
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder cost` | `:CoderCost` | Show LLM cost estimation window |
|
||||
| `:Coder cost-clear` | - | Clear session cost tracking |
|
||||
|
||||
### Credentials Management
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder add-api-key` | `:CoderAddApiKey` | Add/update LLM provider credentials |
|
||||
| `:Coder remove-api-key` | `:CoderRemoveApiKey` | Remove provider credentials |
|
||||
| `:Coder credentials` | `:CoderCredentials` | Show credentials status |
|
||||
| `:Coder switch-provider` | `:CoderSwitchProvider` | Switch active provider |
|
||||
|
||||
### UI Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder type-toggle` | `:CoderType` | Show Ask/Agent mode switcher |
|
||||
| `:Coder logs-toggle` | `:CoderLogs` | Toggle logs panel |
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
|
||||
854
lua/codetyper/agent/agentic.lua
Normal file
854
lua/codetyper/agent/agentic.lua
Normal file
@@ -0,0 +1,854 @@
|
||||
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
|
||||
---@brief [[
|
||||
--- Full agentic system that handles multi-file changes via tool calling.
|
||||
--- Inspired by avante.nvim and opencode patterns.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgenticMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_calls? table[] For assistant messages with tool calls
|
||||
---@field tool_call_id? string For tool result messages
|
||||
---@field name? string Tool name for tool results
|
||||
|
||||
---@class AgenticToolCall
|
||||
---@field id string Unique tool call ID
|
||||
---@field type "function"
|
||||
---@field function {name: string, arguments: string|table}
|
||||
|
||||
---@class AgenticOpts
|
||||
---@field task string The task to accomplish
|
||||
---@field files? string[] Initial files to include as context
|
||||
---@field agent? string Agent name to use (default: "coder")
|
||||
---@field model? string Model override
|
||||
---@field max_iterations? number Max tool call rounds (default: 20)
|
||||
---@field on_message? fun(msg: AgenticMessage) Called for each message
|
||||
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
|
||||
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_file_change? fun(path: string, action: string) Called when file is modified
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
|
||||
---@field on_status? fun(status: string) Status updates
|
||||
|
||||
--- Generate unique tool call ID
|
||||
local function generate_tool_call_id()
|
||||
return "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF))
|
||||
end
|
||||
|
||||
--- Load agent definition
|
||||
---@param name string Agent name
|
||||
---@return table|nil agent definition
|
||||
local function load_agent(name)
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
local agent_file = agents_dir .. "/" .. name .. ".md"
|
||||
|
||||
-- Check if custom agent exists
|
||||
if vim.fn.filereadable(agent_file) == 1 then
|
||||
local content = table.concat(vim.fn.readfile(agent_file), "\n")
|
||||
-- Parse frontmatter and content
|
||||
local frontmatter = {}
|
||||
local body = content
|
||||
|
||||
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
|
||||
if fm_match then
|
||||
-- Parse YAML-like frontmatter
|
||||
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
|
||||
local key, value = line:match("^(%w+):%s*(.+)$")
|
||||
if key and value then
|
||||
frontmatter[key] = value
|
||||
end
|
||||
end
|
||||
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
|
||||
end
|
||||
|
||||
return {
|
||||
name = name,
|
||||
description = frontmatter.description or "Custom agent: " .. name,
|
||||
system_prompt = body,
|
||||
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
|
||||
model = frontmatter.model,
|
||||
}
|
||||
end
|
||||
|
||||
-- Built-in agents
|
||||
local builtin_agents = {
|
||||
coder = {
|
||||
name = "coder",
|
||||
description = "Full-featured coding agent with file modification capabilities",
|
||||
system_prompt = [[You are an expert software engineer. You have access to tools to read, write, and modify files.
|
||||
|
||||
## Your Capabilities
|
||||
- Read files to understand the codebase
|
||||
- Search for patterns with grep and glob
|
||||
- Create new files with write tool
|
||||
- Edit existing files with precise replacements
|
||||
- Execute shell commands for builds and tests
|
||||
|
||||
## Guidelines
|
||||
1. Always read relevant files before making changes
|
||||
2. Make minimal, focused changes
|
||||
3. Follow existing code style and patterns
|
||||
4. Create tests when adding new functionality
|
||||
5. Verify changes work by running tests or builds
|
||||
|
||||
## Important Rules
|
||||
- NEVER guess file contents - always read first
|
||||
- Make precise edits using exact string matching
|
||||
- Explain your reasoning before making changes
|
||||
- If unsure, ask for clarification]],
|
||||
tools = { "view", "edit", "write", "grep", "glob", "bash" },
|
||||
},
|
||||
planner = {
|
||||
name = "planner",
|
||||
description = "Planning agent - read-only, helps design implementations",
|
||||
system_prompt = [[You are a software architect. Analyze codebases and create implementation plans.
|
||||
|
||||
You can read files and search the codebase, but cannot modify files.
|
||||
Your role is to:
|
||||
1. Understand the existing architecture
|
||||
2. Identify relevant files and patterns
|
||||
3. Create step-by-step implementation plans
|
||||
4. Suggest which files to modify and how
|
||||
|
||||
Be thorough in your analysis before making recommendations.]],
|
||||
tools = { "view", "grep", "glob" },
|
||||
},
|
||||
explorer = {
|
||||
name = "explorer",
|
||||
description = "Exploration agent - quickly find information in codebase",
|
||||
system_prompt = [[You are a codebase exploration assistant. Find information quickly and report back.
|
||||
|
||||
Your goal is to efficiently search and summarize findings.
|
||||
Use glob to find files, grep to search content, and view to read specific files.
|
||||
Be concise and focused in your responses.]],
|
||||
tools = { "view", "grep", "glob" },
|
||||
},
|
||||
}
|
||||
|
||||
return builtin_agents[name]
|
||||
end
|
||||
|
||||
--- Load rules from .coder/rules/
|
||||
---@return string Combined rules content
|
||||
local function load_rules()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
local rules = {}
|
||||
|
||||
if vim.fn.isdirectory(rules_dir) == 1 then
|
||||
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local content = table.concat(vim.fn.readfile(file), "\n")
|
||||
local filename = vim.fn.fnamemodify(file, ":t:r")
|
||||
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
|
||||
end
|
||||
end
|
||||
|
||||
if #rules > 0 then
|
||||
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
--- Build messages array for API request
|
||||
---@param history AgenticMessage[]
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted messages
|
||||
local function build_messages(history, provider)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
if provider == "claude" then
|
||||
-- Claude uses system parameter, not message
|
||||
-- Skip system messages in array
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
if provider == "claude" then
|
||||
-- Claude format: content is array of blocks
|
||||
message.content = {}
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(message.content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
for _, tc in ipairs(msg.tool_calls) do
|
||||
table.insert(message.content, {
|
||||
type = "tool_use",
|
||||
id = tc.id,
|
||||
name = tc["function"].name,
|
||||
input = type(tc["function"].arguments) == "string"
|
||||
and vim.json.decode(tc["function"].arguments)
|
||||
or tc["function"].arguments,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
if provider == "claude" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = msg.tool_call_id,
|
||||
content = msg.content,
|
||||
},
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Build tools array for API request
|
||||
---@param tool_names string[] Tool names to include
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted tools
|
||||
local function build_tools(tool_names, provider)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local tools = {}
|
||||
|
||||
for _, name in ipairs(tool_names) do
|
||||
local tool = tools_mod.get(name)
|
||||
if tool then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
if provider == "claude" then
|
||||
table.insert(tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Execute a tool call
|
||||
---@param tool_call AgenticToolCall
|
||||
---@param opts AgenticOpts
|
||||
---@return string result
|
||||
---@return string|nil error
|
||||
local function execute_tool(tool_call, opts)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local name = tool_call["function"].name
|
||||
local args = tool_call["function"].arguments
|
||||
|
||||
-- Parse arguments if string
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
else
|
||||
return "", "Failed to parse tool arguments: " .. args
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify tool start
|
||||
if opts.on_tool_start then
|
||||
opts.on_tool_start(name, args)
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status("Executing: " .. name)
|
||||
end
|
||||
|
||||
-- Execute the tool
|
||||
local tool = tools_mod.get(name)
|
||||
if not tool then
|
||||
local err = "Unknown tool: " .. name
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, nil, err)
|
||||
end
|
||||
return "", err
|
||||
end
|
||||
|
||||
local result, err = tool.func(args, {
|
||||
on_log = function(msg)
|
||||
if opts.on_status then
|
||||
opts.on_status(msg)
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
-- Notify tool end
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, result, err)
|
||||
end
|
||||
|
||||
-- Track file changes
|
||||
if opts.on_file_change and (name == "write" or name == "edit") and not err then
|
||||
opts.on_file_change(args.path, name == "write" and "created" or "modified")
|
||||
end
|
||||
|
||||
if err then
|
||||
return "", err
|
||||
end
|
||||
|
||||
return type(result) == "string" and result or vim.json.encode(result), nil
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return AgenticToolCall[]
|
||||
local function parse_tool_calls(response, provider)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Unified format: content array with tool_use blocks
|
||||
local content = response.content or {}
|
||||
for _, block in ipairs(content) do
|
||||
if block.type == "tool_use" then
|
||||
-- OpenAI expects arguments as JSON string, not table
|
||||
local args = block.input
|
||||
if type(args) == "table" then
|
||||
args = vim.json.encode(args)
|
||||
end
|
||||
|
||||
table.insert(tool_calls, {
|
||||
id = block.id or generate_tool_call_id(),
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = block.name,
|
||||
arguments = args,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Extract text content from response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return string
|
||||
local function extract_content(response, provider)
|
||||
local parts = {}
|
||||
for _, block in ipairs(response.content or {}) do
|
||||
if block.type == "text" then
|
||||
table.insert(parts, block.text)
|
||||
end
|
||||
end
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
--- Check if response indicates completion (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return boolean
|
||||
local function is_complete(response, provider)
|
||||
return response.stop_reason == "end_turn"
|
||||
end
|
||||
|
||||
--- Make API request to LLM with native tool calling support
|
||||
---@param messages table[] Formatted messages
|
||||
---@param tools table[] Formatted tools
|
||||
---@param system_prompt string System prompt
|
||||
---@param provider string "openai"|"claude"|"copilot"
|
||||
---@param model string Model name
|
||||
---@param callback fun(response: table|nil, error: string|nil)
|
||||
local function call_llm(messages, tools, system_prompt, provider, model, callback)
|
||||
local context = {
|
||||
language = "lua",
|
||||
file_content = "",
|
||||
prompt_type = "agent",
|
||||
project_root = vim.fn.getcwd(),
|
||||
cwd = vim.fn.getcwd(),
|
||||
}
|
||||
|
||||
-- Use native tool calling APIs
|
||||
if provider == "copilot" then
|
||||
local client = require("codetyper.llm.copilot")
|
||||
|
||||
-- Copilot's generate_with_tools expects messages in a specific format
|
||||
-- Convert to the format it expects
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
-- Convert to our internal format
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or generate_tool_call_id(),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "openai" then
|
||||
local client = require("codetyper.llm.openai")
|
||||
|
||||
-- OpenAI's generate_with_tools
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or generate_tool_call_id(),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "ollama" then
|
||||
local client = require("codetyper.llm.ollama")
|
||||
|
||||
-- Ollama's generate_with_tools (text-based tool calling)
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
callback(response, nil)
|
||||
end)
|
||||
else
|
||||
-- Fallback for other providers (ollama, etc.) - use text-based parsing
|
||||
local client = require("codetyper.llm." .. provider)
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, "User: " .. content)
|
||||
elseif msg.role == "assistant" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, "Assistant: " .. content)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add tool descriptions to prompt for text-based providers
|
||||
local tool_desc = "\n\n## Available Tools\n"
|
||||
tool_desc = tool_desc .. "Call tools by outputting JSON in this format:\n"
|
||||
tool_desc = tool_desc .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
|
||||
for _, tool in ipairs(tools) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
context.file_content = system_prompt .. tool_desc
|
||||
|
||||
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Parse response for tool calls (text-based fallback)
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
-- Extract text content
|
||||
local text_content = response
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = generate_tool_call_id(),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
if text_content and text_content ~= "" then
|
||||
table.insert(result.content, 1, { type = "text", text = text_content })
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Run the agentic loop
|
||||
---@param opts AgenticOpts
|
||||
function M.run(opts)
|
||||
-- Load agent
|
||||
local agent = load_agent(opts.agent or "coder")
|
||||
if not agent then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Load rules
|
||||
local rules = load_rules()
|
||||
|
||||
-- Build system prompt
|
||||
local system_prompt = agent.system_prompt .. rules
|
||||
|
||||
-- Initialize message history
|
||||
---@type AgenticMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = system_prompt },
|
||||
}
|
||||
|
||||
-- Add initial file context if provided
|
||||
if opts.files and #opts.files > 0 then
|
||||
local file_context = "# Initial Files\n"
|
||||
for _, file_path in ipairs(opts.files) do
|
||||
local content = table.concat(vim.fn.readfile(file_path) or {}, "\n")
|
||||
file_context = file_context .. string.format("\n## %s\n```\n%s\n```\n", file_path, content)
|
||||
end
|
||||
table.insert(history, { role = "user", content = file_context })
|
||||
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
|
||||
end
|
||||
|
||||
-- Add the task
|
||||
table.insert(history, { role = "user", content = opts.task })
|
||||
|
||||
-- Determine provider
|
||||
local config = require("codetyper").get_config()
|
||||
local provider = config.llm.provider or "copilot"
|
||||
-- Note: Ollama has its own handler in call_llm, don't change it
|
||||
|
||||
-- Get tools for this agent
|
||||
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
|
||||
|
||||
-- Ensure tools are loaded
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
tools_mod.setup()
|
||||
|
||||
-- Build tools for API
|
||||
local tools = build_tools(tool_names, provider)
|
||||
|
||||
-- Iteration tracking
|
||||
local iteration = 0
|
||||
local max_iterations = opts.max_iterations or 20
|
||||
|
||||
--- Process one iteration
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
|
||||
end
|
||||
|
||||
-- Build messages for API
|
||||
local messages = build_messages(history, provider)
|
||||
|
||||
-- Call LLM
|
||||
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
|
||||
if err then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, err)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract content and tool calls
|
||||
local content = extract_content(response, provider)
|
||||
local tool_calls = parse_tool_calls(response, provider)
|
||||
|
||||
-- Add assistant message to history
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(assistant_msg)
|
||||
end
|
||||
|
||||
-- Process tool calls if any
|
||||
if #tool_calls > 0 then
|
||||
for _, tc in ipairs(tool_calls) do
|
||||
local result, tool_err = execute_tool(tc, opts)
|
||||
|
||||
-- Add tool result to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = tc.id,
|
||||
name = tc["function"].name,
|
||||
content = tool_err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(tool_msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Continue the loop
|
||||
vim.schedule(process_iteration)
|
||||
else
|
||||
-- No tool calls - check if complete
|
||||
if is_complete(response, provider) or content ~= "" then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(content, nil)
|
||||
end
|
||||
else
|
||||
-- Continue if not explicitly complete
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create default agent files in .coder/agents/
|
||||
function M.init_agents_dir()
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
vim.fn.mkdir(agents_dir, "p")
|
||||
|
||||
-- Create example agent
|
||||
local example_agent = [[---
|
||||
description: Example custom agent
|
||||
tools: view,grep,glob,edit,write
|
||||
model:
|
||||
---
|
||||
|
||||
# Custom Agent
|
||||
|
||||
You are a custom coding agent. Describe your specialized behavior here.
|
||||
|
||||
## Your Role
|
||||
- Define what this agent specializes in
|
||||
- List specific capabilities
|
||||
|
||||
## Guidelines
|
||||
- Add agent-specific rules
|
||||
- Define coding standards to follow
|
||||
|
||||
## Examples
|
||||
Provide examples of how to handle common tasks.
|
||||
]]
|
||||
|
||||
local example_path = agents_dir .. "/example.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
|
||||
end
|
||||
|
||||
return agents_dir
|
||||
end
|
||||
|
||||
--- Create default rules in .coder/rules/
|
||||
function M.init_rules_dir()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
vim.fn.mkdir(rules_dir, "p")
|
||||
|
||||
-- Create example rule
|
||||
local example_rule = [[# Code Style
|
||||
|
||||
Follow these coding standards:
|
||||
|
||||
## General
|
||||
- Use consistent indentation (tabs or spaces based on project)
|
||||
- Keep lines under 100 characters
|
||||
- Add comments for complex logic
|
||||
|
||||
## Naming Conventions
|
||||
- Use descriptive variable names
|
||||
- Functions should be verbs (e.g., getUserData, calculateTotal)
|
||||
- Constants in UPPER_SNAKE_CASE
|
||||
|
||||
## Testing
|
||||
- Write tests for new functionality
|
||||
- Aim for >80% code coverage
|
||||
- Test edge cases
|
||||
|
||||
## Documentation
|
||||
- Document public APIs
|
||||
- Include usage examples
|
||||
- Keep docs up to date with code
|
||||
]]
|
||||
|
||||
local example_path = rules_dir .. "/code-style.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
|
||||
end
|
||||
|
||||
return rules_dir
|
||||
end
|
||||
|
||||
--- Initialize both agents and rules directories
|
||||
function M.init()
|
||||
M.init_agents_dir()
|
||||
M.init_rules_dir()
|
||||
end
|
||||
|
||||
--- List available agents
|
||||
---@return table[] List of {name, description, builtin}
|
||||
function M.list_agents()
|
||||
local agents = {}
|
||||
|
||||
-- Built-in agents
|
||||
local builtins = { "coder", "planner", "explorer" }
|
||||
for _, name in ipairs(builtins) do
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = true,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Custom agents from .coder/agents/
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
if vim.fn.isdirectory(agents_dir) == 1 then
|
||||
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local name = vim.fn.fnamemodify(file, ":t:r")
|
||||
if not vim.tbl_contains(builtins, name) then
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return agents
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -8,11 +8,11 @@ local M = {}
|
||||
|
||||
--- Heuristic weights (must sum to 1.0)
|
||||
M.weights = {
|
||||
length = 0.15, -- Response length relative to prompt
|
||||
length = 0.15, -- Response length relative to prompt
|
||||
uncertainty = 0.30, -- Uncertainty phrases
|
||||
syntax = 0.25, -- Syntax completeness
|
||||
repetition = 0.15, -- Duplicate lines
|
||||
truncation = 0.15, -- Incomplete ending
|
||||
syntax = 0.25, -- Syntax completeness
|
||||
repetition = 0.15, -- Duplicate lines
|
||||
truncation = 0.15, -- Incomplete ending
|
||||
}
|
||||
|
||||
--- Uncertainty phrases that indicate low confidence
|
||||
@@ -255,14 +255,15 @@ function M.score(response, prompt, context)
|
||||
_ = context -- Reserved for future use
|
||||
|
||||
if not response or #response == 0 then
|
||||
return 0, {
|
||||
length = 0,
|
||||
uncertainty = 0,
|
||||
syntax = 0,
|
||||
repetition = 0,
|
||||
truncation = 0,
|
||||
weighted_total = 0,
|
||||
}
|
||||
return 0,
|
||||
{
|
||||
length = 0,
|
||||
uncertainty = 0,
|
||||
syntax = 0,
|
||||
repetition = 0,
|
||||
truncation = 0,
|
||||
weighted_total = 0,
|
||||
}
|
||||
end
|
||||
|
||||
local scores = {
|
||||
|
||||
@@ -111,7 +111,11 @@ function M.agent_loop(context, callbacks)
|
||||
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
|
||||
|
||||
-- Generate with tools enabled
|
||||
client.generate_with_tools(state.conversation, context, tools.definitions, function(response, err)
|
||||
-- Ensure tools are loaded and get definitions
|
||||
tools.setup()
|
||||
local tool_defs = tools.to_openai_format()
|
||||
|
||||
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
|
||||
if err then
|
||||
state.is_running = false
|
||||
callbacks.on_error(err)
|
||||
|
||||
614
lua/codetyper/agent/inject.lua
Normal file
614
lua/codetyper/agent/inject.lua
Normal file
@@ -0,0 +1,614 @@
|
||||
---@mod codetyper.agent.inject Smart code injection with import handling
|
||||
---@brief [[
|
||||
--- Intelligent code injection that properly handles imports, merging them
|
||||
--- into existing import sections instead of blindly appending.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class ImportConfig
|
||||
---@field pattern string Lua pattern to match import statements
|
||||
---@field multi_line boolean Whether imports can span multiple lines
|
||||
---@field sort_key function|nil Function to extract sort key from import
|
||||
---@field group_by function|nil Function to group imports
|
||||
|
||||
---@class ParsedCode
|
||||
---@field imports string[] Import statements
|
||||
---@field body string[] Non-import code lines
|
||||
---@field import_lines table<number, boolean> Map of line numbers that are imports
|
||||
|
||||
--- Language-specific import patterns
|
||||
local import_patterns = {
|
||||
-- JavaScript/TypeScript
|
||||
javascript = {
|
||||
{ pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true },
|
||||
{ pattern = "^%s*import%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*import%s*{", multi_line = true },
|
||||
{ pattern = "^%s*import%s*%*", multi_line = true },
|
||||
{ pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true },
|
||||
{ pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
{ pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
{ pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
},
|
||||
-- Python
|
||||
python = {
|
||||
{ pattern = "^%s*import%s+%w", multi_line = false },
|
||||
{ pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true },
|
||||
},
|
||||
-- Lua
|
||||
lua = {
|
||||
{ pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require%s*%(?['\"]", multi_line = false },
|
||||
},
|
||||
-- Go
|
||||
go = {
|
||||
{ pattern = "^%s*import%s+%(?", multi_line = true },
|
||||
},
|
||||
-- Rust
|
||||
rust = {
|
||||
{ pattern = "^%s*use%s+", multi_line = true },
|
||||
{ pattern = "^%s*extern%s+crate%s+", multi_line = false },
|
||||
},
|
||||
-- C/C++
|
||||
c = {
|
||||
{ pattern = "^%s*#include%s*[<\"]", multi_line = false },
|
||||
},
|
||||
-- Java/Kotlin
|
||||
java = {
|
||||
{ pattern = "^%s*import%s+", multi_line = false },
|
||||
},
|
||||
-- Ruby
|
||||
ruby = {
|
||||
{ pattern = "^%s*require%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require_relative%s+['\"]", multi_line = false },
|
||||
},
|
||||
-- PHP
|
||||
php = {
|
||||
{ pattern = "^%s*use%s+", multi_line = false },
|
||||
{ pattern = "^%s*require%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require_once%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*include%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*include_once%s+['\"]", multi_line = false },
|
||||
},
|
||||
}
|
||||
|
||||
-- Alias common extensions to language configs
|
||||
import_patterns.ts = import_patterns.javascript
|
||||
import_patterns.tsx = import_patterns.javascript
|
||||
import_patterns.jsx = import_patterns.javascript
|
||||
import_patterns.mjs = import_patterns.javascript
|
||||
import_patterns.cjs = import_patterns.javascript
|
||||
import_patterns.py = import_patterns.python
|
||||
import_patterns.cpp = import_patterns.c
|
||||
import_patterns.hpp = import_patterns.c
|
||||
import_patterns.h = import_patterns.c
|
||||
import_patterns.kt = import_patterns.java
|
||||
import_patterns.rs = import_patterns.rust
|
||||
import_patterns.rb = import_patterns.ruby
|
||||
|
||||
--- Check if a line is an import statement for the given language
|
||||
---@param line string
|
||||
---@param patterns table[] Import patterns for the language
|
||||
---@return boolean is_import
|
||||
---@return boolean is_multi_line
|
||||
local function is_import_line(line, patterns)
|
||||
for _, p in ipairs(patterns) do
|
||||
if line:match(p.pattern) then
|
||||
return true, p.multi_line or false
|
||||
end
|
||||
end
|
||||
return false, false
|
||||
end
|
||||
|
||||
--- Check if a line is empty or a comment
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function is_empty_or_comment(line, filetype)
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
if trimmed == "" then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Language-specific comment patterns
|
||||
local comment_patterns = {
|
||||
lua = { "^%-%-" },
|
||||
python = { "^#" },
|
||||
javascript = { "^//", "^/%*", "^%*" },
|
||||
typescript = { "^//", "^/%*", "^%*" },
|
||||
go = { "^//", "^/%*", "^%*" },
|
||||
rust = { "^//", "^/%*", "^%*" },
|
||||
c = { "^//", "^/%*", "^%*", "^#" },
|
||||
java = { "^//", "^/%*", "^%*" },
|
||||
ruby = { "^#" },
|
||||
php = { "^//", "^/%*", "^%*", "^#" },
|
||||
}
|
||||
|
||||
local patterns = comment_patterns[filetype] or comment_patterns.javascript
|
||||
for _, pattern in ipairs(patterns) do
|
||||
if trimmed:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if a line ends a multi-line import
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function ends_multiline_import(line, filetype)
|
||||
-- Check for closing patterns
|
||||
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
|
||||
-- ES6 imports end with 'from "..." ;' or just ';' or a line with just '}'
|
||||
if line:match("from%s+['\"][^'\"]+['\"]%s*;?%s*$") then
|
||||
return true
|
||||
end
|
||||
if line:match("}%s*from%s+['\"]") then
|
||||
return true
|
||||
end
|
||||
if line:match("^%s*}%s*;?%s*$") then
|
||||
return true
|
||||
end
|
||||
if line:match(";%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "python" or filetype == "py" then
|
||||
-- Python single-line import: doesn't end with \, (, or ,
|
||||
-- Examples: "from typing import List, Dict" or "import os"
|
||||
if not line:match("\\%s*$") and not line:match("%(%s*$") and not line:match(",%s*$") then
|
||||
return true
|
||||
end
|
||||
-- Python multiline imports end with closing paren
|
||||
if line:match("%)%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "go" then
|
||||
-- Go multi-line imports end with ')'
|
||||
if line:match("%)%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "rust" or filetype == "rs" then
|
||||
-- Rust use statements end with ';'
|
||||
if line:match(";%s*$") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Parse code into imports and body
|
||||
---@param code string|string[] Code to parse
|
||||
---@param filetype string File type/extension
|
||||
---@return ParsedCode
|
||||
function M.parse_code(code, filetype)
|
||||
local lines
|
||||
if type(code) == "string" then
|
||||
lines = vim.split(code, "\n", { plain = true })
|
||||
else
|
||||
lines = code
|
||||
end
|
||||
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local result = {
|
||||
imports = {},
|
||||
body = {},
|
||||
import_lines = {},
|
||||
}
|
||||
|
||||
local in_multiline_import = false
|
||||
local current_import_lines = {}
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline_import then
|
||||
-- Continue collecting multi-line import
|
||||
table.insert(current_import_lines, line)
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
-- Complete the multi-line import
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
for j = i - #current_import_lines + 1, i do
|
||||
result.import_lines[j] = true
|
||||
end
|
||||
current_import_lines = {}
|
||||
in_multiline_import = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
result.import_lines[i] = true
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
-- Start of multi-line import
|
||||
in_multiline_import = true
|
||||
current_import_lines = { line }
|
||||
else
|
||||
-- Single-line import
|
||||
table.insert(result.imports, line)
|
||||
end
|
||||
else
|
||||
-- Non-import line
|
||||
table.insert(result.body, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
|
||||
if #current_import_lines > 0 then
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Find the import section range in a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return number|nil start_line First import line (1-indexed)
|
||||
---@return number|nil end_line Last import line (1-indexed)
|
||||
function M.find_import_section(bufnr, filetype)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local first_import = nil
|
||||
local last_import = nil
|
||||
local in_multiline = false
|
||||
local consecutive_non_import = 0
|
||||
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline then
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
in_multiline = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
if not first_import then
|
||||
first_import = i
|
||||
end
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
in_multiline = true
|
||||
end
|
||||
elseif is_empty_or_comment(line, filetype) then
|
||||
-- Allow gaps in import section
|
||||
if first_import then
|
||||
consecutive_non_import = consecutive_non_import + 1
|
||||
if consecutive_non_import > max_gap then
|
||||
-- Too many non-import lines, import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
else
|
||||
-- Non-import, non-empty line
|
||||
if first_import then
|
||||
-- Import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return first_import, last_import
|
||||
end
|
||||
|
||||
--- Get existing imports from a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return string[] Existing import statements
|
||||
function M.get_existing_imports(bufnr, filetype)
|
||||
local start_line, end_line = M.find_import_section(bufnr, filetype)
|
||||
if not start_line then
|
||||
return {}
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local parsed = M.parse_code(lines, filetype)
|
||||
return parsed.imports
|
||||
end
|
||||
|
||||
--- Normalize an import for comparison (remove whitespace variations)
|
||||
---@param import_str string
|
||||
---@return string
|
||||
local function normalize_import(import_str)
|
||||
-- Remove trailing semicolon for comparison
|
||||
local normalized = import_str:gsub(";%s*$", "")
|
||||
-- Remove all whitespace around braces, commas, colons
|
||||
normalized = normalized:gsub("%s*{%s*", "{")
|
||||
normalized = normalized:gsub("%s*}%s*", "}")
|
||||
normalized = normalized:gsub("%s*,%s*", ",")
|
||||
normalized = normalized:gsub("%s*:%s*", ":")
|
||||
-- Collapse multiple whitespace to single space
|
||||
normalized = normalized:gsub("%s+", " ")
|
||||
-- Trim leading/trailing whitespace
|
||||
normalized = normalized:match("^%s*(.-)%s*$")
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if two imports are duplicates
|
||||
---@param import1 string
|
||||
---@param import2 string
|
||||
---@return boolean
|
||||
local function are_duplicate_imports(import1, import2)
|
||||
return normalize_import(import1) == normalize_import(import2)
|
||||
end
|
||||
|
||||
--- Merge new imports with existing ones, avoiding duplicates
|
||||
---@param existing string[] Existing imports
|
||||
---@param new_imports string[] New imports to merge
|
||||
---@return string[] Merged imports
|
||||
function M.merge_imports(existing, new_imports)
|
||||
local merged = {}
|
||||
local seen = {}
|
||||
|
||||
-- Add existing imports
|
||||
for _, imp in ipairs(existing) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new imports that aren't duplicates
|
||||
for _, imp in ipairs(new_imports) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
return merged
|
||||
end
|
||||
|
||||
--- Sort imports by their source/module
|
||||
---@param imports string[]
|
||||
---@param filetype string
|
||||
---@return string[]
|
||||
function M.sort_imports(imports, filetype)
|
||||
-- Group imports: stdlib/builtin first, then third-party, then local
|
||||
local builtin = {}
|
||||
local third_party = {}
|
||||
local local_imports = {}
|
||||
|
||||
for _, imp in ipairs(imports) do
|
||||
-- Detect import type based on patterns
|
||||
local is_local = false
|
||||
local is_builtin = false
|
||||
|
||||
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
|
||||
-- Local: starts with . or ..
|
||||
is_local = imp:match("from%s+['\"]%.") or imp:match("require%(['\"]%.")
|
||||
-- Node builtin modules
|
||||
is_builtin = imp:match("from%s+['\"]node:") or imp:match("from%s+['\"]fs['\"]")
|
||||
or imp:match("from%s+['\"]path['\"]") or imp:match("from%s+['\"]http['\"]")
|
||||
elseif filetype == "python" or filetype == "py" then
|
||||
-- Local: relative imports
|
||||
is_local = imp:match("^from%s+%.") or imp:match("^import%s+%.")
|
||||
-- Python stdlib (simplified check)
|
||||
is_builtin = imp:match("^import%s+os") or imp:match("^import%s+sys")
|
||||
or imp:match("^from%s+os%s+") or imp:match("^from%s+sys%s+")
|
||||
or imp:match("^import%s+re") or imp:match("^import%s+json")
|
||||
elseif filetype == "lua" then
|
||||
-- Local: relative requires
|
||||
is_local = imp:match("require%(['\"]%.") or imp:match("require%s+['\"]%.")
|
||||
elseif filetype == "go" then
|
||||
-- Local: project imports (contain /)
|
||||
is_local = imp:match("['\"][^'\"]+/[^'\"]+['\"]") and not imp:match("github%.com")
|
||||
end
|
||||
|
||||
if is_builtin then
|
||||
table.insert(builtin, imp)
|
||||
elseif is_local then
|
||||
table.insert(local_imports, imp)
|
||||
else
|
||||
table.insert(third_party, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort each group alphabetically
|
||||
table.sort(builtin)
|
||||
table.sort(third_party)
|
||||
table.sort(local_imports)
|
||||
|
||||
-- Combine with proper spacing
|
||||
local result = {}
|
||||
|
||||
for _, imp in ipairs(builtin) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(third_party) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #third_party > 0 and #local_imports > 0 then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(local_imports) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
---@class InjectResult
|
||||
---@field success boolean
|
||||
---@field imports_added number Number of new imports added
|
||||
---@field imports_merged boolean Whether imports were merged into existing section
|
||||
---@field body_lines number Number of body lines injected
|
||||
|
||||
--- Smart inject code into a buffer, properly handling imports
|
||||
---@param bufnr number Target buffer
|
||||
---@param code string|string[] Code to inject
|
||||
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
|
||||
---@return InjectResult
|
||||
function M.inject(bufnr, code, opts)
|
||||
opts = opts or {}
|
||||
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
|
||||
end
|
||||
|
||||
-- Get filetype
|
||||
local filetype = opts.filetype
|
||||
if not filetype then
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
filetype = vim.fn.fnamemodify(bufname, ":e")
|
||||
end
|
||||
|
||||
-- Parse the code to separate imports from body
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
|
||||
local result = {
|
||||
success = true,
|
||||
imports_added = 0,
|
||||
imports_merged = false,
|
||||
body_lines = #parsed.body,
|
||||
}
|
||||
|
||||
-- Handle imports first if there are any
|
||||
if #parsed.imports > 0 then
|
||||
local import_start, import_end = M.find_import_section(bufnr, filetype)
|
||||
|
||||
if import_start then
|
||||
-- Merge with existing import section
|
||||
local existing_imports = M.get_existing_imports(bufnr, filetype)
|
||||
local merged = M.merge_imports(existing_imports, parsed.imports)
|
||||
|
||||
-- Count how many new imports were actually added
|
||||
result.imports_added = #merged - #existing_imports
|
||||
result.imports_merged = true
|
||||
|
||||
-- Optionally sort imports
|
||||
if opts.sort_imports ~= false then
|
||||
merged = M.sort_imports(merged, filetype)
|
||||
end
|
||||
|
||||
-- Convert back to lines (handling multi-line imports)
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(merged) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Replace the import section
|
||||
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
|
||||
|
||||
-- Adjust line numbers for body injection
|
||||
local lines_diff = #import_lines - (import_end - import_start + 1)
|
||||
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
|
||||
opts.range.start_line = opts.range.start_line + lines_diff
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + lines_diff
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No existing import section, add imports at the top
|
||||
-- Find the first non-comment, non-empty line
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local insert_at = 0
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
-- Skip shebang, docstrings, and initial comments
|
||||
if trimmed ~= "" and not trimmed:match("^#!")
|
||||
and not trimmed:match("^['\"]") and not is_empty_or_comment(line, filetype) then
|
||||
insert_at = i - 1
|
||||
break
|
||||
end
|
||||
insert_at = i
|
||||
end
|
||||
|
||||
-- Add imports with a trailing blank line
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(parsed.imports) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(import_lines, "") -- Blank line after imports
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
|
||||
result.imports_added = #parsed.imports
|
||||
result.imports_merged = false
|
||||
|
||||
-- Adjust body injection range
|
||||
if opts.range and opts.range.start_line then
|
||||
opts.range.start_line = opts.range.start_line + #import_lines
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + #import_lines
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle body (non-import) code
|
||||
if #parsed.body > 0 then
|
||||
-- Filter out empty leading/trailing lines from body
|
||||
local body_lines = parsed.body
|
||||
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
|
||||
table.remove(body_lines, 1)
|
||||
end
|
||||
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
|
||||
table.remove(body_lines)
|
||||
end
|
||||
|
||||
if #body_lines > 0 then
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
local strategy = opts.strategy or "append"
|
||||
|
||||
if strategy == "replace" and opts.range then
|
||||
local start_line = math.max(1, opts.range.start_line)
|
||||
local end_line = math.min(line_count, opts.range.end_line)
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
|
||||
elseif strategy == "insert" and opts.range then
|
||||
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
|
||||
else
|
||||
-- Default: append
|
||||
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Add blank line for spacing
|
||||
table.insert(body_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
|
||||
end
|
||||
|
||||
result.body_lines = #body_lines
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Check if code contains imports
|
||||
---@param code string|string[]
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
function M.has_imports(code, filetype)
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
return #parsed.imports > 0
|
||||
end
|
||||
|
||||
return M
|
||||
398
lua/codetyper/agent/loop.lua
Normal file
398
lua/codetyper/agent/loop.lua
Normal file
@@ -0,0 +1,398 @@
|
||||
---@mod codetyper.agent.loop Agent loop with tool orchestration
|
||||
---@brief [[
|
||||
--- Main agent loop that handles multi-turn conversations with tool use.
|
||||
--- Inspired by avante.nvim's agent_loop pattern.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgentMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_call_id? string For tool responses
|
||||
---@field tool_calls? table[] For assistant tool calls
|
||||
---@field name? string Tool name for tool responses
|
||||
|
||||
---@class AgentLoopOpts
|
||||
---@field system_prompt string System prompt
|
||||
---@field user_input string Initial user message
|
||||
---@field tools? CoderTool[] Available tools (default: all registered)
|
||||
---@field max_iterations? number Max tool call iterations (default: 10)
|
||||
---@field provider? string LLM provider to use
|
||||
---@field on_start? fun() Called when loop starts
|
||||
---@field on_chunk? fun(chunk: string) Called for each response chunk
|
||||
---@field on_tool_call? fun(name: string, input: table) Called before tool execution
|
||||
---@field on_tool_result? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_message? fun(message: AgentMessage) Called for each message added
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when loop completes
|
||||
---@field session_ctx? table Session context shared across tools
|
||||
|
||||
--- Format tool definitions for OpenAI-compatible API
|
||||
---@param tools CoderTool[]
|
||||
---@return table[]
|
||||
local function format_tools_for_api(tools)
|
||||
local formatted = {}
|
||||
for _, tool in ipairs(tools) do
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(formatted, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = type(tool.description) == "function" and tool.description() or tool.description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
return formatted
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response
|
||||
---@param response table LLM response
|
||||
---@return table[] tool_calls
|
||||
local function parse_tool_calls(response)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Handle different response formats
|
||||
if response.tool_calls then
|
||||
-- OpenAI format
|
||||
for _, call in ipairs(response.tool_calls) do
|
||||
local args = call["function"].arguments
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
end
|
||||
end
|
||||
table.insert(tool_calls, {
|
||||
id = call.id,
|
||||
name = call["function"].name,
|
||||
input = args,
|
||||
})
|
||||
end
|
||||
elseif response.content and type(response.content) == "table" then
|
||||
-- Claude format (content blocks)
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "tool_use" then
|
||||
table.insert(tool_calls, {
|
||||
id = block.id,
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Build messages for LLM request
|
||||
---@param history AgentMessage[]
|
||||
---@return table[]
|
||||
local function build_messages(history)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Execute the agent loop
|
||||
---@param opts AgentLoopOpts
|
||||
function M.run(opts)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local llm = require("codetyper.llm")
|
||||
|
||||
-- Get tools
|
||||
local tools = opts.tools or tools_mod.list()
|
||||
local tool_map = {}
|
||||
for _, tool in ipairs(tools) do
|
||||
tool_map[tool.name] = tool
|
||||
end
|
||||
|
||||
-- Initialize conversation history
|
||||
---@type AgentMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = opts.system_prompt },
|
||||
{ role = "user", content = opts.user_input },
|
||||
}
|
||||
|
||||
local session_ctx = opts.session_ctx or {}
|
||||
local max_iterations = opts.max_iterations or 10
|
||||
local iteration = 0
|
||||
|
||||
-- Callback wrappers
|
||||
local function on_message(msg)
|
||||
if opts.on_message then
|
||||
opts.on_message(msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify of initial messages
|
||||
for _, msg in ipairs(history) do
|
||||
on_message(msg)
|
||||
end
|
||||
|
||||
-- Start notification
|
||||
if opts.on_start then
|
||||
opts.on_start()
|
||||
end
|
||||
|
||||
--- Process one iteration of the loop
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build request
|
||||
local messages = build_messages(history)
|
||||
local formatted_tools = format_tools_for_api(tools)
|
||||
|
||||
-- Build context for LLM
|
||||
local context = {
|
||||
file_content = "",
|
||||
language = "lua",
|
||||
extension = "lua",
|
||||
prompt_type = "agent",
|
||||
tools = formatted_tools,
|
||||
}
|
||||
|
||||
-- Get LLM response
|
||||
local client = llm.get_client()
|
||||
if not client then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "No LLM client available")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or ""))
|
||||
end
|
||||
end
|
||||
local prompt = table.concat(prompt_parts, "\n\n")
|
||||
|
||||
client.generate(prompt, context, function(response, error)
|
||||
if error then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, error)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Chunk callback
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(response)
|
||||
end
|
||||
|
||||
-- Parse response for tool calls
|
||||
-- For now, we'll use a simple heuristic to detect tool calls in the response
|
||||
-- In a full implementation, the LLM would return structured tool calls
|
||||
local tool_calls = {}
|
||||
|
||||
-- Try to parse JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool_calls then
|
||||
tool_calls = parsed.tool_calls
|
||||
end
|
||||
end
|
||||
|
||||
-- Add assistant message
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = response,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
on_message(assistant_msg)
|
||||
|
||||
-- Process tool calls
|
||||
if #tool_calls > 0 then
|
||||
local pending = #tool_calls
|
||||
local results = {}
|
||||
|
||||
for i, call in ipairs(tool_calls) do
|
||||
local tool = tool_map[call.name]
|
||||
if not tool then
|
||||
results[i] = { error = "Unknown tool: " .. call.name }
|
||||
pending = pending - 1
|
||||
else
|
||||
-- Notify of tool call
|
||||
if opts.on_tool_call then
|
||||
opts.on_tool_call(call.name, call.input)
|
||||
end
|
||||
|
||||
-- Execute tool
|
||||
local tool_opts = {
|
||||
on_log = function(msg)
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({ type = "tool", message = msg })
|
||||
end)
|
||||
end,
|
||||
on_complete = function(result, err)
|
||||
results[i] = { result = result, error = err }
|
||||
pending = pending - 1
|
||||
|
||||
-- Notify of tool result
|
||||
if opts.on_tool_result then
|
||||
opts.on_tool_result(call.name, result, err)
|
||||
end
|
||||
|
||||
-- Add tool response to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = call.id or tostring(i),
|
||||
name = call.name,
|
||||
content = err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
on_message(tool_msg)
|
||||
|
||||
-- Continue loop when all tools complete
|
||||
if pending == 0 then
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end,
|
||||
session_ctx = session_ctx,
|
||||
}
|
||||
|
||||
-- Validate and execute
|
||||
local valid, validation_err = true, nil
|
||||
if tool.validate_input then
|
||||
valid, validation_err = tool:validate_input(call.input)
|
||||
end
|
||||
|
||||
if not valid then
|
||||
tool_opts.on_complete(nil, validation_err)
|
||||
else
|
||||
local result, err = tool.func(call.input, tool_opts)
|
||||
-- If sync result, call on_complete
|
||||
if result ~= nil or err ~= nil then
|
||||
tool_opts.on_complete(result, err)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No tool calls - loop complete
|
||||
if opts.on_complete then
|
||||
opts.on_complete(response, nil)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create an agent with default settings
|
||||
---@param task string Task description
|
||||
---@param opts? AgentLoopOpts Additional options
|
||||
function M.create(task, opts)
|
||||
opts = opts or {}
|
||||
|
||||
local system_prompt = opts.system_prompt or [[You are a helpful coding assistant with access to tools.
|
||||
|
||||
Available tools:
|
||||
- view: Read file contents
|
||||
- grep: Search for patterns in files
|
||||
- glob: Find files by pattern
|
||||
- edit: Make targeted edits to files
|
||||
- write: Create or overwrite files
|
||||
- bash: Execute shell commands
|
||||
|
||||
When you need to perform a task:
|
||||
1. Use tools to gather information
|
||||
2. Plan your approach
|
||||
3. Execute changes using appropriate tools
|
||||
4. Verify the results
|
||||
|
||||
Always explain your reasoning before using tools.
|
||||
When you're done, provide a clear summary of what was accomplished.]]
|
||||
|
||||
M.run(vim.tbl_extend("force", opts, {
|
||||
system_prompt = system_prompt,
|
||||
user_input = task,
|
||||
}))
|
||||
end
|
||||
|
||||
--- Simple dispatch agent for sub-tasks
|
||||
---@param prompt string Task for the sub-agent
|
||||
---@param on_complete fun(result: string|nil, error: string|nil) Completion callback
|
||||
---@param opts? table Additional options
|
||||
function M.dispatch(prompt, on_complete, opts)
|
||||
opts = opts or {}
|
||||
|
||||
-- Sub-agents get limited tools by default
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local safe_tools = tools_mod.list(function(tool)
|
||||
return tool.name == "view" or tool.name == "grep" or tool.name == "glob"
|
||||
end)
|
||||
|
||||
M.run({
|
||||
system_prompt = [[You are a research assistant. Your task is to find information and report back.
|
||||
You have access to: view (read files), grep (search content), glob (find files).
|
||||
Be thorough and report your findings clearly.]],
|
||||
user_input = prompt,
|
||||
tools = opts.tools or safe_tools,
|
||||
max_iterations = opts.max_iterations or 5,
|
||||
on_complete = on_complete,
|
||||
session_ctx = opts.session_ctx,
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -2,10 +2,16 @@
|
||||
---@brief [[
|
||||
--- Manages code patches with buffer snapshots for staleness detection.
|
||||
--- Patches are queued for safe injection when completion popup is not visible.
|
||||
--- Uses smart injection for intelligent import merging.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Lazy load inject module to avoid circular requires
|
||||
local function get_inject_module()
|
||||
return require("codetyper.agent.inject")
|
||||
end
|
||||
|
||||
---@class BufferSnapshot
|
||||
---@field bufnr number Buffer number
|
||||
---@field changedtick number vim.b.changedtick at snapshot time
|
||||
@@ -15,7 +21,8 @@ local M = {}
|
||||
---@class PatchCandidate
|
||||
---@field id string Unique patch ID
|
||||
---@field event_id string Related PromptEvent ID
|
||||
---@field target_bufnr number Target buffer for injection
|
||||
---@field source_bufnr number Source buffer where prompt tags are (coder file)
|
||||
---@field target_bufnr number Target buffer for injection (real file)
|
||||
---@field target_path string Target file path
|
||||
---@field original_snapshot BufferSnapshot Snapshot at event creation
|
||||
---@field generated_code string Code to inject
|
||||
@@ -171,7 +178,10 @@ end
|
||||
---@param strategy string|nil Injection strategy (overrides intent-based)
|
||||
---@return PatchCandidate
|
||||
function M.create_from_event(event, generated_code, confidence, strategy)
|
||||
-- Get target buffer
|
||||
-- Source buffer is where the prompt tags are (could be coder file)
|
||||
local source_bufnr = event.bufnr
|
||||
|
||||
-- Get target buffer (where code should be injected - the real file)
|
||||
local target_bufnr = vim.fn.bufnr(event.target_path)
|
||||
if target_bufnr == -1 then
|
||||
-- Try to find by filename
|
||||
@@ -220,7 +230,8 @@ function M.create_from_event(event, generated_code, confidence, strategy)
|
||||
return {
|
||||
id = M.generate_id(),
|
||||
event_id = event.id,
|
||||
target_bufnr = target_bufnr,
|
||||
source_bufnr = source_bufnr, -- Where prompt tags are (coder file)
|
||||
target_bufnr = target_bufnr, -- Where code goes (real file)
|
||||
target_path = event.target_path,
|
||||
original_snapshot = snapshot,
|
||||
generated_code = generated_code,
|
||||
@@ -453,39 +464,56 @@ function M.apply(patch)
|
||||
-- Prepare code lines
|
||||
local code_lines = vim.split(patch.generated_code, "\n", { plain = true })
|
||||
|
||||
-- FIRST: Remove the prompt tags from the buffer before applying code
|
||||
-- This prevents the infinite loop where tags stay and get re-detected
|
||||
local tags_removed = remove_prompt_tags(target_bufnr)
|
||||
-- FIRST: Remove the prompt tags from the SOURCE buffer (coder file), not target
|
||||
-- The tags are in the coder file where the user wrote the prompt
|
||||
-- Code goes to target file, tags get removed from source file
|
||||
local source_bufnr = patch.source_bufnr
|
||||
local tags_removed = 0
|
||||
|
||||
pcall(function()
|
||||
if tags_removed > 0 then
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Removed %d prompt tag(s) from buffer", tags_removed),
|
||||
})
|
||||
end
|
||||
end)
|
||||
if source_bufnr and vim.api.nvim_buf_is_valid(source_bufnr) then
|
||||
tags_removed = remove_prompt_tags(source_bufnr)
|
||||
|
||||
-- Recalculate line count after tag removal
|
||||
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
|
||||
pcall(function()
|
||||
if tags_removed > 0 then
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local source_name = vim.api.nvim_buf_get_name(source_bufnr)
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Removed %d prompt tag(s) from %s",
|
||||
tags_removed,
|
||||
vim.fn.fnamemodify(source_name, ":t")),
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Apply based on strategy
|
||||
-- Get filetype for smart injection
|
||||
local filetype = vim.fn.fnamemodify(patch.target_path or "", ":e")
|
||||
|
||||
-- Use smart injection module for intelligent import handling
|
||||
local inject = get_inject_module()
|
||||
local inject_result = nil
|
||||
|
||||
-- Apply based on strategy using smart injection
|
||||
local ok, err = pcall(function()
|
||||
-- Prepare injection options
|
||||
local inject_opts = {
|
||||
strategy = patch.injection_strategy,
|
||||
filetype = filetype,
|
||||
sort_imports = true,
|
||||
}
|
||||
|
||||
if patch.injection_strategy == "replace" and patch.injection_range then
|
||||
-- Replace the scope range with the new code
|
||||
-- The injection_range points to the function/method we're completing
|
||||
local start_line = patch.injection_range.start_line
|
||||
local end_line = patch.injection_range.end_line
|
||||
|
||||
-- Adjust for tag removal - find the new range by searching for the scope
|
||||
-- After removing tags, line numbers may have shifted
|
||||
-- Use the scope information to find the correct range
|
||||
if patch.scope and patch.scope.type then
|
||||
-- Try to find the scope using treesitter if available
|
||||
local found_range = nil
|
||||
pcall(function()
|
||||
local ts_utils = require("nvim-treesitter.ts_utils")
|
||||
local parsers = require("nvim-treesitter.parsers")
|
||||
local parser = parsers.get_parser(target_bufnr)
|
||||
if parser then
|
||||
@@ -528,34 +556,38 @@ function M.apply(patch)
|
||||
end
|
||||
|
||||
-- Clamp to valid range
|
||||
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
|
||||
start_line = math.max(1, start_line)
|
||||
end_line = math.min(line_count, end_line)
|
||||
|
||||
-- Replace the range (0-indexed for nvim_buf_set_lines)
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, start_line - 1, end_line, false, code_lines)
|
||||
inject_opts.range = { start_line = start_line, end_line = end_line }
|
||||
elseif patch.injection_strategy == "insert" and patch.injection_range then
|
||||
inject_opts.range = { start_line = patch.injection_range.start_line }
|
||||
end
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
-- Use smart injection - handles imports automatically
|
||||
inject_result = inject.inject(target_bufnr, patch.generated_code, inject_opts)
|
||||
|
||||
-- Log injection details
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
if inject_result.imports_added > 0 then
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Replacing lines %d-%d with %d lines of code", start_line, end_line, #code_lines),
|
||||
message = string.format(
|
||||
"%s %d import(s), injected %d body line(s)",
|
||||
inject_result.imports_merged and "Merged" or "Added",
|
||||
inject_result.imports_added,
|
||||
inject_result.body_lines
|
||||
),
|
||||
})
|
||||
else
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Injected %d line(s) of code", inject_result.body_lines),
|
||||
})
|
||||
end)
|
||||
elseif patch.injection_strategy == "insert" and patch.injection_range then
|
||||
-- Insert at the specified location
|
||||
local insert_line = patch.injection_range.start_line
|
||||
insert_line = math.max(1, math.min(line_count + 1, insert_line))
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, insert_line - 1, insert_line - 1, false, code_lines)
|
||||
else
|
||||
-- Default: append to end
|
||||
-- Check if last line is empty, if not add a blank line for spacing
|
||||
local last_line = vim.api.nvim_buf_get_lines(target_bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Last line has content, add blank line for spacing
|
||||
table.insert(code_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, line_count, line_count, false, code_lines)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
@@ -577,6 +609,41 @@ function M.apply(patch)
|
||||
})
|
||||
end)
|
||||
|
||||
-- Learn from successful code generation - this builds neural pathways
|
||||
-- The more code is successfully applied, the better the brain becomes
|
||||
pcall(function()
|
||||
local brain = require("codetyper.brain")
|
||||
if brain.is_initialized() then
|
||||
-- Learn the successful pattern
|
||||
local intent_type = patch.intent and patch.intent.type or "unknown"
|
||||
local scope_type = patch.scope and patch.scope.type or "file"
|
||||
local scope_name = patch.scope and patch.scope.name or ""
|
||||
|
||||
-- Create a meaningful summary for this learning
|
||||
local summary = string.format(
|
||||
"Generated %s: %s %s in %s",
|
||||
intent_type,
|
||||
scope_type,
|
||||
scope_name ~= "" and scope_name or "",
|
||||
vim.fn.fnamemodify(patch.target_path or "", ":t")
|
||||
)
|
||||
|
||||
brain.learn({
|
||||
type = "code_completion",
|
||||
file = patch.target_path,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
intent = intent_type,
|
||||
code = patch.generated_code:sub(1, 500), -- Store first 500 chars
|
||||
language = vim.fn.fnamemodify(patch.target_path or "", ":e"),
|
||||
function_name = scope_name,
|
||||
prompt = patch.prompt_content,
|
||||
confidence = patch.confidence or 0.5,
|
||||
},
|
||||
})
|
||||
end
|
||||
end)
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
|
||||
128
lua/codetyper/agent/tools/base.lua
Normal file
128
lua/codetyper/agent/tools/base.lua
Normal file
@@ -0,0 +1,128 @@
|
||||
---@mod codetyper.agent.tools.base Base tool definition
|
||||
---@brief [[
|
||||
--- Base metatable for all LLM tools.
|
||||
--- Tools extend this base to provide structured AI capabilities.
|
||||
---@brief ]]
|
||||
|
||||
---@class CoderToolParam
|
||||
---@field name string Parameter name
|
||||
---@field description string Parameter description
|
||||
---@field type string Parameter type ("string", "number", "boolean", "table")
|
||||
---@field optional? boolean Whether the parameter is optional
|
||||
---@field default? any Default value for optional parameters
|
||||
|
||||
---@class CoderToolReturn
|
||||
---@field name string Return value name
|
||||
---@field description string Return value description
|
||||
---@field type string Return type
|
||||
---@field optional? boolean Whether the return is optional
|
||||
|
||||
---@class CoderToolOpts
|
||||
---@field on_log? fun(message: string) Log callback
|
||||
---@field on_complete? fun(result: any, error: string|nil) Completion callback
|
||||
---@field session_ctx? table Session context
|
||||
---@field streaming? boolean Whether response is still streaming
|
||||
---@field confirm? fun(message: string, callback: fun(ok: boolean)) Confirmation callback
|
||||
|
||||
---@class CoderTool
|
||||
---@field name string Tool identifier
|
||||
---@field description string|fun(): string Tool description
|
||||
---@field params CoderToolParam[] Input parameters
|
||||
---@field returns CoderToolReturn[] Return values
|
||||
---@field requires_confirmation? boolean Whether tool needs user confirmation
|
||||
---@field func fun(input: table, opts: CoderToolOpts): any, string|nil Tool implementation
|
||||
|
||||
local M = {}
|
||||
M.__index = M
|
||||
|
||||
--- Call the tool function
|
||||
---@param opts CoderToolOpts Options for the tool call
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M:__call(opts, on_log, on_complete)
|
||||
return self.func(opts, on_log, on_complete)
|
||||
end
|
||||
|
||||
--- Get the tool description
|
||||
---@return string
|
||||
function M:get_description()
|
||||
if type(self.description) == "function" then
|
||||
return self.description()
|
||||
end
|
||||
return self.description
|
||||
end
|
||||
|
||||
--- Validate input against parameter schema
|
||||
---@param input table Input to validate
|
||||
---@return boolean valid
|
||||
---@return string|nil error
|
||||
function M:validate_input(input)
|
||||
if not self.params then
|
||||
return true
|
||||
end
|
||||
|
||||
for _, param in ipairs(self.params) do
|
||||
local value = input[param.name]
|
||||
|
||||
-- Check required parameters
|
||||
if not param.optional and value == nil then
|
||||
return false, string.format("Missing required parameter: %s", param.name)
|
||||
end
|
||||
|
||||
-- Type checking
|
||||
if value ~= nil then
|
||||
local actual_type = type(value)
|
||||
local expected_type = param.type
|
||||
|
||||
-- Handle special types
|
||||
if expected_type == "integer" and actual_type == "number" then
|
||||
if math.floor(value) ~= value then
|
||||
return false, string.format("Parameter %s must be an integer", param.name)
|
||||
end
|
||||
elseif expected_type ~= actual_type and expected_type ~= "any" then
|
||||
return false, string.format("Parameter %s must be %s, got %s", param.name, expected_type, actual_type)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Generate JSON schema for the tool (for LLM function calling)
|
||||
---@return table schema
|
||||
function M:to_schema()
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(self.params or {}) do
|
||||
local prop = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
|
||||
if param.default ~= nil then
|
||||
prop.default = param.default
|
||||
end
|
||||
|
||||
properties[param.name] = prop
|
||||
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
type = "function",
|
||||
function_def = {
|
||||
name = self.name,
|
||||
description = self:get_description(),
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
198
lua/codetyper/agent/tools/bash.lua
Normal file
198
lua/codetyper/agent/tools/bash.lua
Normal file
@@ -0,0 +1,198 @@
|
||||
---@mod codetyper.agent.tools.bash Shell command execution tool
|
||||
---@brief [[
|
||||
--- Tool for executing shell commands with safety checks.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "bash"
|
||||
|
||||
M.description = [[Executes a bash command in a shell.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- Do NOT use bash to read files (use 'view' tool instead)
|
||||
- Do NOT use bash to modify files (use 'write' or 'edit' tools instead)
|
||||
- Do NOT use interactive commands (vim, nano, less, etc.)
|
||||
- Commands timeout after 2 minutes by default
|
||||
|
||||
Allowed uses:
|
||||
- Running builds (make, npm run build, cargo build)
|
||||
- Running tests (npm test, pytest, cargo test)
|
||||
- Git operations (git status, git diff, git commit)
|
||||
- Package management (npm install, pip install)
|
||||
- System info commands (ls, pwd, which)]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "command",
|
||||
description = "The shell command to execute",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "cwd",
|
||||
description = "Working directory for the command (optional)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "timeout",
|
||||
description = "Timeout in milliseconds (default: 120000)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "stdout",
|
||||
description = "Command output",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if command failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = true
|
||||
|
||||
--- Banned commands for safety
|
||||
local BANNED_COMMANDS = {
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
":(){ :|:& };:",
|
||||
"> /dev/sda",
|
||||
}
|
||||
|
||||
--- Banned patterns
|
||||
local BANNED_PATTERNS = {
|
||||
"curl.*|.*sh",
|
||||
"wget.*|.*sh",
|
||||
"rm%s+%-rf%s+/",
|
||||
}
|
||||
|
||||
--- Check if command is safe
|
||||
---@param command string
|
||||
---@return boolean safe
|
||||
---@return string|nil reason
|
||||
local function is_safe_command(command)
|
||||
-- Check exact matches
|
||||
for _, banned in ipairs(BANNED_COMMANDS) do
|
||||
if command == banned then
|
||||
return false, "Command is banned for safety"
|
||||
end
|
||||
end
|
||||
|
||||
-- Check patterns
|
||||
for _, pattern in ipairs(BANNED_PATTERNS) do
|
||||
if command:match(pattern) then
|
||||
return false, "Command matches banned pattern"
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
---@param input {command: string, cwd?: string, timeout?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.command then
|
||||
return nil, "command is required"
|
||||
end
|
||||
|
||||
-- Safety check
|
||||
local safe, reason = is_safe_command(input.command)
|
||||
if not safe then
|
||||
return nil, reason
|
||||
end
|
||||
|
||||
-- Confirmation required
|
||||
if M.requires_confirmation and opts.confirm then
|
||||
local confirmed = false
|
||||
local confirm_error = nil
|
||||
|
||||
opts.confirm("Execute command: " .. input.command, function(ok)
|
||||
if not ok then
|
||||
confirm_error = "User declined command execution"
|
||||
end
|
||||
confirmed = ok
|
||||
end)
|
||||
|
||||
-- Wait for confirmation (in async context, this would be handled differently)
|
||||
if confirm_error then
|
||||
return nil, confirm_error
|
||||
end
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Executing: " .. input.command)
|
||||
end
|
||||
|
||||
-- Prepare command
|
||||
local cwd = input.cwd or vim.fn.getcwd()
|
||||
local timeout = input.timeout or 120000
|
||||
|
||||
-- Execute command
|
||||
local output = ""
|
||||
local exit_code = 0
|
||||
|
||||
local job_opts = {
|
||||
command = "bash",
|
||||
args = { "-c", input.command },
|
||||
cwd = cwd,
|
||||
on_stdout = function(_, data)
|
||||
if data then
|
||||
output = output .. table.concat(data, "\n")
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data then
|
||||
output = output .. table.concat(data, "\n")
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
exit_code = code
|
||||
end,
|
||||
}
|
||||
|
||||
-- Run synchronously with timeout
|
||||
local Job = require("plenary.job")
|
||||
local job = Job:new(job_opts)
|
||||
|
||||
job:sync(timeout)
|
||||
exit_code = job.code or 0
|
||||
output = table.concat(job:result() or {}, "\n")
|
||||
|
||||
-- Also get stderr
|
||||
local stderr = table.concat(job:stderr_result() or {}, "\n")
|
||||
if stderr and stderr ~= "" then
|
||||
output = output .. "\n" .. stderr
|
||||
end
|
||||
|
||||
-- Check result
|
||||
if exit_code ~= 0 then
|
||||
local error_msg = string.format("Command failed with exit code %d: %s", exit_code, output)
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, error_msg)
|
||||
end
|
||||
return nil, error_msg
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(output, nil)
|
||||
end
|
||||
|
||||
return output, nil
|
||||
end
|
||||
|
||||
return M
|
||||
429
lua/codetyper/agent/tools/edit.lua
Normal file
429
lua/codetyper/agent/tools/edit.lua
Normal file
@@ -0,0 +1,429 @@
|
||||
---@mod codetyper.agent.tools.edit File editing tool with fallback matching
|
||||
---@brief [[
|
||||
--- Tool for making targeted edits to files using search/replace.
|
||||
--- Implements multiple fallback strategies for robust matching.
|
||||
--- Inspired by opencode's 9-strategy approach.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "edit"
|
||||
|
||||
M.description = [[Makes a targeted edit to a file by replacing text.
|
||||
|
||||
The old_string should match the content you want to replace. The tool uses multiple
|
||||
matching strategies with fallbacks:
|
||||
1. Exact match
|
||||
2. Whitespace-normalized match
|
||||
3. Indentation-flexible match
|
||||
4. Line-trimmed match
|
||||
5. Fuzzy anchor-based match
|
||||
|
||||
For creating new files, use old_string="" and provide the full content in new_string.
|
||||
For large changes, consider using 'write' tool instead.]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file to edit",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "old_string",
|
||||
description = "Text to find and replace (empty string to create new file or append)",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "new_string",
|
||||
description = "Text to replace with",
|
||||
type = "string",
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "Whether the edit was applied",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if edit failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
--- Normalize line endings to LF
|
||||
---@param str string
|
||||
---@return string
|
||||
local function normalize_line_endings(str)
|
||||
return str:gsub("\r\n", "\n"):gsub("\r", "\n")
|
||||
end
|
||||
|
||||
--- Strategy 1: Exact match
|
||||
---@param content string File content
|
||||
---@param old_str string String to find
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function exact_match(content, old_str)
|
||||
local pos = content:find(old_str, 1, true)
|
||||
if pos then
|
||||
return pos, pos + #old_str - 1
|
||||
end
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 2: Whitespace-normalized match
|
||||
--- Collapses all whitespace to single spaces
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function whitespace_normalized_match(content, old_str)
|
||||
local function normalize_ws(s)
|
||||
return s:gsub("%s+", " "):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
end
|
||||
|
||||
local norm_old = normalize_ws(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
-- Try to find matching block
|
||||
for i = 1, #lines do
|
||||
local block = {}
|
||||
local block_start = nil
|
||||
|
||||
for j = i, #lines do
|
||||
table.insert(block, lines[j])
|
||||
local block_text = table.concat(block, "\n")
|
||||
local norm_block = normalize_ws(block_text)
|
||||
|
||||
if norm_block == norm_old then
|
||||
-- Found match
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
|
||||
-- If block is already longer than target, stop
|
||||
if #norm_block > #norm_old then
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 3: Indentation-flexible match
|
||||
--- Ignores leading whitespace differences
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function indentation_flexible_match(content, old_str)
|
||||
local function strip_indent(s)
|
||||
local lines = vim.split(s, "\n")
|
||||
local result = {}
|
||||
for _, line in ipairs(lines) do
|
||||
table.insert(result, line:gsub("^%s+", ""))
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
local stripped_old = strip_indent(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local num_old_lines = #old_lines
|
||||
|
||||
for i = 1, #lines - num_old_lines + 1 do
|
||||
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
|
||||
local block_text = table.concat(block, "\n")
|
||||
|
||||
if strip_indent(block_text) == stripped_old then
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 4: Line-trimmed match
|
||||
--- Trims each line before comparing
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function line_trimmed_match(content, old_str)
|
||||
local function trim_lines(s)
|
||||
local lines = vim.split(s, "\n")
|
||||
local result = {}
|
||||
for _, line in ipairs(lines) do
|
||||
table.insert(result, line:match("^%s*(.-)%s*$"))
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
local trimmed_old = trim_lines(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local num_old_lines = #old_lines
|
||||
|
||||
for i = 1, #lines - num_old_lines + 1 do
|
||||
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
|
||||
local block_text = table.concat(block, "\n")
|
||||
|
||||
if trim_lines(block_text) == trimmed_old then
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Calculate Levenshtein distance between two strings
|
||||
---@param s1 string
|
||||
---@param s2 string
|
||||
---@return number
|
||||
local function levenshtein(s1, s2)
|
||||
local len1, len2 = #s1, #s2
|
||||
local matrix = {}
|
||||
|
||||
for i = 0, len1 do
|
||||
matrix[i] = { [0] = i }
|
||||
end
|
||||
for j = 0, len2 do
|
||||
matrix[0][j] = j
|
||||
end
|
||||
|
||||
for i = 1, len1 do
|
||||
for j = 1, len2 do
|
||||
local cost = s1:sub(i, i) == s2:sub(j, j) and 0 or 1
|
||||
matrix[i][j] = math.min(
|
||||
matrix[i - 1][j] + 1,
|
||||
matrix[i][j - 1] + 1,
|
||||
matrix[i - 1][j - 1] + cost
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
return matrix[len1][len2]
|
||||
end
|
||||
|
||||
--- Strategy 5: Fuzzy anchor-based match
|
||||
--- Uses first and last lines as anchors, allows fuzzy matching in between
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@param threshold? number Similarity threshold (0-1), default 0.8
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function fuzzy_anchor_match(content, old_str, threshold)
|
||||
threshold = threshold or 0.8
|
||||
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
if #old_lines < 2 then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local first_line = old_lines[1]:match("^%s*(.-)%s*$")
|
||||
local last_line = old_lines[#old_lines]:match("^%s*(.-)%s*$")
|
||||
local content_lines = vim.split(content, "\n")
|
||||
|
||||
-- Find potential start positions
|
||||
local candidates = {}
|
||||
for i, line in ipairs(content_lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
if trimmed == first_line or (
|
||||
#first_line > 0 and
|
||||
1 - (levenshtein(trimmed, first_line) / math.max(#trimmed, #first_line)) >= threshold
|
||||
) then
|
||||
table.insert(candidates, i)
|
||||
end
|
||||
end
|
||||
|
||||
-- For each candidate, look for matching end
|
||||
for _, start_idx in ipairs(candidates) do
|
||||
local expected_end = start_idx + #old_lines - 1
|
||||
if expected_end <= #content_lines then
|
||||
local end_line = content_lines[expected_end]:match("^%s*(.-)%s*$")
|
||||
if end_line == last_line or (
|
||||
#last_line > 0 and
|
||||
1 - (levenshtein(end_line, last_line) / math.max(#end_line, #last_line)) >= threshold
|
||||
) then
|
||||
-- Calculate positions
|
||||
local before = table.concat(vim.list_slice(content_lines, 1, start_idx - 1), "\n")
|
||||
local block = table.concat(vim.list_slice(content_lines, start_idx, expected_end), "\n")
|
||||
local start_pos = #before + (start_idx > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Try all matching strategies in order
|
||||
---@param content string File content
|
||||
---@param old_str string String to find
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
---@return string strategy_used
|
||||
local function find_match(content, old_str)
|
||||
-- Strategy 1: Exact match
|
||||
local start_pos, end_pos = exact_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "exact"
|
||||
end
|
||||
|
||||
-- Strategy 2: Whitespace-normalized
|
||||
start_pos, end_pos = whitespace_normalized_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "whitespace_normalized"
|
||||
end
|
||||
|
||||
-- Strategy 3: Indentation-flexible
|
||||
start_pos, end_pos = indentation_flexible_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "indentation_flexible"
|
||||
end
|
||||
|
||||
-- Strategy 4: Line-trimmed
|
||||
start_pos, end_pos = line_trimmed_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "line_trimmed"
|
||||
end
|
||||
|
||||
-- Strategy 5: Fuzzy anchor
|
||||
start_pos, end_pos = fuzzy_anchor_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "fuzzy_anchor"
|
||||
end
|
||||
|
||||
return nil, nil, "none"
|
||||
end
|
||||
|
||||
---@param input {path: string, old_string: string, new_string: string}
|
||||
---@param opts CoderToolOpts
|
||||
---@return boolean|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
if input.old_string == nil then
|
||||
return nil, "old_string is required"
|
||||
end
|
||||
if input.new_string == nil then
|
||||
return nil, "new_string is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Editing file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Normalize inputs
|
||||
local old_str = normalize_line_endings(input.old_string)
|
||||
local new_str = normalize_line_endings(input.new_string)
|
||||
|
||||
-- Handle new file creation (empty old_string)
|
||||
if old_str == "" then
|
||||
-- Create parent directories
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if vim.fn.isdirectory(dir) == 0 then
|
||||
vim.fn.mkdir(dir, "p")
|
||||
end
|
||||
|
||||
-- Write new file
|
||||
local lines = vim.split(new_str, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to create file: " .. input.path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
if vim.fn.filereadable(path) ~= 1 then
|
||||
return nil, "File not found: " .. input.path
|
||||
end
|
||||
|
||||
-- Read current content
|
||||
local lines = vim.fn.readfile(path)
|
||||
if not lines then
|
||||
return nil, "Failed to read file: " .. input.path
|
||||
end
|
||||
|
||||
local content = normalize_line_endings(table.concat(lines, "\n"))
|
||||
|
||||
-- Find match using fallback strategies
|
||||
local start_pos, end_pos, strategy = find_match(content, old_str)
|
||||
|
||||
if not start_pos then
|
||||
return nil, "old_string not found in file (tried 5 matching strategies)"
|
||||
end
|
||||
|
||||
if opts.on_log then
|
||||
opts.on_log("Match found using strategy: " .. strategy)
|
||||
end
|
||||
|
||||
-- Perform replacement
|
||||
local new_content = content:sub(1, start_pos - 1) .. new_str .. content:sub(end_pos + 1)
|
||||
|
||||
-- Write back
|
||||
local new_lines = vim.split(new_content, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, new_lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to write file: " .. input.path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
return M
|
||||
146
lua/codetyper/agent/tools/glob.lua
Normal file
146
lua/codetyper/agent/tools/glob.lua
Normal file
@@ -0,0 +1,146 @@
|
||||
---@mod codetyper.agent.tools.glob File pattern matching tool
|
||||
---@brief [[
|
||||
--- Tool for finding files by glob pattern.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "glob"
|
||||
|
||||
M.description = [[Finds files matching a glob pattern.
|
||||
|
||||
Example patterns:
|
||||
- "**/*.lua" - All Lua files
|
||||
- "src/**/*.ts" - TypeScript files in src
|
||||
- "**/test_*.py" - Test files in Python]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "pattern",
|
||||
description = "Glob pattern to match files",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "path",
|
||||
description = "Base directory to search in (default: project root)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "max_results",
|
||||
description = "Maximum number of results (default: 100)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "matches",
|
||||
description = "JSON array of matching file paths",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if glob failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
---@param input {pattern: string, path?: string, max_results?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.pattern then
|
||||
return nil, "pattern is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Finding files: " .. input.pattern)
|
||||
end
|
||||
|
||||
-- Resolve base path
|
||||
local base_path = input.path or vim.fn.getcwd()
|
||||
if not vim.startswith(base_path, "/") then
|
||||
base_path = vim.fn.getcwd() .. "/" .. base_path
|
||||
end
|
||||
|
||||
local max_results = input.max_results or 100
|
||||
|
||||
-- Use vim.fn.glob or fd if available
|
||||
local matches = {}
|
||||
|
||||
if vim.fn.executable("fd") == 1 then
|
||||
-- Use fd for better performance
|
||||
local Job = require("plenary.job")
|
||||
|
||||
-- Convert glob to fd pattern
|
||||
local fd_pattern = input.pattern:gsub("%*%*/", ""):gsub("%*", ".*")
|
||||
|
||||
local job = Job:new({
|
||||
command = "fd",
|
||||
args = {
|
||||
"--type",
|
||||
"f",
|
||||
"--max-results",
|
||||
tostring(max_results),
|
||||
"--glob",
|
||||
input.pattern,
|
||||
base_path,
|
||||
},
|
||||
cwd = base_path,
|
||||
})
|
||||
|
||||
job:sync(30000)
|
||||
matches = job:result() or {}
|
||||
else
|
||||
-- Fallback to vim.fn.globpath
|
||||
local pattern = base_path .. "/" .. input.pattern
|
||||
local files = vim.fn.glob(pattern, false, true)
|
||||
|
||||
for i, file in ipairs(files) do
|
||||
if i > max_results then
|
||||
break
|
||||
end
|
||||
-- Make paths relative to base_path
|
||||
local relative = file:gsub("^" .. vim.pesc(base_path) .. "/", "")
|
||||
table.insert(matches, relative)
|
||||
end
|
||||
end
|
||||
|
||||
-- Clean up matches
|
||||
local cleaned = {}
|
||||
for _, match in ipairs(matches) do
|
||||
if match and match ~= "" then
|
||||
-- Make relative if absolute
|
||||
local relative = match
|
||||
if vim.startswith(match, base_path) then
|
||||
relative = match:sub(#base_path + 2)
|
||||
end
|
||||
table.insert(cleaned, relative)
|
||||
end
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
matches = cleaned,
|
||||
total = #cleaned,
|
||||
truncated = #cleaned >= max_results,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
150
lua/codetyper/agent/tools/grep.lua
Normal file
150
lua/codetyper/agent/tools/grep.lua
Normal file
@@ -0,0 +1,150 @@
|
||||
---@mod codetyper.agent.tools.grep Search tool
|
||||
---@brief [[
|
||||
--- Tool for searching file contents using ripgrep.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "grep"
|
||||
|
||||
M.description = [[Searches for a pattern in files using ripgrep.
|
||||
|
||||
Returns file paths and matching lines. Use this to find code by content.
|
||||
|
||||
Example patterns:
|
||||
- "function foo" - Find function definitions
|
||||
- "import.*react" - Find React imports
|
||||
- "TODO|FIXME" - Find todo comments]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "pattern",
|
||||
description = "Regular expression pattern to search for",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "path",
|
||||
description = "Directory or file to search in (default: project root)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "include",
|
||||
description = "File glob pattern to include (e.g., '*.lua')",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "max_results",
|
||||
description = "Maximum number of results (default: 50)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "matches",
|
||||
description = "JSON array of matches with file, line_number, and content",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if search failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
---@param input {pattern: string, path?: string, include?: string, max_results?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.pattern then
|
||||
return nil, "pattern is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Searching for: " .. input.pattern)
|
||||
end
|
||||
|
||||
-- Build ripgrep command
|
||||
local path = input.path or vim.fn.getcwd()
|
||||
local max_results = input.max_results or 50
|
||||
|
||||
-- Resolve path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Check if ripgrep is available
|
||||
if vim.fn.executable("rg") ~= 1 then
|
||||
return nil, "ripgrep (rg) is not installed"
|
||||
end
|
||||
|
||||
-- Build command args
|
||||
local args = {
|
||||
"--json",
|
||||
"--max-count",
|
||||
tostring(max_results),
|
||||
"--no-heading",
|
||||
}
|
||||
|
||||
if input.include then
|
||||
table.insert(args, "--glob")
|
||||
table.insert(args, input.include)
|
||||
end
|
||||
|
||||
table.insert(args, input.pattern)
|
||||
table.insert(args, path)
|
||||
|
||||
-- Execute ripgrep
|
||||
local Job = require("plenary.job")
|
||||
local job = Job:new({
|
||||
command = "rg",
|
||||
args = args,
|
||||
cwd = vim.fn.getcwd(),
|
||||
})
|
||||
|
||||
job:sync(30000) -- 30 second timeout
|
||||
|
||||
local results = job:result() or {}
|
||||
local matches = {}
|
||||
|
||||
-- Parse JSON output
|
||||
for _, line in ipairs(results) do
|
||||
if line and line ~= "" then
|
||||
local ok, parsed = pcall(vim.json.decode, line)
|
||||
if ok and parsed.type == "match" then
|
||||
local data = parsed.data
|
||||
table.insert(matches, {
|
||||
file = data.path.text,
|
||||
line_number = data.line_number,
|
||||
content = data.lines.text:gsub("\n$", ""),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
matches = matches,
|
||||
total = #matches,
|
||||
truncated = #matches >= max_results,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
308
lua/codetyper/agent/tools/init.lua
Normal file
308
lua/codetyper/agent/tools/init.lua
Normal file
@@ -0,0 +1,308 @@
|
||||
---@mod codetyper.agent.tools Tool registry and orchestration
|
||||
---@brief [[
|
||||
--- Registry for LLM tools with execution and schema generation.
|
||||
--- Inspired by avante.nvim's tool system.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Registered tools
|
||||
---@type table<string, CoderTool>
|
||||
local tools = {}
|
||||
|
||||
--- Tool execution history for current session
|
||||
---@type table[]
|
||||
local execution_history = {}
|
||||
|
||||
--- Register a tool
|
||||
---@param tool CoderTool Tool to register
|
||||
function M.register(tool)
|
||||
if not tool.name then
|
||||
error("Tool must have a name")
|
||||
end
|
||||
tools[tool.name] = tool
|
||||
end
|
||||
|
||||
--- Unregister a tool
|
||||
---@param name string Tool name
|
||||
function M.unregister(name)
|
||||
tools[name] = nil
|
||||
end
|
||||
|
||||
--- Get a tool by name
|
||||
---@param name string Tool name
|
||||
---@return CoderTool|nil
|
||||
function M.get(name)
|
||||
return tools[name]
|
||||
end
|
||||
|
||||
--- Get all registered tools
|
||||
---@return table<string, CoderTool>
|
||||
function M.get_all()
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Get tools as a list
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return CoderTool[]
|
||||
function M.list(filter)
|
||||
local result = {}
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
table.insert(result, tool)
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Generate schemas for all tools (for LLM function calling)
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] schemas
|
||||
function M.get_schemas(filter)
|
||||
local schemas = {}
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
if tool.to_schema then
|
||||
table.insert(schemas, tool:to_schema())
|
||||
end
|
||||
end
|
||||
end
|
||||
return schemas
|
||||
end
|
||||
|
||||
--- Execute a tool by name
|
||||
---@param name string Tool name
|
||||
---@param input table Input parameters
|
||||
---@param opts CoderToolOpts Execution options
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M.execute(name, input, opts)
|
||||
local tool = tools[name]
|
||||
if not tool then
|
||||
return nil, "Unknown tool: " .. name
|
||||
end
|
||||
|
||||
-- Validate input
|
||||
if tool.validate_input then
|
||||
local valid, err = tool:validate_input(input)
|
||||
if not valid then
|
||||
return nil, err
|
||||
end
|
||||
end
|
||||
|
||||
-- Log execution
|
||||
if opts.on_log then
|
||||
opts.on_log(string.format("Executing tool: %s", name))
|
||||
end
|
||||
|
||||
-- Track execution
|
||||
local execution = {
|
||||
tool = name,
|
||||
input = input,
|
||||
start_time = os.time(),
|
||||
status = "running",
|
||||
}
|
||||
table.insert(execution_history, execution)
|
||||
|
||||
-- Execute the tool
|
||||
local result, err = tool.func(input, opts)
|
||||
|
||||
-- Update execution record
|
||||
execution.end_time = os.time()
|
||||
execution.status = err and "error" or "completed"
|
||||
execution.result = result
|
||||
execution.error = err
|
||||
|
||||
return result, err
|
||||
end
|
||||
|
||||
--- Process a tool call from LLM response
|
||||
---@param tool_call table Tool call from LLM (name + input)
|
||||
---@param opts CoderToolOpts Execution options
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M.process_tool_call(tool_call, opts)
|
||||
local name = tool_call.name or tool_call.function_name
|
||||
local input = tool_call.input or tool_call.arguments or {}
|
||||
|
||||
-- Parse JSON arguments if string
|
||||
if type(input) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, input)
|
||||
if ok then
|
||||
input = parsed
|
||||
else
|
||||
return nil, "Failed to parse tool arguments: " .. input
|
||||
end
|
||||
end
|
||||
|
||||
return M.execute(name, input, opts)
|
||||
end
|
||||
|
||||
--- Get execution history
|
||||
---@param limit? number Max entries to return
|
||||
---@return table[]
|
||||
function M.get_history(limit)
|
||||
if not limit then
|
||||
return execution_history
|
||||
end
|
||||
|
||||
local result = {}
|
||||
local start = math.max(1, #execution_history - limit + 1)
|
||||
for i = start, #execution_history do
|
||||
table.insert(result, execution_history[i])
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Clear execution history
|
||||
function M.clear_history()
|
||||
execution_history = {}
|
||||
end
|
||||
|
||||
--- Load built-in tools
|
||||
function M.load_builtins()
|
||||
-- View file tool
|
||||
local view = require("codetyper.agent.tools.view")
|
||||
M.register(view)
|
||||
|
||||
-- Bash tool
|
||||
local bash = require("codetyper.agent.tools.bash")
|
||||
M.register(bash)
|
||||
|
||||
-- Grep tool
|
||||
local grep = require("codetyper.agent.tools.grep")
|
||||
M.register(grep)
|
||||
|
||||
-- Glob tool
|
||||
local glob = require("codetyper.agent.tools.glob")
|
||||
M.register(glob)
|
||||
|
||||
-- Write file tool
|
||||
local write = require("codetyper.agent.tools.write")
|
||||
M.register(write)
|
||||
|
||||
-- Edit tool
|
||||
local edit = require("codetyper.agent.tools.edit")
|
||||
M.register(edit)
|
||||
end
|
||||
|
||||
--- Initialize tools system
|
||||
function M.setup()
|
||||
M.load_builtins()
|
||||
end
|
||||
|
||||
--- Get tool definitions for LLM (lazy-loaded, OpenAI format)
|
||||
--- This is accessed as M.definitions property
|
||||
M.definitions = setmetatable({}, {
|
||||
__call = function()
|
||||
-- Ensure tools are loaded
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end,
|
||||
__index = function(_, key)
|
||||
-- Make it work as both function and table
|
||||
if key == "get" then
|
||||
return function()
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end,
|
||||
})
|
||||
|
||||
--- Get definitions as a function (for backwards compatibility)
|
||||
function M.get_definitions()
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end
|
||||
|
||||
--- Convert all tools to OpenAI function calling format
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] OpenAI-compatible tool definitions
|
||||
function M.to_openai_format(filter)
|
||||
local openai_tools = {}
|
||||
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if param.default ~= nil then
|
||||
properties[param.name].default = param.default
|
||||
end
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
table.insert(openai_tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return openai_tools
|
||||
end
|
||||
|
||||
--- Convert all tools to Claude tool use format
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] Claude-compatible tool definitions
|
||||
function M.to_claude_format(filter)
|
||||
local claude_tools = {}
|
||||
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
table.insert(claude_tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return claude_tools
|
||||
end
|
||||
|
||||
return M
|
||||
149
lua/codetyper/agent/tools/view.lua
Normal file
149
lua/codetyper/agent/tools/view.lua
Normal file
@@ -0,0 +1,149 @@
|
||||
---@mod codetyper.agent.tools.view File viewing tool
|
||||
---@brief [[
|
||||
--- Tool for reading file contents with line range support.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "view"
|
||||
|
||||
M.description = [[Reads the content of a file.
|
||||
|
||||
Usage notes:
|
||||
- Provide the file path relative to the project root
|
||||
- Use start_line and end_line to read specific sections
|
||||
- If content is truncated, use line ranges to read in chunks
|
||||
- Returns JSON with content, total_line_count, and is_truncated]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file (relative to project root or absolute)",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "start_line",
|
||||
description = "Line number to start reading (1-indexed)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "end_line",
|
||||
description = "Line number to end reading (1-indexed, inclusive)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "content",
|
||||
description = "File contents as JSON with content, total_line_count, is_truncated",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if file could not be read",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
--- Maximum content size before truncation
|
||||
local MAX_CONTENT_SIZE = 200 * 1024 -- 200KB
|
||||
|
||||
---@param input {path: string, start_line?: integer, end_line?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Reading file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
-- Relative path - resolve from project root
|
||||
local root = vim.fn.getcwd()
|
||||
path = root .. "/" .. path
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
local stat = vim.uv.fs_stat(path)
|
||||
if not stat then
|
||||
return nil, "File not found: " .. input.path
|
||||
end
|
||||
|
||||
if stat.type == "directory" then
|
||||
return nil, "Path is a directory: " .. input.path
|
||||
end
|
||||
|
||||
-- Read file
|
||||
local lines = vim.fn.readfile(path)
|
||||
if not lines then
|
||||
return nil, "Failed to read file: " .. input.path
|
||||
end
|
||||
|
||||
-- Apply line range
|
||||
local start_line = input.start_line or 1
|
||||
local end_line = input.end_line or #lines
|
||||
|
||||
start_line = math.max(1, start_line)
|
||||
end_line = math.min(#lines, end_line)
|
||||
|
||||
local total_lines = #lines
|
||||
local selected_lines = {}
|
||||
|
||||
for i = start_line, end_line do
|
||||
table.insert(selected_lines, lines[i])
|
||||
end
|
||||
|
||||
-- Check for truncation
|
||||
local content = table.concat(selected_lines, "\n")
|
||||
local is_truncated = false
|
||||
|
||||
if #content > MAX_CONTENT_SIZE then
|
||||
-- Truncate content
|
||||
local truncated_lines = {}
|
||||
local size = 0
|
||||
|
||||
for _, line in ipairs(selected_lines) do
|
||||
size = size + #line + 1
|
||||
if size > MAX_CONTENT_SIZE then
|
||||
is_truncated = true
|
||||
break
|
||||
end
|
||||
table.insert(truncated_lines, line)
|
||||
end
|
||||
|
||||
content = table.concat(truncated_lines, "\n")
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
content = content,
|
||||
total_line_count = total_lines,
|
||||
is_truncated = is_truncated,
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
101
lua/codetyper/agent/tools/write.lua
Normal file
101
lua/codetyper/agent/tools/write.lua
Normal file
@@ -0,0 +1,101 @@
|
||||
---@mod codetyper.agent.tools.write File writing tool
|
||||
---@brief [[
|
||||
--- Tool for creating or overwriting files.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "write"
|
||||
|
||||
M.description = [[Creates or overwrites a file with new content.
|
||||
|
||||
IMPORTANT:
|
||||
- This will completely replace the file contents
|
||||
- Use 'edit' tool for partial modifications
|
||||
- Parent directories will be created if needed]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file to write",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "content",
|
||||
description = "Content to write to the file",
|
||||
type = "string",
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "Whether the file was written successfully",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if write failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = true
|
||||
|
||||
---@param input {path: string, content: string}
|
||||
---@param opts CoderToolOpts
|
||||
---@return boolean|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
if not input.content then
|
||||
return nil, "content is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Writing file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Create parent directories
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if vim.fn.isdirectory(dir) == 0 then
|
||||
vim.fn.mkdir(dir, "p")
|
||||
end
|
||||
|
||||
-- Write the file
|
||||
local lines = vim.split(input.content, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to write file: " .. path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -224,6 +224,86 @@ local function format_attached_files(attached_files)
|
||||
return table.concat(parts, "")
|
||||
end
|
||||
|
||||
--- Get coder companion file path for a target file
|
||||
---@param target_path string Target file path
|
||||
---@return string|nil Coder file path if exists
|
||||
local function get_coder_companion_path(target_path)
|
||||
if not target_path or target_path == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Skip if target is already a coder file
|
||||
if target_path:match("%.coder%.") then
|
||||
return nil
|
||||
end
|
||||
|
||||
local dir = vim.fn.fnamemodify(target_path, ":h")
|
||||
local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
local coder_path = dir .. "/" .. name .. ".coder." .. ext
|
||||
if vim.fn.filereadable(coder_path) == 1 then
|
||||
return coder_path
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Read and format coder companion context (business logic, pseudo-code)
|
||||
---@param target_path string Target file path
|
||||
---@return string Formatted coder context
|
||||
local function get_coder_context(target_path)
|
||||
local coder_path = get_coder_companion_path(target_path)
|
||||
if not coder_path then
|
||||
return ""
|
||||
end
|
||||
|
||||
local ok, lines = pcall(function()
|
||||
return vim.fn.readfile(coder_path)
|
||||
end)
|
||||
|
||||
if not ok or not lines or #lines == 0 then
|
||||
return ""
|
||||
end
|
||||
|
||||
local content = table.concat(lines, "\n")
|
||||
|
||||
-- Skip if only template comments (no actual content)
|
||||
local stripped = content:gsub("^%s*", ""):gsub("%s*$", "")
|
||||
if stripped == "" then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Check if there's meaningful content (not just template)
|
||||
local has_content = false
|
||||
for _, line in ipairs(lines) do
|
||||
-- Skip comment lines that are part of the template
|
||||
local trimmed = line:gsub("^%s*", "")
|
||||
if not trimmed:match("^[%-#/]+%s*Coder companion")
|
||||
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
|
||||
and not trimmed:match("^[%-#/]+%s*Example:")
|
||||
and not trimmed:match("^<!%-%-")
|
||||
and trimmed ~= ""
|
||||
and not trimmed:match("^[%-#/]+%s*$") then
|
||||
has_content = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if not has_content then
|
||||
return ""
|
||||
end
|
||||
|
||||
local ext = vim.fn.fnamemodify(coder_path, ":e")
|
||||
return string.format(
|
||||
"\n\n--- Business Context / Pseudo-code ---\n" ..
|
||||
"The following describes the intended behavior and design for this file:\n" ..
|
||||
"```%s\n%s\n```",
|
||||
ext,
|
||||
content:sub(1, 4000) -- Limit to 4000 chars
|
||||
)
|
||||
end
|
||||
|
||||
--- Format indexed project context for inclusion in prompt
|
||||
---@param indexed_context table|nil
|
||||
---@return string
|
||||
@@ -309,8 +389,53 @@ local function build_prompt(event)
|
||||
-- Format attached files
|
||||
local attached_content = format_attached_files(event.attached_files)
|
||||
|
||||
-- Combine attached files and indexed context
|
||||
local extra_context = attached_content .. indexed_content
|
||||
-- Get coder companion context (business logic, pseudo-code)
|
||||
local coder_context = get_coder_context(event.target_path)
|
||||
|
||||
-- Get brain memories - contextual recall based on current task
|
||||
local brain_context = ""
|
||||
pcall(function()
|
||||
local brain = require("codetyper.brain")
|
||||
if brain.is_initialized() then
|
||||
-- Query brain for relevant memories based on:
|
||||
-- 1. Current file (file-specific patterns)
|
||||
-- 2. Prompt content (semantic similarity)
|
||||
-- 3. Intent type (relevant past generations)
|
||||
local query_text = event.prompt_content or ""
|
||||
if event.scope and event.scope.name then
|
||||
query_text = event.scope.name .. " " .. query_text
|
||||
end
|
||||
|
||||
local result = brain.query({
|
||||
query = query_text,
|
||||
file = event.target_path,
|
||||
max_results = 5,
|
||||
types = { "pattern", "correction", "convention" },
|
||||
})
|
||||
|
||||
if result and result.nodes and #result.nodes > 0 then
|
||||
local memories = { "\n\n--- Learned Patterns & Conventions ---" }
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c then
|
||||
local summary = node.c.s or ""
|
||||
local detail = node.c.d or ""
|
||||
if summary ~= "" then
|
||||
table.insert(memories, "• " .. summary)
|
||||
if detail ~= "" and #detail < 200 then
|
||||
table.insert(memories, " " .. detail)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if #memories > 1 then
|
||||
brain_context = table.concat(memories, "\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
-- Combine all context sources: brain memories first, then coder context, attached files, indexed
|
||||
local extra_context = brain_context .. coder_context .. attached_content .. indexed_content
|
||||
|
||||
-- Build context with scope information
|
||||
local context = {
|
||||
@@ -502,21 +627,21 @@ function M.start(worker)
|
||||
end
|
||||
end, worker.timeout_ms)
|
||||
|
||||
-- Get client and execute
|
||||
local client, client_err = get_client(worker.worker_type)
|
||||
if not client then
|
||||
M.complete(worker, nil, client_err)
|
||||
return
|
||||
end
|
||||
|
||||
local prompt, context = build_prompt(worker.event)
|
||||
|
||||
-- Call the LLM
|
||||
client.generate(prompt, context, function(response, err, usage)
|
||||
-- Check if smart selection is enabled (memory-based provider selection)
|
||||
local use_smart_selection = false
|
||||
pcall(function()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
use_smart_selection = config.llm.smart_selection ~= false -- Default to true
|
||||
end)
|
||||
|
||||
-- Define the response handler
|
||||
local function handle_response(response, err, usage_or_metadata)
|
||||
-- Cancel timeout timer
|
||||
if worker.timer then
|
||||
pcall(function()
|
||||
-- Timer might have already fired
|
||||
if type(worker.timer) == "userdata" and worker.timer.stop then
|
||||
worker.timer:stop()
|
||||
end
|
||||
@@ -527,8 +652,45 @@ function M.start(worker)
|
||||
return -- Already timed out or cancelled
|
||||
end
|
||||
|
||||
-- Extract usage from metadata if smart_generate was used
|
||||
local usage = usage_or_metadata
|
||||
if type(usage_or_metadata) == "table" and usage_or_metadata.provider then
|
||||
-- This is metadata from smart_generate
|
||||
usage = nil
|
||||
-- Update worker type to reflect actual provider used
|
||||
worker.worker_type = usage_or_metadata.provider
|
||||
-- Log if pondering occurred
|
||||
if usage_or_metadata.pondered then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
"Pondering: %s (agreement: %.0f%%)",
|
||||
usage_or_metadata.corrected and "corrected" or "validated",
|
||||
(usage_or_metadata.agreement or 1) * 100
|
||||
),
|
||||
})
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
M.complete(worker, response, err, usage)
|
||||
end)
|
||||
end
|
||||
|
||||
-- Use smart selection or direct client
|
||||
if use_smart_selection then
|
||||
local llm = require("codetyper.llm")
|
||||
llm.smart_generate(prompt, context, handle_response)
|
||||
else
|
||||
-- Get client and execute directly
|
||||
local client, client_err = get_client(worker.worker_type)
|
||||
if not client then
|
||||
M.complete(worker, nil, client_err)
|
||||
return
|
||||
end
|
||||
client.generate(prompt, context, handle_response)
|
||||
end
|
||||
end
|
||||
|
||||
--- Complete worker execution
|
||||
|
||||
@@ -312,7 +312,9 @@ local function append_log_to_output(entry)
|
||||
}
|
||||
|
||||
local icon = icons[entry.level] or "•"
|
||||
local formatted = string.format("[%s] %s %s", entry.timestamp, icon, entry.message)
|
||||
-- Sanitize message - replace newlines with spaces to prevent nvim_buf_set_lines error
|
||||
local sanitized_message = entry.message:gsub("\n", " "):gsub("\r", "")
|
||||
local formatted = string.format("[%s] %s %s", entry.timestamp, icon, sanitized_message)
|
||||
|
||||
vim.schedule(function()
|
||||
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then
|
||||
|
||||
@@ -18,6 +18,16 @@ local processed_prompts = {}
|
||||
--- Track if we're currently asking for preferences
|
||||
local asking_preference = false
|
||||
|
||||
--- Track if we're currently processing prompts (busy flag)
|
||||
local is_processing = false
|
||||
|
||||
--- Track the previous mode for visual mode detection
|
||||
local previous_mode = "n"
|
||||
|
||||
--- Debounce timer for prompt processing
|
||||
local prompt_process_timer = nil
|
||||
local PROMPT_PROCESS_DEBOUNCE_MS = 200 -- Wait 200ms after mode change before processing
|
||||
|
||||
--- Generate a unique key for a prompt
|
||||
---@param bufnr number Buffer number
|
||||
---@param prompt table Prompt object
|
||||
@@ -64,6 +74,20 @@ function M.setup()
|
||||
desc = "Check for closed prompt tags on InsertLeave",
|
||||
})
|
||||
|
||||
-- Track mode changes for visual mode detection
|
||||
vim.api.nvim_create_autocmd("ModeChanged", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function(ev)
|
||||
-- Extract old mode from pattern (format: "old_mode:new_mode")
|
||||
local old_mode = ev.match:match("^(.-):")
|
||||
if old_mode then
|
||||
previous_mode = old_mode
|
||||
end
|
||||
end,
|
||||
desc = "Track previous mode for visual mode detection",
|
||||
})
|
||||
|
||||
-- Auto-process prompts when entering normal mode (works on ALL files)
|
||||
vim.api.nvim_create_autocmd("ModeChanged", {
|
||||
group = group,
|
||||
@@ -74,10 +98,33 @@ function M.setup()
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
-- Slight delay to let buffer settle
|
||||
vim.defer_fn(function()
|
||||
|
||||
-- Skip if currently processing (avoid concurrent processing)
|
||||
if is_processing then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip if coming from visual mode (v, V, CTRL-V) - user is still editing
|
||||
if previous_mode == "v" or previous_mode == "V" or previous_mode == "\22" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Cancel any pending processing timer
|
||||
if prompt_process_timer then
|
||||
prompt_process_timer:stop()
|
||||
prompt_process_timer = nil
|
||||
end
|
||||
|
||||
-- Debounced processing - wait for user to truly be idle
|
||||
prompt_process_timer = vim.defer_fn(function()
|
||||
prompt_process_timer = nil
|
||||
-- Double-check we're still in normal mode
|
||||
local mode = vim.api.nvim_get_mode().mode
|
||||
if mode ~= "n" then
|
||||
return
|
||||
end
|
||||
M.check_all_prompts_with_preference()
|
||||
end, 50)
|
||||
end, PROMPT_PROCESS_DEBOUNCE_MS)
|
||||
end,
|
||||
desc = "Auto-process closed prompts when entering normal mode",
|
||||
})
|
||||
@@ -92,6 +139,10 @@ function M.setup()
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
-- Skip if currently processing
|
||||
if is_processing then
|
||||
return
|
||||
end
|
||||
local mode = vim.api.nvim_get_mode().mode
|
||||
if mode == "n" then
|
||||
M.check_all_prompts_with_preference()
|
||||
@@ -291,6 +342,12 @@ end
|
||||
|
||||
--- Check if the buffer has a newly closed prompt and auto-process (works on ANY file)
|
||||
function M.check_for_closed_prompt()
|
||||
-- Skip if already processing
|
||||
if is_processing then
|
||||
return
|
||||
end
|
||||
is_processing = true
|
||||
|
||||
local config = get_config_safe()
|
||||
local parser = require("codetyper.parser")
|
||||
|
||||
@@ -299,6 +356,7 @@ function M.check_for_closed_prompt()
|
||||
|
||||
-- Skip if no file
|
||||
if current_file == "" then
|
||||
is_processing = false
|
||||
return
|
||||
end
|
||||
|
||||
@@ -308,6 +366,7 @@ function M.check_for_closed_prompt()
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, line - 1, line, false)
|
||||
|
||||
if #lines == 0 then
|
||||
is_processing = false
|
||||
return
|
||||
end
|
||||
|
||||
@@ -323,6 +382,7 @@ function M.check_for_closed_prompt()
|
||||
|
||||
-- Check if already processed
|
||||
if processed_prompts[prompt_key] then
|
||||
is_processing = false
|
||||
return
|
||||
end
|
||||
|
||||
@@ -366,23 +426,38 @@ function M.check_for_closed_prompt()
|
||||
-- Clean prompt content (strip file references)
|
||||
local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content))
|
||||
|
||||
-- Resolve scope in target file FIRST (need it to adjust intent)
|
||||
local target_bufnr = vim.fn.bufnr(target_path)
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr
|
||||
end
|
||||
-- Check if we're working from a coder file
|
||||
local is_from_coder_file = utils.is_coder_file(current_file)
|
||||
|
||||
-- Resolve scope in target file FIRST (need it to adjust intent)
|
||||
-- Only resolve scope if NOT from coder file (line numbers don't apply)
|
||||
local target_bufnr = vim.fn.bufnr(target_path)
|
||||
local scope = nil
|
||||
local scope_text = nil
|
||||
local scope_range = nil
|
||||
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
if not is_from_coder_file then
|
||||
-- Prompt is in the actual source file, use line position for scope
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr
|
||||
end
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
end
|
||||
else
|
||||
-- Prompt is in coder file - load target if needed, but don't use scope
|
||||
-- Code from coder files should append to target by default
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = vim.fn.bufadd(target_path)
|
||||
if target_bufnr ~= 0 then
|
||||
vim.fn.bufload(target_bufnr)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Detect intent from prompt
|
||||
@@ -390,7 +465,8 @@ function M.check_for_closed_prompt()
|
||||
|
||||
-- IMPORTANT: If prompt is inside a function/method and intent is "add",
|
||||
-- override to "complete" since we're completing the function body
|
||||
if scope and (scope.type == "function" or scope.type == "method") then
|
||||
-- But NOT for coder files - they should use "add/append" by default
|
||||
if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then
|
||||
if intent.type == "add" or intent.action == "insert" or intent.action == "append" then
|
||||
-- Override to complete the function instead of adding new code
|
||||
intent = {
|
||||
@@ -403,6 +479,16 @@ function M.check_for_closed_prompt()
|
||||
end
|
||||
end
|
||||
|
||||
-- For coder files, default to "add" with "append" action
|
||||
if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then
|
||||
intent = {
|
||||
type = intent.type == "complete" and "add" or intent.type,
|
||||
confidence = intent.confidence,
|
||||
action = "append",
|
||||
keywords = intent.keywords,
|
||||
}
|
||||
end
|
||||
|
||||
-- Determine priority based on intent
|
||||
local priority = 2 -- Normal
|
||||
if intent.type == "fix" or intent.type == "complete" then
|
||||
@@ -446,6 +532,7 @@ function M.check_for_closed_prompt()
|
||||
end
|
||||
end
|
||||
end
|
||||
is_processing = false
|
||||
end
|
||||
|
||||
--- Check and process all closed prompts in the buffer (works on ANY file)
|
||||
@@ -507,7 +594,8 @@ function M.check_all_prompts()
|
||||
|
||||
-- Get target path - for coder files, get the target; for regular files, use self
|
||||
local target_path
|
||||
if utils.is_coder_file(current_file) then
|
||||
local is_from_coder_file = utils.is_coder_file(current_file)
|
||||
if is_from_coder_file then
|
||||
target_path = utils.get_target_path(current_file)
|
||||
else
|
||||
target_path = current_file
|
||||
@@ -520,22 +608,33 @@ function M.check_all_prompts()
|
||||
local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content))
|
||||
|
||||
-- Resolve scope in target file FIRST (need it to adjust intent)
|
||||
-- Only resolve scope if NOT from coder file (line numbers don't apply)
|
||||
local target_bufnr = vim.fn.bufnr(target_path)
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr -- Use current buffer if target not loaded
|
||||
end
|
||||
|
||||
local scope = nil
|
||||
local scope_text = nil
|
||||
local scope_range = nil
|
||||
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
if not is_from_coder_file then
|
||||
-- Prompt is in the actual source file, use line position for scope
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr
|
||||
end
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
end
|
||||
else
|
||||
-- Prompt is in coder file - load target if needed
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = vim.fn.bufadd(target_path)
|
||||
if target_bufnr ~= 0 then
|
||||
vim.fn.bufload(target_bufnr)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Detect intent from prompt
|
||||
@@ -543,7 +642,8 @@ function M.check_all_prompts()
|
||||
|
||||
-- IMPORTANT: If prompt is inside a function/method and intent is "add",
|
||||
-- override to "complete" since we're completing the function body
|
||||
if scope and (scope.type == "function" or scope.type == "method") then
|
||||
-- But NOT for coder files - they should use "add/append" by default
|
||||
if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then
|
||||
if intent.type == "add" or intent.action == "insert" or intent.action == "append" then
|
||||
-- Override to complete the function instead of adding new code
|
||||
intent = {
|
||||
@@ -556,6 +656,16 @@ function M.check_all_prompts()
|
||||
end
|
||||
end
|
||||
|
||||
-- For coder files, default to "add" with "append" action
|
||||
if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then
|
||||
intent = {
|
||||
type = intent.type == "complete" and "add" or intent.type,
|
||||
confidence = intent.confidence,
|
||||
action = "append",
|
||||
keywords = intent.keywords,
|
||||
}
|
||||
end
|
||||
|
||||
-- Determine priority based on intent
|
||||
local priority = 2
|
||||
if intent.type == "fix" or intent.type == "complete" then
|
||||
@@ -932,20 +1042,17 @@ function M.update_brain_from_file(filepath)
|
||||
|
||||
local summary = vim.fn.fnamemodify(filepath, ":t") .. " - " .. table.concat(parts, "; ")
|
||||
|
||||
-- Learn this pattern
|
||||
-- Learn this pattern - use "pattern_detected" type to match the pattern learner
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
type = "pattern_detected",
|
||||
file = filepath,
|
||||
content = {
|
||||
summary = summary,
|
||||
detail = #functions .. " functions, " .. #classes .. " classes",
|
||||
code = nil,
|
||||
},
|
||||
context = {
|
||||
file = filepath,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
name = summary,
|
||||
description = #functions .. " functions, " .. #classes .. " classes",
|
||||
language = ext,
|
||||
functions = functions,
|
||||
classes = classes,
|
||||
symbols = vim.tbl_map(function(f) return f.name end, functions),
|
||||
example = nil,
|
||||
},
|
||||
})
|
||||
end
|
||||
@@ -997,6 +1104,126 @@ end
|
||||
|
||||
--- Auto-index a file by creating/opening its coder companion
|
||||
---@param bufnr number Buffer number
|
||||
--- Directories to ignore for coder file creation
|
||||
local ignored_directories = {
|
||||
".git",
|
||||
".coder",
|
||||
".claude",
|
||||
".vscode",
|
||||
".idea",
|
||||
"node_modules",
|
||||
"vendor",
|
||||
"dist",
|
||||
"build",
|
||||
"target",
|
||||
"__pycache__",
|
||||
".cache",
|
||||
".npm",
|
||||
".yarn",
|
||||
"coverage",
|
||||
".next",
|
||||
".nuxt",
|
||||
".svelte-kit",
|
||||
"out",
|
||||
"bin",
|
||||
"obj",
|
||||
}
|
||||
|
||||
--- Files to ignore for coder file creation (exact names or patterns)
|
||||
local ignored_files = {
|
||||
-- Git files
|
||||
".gitignore",
|
||||
".gitattributes",
|
||||
".gitmodules",
|
||||
-- Lock files
|
||||
"package-lock.json",
|
||||
"yarn.lock",
|
||||
"pnpm-lock.yaml",
|
||||
"Cargo.lock",
|
||||
"Gemfile.lock",
|
||||
"poetry.lock",
|
||||
"composer.lock",
|
||||
-- Config files that don't need coder companions
|
||||
".env",
|
||||
".env.local",
|
||||
".env.development",
|
||||
".env.production",
|
||||
".eslintrc",
|
||||
".eslintrc.json",
|
||||
".prettierrc",
|
||||
".prettierrc.json",
|
||||
".editorconfig",
|
||||
".dockerignore",
|
||||
"Dockerfile",
|
||||
"docker-compose.yml",
|
||||
"docker-compose.yaml",
|
||||
".npmrc",
|
||||
".yarnrc",
|
||||
".nvmrc",
|
||||
"tsconfig.json",
|
||||
"jsconfig.json",
|
||||
"babel.config.js",
|
||||
"webpack.config.js",
|
||||
"vite.config.js",
|
||||
"rollup.config.js",
|
||||
"jest.config.js",
|
||||
"vitest.config.js",
|
||||
".stylelintrc",
|
||||
"tailwind.config.js",
|
||||
"postcss.config.js",
|
||||
-- Other non-code files
|
||||
"README.md",
|
||||
"CHANGELOG.md",
|
||||
"LICENSE",
|
||||
"LICENSE.md",
|
||||
"CONTRIBUTING.md",
|
||||
"Makefile",
|
||||
"CMakeLists.txt",
|
||||
}
|
||||
|
||||
--- Check if a file path contains an ignored directory
|
||||
---@param filepath string Full file path
|
||||
---@return boolean
|
||||
local function is_in_ignored_directory(filepath)
|
||||
for _, dir in ipairs(ignored_directories) do
|
||||
-- Check for /dirname/ or /dirname at end
|
||||
if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then
|
||||
return true
|
||||
end
|
||||
-- Also check for dirname/ at start (relative paths)
|
||||
if filepath:match("^" .. dir .. "/") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if a file should be ignored for coder companion creation
|
||||
---@param filepath string Full file path
|
||||
---@return boolean
|
||||
local function should_ignore_for_coder(filepath)
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
|
||||
-- Check exact filename matches
|
||||
for _, ignored in ipairs(ignored_files) do
|
||||
if filename == ignored then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if file starts with dot (hidden/config files)
|
||||
if filename:match("^%.") then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check if in ignored directory
|
||||
if is_in_ignored_directory(filepath) then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
function M.auto_index_file(bufnr)
|
||||
-- Skip if buffer is invalid
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
@@ -1031,6 +1258,11 @@ function M.auto_index_file(bufnr)
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip ignored directories and files (node_modules, .git, config files, etc.)
|
||||
if should_ignore_for_coder(filepath) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip if auto_index is disabled in config
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
@@ -1047,23 +1279,137 @@ function M.auto_index_file(bufnr)
|
||||
-- Check if coder file already exists
|
||||
local coder_exists = utils.file_exists(coder_path)
|
||||
|
||||
-- Create coder file with template if it doesn't exist
|
||||
-- Create coder file with pseudo-code context if it doesn't exist
|
||||
if not coder_exists then
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
local template = string.format(
|
||||
[[-- Coder companion for %s
|
||||
-- Use /@ @/ tags to write pseudo-code prompts
|
||||
-- Example:
|
||||
-- /@
|
||||
-- Add a function that validates user input
|
||||
-- - Check for empty strings
|
||||
-- - Validate email format
|
||||
-- @/
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
|
||||
]],
|
||||
filename
|
||||
)
|
||||
utils.write_file(coder_path, template)
|
||||
-- Determine comment style based on extension
|
||||
local comment_prefix = "--"
|
||||
local comment_block_start = "--[["
|
||||
local comment_block_end = "]]"
|
||||
if ext == "ts" or ext == "tsx" or ext == "js" or ext == "jsx" or ext == "java" or ext == "c" or ext == "cpp" or ext == "cs" or ext == "go" or ext == "rs" then
|
||||
comment_prefix = "//"
|
||||
comment_block_start = "/*"
|
||||
comment_block_end = "*/"
|
||||
elseif ext == "py" or ext == "rb" or ext == "yaml" or ext == "yml" then
|
||||
comment_prefix = "#"
|
||||
comment_block_start = '"""'
|
||||
comment_block_end = '"""'
|
||||
end
|
||||
|
||||
-- Read target file to analyze its structure
|
||||
local content = ""
|
||||
pcall(function()
|
||||
local lines = vim.fn.readfile(filepath)
|
||||
if lines then
|
||||
content = table.concat(lines, "\n")
|
||||
end
|
||||
end)
|
||||
|
||||
-- Extract structure from the file
|
||||
local functions = extract_functions(content, ext)
|
||||
local classes = extract_classes(content, ext)
|
||||
local imports = extract_imports(content, ext)
|
||||
|
||||
-- Build pseudo-code context
|
||||
local pseudo_code = {}
|
||||
|
||||
-- Header
|
||||
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
|
||||
table.insert(pseudo_code, comment_prefix .. " CODER COMPANION: " .. filename)
|
||||
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
|
||||
table.insert(pseudo_code, comment_prefix .. " This file describes the business logic and behavior of " .. filename)
|
||||
table.insert(pseudo_code, comment_prefix .. " Edit this pseudo-code to guide code generation.")
|
||||
table.insert(pseudo_code, comment_prefix .. " Use /@ @/ tags for specific generation requests.")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
|
||||
-- Module purpose
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " MODULE PURPOSE:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " TODO: Describe what this module/file is responsible for")
|
||||
table.insert(pseudo_code, comment_prefix .. " Example: \"Handles user authentication and session management\"")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
|
||||
-- Dependencies section
|
||||
if #imports > 0 then
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " DEPENDENCIES:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
for _, imp in ipairs(imports) do
|
||||
table.insert(pseudo_code, comment_prefix .. " • " .. imp)
|
||||
end
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
end
|
||||
|
||||
-- Classes section
|
||||
if #classes > 0 then
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " CLASSES:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
for _, class in ipairs(classes) do
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
table.insert(pseudo_code, comment_prefix .. " class " .. class.name .. ":")
|
||||
table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - describe what this class represents")
|
||||
table.insert(pseudo_code, comment_prefix .. " RESPONSIBILITIES:")
|
||||
table.insert(pseudo_code, comment_prefix .. " - TODO: list main responsibilities")
|
||||
end
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
end
|
||||
|
||||
-- Functions section
|
||||
if #functions > 0 then
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " FUNCTIONS:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
for _, func in ipairs(functions) do
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
table.insert(pseudo_code, comment_prefix .. " " .. func.name .. "():")
|
||||
table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - what does this function do?")
|
||||
table.insert(pseudo_code, comment_prefix .. " INPUTS: TODO - describe parameters")
|
||||
table.insert(pseudo_code, comment_prefix .. " OUTPUTS: TODO - describe return value")
|
||||
table.insert(pseudo_code, comment_prefix .. " BEHAVIOR:")
|
||||
table.insert(pseudo_code, comment_prefix .. " - TODO: describe step-by-step logic")
|
||||
end
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
end
|
||||
|
||||
-- If empty file, provide starter template
|
||||
if #functions == 0 and #classes == 0 then
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " PLANNED STRUCTURE:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " TODO: Describe what you want to build in this file")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
table.insert(pseudo_code, comment_prefix .. " Example pseudo-code:")
|
||||
table.insert(pseudo_code, comment_prefix .. " /@")
|
||||
table.insert(pseudo_code, comment_prefix .. " Create a module that:")
|
||||
table.insert(pseudo_code, comment_prefix .. " 1. Exports a main function")
|
||||
table.insert(pseudo_code, comment_prefix .. " 2. Handles errors gracefully")
|
||||
table.insert(pseudo_code, comment_prefix .. " 3. Returns structured data")
|
||||
table.insert(pseudo_code, comment_prefix .. " @/")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
end
|
||||
|
||||
-- Business rules section
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " BUSINESS RULES:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " TODO: Document any business rules, constraints, or requirements")
|
||||
table.insert(pseudo_code, comment_prefix .. " Example:")
|
||||
table.insert(pseudo_code, comment_prefix .. " - Users must be authenticated before accessing this feature")
|
||||
table.insert(pseudo_code, comment_prefix .. " - Data must be validated before saving")
|
||||
table.insert(pseudo_code, comment_prefix .. " - Errors should be logged but not exposed to users")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
|
||||
-- Footer with generation tags example
|
||||
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
|
||||
table.insert(pseudo_code, comment_prefix .. " Use /@ @/ tags below to request code generation:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
|
||||
table.insert(pseudo_code, "")
|
||||
|
||||
utils.write_file(coder_path, table.concat(pseudo_code, "\n"))
|
||||
end
|
||||
|
||||
-- Notify user about the coder companion
|
||||
|
||||
@@ -111,7 +111,7 @@ function M.compute_relevance(node, opts)
|
||||
return score
|
||||
end
|
||||
|
||||
--- Traverse graph from seed nodes
|
||||
--- Traverse graph from seed nodes (basic traversal)
|
||||
---@param seed_ids string[] Starting node IDs
|
||||
---@param depth number Traversal depth
|
||||
---@param edge_types? EdgeType[] Edge types to follow
|
||||
@@ -157,6 +157,73 @@ local function traverse(seed_ids, depth, edge_types)
|
||||
return discovered
|
||||
end
|
||||
|
||||
--- Spreading activation - mimics human associative memory
|
||||
--- Activation spreads from seed nodes along edges, decaying by weight
|
||||
--- Nodes accumulate activation from multiple paths (like neural pathways)
|
||||
---@param seed_activations table<string, number> Initial activations {node_id: activation}
|
||||
---@param max_iterations number Max spread iterations (default 3)
|
||||
---@param decay number Activation decay per hop (default 0.5)
|
||||
---@param threshold number Minimum activation to continue spreading (default 0.1)
|
||||
---@return table<string, number> Final activations {node_id: accumulated_activation}
|
||||
local function spreading_activation(seed_activations, max_iterations, decay, threshold)
|
||||
local edge_mod = get_edge_module()
|
||||
max_iterations = max_iterations or 3
|
||||
decay = decay or 0.5
|
||||
threshold = threshold or 0.1
|
||||
|
||||
-- Accumulated activation for each node
|
||||
local activation = {}
|
||||
for node_id, act in pairs(seed_activations) do
|
||||
activation[node_id] = act
|
||||
end
|
||||
|
||||
-- Current frontier with their activation levels
|
||||
local frontier = {}
|
||||
for node_id, act in pairs(seed_activations) do
|
||||
frontier[node_id] = act
|
||||
end
|
||||
|
||||
-- Spread activation iteratively
|
||||
for _ = 1, max_iterations do
|
||||
local next_frontier = {}
|
||||
|
||||
for source_id, source_activation in pairs(frontier) do
|
||||
-- Get all outgoing edges
|
||||
local edges = edge_mod.get_edges(source_id, nil, "both")
|
||||
|
||||
for _, edge in ipairs(edges) do
|
||||
-- Determine target (could be source or target of edge)
|
||||
local target_id = edge.s == source_id and edge.t or edge.s
|
||||
|
||||
-- Calculate spreading activation
|
||||
-- Activation = source_activation * edge_weight * decay
|
||||
local edge_weight = edge.p and edge.p.w or 0.5
|
||||
local spread_amount = source_activation * edge_weight * decay
|
||||
|
||||
-- Only spread if above threshold
|
||||
if spread_amount >= threshold then
|
||||
-- Accumulate activation (multiple paths add up)
|
||||
activation[target_id] = (activation[target_id] or 0) + spread_amount
|
||||
|
||||
-- Add to next frontier if not already processed with higher activation
|
||||
if not next_frontier[target_id] or next_frontier[target_id] < spread_amount then
|
||||
next_frontier[target_id] = spread_amount
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Stop if no more spreading
|
||||
if vim.tbl_count(next_frontier) == 0 then
|
||||
break
|
||||
end
|
||||
|
||||
frontier = next_frontier
|
||||
end
|
||||
|
||||
return activation
|
||||
end
|
||||
|
||||
--- Execute a query across all dimensions
|
||||
---@param opts QueryOpts Query options
|
||||
---@return QueryResult
|
||||
@@ -236,28 +303,49 @@ function M.execute(opts)
|
||||
end
|
||||
end
|
||||
|
||||
-- 4. Combine and deduplicate
|
||||
-- 4. Combine all found nodes and compute seed activations
|
||||
local all_nodes = {}
|
||||
local seed_activations = {}
|
||||
|
||||
for _, category in pairs(results) do
|
||||
for id, node in pairs(category) do
|
||||
if not all_nodes[id] then
|
||||
all_nodes[id] = node
|
||||
-- Compute initial activation based on relevance
|
||||
local relevance = M.compute_relevance(node, opts)
|
||||
seed_activations[id] = relevance
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 5. Score and rank
|
||||
-- 5. Apply spreading activation - like human associative memory
|
||||
-- Activation spreads from seed nodes along edges, accumulating
|
||||
-- Nodes connected to multiple relevant seeds get higher activation
|
||||
local final_activations = spreading_activation(
|
||||
seed_activations,
|
||||
opts.spread_iterations or 3, -- How far activation spreads
|
||||
opts.spread_decay or 0.5, -- How much activation decays per hop
|
||||
opts.spread_threshold or 0.05 -- Minimum activation to continue spreading
|
||||
)
|
||||
|
||||
-- 6. Score and rank by combined activation
|
||||
local scored = {}
|
||||
for id, node in pairs(all_nodes) do
|
||||
local relevance = M.compute_relevance(node, opts)
|
||||
table.insert(scored, { node = node, relevance = relevance })
|
||||
for id, activation in pairs(final_activations) do
|
||||
local node = all_nodes[id] or node_mod.get(id)
|
||||
if node then
|
||||
all_nodes[id] = node
|
||||
-- Final score = spreading activation + base relevance
|
||||
local base_relevance = M.compute_relevance(node, opts)
|
||||
local final_score = (activation * 0.6) + (base_relevance * 0.4)
|
||||
table.insert(scored, { node = node, relevance = final_score, activation = activation })
|
||||
end
|
||||
end
|
||||
|
||||
table.sort(scored, function(a, b)
|
||||
return a.relevance > b.relevance
|
||||
end)
|
||||
|
||||
-- 6. Apply limit
|
||||
-- 7. Apply limit
|
||||
local limit = opts.limit or 50
|
||||
local result_nodes = {}
|
||||
local truncated = #scored > limit
|
||||
@@ -266,7 +354,7 @@ function M.execute(opts)
|
||||
table.insert(result_nodes, scored[i].node)
|
||||
end
|
||||
|
||||
-- 7. Get edges between result nodes
|
||||
-- 8. Get edges between result nodes
|
||||
local edge_mod = get_edge_module()
|
||||
local result_edges = {}
|
||||
local node_ids = {}
|
||||
@@ -291,11 +379,17 @@ function M.execute(opts)
|
||||
file_count = vim.tbl_count(results.file),
|
||||
temporal_count = vim.tbl_count(results.temporal),
|
||||
total_scored = #scored,
|
||||
seed_nodes = vim.tbl_count(seed_activations),
|
||||
activated_nodes = vim.tbl_count(final_activations),
|
||||
},
|
||||
truncated = truncated,
|
||||
}
|
||||
end
|
||||
|
||||
--- Expose spreading activation for direct use
|
||||
--- Useful for custom activation patterns or debugging
|
||||
M.spreading_activation = spreading_activation
|
||||
|
||||
--- Find nodes by file
|
||||
---@param filepath string File path
|
||||
---@param limit? number Max results
|
||||
|
||||
@@ -9,6 +9,10 @@ local M = {}
|
||||
---@param event LearnEvent Learning event
|
||||
---@return boolean
|
||||
function M.detect(event)
|
||||
if not event or not event.type then
|
||||
return false
|
||||
end
|
||||
|
||||
local valid_types = {
|
||||
"code_completion",
|
||||
"file_indexed",
|
||||
|
||||
@@ -287,6 +287,118 @@ local function cmd_agent_stop()
|
||||
end
|
||||
end
|
||||
|
||||
--- Run the agentic loop with a task
|
||||
---@param task string The task to accomplish
|
||||
---@param agent_name? string Optional agent name
|
||||
local function cmd_agentic_run(task, agent_name)
|
||||
local agentic = require("codetyper.agent.agentic")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
-- Open logs panel
|
||||
logs_panel.open()
|
||||
|
||||
logs.info("Starting agentic task: " .. task:sub(1, 50) .. "...")
|
||||
utils.notify("Running agentic task...", vim.log.levels.INFO)
|
||||
|
||||
-- Get current file for context
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
local files = {}
|
||||
if current_file ~= "" then
|
||||
table.insert(files, current_file)
|
||||
end
|
||||
|
||||
agentic.run({
|
||||
task = task,
|
||||
files = files,
|
||||
agent = agent_name or "coder",
|
||||
on_status = function(status)
|
||||
logs.thinking(status)
|
||||
end,
|
||||
on_tool_start = function(name, args)
|
||||
logs.info("Tool: " .. name)
|
||||
end,
|
||||
on_tool_end = function(name, result, err)
|
||||
if err then
|
||||
logs.error(name .. " failed: " .. err)
|
||||
else
|
||||
logs.debug(name .. " completed")
|
||||
end
|
||||
end,
|
||||
on_file_change = function(path, action)
|
||||
logs.info("File " .. action .. ": " .. path)
|
||||
end,
|
||||
on_message = function(msg)
|
||||
if msg.role == "assistant" and type(msg.content) == "string" and msg.content ~= "" then
|
||||
logs.thinking(msg.content:sub(1, 100) .. "...")
|
||||
end
|
||||
end,
|
||||
on_complete = function(result, err)
|
||||
if err then
|
||||
logs.error("Task failed: " .. err)
|
||||
utils.notify("Agentic task failed: " .. err, vim.log.levels.ERROR)
|
||||
else
|
||||
logs.info("Task completed successfully")
|
||||
utils.notify("Agentic task completed!", vim.log.levels.INFO)
|
||||
if result and result ~= "" then
|
||||
-- Show summary in a float
|
||||
vim.schedule(function()
|
||||
vim.notify("Result:\n" .. result:sub(1, 500), vim.log.levels.INFO)
|
||||
end)
|
||||
end
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- List available agents
|
||||
local function cmd_agentic_list()
|
||||
local agentic = require("codetyper.agent.agentic")
|
||||
local agents = agentic.list_agents()
|
||||
|
||||
local lines = {
|
||||
"Available Agents",
|
||||
"================",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, agent in ipairs(agents) do
|
||||
local badge = agent.builtin and "[builtin]" or "[custom]"
|
||||
table.insert(lines, string.format(" %s %s", agent.name, badge))
|
||||
table.insert(lines, string.format(" %s", agent.description))
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
table.insert(lines, "Use :CoderAgenticRun <task> [agent] to run a task")
|
||||
table.insert(lines, "Use :CoderAgenticInit to create custom agents")
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Initialize .coder/agents/ and .coder/rules/ directories
|
||||
local function cmd_agentic_init()
|
||||
local agentic = require("codetyper.agent.agentic")
|
||||
agentic.init()
|
||||
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
|
||||
local lines = {
|
||||
"Initialized Coder directories:",
|
||||
"",
|
||||
" " .. agents_dir,
|
||||
" - example.md (template for custom agents)",
|
||||
"",
|
||||
" " .. rules_dir,
|
||||
" - code-style.md (template for project rules)",
|
||||
"",
|
||||
"Edit these files to customize agent behavior.",
|
||||
"Create new .md files to add more agents/rules.",
|
||||
}
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Show chat type switcher modal (Ask/Agent)
|
||||
local function cmd_type_toggle()
|
||||
local switcher = require("codetyper.chat_switcher")
|
||||
@@ -844,6 +956,65 @@ end
|
||||
|
||||
--- Main command dispatcher
|
||||
---@param args table Command arguments
|
||||
--- Show LLM accuracy statistics
|
||||
local function cmd_llm_stats()
|
||||
local llm = require("codetyper.llm")
|
||||
local stats = llm.get_accuracy_stats()
|
||||
|
||||
local lines = {
|
||||
"LLM Provider Accuracy Statistics",
|
||||
"================================",
|
||||
"",
|
||||
string.format("Ollama:"),
|
||||
string.format(" Total requests: %d", stats.ollama.total),
|
||||
string.format(" Correct: %d", stats.ollama.correct),
|
||||
string.format(" Accuracy: %.1f%%", stats.ollama.accuracy * 100),
|
||||
"",
|
||||
string.format("Copilot:"),
|
||||
string.format(" Total requests: %d", stats.copilot.total),
|
||||
string.format(" Correct: %d", stats.copilot.correct),
|
||||
string.format(" Accuracy: %.1f%%", stats.copilot.accuracy * 100),
|
||||
"",
|
||||
"Note: Smart selection prefers Ollama when brain memories",
|
||||
"provide enough context. Accuracy improves over time via",
|
||||
"pondering (verification with other LLMs).",
|
||||
}
|
||||
|
||||
vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Report feedback on last LLM response
|
||||
---@param was_good boolean Whether the response was good
|
||||
local function cmd_llm_feedback(was_good)
|
||||
local llm = require("codetyper.llm")
|
||||
-- Get the last used provider from logs or default
|
||||
local provider = "ollama" -- Default assumption
|
||||
|
||||
-- Try to get actual last provider from logs
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local entries = logs.get(10)
|
||||
for i = #entries, 1, -1 do
|
||||
local entry = entries[i]
|
||||
if entry.message and entry.message:match("^LLM:") then
|
||||
provider = entry.message:match("LLM: (%w+)") or provider
|
||||
break
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
llm.report_feedback(provider, was_good)
|
||||
local feedback_type = was_good and "positive" or "negative"
|
||||
utils.notify(string.format("Reported %s feedback for %s", feedback_type, provider), vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Reset LLM accuracy statistics
|
||||
local function cmd_llm_reset_stats()
|
||||
local selector = require("codetyper.llm.selector")
|
||||
selector.reset_accuracy_stats()
|
||||
utils.notify("LLM accuracy statistics reset", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
local function coder_cmd(args)
|
||||
local subcommand = args.fargs[1] or "toggle"
|
||||
|
||||
@@ -872,6 +1043,17 @@ local function coder_cmd(args)
|
||||
["logs-toggle"] = cmd_logs_toggle,
|
||||
["queue-status"] = cmd_queue_status,
|
||||
["queue-process"] = cmd_queue_process,
|
||||
-- Agentic commands
|
||||
["agentic-run"] = function(args)
|
||||
local task = table.concat(vim.list_slice(args.fargs, 2), " ")
|
||||
if task == "" then
|
||||
utils.notify("Usage: Coder agentic-run <task> [agent]", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
cmd_agentic_run(task)
|
||||
end,
|
||||
["agentic-list"] = cmd_agentic_list,
|
||||
["agentic-init"] = cmd_agentic_init,
|
||||
["index-project"] = cmd_index_project,
|
||||
["index-status"] = cmd_index_status,
|
||||
memories = cmd_memories,
|
||||
@@ -901,6 +1083,41 @@ local function coder_cmd(args)
|
||||
end
|
||||
end
|
||||
end,
|
||||
-- LLM smart selection commands
|
||||
["llm-stats"] = cmd_llm_stats,
|
||||
["llm-feedback-good"] = function()
|
||||
cmd_llm_feedback(true)
|
||||
end,
|
||||
["llm-feedback-bad"] = function()
|
||||
cmd_llm_feedback(false)
|
||||
end,
|
||||
["llm-reset-stats"] = cmd_llm_reset_stats,
|
||||
-- Cost tracking commands
|
||||
["cost"] = function()
|
||||
local cost = require("codetyper.cost")
|
||||
cost.toggle()
|
||||
end,
|
||||
["cost-clear"] = function()
|
||||
local cost = require("codetyper.cost")
|
||||
cost.clear()
|
||||
end,
|
||||
-- Credentials management commands
|
||||
["add-api-key"] = function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_add()
|
||||
end,
|
||||
["remove-api-key"] = function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_remove()
|
||||
end,
|
||||
["credentials"] = function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.show_status()
|
||||
end,
|
||||
["switch-provider"] = function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_switch_provider()
|
||||
end,
|
||||
}
|
||||
|
||||
local cmd_fn = commands[subcommand]
|
||||
@@ -922,10 +1139,14 @@ function M.setup()
|
||||
"ask", "ask-close", "ask-toggle", "ask-clear",
|
||||
"transform", "transform-cursor",
|
||||
"agent", "agent-close", "agent-toggle", "agent-stop",
|
||||
"agentic-run", "agentic-list", "agentic-init",
|
||||
"type-toggle", "logs-toggle",
|
||||
"queue-status", "queue-process",
|
||||
"index-project", "index-status", "memories", "forget",
|
||||
"auto-toggle", "auto-set",
|
||||
"llm-stats", "llm-feedback-good", "llm-feedback-bad", "llm-reset-stats",
|
||||
"cost", "cost-clear",
|
||||
"add-api-key", "remove-api-key", "credentials", "switch-provider",
|
||||
}
|
||||
end,
|
||||
desc = "Codetyper.nvim commands",
|
||||
@@ -997,6 +1218,31 @@ function M.setup()
|
||||
cmd_agent_stop()
|
||||
end, { desc = "Stop running agent" })
|
||||
|
||||
-- Agentic commands (full IDE-like agent functionality)
|
||||
vim.api.nvim_create_user_command("CoderAgenticRun", function(opts)
|
||||
local task = opts.args
|
||||
if task == "" then
|
||||
vim.ui.input({ prompt = "Task: " }, function(input)
|
||||
if input and input ~= "" then
|
||||
cmd_agentic_run(input)
|
||||
end
|
||||
end)
|
||||
else
|
||||
cmd_agentic_run(task)
|
||||
end
|
||||
end, {
|
||||
desc = "Run agentic task (IDE-like multi-file changes)",
|
||||
nargs = "*",
|
||||
})
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAgenticList", function()
|
||||
cmd_agentic_list()
|
||||
end, { desc = "List available agents" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAgenticInit", function()
|
||||
cmd_agentic_init()
|
||||
end, { desc = "Initialize .coder/agents/ and .coder/rules/ directories" })
|
||||
|
||||
-- Chat type switcher command
|
||||
vim.api.nvim_create_user_command("CoderType", function()
|
||||
cmd_type_toggle()
|
||||
@@ -1075,6 +1321,147 @@ function M.setup()
|
||||
end,
|
||||
})
|
||||
|
||||
-- Brain feedback command - teach the brain from your experience
|
||||
vim.api.nvim_create_user_command("CoderFeedback", function(opts)
|
||||
local brain = require("codetyper.brain")
|
||||
if not brain.is_initialized() then
|
||||
vim.notify("Brain not initialized", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local feedback_type = opts.args:lower()
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
|
||||
if feedback_type == "good" or feedback_type == "accept" or feedback_type == "+" then
|
||||
-- Learn positive feedback
|
||||
brain.learn({
|
||||
type = "user_feedback",
|
||||
file = current_file,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
feedback = "accepted",
|
||||
description = "User marked code as good/accepted",
|
||||
},
|
||||
})
|
||||
vim.notify("Brain: Learned positive feedback ✓", vim.log.levels.INFO)
|
||||
|
||||
elseif feedback_type == "bad" or feedback_type == "reject" or feedback_type == "-" then
|
||||
-- Learn negative feedback
|
||||
brain.learn({
|
||||
type = "user_feedback",
|
||||
file = current_file,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
feedback = "rejected",
|
||||
description = "User marked code as bad/rejected",
|
||||
},
|
||||
})
|
||||
vim.notify("Brain: Learned negative feedback ✗", vim.log.levels.INFO)
|
||||
|
||||
elseif feedback_type == "stats" or feedback_type == "status" then
|
||||
-- Show brain stats
|
||||
local stats = brain.stats()
|
||||
local msg = string.format(
|
||||
"Brain Stats:\n• Nodes: %d\n• Edges: %d\n• Pending: %d\n• Deltas: %d",
|
||||
stats.node_count or 0,
|
||||
stats.edge_count or 0,
|
||||
stats.pending_changes or 0,
|
||||
stats.delta_count or 0
|
||||
)
|
||||
vim.notify(msg, vim.log.levels.INFO)
|
||||
|
||||
else
|
||||
vim.notify("Usage: CoderFeedback <good|bad|stats>", vim.log.levels.INFO)
|
||||
end
|
||||
end, {
|
||||
desc = "Give feedback to the brain (good/bad/stats)",
|
||||
nargs = "?",
|
||||
complete = function()
|
||||
return { "good", "bad", "stats" }
|
||||
end,
|
||||
})
|
||||
|
||||
-- Brain stats command
|
||||
vim.api.nvim_create_user_command("CoderBrain", function(opts)
|
||||
local brain = require("codetyper.brain")
|
||||
if not brain.is_initialized() then
|
||||
vim.notify("Brain not initialized", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local action = opts.args:lower()
|
||||
|
||||
if action == "stats" or action == "" then
|
||||
local stats = brain.stats()
|
||||
local lines = {
|
||||
"╭─────────────────────────────────╮",
|
||||
"│ CODETYPER BRAIN │",
|
||||
"╰─────────────────────────────────╯",
|
||||
"",
|
||||
string.format(" Nodes: %d", stats.node_count or 0),
|
||||
string.format(" Edges: %d", stats.edge_count or 0),
|
||||
string.format(" Deltas: %d", stats.delta_count or 0),
|
||||
string.format(" Pending: %d", stats.pending_changes or 0),
|
||||
"",
|
||||
" The more you use Codetyper,",
|
||||
" the smarter it becomes!",
|
||||
}
|
||||
vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO)
|
||||
|
||||
elseif action == "commit" then
|
||||
local hash = brain.commit("Manual commit")
|
||||
if hash then
|
||||
vim.notify("Brain: Committed changes (hash: " .. hash:sub(1, 8) .. ")", vim.log.levels.INFO)
|
||||
else
|
||||
vim.notify("Brain: Nothing to commit", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
elseif action == "flush" then
|
||||
brain.flush()
|
||||
vim.notify("Brain: Flushed to disk", vim.log.levels.INFO)
|
||||
|
||||
elseif action == "prune" then
|
||||
local pruned = brain.prune()
|
||||
vim.notify("Brain: Pruned " .. pruned .. " low-value nodes", vim.log.levels.INFO)
|
||||
|
||||
else
|
||||
vim.notify("Usage: CoderBrain <stats|commit|flush|prune>", vim.log.levels.INFO)
|
||||
end
|
||||
end, {
|
||||
desc = "Brain management commands",
|
||||
nargs = "?",
|
||||
complete = function()
|
||||
return { "stats", "commit", "flush", "prune" }
|
||||
end,
|
||||
})
|
||||
|
||||
-- Cost estimation command
|
||||
vim.api.nvim_create_user_command("CoderCost", function()
|
||||
local cost = require("codetyper.cost")
|
||||
cost.toggle()
|
||||
end, { desc = "Show LLM cost estimation window" })
|
||||
|
||||
-- Credentials management commands
|
||||
vim.api.nvim_create_user_command("CoderAddApiKey", function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_add()
|
||||
end, { desc = "Add or update LLM provider API key" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderRemoveApiKey", function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_remove()
|
||||
end, { desc = "Remove LLM provider credentials" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderCredentials", function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.show_status()
|
||||
end, { desc = "Show credentials status" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderSwitchProvider", function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_switch_provider()
|
||||
end, { desc = "Switch active LLM provider" })
|
||||
|
||||
-- Setup default keymaps
|
||||
M.setup_keymaps()
|
||||
end
|
||||
|
||||
750
lua/codetyper/cost.lua
Normal file
750
lua/codetyper/cost.lua
Normal file
@@ -0,0 +1,750 @@
|
||||
---@mod codetyper.cost Cost estimation for LLM usage
|
||||
---@brief [[
|
||||
--- Tracks token usage and estimates costs based on model pricing.
|
||||
--- Prices are per 1M tokens. Persists usage data in the brain.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
--- Cost history file name
|
||||
local COST_HISTORY_FILE = "cost_history.json"
|
||||
|
||||
--- Get path to cost history file
|
||||
---@return string File path
|
||||
local function get_history_path()
|
||||
local root = utils.get_project_root()
|
||||
return root .. "/.coder/" .. COST_HISTORY_FILE
|
||||
end
|
||||
|
||||
--- Default model for savings comparison (what you'd pay if not using Ollama)
|
||||
M.comparison_model = "gpt-4o"
|
||||
|
||||
--- Models considered "free" (Ollama, local, Copilot subscription)
|
||||
M.free_models = {
|
||||
["ollama"] = true,
|
||||
["codellama"] = true,
|
||||
["llama2"] = true,
|
||||
["llama3"] = true,
|
||||
["mistral"] = true,
|
||||
["deepseek-coder"] = true,
|
||||
["copilot"] = true,
|
||||
}
|
||||
|
||||
--- Model pricing table (per 1M tokens in USD)
|
||||
---@type table<string, {input: number, cached_input: number|nil, output: number|nil}>
|
||||
M.pricing = {
|
||||
-- GPT-5.x series
|
||||
["gpt-5.2"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 },
|
||||
["gpt-5-nano"] = { input = 0.05, cached_input = 0.005, output = 0.40 },
|
||||
["gpt-5.2-chat-latest"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.2-codex"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1-codex-max"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.1-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.2-pro"] = { input = 21.00, cached_input = nil, output = 168.00 },
|
||||
["gpt-5-pro"] = { input = 15.00, cached_input = nil, output = 120.00 },
|
||||
["gpt-5.1-codex-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 },
|
||||
["gpt-5-search-api"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
|
||||
-- GPT-4.x series
|
||||
["gpt-4.1"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["gpt-4.1-mini"] = { input = 0.40, cached_input = 0.10, output = 1.60 },
|
||||
["gpt-4.1-nano"] = { input = 0.10, cached_input = 0.025, output = 0.40 },
|
||||
["gpt-4o"] = { input = 2.50, cached_input = 1.25, output = 10.00 },
|
||||
["gpt-4o-2024-05-13"] = { input = 5.00, cached_input = nil, output = 15.00 },
|
||||
["gpt-4o-mini"] = { input = 0.15, cached_input = 0.075, output = 0.60 },
|
||||
|
||||
-- Realtime models
|
||||
["gpt-realtime"] = { input = 4.00, cached_input = 0.40, output = 16.00 },
|
||||
["gpt-realtime-mini"] = { input = 0.60, cached_input = 0.06, output = 2.40 },
|
||||
["gpt-4o-realtime-preview"] = { input = 5.00, cached_input = 2.50, output = 20.00 },
|
||||
["gpt-4o-mini-realtime-preview"] = { input = 0.60, cached_input = 0.30, output = 2.40 },
|
||||
|
||||
-- Audio models
|
||||
["gpt-audio"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
["gpt-audio-mini"] = { input = 0.60, cached_input = nil, output = 2.40 },
|
||||
["gpt-4o-audio-preview"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
["gpt-4o-mini-audio-preview"] = { input = 0.15, cached_input = nil, output = 0.60 },
|
||||
|
||||
-- O-series reasoning models
|
||||
["o1"] = { input = 15.00, cached_input = 7.50, output = 60.00 },
|
||||
["o1-pro"] = { input = 150.00, cached_input = nil, output = 600.00 },
|
||||
["o3-pro"] = { input = 20.00, cached_input = nil, output = 80.00 },
|
||||
["o3"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["o3-deep-research"] = { input = 10.00, cached_input = 2.50, output = 40.00 },
|
||||
["o4-mini"] = { input = 1.10, cached_input = 0.275, output = 4.40 },
|
||||
["o4-mini-deep-research"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["o3-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 },
|
||||
["o1-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 },
|
||||
|
||||
-- Codex
|
||||
["codex-mini-latest"] = { input = 1.50, cached_input = 0.375, output = 6.00 },
|
||||
|
||||
-- Search models
|
||||
["gpt-4o-mini-search-preview"] = { input = 0.15, cached_input = nil, output = 0.60 },
|
||||
["gpt-4o-search-preview"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
|
||||
-- Computer use
|
||||
["computer-use-preview"] = { input = 3.00, cached_input = nil, output = 12.00 },
|
||||
|
||||
-- Image models
|
||||
["gpt-image-1.5"] = { input = 5.00, cached_input = 1.25, output = 10.00 },
|
||||
["chatgpt-image-latest"] = { input = 5.00, cached_input = 1.25, output = 10.00 },
|
||||
["gpt-image-1"] = { input = 5.00, cached_input = 1.25, output = nil },
|
||||
["gpt-image-1-mini"] = { input = 2.00, cached_input = 0.20, output = nil },
|
||||
|
||||
-- Claude models (Anthropic)
|
||||
["claude-3-opus"] = { input = 15.00, cached_input = 7.50, output = 75.00 },
|
||||
["claude-3-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 },
|
||||
["claude-3-haiku"] = { input = 0.25, cached_input = 0.125, output = 1.25 },
|
||||
["claude-3.5-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 },
|
||||
["claude-3.5-haiku"] = { input = 0.80, cached_input = 0.40, output = 4.00 },
|
||||
|
||||
-- Ollama/Local models (free)
|
||||
["ollama"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["codellama"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["llama2"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["llama3"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["mistral"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["deepseek-coder"] = { input = 0, cached_input = 0, output = 0 },
|
||||
|
||||
-- Copilot (included in subscription, but tracking usage)
|
||||
["copilot"] = { input = 0, cached_input = 0, output = 0 },
|
||||
}
|
||||
|
||||
---@class CostUsage
|
||||
---@field model string Model name
|
||||
---@field input_tokens number Input tokens used
|
||||
---@field output_tokens number Output tokens used
|
||||
---@field cached_tokens number Cached input tokens
|
||||
---@field timestamp number Unix timestamp
|
||||
---@field cost number Calculated cost in USD
|
||||
|
||||
---@class CostState
|
||||
---@field usage CostUsage[] Current session usage
|
||||
---@field all_usage CostUsage[] All historical usage from brain
|
||||
---@field session_start number Session start timestamp
|
||||
---@field win number|nil Window handle
|
||||
---@field buf number|nil Buffer handle
|
||||
---@field loaded boolean Whether historical data has been loaded
|
||||
local state = {
|
||||
usage = {},
|
||||
all_usage = {},
|
||||
session_start = os.time(),
|
||||
win = nil,
|
||||
buf = nil,
|
||||
loaded = false,
|
||||
}
|
||||
|
||||
--- Load historical usage from disk
|
||||
function M.load_from_history()
|
||||
if state.loaded then
|
||||
return
|
||||
end
|
||||
|
||||
local history_path = get_history_path()
|
||||
local content = utils.read_file(history_path)
|
||||
|
||||
if content and content ~= "" then
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if ok and data and data.usage then
|
||||
state.all_usage = data.usage
|
||||
end
|
||||
end
|
||||
|
||||
state.loaded = true
|
||||
end
|
||||
|
||||
--- Save all usage to disk (debounced)
|
||||
local save_timer = nil
|
||||
local function save_to_disk()
|
||||
-- Cancel existing timer
|
||||
if save_timer then
|
||||
save_timer:stop()
|
||||
save_timer = nil
|
||||
end
|
||||
|
||||
-- Debounce writes (500ms)
|
||||
save_timer = vim.loop.new_timer()
|
||||
save_timer:start(500, 0, vim.schedule_wrap(function()
|
||||
local history_path = get_history_path()
|
||||
|
||||
-- Ensure directory exists
|
||||
local dir = vim.fn.fnamemodify(history_path, ":h")
|
||||
utils.ensure_dir(dir)
|
||||
|
||||
-- Merge session and historical usage
|
||||
local all_data = vim.deepcopy(state.all_usage)
|
||||
for _, usage in ipairs(state.usage) do
|
||||
table.insert(all_data, usage)
|
||||
end
|
||||
|
||||
-- Save to file
|
||||
local data = {
|
||||
version = 1,
|
||||
updated = os.time(),
|
||||
usage = all_data,
|
||||
}
|
||||
|
||||
local ok, json = pcall(vim.json.encode, data)
|
||||
if ok then
|
||||
utils.write_file(history_path, json)
|
||||
end
|
||||
|
||||
save_timer = nil
|
||||
end))
|
||||
end
|
||||
|
||||
--- Normalize model name for pricing lookup
|
||||
---@param model string Model name from API
|
||||
---@return string Normalized model name
|
||||
local function normalize_model(model)
|
||||
if not model then
|
||||
return "unknown"
|
||||
end
|
||||
|
||||
-- Convert to lowercase
|
||||
local normalized = model:lower()
|
||||
|
||||
-- Handle Copilot models
|
||||
if normalized:match("copilot") then
|
||||
return "copilot"
|
||||
end
|
||||
|
||||
-- Handle common prefixes
|
||||
normalized = normalized:gsub("^openai/", "")
|
||||
normalized = normalized:gsub("^anthropic/", "")
|
||||
|
||||
-- Try exact match first
|
||||
if M.pricing[normalized] then
|
||||
return normalized
|
||||
end
|
||||
|
||||
-- Try partial matches
|
||||
for price_model, _ in pairs(M.pricing) do
|
||||
if normalized:match(price_model) or price_model:match(normalized) then
|
||||
return price_model
|
||||
end
|
||||
end
|
||||
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if a model is considered "free" (local/Ollama/Copilot subscription)
|
||||
---@param model string Model name
|
||||
---@return boolean True if free
|
||||
function M.is_free_model(model)
|
||||
local normalized = normalize_model(model)
|
||||
|
||||
-- Check direct match
|
||||
if M.free_models[normalized] then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check if it's an Ollama model (any model with : in name like deepseek-coder:6.7b)
|
||||
if model:match(":") then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check pricing - if cost is 0, it's free
|
||||
local pricing = M.pricing[normalized]
|
||||
if pricing and pricing.input == 0 and pricing.output == 0 then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Calculate cost for token usage
|
||||
---@param model string Model name
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
---@return number Cost in USD
|
||||
function M.calculate_cost(model, input_tokens, output_tokens, cached_tokens)
|
||||
local normalized = normalize_model(model)
|
||||
local pricing = M.pricing[normalized]
|
||||
|
||||
if not pricing then
|
||||
-- Unknown model, return 0
|
||||
return 0
|
||||
end
|
||||
|
||||
cached_tokens = cached_tokens or 0
|
||||
local regular_input = input_tokens - cached_tokens
|
||||
|
||||
-- Calculate cost (prices are per 1M tokens)
|
||||
local input_cost = (regular_input / 1000000) * (pricing.input or 0)
|
||||
local cached_cost = (cached_tokens / 1000000) * (pricing.cached_input or pricing.input or 0)
|
||||
local output_cost = (output_tokens / 1000000) * (pricing.output or 0)
|
||||
|
||||
return input_cost + cached_cost + output_cost
|
||||
end
|
||||
|
||||
--- Calculate estimated savings (what would have been paid if using comparison model)
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
---@return number Estimated savings in USD
|
||||
function M.calculate_savings(input_tokens, output_tokens, cached_tokens)
|
||||
-- Calculate what it would have cost with the comparison model
|
||||
return M.calculate_cost(M.comparison_model, input_tokens, output_tokens, cached_tokens)
|
||||
end
|
||||
|
||||
--- Record token usage
|
||||
---@param model string Model name
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
function M.record_usage(model, input_tokens, output_tokens, cached_tokens)
|
||||
cached_tokens = cached_tokens or 0
|
||||
local cost = M.calculate_cost(model, input_tokens, output_tokens, cached_tokens)
|
||||
|
||||
-- Calculate savings if using a free model
|
||||
local savings = 0
|
||||
if M.is_free_model(model) then
|
||||
savings = M.calculate_savings(input_tokens, output_tokens, cached_tokens)
|
||||
end
|
||||
|
||||
table.insert(state.usage, {
|
||||
model = model,
|
||||
input_tokens = input_tokens,
|
||||
output_tokens = output_tokens,
|
||||
cached_tokens = cached_tokens,
|
||||
timestamp = os.time(),
|
||||
cost = cost,
|
||||
savings = savings,
|
||||
is_free = M.is_free_model(model),
|
||||
})
|
||||
|
||||
-- Save to disk (debounced)
|
||||
save_to_disk()
|
||||
|
||||
-- Update window if open
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
M.refresh_window()
|
||||
end
|
||||
end
|
||||
|
||||
--- Aggregate usage data into stats
|
||||
---@param usage_list CostUsage[] List of usage records
|
||||
---@return table Stats
|
||||
local function aggregate_usage(usage_list)
|
||||
local stats = {
|
||||
total_input = 0,
|
||||
total_output = 0,
|
||||
total_cached = 0,
|
||||
total_cost = 0,
|
||||
total_savings = 0,
|
||||
free_requests = 0,
|
||||
paid_requests = 0,
|
||||
by_model = {},
|
||||
request_count = #usage_list,
|
||||
}
|
||||
|
||||
for _, usage in ipairs(usage_list) do
|
||||
stats.total_input = stats.total_input + (usage.input_tokens or 0)
|
||||
stats.total_output = stats.total_output + (usage.output_tokens or 0)
|
||||
stats.total_cached = stats.total_cached + (usage.cached_tokens or 0)
|
||||
stats.total_cost = stats.total_cost + (usage.cost or 0)
|
||||
|
||||
-- Track savings
|
||||
local usage_savings = usage.savings or 0
|
||||
-- For historical data without savings field, calculate it
|
||||
if usage_savings == 0 and usage.is_free == nil then
|
||||
local model = usage.model or "unknown"
|
||||
if M.is_free_model(model) then
|
||||
usage_savings = M.calculate_savings(
|
||||
usage.input_tokens or 0,
|
||||
usage.output_tokens or 0,
|
||||
usage.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
end
|
||||
stats.total_savings = stats.total_savings + usage_savings
|
||||
|
||||
-- Track free vs paid
|
||||
local is_free = usage.is_free
|
||||
if is_free == nil then
|
||||
is_free = M.is_free_model(usage.model or "unknown")
|
||||
end
|
||||
if is_free then
|
||||
stats.free_requests = stats.free_requests + 1
|
||||
else
|
||||
stats.paid_requests = stats.paid_requests + 1
|
||||
end
|
||||
|
||||
local model = usage.model or "unknown"
|
||||
if not stats.by_model[model] then
|
||||
stats.by_model[model] = {
|
||||
input_tokens = 0,
|
||||
output_tokens = 0,
|
||||
cached_tokens = 0,
|
||||
cost = 0,
|
||||
savings = 0,
|
||||
requests = 0,
|
||||
is_free = is_free,
|
||||
}
|
||||
end
|
||||
|
||||
stats.by_model[model].input_tokens = stats.by_model[model].input_tokens + (usage.input_tokens or 0)
|
||||
stats.by_model[model].output_tokens = stats.by_model[model].output_tokens + (usage.output_tokens or 0)
|
||||
stats.by_model[model].cached_tokens = stats.by_model[model].cached_tokens + (usage.cached_tokens or 0)
|
||||
stats.by_model[model].cost = stats.by_model[model].cost + (usage.cost or 0)
|
||||
stats.by_model[model].savings = stats.by_model[model].savings + usage_savings
|
||||
stats.by_model[model].requests = stats.by_model[model].requests + 1
|
||||
end
|
||||
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Get session statistics
|
||||
---@return table Statistics
|
||||
function M.get_stats()
|
||||
local stats = aggregate_usage(state.usage)
|
||||
stats.session_duration = os.time() - state.session_start
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Get all-time statistics (session + historical)
|
||||
---@return table Statistics
|
||||
function M.get_all_time_stats()
|
||||
-- Load history if not loaded
|
||||
M.load_from_history()
|
||||
|
||||
-- Combine session and historical usage
|
||||
local all_usage = vim.deepcopy(state.all_usage)
|
||||
for _, usage in ipairs(state.usage) do
|
||||
table.insert(all_usage, usage)
|
||||
end
|
||||
|
||||
local stats = aggregate_usage(all_usage)
|
||||
|
||||
-- Calculate time span
|
||||
if #all_usage > 0 then
|
||||
local oldest = all_usage[1].timestamp or os.time()
|
||||
for _, usage in ipairs(all_usage) do
|
||||
if usage.timestamp and usage.timestamp < oldest then
|
||||
oldest = usage.timestamp
|
||||
end
|
||||
end
|
||||
stats.time_span = os.time() - oldest
|
||||
else
|
||||
stats.time_span = 0
|
||||
end
|
||||
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Format cost as string
|
||||
---@param cost number Cost in USD
|
||||
---@return string Formatted cost
|
||||
local function format_cost(cost)
|
||||
if cost < 0.01 then
|
||||
return string.format("$%.4f", cost)
|
||||
elseif cost < 1 then
|
||||
return string.format("$%.3f", cost)
|
||||
else
|
||||
return string.format("$%.2f", cost)
|
||||
end
|
||||
end
|
||||
|
||||
--- Format token count
|
||||
---@param tokens number Token count
|
||||
---@return string Formatted count
|
||||
local function format_tokens(tokens)
|
||||
if tokens >= 1000000 then
|
||||
return string.format("%.2fM", tokens / 1000000)
|
||||
elseif tokens >= 1000 then
|
||||
return string.format("%.1fK", tokens / 1000)
|
||||
else
|
||||
return tostring(tokens)
|
||||
end
|
||||
end
|
||||
|
||||
--- Format duration
|
||||
---@param seconds number Duration in seconds
|
||||
---@return string Formatted duration
|
||||
local function format_duration(seconds)
|
||||
if seconds < 60 then
|
||||
return string.format("%ds", seconds)
|
||||
elseif seconds < 3600 then
|
||||
return string.format("%dm %ds", math.floor(seconds / 60), seconds % 60)
|
||||
else
|
||||
local hours = math.floor(seconds / 3600)
|
||||
local mins = math.floor((seconds % 3600) / 60)
|
||||
return string.format("%dh %dm", hours, mins)
|
||||
end
|
||||
end
|
||||
|
||||
--- Generate model breakdown section
|
||||
---@param stats table Stats with by_model
|
||||
---@return string[] Lines
|
||||
local function generate_model_breakdown(stats)
|
||||
local lines = {}
|
||||
|
||||
if next(stats.by_model) then
|
||||
-- Sort models by cost (descending)
|
||||
local models = {}
|
||||
for model, data in pairs(stats.by_model) do
|
||||
table.insert(models, { name = model, data = data })
|
||||
end
|
||||
table.sort(models, function(a, b)
|
||||
return a.data.cost > b.data.cost
|
||||
end)
|
||||
|
||||
for _, item in ipairs(models) do
|
||||
local model = item.name
|
||||
local data = item.data
|
||||
local pricing = M.pricing[normalize_model(model)]
|
||||
local is_free = data.is_free or M.is_free_model(model)
|
||||
|
||||
table.insert(lines, "")
|
||||
local model_icon = is_free and "🆓" or "💳"
|
||||
table.insert(lines, string.format(" %s %s", model_icon, model))
|
||||
table.insert(lines, string.format(" Requests: %d", data.requests))
|
||||
table.insert(lines, string.format(" Input: %s tokens", format_tokens(data.input_tokens)))
|
||||
table.insert(lines, string.format(" Output: %s tokens", format_tokens(data.output_tokens)))
|
||||
|
||||
if is_free then
|
||||
-- Show savings for free models
|
||||
if data.savings and data.savings > 0 then
|
||||
table.insert(lines, string.format(" Saved: %s", format_cost(data.savings)))
|
||||
end
|
||||
else
|
||||
table.insert(lines, string.format(" Cost: %s", format_cost(data.cost)))
|
||||
end
|
||||
|
||||
-- Show pricing info for paid models
|
||||
if pricing and not is_free then
|
||||
local price_info = string.format(
|
||||
" Rate: $%.2f/1M in, $%.2f/1M out",
|
||||
pricing.input or 0,
|
||||
pricing.output or 0
|
||||
)
|
||||
table.insert(lines, price_info)
|
||||
end
|
||||
end
|
||||
else
|
||||
table.insert(lines, " No usage recorded.")
|
||||
end
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Generate window content
|
||||
---@return string[] Lines for the buffer
|
||||
local function generate_content()
|
||||
local session_stats = M.get_stats()
|
||||
local all_time_stats = M.get_all_time_stats()
|
||||
local lines = {}
|
||||
|
||||
-- Header
|
||||
table.insert(lines, "╔══════════════════════════════════════════════════════╗")
|
||||
table.insert(lines, "║ 💰 LLM Cost Estimation ║")
|
||||
table.insert(lines, "╠══════════════════════════════════════════════════════╣")
|
||||
table.insert(lines, "")
|
||||
|
||||
-- All-time summary (prominent)
|
||||
table.insert(lines, "🌐 All-Time Summary (Project)")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
if all_time_stats.time_span > 0 then
|
||||
table.insert(lines, string.format(" Time span: %s", format_duration(all_time_stats.time_span)))
|
||||
end
|
||||
table.insert(lines, string.format(" Requests: %d total", all_time_stats.request_count))
|
||||
table.insert(lines, string.format(" Local/Free: %d requests", all_time_stats.free_requests or 0))
|
||||
table.insert(lines, string.format(" Paid API: %d requests", all_time_stats.paid_requests or 0))
|
||||
table.insert(lines, string.format(" Input tokens: %s", format_tokens(all_time_stats.total_input)))
|
||||
table.insert(lines, string.format(" Output tokens: %s", format_tokens(all_time_stats.total_output)))
|
||||
if all_time_stats.total_cached > 0 then
|
||||
table.insert(lines, string.format(" Cached tokens: %s", format_tokens(all_time_stats.total_cached)))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, string.format(" 💵 Total Cost: %s", format_cost(all_time_stats.total_cost)))
|
||||
|
||||
-- Show savings prominently if there are any
|
||||
if all_time_stats.total_savings and all_time_stats.total_savings > 0 then
|
||||
table.insert(lines, string.format(" 💚 Saved: %s (vs %s)", format_cost(all_time_stats.total_savings), M.comparison_model))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Session summary
|
||||
table.insert(lines, "📊 Current Session")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
table.insert(lines, string.format(" Duration: %s", format_duration(session_stats.session_duration)))
|
||||
table.insert(lines, string.format(" Requests: %d (%d free, %d paid)",
|
||||
session_stats.request_count,
|
||||
session_stats.free_requests or 0,
|
||||
session_stats.paid_requests or 0))
|
||||
table.insert(lines, string.format(" Input tokens: %s", format_tokens(session_stats.total_input)))
|
||||
table.insert(lines, string.format(" Output tokens: %s", format_tokens(session_stats.total_output)))
|
||||
if session_stats.total_cached > 0 then
|
||||
table.insert(lines, string.format(" Cached tokens: %s", format_tokens(session_stats.total_cached)))
|
||||
end
|
||||
table.insert(lines, string.format(" Session Cost: %s", format_cost(session_stats.total_cost)))
|
||||
if session_stats.total_savings and session_stats.total_savings > 0 then
|
||||
table.insert(lines, string.format(" Session Saved: %s", format_cost(session_stats.total_savings)))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Per-model breakdown (all-time)
|
||||
table.insert(lines, "📈 Cost by Model (All-Time)")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
local model_lines = generate_model_breakdown(all_time_stats)
|
||||
for _, line in ipairs(model_lines) do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
table.insert(lines, " 'q' close | 'r' refresh | 'c' clear session | 'C' all")
|
||||
table.insert(lines, "╚══════════════════════════════════════════════════════╝")
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Refresh the cost window content
|
||||
function M.refresh_window()
|
||||
if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then
|
||||
return
|
||||
end
|
||||
|
||||
local lines = generate_content()
|
||||
|
||||
vim.bo[state.buf].modifiable = true
|
||||
vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, lines)
|
||||
vim.bo[state.buf].modifiable = false
|
||||
end
|
||||
|
||||
--- Open the cost estimation window
|
||||
function M.open()
|
||||
-- Load historical data if not loaded
|
||||
M.load_from_history()
|
||||
|
||||
-- Close existing window if open
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
vim.api.nvim_win_close(state.win, true)
|
||||
end
|
||||
|
||||
-- Create buffer
|
||||
state.buf = vim.api.nvim_create_buf(false, true)
|
||||
vim.bo[state.buf].buftype = "nofile"
|
||||
vim.bo[state.buf].bufhidden = "wipe"
|
||||
vim.bo[state.buf].swapfile = false
|
||||
vim.bo[state.buf].filetype = "codetyper-cost"
|
||||
|
||||
-- Calculate window size
|
||||
local width = 58
|
||||
local height = 40
|
||||
local row = math.floor((vim.o.lines - height) / 2)
|
||||
local col = math.floor((vim.o.columns - width) / 2)
|
||||
|
||||
-- Create floating window
|
||||
state.win = vim.api.nvim_open_win(state.buf, true, {
|
||||
relative = "editor",
|
||||
width = width,
|
||||
height = height,
|
||||
row = row,
|
||||
col = col,
|
||||
style = "minimal",
|
||||
border = "rounded",
|
||||
title = " Cost Estimation ",
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
-- Set window options
|
||||
vim.wo[state.win].wrap = false
|
||||
vim.wo[state.win].cursorline = false
|
||||
|
||||
-- Populate content
|
||||
M.refresh_window()
|
||||
|
||||
-- Set up keymaps
|
||||
local opts = { buffer = state.buf, silent = true }
|
||||
vim.keymap.set("n", "q", function()
|
||||
M.close()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "<Esc>", function()
|
||||
M.close()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "r", function()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "c", function()
|
||||
M.clear_session()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "C", function()
|
||||
M.clear_all()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
|
||||
-- Set up highlights
|
||||
vim.api.nvim_buf_call(state.buf, function()
|
||||
vim.fn.matchadd("Title", "LLM Cost Estimation")
|
||||
vim.fn.matchadd("Number", "\\$[0-9.]*")
|
||||
vim.fn.matchadd("Keyword", "[0-9.]*[KM]\\? tokens")
|
||||
vim.fn.matchadd("Special", "🤖\\|💰\\|📊\\|📈\\|💵")
|
||||
end)
|
||||
end
|
||||
|
||||
--- Close the cost window
|
||||
function M.close()
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
vim.api.nvim_win_close(state.win, true)
|
||||
end
|
||||
state.win = nil
|
||||
state.buf = nil
|
||||
end
|
||||
|
||||
--- Toggle the cost window
|
||||
function M.toggle()
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
M.close()
|
||||
else
|
||||
M.open()
|
||||
end
|
||||
end
|
||||
|
||||
--- Clear session usage (not history)
|
||||
function M.clear_session()
|
||||
state.usage = {}
|
||||
state.session_start = os.time()
|
||||
utils.notify("Session cost tracking cleared", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Clear all history (session + saved)
|
||||
function M.clear_all()
|
||||
state.usage = {}
|
||||
state.all_usage = {}
|
||||
state.session_start = os.time()
|
||||
state.loaded = false
|
||||
|
||||
-- Delete history file
|
||||
local history_path = get_history_path()
|
||||
local ok, err = os.remove(history_path)
|
||||
if not ok and err and not err:match("No such file") then
|
||||
utils.notify("Failed to delete history: " .. err, vim.log.levels.WARN)
|
||||
end
|
||||
|
||||
utils.notify("All cost history cleared", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Clear usage history (alias for clear_session)
|
||||
function M.clear()
|
||||
M.clear_session()
|
||||
end
|
||||
|
||||
--- Reset session
|
||||
function M.reset()
|
||||
M.clear_session()
|
||||
end
|
||||
|
||||
return M
|
||||
602
lua/codetyper/credentials.lua
Normal file
602
lua/codetyper/credentials.lua
Normal file
@@ -0,0 +1,602 @@
|
||||
---@mod codetyper.credentials Secure credential storage for Codetyper.nvim
|
||||
---@brief [[
|
||||
--- Manages API keys and model preferences stored outside of config files.
|
||||
--- Credentials are stored in ~/.local/share/nvim/codetyper/configuration.json
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
--- Get the credentials file path
|
||||
---@return string Path to credentials file
|
||||
local function get_credentials_path()
|
||||
local data_dir = vim.fn.stdpath("data")
|
||||
return data_dir .. "/codetyper/configuration.json"
|
||||
end
|
||||
|
||||
--- Ensure the credentials directory exists
|
||||
---@return boolean Success
|
||||
local function ensure_dir()
|
||||
local data_dir = vim.fn.stdpath("data")
|
||||
local codetyper_dir = data_dir .. "/codetyper"
|
||||
return utils.ensure_dir(codetyper_dir)
|
||||
end
|
||||
|
||||
--- Load credentials from file
|
||||
---@return table Credentials data
|
||||
function M.load()
|
||||
local path = get_credentials_path()
|
||||
local content = utils.read_file(path)
|
||||
|
||||
if not content or content == "" then
|
||||
return {
|
||||
version = 1,
|
||||
providers = {},
|
||||
}
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok or not data then
|
||||
return {
|
||||
version = 1,
|
||||
providers = {},
|
||||
}
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Save credentials to file
|
||||
---@param data table Credentials data
|
||||
---@return boolean Success
|
||||
function M.save(data)
|
||||
if not ensure_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local path = get_credentials_path()
|
||||
local ok, json = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, json)
|
||||
end
|
||||
|
||||
--- Get API key for a provider
|
||||
---@param provider string Provider name (claude, openai, gemini, copilot, ollama)
|
||||
---@return string|nil API key or nil if not found
|
||||
function M.get_api_key(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.api_key then
|
||||
return provider_data.api_key
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get model for a provider
|
||||
---@param provider string Provider name
|
||||
---@return string|nil Model name or nil if not found
|
||||
function M.get_model(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.model then
|
||||
return provider_data.model
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get endpoint for a provider (for custom OpenAI-compatible endpoints)
|
||||
---@param provider string Provider name
|
||||
---@return string|nil Endpoint URL or nil if not found
|
||||
function M.get_endpoint(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.endpoint then
|
||||
return provider_data.endpoint
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get host for Ollama
|
||||
---@return string|nil Host URL or nil if not found
|
||||
function M.get_ollama_host()
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers.ollama
|
||||
|
||||
if provider_data and provider_data.host then
|
||||
return provider_data.host
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Set credentials for a provider
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Credentials (api_key, model, endpoint, host)
|
||||
---@return boolean Success
|
||||
function M.set_credentials(provider, credentials)
|
||||
local data = M.load()
|
||||
|
||||
if not data.providers then
|
||||
data.providers = {}
|
||||
end
|
||||
|
||||
if not data.providers[provider] then
|
||||
data.providers[provider] = {}
|
||||
end
|
||||
|
||||
-- Merge credentials
|
||||
for key, value in pairs(credentials) do
|
||||
if value and value ~= "" then
|
||||
data.providers[provider][key] = value
|
||||
end
|
||||
end
|
||||
|
||||
data.updated = os.time()
|
||||
|
||||
return M.save(data)
|
||||
end
|
||||
|
||||
--- Remove credentials for a provider
|
||||
---@param provider string Provider name
|
||||
---@return boolean Success
|
||||
function M.remove_credentials(provider)
|
||||
local data = M.load()
|
||||
|
||||
if data.providers and data.providers[provider] then
|
||||
data.providers[provider] = nil
|
||||
data.updated = os.time()
|
||||
return M.save(data)
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- List all configured providers (checks both stored credentials AND config)
|
||||
---@return table List of provider names with their config status
|
||||
function M.list_providers()
|
||||
local data = M.load()
|
||||
local result = {}
|
||||
|
||||
local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
|
||||
for _, provider in ipairs(all_providers) do
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
local has_stored_key = provider_data and provider_data.api_key and provider_data.api_key ~= ""
|
||||
local has_model = provider_data and provider_data.model and provider_data.model ~= ""
|
||||
|
||||
-- Check if configured from config or environment
|
||||
local configured_from_config = false
|
||||
local config_model = nil
|
||||
local ok, codetyper = pcall(require, "codetyper")
|
||||
if ok then
|
||||
local config = codetyper.get_config()
|
||||
if config and config.llm and config.llm[provider] then
|
||||
local pc = config.llm[provider]
|
||||
config_model = pc.model
|
||||
|
||||
if provider == "claude" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.ANTHROPIC_API_KEY ~= nil
|
||||
elseif provider == "openai" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.OPENAI_API_KEY ~= nil
|
||||
elseif provider == "gemini" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.GEMINI_API_KEY ~= nil
|
||||
elseif provider == "copilot" then
|
||||
configured_from_config = true -- Just needs copilot.lua
|
||||
elseif provider == "ollama" then
|
||||
configured_from_config = pc.host ~= nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local is_configured = has_stored_key
|
||||
or (provider == "ollama" and provider_data ~= nil)
|
||||
or (provider == "copilot" and (provider_data ~= nil or configured_from_config))
|
||||
or configured_from_config
|
||||
|
||||
table.insert(result, {
|
||||
name = provider,
|
||||
configured = is_configured,
|
||||
has_api_key = has_stored_key,
|
||||
has_model = has_model or config_model ~= nil,
|
||||
model = (provider_data and provider_data.model) or config_model,
|
||||
source = has_stored_key and "stored" or (configured_from_config and "config" or nil),
|
||||
})
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Default models for each provider
|
||||
M.default_models = {
|
||||
claude = "claude-sonnet-4-20250514",
|
||||
openai = "gpt-4o",
|
||||
gemini = "gemini-2.0-flash",
|
||||
copilot = "gpt-4o",
|
||||
ollama = "deepseek-coder:6.7b",
|
||||
}
|
||||
|
||||
--- Interactive command to add/update API key
|
||||
function M.interactive_add()
|
||||
local providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
|
||||
-- Step 1: Select provider
|
||||
vim.ui.select(providers, {
|
||||
prompt = "Select LLM provider:",
|
||||
format_item = function(item)
|
||||
local display = item:sub(1, 1):upper() .. item:sub(2)
|
||||
local creds = M.load()
|
||||
local configured = creds.providers and creds.providers[item]
|
||||
if configured and (configured.api_key or item == "ollama") then
|
||||
return display .. " [configured]"
|
||||
end
|
||||
return display
|
||||
end,
|
||||
}, function(provider)
|
||||
if not provider then
|
||||
return
|
||||
end
|
||||
|
||||
-- Step 2: Get API key (skip for Ollama)
|
||||
if provider == "ollama" then
|
||||
M.interactive_ollama_config()
|
||||
else
|
||||
M.interactive_api_key(provider)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive API key input
|
||||
---@param provider string Provider name
|
||||
function M.interactive_api_key(provider)
|
||||
-- Copilot uses OAuth from copilot.lua, no API key needed
|
||||
if provider == "copilot" then
|
||||
M.interactive_copilot_config()
|
||||
return
|
||||
end
|
||||
|
||||
local prompt = string.format("Enter %s API key (leave empty to skip): ", provider:upper())
|
||||
|
||||
vim.ui.input({ prompt = prompt }, function(api_key)
|
||||
if api_key == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
-- Step 3: Get model
|
||||
M.interactive_model(provider, api_key)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive Copilot configuration (no API key, uses OAuth)
|
||||
function M.interactive_copilot_config()
|
||||
utils.notify("Copilot uses OAuth from copilot.lua/copilot.vim - no API key needed", vim.log.levels.INFO)
|
||||
|
||||
-- Just ask for model
|
||||
local default_model = M.default_models.copilot
|
||||
vim.ui.input({
|
||||
prompt = string.format("Copilot model (default: %s): ", default_model),
|
||||
default = default_model,
|
||||
}, function(model)
|
||||
if model == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if model == "" then
|
||||
model = default_model
|
||||
end
|
||||
|
||||
M.save_and_notify("copilot", {
|
||||
model = model,
|
||||
-- Mark as configured even without API key
|
||||
configured = true,
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive model selection
|
||||
---@param provider string Provider name
|
||||
---@param api_key string|nil API key
|
||||
function M.interactive_model(provider, api_key)
|
||||
local default_model = M.default_models[provider] or ""
|
||||
local prompt = string.format("Enter model (default: %s): ", default_model)
|
||||
|
||||
vim.ui.input({ prompt = prompt, default = default_model }, function(model)
|
||||
if model == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
-- Use default if empty
|
||||
if model == "" then
|
||||
model = default_model
|
||||
end
|
||||
|
||||
-- Save credentials
|
||||
local credentials = {
|
||||
model = model,
|
||||
}
|
||||
|
||||
if api_key and api_key ~= "" then
|
||||
credentials.api_key = api_key
|
||||
end
|
||||
|
||||
-- For OpenAI, also ask for custom endpoint
|
||||
if provider == "openai" then
|
||||
M.interactive_endpoint(provider, credentials)
|
||||
else
|
||||
M.save_and_notify(provider, credentials)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive endpoint input for OpenAI-compatible providers
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Current credentials
|
||||
function M.interactive_endpoint(provider, credentials)
|
||||
vim.ui.input({
|
||||
prompt = "Custom endpoint (leave empty for default OpenAI): ",
|
||||
}, function(endpoint)
|
||||
if endpoint == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if endpoint ~= "" then
|
||||
credentials.endpoint = endpoint
|
||||
end
|
||||
|
||||
M.save_and_notify(provider, credentials)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive Ollama configuration
|
||||
function M.interactive_ollama_config()
|
||||
vim.ui.input({
|
||||
prompt = "Ollama host (default: http://localhost:11434): ",
|
||||
default = "http://localhost:11434",
|
||||
}, function(host)
|
||||
if host == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if host == "" then
|
||||
host = "http://localhost:11434"
|
||||
end
|
||||
|
||||
-- Get model
|
||||
local default_model = M.default_models.ollama
|
||||
vim.ui.input({
|
||||
prompt = string.format("Ollama model (default: %s): ", default_model),
|
||||
default = default_model,
|
||||
}, function(model)
|
||||
if model == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if model == "" then
|
||||
model = default_model
|
||||
end
|
||||
|
||||
M.save_and_notify("ollama", {
|
||||
host = host,
|
||||
model = model,
|
||||
})
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Save credentials and notify user
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Credentials to save
|
||||
function M.save_and_notify(provider, credentials)
|
||||
if M.set_credentials(provider, credentials) then
|
||||
local msg = string.format("Saved %s configuration", provider:upper())
|
||||
if credentials.model then
|
||||
msg = msg .. " (model: " .. credentials.model .. ")"
|
||||
end
|
||||
utils.notify(msg, vim.log.levels.INFO)
|
||||
else
|
||||
utils.notify("Failed to save credentials", vim.log.levels.ERROR)
|
||||
end
|
||||
end
|
||||
|
||||
--- Show current credentials status
|
||||
function M.show_status()
|
||||
local providers = M.list_providers()
|
||||
|
||||
-- Get current active provider
|
||||
local codetyper = require("codetyper")
|
||||
local current = codetyper.get_config().llm.provider
|
||||
|
||||
local lines = {
|
||||
"Codetyper Credentials Status",
|
||||
"============================",
|
||||
"",
|
||||
"Storage: " .. get_credentials_path(),
|
||||
"Active: " .. current:upper(),
|
||||
"",
|
||||
}
|
||||
|
||||
for _, p in ipairs(providers) do
|
||||
local status_icon = p.configured and "✓" or "✗"
|
||||
local active_marker = p.name == current and " [ACTIVE]" or ""
|
||||
local source_info = ""
|
||||
if p.configured then
|
||||
source_info = p.source == "stored" and " (stored)" or " (config)"
|
||||
end
|
||||
local model_info = p.model and (" - " .. p.model) or ""
|
||||
|
||||
table.insert(lines, string.format(" %s %s%s%s%s",
|
||||
status_icon,
|
||||
p.name:upper(),
|
||||
active_marker,
|
||||
source_info,
|
||||
model_info))
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Commands:")
|
||||
table.insert(lines, " :CoderAddApiKey - Add/update credentials")
|
||||
table.insert(lines, " :CoderSwitchProvider - Switch active provider")
|
||||
table.insert(lines, " :CoderRemoveApiKey - Remove stored credentials")
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Interactive remove credentials
|
||||
function M.interactive_remove()
|
||||
local data = M.load()
|
||||
local configured = {}
|
||||
|
||||
for provider, _ in pairs(data.providers or {}) do
|
||||
table.insert(configured, provider)
|
||||
end
|
||||
|
||||
if #configured == 0 then
|
||||
utils.notify("No credentials configured", vim.log.levels.INFO)
|
||||
return
|
||||
end
|
||||
|
||||
vim.ui.select(configured, {
|
||||
prompt = "Select provider to remove:",
|
||||
}, function(provider)
|
||||
if not provider then
|
||||
return
|
||||
end
|
||||
|
||||
vim.ui.select({ "Yes", "No" }, {
|
||||
prompt = "Remove " .. provider:upper() .. " credentials?",
|
||||
}, function(choice)
|
||||
if choice == "Yes" then
|
||||
if M.remove_credentials(provider) then
|
||||
utils.notify("Removed " .. provider:upper() .. " credentials", vim.log.levels.INFO)
|
||||
else
|
||||
utils.notify("Failed to remove credentials", vim.log.levels.ERROR)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Set the active provider
|
||||
---@param provider string Provider name
|
||||
function M.set_active_provider(provider)
|
||||
local data = M.load()
|
||||
data.active_provider = provider
|
||||
data.updated = os.time()
|
||||
M.save(data)
|
||||
|
||||
-- Also update the runtime config
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
config.llm.provider = provider
|
||||
|
||||
utils.notify("Active provider set to: " .. provider:upper(), vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Get the active provider from stored config
|
||||
---@return string|nil Active provider
|
||||
function M.get_active_provider()
|
||||
local data = M.load()
|
||||
return data.active_provider
|
||||
end
|
||||
|
||||
--- Check if a provider is configured (from stored credentials OR config)
|
||||
---@param provider string Provider name
|
||||
---@return boolean configured, string|nil source
|
||||
local function is_provider_configured(provider)
|
||||
-- Check stored credentials first
|
||||
local data = M.load()
|
||||
local stored = data.providers and data.providers[provider]
|
||||
if stored then
|
||||
if stored.configured or stored.api_key or provider == "ollama" or provider == "copilot" then
|
||||
return true, "stored"
|
||||
end
|
||||
end
|
||||
|
||||
-- Check codetyper config
|
||||
local ok, codetyper = pcall(require, "codetyper")
|
||||
if not ok then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local config = codetyper.get_config()
|
||||
if not config or not config.llm then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local provider_config = config.llm[provider]
|
||||
if not provider_config then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
-- Check for API key in config or environment
|
||||
if provider == "claude" then
|
||||
if provider_config.api_key or vim.env.ANTHROPIC_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "openai" then
|
||||
if provider_config.api_key or vim.env.OPENAI_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "gemini" then
|
||||
if provider_config.api_key or vim.env.GEMINI_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "copilot" then
|
||||
-- Copilot just needs copilot.lua installed
|
||||
return true, "config"
|
||||
elseif provider == "ollama" then
|
||||
-- Ollama just needs host configured
|
||||
if provider_config.host then
|
||||
return true, "config"
|
||||
end
|
||||
end
|
||||
|
||||
return false, nil
|
||||
end
|
||||
|
||||
--- Interactive switch provider
|
||||
function M.interactive_switch_provider()
|
||||
local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
local available = {}
|
||||
local sources = {}
|
||||
|
||||
for _, provider in ipairs(all_providers) do
|
||||
local configured, source = is_provider_configured(provider)
|
||||
if configured then
|
||||
table.insert(available, provider)
|
||||
sources[provider] = source
|
||||
end
|
||||
end
|
||||
|
||||
if #available == 0 then
|
||||
utils.notify("No providers configured. Use :CoderAddApiKey or add to your config.", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local current = codetyper.get_config().llm.provider
|
||||
|
||||
vim.ui.select(available, {
|
||||
prompt = "Select provider (current: " .. current .. "):",
|
||||
format_item = function(item)
|
||||
local marker = item == current and " [active]" or ""
|
||||
local source_marker = sources[item] == "stored" and " (stored)" or " (config)"
|
||||
return item:upper() .. marker .. source_marker
|
||||
end,
|
||||
}, function(provider)
|
||||
if provider then
|
||||
M.set_active_provider(provider)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -83,6 +83,9 @@ local TS_QUERIES = {
|
||||
},
|
||||
}
|
||||
|
||||
-- Forward declaration for analyze_tree_generic (defined below)
|
||||
local analyze_tree_generic
|
||||
|
||||
--- Hash file content for change detection
|
||||
---@param content string
|
||||
---@return string
|
||||
@@ -256,7 +259,7 @@ end
|
||||
---@param root TSNode
|
||||
---@param bufnr number
|
||||
---@return table
|
||||
local function analyze_tree_generic(root, bufnr)
|
||||
analyze_tree_generic = function(root, bufnr)
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
|
||||
@@ -14,6 +14,51 @@ local AUTH_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
---@field github_token table|nil
|
||||
M.state = nil
|
||||
|
||||
--- Track if we've already suggested Ollama fallback this session
|
||||
local ollama_fallback_suggested = false
|
||||
|
||||
--- Suggest switching to Ollama when rate limits are hit
|
||||
---@param error_msg string The error message that triggered this
|
||||
function M.suggest_ollama_fallback(error_msg)
|
||||
if ollama_fallback_suggested then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if Ollama is available
|
||||
local ollama_available = false
|
||||
vim.fn.jobstart({ "curl", "-s", "http://localhost:11434/api/tags" }, {
|
||||
on_exit = function(_, code)
|
||||
if code == 0 then
|
||||
ollama_available = true
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
if ollama_available then
|
||||
-- Switch to Ollama automatically
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
config.llm.provider = "ollama"
|
||||
|
||||
ollama_fallback_suggested = true
|
||||
utils.notify(
|
||||
"⚠️ Copilot rate limit reached. Switched to Ollama automatically.\n"
|
||||
.. "Original error: "
|
||||
.. error_msg:sub(1, 100),
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
else
|
||||
utils.notify(
|
||||
"⚠️ Copilot rate limit reached. Ollama not available.\n"
|
||||
.. "Start Ollama with: ollama serve\n"
|
||||
.. "Or wait for Copilot limits to reset.",
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
end
|
||||
end)
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Get OAuth token from copilot.lua or copilot.vim config
|
||||
---@return string|nil OAuth token
|
||||
local function get_oauth_token()
|
||||
@@ -51,9 +96,16 @@ local function get_oauth_token()
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("copilot")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.copilot.model
|
||||
@@ -204,15 +256,37 @@ local function make_request(token, body, callback)
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
-- Show the actual response text as the error (truncated if too long)
|
||||
local error_msg = response_text
|
||||
if #error_msg > 200 then
|
||||
error_msg = error_msg:sub(1, 200) .. "..."
|
||||
end
|
||||
|
||||
-- Clean up common patterns
|
||||
if response_text:match("<!DOCTYPE") or response_text:match("<html") then
|
||||
error_msg = "Copilot API returned HTML error page. Service may be unavailable."
|
||||
end
|
||||
|
||||
-- Check for rate limit and suggest Ollama fallback
|
||||
if response_text:match("limit") or response_text:match("Upgrade") or response_text:match("quota") then
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Copilot response", nil)
|
||||
callback(nil, error_msg, nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
local error_msg = response.error.message or "Copilot API error"
|
||||
if response.error.code == "rate_limit_exceeded" or (error_msg:match("limit") and error_msg:match("plan")) then
|
||||
error_msg = "Copilot rate limit: " .. error_msg
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error.message or "Copilot API error", nil)
|
||||
callback(nil, error_msg, nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
@@ -220,6 +294,17 @@ local function make_request(token, body, callback)
|
||||
-- Extract usage info
|
||||
local usage = response.usage or {}
|
||||
|
||||
-- Record usage for cost tracking
|
||||
if usage.prompt_tokens or usage.completion_tokens then
|
||||
local cost = require("codetyper.cost")
|
||||
cost.record_usage(
|
||||
get_model(),
|
||||
usage.prompt_tokens or 0,
|
||||
usage.completion_tokens or 0,
|
||||
usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
if response.choices and response.choices[1] and response.choices[1].message then
|
||||
local code = llm.extract_code(response.choices[1].message.content)
|
||||
vim.schedule(function()
|
||||
@@ -362,20 +447,46 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
-- Format messages for Copilot (OpenAI-compatible format)
|
||||
local copilot_messages = { { role = "system", content = system_prompt } }
|
||||
for _, msg in ipairs(messages) do
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(copilot_messages, { role = msg.role, content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(msg.content) do
|
||||
if part.type == "tool_result" then
|
||||
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
|
||||
elseif part.type == "text" then
|
||||
table.insert(text_parts, part.text or "")
|
||||
if msg.role == "user" then
|
||||
-- User messages - handle string or table content
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(copilot_messages, { role = "user", content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
-- Handle complex content (like tool results from user perspective)
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(msg.content) do
|
||||
if part.type == "tool_result" then
|
||||
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
|
||||
elseif part.type == "text" then
|
||||
table.insert(text_parts, part.text or "")
|
||||
end
|
||||
end
|
||||
if #text_parts > 0 then
|
||||
table.insert(copilot_messages, { role = "user", content = table.concat(text_parts, "\n") })
|
||||
end
|
||||
end
|
||||
if #text_parts > 0 then
|
||||
table.insert(copilot_messages, { role = msg.role, content = table.concat(text_parts, "\n") })
|
||||
elseif msg.role == "assistant" then
|
||||
-- Assistant messages - must preserve tool_calls if present
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = type(msg.content) == "string" and msg.content or nil,
|
||||
}
|
||||
-- Preserve tool_calls for the API
|
||||
if msg.tool_calls then
|
||||
assistant_msg.tool_calls = msg.tool_calls
|
||||
-- Ensure content is not nil when tool_calls present
|
||||
if assistant_msg.content == nil then
|
||||
assistant_msg.content = ""
|
||||
end
|
||||
end
|
||||
table.insert(copilot_messages, assistant_msg)
|
||||
elseif msg.role == "tool" then
|
||||
-- Tool result messages - must have tool_call_id
|
||||
table.insert(copilot_messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
@@ -396,6 +507,20 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Copilot API...")
|
||||
|
||||
-- Log request to debug file
|
||||
local debug_log_path = vim.fn.expand("~/.local/codetyper-debug.log")
|
||||
local debug_f = io.open(debug_log_path, "a")
|
||||
if debug_f then
|
||||
debug_f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. "COPILOT REQUEST\n")
|
||||
debug_f:write("Messages count: " .. #copilot_messages .. "\n")
|
||||
for i, m in ipairs(copilot_messages) do
|
||||
debug_f:write(string.format(" [%d] role=%s, has_tool_calls=%s, has_tool_call_id=%s\n",
|
||||
i, m.role, tostring(m.tool_calls ~= nil), tostring(m.tool_call_id ~= nil)))
|
||||
end
|
||||
debug_f:write("---\n")
|
||||
debug_f:close()
|
||||
end
|
||||
|
||||
local headers = build_headers(token)
|
||||
local cmd = {
|
||||
"curl",
|
||||
@@ -413,35 +538,97 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
table.insert(cmd, "-d")
|
||||
table.insert(cmd, json_body)
|
||||
|
||||
-- Debug logging helper
|
||||
local function debug_log(msg, data)
|
||||
local log_path = vim.fn.expand("~/.local/codetyper-debug.log")
|
||||
local f = io.open(log_path, "a")
|
||||
if f then
|
||||
f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. msg .. "\n")
|
||||
if data then
|
||||
f:write("DATA: " .. tostring(data):sub(1, 2000) .. "\n")
|
||||
end
|
||||
f:write("---\n")
|
||||
f:close()
|
||||
end
|
||||
end
|
||||
|
||||
-- Prevent double callback calls
|
||||
local callback_called = false
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if callback_called then
|
||||
debug_log("on_stdout: callback already called, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
debug_log("on_stdout: empty data")
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
debug_log("on_stdout: received response", response_text)
|
||||
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
debug_log("JSON parse failed", response_text)
|
||||
callback_called = true
|
||||
|
||||
-- Show the actual response text as the error (truncated if too long)
|
||||
local error_msg = response_text
|
||||
if #error_msg > 200 then
|
||||
error_msg = error_msg:sub(1, 200) .. "..."
|
||||
end
|
||||
|
||||
-- Clean up common patterns
|
||||
if response_text:match("<!DOCTYPE") or response_text:match("<html") then
|
||||
error_msg = "Copilot API returned HTML error page. Service may be unavailable."
|
||||
end
|
||||
|
||||
-- Check for rate limit and suggest Ollama fallback
|
||||
if response_text:match("limit") or response_text:match("Upgrade") or response_text:match("quota") then
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
logs.error("Failed to parse Copilot response")
|
||||
callback(nil, "Failed to parse Copilot response")
|
||||
logs.error(error_msg)
|
||||
callback(nil, error_msg)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
callback_called = true
|
||||
local error_msg = response.error.message or "Copilot API error"
|
||||
|
||||
-- Check for rate limit in structured error
|
||||
if response.error.code == "rate_limit_exceeded" or (error_msg:match("limit") and error_msg:match("plan")) then
|
||||
error_msg = "Copilot rate limit: " .. error_msg
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
logs.error(response.error.message or "Copilot API error")
|
||||
callback(nil, response.error.message or "Copilot API error")
|
||||
logs.error(error_msg)
|
||||
callback(nil, error_msg)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Log token usage
|
||||
-- Log token usage and record cost
|
||||
if response.usage then
|
||||
logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop")
|
||||
|
||||
-- Record usage for cost tracking
|
||||
local cost_tracker = require("codetyper.cost")
|
||||
cost_tracker.record_usage(
|
||||
get_model(),
|
||||
response.usage.prompt_tokens or 0,
|
||||
response.usage.completion_tokens or 0,
|
||||
response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
-- Convert to Claude-like format for parser compatibility
|
||||
@@ -474,12 +661,19 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
end
|
||||
end
|
||||
|
||||
callback_called = true
|
||||
debug_log("on_stdout: success, calling callback")
|
||||
vim.schedule(function()
|
||||
callback(converted, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if callback_called then
|
||||
return
|
||||
end
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
debug_log("on_stderr", table.concat(data, "\n"))
|
||||
callback_called = true
|
||||
vim.schedule(function()
|
||||
logs.error("Copilot API request failed: " .. table.concat(data, "\n"))
|
||||
callback(nil, "Copilot API request failed: " .. table.concat(data, "\n"))
|
||||
@@ -487,7 +681,12 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
debug_log("on_exit: code=" .. code .. ", callback_called=" .. tostring(callback_called))
|
||||
if callback_called then
|
||||
return
|
||||
end
|
||||
if code ~= 0 then
|
||||
callback_called = true
|
||||
vim.schedule(function()
|
||||
logs.error("Copilot API request failed with code: " .. code)
|
||||
callback(nil, "Copilot API request failed with code: " .. code)
|
||||
|
||||
@@ -8,17 +8,31 @@ local llm = require("codetyper.llm")
|
||||
--- Gemini API endpoint
|
||||
local API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
--- Get API key from config or environment
|
||||
--- Get API key from stored credentials, config, or environment
|
||||
---@return string|nil API key
|
||||
local function get_api_key()
|
||||
-- Priority: stored credentials > config > environment
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_key = credentials.get_api_key("gemini")
|
||||
if stored_key then
|
||||
return stored_key
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.gemini.api_key or vim.env.GEMINI_API_KEY
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("gemini")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.gemini.model
|
||||
|
||||
@@ -32,6 +32,32 @@ function M.generate(prompt, context, callback)
|
||||
client.generate(prompt, context, callback)
|
||||
end
|
||||
|
||||
--- Smart generate with automatic provider selection based on brain memories
|
||||
--- Prefers Ollama when context is rich, falls back to Copilot otherwise.
|
||||
--- Implements verification pondering to reinforce Ollama accuracy over time.
|
||||
---@param prompt string The user's prompt
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback
|
||||
function M.smart_generate(prompt, context, callback)
|
||||
local selector = require("codetyper.llm.selector")
|
||||
selector.smart_generate(prompt, context, callback)
|
||||
end
|
||||
|
||||
--- Get accuracy statistics for providers
|
||||
---@return table Statistics for each provider
|
||||
function M.get_accuracy_stats()
|
||||
local selector = require("codetyper.llm.selector")
|
||||
return selector.get_accuracy_stats()
|
||||
end
|
||||
|
||||
--- Report user feedback on response quality (for reinforcement learning)
|
||||
---@param provider string Which provider generated the response
|
||||
---@param was_correct boolean Whether the response was good
|
||||
function M.report_feedback(provider, was_correct)
|
||||
local selector = require("codetyper.llm.selector")
|
||||
selector.report_feedback(provider, was_correct)
|
||||
end
|
||||
|
||||
--- Build the system prompt for code generation
|
||||
---@param context table Context information
|
||||
---@return string System prompt
|
||||
|
||||
@@ -5,21 +5,33 @@ local M = {}
|
||||
local utils = require("codetyper.utils")
|
||||
local llm = require("codetyper.llm")
|
||||
|
||||
--- Get Ollama host from config
|
||||
--- Get Ollama host from stored credentials or config
|
||||
---@return string Host URL
|
||||
local function get_host()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_host = credentials.get_ollama_host()
|
||||
if stored_host then
|
||||
return stored_host
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.ollama.host
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("ollama")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.ollama.model
|
||||
end
|
||||
|
||||
@@ -199,47 +211,41 @@ function M.validate()
|
||||
return true
|
||||
end
|
||||
|
||||
--- Build system prompt for agent mode with tool instructions
|
||||
--- Generate with tool use support for agentic mode (text-based tool calling)
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@return string System prompt
|
||||
local function build_agent_system_prompt(context)
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil) Callback with Claude-like response format
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
|
||||
local system_prompt = agent_prompts.system .. "\n\n"
|
||||
system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n"
|
||||
system_prompt = system_prompt .. agent_prompts.tool_instructions
|
||||
logs.request("ollama", get_model())
|
||||
logs.thinking("Preparing agent request...")
|
||||
|
||||
-- Add context about current file if available
|
||||
if context.file_path then
|
||||
system_prompt = system_prompt .. "\n\nCurrent working context:\n"
|
||||
system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n"
|
||||
if context.language then
|
||||
system_prompt = system_prompt .. "- Language: " .. context.language .. "\n"
|
||||
-- Build system prompt with tool instructions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
|
||||
|
||||
-- Add tool descriptions
|
||||
system_prompt = system_prompt .. "\n\n## Available Tools\n"
|
||||
system_prompt = system_prompt .. "Call tools by outputting JSON in this exact format:\n"
|
||||
system_prompt = system_prompt .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
|
||||
|
||||
for _, tool in ipairs(tool_definitions) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
system_prompt = system_prompt .. string.format("### %s\n%s\n\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
-- Add project root info
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
system_prompt = system_prompt .. "- Project root: " .. root .. "\n"
|
||||
end
|
||||
|
||||
return system_prompt
|
||||
end
|
||||
|
||||
--- Build request body for Ollama API with tools (chat format)
|
||||
---@param messages table[] Conversation messages
|
||||
---@param context table Context information
|
||||
---@return table Request body
|
||||
local function build_tools_request_body(messages, context)
|
||||
local system_prompt = build_agent_system_prompt(context)
|
||||
|
||||
-- Convert messages to Ollama chat format
|
||||
local ollama_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
local content = msg.content
|
||||
-- Handle complex content (like tool results)
|
||||
if type(content) == "table" then
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(content) do
|
||||
@@ -251,14 +257,10 @@ local function build_tools_request_body(messages, context)
|
||||
end
|
||||
content = table.concat(text_parts, "\n")
|
||||
end
|
||||
|
||||
table.insert(ollama_messages, {
|
||||
role = msg.role,
|
||||
content = content,
|
||||
})
|
||||
table.insert(ollama_messages, { role = msg.role, content = content })
|
||||
end
|
||||
|
||||
return {
|
||||
local body = {
|
||||
model = get_model(),
|
||||
messages = ollama_messages,
|
||||
system = system_prompt,
|
||||
@@ -268,16 +270,15 @@ local function build_tools_request_body(messages, context)
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Make HTTP request to Ollama chat API
|
||||
---@param body table Request body
|
||||
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
|
||||
local function make_chat_request(body, callback)
|
||||
local host = get_host()
|
||||
local url = host .. "/api/chat"
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local prompt_estimate = logs.estimate_tokens(json_body)
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Ollama API...")
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
@@ -302,196 +303,82 @@ local function make_chat_request(body, callback)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response", nil)
|
||||
logs.error("Failed to parse Ollama response")
|
||||
callback(nil, "Failed to parse Ollama response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error", nil)
|
||||
logs.error(response.error or "Ollama API error")
|
||||
callback(nil, response.error or "Ollama API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract usage info
|
||||
local usage = {
|
||||
prompt_tokens = response.prompt_eval_count or 0,
|
||||
response_tokens = response.eval_count or 0,
|
||||
}
|
||||
-- Log token usage and record cost (Ollama is free but we track usage)
|
||||
if response.prompt_eval_count or response.eval_count then
|
||||
logs.response(response.prompt_eval_count or 0, response.eval_count or 0, "stop")
|
||||
|
||||
-- Return the message content for agent parsing
|
||||
if response.message and response.message.content then
|
||||
vim.schedule(function()
|
||||
callback(response.message.content, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No response from Ollama", nil)
|
||||
end)
|
||||
-- Record usage for cost tracking (free for local models)
|
||||
local cost = require("codetyper.cost")
|
||||
cost.record_usage(
|
||||
get_model(),
|
||||
response.prompt_eval_count or 0,
|
||||
response.eval_count or 0,
|
||||
0 -- No cached tokens for Ollama
|
||||
)
|
||||
end
|
||||
|
||||
-- Parse the response text for tool calls
|
||||
local content_text = response.message and response.message.content or ""
|
||||
local converted = { content = {}, stop_reason = "end_turn" }
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = content_text:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok_json, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok_json and parsed.tool then
|
||||
table.insert(converted.content, {
|
||||
type = "tool_use",
|
||||
id = "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF)),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
logs.thinking("Tool call: " .. parsed.tool)
|
||||
content_text = content_text:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
converted.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
-- Add text content
|
||||
if content_text and content_text ~= "" then
|
||||
table.insert(converted.content, 1, { type = "text", text = content_text })
|
||||
logs.thinking("Response contains text")
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(converted, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
|
||||
logs.error("Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
-- Don't double-report errors
|
||||
vim.schedule(function()
|
||||
logs.error("Ollama API request failed with code: " .. code)
|
||||
callback(nil, "Ollama API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Generate response with tools using Ollama API
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tools table Tool definitions (embedded in prompt for Ollama)
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate_with_tools(messages, context, tools, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
-- Log the request
|
||||
local model = get_model()
|
||||
logs.request("ollama", model)
|
||||
logs.thinking("Preparing API request...")
|
||||
|
||||
local body = build_tools_request_body(messages, context)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
|
||||
make_chat_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn")
|
||||
end
|
||||
|
||||
-- Log if response contains tool calls
|
||||
if response then
|
||||
local parser = require("codetyper.agent.parser")
|
||||
local parsed = parser.parse_ollama_response(response)
|
||||
if #parsed.tool_calls > 0 then
|
||||
for _, tc in ipairs(parsed.tool_calls) do
|
||||
logs.thinking("Tool call: " .. tc.name)
|
||||
end
|
||||
end
|
||||
if parsed.text and parsed.text ~= "" then
|
||||
logs.thinking("Response contains text")
|
||||
end
|
||||
end
|
||||
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Generate with tool use support for agentic mode (simulated via prompts)
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback with response text
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
|
||||
-- Build system prompt with agent instructions and tool definitions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. tools_module.to_prompt_format()
|
||||
|
||||
-- Flatten messages to single prompt (Ollama's generate API)
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if type(msg.content) == "string" then
|
||||
local role_prefix = msg.role == "user" and "User" or "Assistant"
|
||||
table.insert(prompt_parts, role_prefix .. ": " .. msg.content)
|
||||
elseif type(msg.content) == "table" then
|
||||
-- Handle tool results
|
||||
for _, item in ipairs(msg.content) do
|
||||
if item.type == "tool_result" then
|
||||
table.insert(prompt_parts, "Tool result: " .. item.content)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local body = {
|
||||
model = get_model(),
|
||||
system = system_prompt,
|
||||
prompt = table.concat(prompt_parts, "\n\n"),
|
||||
stream = false,
|
||||
options = {
|
||||
temperature = 0.2,
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
|
||||
local host = get_host()
|
||||
local url = host .. "/api/generate"
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X", "POST",
|
||||
url,
|
||||
"-H", "Content-Type: application/json",
|
||||
"-d", json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Return raw response text for parser to handle
|
||||
vim.schedule(function()
|
||||
callback(response.response or "", nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -8,25 +8,46 @@ local llm = require("codetyper.llm")
|
||||
--- OpenAI API endpoint
|
||||
local API_URL = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
--- Get API key from config or environment
|
||||
--- Get API key from stored credentials, config, or environment
|
||||
---@return string|nil API key
|
||||
local function get_api_key()
|
||||
-- Priority: stored credentials > config > environment
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_key = credentials.get_api_key("openai")
|
||||
if stored_key then
|
||||
return stored_key
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.api_key or vim.env.OPENAI_API_KEY
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("openai")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.model
|
||||
end
|
||||
|
||||
--- Get endpoint from config (allows custom endpoints like Azure, OpenRouter)
|
||||
--- Get endpoint from stored credentials or config (allows custom endpoints like Azure, OpenRouter)
|
||||
---@return string API endpoint
|
||||
local function get_endpoint()
|
||||
-- Priority: stored credentials > config > default
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_endpoint = credentials.get_endpoint("openai")
|
||||
if stored_endpoint then
|
||||
return stored_endpoint
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.endpoint or API_URL
|
||||
@@ -284,9 +305,18 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
return
|
||||
end
|
||||
|
||||
-- Log token usage
|
||||
-- Log token usage and record cost
|
||||
if response.usage then
|
||||
logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop")
|
||||
|
||||
-- Record usage for cost tracking
|
||||
local cost = require("codetyper.cost")
|
||||
cost.record_usage(
|
||||
model,
|
||||
response.usage.prompt_tokens or 0,
|
||||
response.usage.completion_tokens or 0,
|
||||
response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
-- Convert to Claude-like format for parser compatibility
|
||||
|
||||
514
lua/codetyper/llm/selector.lua
Normal file
514
lua/codetyper/llm/selector.lua
Normal file
@@ -0,0 +1,514 @@
|
||||
---@mod codetyper.llm.selector Smart LLM selection with memory-based confidence
|
||||
---@brief [[
|
||||
--- Intelligent LLM provider selection based on brain memories.
|
||||
--- Prefers local Ollama when context is rich, falls back to Copilot otherwise.
|
||||
--- Implements verification pondering to reinforce Ollama accuracy over time.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class SelectionResult
|
||||
---@field provider string Selected provider name
|
||||
---@field confidence number Confidence score (0-1)
|
||||
---@field memory_count number Number of relevant memories found
|
||||
---@field reason string Human-readable reason for selection
|
||||
|
||||
---@class PonderResult
|
||||
---@field ollama_response string Ollama's response
|
||||
---@field verifier_response string Verifier's response
|
||||
---@field agreement_score number How much they agree (0-1)
|
||||
---@field ollama_correct boolean Whether Ollama was deemed correct
|
||||
---@field feedback string Feedback for learning
|
||||
|
||||
--- Minimum memories required for high confidence
|
||||
local MIN_MEMORIES_FOR_LOCAL = 3
|
||||
|
||||
--- Minimum memory relevance score for local provider
|
||||
local MIN_RELEVANCE_FOR_LOCAL = 0.6
|
||||
|
||||
--- Agreement threshold for Ollama verification
|
||||
local AGREEMENT_THRESHOLD = 0.7
|
||||
|
||||
--- Pondering sample rate (0-1) - how often to verify Ollama
|
||||
local PONDER_SAMPLE_RATE = 0.2
|
||||
|
||||
--- Provider accuracy tracking (persisted in brain)
|
||||
local accuracy_cache = {
|
||||
ollama = { correct = 0, total = 0 },
|
||||
copilot = { correct = 0, total = 0 },
|
||||
}
|
||||
|
||||
--- Get the brain module safely
|
||||
---@return table|nil
|
||||
local function get_brain()
|
||||
local ok, brain = pcall(require, "codetyper.brain")
|
||||
if ok and brain.is_initialized and brain.is_initialized() then
|
||||
return brain
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Load accuracy stats from brain
|
||||
local function load_accuracy_stats()
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return
|
||||
end
|
||||
|
||||
-- Query for accuracy tracking nodes
|
||||
pcall(function()
|
||||
local result = brain.query({
|
||||
query = "provider_accuracy_stats",
|
||||
types = { "metric" },
|
||||
limit = 1,
|
||||
})
|
||||
|
||||
if result and result.nodes and #result.nodes > 0 then
|
||||
local node = result.nodes[1]
|
||||
if node.c and node.c.d then
|
||||
local ok, stats = pcall(vim.json.decode, node.c.d)
|
||||
if ok and stats then
|
||||
accuracy_cache = stats
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Save accuracy stats to brain
|
||||
local function save_accuracy_stats()
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return
|
||||
end
|
||||
|
||||
pcall(function()
|
||||
brain.learn({
|
||||
type = "metric",
|
||||
summary = "provider_accuracy_stats",
|
||||
detail = vim.json.encode(accuracy_cache),
|
||||
weight = 1.0,
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Calculate Ollama confidence based on historical accuracy
|
||||
---@return number confidence (0-1)
|
||||
local function get_ollama_historical_confidence()
|
||||
local stats = accuracy_cache.ollama
|
||||
if stats.total < 5 then
|
||||
-- Not enough data, return neutral confidence
|
||||
return 0.5
|
||||
end
|
||||
|
||||
local accuracy = stats.correct / stats.total
|
||||
-- Boost confidence if accuracy is high
|
||||
return math.min(1.0, accuracy * 1.2)
|
||||
end
|
||||
|
||||
--- Query brain for relevant context
|
||||
---@param prompt string User prompt
|
||||
---@param file_path string|nil Current file path
|
||||
---@return table result {memories: table[], relevance: number, count: number}
|
||||
local function query_brain_context(prompt, file_path)
|
||||
local result = {
|
||||
memories = {},
|
||||
relevance = 0,
|
||||
count = 0,
|
||||
}
|
||||
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return result
|
||||
end
|
||||
|
||||
-- Query brain with multiple dimensions
|
||||
local ok, query_result = pcall(function()
|
||||
return brain.query({
|
||||
query = prompt,
|
||||
file = file_path,
|
||||
limit = 10,
|
||||
types = { "pattern", "correction", "convention", "fact" },
|
||||
})
|
||||
end)
|
||||
|
||||
if not ok or not query_result then
|
||||
return result
|
||||
end
|
||||
|
||||
result.memories = query_result.nodes or {}
|
||||
result.count = #result.memories
|
||||
|
||||
-- Calculate average relevance
|
||||
if result.count > 0 then
|
||||
local total_relevance = 0
|
||||
for _, node in ipairs(result.memories) do
|
||||
-- Use node weight and success rate as relevance indicators
|
||||
local node_relevance = (node.sc and node.sc.w or 0.5) * (node.sc and node.sc.sr or 0.5)
|
||||
total_relevance = total_relevance + node_relevance
|
||||
end
|
||||
result.relevance = total_relevance / result.count
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Select the best LLM provider based on context
|
||||
---@param prompt string User prompt
|
||||
---@param context table LLM context
|
||||
---@return SelectionResult
|
||||
function M.select_provider(prompt, context)
|
||||
-- Load accuracy stats on first call
|
||||
if accuracy_cache.ollama.total == 0 then
|
||||
load_accuracy_stats()
|
||||
end
|
||||
|
||||
local file_path = context.file_path
|
||||
|
||||
-- Query brain for relevant memories
|
||||
local brain_context = query_brain_context(prompt, file_path)
|
||||
|
||||
-- Calculate base confidence from memories
|
||||
local memory_confidence = 0
|
||||
if brain_context.count >= MIN_MEMORIES_FOR_LOCAL then
|
||||
memory_confidence = math.min(1.0, brain_context.count / 10) * brain_context.relevance
|
||||
end
|
||||
|
||||
-- Factor in historical Ollama accuracy
|
||||
local historical_confidence = get_ollama_historical_confidence()
|
||||
|
||||
-- Combined confidence score
|
||||
local combined_confidence = (memory_confidence * 0.6) + (historical_confidence * 0.4)
|
||||
|
||||
-- Decision logic
|
||||
local provider = "copilot" -- Default to more capable
|
||||
local reason = ""
|
||||
|
||||
if brain_context.count >= MIN_MEMORIES_FOR_LOCAL and combined_confidence >= MIN_RELEVANCE_FOR_LOCAL then
|
||||
provider = "ollama"
|
||||
reason = string.format(
|
||||
"Rich context: %d memories (%.1f%% relevance), historical accuracy: %.1f%%",
|
||||
brain_context.count,
|
||||
brain_context.relevance * 100,
|
||||
historical_confidence * 100
|
||||
)
|
||||
elseif brain_context.count > 0 and combined_confidence >= 0.4 then
|
||||
-- Medium confidence - use Ollama but with pondering
|
||||
provider = "ollama"
|
||||
reason = string.format(
|
||||
"Moderate context: %d memories, will verify with pondering",
|
||||
brain_context.count
|
||||
)
|
||||
else
|
||||
reason = string.format(
|
||||
"Insufficient context: %d memories (need %d), using capable provider",
|
||||
brain_context.count,
|
||||
MIN_MEMORIES_FOR_LOCAL
|
||||
)
|
||||
end
|
||||
|
||||
return {
|
||||
provider = provider,
|
||||
confidence = combined_confidence,
|
||||
memory_count = brain_context.count,
|
||||
reason = reason,
|
||||
memories = brain_context.memories,
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if we should ponder (verify) this Ollama response
|
||||
---@param confidence number Current confidence level
|
||||
---@return boolean
|
||||
function M.should_ponder(confidence)
|
||||
-- Always ponder when confidence is medium
|
||||
if confidence >= 0.4 and confidence < 0.7 then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Random sampling for high confidence to keep learning
|
||||
if confidence >= 0.7 then
|
||||
return math.random() < PONDER_SAMPLE_RATE
|
||||
end
|
||||
|
||||
-- Low confidence shouldn't reach Ollama anyway
|
||||
return false
|
||||
end
|
||||
|
||||
--- Calculate agreement score between two responses
|
||||
---@param response1 string First response
|
||||
---@param response2 string Second response
|
||||
---@return number Agreement score (0-1)
|
||||
local function calculate_agreement(response1, response2)
|
||||
-- Normalize responses
|
||||
local norm1 = response1:lower():gsub("%s+", " "):gsub("[^%w%s]", "")
|
||||
local norm2 = response2:lower():gsub("%s+", " "):gsub("[^%w%s]", "")
|
||||
|
||||
-- Extract words
|
||||
local words1 = {}
|
||||
for word in norm1:gmatch("%w+") do
|
||||
words1[word] = (words1[word] or 0) + 1
|
||||
end
|
||||
|
||||
local words2 = {}
|
||||
for word in norm2:gmatch("%w+") do
|
||||
words2[word] = (words2[word] or 0) + 1
|
||||
end
|
||||
|
||||
-- Calculate Jaccard similarity
|
||||
local intersection = 0
|
||||
local union = 0
|
||||
|
||||
for word, count1 in pairs(words1) do
|
||||
local count2 = words2[word] or 0
|
||||
intersection = intersection + math.min(count1, count2)
|
||||
union = union + math.max(count1, count2)
|
||||
end
|
||||
|
||||
for word, count2 in pairs(words2) do
|
||||
if not words1[word] then
|
||||
union = union + count2
|
||||
end
|
||||
end
|
||||
|
||||
if union == 0 then
|
||||
return 1.0 -- Both empty
|
||||
end
|
||||
|
||||
-- Also check structural similarity (code structure)
|
||||
local struct_score = 0
|
||||
local function_count1 = select(2, response1:gsub("function", ""))
|
||||
local function_count2 = select(2, response2:gsub("function", ""))
|
||||
if function_count1 > 0 or function_count2 > 0 then
|
||||
struct_score = 1 - math.abs(function_count1 - function_count2) / math.max(function_count1, function_count2, 1)
|
||||
else
|
||||
struct_score = 1.0
|
||||
end
|
||||
|
||||
-- Combined score
|
||||
local jaccard = intersection / union
|
||||
return (jaccard * 0.7) + (struct_score * 0.3)
|
||||
end
|
||||
|
||||
--- Ponder (verify) Ollama's response with another LLM
|
||||
---@param prompt string Original prompt
|
||||
---@param context table LLM context
|
||||
---@param ollama_response string Ollama's response
|
||||
---@param callback fun(result: PonderResult) Callback with pondering result
|
||||
function M.ponder(prompt, context, ollama_response, callback)
|
||||
-- Use Copilot as verifier
|
||||
local copilot = require("codetyper.llm.copilot")
|
||||
|
||||
-- Build verification prompt
|
||||
local verify_prompt = prompt
|
||||
|
||||
copilot.generate(verify_prompt, context, function(verifier_response, error)
|
||||
if error or not verifier_response then
|
||||
-- Verification failed, assume Ollama is correct
|
||||
callback({
|
||||
ollama_response = ollama_response,
|
||||
verifier_response = "",
|
||||
agreement_score = 1.0,
|
||||
ollama_correct = true,
|
||||
feedback = "Verification unavailable, trusting Ollama",
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Calculate agreement
|
||||
local agreement = calculate_agreement(ollama_response, verifier_response)
|
||||
|
||||
-- Determine if Ollama was correct
|
||||
local ollama_correct = agreement >= AGREEMENT_THRESHOLD
|
||||
|
||||
-- Generate feedback
|
||||
local feedback
|
||||
if ollama_correct then
|
||||
feedback = string.format("Agreement: %.1f%% - Ollama response validated", agreement * 100)
|
||||
else
|
||||
feedback = string.format(
|
||||
"Disagreement: %.1f%% - Ollama may need correction",
|
||||
(1 - agreement) * 100
|
||||
)
|
||||
end
|
||||
|
||||
-- Update accuracy tracking
|
||||
accuracy_cache.ollama.total = accuracy_cache.ollama.total + 1
|
||||
if ollama_correct then
|
||||
accuracy_cache.ollama.correct = accuracy_cache.ollama.correct + 1
|
||||
end
|
||||
save_accuracy_stats()
|
||||
|
||||
-- Learn from this verification
|
||||
local brain = get_brain()
|
||||
if brain then
|
||||
pcall(function()
|
||||
if ollama_correct then
|
||||
-- Reinforce the pattern
|
||||
brain.learn({
|
||||
type = "correction",
|
||||
summary = "Ollama verified correct",
|
||||
detail = string.format(
|
||||
"Prompt: %s\nAgreement: %.1f%%",
|
||||
prompt:sub(1, 100),
|
||||
agreement * 100
|
||||
),
|
||||
weight = 0.8,
|
||||
file = context.file_path,
|
||||
})
|
||||
else
|
||||
-- Learn the correction
|
||||
brain.learn({
|
||||
type = "correction",
|
||||
summary = "Ollama needed correction",
|
||||
detail = string.format(
|
||||
"Prompt: %s\nOllama: %s\nCorrect: %s",
|
||||
prompt:sub(1, 100),
|
||||
ollama_response:sub(1, 200),
|
||||
verifier_response:sub(1, 200)
|
||||
),
|
||||
weight = 0.9,
|
||||
file = context.file_path,
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
callback({
|
||||
ollama_response = ollama_response,
|
||||
verifier_response = verifier_response,
|
||||
agreement_score = agreement,
|
||||
ollama_correct = ollama_correct,
|
||||
feedback = feedback,
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Smart generate with automatic provider selection and pondering
|
||||
---@param prompt string User prompt
|
||||
---@param context table LLM context
|
||||
---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback
|
||||
function M.smart_generate(prompt, context, callback)
|
||||
-- Select provider
|
||||
local selection = M.select_provider(prompt, context)
|
||||
|
||||
-- Log selection
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
"LLM: %s (confidence: %.1f%%, %s)",
|
||||
selection.provider,
|
||||
selection.confidence * 100,
|
||||
selection.reason
|
||||
),
|
||||
})
|
||||
end)
|
||||
|
||||
-- Get the selected client
|
||||
local client
|
||||
if selection.provider == "ollama" then
|
||||
client = require("codetyper.llm.ollama")
|
||||
else
|
||||
client = require("codetyper.llm.copilot")
|
||||
end
|
||||
|
||||
-- Generate response
|
||||
client.generate(prompt, context, function(response, error)
|
||||
if error then
|
||||
-- Fallback on error
|
||||
if selection.provider == "ollama" then
|
||||
-- Try Copilot as fallback
|
||||
local copilot = require("codetyper.llm.copilot")
|
||||
copilot.generate(prompt, context, function(fallback_response, fallback_error)
|
||||
callback(fallback_response, fallback_error, {
|
||||
provider = "copilot",
|
||||
fallback = true,
|
||||
original_provider = "ollama",
|
||||
original_error = error,
|
||||
})
|
||||
end)
|
||||
return
|
||||
end
|
||||
callback(nil, error, { provider = selection.provider })
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if we should ponder
|
||||
if selection.provider == "ollama" and M.should_ponder(selection.confidence) then
|
||||
M.ponder(prompt, context, response, function(ponder_result)
|
||||
if ponder_result.ollama_correct then
|
||||
-- Ollama was correct, use its response
|
||||
callback(response, nil, {
|
||||
provider = "ollama",
|
||||
pondered = true,
|
||||
agreement = ponder_result.agreement_score,
|
||||
confidence = selection.confidence,
|
||||
})
|
||||
else
|
||||
-- Use verifier's response instead
|
||||
callback(ponder_result.verifier_response, nil, {
|
||||
provider = "copilot",
|
||||
pondered = true,
|
||||
agreement = ponder_result.agreement_score,
|
||||
original_provider = "ollama",
|
||||
corrected = true,
|
||||
})
|
||||
end
|
||||
end)
|
||||
else
|
||||
-- No pondering needed
|
||||
callback(response, nil, {
|
||||
provider = selection.provider,
|
||||
pondered = false,
|
||||
confidence = selection.confidence,
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Get current accuracy statistics
|
||||
---@return table {ollama: {correct, total, accuracy}, copilot: {correct, total, accuracy}}
|
||||
function M.get_accuracy_stats()
|
||||
local stats = {
|
||||
ollama = {
|
||||
correct = accuracy_cache.ollama.correct,
|
||||
total = accuracy_cache.ollama.total,
|
||||
accuracy = accuracy_cache.ollama.total > 0
|
||||
and (accuracy_cache.ollama.correct / accuracy_cache.ollama.total)
|
||||
or 0,
|
||||
},
|
||||
copilot = {
|
||||
correct = accuracy_cache.copilot.correct,
|
||||
total = accuracy_cache.copilot.total,
|
||||
accuracy = accuracy_cache.copilot.total > 0
|
||||
and (accuracy_cache.copilot.correct / accuracy_cache.copilot.total)
|
||||
or 0,
|
||||
},
|
||||
}
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Reset accuracy statistics
|
||||
function M.reset_accuracy_stats()
|
||||
accuracy_cache = {
|
||||
ollama = { correct = 0, total = 0 },
|
||||
copilot = { correct = 0, total = 0 },
|
||||
}
|
||||
save_accuracy_stats()
|
||||
end
|
||||
|
||||
--- Report user feedback on response quality
|
||||
---@param provider string Which provider generated the response
|
||||
---@param was_correct boolean Whether the response was good
|
||||
function M.report_feedback(provider, was_correct)
|
||||
if accuracy_cache[provider] then
|
||||
accuracy_cache[provider].total = accuracy_cache[provider].total + 1
|
||||
if was_correct then
|
||||
accuracy_cache[provider].correct = accuracy_cache[provider].correct + 1
|
||||
end
|
||||
save_accuracy_stats()
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -13,24 +13,23 @@ M.system =
|
||||
You have access to these tools - USE THEM to accomplish tasks:
|
||||
|
||||
### File Operations
|
||||
- **read_file**: Read any file. ALWAYS read files before modifying them.
|
||||
- **write_file**: Create new files or completely replace existing ones. Use for new files.
|
||||
- **edit_file**: Make precise edits to existing files using find/replace. The "find" must match EXACTLY.
|
||||
- **delete_file**: Delete files (requires user approval). Include a reason.
|
||||
- **list_directory**: Explore project structure. See what files exist.
|
||||
- **search_files**: Find files by pattern or content.
|
||||
- **view**: Read any file. ALWAYS read files before modifying them. Parameters: path (string)
|
||||
- **write**: Create new files or completely replace existing ones. Use for new files. Parameters: path (string), content (string)
|
||||
- **edit**: Make precise edits to existing files using search/replace. Parameters: path (string), old_string (string), new_string (string)
|
||||
- **glob**: Find files by pattern (e.g., "**/*.lua"). Parameters: pattern (string), path (optional)
|
||||
- **grep**: Search file contents with regex. Parameters: pattern (string), path (optional)
|
||||
|
||||
### Shell Commands
|
||||
- **bash**: Run shell commands (git, npm, make, etc.). User approves each command.
|
||||
- **bash**: Run shell commands (git, npm, make, etc.). User approves each command. Parameters: command (string)
|
||||
|
||||
## HOW TO WORK
|
||||
|
||||
1. **UNDERSTAND FIRST**: Use read_file, list_directory, or search_files to understand the codebase before making changes.
|
||||
1. **UNDERSTAND FIRST**: Use view, glob, or grep to understand the codebase before making changes.
|
||||
|
||||
2. **MAKE CHANGES**: Use write_file for new files, edit_file for modifications.
|
||||
- For edit_file: The "find" parameter must match file content EXACTLY (including whitespace)
|
||||
- Include enough context in "find" to be unique
|
||||
- For write_file: Provide complete file content
|
||||
2. **MAKE CHANGES**: Use write for new files, edit for modifications.
|
||||
- For edit: The "old_string" parameter must match file content EXACTLY (including whitespace)
|
||||
- Include enough context in "old_string" to be unique
|
||||
- For write: Provide complete file content
|
||||
|
||||
3. **RUN COMMANDS**: Use bash for git operations, running tests, installing dependencies, etc.
|
||||
|
||||
@@ -41,16 +40,16 @@ You have access to these tools - USE THEM to accomplish tasks:
|
||||
User: "Create a new React component for a login form"
|
||||
|
||||
Your approach:
|
||||
1. Use list_directory to see project structure
|
||||
2. Use read_file to check existing component patterns
|
||||
3. Use write_file to create the new component file
|
||||
4. Use write_file to create a test file if appropriate
|
||||
1. Use glob to see project structure (glob pattern="**/*.tsx")
|
||||
2. Use view to check existing component patterns
|
||||
3. Use write to create the new component file
|
||||
4. Use write to create a test file if appropriate
|
||||
5. Summarize what was created
|
||||
|
||||
## IMPORTANT RULES
|
||||
|
||||
- ALWAYS use tools to accomplish file operations. Don't just describe what to do - DO IT.
|
||||
- Read files before editing to ensure your "find" string matches exactly.
|
||||
- Read files before editing to ensure your "old_string" matches exactly.
|
||||
- When creating files, write complete, working code.
|
||||
- When editing, preserve existing code style and conventions.
|
||||
- If a file path is provided, use it. If not, infer from context.
|
||||
@@ -68,13 +67,12 @@ M.tool_instructions = [[
|
||||
## TOOL USAGE
|
||||
|
||||
When you need to perform an action, call the appropriate tool. You can call tools to:
|
||||
- Read files to understand code
|
||||
- Create new files with write_file
|
||||
- Modify existing files with edit_file (read first!)
|
||||
- Delete files with delete_file
|
||||
- List directories to explore structure
|
||||
- Search for files by name or content
|
||||
- Run shell commands with bash
|
||||
- Read files with view (parameters: path)
|
||||
- Create new files with write (parameters: path, content)
|
||||
- Modify existing files with edit (parameters: path, old_string, new_string) - read first!
|
||||
- Find files by pattern with glob (parameters: pattern, path)
|
||||
- Search file contents with grep (parameters: pattern, path)
|
||||
- Run shell commands with bash (parameters: command)
|
||||
|
||||
After receiving a tool result, continue working:
|
||||
- If more actions are needed, call another tool
|
||||
@@ -82,11 +80,11 @@ After receiving a tool result, continue working:
|
||||
|
||||
## CRITICAL RULES
|
||||
|
||||
1. **Always read before editing**: Use read_file before edit_file to ensure exact matches
|
||||
2. **Be precise with edits**: The "find" parameter must match the file content EXACTLY
|
||||
3. **Create complete files**: When using write_file, provide fully working code
|
||||
4. **User approval required**: File writes, edits, deletes, and bash commands need approval
|
||||
5. **Don't guess**: If unsure about file structure, use list_directory or search_files
|
||||
1. **Always read before editing**: Use view before edit to ensure exact matches
|
||||
2. **Be precise with edits**: The "old_string" parameter must match the file content EXACTLY
|
||||
3. **Create complete files**: When using write, provide fully working code
|
||||
4. **User approval required**: File writes, edits, and bash commands need approval
|
||||
5. **Don't guess**: If unsure about file structure, use glob or grep
|
||||
]]
|
||||
|
||||
--- Prompt for when agent finishes
|
||||
|
||||
427
tests/spec/agent_tools_spec.lua
Normal file
427
tests/spec/agent_tools_spec.lua
Normal file
@@ -0,0 +1,427 @@
|
||||
--- Tests for agent tools system
|
||||
|
||||
describe("codetyper.agent.tools", function()
|
||||
local tools
|
||||
|
||||
before_each(function()
|
||||
tools = require("codetyper.agent.tools")
|
||||
-- Clear any existing registrations
|
||||
for name, _ in pairs(tools.get_all()) do
|
||||
tools.unregister(name)
|
||||
end
|
||||
end)
|
||||
|
||||
describe("tool registration", function()
|
||||
it("should register a tool", function()
|
||||
local test_tool = {
|
||||
name = "test_tool",
|
||||
description = "A test tool",
|
||||
params = {
|
||||
{ name = "input", type = "string", description = "Test input" },
|
||||
},
|
||||
func = function(input, opts)
|
||||
return "result", nil
|
||||
end,
|
||||
}
|
||||
|
||||
tools.register(test_tool)
|
||||
local retrieved = tools.get("test_tool")
|
||||
|
||||
assert.is_not_nil(retrieved)
|
||||
assert.equals("test_tool", retrieved.name)
|
||||
end)
|
||||
|
||||
it("should unregister a tool", function()
|
||||
local test_tool = {
|
||||
name = "temp_tool",
|
||||
description = "Temporary",
|
||||
func = function() end,
|
||||
}
|
||||
|
||||
tools.register(test_tool)
|
||||
assert.is_not_nil(tools.get("temp_tool"))
|
||||
|
||||
tools.unregister("temp_tool")
|
||||
assert.is_nil(tools.get("temp_tool"))
|
||||
end)
|
||||
|
||||
it("should list all tools", function()
|
||||
tools.register({ name = "tool1", func = function() end })
|
||||
tools.register({ name = "tool2", func = function() end })
|
||||
tools.register({ name = "tool3", func = function() end })
|
||||
|
||||
local list = tools.list()
|
||||
assert.equals(3, #list)
|
||||
end)
|
||||
|
||||
it("should filter tools with predicate", function()
|
||||
tools.register({ name = "safe_tool", requires_confirmation = false, func = function() end })
|
||||
tools.register({ name = "dangerous_tool", requires_confirmation = true, func = function() end })
|
||||
|
||||
local safe_list = tools.list(function(t)
|
||||
return not t.requires_confirmation
|
||||
end)
|
||||
|
||||
assert.equals(1, #safe_list)
|
||||
assert.equals("safe_tool", safe_list[1].name)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tool execution", function()
|
||||
it("should execute a tool and return result", function()
|
||||
tools.register({
|
||||
name = "adder",
|
||||
params = {
|
||||
{ name = "a", type = "number" },
|
||||
{ name = "b", type = "number" },
|
||||
},
|
||||
func = function(input, opts)
|
||||
return input.a + input.b, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.execute("adder", { a = 5, b = 3 }, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals(8, result)
|
||||
end)
|
||||
|
||||
it("should return error for unknown tool", function()
|
||||
local result, err = tools.execute("nonexistent", {}, {})
|
||||
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("Unknown tool"))
|
||||
end)
|
||||
|
||||
it("should track execution history", function()
|
||||
tools.clear_history()
|
||||
tools.register({
|
||||
name = "tracked_tool",
|
||||
func = function()
|
||||
return "done", nil
|
||||
end,
|
||||
})
|
||||
|
||||
tools.execute("tracked_tool", {}, {})
|
||||
tools.execute("tracked_tool", {}, {})
|
||||
|
||||
local history = tools.get_history()
|
||||
assert.equals(2, #history)
|
||||
assert.equals("tracked_tool", history[1].tool)
|
||||
assert.equals("completed", history[1].status)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tool schemas", function()
|
||||
it("should generate JSON schema for tools", function()
|
||||
tools.register({
|
||||
name = "schema_test",
|
||||
description = "Test schema generation",
|
||||
params = {
|
||||
{ name = "required_param", type = "string", description = "A required param" },
|
||||
{ name = "optional_param", type = "number", description = "Optional", optional = true },
|
||||
},
|
||||
returns = {
|
||||
{ name = "result", type = "string" },
|
||||
},
|
||||
to_schema = require("codetyper.agent.tools.base").to_schema,
|
||||
func = function() end,
|
||||
})
|
||||
|
||||
local schemas = tools.get_schemas()
|
||||
assert.equals(1, #schemas)
|
||||
|
||||
local schema = schemas[1]
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("schema_test", schema.function_def.name)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.required_param)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.optional_param)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("process_tool_call", function()
|
||||
it("should process tool call with name and input", function()
|
||||
tools.register({
|
||||
name = "processor_test",
|
||||
func = function(input, opts)
|
||||
return "processed: " .. input.value, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.process_tool_call({
|
||||
name = "processor_test",
|
||||
input = { value = "test" },
|
||||
}, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals("processed: test", result)
|
||||
end)
|
||||
|
||||
it("should parse JSON string arguments", function()
|
||||
tools.register({
|
||||
name = "json_parser_test",
|
||||
func = function(input, opts)
|
||||
return input.key, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.process_tool_call({
|
||||
name = "json_parser_test",
|
||||
arguments = '{"key": "value"}',
|
||||
}, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals("value", result)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("codetyper.agent.tools.base", function()
|
||||
local base
|
||||
|
||||
before_each(function()
|
||||
base = require("codetyper.agent.tools.base")
|
||||
end)
|
||||
|
||||
describe("validate_input", function()
|
||||
it("should validate required parameters", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "required", type = "string" },
|
||||
{ name = "optional", type = "string", optional = true },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid, err = tool:validate_input({ required = "value" })
|
||||
assert.is_true(valid)
|
||||
assert.is_nil(err)
|
||||
end)
|
||||
|
||||
it("should fail on missing required parameter", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "required", type = "string" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid, err = tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
assert.truthy(err:match("Missing required parameter"))
|
||||
end)
|
||||
|
||||
it("should validate parameter types", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "num", type = "number" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid1, _ = tool:validate_input({ num = 42 })
|
||||
assert.is_true(valid1)
|
||||
|
||||
local valid2, err2 = tool:validate_input({ num = "not a number" })
|
||||
assert.is_false(valid2)
|
||||
assert.truthy(err2:match("must be number"))
|
||||
end)
|
||||
|
||||
it("should validate integer type", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "int", type = "integer" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid1, _ = tool:validate_input({ int = 42 })
|
||||
assert.is_true(valid1)
|
||||
|
||||
local valid2, err2 = tool:validate_input({ int = 42.5 })
|
||||
assert.is_false(valid2)
|
||||
assert.truthy(err2:match("must be an integer"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_description", function()
|
||||
it("should return string description", function()
|
||||
local tool = setmetatable({
|
||||
description = "Static description",
|
||||
}, base)
|
||||
|
||||
assert.equals("Static description", tool:get_description())
|
||||
end)
|
||||
|
||||
it("should call function description", function()
|
||||
local tool = setmetatable({
|
||||
description = function()
|
||||
return "Dynamic description"
|
||||
end,
|
||||
}, base)
|
||||
|
||||
assert.equals("Dynamic description", tool:get_description())
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("to_schema", function()
|
||||
it("should generate valid schema", function()
|
||||
local tool = setmetatable({
|
||||
name = "test",
|
||||
description = "Test tool",
|
||||
params = {
|
||||
{ name = "input", type = "string", description = "Input value" },
|
||||
{ name = "count", type = "integer", description = "Count", optional = true },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local schema = tool:to_schema()
|
||||
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("test", schema.function_def.name)
|
||||
assert.equals("Test tool", schema.function_def.description)
|
||||
assert.equals("object", schema.function_def.parameters.type)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.input)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.count)
|
||||
assert.same({ "input" }, schema.function_def.parameters.required)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("built-in tools", function()
|
||||
describe("view tool", function()
|
||||
local view
|
||||
|
||||
before_each(function()
|
||||
view = require("codetyper.agent.tools.view")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("view", view.name)
|
||||
assert.is_string(view.description)
|
||||
assert.is_table(view.params)
|
||||
assert.is_function(view.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = view.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("grep tool", function()
|
||||
local grep
|
||||
|
||||
before_each(function()
|
||||
grep = require("codetyper.agent.tools.grep")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("grep", grep.name)
|
||||
assert.is_string(grep.description)
|
||||
assert.is_table(grep.params)
|
||||
assert.is_function(grep.func)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local result, err = grep.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("pattern is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("glob tool", function()
|
||||
local glob
|
||||
|
||||
before_each(function()
|
||||
glob = require("codetyper.agent.tools.glob")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("glob", glob.name)
|
||||
assert.is_string(glob.description)
|
||||
assert.is_table(glob.params)
|
||||
assert.is_function(glob.func)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local result, err = glob.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("pattern is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edit tool", function()
|
||||
local edit
|
||||
|
||||
before_each(function()
|
||||
edit = require("codetyper.agent.tools.edit")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("edit", edit.name)
|
||||
assert.is_string(edit.description)
|
||||
assert.is_table(edit.params)
|
||||
assert.is_function(edit.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = edit.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
|
||||
it("should require old_string parameter", function()
|
||||
local result, err = edit.func({ path = "/tmp/test" }, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("old_string is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("write tool", function()
|
||||
local write
|
||||
|
||||
before_each(function()
|
||||
write = require("codetyper.agent.tools.write")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("write", write.name)
|
||||
assert.is_string(write.description)
|
||||
assert.is_table(write.params)
|
||||
assert.is_function(write.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = write.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
|
||||
it("should require content parameter", function()
|
||||
local result, err = write.func({ path = "/tmp/test" }, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("content is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("bash tool", function()
|
||||
local bash
|
||||
|
||||
before_each(function()
|
||||
bash = require("codetyper.agent.tools.bash")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("bash", bash.name)
|
||||
assert.is_function(bash.func)
|
||||
end)
|
||||
|
||||
it("should require command parameter", function()
|
||||
local result, err = bash.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("command is required"))
|
||||
end)
|
||||
|
||||
it("should require confirmation by default", function()
|
||||
assert.is_true(bash.requires_confirmation)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
312
tests/spec/agentic_spec.lua
Normal file
312
tests/spec/agentic_spec.lua
Normal file
@@ -0,0 +1,312 @@
|
||||
---@diagnostic disable: undefined-global
|
||||
-- Unit tests for the agentic system
|
||||
|
||||
describe("agentic module", function()
|
||||
local agentic
|
||||
|
||||
before_each(function()
|
||||
-- Reset and reload
|
||||
package.loaded["codetyper.agent.agentic"] = nil
|
||||
agentic = require("codetyper.agent.agentic")
|
||||
end)
|
||||
|
||||
it("should list built-in agents", function()
|
||||
local agents = agentic.list_agents()
|
||||
assert.is_table(agents)
|
||||
assert.is_true(#agents >= 3) -- coder, planner, explorer
|
||||
|
||||
local names = {}
|
||||
for _, agent in ipairs(agents) do
|
||||
names[agent.name] = true
|
||||
end
|
||||
|
||||
assert.is_true(names["coder"])
|
||||
assert.is_true(names["planner"])
|
||||
assert.is_true(names["explorer"])
|
||||
end)
|
||||
|
||||
it("should have description for each agent", function()
|
||||
local agents = agentic.list_agents()
|
||||
for _, agent in ipairs(agents) do
|
||||
assert.is_string(agent.description)
|
||||
assert.is_true(#agent.description > 0)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should mark built-in agents as builtin", function()
|
||||
local agents = agentic.list_agents()
|
||||
local coder = nil
|
||||
for _, agent in ipairs(agents) do
|
||||
if agent.name == "coder" then
|
||||
coder = agent
|
||||
break
|
||||
end
|
||||
end
|
||||
assert.is_not_nil(coder)
|
||||
assert.is_true(coder.builtin)
|
||||
end)
|
||||
|
||||
it("should have init function to create directories", function()
|
||||
assert.is_function(agentic.init)
|
||||
assert.is_function(agentic.init_agents_dir)
|
||||
assert.is_function(agentic.init_rules_dir)
|
||||
end)
|
||||
|
||||
it("should have run function for executing tasks", function()
|
||||
assert.is_function(agentic.run)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tools format conversion", function()
|
||||
local tools_module
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools"] = nil
|
||||
tools_module = require("codetyper.agent.tools")
|
||||
-- Load tools
|
||||
if tools_module.load_builtins then
|
||||
pcall(tools_module.load_builtins)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should have to_openai_format function", function()
|
||||
assert.is_function(tools_module.to_openai_format)
|
||||
end)
|
||||
|
||||
it("should have to_claude_format function", function()
|
||||
assert.is_function(tools_module.to_claude_format)
|
||||
end)
|
||||
|
||||
it("should convert tools to OpenAI format", function()
|
||||
local openai_tools = tools_module.to_openai_format()
|
||||
assert.is_table(openai_tools)
|
||||
|
||||
-- If tools are loaded, check format
|
||||
if #openai_tools > 0 then
|
||||
local first_tool = openai_tools[1]
|
||||
assert.equals("function", first_tool.type)
|
||||
assert.is_table(first_tool["function"])
|
||||
assert.is_string(first_tool["function"].name)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should convert tools to Claude format", function()
|
||||
local claude_tools = tools_module.to_claude_format()
|
||||
assert.is_table(claude_tools)
|
||||
|
||||
-- If tools are loaded, check format
|
||||
if #claude_tools > 0 then
|
||||
local first_tool = claude_tools[1]
|
||||
assert.is_string(first_tool.name)
|
||||
assert.is_table(first_tool.input_schema)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edit tool", function()
|
||||
local edit_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.edit"] = nil
|
||||
edit_tool = require("codetyper.agent.tools.edit")
|
||||
end)
|
||||
|
||||
it("should have name 'edit'", function()
|
||||
assert.equals("edit", edit_tool.name)
|
||||
end)
|
||||
|
||||
it("should have description mentioning matching strategies", function()
|
||||
local desc = edit_tool:get_description()
|
||||
assert.is_string(desc)
|
||||
-- Should mention the matching capabilities
|
||||
assert.is_true(desc:lower():match("match") ~= nil or desc:lower():match("replac") ~= nil)
|
||||
end)
|
||||
|
||||
it("should have params defined", function()
|
||||
assert.is_table(edit_tool.params)
|
||||
assert.is_true(#edit_tool.params >= 3) -- path, old_string, new_string
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
old_string = "test",
|
||||
new_string = "test2",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
assert.is_string(err)
|
||||
end)
|
||||
|
||||
it("should require old_string parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test",
|
||||
new_string = "test",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should require new_string parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test",
|
||||
old_string = "test",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept empty old_string for new file creation", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test/new_file.lua",
|
||||
old_string = "",
|
||||
new_string = "new content",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
assert.is_nil(err)
|
||||
end)
|
||||
|
||||
it("should have func implementation", function()
|
||||
assert.is_function(edit_tool.func)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("view tool", function()
|
||||
local view_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.view"] = nil
|
||||
view_tool = require("codetyper.agent.tools.view")
|
||||
end)
|
||||
|
||||
it("should have name 'view'", function()
|
||||
assert.equals("view", view_tool.name)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local valid, err = view_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid path", function()
|
||||
local valid, err = view_tool:validate_input({
|
||||
path = "/test/file.lua",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("write tool", function()
|
||||
local write_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.write"] = nil
|
||||
write_tool = require("codetyper.agent.tools.write")
|
||||
end)
|
||||
|
||||
it("should have name 'write'", function()
|
||||
assert.equals("write", write_tool.name)
|
||||
end)
|
||||
|
||||
it("should require path and content parameters", function()
|
||||
local valid, err = write_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
|
||||
valid, err = write_tool:validate_input({ path = "/test" })
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid input", function()
|
||||
local valid, err = write_tool:validate_input({
|
||||
path = "/test/file.lua",
|
||||
content = "test content",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("grep tool", function()
|
||||
local grep_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.grep"] = nil
|
||||
grep_tool = require("codetyper.agent.tools.grep")
|
||||
end)
|
||||
|
||||
it("should have name 'grep'", function()
|
||||
assert.equals("grep", grep_tool.name)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local valid, err = grep_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid pattern", function()
|
||||
local valid, err = grep_tool:validate_input({
|
||||
pattern = "function.*test",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("glob tool", function()
|
||||
local glob_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.glob"] = nil
|
||||
glob_tool = require("codetyper.agent.tools.glob")
|
||||
end)
|
||||
|
||||
it("should have name 'glob'", function()
|
||||
assert.equals("glob", glob_tool.name)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local valid, err = glob_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid pattern", function()
|
||||
local valid, err = glob_tool:validate_input({
|
||||
pattern = "**/*.lua",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("base tool", function()
|
||||
local Base
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.base"] = nil
|
||||
Base = require("codetyper.agent.tools.base")
|
||||
end)
|
||||
|
||||
it("should have validate_input method", function()
|
||||
assert.is_function(Base.validate_input)
|
||||
end)
|
||||
|
||||
it("should have to_schema method", function()
|
||||
assert.is_function(Base.to_schema)
|
||||
end)
|
||||
|
||||
it("should have get_description method", function()
|
||||
assert.is_function(Base.get_description)
|
||||
end)
|
||||
|
||||
it("should generate valid schema", function()
|
||||
local test_tool = setmetatable({
|
||||
name = "test",
|
||||
description = "A test tool",
|
||||
params = {
|
||||
{ name = "arg1", type = "string", description = "First arg" },
|
||||
{ name = "arg2", type = "number", description = "Second arg", optional = true },
|
||||
},
|
||||
}, Base)
|
||||
|
||||
local schema = test_tool:to_schema()
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("test", schema.function_def.name)
|
||||
assert.is_table(schema.function_def.parameters.properties)
|
||||
assert.is_table(schema.function_def.parameters.required)
|
||||
assert.is_true(vim.tbl_contains(schema.function_def.parameters.required, "arg1"))
|
||||
assert.is_false(vim.tbl_contains(schema.function_def.parameters.required, "arg2"))
|
||||
end)
|
||||
end)
|
||||
153
tests/spec/brain_learners_spec.lua
Normal file
153
tests/spec/brain_learners_spec.lua
Normal file
@@ -0,0 +1,153 @@
|
||||
--- Tests for brain/learners pattern detection and extraction
|
||||
describe("brain.learners", function()
|
||||
local pattern_learner
|
||||
|
||||
before_each(function()
|
||||
-- Clear module cache
|
||||
package.loaded["codetyper.brain.learners.pattern"] = nil
|
||||
package.loaded["codetyper.brain.types"] = nil
|
||||
|
||||
pattern_learner = require("codetyper.brain.learners.pattern")
|
||||
end)
|
||||
|
||||
describe("pattern learner detection", function()
|
||||
it("should detect code_completion events", function()
|
||||
local event = { type = "code_completion", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect file_indexed events", function()
|
||||
local event = { type = "file_indexed", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect code_analyzed events", function()
|
||||
local event = { type = "code_analyzed", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect pattern_detected events", function()
|
||||
local event = { type = "pattern_detected", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect plain 'pattern' type events", function()
|
||||
-- This was the bug - 'pattern' type was not in the valid_types list
|
||||
local event = { type = "pattern", data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect unknown event types", function()
|
||||
local event = { type = "unknown_type", data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect nil events", function()
|
||||
assert.is_false(pattern_learner.detect(nil))
|
||||
end)
|
||||
|
||||
it("should NOT detect events without type", function()
|
||||
local event = { data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("pattern learner extraction", function()
|
||||
it("should extract from pattern_detected events", function()
|
||||
local event = {
|
||||
type = "pattern_detected",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
name = "Test pattern",
|
||||
description = "Pattern description",
|
||||
language = "lua",
|
||||
symbols = { "func1", "func2" },
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.equals("Test pattern", extracted.summary)
|
||||
assert.equals("Pattern description", extracted.detail)
|
||||
assert.equals("lua", extracted.lang)
|
||||
assert.equals("/path/to/file.lua", extracted.file)
|
||||
end)
|
||||
|
||||
it("should handle pattern_detected with minimal data", function()
|
||||
local event = {
|
||||
type = "pattern_detected",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
name = "Minimal pattern",
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.equals("Minimal pattern", extracted.summary)
|
||||
assert.equals("Minimal pattern", extracted.detail)
|
||||
end)
|
||||
|
||||
it("should extract from code_completion events", function()
|
||||
local event = {
|
||||
type = "code_completion",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
intent = "add function",
|
||||
code = "function test() end",
|
||||
language = "lua",
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.is_true(extracted.summary:find("Code pattern") ~= nil)
|
||||
assert.equals("function test() end", extracted.detail)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("should_learn validation", function()
|
||||
it("should accept valid patterns", function()
|
||||
local data = {
|
||||
summary = "Valid pattern summary",
|
||||
detail = "This is a detailed description of the pattern",
|
||||
}
|
||||
assert.is_true(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns without summary", function()
|
||||
local data = {
|
||||
summary = "",
|
||||
detail = "Some detail",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns with nil summary", function()
|
||||
local data = {
|
||||
summary = nil,
|
||||
detail = "Some detail",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns with very short detail", function()
|
||||
local data = {
|
||||
summary = "Valid summary",
|
||||
detail = "short", -- Less than 10 chars
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject whitespace-only summaries", function()
|
||||
local data = {
|
||||
summary = " ",
|
||||
detail = "Some valid detail here",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
194
tests/spec/coder_context_spec.lua
Normal file
194
tests/spec/coder_context_spec.lua
Normal file
@@ -0,0 +1,194 @@
|
||||
--- Tests for coder file context injection
|
||||
describe("coder context injection", function()
|
||||
local test_dir
|
||||
local original_filereadable
|
||||
|
||||
before_each(function()
|
||||
test_dir = "/tmp/codetyper_coder_test_" .. os.time()
|
||||
vim.fn.mkdir(test_dir, "p")
|
||||
|
||||
-- Store original function
|
||||
original_filereadable = vim.fn.filereadable
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
vim.fn.delete(test_dir, "rf")
|
||||
vim.fn.filereadable = original_filereadable
|
||||
end)
|
||||
|
||||
describe("get_coder_companion_path logic", function()
|
||||
-- Test the path generation logic (simulating the function behavior)
|
||||
local function get_coder_companion_path(target_path, file_exists_check)
|
||||
if not target_path or target_path == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Skip if target is already a coder file
|
||||
if target_path:match("%.coder%.") then
|
||||
return nil
|
||||
end
|
||||
|
||||
local dir = vim.fn.fnamemodify(target_path, ":h")
|
||||
local name = vim.fn.fnamemodify(target_path, ":t:r")
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
local coder_path = dir .. "/" .. name .. ".coder." .. ext
|
||||
if file_exists_check(coder_path) then
|
||||
return coder_path
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
it("should generate correct coder path for source file", function()
|
||||
local target = "/path/to/file.ts"
|
||||
local expected = "/path/to/file.coder.ts"
|
||||
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
|
||||
assert.equals(expected, path)
|
||||
end)
|
||||
|
||||
it("should return nil for empty path", function()
|
||||
local path = get_coder_companion_path("", function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil for nil path", function()
|
||||
local path = get_coder_companion_path(nil, function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil for coder files (avoid recursion)", function()
|
||||
local target = "/path/to/file.coder.ts"
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil if coder file doesn't exist", function()
|
||||
local target = "/path/to/file.ts"
|
||||
local path = get_coder_companion_path(target, function() return false end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should handle files with multiple dots", function()
|
||||
local target = "/path/to/my.component.ts"
|
||||
local expected = "/path/to/my.component.coder.ts"
|
||||
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
|
||||
assert.equals(expected, path)
|
||||
end)
|
||||
|
||||
it("should handle different extensions", function()
|
||||
local test_cases = {
|
||||
{ target = "/path/file.lua", expected = "/path/file.coder.lua" },
|
||||
{ target = "/path/file.py", expected = "/path/file.coder.py" },
|
||||
{ target = "/path/file.js", expected = "/path/file.coder.js" },
|
||||
{ target = "/path/file.go", expected = "/path/file.coder.go" },
|
||||
}
|
||||
|
||||
for _, tc in ipairs(test_cases) do
|
||||
local path = get_coder_companion_path(tc.target, function() return true end)
|
||||
assert.equals(tc.expected, path, "Failed for: " .. tc.target)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("coder content filtering", function()
|
||||
-- Test the filtering logic that skips template-only content
|
||||
local function has_meaningful_content(lines)
|
||||
for _, line in ipairs(lines) do
|
||||
local trimmed = line:gsub("^%s*", "")
|
||||
if not trimmed:match("^[%-#/]+%s*Coder companion")
|
||||
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
|
||||
and not trimmed:match("^[%-#/]+%s*Example:")
|
||||
and not trimmed:match("^<!%-%-")
|
||||
and trimmed ~= ""
|
||||
and not trimmed:match("^[%-#/]+%s*$") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
it("should detect meaningful content", function()
|
||||
local lines = {
|
||||
"-- Coder companion for test.lua",
|
||||
"-- This file handles authentication",
|
||||
"/@",
|
||||
"Add login function",
|
||||
"@/",
|
||||
}
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should reject template-only content", function()
|
||||
-- Template lines are filtered by specific patterns
|
||||
-- Only header comments that match the template format are filtered
|
||||
local lines = {
|
||||
"-- Coder companion for test.lua",
|
||||
"-- Use /@ @/ tags to write pseudo-code prompts",
|
||||
"-- Example:",
|
||||
"--",
|
||||
"",
|
||||
}
|
||||
assert.is_false(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should detect pseudo-code content", function()
|
||||
local lines = {
|
||||
"-- Authentication module",
|
||||
"",
|
||||
"-- This module should:",
|
||||
"-- 1. Validate user credentials",
|
||||
"-- 2. Generate JWT tokens",
|
||||
"-- 3. Handle session management",
|
||||
}
|
||||
-- "-- Authentication module" doesn't match template patterns
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should handle JavaScript style comments", function()
|
||||
local lines = {
|
||||
"// Coder companion for test.ts",
|
||||
"// Business logic for user authentication",
|
||||
"",
|
||||
"// The auth flow should:",
|
||||
"// 1. Check OAuth token",
|
||||
"// 2. Validate permissions",
|
||||
}
|
||||
-- "// Business logic..." doesn't match template patterns
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should handle empty lines", function()
|
||||
local lines = {
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
}
|
||||
assert.is_false(has_meaningful_content(lines))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("context format", function()
|
||||
it("should format context with proper header", function()
|
||||
local function format_coder_context(content, ext)
|
||||
return string.format(
|
||||
"\n\n--- Business Context / Pseudo-code ---\n" ..
|
||||
"The following describes the intended behavior and design for this file:\n" ..
|
||||
"```%s\n%s\n```",
|
||||
ext,
|
||||
content
|
||||
)
|
||||
end
|
||||
|
||||
local formatted = format_coder_context("-- Auth logic here", "lua")
|
||||
|
||||
assert.is_true(formatted:find("Business Context") ~= nil)
|
||||
assert.is_true(formatted:find("```lua") ~= nil)
|
||||
assert.is_true(formatted:find("Auth logic here") ~= nil)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
161
tests/spec/coder_ignore_spec.lua
Normal file
161
tests/spec/coder_ignore_spec.lua
Normal file
@@ -0,0 +1,161 @@
|
||||
--- Tests for coder file ignore logic
|
||||
describe("coder file ignore logic", function()
|
||||
-- Directories to ignore
|
||||
local ignored_directories = {
|
||||
".git",
|
||||
".coder",
|
||||
".claude",
|
||||
".vscode",
|
||||
".idea",
|
||||
"node_modules",
|
||||
"vendor",
|
||||
"dist",
|
||||
"build",
|
||||
"target",
|
||||
"__pycache__",
|
||||
".cache",
|
||||
".npm",
|
||||
".yarn",
|
||||
"coverage",
|
||||
".next",
|
||||
".nuxt",
|
||||
".svelte-kit",
|
||||
"out",
|
||||
"bin",
|
||||
"obj",
|
||||
}
|
||||
|
||||
-- Files to ignore
|
||||
local ignored_files = {
|
||||
".gitignore",
|
||||
".gitattributes",
|
||||
"package-lock.json",
|
||||
"yarn.lock",
|
||||
".env",
|
||||
".eslintrc",
|
||||
"tsconfig.json",
|
||||
"README.md",
|
||||
"LICENSE",
|
||||
"Makefile",
|
||||
}
|
||||
|
||||
local function is_in_ignored_directory(filepath)
|
||||
for _, dir in ipairs(ignored_directories) do
|
||||
if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then
|
||||
return true
|
||||
end
|
||||
if filepath:match("^" .. dir .. "/") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function should_ignore_for_coder(filepath)
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
|
||||
for _, ignored in ipairs(ignored_files) do
|
||||
if filename == ignored then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
if filename:match("^%.") then
|
||||
return true
|
||||
end
|
||||
|
||||
if is_in_ignored_directory(filepath) then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
describe("ignored directories", function()
|
||||
it("should ignore files in node_modules", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/node_modules/lodash/index.js"))
|
||||
assert.is_true(should_ignore_for_coder("/project/node_modules/react/index.js"))
|
||||
end)
|
||||
|
||||
it("should ignore files in .git", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.git/config"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.git/hooks/pre-commit"))
|
||||
end)
|
||||
|
||||
it("should ignore files in .coder", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.coder/brain/meta.json"))
|
||||
end)
|
||||
|
||||
it("should ignore files in vendor", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/vendor/autoload.php"))
|
||||
end)
|
||||
|
||||
it("should ignore files in dist/build", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/dist/bundle.js"))
|
||||
assert.is_true(should_ignore_for_coder("/project/build/output.js"))
|
||||
end)
|
||||
|
||||
it("should ignore files in __pycache__", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/__pycache__/module.cpython-39.pyc"))
|
||||
end)
|
||||
|
||||
it("should NOT ignore regular source files", function()
|
||||
assert.is_false(should_ignore_for_coder("/project/src/index.ts"))
|
||||
assert.is_false(should_ignore_for_coder("/project/lib/utils.lua"))
|
||||
assert.is_false(should_ignore_for_coder("/project/app/main.py"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("ignored files", function()
|
||||
it("should ignore .gitignore", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.gitignore"))
|
||||
end)
|
||||
|
||||
it("should ignore lock files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/package-lock.json"))
|
||||
assert.is_true(should_ignore_for_coder("/project/yarn.lock"))
|
||||
end)
|
||||
|
||||
it("should ignore config files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/tsconfig.json"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.eslintrc"))
|
||||
end)
|
||||
|
||||
it("should ignore .env files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.env"))
|
||||
end)
|
||||
|
||||
it("should ignore README and LICENSE", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/README.md"))
|
||||
assert.is_true(should_ignore_for_coder("/project/LICENSE"))
|
||||
end)
|
||||
|
||||
it("should ignore hidden/dot files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.hidden"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.secret"))
|
||||
end)
|
||||
|
||||
it("should NOT ignore regular source files", function()
|
||||
assert.is_false(should_ignore_for_coder("/project/src/app.ts"))
|
||||
assert.is_false(should_ignore_for_coder("/project/components/Button.tsx"))
|
||||
assert.is_false(should_ignore_for_coder("/project/utils/helpers.js"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edge cases", function()
|
||||
it("should handle nested node_modules", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/packages/core/node_modules/dep/index.js"))
|
||||
end)
|
||||
|
||||
it("should handle files named like directories but not in them", function()
|
||||
-- A file named "node_modules.md" in root should be ignored (starts with .)
|
||||
-- But a file in a folder that contains "node" should NOT be ignored
|
||||
assert.is_false(should_ignore_for_coder("/project/src/node_utils.ts"))
|
||||
end)
|
||||
|
||||
it("should handle relative paths", function()
|
||||
assert.is_true(should_ignore_for_coder("node_modules/lodash/index.js"))
|
||||
assert.is_false(should_ignore_for_coder("src/index.ts"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
371
tests/spec/inject_spec.lua
Normal file
371
tests/spec/inject_spec.lua
Normal file
@@ -0,0 +1,371 @@
|
||||
--- Tests for smart code injection with import handling
|
||||
|
||||
describe("codetyper.agent.inject", function()
|
||||
local inject
|
||||
|
||||
before_each(function()
|
||||
inject = require("codetyper.agent.inject")
|
||||
end)
|
||||
|
||||
describe("parse_code", function()
|
||||
describe("JavaScript/TypeScript", function()
|
||||
it("should detect ES6 named imports", function()
|
||||
local code = [[import { useState, useEffect } from 'react';
|
||||
import { Button } from './components';
|
||||
|
||||
function App() {
|
||||
return <div>Hello</div>;
|
||||
}]]
|
||||
local result = inject.parse_code(code, "typescript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("useState"))
|
||||
assert.truthy(result.imports[2]:match("Button"))
|
||||
assert.truthy(#result.body > 0)
|
||||
end)
|
||||
|
||||
it("should detect ES6 default imports", function()
|
||||
local code = [[import React from 'react';
|
||||
import axios from 'axios';
|
||||
|
||||
const api = axios.create();]]
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("React"))
|
||||
assert.truthy(result.imports[2]:match("axios"))
|
||||
end)
|
||||
|
||||
it("should detect require imports", function()
|
||||
local code = [[const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
module.exports = { fs, path };]]
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("fs"))
|
||||
assert.truthy(result.imports[2]:match("path"))
|
||||
end)
|
||||
|
||||
it("should detect multi-line imports", function()
|
||||
local code = [[import {
|
||||
useState,
|
||||
useEffect,
|
||||
useCallback
|
||||
} from 'react';
|
||||
|
||||
function Component() {}]]
|
||||
local result = inject.parse_code(code, "typescript")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("useState"))
|
||||
assert.truthy(result.imports[1]:match("useCallback"))
|
||||
end)
|
||||
|
||||
it("should detect namespace imports", function()
|
||||
local code = [[import * as React from 'react';
|
||||
|
||||
export default React;]]
|
||||
local result = inject.parse_code(code, "tsx")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("%* as React"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Python", function()
|
||||
it("should detect simple imports", function()
|
||||
local code = [[import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
def main():
|
||||
pass]]
|
||||
local result = inject.parse_code(code, "python")
|
||||
|
||||
assert.equals(3, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("import os"))
|
||||
assert.truthy(result.imports[2]:match("import sys"))
|
||||
assert.truthy(result.imports[3]:match("import json"))
|
||||
end)
|
||||
|
||||
it("should detect from imports", function()
|
||||
local code = [[from typing import List, Dict
|
||||
from pathlib import Path
|
||||
|
||||
def process(items: List[str]) -> None:
|
||||
pass]]
|
||||
local result = inject.parse_code(code, "py")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("from typing"))
|
||||
assert.truthy(result.imports[2]:match("from pathlib"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Lua", function()
|
||||
it("should detect require statements", function()
|
||||
local code = [[local M = {}
|
||||
local utils = require("codetyper.utils")
|
||||
local config = require('codetyper.config')
|
||||
|
||||
function M.setup()
|
||||
end
|
||||
|
||||
return M]]
|
||||
local result = inject.parse_code(code, "lua")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("utils"))
|
||||
assert.truthy(result.imports[2]:match("config"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Go", function()
|
||||
it("should detect single imports", function()
|
||||
local code = [[package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println("Hello")
|
||||
}]]
|
||||
local result = inject.parse_code(code, "go")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match('import "fmt"'))
|
||||
end)
|
||||
|
||||
it("should detect grouped imports", function()
|
||||
local code = [[package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {}]]
|
||||
local result = inject.parse_code(code, "go")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("fmt"))
|
||||
assert.truthy(result.imports[1]:match("os"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Rust", function()
|
||||
it("should detect use statements", function()
|
||||
local code = [[use std::io;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
let map = HashMap::new();
|
||||
}]]
|
||||
local result = inject.parse_code(code, "rs")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("std::io"))
|
||||
assert.truthy(result.imports[2]:match("HashMap"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("C/C++", function()
|
||||
it("should detect include statements", function()
|
||||
local code = [[#include <stdio.h>
|
||||
#include "myheader.h"
|
||||
|
||||
int main() {
|
||||
return 0;
|
||||
}]]
|
||||
local result = inject.parse_code(code, "c")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("stdio"))
|
||||
assert.truthy(result.imports[2]:match("myheader"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("merge_imports", function()
|
||||
it("should merge without duplicates", function()
|
||||
local existing = {
|
||||
"import { useState } from 'react';",
|
||||
"import { Button } from './components';",
|
||||
}
|
||||
local new_imports = {
|
||||
"import { useEffect } from 'react';",
|
||||
"import { useState } from 'react';", -- duplicate
|
||||
"import { Card } from './components';",
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(4, #merged) -- Should not have duplicate useState
|
||||
end)
|
||||
|
||||
it("should handle empty existing imports", function()
|
||||
local existing = {}
|
||||
local new_imports = {
|
||||
"import os",
|
||||
"import sys",
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(2, #merged)
|
||||
end)
|
||||
|
||||
it("should handle empty new imports", function()
|
||||
local existing = {
|
||||
"import os",
|
||||
"import sys",
|
||||
}
|
||||
local new_imports = {}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(2, #merged)
|
||||
end)
|
||||
|
||||
it("should handle whitespace variations in duplicates", function()
|
||||
local existing = {
|
||||
"import { useState } from 'react';",
|
||||
}
|
||||
local new_imports = {
|
||||
"import {useState} from 'react';", -- Same but different spacing
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(1, #merged) -- Should detect as duplicate
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("sort_imports", function()
|
||||
it("should group imports by type for JavaScript", function()
|
||||
local imports = {
|
||||
"import React from 'react';",
|
||||
"import { Button } from './components';",
|
||||
"import axios from 'axios';",
|
||||
"import path from 'path';",
|
||||
}
|
||||
|
||||
local sorted = inject.sort_imports(imports, "javascript")
|
||||
|
||||
-- Check ordering: builtin -> third-party -> local
|
||||
local found_builtin = false
|
||||
local found_local = false
|
||||
local builtin_pos = 0
|
||||
local local_pos = 0
|
||||
|
||||
for i, imp in ipairs(sorted) do
|
||||
if imp:match("path") then
|
||||
found_builtin = true
|
||||
builtin_pos = i
|
||||
end
|
||||
if imp:match("%.%/") then
|
||||
found_local = true
|
||||
local_pos = i
|
||||
end
|
||||
end
|
||||
|
||||
-- Local imports should come after third-party
|
||||
if found_local and found_builtin then
|
||||
assert.truthy(local_pos > builtin_pos)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("has_imports", function()
|
||||
it("should return true when code has imports", function()
|
||||
local code = [[import { useState } from 'react';
|
||||
function App() {}]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "typescript"))
|
||||
end)
|
||||
|
||||
it("should return false when code has no imports", function()
|
||||
local code = [[function App() {
|
||||
return <div>Hello</div>;
|
||||
}]]
|
||||
|
||||
assert.is_false(inject.has_imports(code, "typescript"))
|
||||
end)
|
||||
|
||||
it("should detect Python imports", function()
|
||||
local code = [[from typing import List
|
||||
|
||||
def process(items: List[str]):
|
||||
pass]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "python"))
|
||||
end)
|
||||
|
||||
it("should detect Lua requires", function()
|
||||
local code = [[local utils = require("utils")
|
||||
|
||||
local M = {}
|
||||
return M]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "lua"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edge cases", function()
|
||||
it("should handle empty code", function()
|
||||
local result = inject.parse_code("", "javascript")
|
||||
|
||||
assert.equals(0, #result.imports)
|
||||
assert.equals(1, #result.body) -- Empty string becomes one empty line
|
||||
end)
|
||||
|
||||
it("should handle code with only imports", function()
|
||||
local code = [[import React from 'react';
|
||||
import { useState } from 'react';]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.equals(0, #result.body)
|
||||
end)
|
||||
|
||||
it("should handle code with only body", function()
|
||||
local code = [[function hello() {
|
||||
console.log("Hello");
|
||||
}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(0, #result.imports)
|
||||
assert.truthy(#result.body > 0)
|
||||
end)
|
||||
|
||||
it("should handle imports in string literals (not detect as imports)", function()
|
||||
local code = [[const example = "import { fake } from 'not-real';";
|
||||
const config = { import: true };
|
||||
|
||||
function test() {}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
-- The first line looks like an import but is in a string
|
||||
-- This is a known limitation - we accept some false positives
|
||||
-- The important thing is we don't break the code
|
||||
assert.truthy(#result.body >= 0)
|
||||
end)
|
||||
|
||||
it("should handle mixed import styles in same file", function()
|
||||
local code = [[import React from 'react';
|
||||
const axios = require('axios');
|
||||
import { useState } from 'react';
|
||||
|
||||
function App() {}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(3, #result.imports)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
174
tests/spec/llm_selector_spec.lua
Normal file
174
tests/spec/llm_selector_spec.lua
Normal file
@@ -0,0 +1,174 @@
|
||||
--- Tests for smart LLM selection with memory-based confidence
|
||||
|
||||
describe("codetyper.llm.selector", function()
|
||||
local selector
|
||||
|
||||
before_each(function()
|
||||
selector = require("codetyper.llm.selector")
|
||||
-- Reset stats for clean tests
|
||||
selector.reset_accuracy_stats()
|
||||
end)
|
||||
|
||||
describe("select_provider", function()
|
||||
it("should return copilot when no brain memories exist", function()
|
||||
local result = selector.select_provider("write a function", {
|
||||
file_path = "/test/file.lua",
|
||||
})
|
||||
|
||||
assert.equals("copilot", result.provider)
|
||||
assert.equals(0, result.memory_count)
|
||||
assert.truthy(result.reason:match("Insufficient context"))
|
||||
end)
|
||||
|
||||
it("should return a valid selection result structure", function()
|
||||
local result = selector.select_provider("test prompt", {})
|
||||
|
||||
assert.is_string(result.provider)
|
||||
assert.is_number(result.confidence)
|
||||
assert.is_number(result.memory_count)
|
||||
assert.is_string(result.reason)
|
||||
end)
|
||||
|
||||
it("should have confidence between 0 and 1", function()
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
assert.truthy(result.confidence >= 0)
|
||||
assert.truthy(result.confidence <= 1)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("should_ponder", function()
|
||||
it("should return true for medium confidence", function()
|
||||
assert.is_true(selector.should_ponder(0.5))
|
||||
assert.is_true(selector.should_ponder(0.6))
|
||||
end)
|
||||
|
||||
it("should return false for low confidence", function()
|
||||
assert.is_false(selector.should_ponder(0.2))
|
||||
assert.is_false(selector.should_ponder(0.3))
|
||||
end)
|
||||
|
||||
-- High confidence pondering is probabilistic, so we test the range
|
||||
it("should sometimes ponder for high confidence (sampling)", function()
|
||||
-- Run multiple times to test probabilistic behavior
|
||||
local pondered_count = 0
|
||||
for _ = 1, 100 do
|
||||
if selector.should_ponder(0.9) then
|
||||
pondered_count = pondered_count + 1
|
||||
end
|
||||
end
|
||||
-- Should ponder roughly 20% of the time (PONDER_SAMPLE_RATE = 0.2)
|
||||
-- Allow range of 5-40% due to randomness
|
||||
assert.truthy(pondered_count >= 5, "Should ponder at least sometimes")
|
||||
assert.truthy(pondered_count <= 40, "Should not ponder too often")
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_accuracy_stats", function()
|
||||
it("should return initial empty stats", function()
|
||||
local stats = selector.get_accuracy_stats()
|
||||
|
||||
assert.equals(0, stats.ollama.total)
|
||||
assert.equals(0, stats.ollama.correct)
|
||||
assert.equals(0, stats.ollama.accuracy)
|
||||
assert.equals(0, stats.copilot.total)
|
||||
assert.equals(0, stats.copilot.correct)
|
||||
assert.equals(0, stats.copilot.accuracy)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("report_feedback", function()
|
||||
it("should track positive feedback", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(3, stats.ollama.total)
|
||||
assert.equals(2, stats.ollama.correct)
|
||||
end)
|
||||
|
||||
it("should track copilot feedback separately", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("copilot", true)
|
||||
selector.report_feedback("copilot", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(1, stats.ollama.total)
|
||||
assert.equals(2, stats.copilot.total)
|
||||
end)
|
||||
|
||||
it("should calculate accuracy correctly", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(0.75, stats.ollama.accuracy)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("reset_accuracy_stats", function()
|
||||
it("should clear all stats", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("copilot", true)
|
||||
|
||||
selector.reset_accuracy_stats()
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(0, stats.ollama.total)
|
||||
assert.equals(0, stats.copilot.total)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("agreement calculation", function()
|
||||
-- Test the internal agreement calculation through pondering behavior
|
||||
-- Since calculate_agreement is local, we test its effects indirectly
|
||||
|
||||
it("should detect high agreement for similar responses", function()
|
||||
-- This is tested through the pondering system
|
||||
-- When responses are similar, agreement should be high
|
||||
local selector = require("codetyper.llm.selector")
|
||||
|
||||
-- Verify that should_ponder returns predictable results
|
||||
-- for medium confidence (where pondering always happens)
|
||||
assert.is_true(selector.should_ponder(0.5))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("provider selection with accuracy history", function()
|
||||
local selector
|
||||
|
||||
before_each(function()
|
||||
selector = require("codetyper.llm.selector")
|
||||
selector.reset_accuracy_stats()
|
||||
end)
|
||||
|
||||
it("should factor in historical accuracy for selection", function()
|
||||
-- Simulate high Ollama accuracy
|
||||
for _ = 1, 10 do
|
||||
selector.report_feedback("ollama", true)
|
||||
end
|
||||
|
||||
-- Even with no brain context, historical accuracy should influence confidence
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
-- Confidence should be higher due to historical accuracy
|
||||
-- but provider might still be copilot if no memories
|
||||
assert.is_number(result.confidence)
|
||||
end)
|
||||
|
||||
it("should have lower confidence for low historical accuracy", function()
|
||||
-- Simulate low Ollama accuracy
|
||||
for _ = 1, 10 do
|
||||
selector.report_feedback("ollama", false)
|
||||
end
|
||||
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
-- With bad history and no memories, should definitely use copilot
|
||||
assert.equals("copilot", result.provider)
|
||||
end)
|
||||
end)
|
||||
Reference in New Issue
Block a user