Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f5df1a9ac0 | |||
| 84c8bcf92c |
54
CHANGELOG.md
54
CHANGELOG.md
@@ -7,6 +7,57 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.5.0] - 2026-01-15
|
||||
|
||||
### Added
|
||||
|
||||
- **Cost Tracking System** - Track LLM API costs across sessions
|
||||
- New `:CoderCost` command opens cost estimation floating window
|
||||
- Session costs tracked in real-time
|
||||
- All-time costs persisted in `.coder/cost_history.json`
|
||||
- Per-model breakdown with token counts
|
||||
- Pricing database for 50+ models (GPT-4/5, Claude, O-series, Gemini)
|
||||
- Window keymaps: `q` close, `r` refresh, `c` clear session, `C` clear all
|
||||
|
||||
- **Automatic Ollama Fallback** - Graceful degradation when API limits hit
|
||||
- Automatically switches to Ollama when Copilot rate limits exceeded
|
||||
- Detects local Ollama availability before fallback
|
||||
- Notifies user of provider switch
|
||||
|
||||
- **Enhanced Error Handling** - Better error messages for API failures
|
||||
- Shows actual API response on parse errors (not generic "failed to parse")
|
||||
- Improved rate limit detection and messaging
|
||||
- Sanitized newlines in error notifications to prevent UI crashes
|
||||
|
||||
- **Agent Tools System Improvements**
|
||||
- New `to_openai_format()` and `to_claude_format()` functions
|
||||
- `get_definitions()` for generic tool access
|
||||
- Fixed tool call argument serialization (JSON strings vs tables)
|
||||
|
||||
- **Credentials Management System** - Store API keys outside of config files
|
||||
- New `:CoderAddApiKey` command for interactive credential setup
|
||||
- `:CoderRemoveApiKey` to remove stored credentials
|
||||
- `:CoderCredentials` to view credential status
|
||||
- `:CoderSwitchProvider` to switch active LLM provider
|
||||
- Credentials stored in `~/.local/share/nvim/codetyper/configuration.json`
|
||||
- Priority: stored credentials > config > environment variables
|
||||
- Supports all providers: Claude, OpenAI, Gemini, Copilot, Ollama
|
||||
|
||||
### Changed
|
||||
|
||||
- Cost window now shows both session and all-time statistics
|
||||
- Improved agent prompt templates with correct tool names
|
||||
- Better error context in LLM provider responses
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed "Failed to parse Copilot response" error showing instead of actual error
|
||||
- Fixed `nvim_buf_set_lines` crash from newlines in error messages
|
||||
- Fixed `tools.definitions` nil error in agent initialization
|
||||
- Fixed tool name mismatch in agent prompts (write_file vs write)
|
||||
|
||||
---
|
||||
|
||||
## [0.4.0] - 2026-01-13
|
||||
|
||||
### Added
|
||||
@@ -194,7 +245,8 @@ scheduler = {
|
||||
- **Fixed** - Bug fixes
|
||||
- **Security** - Vulnerability fixes
|
||||
|
||||
[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...HEAD
|
||||
[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.5.0...HEAD
|
||||
[0.5.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...v0.5.0
|
||||
[0.4.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.3.0...v0.4.0
|
||||
[0.3.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.2.0...v0.3.0
|
||||
[0.2.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.1.0...v0.2.0
|
||||
|
||||
214
README.md
214
README.md
@@ -20,8 +20,10 @@
|
||||
- 🛡️ **Completion-Aware**: Safe injection that doesn't fight with autocomplete
|
||||
- 📁 **Auto-Index**: Automatically create coder companion files on file open
|
||||
- 📜 **Logs Panel**: Real-time visibility into LLM requests and token usage
|
||||
- 💰 **Cost Tracking**: Persistent LLM cost estimation with session and all-time stats
|
||||
- 🔒 **Git Integration**: Automatically adds `.coder.*` files to `.gitignore`
|
||||
- 🌳 **Project Tree Logging**: Maintains a `tree.log` tracking your project structure
|
||||
- 🧠 **Brain System**: Knowledge graph that learns from your coding patterns
|
||||
|
||||
---
|
||||
|
||||
@@ -34,6 +36,8 @@
|
||||
- [LLM Providers](#-llm-providers)
|
||||
- [Commands Reference](#-commands-reference)
|
||||
- [Usage Guide](#-usage-guide)
|
||||
- [Logs Panel](#-logs-panel)
|
||||
- [Cost Tracking](#-cost-tracking)
|
||||
- [Agent Mode](#-agent-mode)
|
||||
- [Keymaps](#-keymaps)
|
||||
- [Health Check](#-health-check)
|
||||
@@ -196,6 +200,32 @@ require("codetyper").setup({
|
||||
| `OPENAI_API_KEY` | OpenAI API key |
|
||||
| `GEMINI_API_KEY` | Google Gemini API key |
|
||||
|
||||
### Credentials Management
|
||||
|
||||
Instead of storing API keys in your config (which may be committed to git), you can use the credentials system:
|
||||
|
||||
```vim
|
||||
:CoderAddApiKey
|
||||
```
|
||||
|
||||
This command interactively prompts for:
|
||||
1. Provider selection (Claude, OpenAI, Gemini, Copilot, Ollama)
|
||||
2. API key (for cloud providers)
|
||||
3. Model name
|
||||
4. Custom endpoint (for OpenAI-compatible APIs)
|
||||
|
||||
Credentials are stored securely in `~/.local/share/nvim/codetyper/configuration.json` (not in your config files).
|
||||
|
||||
**Priority order for credentials:**
|
||||
1. Stored credentials (via `:CoderAddApiKey`)
|
||||
2. Config file settings
|
||||
3. Environment variables
|
||||
|
||||
**Other credential commands:**
|
||||
- `:CoderCredentials` - View configured providers
|
||||
- `:CoderSwitchProvider` - Switch between configured providers
|
||||
- `:CoderRemoveApiKey` - Remove stored credentials
|
||||
|
||||
---
|
||||
|
||||
## 🔌 LLM Providers
|
||||
@@ -255,49 +285,129 @@ llm = {
|
||||
|
||||
## 📝 Commands Reference
|
||||
|
||||
### Main Commands
|
||||
All commands can be invoked via `:Coder {subcommand}` or their dedicated command aliases.
|
||||
|
||||
### Core Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder open` | `:CoderOpen` | Open the coder split view |
|
||||
| `:Coder close` | `:CoderClose` | Close the coder split view |
|
||||
| `:Coder toggle` | `:CoderToggle` | Toggle the coder split view |
|
||||
| `:Coder process` | `:CoderProcess` | Process the last prompt in coder file |
|
||||
| `:Coder status` | - | Show plugin status and configuration |
|
||||
| `:Coder focus` | - | Switch focus between coder and target windows |
|
||||
| `:Coder reset` | - | Reset processed prompts to allow re-processing |
|
||||
| `:Coder gitignore` | - | Force update .gitignore with coder patterns |
|
||||
|
||||
### Ask Panel (Chat Interface)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder ask` | `:CoderAsk` | Open the Ask panel |
|
||||
| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle the Ask panel |
|
||||
| `:Coder ask-close` | - | Close the Ask panel |
|
||||
| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history |
|
||||
|
||||
### Agent Mode (Autonomous Coding)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agent` | `:CoderAgent` | Open the Agent panel |
|
||||
| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle the Agent panel |
|
||||
| `:Coder agent-close` | - | Close the Agent panel |
|
||||
| `:Coder agent-stop` | `:CoderAgentStop` | Stop the running agent |
|
||||
|
||||
### Agentic Mode (IDE-like Multi-file Agent)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agentic-run <task>` | `:CoderAgenticRun <task>` | Run an agentic task (multi-file changes) |
|
||||
| `:Coder agentic-list` | `:CoderAgenticList` | List available agents |
|
||||
| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize `.coder/agents/` and `.coder/rules/` |
|
||||
|
||||
### Transform Commands (Inline Tag Processing)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder transform` | `:CoderTransform` | Transform all `/@ @/` tags in file |
|
||||
| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor position |
|
||||
| - | `:CoderTransformVisual` | Transform selected tags (visual mode) |
|
||||
|
||||
### Project & Index Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| - | `:CoderIndex` | Open coder companion for current file |
|
||||
| `:Coder index-project` | `:CoderIndexProject` | Index the entire project |
|
||||
| `:Coder index-status` | `:CoderIndexStatus` | Show project index status |
|
||||
|
||||
### Tree & Structure Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder tree` | `:CoderTree` | Refresh `.coder/tree.log` |
|
||||
| `:Coder tree-view` | `:CoderTreeView` | View `.coder/tree.log` in split |
|
||||
|
||||
### Queue & Scheduler Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler and queue status |
|
||||
| `:Coder queue-process` | `:CoderQueueProcess` | Manually trigger queue processing |
|
||||
|
||||
### Processing Mode Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle automatic/manual prompt processing |
|
||||
| `:Coder auto-set <mode>` | `:CoderAutoSet <mode>` | Set processing mode (`auto`/`manual`) |
|
||||
|
||||
### Memory & Learning Commands
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder memories` | `:CoderMemories` | Show learned memories |
|
||||
| `:Coder forget [pattern]` | `:CoderForget [pattern]` | Clear memories (optionally matching pattern) |
|
||||
|
||||
### Brain Commands (Knowledge Graph)
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| - | `:CoderBrain [action]` | Brain management (`stats`/`commit`/`flush`/`prune`) |
|
||||
| - | `:CoderFeedback <type>` | Give feedback to brain (`good`/`bad`/`stats`) |
|
||||
|
||||
### LLM Statistics & Feedback
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:Coder {subcommand}` | Main command with subcommands |
|
||||
| `:CoderOpen` | Open the coder split view |
|
||||
| `:CoderClose` | Close the coder split view |
|
||||
| `:CoderToggle` | Toggle the coder split view |
|
||||
| `:CoderProcess` | Process the last prompt |
|
||||
| `:Coder llm-stats` | Show LLM provider accuracy statistics |
|
||||
| `:Coder llm-feedback-good` | Report positive feedback on last response |
|
||||
| `:Coder llm-feedback-bad` | Report negative feedback on last response |
|
||||
| `:Coder llm-reset-stats` | Reset LLM accuracy statistics |
|
||||
|
||||
### Ask Panel
|
||||
### Cost Tracking
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderAsk` | Open the Ask panel |
|
||||
| `:CoderAskToggle` | Toggle the Ask panel |
|
||||
| `:CoderAskClear` | Clear chat history |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder cost` | `:CoderCost` | Show LLM cost estimation window |
|
||||
| `:Coder cost-clear` | - | Clear session cost tracking |
|
||||
|
||||
### Agent Mode
|
||||
### Credentials Management
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderAgent` | Open the Agent panel |
|
||||
| `:CoderAgentToggle` | Toggle the Agent panel |
|
||||
| `:CoderAgentStop` | Stop the running agent |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder add-api-key` | `:CoderAddApiKey` | Add or update LLM provider API key |
|
||||
| `:Coder remove-api-key` | `:CoderRemoveApiKey` | Remove LLM provider credentials |
|
||||
| `:Coder credentials` | `:CoderCredentials` | Show credentials status |
|
||||
| `:Coder switch-provider` | `:CoderSwitchProvider` | Switch active LLM provider |
|
||||
|
||||
### Transform Commands
|
||||
### UI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderTransform` | Transform all /@ @/ tags in file |
|
||||
| `:CoderTransformCursor` | Transform tag at cursor position |
|
||||
| `:CoderTransformVisual` | Transform selected tags (visual mode) |
|
||||
|
||||
### Utility Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderIndex` | Open coder companion for current file |
|
||||
| `:CoderLogs` | Toggle logs panel |
|
||||
| `:CoderType` | Switch between Ask/Agent modes |
|
||||
| `:CoderTree` | Refresh tree.log |
|
||||
| `:CoderTreeView` | View tree.log |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder type-toggle` | `:CoderType` | Show Ask/Agent mode switcher |
|
||||
| `:Coder logs-toggle` | `:CoderLogs` | Toggle logs panel |
|
||||
|
||||
---
|
||||
|
||||
@@ -384,6 +494,42 @@ The logs panel opens automatically when processing prompts with the scheduler en
|
||||
|
||||
---
|
||||
|
||||
## 💰 Cost Tracking
|
||||
|
||||
Track your LLM API costs across sessions with the Cost Estimation window.
|
||||
|
||||
### Features
|
||||
|
||||
- **Session Tracking**: Monitor current session token usage and costs
|
||||
- **All-Time Tracking**: Persistent cost history stored per-project in `.coder/cost_history.json`
|
||||
- **Model Breakdown**: See costs by individual model
|
||||
- **Pricing Database**: Built-in pricing for 50+ models (GPT, Claude, Gemini, O-series, etc.)
|
||||
|
||||
### Opening the Cost Window
|
||||
|
||||
```vim
|
||||
:CoderCost
|
||||
```
|
||||
|
||||
### Cost Window Keymaps
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `q` / `<Esc>` | Close window |
|
||||
| `r` | Refresh display |
|
||||
| `c` | Clear session costs |
|
||||
| `C` | Clear all history |
|
||||
|
||||
### Supported Models
|
||||
|
||||
The cost tracker includes pricing for:
|
||||
- **OpenAI**: GPT-4, GPT-4o, GPT-4o-mini, O1, O3, O4-mini, and more
|
||||
- **Anthropic**: Claude 3 Opus, Sonnet, Haiku, Claude 3.5 Sonnet/Haiku
|
||||
- **Local**: Ollama models (free, but usage tracked)
|
||||
- **Copilot**: Usage tracked (included in subscription)
|
||||
|
||||
---
|
||||
|
||||
## 🤖 Agent Mode
|
||||
|
||||
The Agent mode provides an autonomous coding assistant with tool access:
|
||||
|
||||
188
llms.txt
188
llms.txt
@@ -33,6 +33,8 @@ lua/codetyper/
|
||||
├── health.lua # Health check for :checkhealth
|
||||
├── tree.lua # Project tree logging (.coder/tree.log)
|
||||
├── logs_panel.lua # Standalone logs panel UI
|
||||
├── cost.lua # LLM cost tracking with persistent history
|
||||
├── credentials.lua # Secure credential storage (API keys, models)
|
||||
├── llm/
|
||||
│ ├── init.lua # LLM interface, provider selection
|
||||
│ ├── claude.lua # Claude API client (Anthropic)
|
||||
@@ -68,7 +70,14 @@ The plugin automatically creates and maintains a `.coder/` folder in your projec
|
||||
|
||||
```
|
||||
.coder/
|
||||
└── tree.log # Project structure, auto-updated on file changes
|
||||
├── tree.log # Project structure, auto-updated on file changes
|
||||
├── cost_history.json # LLM cost tracking history (persistent)
|
||||
├── brain/ # Knowledge graph storage
|
||||
│ ├── nodes/ # Learning nodes by type
|
||||
│ ├── indices/ # Search indices
|
||||
│ └── deltas/ # Version history
|
||||
├── agents/ # Custom agent definitions
|
||||
└── rules/ # Project-specific rules
|
||||
```
|
||||
|
||||
## Key Features
|
||||
@@ -115,7 +124,49 @@ auto_index = true -- disabled by default
|
||||
|
||||
Real-time visibility into LLM operations with token usage tracking.
|
||||
|
||||
### 6. Event-Driven Scheduler
|
||||
### 6. Cost Tracking
|
||||
|
||||
Track LLM API costs across sessions:
|
||||
|
||||
- **Session tracking**: Monitor current session costs in real-time
|
||||
- **All-time tracking**: Persistent history in `.coder/cost_history.json`
|
||||
- **Per-model breakdown**: See costs by individual model
|
||||
- **50+ models**: Built-in pricing for GPT, Claude, O-series, Gemini
|
||||
|
||||
Cost window keymaps:
|
||||
- `q`/`<Esc>` - Close window
|
||||
- `r` - Refresh display
|
||||
- `c` - Clear session costs
|
||||
- `C` - Clear all history
|
||||
|
||||
### 7. Automatic Ollama Fallback
|
||||
|
||||
When API rate limits are hit (e.g., Copilot free tier), the plugin:
|
||||
1. Detects the rate limit error
|
||||
2. Checks if local Ollama is available
|
||||
3. Automatically switches provider to Ollama
|
||||
4. Notifies user of the provider change
|
||||
|
||||
### 8. Credentials Management
|
||||
|
||||
Store API keys securely outside of config files:
|
||||
|
||||
```vim
|
||||
:CoderAddApiKey
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Interactive prompts for provider, API key, model, endpoint
|
||||
- Stored in `~/.local/share/nvim/codetyper/configuration.json`
|
||||
- Supports all providers: Claude, OpenAI, Gemini, Copilot, Ollama
|
||||
- Switch providers at runtime with `:CoderSwitchProvider`
|
||||
|
||||
**Credential priority:**
|
||||
1. Stored credentials (via `:CoderAddApiKey`)
|
||||
2. Config file settings (`require("codetyper").setup({...})`)
|
||||
3. Environment variables (`OPENAI_API_KEY`, etc.)
|
||||
|
||||
### 9. Event-Driven Scheduler
|
||||
|
||||
Prompts are treated as events, not commands:
|
||||
|
||||
@@ -143,7 +194,7 @@ scheduler = {
|
||||
}
|
||||
```
|
||||
|
||||
### 7. Tree-sitter Scope Resolution
|
||||
### 10. Tree-sitter Scope Resolution
|
||||
|
||||
Prompts automatically resolve to their enclosing function/method/class:
|
||||
|
||||
@@ -158,7 +209,7 @@ end
|
||||
For replacement intents (complete, refactor, fix), the entire scope is extracted
|
||||
and sent to the LLM, then replaced with the transformed version.
|
||||
|
||||
### 8. Intent Detection
|
||||
### 11. Intent Detection
|
||||
|
||||
The system parses prompts to detect user intent:
|
||||
|
||||
@@ -173,7 +224,7 @@ The system parses prompts to detect user intent:
|
||||
| optimize | optimize, performance, faster | replace |
|
||||
| explain | explain, what, how, why | none |
|
||||
|
||||
### 9. Tag Precedence
|
||||
### 12. Tag Precedence
|
||||
|
||||
Multiple tags in the same scope follow "first tag wins" rule:
|
||||
- Earlier (by line number) unresolved tag processes first
|
||||
@@ -182,33 +233,114 @@ Multiple tags in the same scope follow "first tag wins" rule:
|
||||
|
||||
## Commands
|
||||
|
||||
### Main Commands
|
||||
- `:Coder open` - Opens split view with coder file
|
||||
- `:Coder close` - Closes the split
|
||||
- `:Coder toggle` - Toggles the view
|
||||
- `:Coder process` - Manually triggers code generation
|
||||
All commands can be invoked via `:Coder {subcommand}` or dedicated aliases.
|
||||
|
||||
### Ask Panel
|
||||
- `:CoderAsk` - Open Ask panel
|
||||
- `:CoderAskToggle` - Toggle Ask panel
|
||||
- `:CoderAskClear` - Clear chat history
|
||||
### Core Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder open` | `:CoderOpen` | Open coder split view |
|
||||
| `:Coder close` | `:CoderClose` | Close coder split view |
|
||||
| `:Coder toggle` | `:CoderToggle` | Toggle coder split view |
|
||||
| `:Coder process` | `:CoderProcess` | Process last prompt in coder file |
|
||||
| `:Coder status` | - | Show plugin status and configuration |
|
||||
| `:Coder focus` | - | Switch focus between coder/target windows |
|
||||
| `:Coder reset` | - | Reset processed prompts |
|
||||
| `:Coder gitignore` | - | Force update .gitignore |
|
||||
|
||||
### Agent Mode
|
||||
- `:CoderAgent` - Open Agent panel
|
||||
- `:CoderAgentToggle` - Toggle Agent panel
|
||||
- `:CoderAgentStop` - Stop running agent
|
||||
### Ask Panel (Chat Interface)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder ask` | `:CoderAsk` | Open Ask panel |
|
||||
| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle Ask panel |
|
||||
| `:Coder ask-close` | - | Close Ask panel |
|
||||
| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history |
|
||||
|
||||
### Transform
|
||||
- `:CoderTransform` - Transform all tags
|
||||
- `:CoderTransformCursor` - Transform at cursor
|
||||
- `:CoderTransformVisual` - Transform selection
|
||||
### Agent Mode (Autonomous Coding)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agent` | `:CoderAgent` | Open Agent panel |
|
||||
| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle Agent panel |
|
||||
| `:Coder agent-close` | - | Close Agent panel |
|
||||
| `:Coder agent-stop` | `:CoderAgentStop` | Stop running agent |
|
||||
|
||||
### Utility
|
||||
- `:CoderIndex` - Open coder companion
|
||||
- `:CoderLogs` - Toggle logs panel
|
||||
- `:CoderType` - Switch Ask/Agent mode
|
||||
- `:CoderTree` - Refresh tree.log
|
||||
- `:CoderTreeView` - View tree.log
|
||||
### Agentic Mode (IDE-like Multi-file Agent)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agentic-run <task>` | `:CoderAgenticRun <task>` | Run agentic task |
|
||||
| `:Coder agentic-list` | `:CoderAgenticList` | List available agents |
|
||||
| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize .coder/agents/ and .coder/rules/ |
|
||||
|
||||
### Transform Commands (Inline Tag Processing)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder transform` | `:CoderTransform` | Transform all /@ @/ tags in file |
|
||||
| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor |
|
||||
| - | `:CoderTransformVisual` | Transform selected tags (visual mode) |
|
||||
|
||||
### Project & Index Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| - | `:CoderIndex` | Open coder companion for current file |
|
||||
| `:Coder index-project` | `:CoderIndexProject` | Index entire project |
|
||||
| `:Coder index-status` | `:CoderIndexStatus` | Show project index status |
|
||||
|
||||
### Tree & Structure Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder tree` | `:CoderTree` | Refresh .coder/tree.log |
|
||||
| `:Coder tree-view` | `:CoderTreeView` | View .coder/tree.log |
|
||||
|
||||
### Queue & Scheduler Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler/queue status |
|
||||
| `:Coder queue-process` | `:CoderQueueProcess` | Manually trigger queue processing |
|
||||
|
||||
### Processing Mode Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle automatic/manual processing |
|
||||
| `:Coder auto-set <mode>` | `:CoderAutoSet <mode>` | Set mode (auto/manual) |
|
||||
|
||||
### Memory & Learning Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder memories` | `:CoderMemories` | Show learned memories |
|
||||
| `:Coder forget [pattern]` | `:CoderForget [pattern]` | Clear memories |
|
||||
|
||||
### Brain Commands (Knowledge Graph)
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| - | `:CoderBrain [action]` | Brain management (stats/commit/flush/prune) |
|
||||
| - | `:CoderFeedback <type>` | Give feedback (good/bad/stats) |
|
||||
|
||||
### LLM Statistics & Feedback
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:Coder llm-stats` | Show LLM provider accuracy stats |
|
||||
| `:Coder llm-feedback-good` | Report positive feedback |
|
||||
| `:Coder llm-feedback-bad` | Report negative feedback |
|
||||
| `:Coder llm-reset-stats` | Reset LLM accuracy stats |
|
||||
|
||||
### Cost Tracking
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder cost` | `:CoderCost` | Show LLM cost estimation window |
|
||||
| `:Coder cost-clear` | - | Clear session cost tracking |
|
||||
|
||||
### Credentials Management
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder add-api-key` | `:CoderAddApiKey` | Add/update LLM provider credentials |
|
||||
| `:Coder remove-api-key` | `:CoderRemoveApiKey` | Remove provider credentials |
|
||||
| `:Coder credentials` | `:CoderCredentials` | Show credentials status |
|
||||
| `:Coder switch-provider` | `:CoderSwitchProvider` | Switch active provider |
|
||||
|
||||
### UI Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder type-toggle` | `:CoderType` | Show Ask/Agent mode switcher |
|
||||
| `:Coder logs-toggle` | `:CoderLogs` | Toggle logs panel |
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
|
||||
854
lua/codetyper/agent/agentic.lua
Normal file
854
lua/codetyper/agent/agentic.lua
Normal file
@@ -0,0 +1,854 @@
|
||||
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
|
||||
---@brief [[
|
||||
--- Full agentic system that handles multi-file changes via tool calling.
|
||||
--- Inspired by avante.nvim and opencode patterns.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgenticMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_calls? table[] For assistant messages with tool calls
|
||||
---@field tool_call_id? string For tool result messages
|
||||
---@field name? string Tool name for tool results
|
||||
|
||||
---@class AgenticToolCall
|
||||
---@field id string Unique tool call ID
|
||||
---@field type "function"
|
||||
---@field function {name: string, arguments: string|table}
|
||||
|
||||
---@class AgenticOpts
|
||||
---@field task string The task to accomplish
|
||||
---@field files? string[] Initial files to include as context
|
||||
---@field agent? string Agent name to use (default: "coder")
|
||||
---@field model? string Model override
|
||||
---@field max_iterations? number Max tool call rounds (default: 20)
|
||||
---@field on_message? fun(msg: AgenticMessage) Called for each message
|
||||
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
|
||||
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_file_change? fun(path: string, action: string) Called when file is modified
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
|
||||
---@field on_status? fun(status: string) Status updates
|
||||
|
||||
--- Generate unique tool call ID
|
||||
local function generate_tool_call_id()
|
||||
return "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF))
|
||||
end
|
||||
|
||||
--- Load agent definition
|
||||
---@param name string Agent name
|
||||
---@return table|nil agent definition
|
||||
local function load_agent(name)
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
local agent_file = agents_dir .. "/" .. name .. ".md"
|
||||
|
||||
-- Check if custom agent exists
|
||||
if vim.fn.filereadable(agent_file) == 1 then
|
||||
local content = table.concat(vim.fn.readfile(agent_file), "\n")
|
||||
-- Parse frontmatter and content
|
||||
local frontmatter = {}
|
||||
local body = content
|
||||
|
||||
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
|
||||
if fm_match then
|
||||
-- Parse YAML-like frontmatter
|
||||
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
|
||||
local key, value = line:match("^(%w+):%s*(.+)$")
|
||||
if key and value then
|
||||
frontmatter[key] = value
|
||||
end
|
||||
end
|
||||
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
|
||||
end
|
||||
|
||||
return {
|
||||
name = name,
|
||||
description = frontmatter.description or "Custom agent: " .. name,
|
||||
system_prompt = body,
|
||||
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
|
||||
model = frontmatter.model,
|
||||
}
|
||||
end
|
||||
|
||||
-- Built-in agents
|
||||
local builtin_agents = {
|
||||
coder = {
|
||||
name = "coder",
|
||||
description = "Full-featured coding agent with file modification capabilities",
|
||||
system_prompt = [[You are an expert software engineer. You have access to tools to read, write, and modify files.
|
||||
|
||||
## Your Capabilities
|
||||
- Read files to understand the codebase
|
||||
- Search for patterns with grep and glob
|
||||
- Create new files with write tool
|
||||
- Edit existing files with precise replacements
|
||||
- Execute shell commands for builds and tests
|
||||
|
||||
## Guidelines
|
||||
1. Always read relevant files before making changes
|
||||
2. Make minimal, focused changes
|
||||
3. Follow existing code style and patterns
|
||||
4. Create tests when adding new functionality
|
||||
5. Verify changes work by running tests or builds
|
||||
|
||||
## Important Rules
|
||||
- NEVER guess file contents - always read first
|
||||
- Make precise edits using exact string matching
|
||||
- Explain your reasoning before making changes
|
||||
- If unsure, ask for clarification]],
|
||||
tools = { "view", "edit", "write", "grep", "glob", "bash" },
|
||||
},
|
||||
planner = {
|
||||
name = "planner",
|
||||
description = "Planning agent - read-only, helps design implementations",
|
||||
system_prompt = [[You are a software architect. Analyze codebases and create implementation plans.
|
||||
|
||||
You can read files and search the codebase, but cannot modify files.
|
||||
Your role is to:
|
||||
1. Understand the existing architecture
|
||||
2. Identify relevant files and patterns
|
||||
3. Create step-by-step implementation plans
|
||||
4. Suggest which files to modify and how
|
||||
|
||||
Be thorough in your analysis before making recommendations.]],
|
||||
tools = { "view", "grep", "glob" },
|
||||
},
|
||||
explorer = {
|
||||
name = "explorer",
|
||||
description = "Exploration agent - quickly find information in codebase",
|
||||
system_prompt = [[You are a codebase exploration assistant. Find information quickly and report back.
|
||||
|
||||
Your goal is to efficiently search and summarize findings.
|
||||
Use glob to find files, grep to search content, and view to read specific files.
|
||||
Be concise and focused in your responses.]],
|
||||
tools = { "view", "grep", "glob" },
|
||||
},
|
||||
}
|
||||
|
||||
return builtin_agents[name]
|
||||
end
|
||||
|
||||
--- Load rules from .coder/rules/
|
||||
---@return string Combined rules content
|
||||
local function load_rules()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
local rules = {}
|
||||
|
||||
if vim.fn.isdirectory(rules_dir) == 1 then
|
||||
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local content = table.concat(vim.fn.readfile(file), "\n")
|
||||
local filename = vim.fn.fnamemodify(file, ":t:r")
|
||||
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
|
||||
end
|
||||
end
|
||||
|
||||
if #rules > 0 then
|
||||
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
--- Build messages array for API request
|
||||
---@param history AgenticMessage[]
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted messages
|
||||
local function build_messages(history, provider)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
if provider == "claude" then
|
||||
-- Claude uses system parameter, not message
|
||||
-- Skip system messages in array
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
if provider == "claude" then
|
||||
-- Claude format: content is array of blocks
|
||||
message.content = {}
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(message.content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
for _, tc in ipairs(msg.tool_calls) do
|
||||
table.insert(message.content, {
|
||||
type = "tool_use",
|
||||
id = tc.id,
|
||||
name = tc["function"].name,
|
||||
input = type(tc["function"].arguments) == "string"
|
||||
and vim.json.decode(tc["function"].arguments)
|
||||
or tc["function"].arguments,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
if provider == "claude" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = msg.tool_call_id,
|
||||
content = msg.content,
|
||||
},
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Build tools array for API request
|
||||
---@param tool_names string[] Tool names to include
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted tools
|
||||
local function build_tools(tool_names, provider)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local tools = {}
|
||||
|
||||
for _, name in ipairs(tool_names) do
|
||||
local tool = tools_mod.get(name)
|
||||
if tool then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
if provider == "claude" then
|
||||
table.insert(tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Execute a tool call
|
||||
---@param tool_call AgenticToolCall
|
||||
---@param opts AgenticOpts
|
||||
---@return string result
|
||||
---@return string|nil error
|
||||
local function execute_tool(tool_call, opts)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local name = tool_call["function"].name
|
||||
local args = tool_call["function"].arguments
|
||||
|
||||
-- Parse arguments if string
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
else
|
||||
return "", "Failed to parse tool arguments: " .. args
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify tool start
|
||||
if opts.on_tool_start then
|
||||
opts.on_tool_start(name, args)
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status("Executing: " .. name)
|
||||
end
|
||||
|
||||
-- Execute the tool
|
||||
local tool = tools_mod.get(name)
|
||||
if not tool then
|
||||
local err = "Unknown tool: " .. name
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, nil, err)
|
||||
end
|
||||
return "", err
|
||||
end
|
||||
|
||||
local result, err = tool.func(args, {
|
||||
on_log = function(msg)
|
||||
if opts.on_status then
|
||||
opts.on_status(msg)
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
-- Notify tool end
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, result, err)
|
||||
end
|
||||
|
||||
-- Track file changes
|
||||
if opts.on_file_change and (name == "write" or name == "edit") and not err then
|
||||
opts.on_file_change(args.path, name == "write" and "created" or "modified")
|
||||
end
|
||||
|
||||
if err then
|
||||
return "", err
|
||||
end
|
||||
|
||||
return type(result) == "string" and result or vim.json.encode(result), nil
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return AgenticToolCall[]
|
||||
local function parse_tool_calls(response, provider)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Unified format: content array with tool_use blocks
|
||||
local content = response.content or {}
|
||||
for _, block in ipairs(content) do
|
||||
if block.type == "tool_use" then
|
||||
-- OpenAI expects arguments as JSON string, not table
|
||||
local args = block.input
|
||||
if type(args) == "table" then
|
||||
args = vim.json.encode(args)
|
||||
end
|
||||
|
||||
table.insert(tool_calls, {
|
||||
id = block.id or generate_tool_call_id(),
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = block.name,
|
||||
arguments = args,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Extract text content from response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return string
|
||||
local function extract_content(response, provider)
|
||||
local parts = {}
|
||||
for _, block in ipairs(response.content or {}) do
|
||||
if block.type == "text" then
|
||||
table.insert(parts, block.text)
|
||||
end
|
||||
end
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
--- Check if response indicates completion (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return boolean
|
||||
local function is_complete(response, provider)
|
||||
return response.stop_reason == "end_turn"
|
||||
end
|
||||
|
||||
--- Make API request to LLM with native tool calling support
|
||||
---@param messages table[] Formatted messages
|
||||
---@param tools table[] Formatted tools
|
||||
---@param system_prompt string System prompt
|
||||
---@param provider string "openai"|"claude"|"copilot"
|
||||
---@param model string Model name
|
||||
---@param callback fun(response: table|nil, error: string|nil)
|
||||
local function call_llm(messages, tools, system_prompt, provider, model, callback)
|
||||
local context = {
|
||||
language = "lua",
|
||||
file_content = "",
|
||||
prompt_type = "agent",
|
||||
project_root = vim.fn.getcwd(),
|
||||
cwd = vim.fn.getcwd(),
|
||||
}
|
||||
|
||||
-- Use native tool calling APIs
|
||||
if provider == "copilot" then
|
||||
local client = require("codetyper.llm.copilot")
|
||||
|
||||
-- Copilot's generate_with_tools expects messages in a specific format
|
||||
-- Convert to the format it expects
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
-- Convert to our internal format
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or generate_tool_call_id(),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "openai" then
|
||||
local client = require("codetyper.llm.openai")
|
||||
|
||||
-- OpenAI's generate_with_tools
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or generate_tool_call_id(),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "ollama" then
|
||||
local client = require("codetyper.llm.ollama")
|
||||
|
||||
-- Ollama's generate_with_tools (text-based tool calling)
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
callback(response, nil)
|
||||
end)
|
||||
else
|
||||
-- Fallback for other providers (ollama, etc.) - use text-based parsing
|
||||
local client = require("codetyper.llm." .. provider)
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, "User: " .. content)
|
||||
elseif msg.role == "assistant" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, "Assistant: " .. content)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add tool descriptions to prompt for text-based providers
|
||||
local tool_desc = "\n\n## Available Tools\n"
|
||||
tool_desc = tool_desc .. "Call tools by outputting JSON in this format:\n"
|
||||
tool_desc = tool_desc .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
|
||||
for _, tool in ipairs(tools) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
context.file_content = system_prompt .. tool_desc
|
||||
|
||||
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Parse response for tool calls (text-based fallback)
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
-- Extract text content
|
||||
local text_content = response
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = generate_tool_call_id(),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
if text_content and text_content ~= "" then
|
||||
table.insert(result.content, 1, { type = "text", text = text_content })
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Run the agentic loop
|
||||
---@param opts AgenticOpts
|
||||
function M.run(opts)
|
||||
-- Load agent
|
||||
local agent = load_agent(opts.agent or "coder")
|
||||
if not agent then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Load rules
|
||||
local rules = load_rules()
|
||||
|
||||
-- Build system prompt
|
||||
local system_prompt = agent.system_prompt .. rules
|
||||
|
||||
-- Initialize message history
|
||||
---@type AgenticMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = system_prompt },
|
||||
}
|
||||
|
||||
-- Add initial file context if provided
|
||||
if opts.files and #opts.files > 0 then
|
||||
local file_context = "# Initial Files\n"
|
||||
for _, file_path in ipairs(opts.files) do
|
||||
local content = table.concat(vim.fn.readfile(file_path) or {}, "\n")
|
||||
file_context = file_context .. string.format("\n## %s\n```\n%s\n```\n", file_path, content)
|
||||
end
|
||||
table.insert(history, { role = "user", content = file_context })
|
||||
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
|
||||
end
|
||||
|
||||
-- Add the task
|
||||
table.insert(history, { role = "user", content = opts.task })
|
||||
|
||||
-- Determine provider
|
||||
local config = require("codetyper").get_config()
|
||||
local provider = config.llm.provider or "copilot"
|
||||
-- Note: Ollama has its own handler in call_llm, don't change it
|
||||
|
||||
-- Get tools for this agent
|
||||
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
|
||||
|
||||
-- Ensure tools are loaded
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
tools_mod.setup()
|
||||
|
||||
-- Build tools for API
|
||||
local tools = build_tools(tool_names, provider)
|
||||
|
||||
-- Iteration tracking
|
||||
local iteration = 0
|
||||
local max_iterations = opts.max_iterations or 20
|
||||
|
||||
--- Process one iteration
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
|
||||
end
|
||||
|
||||
-- Build messages for API
|
||||
local messages = build_messages(history, provider)
|
||||
|
||||
-- Call LLM
|
||||
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
|
||||
if err then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, err)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract content and tool calls
|
||||
local content = extract_content(response, provider)
|
||||
local tool_calls = parse_tool_calls(response, provider)
|
||||
|
||||
-- Add assistant message to history
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(assistant_msg)
|
||||
end
|
||||
|
||||
-- Process tool calls if any
|
||||
if #tool_calls > 0 then
|
||||
for _, tc in ipairs(tool_calls) do
|
||||
local result, tool_err = execute_tool(tc, opts)
|
||||
|
||||
-- Add tool result to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = tc.id,
|
||||
name = tc["function"].name,
|
||||
content = tool_err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(tool_msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Continue the loop
|
||||
vim.schedule(process_iteration)
|
||||
else
|
||||
-- No tool calls - check if complete
|
||||
if is_complete(response, provider) or content ~= "" then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(content, nil)
|
||||
end
|
||||
else
|
||||
-- Continue if not explicitly complete
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create default agent files in .coder/agents/
|
||||
function M.init_agents_dir()
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
vim.fn.mkdir(agents_dir, "p")
|
||||
|
||||
-- Create example agent
|
||||
local example_agent = [[---
|
||||
description: Example custom agent
|
||||
tools: view,grep,glob,edit,write
|
||||
model:
|
||||
---
|
||||
|
||||
# Custom Agent
|
||||
|
||||
You are a custom coding agent. Describe your specialized behavior here.
|
||||
|
||||
## Your Role
|
||||
- Define what this agent specializes in
|
||||
- List specific capabilities
|
||||
|
||||
## Guidelines
|
||||
- Add agent-specific rules
|
||||
- Define coding standards to follow
|
||||
|
||||
## Examples
|
||||
Provide examples of how to handle common tasks.
|
||||
]]
|
||||
|
||||
local example_path = agents_dir .. "/example.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
|
||||
end
|
||||
|
||||
return agents_dir
|
||||
end
|
||||
|
||||
--- Create default rules in .coder/rules/
|
||||
function M.init_rules_dir()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
vim.fn.mkdir(rules_dir, "p")
|
||||
|
||||
-- Create example rule
|
||||
local example_rule = [[# Code Style
|
||||
|
||||
Follow these coding standards:
|
||||
|
||||
## General
|
||||
- Use consistent indentation (tabs or spaces based on project)
|
||||
- Keep lines under 100 characters
|
||||
- Add comments for complex logic
|
||||
|
||||
## Naming Conventions
|
||||
- Use descriptive variable names
|
||||
- Functions should be verbs (e.g., getUserData, calculateTotal)
|
||||
- Constants in UPPER_SNAKE_CASE
|
||||
|
||||
## Testing
|
||||
- Write tests for new functionality
|
||||
- Aim for >80% code coverage
|
||||
- Test edge cases
|
||||
|
||||
## Documentation
|
||||
- Document public APIs
|
||||
- Include usage examples
|
||||
- Keep docs up to date with code
|
||||
]]
|
||||
|
||||
local example_path = rules_dir .. "/code-style.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
|
||||
end
|
||||
|
||||
return rules_dir
|
||||
end
|
||||
|
||||
--- Initialize both agents and rules directories
|
||||
function M.init()
|
||||
M.init_agents_dir()
|
||||
M.init_rules_dir()
|
||||
end
|
||||
|
||||
--- List available agents
|
||||
---@return table[] List of {name, description, builtin}
|
||||
function M.list_agents()
|
||||
local agents = {}
|
||||
|
||||
-- Built-in agents
|
||||
local builtins = { "coder", "planner", "explorer" }
|
||||
for _, name in ipairs(builtins) do
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = true,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Custom agents from .coder/agents/
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
if vim.fn.isdirectory(agents_dir) == 1 then
|
||||
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local name = vim.fn.fnamemodify(file, ":t:r")
|
||||
if not vim.tbl_contains(builtins, name) then
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return agents
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -8,11 +8,11 @@ local M = {}
|
||||
|
||||
--- Heuristic weights (must sum to 1.0)
|
||||
M.weights = {
|
||||
length = 0.15, -- Response length relative to prompt
|
||||
length = 0.15, -- Response length relative to prompt
|
||||
uncertainty = 0.30, -- Uncertainty phrases
|
||||
syntax = 0.25, -- Syntax completeness
|
||||
repetition = 0.15, -- Duplicate lines
|
||||
truncation = 0.15, -- Incomplete ending
|
||||
syntax = 0.25, -- Syntax completeness
|
||||
repetition = 0.15, -- Duplicate lines
|
||||
truncation = 0.15, -- Incomplete ending
|
||||
}
|
||||
|
||||
--- Uncertainty phrases that indicate low confidence
|
||||
@@ -255,14 +255,15 @@ function M.score(response, prompt, context)
|
||||
_ = context -- Reserved for future use
|
||||
|
||||
if not response or #response == 0 then
|
||||
return 0, {
|
||||
length = 0,
|
||||
uncertainty = 0,
|
||||
syntax = 0,
|
||||
repetition = 0,
|
||||
truncation = 0,
|
||||
weighted_total = 0,
|
||||
}
|
||||
return 0,
|
||||
{
|
||||
length = 0,
|
||||
uncertainty = 0,
|
||||
syntax = 0,
|
||||
repetition = 0,
|
||||
truncation = 0,
|
||||
weighted_total = 0,
|
||||
}
|
||||
end
|
||||
|
||||
local scores = {
|
||||
|
||||
@@ -9,7 +9,20 @@ local M = {}
|
||||
---@param callback fun(approved: boolean) Called with user decision
|
||||
function M.show_diff(diff_data, callback)
|
||||
local original_lines = vim.split(diff_data.original, "\n", { plain = true })
|
||||
local modified_lines = vim.split(diff_data.modified, "\n", { plain = true })
|
||||
local modified_lines
|
||||
|
||||
-- For delete operations, show a clear message
|
||||
if diff_data.operation == "delete" then
|
||||
modified_lines = {
|
||||
"",
|
||||
" FILE WILL BE DELETED",
|
||||
"",
|
||||
" Reason: " .. (diff_data.reason or "No reason provided"),
|
||||
"",
|
||||
}
|
||||
else
|
||||
modified_lines = vim.split(diff_data.modified, "\n", { plain = true })
|
||||
end
|
||||
|
||||
-- Calculate window dimensions
|
||||
local width = math.floor(vim.o.columns * 0.8)
|
||||
@@ -59,7 +72,7 @@ function M.show_diff(diff_data, callback)
|
||||
col = col + half_width + 1,
|
||||
style = "minimal",
|
||||
border = "rounded",
|
||||
title = " MODIFIED [" .. diff_data.operation .. "] ",
|
||||
title = diff_data.operation == "delete" and " ⚠️ DELETE " or (" MODIFIED [" .. diff_data.operation .. "] "),
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
@@ -157,26 +170,52 @@ function M.show_diff(diff_data, callback)
|
||||
}, false, {})
|
||||
end
|
||||
|
||||
--- Show approval dialog for bash commands
|
||||
---@alias BashApprovalResult {approved: boolean, permission_level: string|nil}
|
||||
|
||||
--- Show approval dialog for bash commands with permission levels
|
||||
---@param command string The bash command to approve
|
||||
---@param callback fun(approved: boolean) Called with user decision
|
||||
---@param callback fun(result: BashApprovalResult) Called with user decision
|
||||
function M.show_bash_approval(command, callback)
|
||||
-- Create a simple floating window for bash approval
|
||||
local permissions = require("codetyper.agent.permissions")
|
||||
|
||||
-- Check if command is auto-approved
|
||||
local perm_result = permissions.check_bash_permission(command)
|
||||
if perm_result.auto and perm_result.allowed then
|
||||
vim.schedule(function()
|
||||
callback({ approved = true, permission_level = "auto" })
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Create approval dialog with options
|
||||
local lines = {
|
||||
"",
|
||||
" BASH COMMAND APPROVAL",
|
||||
" " .. string.rep("-", 50),
|
||||
" " .. string.rep("─", 56),
|
||||
"",
|
||||
" Command:",
|
||||
" $ " .. command,
|
||||
"",
|
||||
" " .. string.rep("-", 50),
|
||||
" Press [y] or [Enter] to execute",
|
||||
" Press [n], [q], or [Esc] to cancel",
|
||||
"",
|
||||
}
|
||||
|
||||
local width = math.max(60, #command + 10)
|
||||
-- Add warning for dangerous commands
|
||||
if not perm_result.allowed and perm_result.reason ~= "Requires approval" then
|
||||
table.insert(lines, " ⚠️ WARNING: " .. perm_result.reason)
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
table.insert(lines, " " .. string.rep("─", 56))
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, " [y] Allow once - Execute this command")
|
||||
table.insert(lines, " [s] Allow this session - Auto-allow until restart")
|
||||
table.insert(lines, " [a] Add to allow list - Always allow this command")
|
||||
table.insert(lines, " [n] Reject - Cancel execution")
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, " " .. string.rep("─", 56))
|
||||
table.insert(lines, " Press key to choose | [q] or [Esc] to cancel")
|
||||
table.insert(lines, "")
|
||||
|
||||
local width = math.max(65, #command + 15)
|
||||
local height = #lines
|
||||
|
||||
local buf = vim.api.nvim_create_buf(false, true)
|
||||
@@ -196,45 +235,84 @@ function M.show_bash_approval(command, callback)
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
-- Apply some highlighting
|
||||
-- Apply highlighting
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "Title", 1, 0, -1)
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "String", 5, 0, -1)
|
||||
|
||||
-- Highlight options
|
||||
for i, line in ipairs(lines) do
|
||||
if line:match("^%s+%[y%]") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticOk", i - 1, 0, -1)
|
||||
elseif line:match("^%s+%[s%]") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticInfo", i - 1, 0, -1)
|
||||
elseif line:match("^%s+%[a%]") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticHint", i - 1, 0, -1)
|
||||
elseif line:match("^%s+%[n%]") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticError", i - 1, 0, -1)
|
||||
elseif line:match("⚠️") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticWarn", i - 1, 0, -1)
|
||||
end
|
||||
end
|
||||
|
||||
local callback_called = false
|
||||
|
||||
local function close_and_respond(approved)
|
||||
local function close_and_respond(approved, permission_level)
|
||||
if callback_called then
|
||||
return
|
||||
end
|
||||
callback_called = true
|
||||
|
||||
-- Grant permission if approved with session or list level
|
||||
if approved and permission_level then
|
||||
permissions.grant_permission(command, permission_level)
|
||||
end
|
||||
|
||||
pcall(vim.api.nvim_win_close, win, true)
|
||||
|
||||
vim.schedule(function()
|
||||
callback(approved)
|
||||
callback({ approved = approved, permission_level = permission_level })
|
||||
end)
|
||||
end
|
||||
|
||||
local keymap_opts = { buffer = buf, noremap = true, silent = true, nowait = true }
|
||||
|
||||
-- Approve
|
||||
-- Allow once
|
||||
vim.keymap.set("n", "y", function()
|
||||
close_and_respond(true)
|
||||
close_and_respond(true, "allow")
|
||||
end, keymap_opts)
|
||||
vim.keymap.set("n", "<CR>", function()
|
||||
close_and_respond(true)
|
||||
close_and_respond(true, "allow")
|
||||
end, keymap_opts)
|
||||
|
||||
-- Allow this session
|
||||
vim.keymap.set("n", "s", function()
|
||||
close_and_respond(true, "allow_session")
|
||||
end, keymap_opts)
|
||||
|
||||
-- Add to allow list
|
||||
vim.keymap.set("n", "a", function()
|
||||
close_and_respond(true, "allow_list")
|
||||
end, keymap_opts)
|
||||
|
||||
-- Reject
|
||||
vim.keymap.set("n", "n", function()
|
||||
close_and_respond(false)
|
||||
close_and_respond(false, nil)
|
||||
end, keymap_opts)
|
||||
vim.keymap.set("n", "q", function()
|
||||
close_and_respond(false)
|
||||
close_and_respond(false, nil)
|
||||
end, keymap_opts)
|
||||
vim.keymap.set("n", "<Esc>", function()
|
||||
close_and_respond(false)
|
||||
close_and_respond(false, nil)
|
||||
end, keymap_opts)
|
||||
end
|
||||
|
||||
--- Show approval dialog for bash commands (simple version for backward compatibility)
|
||||
---@param command string The bash command to approve
|
||||
---@param callback fun(approved: boolean) Called with user decision
|
||||
function M.show_bash_approval_simple(command, callback)
|
||||
M.show_bash_approval(command, function(result)
|
||||
callback(result.approved)
|
||||
end)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -27,6 +27,9 @@ function M.execute(tool_name, parameters, callback)
|
||||
edit_file = M.handle_edit_file,
|
||||
write_file = M.handle_write_file,
|
||||
bash = M.handle_bash,
|
||||
delete_file = M.handle_delete_file,
|
||||
list_directory = M.handle_list_directory,
|
||||
search_files = M.handle_search_files,
|
||||
}
|
||||
|
||||
local handler = handlers[tool_name]
|
||||
@@ -156,6 +159,165 @@ function M.handle_bash(params, callback)
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle delete_file tool
|
||||
---@param params table { path: string, reason: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_delete_file(params, callback)
|
||||
local path = M.resolve_path(params.path)
|
||||
local reason = params.reason or "No reason provided"
|
||||
|
||||
-- Check if file exists
|
||||
if not utils.file_exists(path) then
|
||||
callback({
|
||||
success = false,
|
||||
result = "File not found: " .. path,
|
||||
requires_approval = false,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Read content for showing in diff (so user knows what they're deleting)
|
||||
local content = utils.read_file(path) or "[Could not read file]"
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = "Delete: " .. path .. " (" .. reason .. ")",
|
||||
requires_approval = true,
|
||||
diff_data = {
|
||||
path = path,
|
||||
original = content,
|
||||
modified = "", -- Empty = deletion
|
||||
operation = "delete",
|
||||
reason = reason,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle list_directory tool
|
||||
---@param params table { path?: string, recursive?: boolean }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_list_directory(params, callback)
|
||||
local path = params.path and M.resolve_path(params.path) or (utils.get_project_root() or vim.fn.getcwd())
|
||||
local recursive = params.recursive or false
|
||||
|
||||
-- Use vim.fn.readdir or glob for directory listing
|
||||
local entries = {}
|
||||
local function list_dir(dir, depth)
|
||||
if depth > 3 then
|
||||
return
|
||||
end
|
||||
|
||||
local ok, files = pcall(vim.fn.readdir, dir)
|
||||
if not ok or not files then
|
||||
return
|
||||
end
|
||||
|
||||
for _, name in ipairs(files) do
|
||||
if name ~= "." and name ~= ".." and not name:match("^%.git$") and not name:match("^node_modules$") then
|
||||
local full_path = dir .. "/" .. name
|
||||
local stat = vim.loop.fs_stat(full_path)
|
||||
if stat then
|
||||
local prefix = string.rep(" ", depth)
|
||||
local type_indicator = stat.type == "directory" and "/" or ""
|
||||
table.insert(entries, prefix .. name .. type_indicator)
|
||||
|
||||
if recursive and stat.type == "directory" then
|
||||
list_dir(full_path, depth + 1)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
list_dir(path, 0)
|
||||
|
||||
local result = "Directory: " .. path .. "\n\n" .. table.concat(entries, "\n")
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = result,
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle search_files tool
|
||||
---@param params table { pattern?: string, content?: string, path?: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_search_files(params, callback)
|
||||
local search_path = params.path and M.resolve_path(params.path) or (utils.get_project_root() or vim.fn.getcwd())
|
||||
local pattern = params.pattern
|
||||
local content_search = params.content
|
||||
|
||||
local results = {}
|
||||
|
||||
if pattern then
|
||||
-- Search by file name pattern using glob
|
||||
local glob_pattern = search_path .. "/**/" .. pattern
|
||||
local files = vim.fn.glob(glob_pattern, false, true)
|
||||
|
||||
for _, file in ipairs(files) do
|
||||
-- Skip common ignore patterns
|
||||
if not file:match("node_modules") and not file:match("%.git/") then
|
||||
table.insert(results, file:gsub(search_path .. "/", ""))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if content_search then
|
||||
-- Search by content using grep
|
||||
local grep_results = {}
|
||||
local grep_cmd = string.format("grep -rl '%s' '%s' 2>/dev/null | head -20", content_search:gsub("'", "\\'"), search_path)
|
||||
|
||||
local handle = io.popen(grep_cmd)
|
||||
if handle then
|
||||
for line in handle:lines() do
|
||||
if not line:match("node_modules") and not line:match("%.git/") then
|
||||
table.insert(grep_results, line:gsub(search_path .. "/", ""))
|
||||
end
|
||||
end
|
||||
handle:close()
|
||||
end
|
||||
|
||||
-- Merge with pattern results or use as primary results
|
||||
if #results == 0 then
|
||||
results = grep_results
|
||||
else
|
||||
-- Intersection of pattern and content results
|
||||
local pattern_set = {}
|
||||
for _, f in ipairs(results) do
|
||||
pattern_set[f] = true
|
||||
end
|
||||
results = {}
|
||||
for _, f in ipairs(grep_results) do
|
||||
if pattern_set[f] then
|
||||
table.insert(results, f)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local result_text = "Search results"
|
||||
if pattern then
|
||||
result_text = result_text .. " (pattern: " .. pattern .. ")"
|
||||
end
|
||||
if content_search then
|
||||
result_text = result_text .. " (content: " .. content_search .. ")"
|
||||
end
|
||||
result_text = result_text .. ":\n\n"
|
||||
|
||||
if #results == 0 then
|
||||
result_text = result_text .. "No files found."
|
||||
else
|
||||
result_text = result_text .. table.concat(results, "\n")
|
||||
end
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = result_text,
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
|
||||
--- Actually apply an approved change
|
||||
---@param diff_data DiffData The diff data to apply
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
@@ -164,6 +326,24 @@ function M.apply_change(diff_data, callback)
|
||||
-- Extract command from modified (remove "$ " prefix)
|
||||
local command = diff_data.modified:gsub("^%$ ", "")
|
||||
M.execute_bash_command(command, 30000, callback)
|
||||
elseif diff_data.operation == "delete" then
|
||||
-- Delete file
|
||||
local ok, err = os.remove(diff_data.path)
|
||||
if ok then
|
||||
-- Close buffer if it's open
|
||||
M.close_buffer_if_open(diff_data.path)
|
||||
callback({
|
||||
success = true,
|
||||
result = "Deleted: " .. diff_data.path,
|
||||
requires_approval = false,
|
||||
})
|
||||
else
|
||||
callback({
|
||||
success = false,
|
||||
result = "Failed to delete: " .. diff_data.path .. " (" .. (err or "unknown error") .. ")",
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
else
|
||||
-- Write file
|
||||
local success = utils.write_file(diff_data.path, diff_data.modified)
|
||||
@@ -275,6 +455,22 @@ function M.reload_buffer_if_open(filepath)
|
||||
end
|
||||
end
|
||||
|
||||
--- Close a buffer if it's currently open (for deleted files)
|
||||
---@param filepath string Path to the file
|
||||
function M.close_buffer_if_open(filepath)
|
||||
local full_path = vim.fn.fnamemodify(filepath, ":p")
|
||||
for _, buf in ipairs(vim.api.nvim_list_bufs()) do
|
||||
if vim.api.nvim_buf_is_loaded(buf) then
|
||||
local buf_name = vim.api.nvim_buf_get_name(buf)
|
||||
if buf_name == full_path then
|
||||
-- Force close the buffer
|
||||
pcall(vim.api.nvim_buf_delete, buf, { force = true })
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Resolve a path (expand ~ and make absolute if needed)
|
||||
---@param path string Path to resolve
|
||||
---@return string Resolved path
|
||||
|
||||
@@ -111,7 +111,11 @@ function M.agent_loop(context, callbacks)
|
||||
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
|
||||
|
||||
-- Generate with tools enabled
|
||||
client.generate_with_tools(state.conversation, context, tools.definitions, function(response, err)
|
||||
-- Ensure tools are loaded and get definitions
|
||||
tools.setup()
|
||||
local tool_defs = tools.to_openai_format()
|
||||
|
||||
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
|
||||
if err then
|
||||
state.is_running = false
|
||||
callbacks.on_error(err)
|
||||
@@ -123,12 +127,14 @@ function M.agent_loop(context, callbacks)
|
||||
local config = codetyper.get_config()
|
||||
local parsed
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
-- Copilot uses Claude-like response format
|
||||
if config.llm.provider == "copilot" then
|
||||
parsed = parser.parse_claude_response(response)
|
||||
-- For Claude, preserve the original content array for proper tool_use handling
|
||||
table.insert(state.conversation, {
|
||||
role = "assistant",
|
||||
content = response.content, -- Keep original content blocks for Claude API
|
||||
content = parsed.text or "",
|
||||
tool_calls = parsed.tool_calls,
|
||||
_raw_content = response.content,
|
||||
})
|
||||
else
|
||||
-- For Ollama, response is the text directly
|
||||
@@ -200,9 +206,22 @@ function M.process_tool_calls(tool_calls, index, context, callbacks)
|
||||
show_fn = diff.show_diff
|
||||
end
|
||||
|
||||
show_fn(result.diff_data, function(approved)
|
||||
show_fn(result.diff_data, function(approval_result)
|
||||
-- Handle both old (boolean) and new (table) approval result formats
|
||||
local approved = type(approval_result) == "table" and approval_result.approved or approval_result
|
||||
local permission_level = type(approval_result) == "table" and approval_result.permission_level or nil
|
||||
|
||||
if approved then
|
||||
logs.tool(tool_call.name, "approved", "User approved")
|
||||
local log_msg = "User approved"
|
||||
if permission_level == "allow_session" then
|
||||
log_msg = "Allowed for session"
|
||||
elseif permission_level == "allow_list" then
|
||||
log_msg = "Added to allow list"
|
||||
elseif permission_level == "auto" then
|
||||
log_msg = "Auto-approved"
|
||||
end
|
||||
logs.tool(tool_call.name, "approved", log_msg)
|
||||
|
||||
-- Apply the change
|
||||
executor.apply_change(result.diff_data, function(apply_result)
|
||||
-- Store result for sending back to LLM
|
||||
@@ -261,8 +280,9 @@ function M.continue_with_results(context, callbacks)
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
-- Claude format: tool_result blocks
|
||||
-- Copilot uses Claude-like format for tool results
|
||||
if config.llm.provider == "copilot" then
|
||||
-- Claude-style tool_result blocks
|
||||
local content = {}
|
||||
for _, result in ipairs(state.pending_tool_results) do
|
||||
table.insert(content, {
|
||||
|
||||
614
lua/codetyper/agent/inject.lua
Normal file
614
lua/codetyper/agent/inject.lua
Normal file
@@ -0,0 +1,614 @@
|
||||
---@mod codetyper.agent.inject Smart code injection with import handling
|
||||
---@brief [[
|
||||
--- Intelligent code injection that properly handles imports, merging them
|
||||
--- into existing import sections instead of blindly appending.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class ImportConfig
|
||||
---@field pattern string Lua pattern to match import statements
|
||||
---@field multi_line boolean Whether imports can span multiple lines
|
||||
---@field sort_key function|nil Function to extract sort key from import
|
||||
---@field group_by function|nil Function to group imports
|
||||
|
||||
---@class ParsedCode
|
||||
---@field imports string[] Import statements
|
||||
---@field body string[] Non-import code lines
|
||||
---@field import_lines table<number, boolean> Map of line numbers that are imports
|
||||
|
||||
--- Language-specific import patterns
|
||||
local import_patterns = {
|
||||
-- JavaScript/TypeScript
|
||||
javascript = {
|
||||
{ pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true },
|
||||
{ pattern = "^%s*import%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*import%s*{", multi_line = true },
|
||||
{ pattern = "^%s*import%s*%*", multi_line = true },
|
||||
{ pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true },
|
||||
{ pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
{ pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
{ pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
},
|
||||
-- Python
|
||||
python = {
|
||||
{ pattern = "^%s*import%s+%w", multi_line = false },
|
||||
{ pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true },
|
||||
},
|
||||
-- Lua
|
||||
lua = {
|
||||
{ pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require%s*%(?['\"]", multi_line = false },
|
||||
},
|
||||
-- Go
|
||||
go = {
|
||||
{ pattern = "^%s*import%s+%(?", multi_line = true },
|
||||
},
|
||||
-- Rust
|
||||
rust = {
|
||||
{ pattern = "^%s*use%s+", multi_line = true },
|
||||
{ pattern = "^%s*extern%s+crate%s+", multi_line = false },
|
||||
},
|
||||
-- C/C++
|
||||
c = {
|
||||
{ pattern = "^%s*#include%s*[<\"]", multi_line = false },
|
||||
},
|
||||
-- Java/Kotlin
|
||||
java = {
|
||||
{ pattern = "^%s*import%s+", multi_line = false },
|
||||
},
|
||||
-- Ruby
|
||||
ruby = {
|
||||
{ pattern = "^%s*require%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require_relative%s+['\"]", multi_line = false },
|
||||
},
|
||||
-- PHP
|
||||
php = {
|
||||
{ pattern = "^%s*use%s+", multi_line = false },
|
||||
{ pattern = "^%s*require%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require_once%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*include%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*include_once%s+['\"]", multi_line = false },
|
||||
},
|
||||
}
|
||||
|
||||
-- Alias common extensions to language configs
|
||||
import_patterns.ts = import_patterns.javascript
|
||||
import_patterns.tsx = import_patterns.javascript
|
||||
import_patterns.jsx = import_patterns.javascript
|
||||
import_patterns.mjs = import_patterns.javascript
|
||||
import_patterns.cjs = import_patterns.javascript
|
||||
import_patterns.py = import_patterns.python
|
||||
import_patterns.cpp = import_patterns.c
|
||||
import_patterns.hpp = import_patterns.c
|
||||
import_patterns.h = import_patterns.c
|
||||
import_patterns.kt = import_patterns.java
|
||||
import_patterns.rs = import_patterns.rust
|
||||
import_patterns.rb = import_patterns.ruby
|
||||
|
||||
--- Check if a line is an import statement for the given language
|
||||
---@param line string
|
||||
---@param patterns table[] Import patterns for the language
|
||||
---@return boolean is_import
|
||||
---@return boolean is_multi_line
|
||||
local function is_import_line(line, patterns)
|
||||
for _, p in ipairs(patterns) do
|
||||
if line:match(p.pattern) then
|
||||
return true, p.multi_line or false
|
||||
end
|
||||
end
|
||||
return false, false
|
||||
end
|
||||
|
||||
--- Check if a line is empty or a comment
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function is_empty_or_comment(line, filetype)
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
if trimmed == "" then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Language-specific comment patterns
|
||||
local comment_patterns = {
|
||||
lua = { "^%-%-" },
|
||||
python = { "^#" },
|
||||
javascript = { "^//", "^/%*", "^%*" },
|
||||
typescript = { "^//", "^/%*", "^%*" },
|
||||
go = { "^//", "^/%*", "^%*" },
|
||||
rust = { "^//", "^/%*", "^%*" },
|
||||
c = { "^//", "^/%*", "^%*", "^#" },
|
||||
java = { "^//", "^/%*", "^%*" },
|
||||
ruby = { "^#" },
|
||||
php = { "^//", "^/%*", "^%*", "^#" },
|
||||
}
|
||||
|
||||
local patterns = comment_patterns[filetype] or comment_patterns.javascript
|
||||
for _, pattern in ipairs(patterns) do
|
||||
if trimmed:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if a line ends a multi-line import
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function ends_multiline_import(line, filetype)
|
||||
-- Check for closing patterns
|
||||
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
|
||||
-- ES6 imports end with 'from "..." ;' or just ';' or a line with just '}'
|
||||
if line:match("from%s+['\"][^'\"]+['\"]%s*;?%s*$") then
|
||||
return true
|
||||
end
|
||||
if line:match("}%s*from%s+['\"]") then
|
||||
return true
|
||||
end
|
||||
if line:match("^%s*}%s*;?%s*$") then
|
||||
return true
|
||||
end
|
||||
if line:match(";%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "python" or filetype == "py" then
|
||||
-- Python single-line import: doesn't end with \, (, or ,
|
||||
-- Examples: "from typing import List, Dict" or "import os"
|
||||
if not line:match("\\%s*$") and not line:match("%(%s*$") and not line:match(",%s*$") then
|
||||
return true
|
||||
end
|
||||
-- Python multiline imports end with closing paren
|
||||
if line:match("%)%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "go" then
|
||||
-- Go multi-line imports end with ')'
|
||||
if line:match("%)%s*$") then
|
||||
return true
|
||||
end
|
||||
elseif filetype == "rust" or filetype == "rs" then
|
||||
-- Rust use statements end with ';'
|
||||
if line:match(";%s*$") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Parse code into imports and body
|
||||
---@param code string|string[] Code to parse
|
||||
---@param filetype string File type/extension
|
||||
---@return ParsedCode
|
||||
function M.parse_code(code, filetype)
|
||||
local lines
|
||||
if type(code) == "string" then
|
||||
lines = vim.split(code, "\n", { plain = true })
|
||||
else
|
||||
lines = code
|
||||
end
|
||||
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local result = {
|
||||
imports = {},
|
||||
body = {},
|
||||
import_lines = {},
|
||||
}
|
||||
|
||||
local in_multiline_import = false
|
||||
local current_import_lines = {}
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline_import then
|
||||
-- Continue collecting multi-line import
|
||||
table.insert(current_import_lines, line)
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
-- Complete the multi-line import
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
for j = i - #current_import_lines + 1, i do
|
||||
result.import_lines[j] = true
|
||||
end
|
||||
current_import_lines = {}
|
||||
in_multiline_import = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
result.import_lines[i] = true
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
-- Start of multi-line import
|
||||
in_multiline_import = true
|
||||
current_import_lines = { line }
|
||||
else
|
||||
-- Single-line import
|
||||
table.insert(result.imports, line)
|
||||
end
|
||||
else
|
||||
-- Non-import line
|
||||
table.insert(result.body, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
|
||||
if #current_import_lines > 0 then
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Find the import section range in a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return number|nil start_line First import line (1-indexed)
|
||||
---@return number|nil end_line Last import line (1-indexed)
|
||||
function M.find_import_section(bufnr, filetype)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local first_import = nil
|
||||
local last_import = nil
|
||||
local in_multiline = false
|
||||
local consecutive_non_import = 0
|
||||
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline then
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
in_multiline = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
if not first_import then
|
||||
first_import = i
|
||||
end
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
in_multiline = true
|
||||
end
|
||||
elseif is_empty_or_comment(line, filetype) then
|
||||
-- Allow gaps in import section
|
||||
if first_import then
|
||||
consecutive_non_import = consecutive_non_import + 1
|
||||
if consecutive_non_import > max_gap then
|
||||
-- Too many non-import lines, import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
else
|
||||
-- Non-import, non-empty line
|
||||
if first_import then
|
||||
-- Import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return first_import, last_import
|
||||
end
|
||||
|
||||
--- Get existing imports from a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return string[] Existing import statements
|
||||
function M.get_existing_imports(bufnr, filetype)
|
||||
local start_line, end_line = M.find_import_section(bufnr, filetype)
|
||||
if not start_line then
|
||||
return {}
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local parsed = M.parse_code(lines, filetype)
|
||||
return parsed.imports
|
||||
end
|
||||
|
||||
--- Normalize an import for comparison (remove whitespace variations)
|
||||
---@param import_str string
|
||||
---@return string
|
||||
local function normalize_import(import_str)
|
||||
-- Remove trailing semicolon for comparison
|
||||
local normalized = import_str:gsub(";%s*$", "")
|
||||
-- Remove all whitespace around braces, commas, colons
|
||||
normalized = normalized:gsub("%s*{%s*", "{")
|
||||
normalized = normalized:gsub("%s*}%s*", "}")
|
||||
normalized = normalized:gsub("%s*,%s*", ",")
|
||||
normalized = normalized:gsub("%s*:%s*", ":")
|
||||
-- Collapse multiple whitespace to single space
|
||||
normalized = normalized:gsub("%s+", " ")
|
||||
-- Trim leading/trailing whitespace
|
||||
normalized = normalized:match("^%s*(.-)%s*$")
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if two imports are duplicates
|
||||
---@param import1 string
|
||||
---@param import2 string
|
||||
---@return boolean
|
||||
local function are_duplicate_imports(import1, import2)
|
||||
return normalize_import(import1) == normalize_import(import2)
|
||||
end
|
||||
|
||||
--- Merge new imports with existing ones, avoiding duplicates
|
||||
---@param existing string[] Existing imports
|
||||
---@param new_imports string[] New imports to merge
|
||||
---@return string[] Merged imports
|
||||
function M.merge_imports(existing, new_imports)
|
||||
local merged = {}
|
||||
local seen = {}
|
||||
|
||||
-- Add existing imports
|
||||
for _, imp in ipairs(existing) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new imports that aren't duplicates
|
||||
for _, imp in ipairs(new_imports) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
return merged
|
||||
end
|
||||
|
||||
--- Sort imports by their source/module
|
||||
---@param imports string[]
|
||||
---@param filetype string
|
||||
---@return string[]
|
||||
function M.sort_imports(imports, filetype)
|
||||
-- Group imports: stdlib/builtin first, then third-party, then local
|
||||
local builtin = {}
|
||||
local third_party = {}
|
||||
local local_imports = {}
|
||||
|
||||
for _, imp in ipairs(imports) do
|
||||
-- Detect import type based on patterns
|
||||
local is_local = false
|
||||
local is_builtin = false
|
||||
|
||||
if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then
|
||||
-- Local: starts with . or ..
|
||||
is_local = imp:match("from%s+['\"]%.") or imp:match("require%(['\"]%.")
|
||||
-- Node builtin modules
|
||||
is_builtin = imp:match("from%s+['\"]node:") or imp:match("from%s+['\"]fs['\"]")
|
||||
or imp:match("from%s+['\"]path['\"]") or imp:match("from%s+['\"]http['\"]")
|
||||
elseif filetype == "python" or filetype == "py" then
|
||||
-- Local: relative imports
|
||||
is_local = imp:match("^from%s+%.") or imp:match("^import%s+%.")
|
||||
-- Python stdlib (simplified check)
|
||||
is_builtin = imp:match("^import%s+os") or imp:match("^import%s+sys")
|
||||
or imp:match("^from%s+os%s+") or imp:match("^from%s+sys%s+")
|
||||
or imp:match("^import%s+re") or imp:match("^import%s+json")
|
||||
elseif filetype == "lua" then
|
||||
-- Local: relative requires
|
||||
is_local = imp:match("require%(['\"]%.") or imp:match("require%s+['\"]%.")
|
||||
elseif filetype == "go" then
|
||||
-- Local: project imports (contain /)
|
||||
is_local = imp:match("['\"][^'\"]+/[^'\"]+['\"]") and not imp:match("github%.com")
|
||||
end
|
||||
|
||||
if is_builtin then
|
||||
table.insert(builtin, imp)
|
||||
elseif is_local then
|
||||
table.insert(local_imports, imp)
|
||||
else
|
||||
table.insert(third_party, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort each group alphabetically
|
||||
table.sort(builtin)
|
||||
table.sort(third_party)
|
||||
table.sort(local_imports)
|
||||
|
||||
-- Combine with proper spacing
|
||||
local result = {}
|
||||
|
||||
for _, imp in ipairs(builtin) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(third_party) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #third_party > 0 and #local_imports > 0 then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(local_imports) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
---@class InjectResult
|
||||
---@field success boolean
|
||||
---@field imports_added number Number of new imports added
|
||||
---@field imports_merged boolean Whether imports were merged into existing section
|
||||
---@field body_lines number Number of body lines injected
|
||||
|
||||
--- Smart inject code into a buffer, properly handling imports
|
||||
---@param bufnr number Target buffer
|
||||
---@param code string|string[] Code to inject
|
||||
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
|
||||
---@return InjectResult
|
||||
function M.inject(bufnr, code, opts)
|
||||
opts = opts or {}
|
||||
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
|
||||
end
|
||||
|
||||
-- Get filetype
|
||||
local filetype = opts.filetype
|
||||
if not filetype then
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
filetype = vim.fn.fnamemodify(bufname, ":e")
|
||||
end
|
||||
|
||||
-- Parse the code to separate imports from body
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
|
||||
local result = {
|
||||
success = true,
|
||||
imports_added = 0,
|
||||
imports_merged = false,
|
||||
body_lines = #parsed.body,
|
||||
}
|
||||
|
||||
-- Handle imports first if there are any
|
||||
if #parsed.imports > 0 then
|
||||
local import_start, import_end = M.find_import_section(bufnr, filetype)
|
||||
|
||||
if import_start then
|
||||
-- Merge with existing import section
|
||||
local existing_imports = M.get_existing_imports(bufnr, filetype)
|
||||
local merged = M.merge_imports(existing_imports, parsed.imports)
|
||||
|
||||
-- Count how many new imports were actually added
|
||||
result.imports_added = #merged - #existing_imports
|
||||
result.imports_merged = true
|
||||
|
||||
-- Optionally sort imports
|
||||
if opts.sort_imports ~= false then
|
||||
merged = M.sort_imports(merged, filetype)
|
||||
end
|
||||
|
||||
-- Convert back to lines (handling multi-line imports)
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(merged) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Replace the import section
|
||||
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
|
||||
|
||||
-- Adjust line numbers for body injection
|
||||
local lines_diff = #import_lines - (import_end - import_start + 1)
|
||||
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
|
||||
opts.range.start_line = opts.range.start_line + lines_diff
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + lines_diff
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No existing import section, add imports at the top
|
||||
-- Find the first non-comment, non-empty line
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local insert_at = 0
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
-- Skip shebang, docstrings, and initial comments
|
||||
if trimmed ~= "" and not trimmed:match("^#!")
|
||||
and not trimmed:match("^['\"]") and not is_empty_or_comment(line, filetype) then
|
||||
insert_at = i - 1
|
||||
break
|
||||
end
|
||||
insert_at = i
|
||||
end
|
||||
|
||||
-- Add imports with a trailing blank line
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(parsed.imports) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(import_lines, "") -- Blank line after imports
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
|
||||
result.imports_added = #parsed.imports
|
||||
result.imports_merged = false
|
||||
|
||||
-- Adjust body injection range
|
||||
if opts.range and opts.range.start_line then
|
||||
opts.range.start_line = opts.range.start_line + #import_lines
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + #import_lines
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle body (non-import) code
|
||||
if #parsed.body > 0 then
|
||||
-- Filter out empty leading/trailing lines from body
|
||||
local body_lines = parsed.body
|
||||
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
|
||||
table.remove(body_lines, 1)
|
||||
end
|
||||
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
|
||||
table.remove(body_lines)
|
||||
end
|
||||
|
||||
if #body_lines > 0 then
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
local strategy = opts.strategy or "append"
|
||||
|
||||
if strategy == "replace" and opts.range then
|
||||
local start_line = math.max(1, opts.range.start_line)
|
||||
local end_line = math.min(line_count, opts.range.end_line)
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
|
||||
elseif strategy == "insert" and opts.range then
|
||||
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
|
||||
else
|
||||
-- Default: append
|
||||
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Add blank line for spacing
|
||||
table.insert(body_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
|
||||
end
|
||||
|
||||
result.body_lines = #body_lines
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Check if code contains imports
|
||||
---@param code string|string[]
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
function M.has_imports(code, filetype)
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
return #parsed.imports > 0
|
||||
end
|
||||
|
||||
return M
|
||||
398
lua/codetyper/agent/loop.lua
Normal file
398
lua/codetyper/agent/loop.lua
Normal file
@@ -0,0 +1,398 @@
|
||||
---@mod codetyper.agent.loop Agent loop with tool orchestration
|
||||
---@brief [[
|
||||
--- Main agent loop that handles multi-turn conversations with tool use.
|
||||
--- Inspired by avante.nvim's agent_loop pattern.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgentMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_call_id? string For tool responses
|
||||
---@field tool_calls? table[] For assistant tool calls
|
||||
---@field name? string Tool name for tool responses
|
||||
|
||||
---@class AgentLoopOpts
|
||||
---@field system_prompt string System prompt
|
||||
---@field user_input string Initial user message
|
||||
---@field tools? CoderTool[] Available tools (default: all registered)
|
||||
---@field max_iterations? number Max tool call iterations (default: 10)
|
||||
---@field provider? string LLM provider to use
|
||||
---@field on_start? fun() Called when loop starts
|
||||
---@field on_chunk? fun(chunk: string) Called for each response chunk
|
||||
---@field on_tool_call? fun(name: string, input: table) Called before tool execution
|
||||
---@field on_tool_result? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_message? fun(message: AgentMessage) Called for each message added
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when loop completes
|
||||
---@field session_ctx? table Session context shared across tools
|
||||
|
||||
--- Format tool definitions for OpenAI-compatible API
|
||||
---@param tools CoderTool[]
|
||||
---@return table[]
|
||||
local function format_tools_for_api(tools)
|
||||
local formatted = {}
|
||||
for _, tool in ipairs(tools) do
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(formatted, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = type(tool.description) == "function" and tool.description() or tool.description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
return formatted
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response
|
||||
---@param response table LLM response
|
||||
---@return table[] tool_calls
|
||||
local function parse_tool_calls(response)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Handle different response formats
|
||||
if response.tool_calls then
|
||||
-- OpenAI format
|
||||
for _, call in ipairs(response.tool_calls) do
|
||||
local args = call["function"].arguments
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
end
|
||||
end
|
||||
table.insert(tool_calls, {
|
||||
id = call.id,
|
||||
name = call["function"].name,
|
||||
input = args,
|
||||
})
|
||||
end
|
||||
elseif response.content and type(response.content) == "table" then
|
||||
-- Claude format (content blocks)
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "tool_use" then
|
||||
table.insert(tool_calls, {
|
||||
id = block.id,
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Build messages for LLM request
|
||||
---@param history AgentMessage[]
|
||||
---@return table[]
|
||||
local function build_messages(history)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Execute the agent loop
|
||||
---@param opts AgentLoopOpts
|
||||
function M.run(opts)
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local llm = require("codetyper.llm")
|
||||
|
||||
-- Get tools
|
||||
local tools = opts.tools or tools_mod.list()
|
||||
local tool_map = {}
|
||||
for _, tool in ipairs(tools) do
|
||||
tool_map[tool.name] = tool
|
||||
end
|
||||
|
||||
-- Initialize conversation history
|
||||
---@type AgentMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = opts.system_prompt },
|
||||
{ role = "user", content = opts.user_input },
|
||||
}
|
||||
|
||||
local session_ctx = opts.session_ctx or {}
|
||||
local max_iterations = opts.max_iterations or 10
|
||||
local iteration = 0
|
||||
|
||||
-- Callback wrappers
|
||||
local function on_message(msg)
|
||||
if opts.on_message then
|
||||
opts.on_message(msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify of initial messages
|
||||
for _, msg in ipairs(history) do
|
||||
on_message(msg)
|
||||
end
|
||||
|
||||
-- Start notification
|
||||
if opts.on_start then
|
||||
opts.on_start()
|
||||
end
|
||||
|
||||
--- Process one iteration of the loop
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build request
|
||||
local messages = build_messages(history)
|
||||
local formatted_tools = format_tools_for_api(tools)
|
||||
|
||||
-- Build context for LLM
|
||||
local context = {
|
||||
file_content = "",
|
||||
language = "lua",
|
||||
extension = "lua",
|
||||
prompt_type = "agent",
|
||||
tools = formatted_tools,
|
||||
}
|
||||
|
||||
-- Get LLM response
|
||||
local client = llm.get_client()
|
||||
if not client then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "No LLM client available")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or ""))
|
||||
end
|
||||
end
|
||||
local prompt = table.concat(prompt_parts, "\n\n")
|
||||
|
||||
client.generate(prompt, context, function(response, error)
|
||||
if error then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, error)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Chunk callback
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(response)
|
||||
end
|
||||
|
||||
-- Parse response for tool calls
|
||||
-- For now, we'll use a simple heuristic to detect tool calls in the response
|
||||
-- In a full implementation, the LLM would return structured tool calls
|
||||
local tool_calls = {}
|
||||
|
||||
-- Try to parse JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool_calls then
|
||||
tool_calls = parsed.tool_calls
|
||||
end
|
||||
end
|
||||
|
||||
-- Add assistant message
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = response,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
on_message(assistant_msg)
|
||||
|
||||
-- Process tool calls
|
||||
if #tool_calls > 0 then
|
||||
local pending = #tool_calls
|
||||
local results = {}
|
||||
|
||||
for i, call in ipairs(tool_calls) do
|
||||
local tool = tool_map[call.name]
|
||||
if not tool then
|
||||
results[i] = { error = "Unknown tool: " .. call.name }
|
||||
pending = pending - 1
|
||||
else
|
||||
-- Notify of tool call
|
||||
if opts.on_tool_call then
|
||||
opts.on_tool_call(call.name, call.input)
|
||||
end
|
||||
|
||||
-- Execute tool
|
||||
local tool_opts = {
|
||||
on_log = function(msg)
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({ type = "tool", message = msg })
|
||||
end)
|
||||
end,
|
||||
on_complete = function(result, err)
|
||||
results[i] = { result = result, error = err }
|
||||
pending = pending - 1
|
||||
|
||||
-- Notify of tool result
|
||||
if opts.on_tool_result then
|
||||
opts.on_tool_result(call.name, result, err)
|
||||
end
|
||||
|
||||
-- Add tool response to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = call.id or tostring(i),
|
||||
name = call.name,
|
||||
content = err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
on_message(tool_msg)
|
||||
|
||||
-- Continue loop when all tools complete
|
||||
if pending == 0 then
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end,
|
||||
session_ctx = session_ctx,
|
||||
}
|
||||
|
||||
-- Validate and execute
|
||||
local valid, validation_err = true, nil
|
||||
if tool.validate_input then
|
||||
valid, validation_err = tool:validate_input(call.input)
|
||||
end
|
||||
|
||||
if not valid then
|
||||
tool_opts.on_complete(nil, validation_err)
|
||||
else
|
||||
local result, err = tool.func(call.input, tool_opts)
|
||||
-- If sync result, call on_complete
|
||||
if result ~= nil or err ~= nil then
|
||||
tool_opts.on_complete(result, err)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No tool calls - loop complete
|
||||
if opts.on_complete then
|
||||
opts.on_complete(response, nil)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create an agent with default settings
|
||||
---@param task string Task description
|
||||
---@param opts? AgentLoopOpts Additional options
|
||||
function M.create(task, opts)
|
||||
opts = opts or {}
|
||||
|
||||
local system_prompt = opts.system_prompt or [[You are a helpful coding assistant with access to tools.
|
||||
|
||||
Available tools:
|
||||
- view: Read file contents
|
||||
- grep: Search for patterns in files
|
||||
- glob: Find files by pattern
|
||||
- edit: Make targeted edits to files
|
||||
- write: Create or overwrite files
|
||||
- bash: Execute shell commands
|
||||
|
||||
When you need to perform a task:
|
||||
1. Use tools to gather information
|
||||
2. Plan your approach
|
||||
3. Execute changes using appropriate tools
|
||||
4. Verify the results
|
||||
|
||||
Always explain your reasoning before using tools.
|
||||
When you're done, provide a clear summary of what was accomplished.]]
|
||||
|
||||
M.run(vim.tbl_extend("force", opts, {
|
||||
system_prompt = system_prompt,
|
||||
user_input = task,
|
||||
}))
|
||||
end
|
||||
|
||||
--- Simple dispatch agent for sub-tasks
|
||||
---@param prompt string Task for the sub-agent
|
||||
---@param on_complete fun(result: string|nil, error: string|nil) Completion callback
|
||||
---@param opts? table Additional options
|
||||
function M.dispatch(prompt, on_complete, opts)
|
||||
opts = opts or {}
|
||||
|
||||
-- Sub-agents get limited tools by default
|
||||
local tools_mod = require("codetyper.agent.tools")
|
||||
local safe_tools = tools_mod.list(function(tool)
|
||||
return tool.name == "view" or tool.name == "grep" or tool.name == "glob"
|
||||
end)
|
||||
|
||||
M.run({
|
||||
system_prompt = [[You are a research assistant. Your task is to find information and report back.
|
||||
You have access to: view (read files), grep (search content), glob (find files).
|
||||
Be thorough and report your findings clearly.]],
|
||||
user_input = prompt,
|
||||
tools = opts.tools or safe_tools,
|
||||
max_iterations = opts.max_iterations or 5,
|
||||
on_complete = on_complete,
|
||||
session_ctx = opts.session_ctx,
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -2,10 +2,16 @@
|
||||
---@brief [[
|
||||
--- Manages code patches with buffer snapshots for staleness detection.
|
||||
--- Patches are queued for safe injection when completion popup is not visible.
|
||||
--- Uses smart injection for intelligent import merging.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Lazy load inject module to avoid circular requires
|
||||
local function get_inject_module()
|
||||
return require("codetyper.agent.inject")
|
||||
end
|
||||
|
||||
---@class BufferSnapshot
|
||||
---@field bufnr number Buffer number
|
||||
---@field changedtick number vim.b.changedtick at snapshot time
|
||||
@@ -15,7 +21,8 @@ local M = {}
|
||||
---@class PatchCandidate
|
||||
---@field id string Unique patch ID
|
||||
---@field event_id string Related PromptEvent ID
|
||||
---@field target_bufnr number Target buffer for injection
|
||||
---@field source_bufnr number Source buffer where prompt tags are (coder file)
|
||||
---@field target_bufnr number Target buffer for injection (real file)
|
||||
---@field target_path string Target file path
|
||||
---@field original_snapshot BufferSnapshot Snapshot at event creation
|
||||
---@field generated_code string Code to inject
|
||||
@@ -171,7 +178,10 @@ end
|
||||
---@param strategy string|nil Injection strategy (overrides intent-based)
|
||||
---@return PatchCandidate
|
||||
function M.create_from_event(event, generated_code, confidence, strategy)
|
||||
-- Get target buffer
|
||||
-- Source buffer is where the prompt tags are (could be coder file)
|
||||
local source_bufnr = event.bufnr
|
||||
|
||||
-- Get target buffer (where code should be injected - the real file)
|
||||
local target_bufnr = vim.fn.bufnr(event.target_path)
|
||||
if target_bufnr == -1 then
|
||||
-- Try to find by filename
|
||||
@@ -220,7 +230,8 @@ function M.create_from_event(event, generated_code, confidence, strategy)
|
||||
return {
|
||||
id = M.generate_id(),
|
||||
event_id = event.id,
|
||||
target_bufnr = target_bufnr,
|
||||
source_bufnr = source_bufnr, -- Where prompt tags are (coder file)
|
||||
target_bufnr = target_bufnr, -- Where code goes (real file)
|
||||
target_path = event.target_path,
|
||||
original_snapshot = snapshot,
|
||||
generated_code = generated_code,
|
||||
@@ -453,39 +464,56 @@ function M.apply(patch)
|
||||
-- Prepare code lines
|
||||
local code_lines = vim.split(patch.generated_code, "\n", { plain = true })
|
||||
|
||||
-- FIRST: Remove the prompt tags from the buffer before applying code
|
||||
-- This prevents the infinite loop where tags stay and get re-detected
|
||||
local tags_removed = remove_prompt_tags(target_bufnr)
|
||||
-- FIRST: Remove the prompt tags from the SOURCE buffer (coder file), not target
|
||||
-- The tags are in the coder file where the user wrote the prompt
|
||||
-- Code goes to target file, tags get removed from source file
|
||||
local source_bufnr = patch.source_bufnr
|
||||
local tags_removed = 0
|
||||
|
||||
pcall(function()
|
||||
if tags_removed > 0 then
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Removed %d prompt tag(s) from buffer", tags_removed),
|
||||
})
|
||||
end
|
||||
end)
|
||||
if source_bufnr and vim.api.nvim_buf_is_valid(source_bufnr) then
|
||||
tags_removed = remove_prompt_tags(source_bufnr)
|
||||
|
||||
-- Recalculate line count after tag removal
|
||||
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
|
||||
pcall(function()
|
||||
if tags_removed > 0 then
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local source_name = vim.api.nvim_buf_get_name(source_bufnr)
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Removed %d prompt tag(s) from %s",
|
||||
tags_removed,
|
||||
vim.fn.fnamemodify(source_name, ":t")),
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Apply based on strategy
|
||||
-- Get filetype for smart injection
|
||||
local filetype = vim.fn.fnamemodify(patch.target_path or "", ":e")
|
||||
|
||||
-- Use smart injection module for intelligent import handling
|
||||
local inject = get_inject_module()
|
||||
local inject_result = nil
|
||||
|
||||
-- Apply based on strategy using smart injection
|
||||
local ok, err = pcall(function()
|
||||
-- Prepare injection options
|
||||
local inject_opts = {
|
||||
strategy = patch.injection_strategy,
|
||||
filetype = filetype,
|
||||
sort_imports = true,
|
||||
}
|
||||
|
||||
if patch.injection_strategy == "replace" and patch.injection_range then
|
||||
-- Replace the scope range with the new code
|
||||
-- The injection_range points to the function/method we're completing
|
||||
local start_line = patch.injection_range.start_line
|
||||
local end_line = patch.injection_range.end_line
|
||||
|
||||
-- Adjust for tag removal - find the new range by searching for the scope
|
||||
-- After removing tags, line numbers may have shifted
|
||||
-- Use the scope information to find the correct range
|
||||
if patch.scope and patch.scope.type then
|
||||
-- Try to find the scope using treesitter if available
|
||||
local found_range = nil
|
||||
pcall(function()
|
||||
local ts_utils = require("nvim-treesitter.ts_utils")
|
||||
local parsers = require("nvim-treesitter.parsers")
|
||||
local parser = parsers.get_parser(target_bufnr)
|
||||
if parser then
|
||||
@@ -528,34 +556,38 @@ function M.apply(patch)
|
||||
end
|
||||
|
||||
-- Clamp to valid range
|
||||
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
|
||||
start_line = math.max(1, start_line)
|
||||
end_line = math.min(line_count, end_line)
|
||||
|
||||
-- Replace the range (0-indexed for nvim_buf_set_lines)
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, start_line - 1, end_line, false, code_lines)
|
||||
inject_opts.range = { start_line = start_line, end_line = end_line }
|
||||
elseif patch.injection_strategy == "insert" and patch.injection_range then
|
||||
inject_opts.range = { start_line = patch.injection_range.start_line }
|
||||
end
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
-- Use smart injection - handles imports automatically
|
||||
inject_result = inject.inject(target_bufnr, patch.generated_code, inject_opts)
|
||||
|
||||
-- Log injection details
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
if inject_result.imports_added > 0 then
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Replacing lines %d-%d with %d lines of code", start_line, end_line, #code_lines),
|
||||
message = string.format(
|
||||
"%s %d import(s), injected %d body line(s)",
|
||||
inject_result.imports_merged and "Merged" or "Added",
|
||||
inject_result.imports_added,
|
||||
inject_result.body_lines
|
||||
),
|
||||
})
|
||||
else
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Injected %d line(s) of code", inject_result.body_lines),
|
||||
})
|
||||
end)
|
||||
elseif patch.injection_strategy == "insert" and patch.injection_range then
|
||||
-- Insert at the specified location
|
||||
local insert_line = patch.injection_range.start_line
|
||||
insert_line = math.max(1, math.min(line_count + 1, insert_line))
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, insert_line - 1, insert_line - 1, false, code_lines)
|
||||
else
|
||||
-- Default: append to end
|
||||
-- Check if last line is empty, if not add a blank line for spacing
|
||||
local last_line = vim.api.nvim_buf_get_lines(target_bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Last line has content, add blank line for spacing
|
||||
table.insert(code_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, line_count, line_count, false, code_lines)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
@@ -577,6 +609,41 @@ function M.apply(patch)
|
||||
})
|
||||
end)
|
||||
|
||||
-- Learn from successful code generation - this builds neural pathways
|
||||
-- The more code is successfully applied, the better the brain becomes
|
||||
pcall(function()
|
||||
local brain = require("codetyper.brain")
|
||||
if brain.is_initialized() then
|
||||
-- Learn the successful pattern
|
||||
local intent_type = patch.intent and patch.intent.type or "unknown"
|
||||
local scope_type = patch.scope and patch.scope.type or "file"
|
||||
local scope_name = patch.scope and patch.scope.name or ""
|
||||
|
||||
-- Create a meaningful summary for this learning
|
||||
local summary = string.format(
|
||||
"Generated %s: %s %s in %s",
|
||||
intent_type,
|
||||
scope_type,
|
||||
scope_name ~= "" and scope_name or "",
|
||||
vim.fn.fnamemodify(patch.target_path or "", ":t")
|
||||
)
|
||||
|
||||
brain.learn({
|
||||
type = "code_completion",
|
||||
file = patch.target_path,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
intent = intent_type,
|
||||
code = patch.generated_code:sub(1, 500), -- Store first 500 chars
|
||||
language = vim.fn.fnamemodify(patch.target_path or "", ":e"),
|
||||
function_name = scope_name,
|
||||
prompt = patch.prompt_content,
|
||||
confidence = patch.confidence or 0.5,
|
||||
},
|
||||
})
|
||||
end
|
||||
end)
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
|
||||
229
lua/codetyper/agent/permissions.lua
Normal file
229
lua/codetyper/agent/permissions.lua
Normal file
@@ -0,0 +1,229 @@
|
||||
---@mod codetyper.agent.permissions Permission manager for agent actions
|
||||
---
|
||||
--- Manages permissions for bash commands and file operations with
|
||||
--- allow, allow-session, allow-list, and reject options.
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class PermissionState
|
||||
---@field session_allowed table<string, boolean> Commands allowed for this session
|
||||
---@field allow_list table<string, boolean> Patterns always allowed
|
||||
---@field deny_list table<string, boolean> Patterns always denied
|
||||
|
||||
local state = {
|
||||
session_allowed = {},
|
||||
allow_list = {},
|
||||
deny_list = {},
|
||||
}
|
||||
|
||||
--- Dangerous command patterns that should never be auto-allowed
|
||||
local DANGEROUS_PATTERNS = {
|
||||
"^rm%s+%-rf",
|
||||
"^rm%s+%-r%s+/",
|
||||
"^rm%s+/",
|
||||
"^sudo%s+rm",
|
||||
"^chmod%s+777",
|
||||
"^chmod%s+%-R",
|
||||
"^chown%s+%-R",
|
||||
"^dd%s+",
|
||||
"^mkfs",
|
||||
"^fdisk",
|
||||
"^format",
|
||||
":.*>%s*/dev/",
|
||||
"^curl.*|.*sh",
|
||||
"^wget.*|.*sh",
|
||||
"^eval%s+",
|
||||
"`;.*`",
|
||||
"%$%(.*%)",
|
||||
"fork%s*bomb",
|
||||
}
|
||||
|
||||
--- Safe command patterns that can be auto-allowed
|
||||
local SAFE_PATTERNS = {
|
||||
"^ls%s",
|
||||
"^ls$",
|
||||
"^cat%s",
|
||||
"^head%s",
|
||||
"^tail%s",
|
||||
"^grep%s",
|
||||
"^find%s",
|
||||
"^pwd$",
|
||||
"^echo%s",
|
||||
"^wc%s",
|
||||
"^which%s",
|
||||
"^type%s",
|
||||
"^file%s",
|
||||
"^stat%s",
|
||||
"^git%s+status",
|
||||
"^git%s+log",
|
||||
"^git%s+diff",
|
||||
"^git%s+branch",
|
||||
"^git%s+show",
|
||||
"^npm%s+list",
|
||||
"^npm%s+ls",
|
||||
"^npm%s+outdated",
|
||||
"^yarn%s+list",
|
||||
"^cargo%s+check",
|
||||
"^cargo%s+test",
|
||||
"^go%s+test",
|
||||
"^go%s+build",
|
||||
"^make%s+test",
|
||||
"^make%s+check",
|
||||
}
|
||||
|
||||
---@alias PermissionLevel "allow"|"allow_session"|"allow_list"|"reject"
|
||||
|
||||
---@class PermissionResult
|
||||
---@field allowed boolean Whether action is allowed
|
||||
---@field reason string Reason for the decision
|
||||
---@field auto boolean Whether this was an automatic decision
|
||||
|
||||
--- Check if a command matches a pattern
|
||||
---@param command string The command to check
|
||||
---@param pattern string The pattern to match
|
||||
---@return boolean
|
||||
local function matches_pattern(command, pattern)
|
||||
return command:match(pattern) ~= nil
|
||||
end
|
||||
|
||||
--- Check if command is dangerous
|
||||
---@param command string The command to check
|
||||
---@return boolean, string|nil dangerous, reason
|
||||
local function is_dangerous(command)
|
||||
for _, pattern in ipairs(DANGEROUS_PATTERNS) do
|
||||
if matches_pattern(command, pattern) then
|
||||
return true, "Matches dangerous pattern: " .. pattern
|
||||
end
|
||||
end
|
||||
return false, nil
|
||||
end
|
||||
|
||||
--- Check if command is safe
|
||||
---@param command string The command to check
|
||||
---@return boolean
|
||||
local function is_safe(command)
|
||||
for _, pattern in ipairs(SAFE_PATTERNS) do
|
||||
if matches_pattern(command, pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Normalize command for comparison (trim, lowercase first word)
|
||||
---@param command string
|
||||
---@return string
|
||||
local function normalize_command(command)
|
||||
return vim.trim(command)
|
||||
end
|
||||
|
||||
--- Check permission for a bash command
|
||||
---@param command string The command to check
|
||||
---@return PermissionResult
|
||||
function M.check_bash_permission(command)
|
||||
local normalized = normalize_command(command)
|
||||
|
||||
-- Check deny list first
|
||||
for pattern, _ in pairs(state.deny_list) do
|
||||
if matches_pattern(normalized, pattern) then
|
||||
return {
|
||||
allowed = false,
|
||||
reason = "Command in deny list",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if command is dangerous
|
||||
local dangerous, reason = is_dangerous(normalized)
|
||||
if dangerous then
|
||||
return {
|
||||
allowed = false,
|
||||
reason = reason,
|
||||
auto = false, -- Require explicit approval for dangerous commands
|
||||
}
|
||||
end
|
||||
|
||||
-- Check session allowed
|
||||
if state.session_allowed[normalized] then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Allowed for this session",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
|
||||
-- Check allow list patterns
|
||||
for pattern, _ in pairs(state.allow_list) do
|
||||
if matches_pattern(normalized, pattern) then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Matches allow list pattern",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if command is inherently safe
|
||||
if is_safe(normalized) then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Safe read-only command",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
|
||||
-- Otherwise, require explicit permission
|
||||
return {
|
||||
allowed = false,
|
||||
reason = "Requires approval",
|
||||
auto = false,
|
||||
}
|
||||
end
|
||||
|
||||
--- Grant permission for a command
|
||||
---@param command string The command
|
||||
---@param level PermissionLevel The permission level
|
||||
function M.grant_permission(command, level)
|
||||
local normalized = normalize_command(command)
|
||||
|
||||
if level == "allow_session" then
|
||||
state.session_allowed[normalized] = true
|
||||
elseif level == "allow_list" then
|
||||
-- Add as pattern (escape special chars for exact match)
|
||||
local pattern = "^" .. vim.pesc(normalized) .. "$"
|
||||
state.allow_list[pattern] = true
|
||||
end
|
||||
end
|
||||
|
||||
--- Add a pattern to the allow list
|
||||
---@param pattern string Lua pattern to allow
|
||||
function M.add_to_allow_list(pattern)
|
||||
state.allow_list[pattern] = true
|
||||
end
|
||||
|
||||
--- Add a pattern to the deny list
|
||||
---@param pattern string Lua pattern to deny
|
||||
function M.add_to_deny_list(pattern)
|
||||
state.deny_list[pattern] = true
|
||||
end
|
||||
|
||||
--- Clear session permissions
|
||||
function M.clear_session()
|
||||
state.session_allowed = {}
|
||||
end
|
||||
|
||||
--- Reset all permissions
|
||||
function M.reset()
|
||||
state.session_allowed = {}
|
||||
state.allow_list = {}
|
||||
state.deny_list = {}
|
||||
end
|
||||
|
||||
--- Get current permission state (for debugging)
|
||||
---@return PermissionState
|
||||
function M.get_state()
|
||||
return vim.deepcopy(state)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -23,7 +23,7 @@ local M = {}
|
||||
---@field priority number Priority (1=high, 2=normal, 3=low)
|
||||
---@field status string "pending"|"processing"|"completed"|"escalated"|"cancelled"|"needs_context"|"failed"
|
||||
---@field attempt_count number Number of processing attempts
|
||||
---@field worker_type string|nil LLM provider used ("ollama"|"claude"|etc)
|
||||
---@field worker_type string|nil LLM provider used ("ollama"|"openai"|"gemini"|"copilot")
|
||||
---@field created_at number System time when created
|
||||
---@field intent Intent|nil Detected intent from prompt
|
||||
---@field scope ScopeInfo|nil Resolved scope (function/class/file)
|
||||
|
||||
@@ -28,7 +28,7 @@ local state = {
|
||||
max_concurrent = 2,
|
||||
completion_delay_ms = 100,
|
||||
apply_delay_ms = 5000, -- Wait before applying code
|
||||
remote_provider = "claude", -- Default fallback provider
|
||||
remote_provider = "copilot", -- Default fallback provider
|
||||
},
|
||||
}
|
||||
|
||||
@@ -90,9 +90,7 @@ local function get_remote_provider()
|
||||
-- If current provider is ollama, use configured remote
|
||||
if config.llm.provider == "ollama" then
|
||||
-- Check which remote provider is configured
|
||||
if config.llm.claude and config.llm.claude.api_key then
|
||||
return "claude"
|
||||
elseif config.llm.openai and config.llm.openai.api_key then
|
||||
if config.llm.openai and config.llm.openai.api_key then
|
||||
return "openai"
|
||||
elseif config.llm.gemini and config.llm.gemini.api_key then
|
||||
return "gemini"
|
||||
@@ -120,7 +118,7 @@ local function get_primary_provider()
|
||||
return config.llm.provider
|
||||
end
|
||||
end
|
||||
return "claude"
|
||||
return "ollama"
|
||||
end
|
||||
|
||||
--- Retry event with additional context
|
||||
|
||||
@@ -81,6 +81,67 @@ M.definitions = {
|
||||
required = { "command" },
|
||||
},
|
||||
},
|
||||
|
||||
delete_file = {
|
||||
name = "delete_file",
|
||||
description = "Delete a file from the filesystem. Use with caution - requires explicit user approval.",
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = {
|
||||
path = {
|
||||
type = "string",
|
||||
description = "Path to the file to delete",
|
||||
},
|
||||
reason = {
|
||||
type = "string",
|
||||
description = "Reason for deleting this file (shown to user for approval)",
|
||||
},
|
||||
},
|
||||
required = { "path", "reason" },
|
||||
},
|
||||
},
|
||||
|
||||
list_directory = {
|
||||
name = "list_directory",
|
||||
description = "List files and directories in a path. Use to explore project structure.",
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = {
|
||||
path = {
|
||||
type = "string",
|
||||
description = "Path to the directory to list (defaults to current directory)",
|
||||
},
|
||||
recursive = {
|
||||
type = "boolean",
|
||||
description = "Whether to list recursively (default: false, max depth: 3)",
|
||||
},
|
||||
},
|
||||
required = {},
|
||||
},
|
||||
},
|
||||
|
||||
search_files = {
|
||||
name = "search_files",
|
||||
description = "Search for files by name pattern or content. Use to find relevant files in the project.",
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = {
|
||||
pattern = {
|
||||
type = "string",
|
||||
description = "Glob pattern for file names (e.g., '*.lua', 'test_*.py')",
|
||||
},
|
||||
content = {
|
||||
type = "string",
|
||||
description = "Search for files containing this text",
|
||||
},
|
||||
path = {
|
||||
type = "string",
|
||||
description = "Directory to search in (defaults to project root)",
|
||||
},
|
||||
},
|
||||
required = {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
--- Convert tool definitions to Claude API format
|
||||
|
||||
128
lua/codetyper/agent/tools/base.lua
Normal file
128
lua/codetyper/agent/tools/base.lua
Normal file
@@ -0,0 +1,128 @@
|
||||
---@mod codetyper.agent.tools.base Base tool definition
|
||||
---@brief [[
|
||||
--- Base metatable for all LLM tools.
|
||||
--- Tools extend this base to provide structured AI capabilities.
|
||||
---@brief ]]
|
||||
|
||||
---@class CoderToolParam
|
||||
---@field name string Parameter name
|
||||
---@field description string Parameter description
|
||||
---@field type string Parameter type ("string", "number", "boolean", "table")
|
||||
---@field optional? boolean Whether the parameter is optional
|
||||
---@field default? any Default value for optional parameters
|
||||
|
||||
---@class CoderToolReturn
|
||||
---@field name string Return value name
|
||||
---@field description string Return value description
|
||||
---@field type string Return type
|
||||
---@field optional? boolean Whether the return is optional
|
||||
|
||||
---@class CoderToolOpts
|
||||
---@field on_log? fun(message: string) Log callback
|
||||
---@field on_complete? fun(result: any, error: string|nil) Completion callback
|
||||
---@field session_ctx? table Session context
|
||||
---@field streaming? boolean Whether response is still streaming
|
||||
---@field confirm? fun(message: string, callback: fun(ok: boolean)) Confirmation callback
|
||||
|
||||
---@class CoderTool
|
||||
---@field name string Tool identifier
|
||||
---@field description string|fun(): string Tool description
|
||||
---@field params CoderToolParam[] Input parameters
|
||||
---@field returns CoderToolReturn[] Return values
|
||||
---@field requires_confirmation? boolean Whether tool needs user confirmation
|
||||
---@field func fun(input: table, opts: CoderToolOpts): any, string|nil Tool implementation
|
||||
|
||||
local M = {}
|
||||
M.__index = M
|
||||
|
||||
--- Call the tool function
|
||||
---@param opts CoderToolOpts Options for the tool call
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M:__call(opts, on_log, on_complete)
|
||||
return self.func(opts, on_log, on_complete)
|
||||
end
|
||||
|
||||
--- Get the tool description
|
||||
---@return string
|
||||
function M:get_description()
|
||||
if type(self.description) == "function" then
|
||||
return self.description()
|
||||
end
|
||||
return self.description
|
||||
end
|
||||
|
||||
--- Validate input against parameter schema
|
||||
---@param input table Input to validate
|
||||
---@return boolean valid
|
||||
---@return string|nil error
|
||||
function M:validate_input(input)
|
||||
if not self.params then
|
||||
return true
|
||||
end
|
||||
|
||||
for _, param in ipairs(self.params) do
|
||||
local value = input[param.name]
|
||||
|
||||
-- Check required parameters
|
||||
if not param.optional and value == nil then
|
||||
return false, string.format("Missing required parameter: %s", param.name)
|
||||
end
|
||||
|
||||
-- Type checking
|
||||
if value ~= nil then
|
||||
local actual_type = type(value)
|
||||
local expected_type = param.type
|
||||
|
||||
-- Handle special types
|
||||
if expected_type == "integer" and actual_type == "number" then
|
||||
if math.floor(value) ~= value then
|
||||
return false, string.format("Parameter %s must be an integer", param.name)
|
||||
end
|
||||
elseif expected_type ~= actual_type and expected_type ~= "any" then
|
||||
return false, string.format("Parameter %s must be %s, got %s", param.name, expected_type, actual_type)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Generate JSON schema for the tool (for LLM function calling)
|
||||
---@return table schema
|
||||
function M:to_schema()
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(self.params or {}) do
|
||||
local prop = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
|
||||
if param.default ~= nil then
|
||||
prop.default = param.default
|
||||
end
|
||||
|
||||
properties[param.name] = prop
|
||||
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
type = "function",
|
||||
function_def = {
|
||||
name = self.name,
|
||||
description = self:get_description(),
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
198
lua/codetyper/agent/tools/bash.lua
Normal file
198
lua/codetyper/agent/tools/bash.lua
Normal file
@@ -0,0 +1,198 @@
|
||||
---@mod codetyper.agent.tools.bash Shell command execution tool
|
||||
---@brief [[
|
||||
--- Tool for executing shell commands with safety checks.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "bash"
|
||||
|
||||
M.description = [[Executes a bash command in a shell.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- Do NOT use bash to read files (use 'view' tool instead)
|
||||
- Do NOT use bash to modify files (use 'write' or 'edit' tools instead)
|
||||
- Do NOT use interactive commands (vim, nano, less, etc.)
|
||||
- Commands timeout after 2 minutes by default
|
||||
|
||||
Allowed uses:
|
||||
- Running builds (make, npm run build, cargo build)
|
||||
- Running tests (npm test, pytest, cargo test)
|
||||
- Git operations (git status, git diff, git commit)
|
||||
- Package management (npm install, pip install)
|
||||
- System info commands (ls, pwd, which)]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "command",
|
||||
description = "The shell command to execute",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "cwd",
|
||||
description = "Working directory for the command (optional)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "timeout",
|
||||
description = "Timeout in milliseconds (default: 120000)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "stdout",
|
||||
description = "Command output",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if command failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = true
|
||||
|
||||
--- Banned commands for safety
|
||||
local BANNED_COMMANDS = {
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
":(){ :|:& };:",
|
||||
"> /dev/sda",
|
||||
}
|
||||
|
||||
--- Banned patterns
|
||||
local BANNED_PATTERNS = {
|
||||
"curl.*|.*sh",
|
||||
"wget.*|.*sh",
|
||||
"rm%s+%-rf%s+/",
|
||||
}
|
||||
|
||||
--- Check if command is safe
|
||||
---@param command string
|
||||
---@return boolean safe
|
||||
---@return string|nil reason
|
||||
local function is_safe_command(command)
|
||||
-- Check exact matches
|
||||
for _, banned in ipairs(BANNED_COMMANDS) do
|
||||
if command == banned then
|
||||
return false, "Command is banned for safety"
|
||||
end
|
||||
end
|
||||
|
||||
-- Check patterns
|
||||
for _, pattern in ipairs(BANNED_PATTERNS) do
|
||||
if command:match(pattern) then
|
||||
return false, "Command matches banned pattern"
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
---@param input {command: string, cwd?: string, timeout?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.command then
|
||||
return nil, "command is required"
|
||||
end
|
||||
|
||||
-- Safety check
|
||||
local safe, reason = is_safe_command(input.command)
|
||||
if not safe then
|
||||
return nil, reason
|
||||
end
|
||||
|
||||
-- Confirmation required
|
||||
if M.requires_confirmation and opts.confirm then
|
||||
local confirmed = false
|
||||
local confirm_error = nil
|
||||
|
||||
opts.confirm("Execute command: " .. input.command, function(ok)
|
||||
if not ok then
|
||||
confirm_error = "User declined command execution"
|
||||
end
|
||||
confirmed = ok
|
||||
end)
|
||||
|
||||
-- Wait for confirmation (in async context, this would be handled differently)
|
||||
if confirm_error then
|
||||
return nil, confirm_error
|
||||
end
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Executing: " .. input.command)
|
||||
end
|
||||
|
||||
-- Prepare command
|
||||
local cwd = input.cwd or vim.fn.getcwd()
|
||||
local timeout = input.timeout or 120000
|
||||
|
||||
-- Execute command
|
||||
local output = ""
|
||||
local exit_code = 0
|
||||
|
||||
local job_opts = {
|
||||
command = "bash",
|
||||
args = { "-c", input.command },
|
||||
cwd = cwd,
|
||||
on_stdout = function(_, data)
|
||||
if data then
|
||||
output = output .. table.concat(data, "\n")
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data then
|
||||
output = output .. table.concat(data, "\n")
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
exit_code = code
|
||||
end,
|
||||
}
|
||||
|
||||
-- Run synchronously with timeout
|
||||
local Job = require("plenary.job")
|
||||
local job = Job:new(job_opts)
|
||||
|
||||
job:sync(timeout)
|
||||
exit_code = job.code or 0
|
||||
output = table.concat(job:result() or {}, "\n")
|
||||
|
||||
-- Also get stderr
|
||||
local stderr = table.concat(job:stderr_result() or {}, "\n")
|
||||
if stderr and stderr ~= "" then
|
||||
output = output .. "\n" .. stderr
|
||||
end
|
||||
|
||||
-- Check result
|
||||
if exit_code ~= 0 then
|
||||
local error_msg = string.format("Command failed with exit code %d: %s", exit_code, output)
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, error_msg)
|
||||
end
|
||||
return nil, error_msg
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(output, nil)
|
||||
end
|
||||
|
||||
return output, nil
|
||||
end
|
||||
|
||||
return M
|
||||
429
lua/codetyper/agent/tools/edit.lua
Normal file
429
lua/codetyper/agent/tools/edit.lua
Normal file
@@ -0,0 +1,429 @@
|
||||
---@mod codetyper.agent.tools.edit File editing tool with fallback matching
|
||||
---@brief [[
|
||||
--- Tool for making targeted edits to files using search/replace.
|
||||
--- Implements multiple fallback strategies for robust matching.
|
||||
--- Inspired by opencode's 9-strategy approach.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "edit"
|
||||
|
||||
M.description = [[Makes a targeted edit to a file by replacing text.
|
||||
|
||||
The old_string should match the content you want to replace. The tool uses multiple
|
||||
matching strategies with fallbacks:
|
||||
1. Exact match
|
||||
2. Whitespace-normalized match
|
||||
3. Indentation-flexible match
|
||||
4. Line-trimmed match
|
||||
5. Fuzzy anchor-based match
|
||||
|
||||
For creating new files, use old_string="" and provide the full content in new_string.
|
||||
For large changes, consider using 'write' tool instead.]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file to edit",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "old_string",
|
||||
description = "Text to find and replace (empty string to create new file or append)",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "new_string",
|
||||
description = "Text to replace with",
|
||||
type = "string",
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "Whether the edit was applied",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if edit failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
--- Normalize line endings to LF
|
||||
---@param str string
|
||||
---@return string
|
||||
local function normalize_line_endings(str)
|
||||
return str:gsub("\r\n", "\n"):gsub("\r", "\n")
|
||||
end
|
||||
|
||||
--- Strategy 1: Exact match
|
||||
---@param content string File content
|
||||
---@param old_str string String to find
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function exact_match(content, old_str)
|
||||
local pos = content:find(old_str, 1, true)
|
||||
if pos then
|
||||
return pos, pos + #old_str - 1
|
||||
end
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 2: Whitespace-normalized match
|
||||
--- Collapses all whitespace to single spaces
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function whitespace_normalized_match(content, old_str)
|
||||
local function normalize_ws(s)
|
||||
return s:gsub("%s+", " "):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
end
|
||||
|
||||
local norm_old = normalize_ws(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
-- Try to find matching block
|
||||
for i = 1, #lines do
|
||||
local block = {}
|
||||
local block_start = nil
|
||||
|
||||
for j = i, #lines do
|
||||
table.insert(block, lines[j])
|
||||
local block_text = table.concat(block, "\n")
|
||||
local norm_block = normalize_ws(block_text)
|
||||
|
||||
if norm_block == norm_old then
|
||||
-- Found match
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
|
||||
-- If block is already longer than target, stop
|
||||
if #norm_block > #norm_old then
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 3: Indentation-flexible match
|
||||
--- Ignores leading whitespace differences
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function indentation_flexible_match(content, old_str)
|
||||
local function strip_indent(s)
|
||||
local lines = vim.split(s, "\n")
|
||||
local result = {}
|
||||
for _, line in ipairs(lines) do
|
||||
table.insert(result, line:gsub("^%s+", ""))
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
local stripped_old = strip_indent(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local num_old_lines = #old_lines
|
||||
|
||||
for i = 1, #lines - num_old_lines + 1 do
|
||||
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
|
||||
local block_text = table.concat(block, "\n")
|
||||
|
||||
if strip_indent(block_text) == stripped_old then
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 4: Line-trimmed match
|
||||
--- Trims each line before comparing
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function line_trimmed_match(content, old_str)
|
||||
local function trim_lines(s)
|
||||
local lines = vim.split(s, "\n")
|
||||
local result = {}
|
||||
for _, line in ipairs(lines) do
|
||||
table.insert(result, line:match("^%s*(.-)%s*$"))
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
local trimmed_old = trim_lines(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local num_old_lines = #old_lines
|
||||
|
||||
for i = 1, #lines - num_old_lines + 1 do
|
||||
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
|
||||
local block_text = table.concat(block, "\n")
|
||||
|
||||
if trim_lines(block_text) == trimmed_old then
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Calculate Levenshtein distance between two strings
|
||||
---@param s1 string
|
||||
---@param s2 string
|
||||
---@return number
|
||||
local function levenshtein(s1, s2)
|
||||
local len1, len2 = #s1, #s2
|
||||
local matrix = {}
|
||||
|
||||
for i = 0, len1 do
|
||||
matrix[i] = { [0] = i }
|
||||
end
|
||||
for j = 0, len2 do
|
||||
matrix[0][j] = j
|
||||
end
|
||||
|
||||
for i = 1, len1 do
|
||||
for j = 1, len2 do
|
||||
local cost = s1:sub(i, i) == s2:sub(j, j) and 0 or 1
|
||||
matrix[i][j] = math.min(
|
||||
matrix[i - 1][j] + 1,
|
||||
matrix[i][j - 1] + 1,
|
||||
matrix[i - 1][j - 1] + cost
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
return matrix[len1][len2]
|
||||
end
|
||||
|
||||
--- Strategy 5: Fuzzy anchor-based match
|
||||
--- Uses first and last lines as anchors, allows fuzzy matching in between
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@param threshold? number Similarity threshold (0-1), default 0.8
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function fuzzy_anchor_match(content, old_str, threshold)
|
||||
threshold = threshold or 0.8
|
||||
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
if #old_lines < 2 then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local first_line = old_lines[1]:match("^%s*(.-)%s*$")
|
||||
local last_line = old_lines[#old_lines]:match("^%s*(.-)%s*$")
|
||||
local content_lines = vim.split(content, "\n")
|
||||
|
||||
-- Find potential start positions
|
||||
local candidates = {}
|
||||
for i, line in ipairs(content_lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
if trimmed == first_line or (
|
||||
#first_line > 0 and
|
||||
1 - (levenshtein(trimmed, first_line) / math.max(#trimmed, #first_line)) >= threshold
|
||||
) then
|
||||
table.insert(candidates, i)
|
||||
end
|
||||
end
|
||||
|
||||
-- For each candidate, look for matching end
|
||||
for _, start_idx in ipairs(candidates) do
|
||||
local expected_end = start_idx + #old_lines - 1
|
||||
if expected_end <= #content_lines then
|
||||
local end_line = content_lines[expected_end]:match("^%s*(.-)%s*$")
|
||||
if end_line == last_line or (
|
||||
#last_line > 0 and
|
||||
1 - (levenshtein(end_line, last_line) / math.max(#end_line, #last_line)) >= threshold
|
||||
) then
|
||||
-- Calculate positions
|
||||
local before = table.concat(vim.list_slice(content_lines, 1, start_idx - 1), "\n")
|
||||
local block = table.concat(vim.list_slice(content_lines, start_idx, expected_end), "\n")
|
||||
local start_pos = #before + (start_idx > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Try all matching strategies in order
|
||||
---@param content string File content
|
||||
---@param old_str string String to find
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
---@return string strategy_used
|
||||
local function find_match(content, old_str)
|
||||
-- Strategy 1: Exact match
|
||||
local start_pos, end_pos = exact_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "exact"
|
||||
end
|
||||
|
||||
-- Strategy 2: Whitespace-normalized
|
||||
start_pos, end_pos = whitespace_normalized_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "whitespace_normalized"
|
||||
end
|
||||
|
||||
-- Strategy 3: Indentation-flexible
|
||||
start_pos, end_pos = indentation_flexible_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "indentation_flexible"
|
||||
end
|
||||
|
||||
-- Strategy 4: Line-trimmed
|
||||
start_pos, end_pos = line_trimmed_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "line_trimmed"
|
||||
end
|
||||
|
||||
-- Strategy 5: Fuzzy anchor
|
||||
start_pos, end_pos = fuzzy_anchor_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "fuzzy_anchor"
|
||||
end
|
||||
|
||||
return nil, nil, "none"
|
||||
end
|
||||
|
||||
---@param input {path: string, old_string: string, new_string: string}
|
||||
---@param opts CoderToolOpts
|
||||
---@return boolean|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
if input.old_string == nil then
|
||||
return nil, "old_string is required"
|
||||
end
|
||||
if input.new_string == nil then
|
||||
return nil, "new_string is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Editing file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Normalize inputs
|
||||
local old_str = normalize_line_endings(input.old_string)
|
||||
local new_str = normalize_line_endings(input.new_string)
|
||||
|
||||
-- Handle new file creation (empty old_string)
|
||||
if old_str == "" then
|
||||
-- Create parent directories
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if vim.fn.isdirectory(dir) == 0 then
|
||||
vim.fn.mkdir(dir, "p")
|
||||
end
|
||||
|
||||
-- Write new file
|
||||
local lines = vim.split(new_str, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to create file: " .. input.path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
if vim.fn.filereadable(path) ~= 1 then
|
||||
return nil, "File not found: " .. input.path
|
||||
end
|
||||
|
||||
-- Read current content
|
||||
local lines = vim.fn.readfile(path)
|
||||
if not lines then
|
||||
return nil, "Failed to read file: " .. input.path
|
||||
end
|
||||
|
||||
local content = normalize_line_endings(table.concat(lines, "\n"))
|
||||
|
||||
-- Find match using fallback strategies
|
||||
local start_pos, end_pos, strategy = find_match(content, old_str)
|
||||
|
||||
if not start_pos then
|
||||
return nil, "old_string not found in file (tried 5 matching strategies)"
|
||||
end
|
||||
|
||||
if opts.on_log then
|
||||
opts.on_log("Match found using strategy: " .. strategy)
|
||||
end
|
||||
|
||||
-- Perform replacement
|
||||
local new_content = content:sub(1, start_pos - 1) .. new_str .. content:sub(end_pos + 1)
|
||||
|
||||
-- Write back
|
||||
local new_lines = vim.split(new_content, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, new_lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to write file: " .. input.path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
return M
|
||||
146
lua/codetyper/agent/tools/glob.lua
Normal file
146
lua/codetyper/agent/tools/glob.lua
Normal file
@@ -0,0 +1,146 @@
|
||||
---@mod codetyper.agent.tools.glob File pattern matching tool
|
||||
---@brief [[
|
||||
--- Tool for finding files by glob pattern.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "glob"
|
||||
|
||||
M.description = [[Finds files matching a glob pattern.
|
||||
|
||||
Example patterns:
|
||||
- "**/*.lua" - All Lua files
|
||||
- "src/**/*.ts" - TypeScript files in src
|
||||
- "**/test_*.py" - Test files in Python]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "pattern",
|
||||
description = "Glob pattern to match files",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "path",
|
||||
description = "Base directory to search in (default: project root)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "max_results",
|
||||
description = "Maximum number of results (default: 100)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "matches",
|
||||
description = "JSON array of matching file paths",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if glob failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
---@param input {pattern: string, path?: string, max_results?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.pattern then
|
||||
return nil, "pattern is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Finding files: " .. input.pattern)
|
||||
end
|
||||
|
||||
-- Resolve base path
|
||||
local base_path = input.path or vim.fn.getcwd()
|
||||
if not vim.startswith(base_path, "/") then
|
||||
base_path = vim.fn.getcwd() .. "/" .. base_path
|
||||
end
|
||||
|
||||
local max_results = input.max_results or 100
|
||||
|
||||
-- Use vim.fn.glob or fd if available
|
||||
local matches = {}
|
||||
|
||||
if vim.fn.executable("fd") == 1 then
|
||||
-- Use fd for better performance
|
||||
local Job = require("plenary.job")
|
||||
|
||||
-- Convert glob to fd pattern
|
||||
local fd_pattern = input.pattern:gsub("%*%*/", ""):gsub("%*", ".*")
|
||||
|
||||
local job = Job:new({
|
||||
command = "fd",
|
||||
args = {
|
||||
"--type",
|
||||
"f",
|
||||
"--max-results",
|
||||
tostring(max_results),
|
||||
"--glob",
|
||||
input.pattern,
|
||||
base_path,
|
||||
},
|
||||
cwd = base_path,
|
||||
})
|
||||
|
||||
job:sync(30000)
|
||||
matches = job:result() or {}
|
||||
else
|
||||
-- Fallback to vim.fn.globpath
|
||||
local pattern = base_path .. "/" .. input.pattern
|
||||
local files = vim.fn.glob(pattern, false, true)
|
||||
|
||||
for i, file in ipairs(files) do
|
||||
if i > max_results then
|
||||
break
|
||||
end
|
||||
-- Make paths relative to base_path
|
||||
local relative = file:gsub("^" .. vim.pesc(base_path) .. "/", "")
|
||||
table.insert(matches, relative)
|
||||
end
|
||||
end
|
||||
|
||||
-- Clean up matches
|
||||
local cleaned = {}
|
||||
for _, match in ipairs(matches) do
|
||||
if match and match ~= "" then
|
||||
-- Make relative if absolute
|
||||
local relative = match
|
||||
if vim.startswith(match, base_path) then
|
||||
relative = match:sub(#base_path + 2)
|
||||
end
|
||||
table.insert(cleaned, relative)
|
||||
end
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
matches = cleaned,
|
||||
total = #cleaned,
|
||||
truncated = #cleaned >= max_results,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
150
lua/codetyper/agent/tools/grep.lua
Normal file
150
lua/codetyper/agent/tools/grep.lua
Normal file
@@ -0,0 +1,150 @@
|
||||
---@mod codetyper.agent.tools.grep Search tool
|
||||
---@brief [[
|
||||
--- Tool for searching file contents using ripgrep.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "grep"
|
||||
|
||||
M.description = [[Searches for a pattern in files using ripgrep.
|
||||
|
||||
Returns file paths and matching lines. Use this to find code by content.
|
||||
|
||||
Example patterns:
|
||||
- "function foo" - Find function definitions
|
||||
- "import.*react" - Find React imports
|
||||
- "TODO|FIXME" - Find todo comments]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "pattern",
|
||||
description = "Regular expression pattern to search for",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "path",
|
||||
description = "Directory or file to search in (default: project root)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "include",
|
||||
description = "File glob pattern to include (e.g., '*.lua')",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "max_results",
|
||||
description = "Maximum number of results (default: 50)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "matches",
|
||||
description = "JSON array of matches with file, line_number, and content",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if search failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
---@param input {pattern: string, path?: string, include?: string, max_results?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.pattern then
|
||||
return nil, "pattern is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Searching for: " .. input.pattern)
|
||||
end
|
||||
|
||||
-- Build ripgrep command
|
||||
local path = input.path or vim.fn.getcwd()
|
||||
local max_results = input.max_results or 50
|
||||
|
||||
-- Resolve path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Check if ripgrep is available
|
||||
if vim.fn.executable("rg") ~= 1 then
|
||||
return nil, "ripgrep (rg) is not installed"
|
||||
end
|
||||
|
||||
-- Build command args
|
||||
local args = {
|
||||
"--json",
|
||||
"--max-count",
|
||||
tostring(max_results),
|
||||
"--no-heading",
|
||||
}
|
||||
|
||||
if input.include then
|
||||
table.insert(args, "--glob")
|
||||
table.insert(args, input.include)
|
||||
end
|
||||
|
||||
table.insert(args, input.pattern)
|
||||
table.insert(args, path)
|
||||
|
||||
-- Execute ripgrep
|
||||
local Job = require("plenary.job")
|
||||
local job = Job:new({
|
||||
command = "rg",
|
||||
args = args,
|
||||
cwd = vim.fn.getcwd(),
|
||||
})
|
||||
|
||||
job:sync(30000) -- 30 second timeout
|
||||
|
||||
local results = job:result() or {}
|
||||
local matches = {}
|
||||
|
||||
-- Parse JSON output
|
||||
for _, line in ipairs(results) do
|
||||
if line and line ~= "" then
|
||||
local ok, parsed = pcall(vim.json.decode, line)
|
||||
if ok and parsed.type == "match" then
|
||||
local data = parsed.data
|
||||
table.insert(matches, {
|
||||
file = data.path.text,
|
||||
line_number = data.line_number,
|
||||
content = data.lines.text:gsub("\n$", ""),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
matches = matches,
|
||||
total = #matches,
|
||||
truncated = #matches >= max_results,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
308
lua/codetyper/agent/tools/init.lua
Normal file
308
lua/codetyper/agent/tools/init.lua
Normal file
@@ -0,0 +1,308 @@
|
||||
---@mod codetyper.agent.tools Tool registry and orchestration
|
||||
---@brief [[
|
||||
--- Registry for LLM tools with execution and schema generation.
|
||||
--- Inspired by avante.nvim's tool system.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Registered tools
|
||||
---@type table<string, CoderTool>
|
||||
local tools = {}
|
||||
|
||||
--- Tool execution history for current session
|
||||
---@type table[]
|
||||
local execution_history = {}
|
||||
|
||||
--- Register a tool
|
||||
---@param tool CoderTool Tool to register
|
||||
function M.register(tool)
|
||||
if not tool.name then
|
||||
error("Tool must have a name")
|
||||
end
|
||||
tools[tool.name] = tool
|
||||
end
|
||||
|
||||
--- Unregister a tool
|
||||
---@param name string Tool name
|
||||
function M.unregister(name)
|
||||
tools[name] = nil
|
||||
end
|
||||
|
||||
--- Get a tool by name
|
||||
---@param name string Tool name
|
||||
---@return CoderTool|nil
|
||||
function M.get(name)
|
||||
return tools[name]
|
||||
end
|
||||
|
||||
--- Get all registered tools
|
||||
---@return table<string, CoderTool>
|
||||
function M.get_all()
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Get tools as a list
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return CoderTool[]
|
||||
function M.list(filter)
|
||||
local result = {}
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
table.insert(result, tool)
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Generate schemas for all tools (for LLM function calling)
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] schemas
|
||||
function M.get_schemas(filter)
|
||||
local schemas = {}
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
if tool.to_schema then
|
||||
table.insert(schemas, tool:to_schema())
|
||||
end
|
||||
end
|
||||
end
|
||||
return schemas
|
||||
end
|
||||
|
||||
--- Execute a tool by name
|
||||
---@param name string Tool name
|
||||
---@param input table Input parameters
|
||||
---@param opts CoderToolOpts Execution options
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M.execute(name, input, opts)
|
||||
local tool = tools[name]
|
||||
if not tool then
|
||||
return nil, "Unknown tool: " .. name
|
||||
end
|
||||
|
||||
-- Validate input
|
||||
if tool.validate_input then
|
||||
local valid, err = tool:validate_input(input)
|
||||
if not valid then
|
||||
return nil, err
|
||||
end
|
||||
end
|
||||
|
||||
-- Log execution
|
||||
if opts.on_log then
|
||||
opts.on_log(string.format("Executing tool: %s", name))
|
||||
end
|
||||
|
||||
-- Track execution
|
||||
local execution = {
|
||||
tool = name,
|
||||
input = input,
|
||||
start_time = os.time(),
|
||||
status = "running",
|
||||
}
|
||||
table.insert(execution_history, execution)
|
||||
|
||||
-- Execute the tool
|
||||
local result, err = tool.func(input, opts)
|
||||
|
||||
-- Update execution record
|
||||
execution.end_time = os.time()
|
||||
execution.status = err and "error" or "completed"
|
||||
execution.result = result
|
||||
execution.error = err
|
||||
|
||||
return result, err
|
||||
end
|
||||
|
||||
--- Process a tool call from LLM response
|
||||
---@param tool_call table Tool call from LLM (name + input)
|
||||
---@param opts CoderToolOpts Execution options
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M.process_tool_call(tool_call, opts)
|
||||
local name = tool_call.name or tool_call.function_name
|
||||
local input = tool_call.input or tool_call.arguments or {}
|
||||
|
||||
-- Parse JSON arguments if string
|
||||
if type(input) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, input)
|
||||
if ok then
|
||||
input = parsed
|
||||
else
|
||||
return nil, "Failed to parse tool arguments: " .. input
|
||||
end
|
||||
end
|
||||
|
||||
return M.execute(name, input, opts)
|
||||
end
|
||||
|
||||
--- Get execution history
|
||||
---@param limit? number Max entries to return
|
||||
---@return table[]
|
||||
function M.get_history(limit)
|
||||
if not limit then
|
||||
return execution_history
|
||||
end
|
||||
|
||||
local result = {}
|
||||
local start = math.max(1, #execution_history - limit + 1)
|
||||
for i = start, #execution_history do
|
||||
table.insert(result, execution_history[i])
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Clear execution history
|
||||
function M.clear_history()
|
||||
execution_history = {}
|
||||
end
|
||||
|
||||
--- Load built-in tools
|
||||
function M.load_builtins()
|
||||
-- View file tool
|
||||
local view = require("codetyper.agent.tools.view")
|
||||
M.register(view)
|
||||
|
||||
-- Bash tool
|
||||
local bash = require("codetyper.agent.tools.bash")
|
||||
M.register(bash)
|
||||
|
||||
-- Grep tool
|
||||
local grep = require("codetyper.agent.tools.grep")
|
||||
M.register(grep)
|
||||
|
||||
-- Glob tool
|
||||
local glob = require("codetyper.agent.tools.glob")
|
||||
M.register(glob)
|
||||
|
||||
-- Write file tool
|
||||
local write = require("codetyper.agent.tools.write")
|
||||
M.register(write)
|
||||
|
||||
-- Edit tool
|
||||
local edit = require("codetyper.agent.tools.edit")
|
||||
M.register(edit)
|
||||
end
|
||||
|
||||
--- Initialize tools system
|
||||
function M.setup()
|
||||
M.load_builtins()
|
||||
end
|
||||
|
||||
--- Get tool definitions for LLM (lazy-loaded, OpenAI format)
|
||||
--- This is accessed as M.definitions property
|
||||
M.definitions = setmetatable({}, {
|
||||
__call = function()
|
||||
-- Ensure tools are loaded
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end,
|
||||
__index = function(_, key)
|
||||
-- Make it work as both function and table
|
||||
if key == "get" then
|
||||
return function()
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end,
|
||||
})
|
||||
|
||||
--- Get definitions as a function (for backwards compatibility)
|
||||
function M.get_definitions()
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end
|
||||
|
||||
--- Convert all tools to OpenAI function calling format
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] OpenAI-compatible tool definitions
|
||||
function M.to_openai_format(filter)
|
||||
local openai_tools = {}
|
||||
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if param.default ~= nil then
|
||||
properties[param.name].default = param.default
|
||||
end
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
table.insert(openai_tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return openai_tools
|
||||
end
|
||||
|
||||
--- Convert all tools to Claude tool use format
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] Claude-compatible tool definitions
|
||||
function M.to_claude_format(filter)
|
||||
local claude_tools = {}
|
||||
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
table.insert(claude_tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return claude_tools
|
||||
end
|
||||
|
||||
return M
|
||||
149
lua/codetyper/agent/tools/view.lua
Normal file
149
lua/codetyper/agent/tools/view.lua
Normal file
@@ -0,0 +1,149 @@
|
||||
---@mod codetyper.agent.tools.view File viewing tool
|
||||
---@brief [[
|
||||
--- Tool for reading file contents with line range support.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "view"
|
||||
|
||||
M.description = [[Reads the content of a file.
|
||||
|
||||
Usage notes:
|
||||
- Provide the file path relative to the project root
|
||||
- Use start_line and end_line to read specific sections
|
||||
- If content is truncated, use line ranges to read in chunks
|
||||
- Returns JSON with content, total_line_count, and is_truncated]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file (relative to project root or absolute)",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "start_line",
|
||||
description = "Line number to start reading (1-indexed)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "end_line",
|
||||
description = "Line number to end reading (1-indexed, inclusive)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "content",
|
||||
description = "File contents as JSON with content, total_line_count, is_truncated",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if file could not be read",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
--- Maximum content size before truncation
|
||||
local MAX_CONTENT_SIZE = 200 * 1024 -- 200KB
|
||||
|
||||
---@param input {path: string, start_line?: integer, end_line?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Reading file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
-- Relative path - resolve from project root
|
||||
local root = vim.fn.getcwd()
|
||||
path = root .. "/" .. path
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
local stat = vim.uv.fs_stat(path)
|
||||
if not stat then
|
||||
return nil, "File not found: " .. input.path
|
||||
end
|
||||
|
||||
if stat.type == "directory" then
|
||||
return nil, "Path is a directory: " .. input.path
|
||||
end
|
||||
|
||||
-- Read file
|
||||
local lines = vim.fn.readfile(path)
|
||||
if not lines then
|
||||
return nil, "Failed to read file: " .. input.path
|
||||
end
|
||||
|
||||
-- Apply line range
|
||||
local start_line = input.start_line or 1
|
||||
local end_line = input.end_line or #lines
|
||||
|
||||
start_line = math.max(1, start_line)
|
||||
end_line = math.min(#lines, end_line)
|
||||
|
||||
local total_lines = #lines
|
||||
local selected_lines = {}
|
||||
|
||||
for i = start_line, end_line do
|
||||
table.insert(selected_lines, lines[i])
|
||||
end
|
||||
|
||||
-- Check for truncation
|
||||
local content = table.concat(selected_lines, "\n")
|
||||
local is_truncated = false
|
||||
|
||||
if #content > MAX_CONTENT_SIZE then
|
||||
-- Truncate content
|
||||
local truncated_lines = {}
|
||||
local size = 0
|
||||
|
||||
for _, line in ipairs(selected_lines) do
|
||||
size = size + #line + 1
|
||||
if size > MAX_CONTENT_SIZE then
|
||||
is_truncated = true
|
||||
break
|
||||
end
|
||||
table.insert(truncated_lines, line)
|
||||
end
|
||||
|
||||
content = table.concat(truncated_lines, "\n")
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
content = content,
|
||||
total_line_count = total_lines,
|
||||
is_truncated = is_truncated,
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
101
lua/codetyper/agent/tools/write.lua
Normal file
101
lua/codetyper/agent/tools/write.lua
Normal file
@@ -0,0 +1,101 @@
|
||||
---@mod codetyper.agent.tools.write File writing tool
|
||||
---@brief [[
|
||||
--- Tool for creating or overwriting files.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.agent.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "write"
|
||||
|
||||
M.description = [[Creates or overwrites a file with new content.
|
||||
|
||||
IMPORTANT:
|
||||
- This will completely replace the file contents
|
||||
- Use 'edit' tool for partial modifications
|
||||
- Parent directories will be created if needed]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file to write",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "content",
|
||||
description = "Content to write to the file",
|
||||
type = "string",
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "Whether the file was written successfully",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if write failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = true
|
||||
|
||||
---@param input {path: string, content: string}
|
||||
---@param opts CoderToolOpts
|
||||
---@return boolean|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
if not input.content then
|
||||
return nil, "content is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Writing file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Create parent directories
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if vim.fn.isdirectory(dir) == 0 then
|
||||
vim.fn.mkdir(dir, "p")
|
||||
end
|
||||
|
||||
-- Write the file
|
||||
local lines = vim.split(input.content, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to write file: " .. path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -35,14 +35,62 @@ local state = {
|
||||
local ns_chat = vim.api.nvim_create_namespace("codetyper_agent_chat")
|
||||
local ns_logs = vim.api.nvim_create_namespace("codetyper_agent_logs")
|
||||
|
||||
--- Fixed widths
|
||||
local CHAT_WIDTH = 300
|
||||
local LOGS_WIDTH = 50
|
||||
--- Fixed heights
|
||||
local INPUT_HEIGHT = 5
|
||||
local LOGS_WIDTH = 50
|
||||
|
||||
--- Calculate dynamic width (1/4 of screen, minimum 30)
|
||||
---@return number
|
||||
local function get_panel_width()
|
||||
return math.max(math.floor(vim.o.columns * 0.25), 30)
|
||||
end
|
||||
|
||||
--- Autocmd group
|
||||
local agent_augroup = nil
|
||||
|
||||
--- Autocmd group for width maintenance
|
||||
local width_augroup = nil
|
||||
|
||||
--- Store target width
|
||||
local target_width = nil
|
||||
|
||||
--- Setup autocmd to always maintain 1/4 window width
|
||||
local function setup_width_autocmd()
|
||||
-- Clear previous autocmd group if exists
|
||||
if width_augroup then
|
||||
pcall(vim.api.nvim_del_augroup_by_id, width_augroup)
|
||||
end
|
||||
|
||||
width_augroup = vim.api.nvim_create_augroup("CodetypeAgentWidth", { clear = true })
|
||||
|
||||
-- Always maintain 1/4 width on any window event
|
||||
vim.api.nvim_create_autocmd({ "WinResized", "WinNew", "WinClosed", "VimResized" }, {
|
||||
group = width_augroup,
|
||||
callback = function()
|
||||
if not state.is_open or not state.chat_win then
|
||||
return
|
||||
end
|
||||
if not vim.api.nvim_win_is_valid(state.chat_win) then
|
||||
return
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
if state.chat_win and vim.api.nvim_win_is_valid(state.chat_win) then
|
||||
-- Always calculate 1/4 of current screen width
|
||||
local new_target = math.max(math.floor(vim.o.columns * 0.25), 30)
|
||||
target_width = new_target
|
||||
|
||||
local current_width = vim.api.nvim_win_get_width(state.chat_win)
|
||||
if current_width ~= target_width then
|
||||
pcall(vim.api.nvim_win_set_width, state.chat_win, target_width)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end,
|
||||
desc = "Maintain Agent panel at 1/4 window width",
|
||||
})
|
||||
end
|
||||
|
||||
--- Add a log entry to the logs buffer
|
||||
---@param entry table Log entry
|
||||
local function add_log_entry(entry)
|
||||
@@ -479,7 +527,7 @@ function M.open()
|
||||
vim.cmd("topleft vsplit")
|
||||
state.chat_win = vim.api.nvim_get_current_win()
|
||||
vim.api.nvim_win_set_buf(state.chat_win, state.chat_buf)
|
||||
vim.api.nvim_win_set_width(state.chat_win, CHAT_WIDTH)
|
||||
vim.api.nvim_win_set_width(state.chat_win, get_panel_width())
|
||||
|
||||
-- Window options for chat
|
||||
vim.wo[state.chat_win].number = false
|
||||
@@ -592,6 +640,10 @@ function M.open()
|
||||
end,
|
||||
})
|
||||
|
||||
-- Setup autocmd to maintain 1/4 width
|
||||
target_width = get_panel_width()
|
||||
setup_width_autocmd()
|
||||
|
||||
state.is_open = true
|
||||
|
||||
-- Focus input and log startup
|
||||
@@ -603,7 +655,16 @@ function M.open()
|
||||
if ok then
|
||||
local config = codetyper.get_config()
|
||||
local provider = config.llm.provider
|
||||
local model = provider == "claude" and config.llm.claude.model or config.llm.ollama.model
|
||||
local model = "unknown"
|
||||
if provider == "ollama" then
|
||||
model = config.llm.ollama.model
|
||||
elseif provider == "openai" then
|
||||
model = config.llm.openai.model
|
||||
elseif provider == "gemini" then
|
||||
model = config.llm.gemini.model
|
||||
elseif provider == "copilot" then
|
||||
model = config.llm.copilot.model
|
||||
end
|
||||
logs.info(string.format("%s (%s)", provider, model))
|
||||
end
|
||||
end
|
||||
|
||||
@@ -178,8 +178,7 @@ local active_workers = {}
|
||||
--- Default timeouts by provider type
|
||||
local default_timeouts = {
|
||||
ollama = 30000, -- 30s for local
|
||||
claude = 60000, -- 60s for remote
|
||||
openai = 60000,
|
||||
openai = 60000, -- 60s for remote
|
||||
gemini = 60000,
|
||||
copilot = 60000,
|
||||
}
|
||||
@@ -225,6 +224,134 @@ local function format_attached_files(attached_files)
|
||||
return table.concat(parts, "")
|
||||
end
|
||||
|
||||
--- Get coder companion file path for a target file
|
||||
---@param target_path string Target file path
|
||||
---@return string|nil Coder file path if exists
|
||||
local function get_coder_companion_path(target_path)
|
||||
if not target_path or target_path == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Skip if target is already a coder file
|
||||
if target_path:match("%.coder%.") then
|
||||
return nil
|
||||
end
|
||||
|
||||
local dir = vim.fn.fnamemodify(target_path, ":h")
|
||||
local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
local coder_path = dir .. "/" .. name .. ".coder." .. ext
|
||||
if vim.fn.filereadable(coder_path) == 1 then
|
||||
return coder_path
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Read and format coder companion context (business logic, pseudo-code)
|
||||
---@param target_path string Target file path
|
||||
---@return string Formatted coder context
|
||||
local function get_coder_context(target_path)
|
||||
local coder_path = get_coder_companion_path(target_path)
|
||||
if not coder_path then
|
||||
return ""
|
||||
end
|
||||
|
||||
local ok, lines = pcall(function()
|
||||
return vim.fn.readfile(coder_path)
|
||||
end)
|
||||
|
||||
if not ok or not lines or #lines == 0 then
|
||||
return ""
|
||||
end
|
||||
|
||||
local content = table.concat(lines, "\n")
|
||||
|
||||
-- Skip if only template comments (no actual content)
|
||||
local stripped = content:gsub("^%s*", ""):gsub("%s*$", "")
|
||||
if stripped == "" then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Check if there's meaningful content (not just template)
|
||||
local has_content = false
|
||||
for _, line in ipairs(lines) do
|
||||
-- Skip comment lines that are part of the template
|
||||
local trimmed = line:gsub("^%s*", "")
|
||||
if not trimmed:match("^[%-#/]+%s*Coder companion")
|
||||
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
|
||||
and not trimmed:match("^[%-#/]+%s*Example:")
|
||||
and not trimmed:match("^<!%-%-")
|
||||
and trimmed ~= ""
|
||||
and not trimmed:match("^[%-#/]+%s*$") then
|
||||
has_content = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if not has_content then
|
||||
return ""
|
||||
end
|
||||
|
||||
local ext = vim.fn.fnamemodify(coder_path, ":e")
|
||||
return string.format(
|
||||
"\n\n--- Business Context / Pseudo-code ---\n" ..
|
||||
"The following describes the intended behavior and design for this file:\n" ..
|
||||
"```%s\n%s\n```",
|
||||
ext,
|
||||
content:sub(1, 4000) -- Limit to 4000 chars
|
||||
)
|
||||
end
|
||||
|
||||
--- Format indexed project context for inclusion in prompt
|
||||
---@param indexed_context table|nil
|
||||
---@return string
|
||||
local function format_indexed_context(indexed_context)
|
||||
if not indexed_context then
|
||||
return ""
|
||||
end
|
||||
|
||||
local parts = {}
|
||||
|
||||
-- Project type
|
||||
if indexed_context.project_type and indexed_context.project_type ~= "unknown" then
|
||||
table.insert(parts, "Project type: " .. indexed_context.project_type)
|
||||
end
|
||||
|
||||
-- Relevant symbols
|
||||
if indexed_context.relevant_symbols then
|
||||
local symbol_list = {}
|
||||
for symbol, files in pairs(indexed_context.relevant_symbols) do
|
||||
if #files > 0 then
|
||||
table.insert(symbol_list, symbol .. " (in " .. files[1] .. ")")
|
||||
end
|
||||
end
|
||||
if #symbol_list > 0 then
|
||||
table.insert(parts, "Relevant symbols: " .. table.concat(symbol_list, ", "))
|
||||
end
|
||||
end
|
||||
|
||||
-- Learned patterns
|
||||
if indexed_context.patterns and #indexed_context.patterns > 0 then
|
||||
local pattern_list = {}
|
||||
for i, p in ipairs(indexed_context.patterns) do
|
||||
if i <= 3 then
|
||||
table.insert(pattern_list, p.content or "")
|
||||
end
|
||||
end
|
||||
if #pattern_list > 0 then
|
||||
table.insert(parts, "Project conventions: " .. table.concat(pattern_list, "; "))
|
||||
end
|
||||
end
|
||||
|
||||
if #parts == 0 then
|
||||
return ""
|
||||
end
|
||||
|
||||
return "\n\n--- Project Context ---\n" .. table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
--- Build prompt for code generation
|
||||
---@param event table PromptEvent
|
||||
---@return string prompt
|
||||
@@ -245,9 +372,71 @@ local function build_prompt(event)
|
||||
|
||||
local filetype = vim.fn.fnamemodify(event.target_path or "", ":e")
|
||||
|
||||
-- Get indexed project context
|
||||
local indexed_context = nil
|
||||
local indexed_content = ""
|
||||
pcall(function()
|
||||
local indexer = require("codetyper.indexer")
|
||||
indexed_context = indexer.get_context_for({
|
||||
file = event.target_path,
|
||||
intent = event.intent,
|
||||
prompt = event.prompt_content,
|
||||
scope = event.scope_text,
|
||||
})
|
||||
indexed_content = format_indexed_context(indexed_context)
|
||||
end)
|
||||
|
||||
-- Format attached files
|
||||
local attached_content = format_attached_files(event.attached_files)
|
||||
|
||||
-- Get coder companion context (business logic, pseudo-code)
|
||||
local coder_context = get_coder_context(event.target_path)
|
||||
|
||||
-- Get brain memories - contextual recall based on current task
|
||||
local brain_context = ""
|
||||
pcall(function()
|
||||
local brain = require("codetyper.brain")
|
||||
if brain.is_initialized() then
|
||||
-- Query brain for relevant memories based on:
|
||||
-- 1. Current file (file-specific patterns)
|
||||
-- 2. Prompt content (semantic similarity)
|
||||
-- 3. Intent type (relevant past generations)
|
||||
local query_text = event.prompt_content or ""
|
||||
if event.scope and event.scope.name then
|
||||
query_text = event.scope.name .. " " .. query_text
|
||||
end
|
||||
|
||||
local result = brain.query({
|
||||
query = query_text,
|
||||
file = event.target_path,
|
||||
max_results = 5,
|
||||
types = { "pattern", "correction", "convention" },
|
||||
})
|
||||
|
||||
if result and result.nodes and #result.nodes > 0 then
|
||||
local memories = { "\n\n--- Learned Patterns & Conventions ---" }
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c then
|
||||
local summary = node.c.s or ""
|
||||
local detail = node.c.d or ""
|
||||
if summary ~= "" then
|
||||
table.insert(memories, "• " .. summary)
|
||||
if detail ~= "" and #detail < 200 then
|
||||
table.insert(memories, " " .. detail)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if #memories > 1 then
|
||||
brain_context = table.concat(memories, "\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
-- Combine all context sources: brain memories first, then coder context, attached files, indexed
|
||||
local extra_context = brain_context .. coder_context .. attached_content .. indexed_content
|
||||
|
||||
-- Build context with scope information
|
||||
local context = {
|
||||
target_path = event.target_path,
|
||||
@@ -258,6 +447,7 @@ local function build_prompt(event)
|
||||
scope_range = event.scope_range,
|
||||
intent = event.intent,
|
||||
attached_files = event.attached_files,
|
||||
indexed_context = indexed_context,
|
||||
}
|
||||
|
||||
-- Build the actual prompt based on intent and scope
|
||||
@@ -296,7 +486,7 @@ Return ONLY the complete %s with implementation. No explanations, no duplicates.
|
||||
scope_type,
|
||||
filetype,
|
||||
event.scope_text,
|
||||
attached_content,
|
||||
extra_context,
|
||||
event.prompt_content,
|
||||
scope_type
|
||||
)
|
||||
@@ -317,7 +507,7 @@ Return the complete transformed %s. Output only code, no explanations.]],
|
||||
filetype,
|
||||
filetype,
|
||||
event.scope_text,
|
||||
attached_content,
|
||||
extra_context,
|
||||
event.prompt_content,
|
||||
scope_type
|
||||
)
|
||||
@@ -337,7 +527,7 @@ Output only the code to insert, no explanations.]],
|
||||
scope_name,
|
||||
filetype,
|
||||
event.scope_text,
|
||||
attached_content,
|
||||
extra_context,
|
||||
event.prompt_content
|
||||
)
|
||||
end
|
||||
@@ -357,7 +547,7 @@ Output only code, no explanations.]],
|
||||
filetype,
|
||||
filetype,
|
||||
target_content:sub(1, 4000), -- Limit context size
|
||||
attached_content,
|
||||
extra_context,
|
||||
event.prompt_content
|
||||
)
|
||||
end
|
||||
@@ -437,21 +627,21 @@ function M.start(worker)
|
||||
end
|
||||
end, worker.timeout_ms)
|
||||
|
||||
-- Get client and execute
|
||||
local client, client_err = get_client(worker.worker_type)
|
||||
if not client then
|
||||
M.complete(worker, nil, client_err)
|
||||
return
|
||||
end
|
||||
|
||||
local prompt, context = build_prompt(worker.event)
|
||||
|
||||
-- Call the LLM
|
||||
client.generate(prompt, context, function(response, err, usage)
|
||||
-- Check if smart selection is enabled (memory-based provider selection)
|
||||
local use_smart_selection = false
|
||||
pcall(function()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
use_smart_selection = config.llm.smart_selection ~= false -- Default to true
|
||||
end)
|
||||
|
||||
-- Define the response handler
|
||||
local function handle_response(response, err, usage_or_metadata)
|
||||
-- Cancel timeout timer
|
||||
if worker.timer then
|
||||
pcall(function()
|
||||
-- Timer might have already fired
|
||||
if type(worker.timer) == "userdata" and worker.timer.stop then
|
||||
worker.timer:stop()
|
||||
end
|
||||
@@ -462,8 +652,45 @@ function M.start(worker)
|
||||
return -- Already timed out or cancelled
|
||||
end
|
||||
|
||||
-- Extract usage from metadata if smart_generate was used
|
||||
local usage = usage_or_metadata
|
||||
if type(usage_or_metadata) == "table" and usage_or_metadata.provider then
|
||||
-- This is metadata from smart_generate
|
||||
usage = nil
|
||||
-- Update worker type to reflect actual provider used
|
||||
worker.worker_type = usage_or_metadata.provider
|
||||
-- Log if pondering occurred
|
||||
if usage_or_metadata.pondered then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
"Pondering: %s (agreement: %.0f%%)",
|
||||
usage_or_metadata.corrected and "corrected" or "validated",
|
||||
(usage_or_metadata.agreement or 1) * 100
|
||||
),
|
||||
})
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
M.complete(worker, response, err, usage)
|
||||
end)
|
||||
end
|
||||
|
||||
-- Use smart selection or direct client
|
||||
if use_smart_selection then
|
||||
local llm = require("codetyper.llm")
|
||||
llm.smart_generate(prompt, context, handle_response)
|
||||
else
|
||||
-- Get client and execute directly
|
||||
local client, client_err = get_client(worker.worker_type)
|
||||
if not client then
|
||||
M.complete(worker, nil, client_err)
|
||||
return
|
||||
end
|
||||
client.generate(prompt, context, handle_response)
|
||||
end
|
||||
end
|
||||
|
||||
--- Complete worker execution
|
||||
|
||||
@@ -312,7 +312,9 @@ local function append_log_to_output(entry)
|
||||
}
|
||||
|
||||
local icon = icons[entry.level] or "•"
|
||||
local formatted = string.format("[%s] %s %s", entry.timestamp, icon, entry.message)
|
||||
-- Sanitize message - replace newlines with spaces to prevent nvim_buf_set_lines error
|
||||
local sanitized_message = entry.message:gsub("\n", " "):gsub("\r", "")
|
||||
local formatted = string.format("[%s] %s %s", entry.timestamp, icon, sanitized_message)
|
||||
|
||||
vim.schedule(function()
|
||||
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then
|
||||
@@ -728,14 +730,17 @@ local function build_file_context()
|
||||
end
|
||||
|
||||
--- Build context for the question
|
||||
---@param intent? table Detected intent from intent module
|
||||
---@return table Context object
|
||||
local function build_context()
|
||||
local function build_context(intent)
|
||||
local context = {
|
||||
project_root = utils.get_project_root(),
|
||||
current_file = nil,
|
||||
current_content = nil,
|
||||
language = nil,
|
||||
referenced_files = state.referenced_files,
|
||||
brain_context = nil,
|
||||
indexer_context = nil,
|
||||
}
|
||||
|
||||
-- Try to get current file context from the non-ask window
|
||||
@@ -754,49 +759,140 @@ local function build_context()
|
||||
end
|
||||
end
|
||||
|
||||
-- Add brain context if intent needs it
|
||||
if intent and intent.needs_brain_context then
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized() then
|
||||
context.brain_context = brain.get_context_for_llm({
|
||||
file = context.current_file,
|
||||
max_tokens = 1000,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Add indexer context if intent needs project-wide context
|
||||
if intent and intent.needs_project_context then
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
context.indexer_context = indexer.get_context_for({
|
||||
file = context.current_file,
|
||||
prompt = "", -- Will be filled later
|
||||
intent = intent,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return context
|
||||
end
|
||||
|
||||
--- Submit the question to LLM
|
||||
function M.submit()
|
||||
local question = get_input_text()
|
||||
|
||||
if not question or question:match("^%s*$") then
|
||||
utils.notify("Please enter a question", vim.log.levels.WARN)
|
||||
M.focus_input()
|
||||
--- Append exploration log to output buffer
|
||||
---@param msg string
|
||||
---@param level string
|
||||
local function append_exploration_log(msg, level)
|
||||
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build context BEFORE clearing input (to preserve file references)
|
||||
local context = build_context()
|
||||
local file_context, file_count = build_file_context()
|
||||
vim.schedule(function()
|
||||
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build display message (without full file contents)
|
||||
local display_question = question
|
||||
if file_count > 0 then
|
||||
display_question = question .. "\n📎 " .. file_count .. " file(s) attached"
|
||||
vim.bo[state.output_buf].modifiable = true
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(state.output_buf, 0, -1, false)
|
||||
|
||||
-- Format based on level
|
||||
local formatted = msg
|
||||
if level == "progress" then
|
||||
formatted = msg
|
||||
elseif level == "debug" then
|
||||
formatted = msg
|
||||
elseif level == "file" then
|
||||
formatted = msg
|
||||
end
|
||||
|
||||
table.insert(lines, formatted)
|
||||
|
||||
vim.api.nvim_buf_set_lines(state.output_buf, 0, -1, false, lines)
|
||||
vim.bo[state.output_buf].modifiable = false
|
||||
|
||||
-- Scroll to bottom
|
||||
if state.output_win and vim.api.nvim_win_is_valid(state.output_win) then
|
||||
local line_count = vim.api.nvim_buf_line_count(state.output_buf)
|
||||
pcall(vim.api.nvim_win_set_cursor, state.output_win, { line_count, 0 })
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Continue submission after exploration
|
||||
---@param question string
|
||||
---@param intent table
|
||||
---@param context table
|
||||
---@param file_context string
|
||||
---@param file_count number
|
||||
---@param exploration_result table|nil
|
||||
local function continue_submit(question, intent, context, file_context, file_count, exploration_result)
|
||||
-- Get prompt type based on intent
|
||||
local ok_intent, intent_module = pcall(require, "codetyper.ask.intent")
|
||||
local prompt_type = "ask"
|
||||
if ok_intent then
|
||||
prompt_type = intent_module.get_prompt_type(intent)
|
||||
end
|
||||
|
||||
-- Add user message to output
|
||||
append_to_output(display_question, true)
|
||||
|
||||
-- Clear input and references AFTER building context
|
||||
M.clear_input()
|
||||
|
||||
-- Build system prompt for ask mode using prompts module
|
||||
-- Build system prompt using prompts module
|
||||
local prompts = require("codetyper.prompts")
|
||||
local system_prompt = prompts.system.ask
|
||||
local system_prompt = prompts.system[prompt_type] or prompts.system.ask
|
||||
|
||||
if context.current_file then
|
||||
system_prompt = system_prompt .. "\n\nCurrent open file: " .. context.current_file
|
||||
system_prompt = system_prompt .. "\nLanguage: " .. (context.language or "unknown")
|
||||
end
|
||||
|
||||
-- Add exploration context if available
|
||||
if exploration_result then
|
||||
local ok_explorer, explorer = pcall(require, "codetyper.ask.explorer")
|
||||
if ok_explorer then
|
||||
local explore_context = explorer.build_context(exploration_result)
|
||||
system_prompt = system_prompt .. "\n\n=== PROJECT EXPLORATION RESULTS ===\n"
|
||||
system_prompt = system_prompt .. explore_context
|
||||
system_prompt = system_prompt .. "\n=== END EXPLORATION ===\n"
|
||||
end
|
||||
end
|
||||
|
||||
-- Add brain context (learned patterns, conventions)
|
||||
if context.brain_context and context.brain_context ~= "" then
|
||||
system_prompt = system_prompt .. "\n\n=== LEARNED PROJECT KNOWLEDGE ===\n"
|
||||
system_prompt = system_prompt .. context.brain_context
|
||||
system_prompt = system_prompt .. "\n=== END LEARNED KNOWLEDGE ===\n"
|
||||
end
|
||||
|
||||
-- Add indexer context (project structure, symbols)
|
||||
if context.indexer_context then
|
||||
local idx_ctx = context.indexer_context
|
||||
if idx_ctx.project_type and idx_ctx.project_type ~= "unknown" then
|
||||
system_prompt = system_prompt .. "\n\nProject type: " .. idx_ctx.project_type
|
||||
end
|
||||
if idx_ctx.relevant_symbols and next(idx_ctx.relevant_symbols) then
|
||||
system_prompt = system_prompt .. "\n\nRelevant symbols in project:"
|
||||
for symbol, files in pairs(idx_ctx.relevant_symbols) do
|
||||
system_prompt = system_prompt .. "\n - " .. symbol .. " (in: " .. table.concat(files, ", ") .. ")"
|
||||
end
|
||||
end
|
||||
if idx_ctx.patterns and #idx_ctx.patterns > 0 then
|
||||
system_prompt = system_prompt .. "\n\nProject patterns/memories:"
|
||||
for _, pattern in ipairs(idx_ctx.patterns) do
|
||||
system_prompt = system_prompt .. "\n - " .. (pattern.summary or pattern.content or "")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Add to history
|
||||
table.insert(state.history, { role = "user", content = question })
|
||||
|
||||
-- Show loading indicator
|
||||
append_to_output("⏳ Thinking...", false)
|
||||
append_to_output("", false)
|
||||
append_to_output("⏳ Generating response...", false)
|
||||
|
||||
-- Get LLM client and generate response
|
||||
local ok, llm = pcall(require, "codetyper.llm")
|
||||
@@ -829,10 +925,23 @@ function M.submit()
|
||||
.. "\n```"
|
||||
end
|
||||
|
||||
-- Add exploration summary to prompt if available
|
||||
if exploration_result then
|
||||
full_prompt = full_prompt
|
||||
.. "\n\nPROJECT EXPLORATION COMPLETE: "
|
||||
.. exploration_result.total_files
|
||||
.. " files analyzed. "
|
||||
.. "Project type: "
|
||||
.. exploration_result.project.language
|
||||
.. " ("
|
||||
.. (exploration_result.project.framework or exploration_result.project.type)
|
||||
.. ")"
|
||||
end
|
||||
|
||||
local request_context = {
|
||||
file_content = file_context ~= "" and file_context or context.current_content,
|
||||
language = context.language,
|
||||
prompt_type = "explain",
|
||||
prompt_type = prompt_type,
|
||||
file_path = context.current_file,
|
||||
}
|
||||
|
||||
@@ -844,9 +953,9 @@ function M.submit()
|
||||
-- Remove last few lines (the thinking message)
|
||||
local to_remove = 0
|
||||
for i = #lines, 1, -1 do
|
||||
if lines[i]:match("Thinking") or lines[i]:match("^[│└┌─]") then
|
||||
if lines[i]:match("Generating") or lines[i]:match("^[│└┌─]") or lines[i] == "" then
|
||||
to_remove = to_remove + 1
|
||||
if lines[i]:match("┌") then
|
||||
if lines[i]:match("┌") or to_remove >= 5 then
|
||||
break
|
||||
end
|
||||
else
|
||||
@@ -879,6 +988,77 @@ function M.submit()
|
||||
end)
|
||||
end
|
||||
|
||||
--- Submit the question to LLM
|
||||
function M.submit()
|
||||
local question = get_input_text()
|
||||
|
||||
if not question or question:match("^%s*$") then
|
||||
utils.notify("Please enter a question", vim.log.levels.WARN)
|
||||
M.focus_input()
|
||||
return
|
||||
end
|
||||
|
||||
-- Detect intent from prompt
|
||||
local ok_intent, intent_module = pcall(require, "codetyper.ask.intent")
|
||||
local intent = nil
|
||||
if ok_intent then
|
||||
intent = intent_module.detect(question)
|
||||
else
|
||||
-- Fallback intent
|
||||
intent = {
|
||||
type = "ask",
|
||||
confidence = 0.5,
|
||||
needs_project_context = false,
|
||||
needs_brain_context = true,
|
||||
needs_exploration = false,
|
||||
}
|
||||
end
|
||||
|
||||
-- Build context BEFORE clearing input (to preserve file references)
|
||||
local context = build_context(intent)
|
||||
local file_context, file_count = build_file_context()
|
||||
|
||||
-- Build display message (without full file contents)
|
||||
local display_question = question
|
||||
if file_count > 0 then
|
||||
display_question = question .. "\n📎 " .. file_count .. " file(s) attached"
|
||||
end
|
||||
-- Show detected intent if not standard ask
|
||||
if intent.type ~= "ask" then
|
||||
display_question = display_question .. "\n🎯 " .. intent.type:upper() .. " mode"
|
||||
end
|
||||
-- Show exploration indicator
|
||||
if intent.needs_exploration then
|
||||
display_question = display_question .. "\n🔍 Project exploration required"
|
||||
end
|
||||
|
||||
-- Add user message to output
|
||||
append_to_output(display_question, true)
|
||||
|
||||
-- Clear input and references AFTER building context
|
||||
M.clear_input()
|
||||
|
||||
-- Check if exploration is needed
|
||||
if intent.needs_exploration then
|
||||
local ok_explorer, explorer = pcall(require, "codetyper.ask.explorer")
|
||||
if ok_explorer then
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
-- Start exploration with logging
|
||||
append_to_output("", false)
|
||||
explorer.explore(root, append_exploration_log, function(exploration_result)
|
||||
-- After exploration completes, continue with LLM request
|
||||
continue_submit(question, intent, context, file_context, file_count, exploration_result)
|
||||
end)
|
||||
return
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- No exploration needed, continue directly
|
||||
continue_submit(question, intent, context, file_context, file_count, nil)
|
||||
end
|
||||
|
||||
--- Clear chat history
|
||||
function M.clear_history()
|
||||
state.history = {}
|
||||
|
||||
676
lua/codetyper/ask/explorer.lua
Normal file
676
lua/codetyper/ask/explorer.lua
Normal file
@@ -0,0 +1,676 @@
|
||||
---@mod codetyper.ask.explorer Project exploration for Ask mode
|
||||
---@brief [[
|
||||
--- Performs comprehensive project exploration when explaining a project.
|
||||
--- Shows progress, indexes files, and builds brain context.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
---@class ExplorationState
|
||||
---@field is_exploring boolean
|
||||
---@field files_scanned number
|
||||
---@field total_files number
|
||||
---@field current_file string|nil
|
||||
---@field findings table
|
||||
---@field on_log fun(msg: string, level: string)|nil
|
||||
|
||||
local state = {
|
||||
is_exploring = false,
|
||||
files_scanned = 0,
|
||||
total_files = 0,
|
||||
current_file = nil,
|
||||
findings = {},
|
||||
on_log = nil,
|
||||
}
|
||||
|
||||
--- File extensions to analyze
|
||||
local ANALYZABLE_EXTENSIONS = {
|
||||
lua = true,
|
||||
ts = true,
|
||||
tsx = true,
|
||||
js = true,
|
||||
jsx = true,
|
||||
py = true,
|
||||
go = true,
|
||||
rs = true,
|
||||
rb = true,
|
||||
java = true,
|
||||
c = true,
|
||||
cpp = true,
|
||||
h = true,
|
||||
hpp = true,
|
||||
json = true,
|
||||
yaml = true,
|
||||
yml = true,
|
||||
toml = true,
|
||||
md = true,
|
||||
xml = true,
|
||||
}
|
||||
|
||||
--- Directories to skip
|
||||
local SKIP_DIRS = {
|
||||
-- Version control
|
||||
[".git"] = true,
|
||||
[".svn"] = true,
|
||||
[".hg"] = true,
|
||||
|
||||
-- IDE/Editor
|
||||
[".idea"] = true,
|
||||
[".vscode"] = true,
|
||||
[".cursor"] = true,
|
||||
[".cursorignore"] = true,
|
||||
[".claude"] = true,
|
||||
[".zed"] = true,
|
||||
|
||||
-- Project tooling
|
||||
[".coder"] = true,
|
||||
[".github"] = true,
|
||||
[".gitlab"] = true,
|
||||
[".husky"] = true,
|
||||
|
||||
-- Build outputs
|
||||
dist = true,
|
||||
build = true,
|
||||
out = true,
|
||||
target = true,
|
||||
bin = true,
|
||||
obj = true,
|
||||
[".build"] = true,
|
||||
[".output"] = true,
|
||||
|
||||
-- Dependencies
|
||||
node_modules = true,
|
||||
vendor = true,
|
||||
[".vendor"] = true,
|
||||
packages = true,
|
||||
bower_components = true,
|
||||
jspm_packages = true,
|
||||
|
||||
-- Cache/temp
|
||||
[".cache"] = true,
|
||||
[".tmp"] = true,
|
||||
[".temp"] = true,
|
||||
__pycache__ = true,
|
||||
[".pytest_cache"] = true,
|
||||
[".mypy_cache"] = true,
|
||||
[".ruff_cache"] = true,
|
||||
[".tox"] = true,
|
||||
[".nox"] = true,
|
||||
[".eggs"] = true,
|
||||
["*.egg-info"] = true,
|
||||
|
||||
-- Framework specific
|
||||
[".next"] = true,
|
||||
[".nuxt"] = true,
|
||||
[".svelte-kit"] = true,
|
||||
[".vercel"] = true,
|
||||
[".netlify"] = true,
|
||||
[".serverless"] = true,
|
||||
[".turbo"] = true,
|
||||
|
||||
-- Testing/coverage
|
||||
coverage = true,
|
||||
[".nyc_output"] = true,
|
||||
htmlcov = true,
|
||||
|
||||
-- Logs
|
||||
logs = true,
|
||||
log = true,
|
||||
|
||||
-- OS files
|
||||
[".DS_Store"] = true,
|
||||
Thumbs_db = true,
|
||||
}
|
||||
|
||||
--- Files to skip (patterns)
|
||||
local SKIP_FILES = {
|
||||
-- Lock files
|
||||
"package%-lock%.json",
|
||||
"yarn%.lock",
|
||||
"pnpm%-lock%.yaml",
|
||||
"Gemfile%.lock",
|
||||
"Cargo%.lock",
|
||||
"poetry%.lock",
|
||||
"Pipfile%.lock",
|
||||
"composer%.lock",
|
||||
"go%.sum",
|
||||
"flake%.lock",
|
||||
"%.lock$",
|
||||
"%-lock%.json$",
|
||||
"%-lock%.yaml$",
|
||||
|
||||
-- Generated files
|
||||
"%.min%.js$",
|
||||
"%.min%.css$",
|
||||
"%.bundle%.js$",
|
||||
"%.chunk%.js$",
|
||||
"%.map$",
|
||||
"%.d%.ts$",
|
||||
|
||||
-- Binary/media (shouldn't match anyway but be safe)
|
||||
"%.png$",
|
||||
"%.jpg$",
|
||||
"%.jpeg$",
|
||||
"%.gif$",
|
||||
"%.ico$",
|
||||
"%.svg$",
|
||||
"%.woff",
|
||||
"%.ttf$",
|
||||
"%.eot$",
|
||||
"%.pdf$",
|
||||
"%.zip$",
|
||||
"%.tar",
|
||||
"%.gz$",
|
||||
|
||||
-- Config that's not useful
|
||||
"%.env",
|
||||
"%.env%.",
|
||||
}
|
||||
|
||||
--- Log a message during exploration
|
||||
---@param msg string
|
||||
---@param level? string "info"|"debug"|"file"|"progress"
|
||||
local function log(msg, level)
|
||||
level = level or "info"
|
||||
if state.on_log then
|
||||
state.on_log(msg, level)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if file should be skipped
|
||||
---@param filename string
|
||||
---@return boolean
|
||||
local function should_skip_file(filename)
|
||||
for _, pattern in ipairs(SKIP_FILES) do
|
||||
if filename:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if directory should be skipped
|
||||
---@param dirname string
|
||||
---@return boolean
|
||||
local function should_skip_dir(dirname)
|
||||
-- Direct match
|
||||
if SKIP_DIRS[dirname] then
|
||||
return true
|
||||
end
|
||||
-- Pattern match for .cursor* etc
|
||||
if dirname:match("^%.cursor") then
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Get all files in project
|
||||
---@param root string Project root
|
||||
---@return string[] files
|
||||
local function get_project_files(root)
|
||||
local files = {}
|
||||
|
||||
local function scan_dir(dir)
|
||||
local handle = vim.loop.fs_scandir(dir)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
|
||||
local full_path = dir .. "/" .. name
|
||||
|
||||
if type == "directory" then
|
||||
if not should_skip_dir(name) then
|
||||
scan_dir(full_path)
|
||||
end
|
||||
elseif type == "file" then
|
||||
if not should_skip_file(name) then
|
||||
local ext = name:match("%.([^%.]+)$")
|
||||
if ext and ANALYZABLE_EXTENSIONS[ext:lower()] then
|
||||
table.insert(files, full_path)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
scan_dir(root)
|
||||
return files
|
||||
end
|
||||
|
||||
--- Analyze a single file
|
||||
---@param filepath string
|
||||
---@return table|nil analysis
|
||||
local function analyze_file(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content or content == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ext = filepath:match("%.([^%.]+)$") or ""
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
local analysis = {
|
||||
path = filepath,
|
||||
extension = ext,
|
||||
lines = #lines,
|
||||
size = #content,
|
||||
imports = {},
|
||||
exports = {},
|
||||
functions = {},
|
||||
classes = {},
|
||||
summary = "",
|
||||
}
|
||||
|
||||
-- Extract key patterns based on file type
|
||||
for i, line in ipairs(lines) do
|
||||
-- Imports/requires
|
||||
local import = line:match('import%s+.*%s+from%s+["\']([^"\']+)["\']')
|
||||
or line:match('require%(["\']([^"\']+)["\']%)')
|
||||
or line:match("from%s+([%w_.]+)%s+import")
|
||||
if import then
|
||||
table.insert(analysis.imports, { source = import, line = i })
|
||||
end
|
||||
|
||||
-- Function definitions
|
||||
local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(")
|
||||
or line:match("^%s*local%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*def%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*func%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*async%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(")
|
||||
if func then
|
||||
table.insert(analysis.functions, { name = func, line = i })
|
||||
end
|
||||
|
||||
-- Class definitions
|
||||
local class = line:match("^%s*class%s+([%w_]+)")
|
||||
or line:match("^%s*public%s+class%s+([%w_]+)")
|
||||
or line:match("^%s*interface%s+([%w_]+)")
|
||||
if class then
|
||||
table.insert(analysis.classes, { name = class, line = i })
|
||||
end
|
||||
|
||||
-- Exports
|
||||
local exp = line:match("^%s*export%s+.*%s+([%w_]+)")
|
||||
or line:match("^%s*module%.exports%s*=")
|
||||
or line:match("^return%s+M")
|
||||
if exp then
|
||||
table.insert(analysis.exports, { name = exp, line = i })
|
||||
end
|
||||
end
|
||||
|
||||
-- Create summary
|
||||
local parts = {}
|
||||
if #analysis.functions > 0 then
|
||||
table.insert(parts, #analysis.functions .. " functions")
|
||||
end
|
||||
if #analysis.classes > 0 then
|
||||
table.insert(parts, #analysis.classes .. " classes")
|
||||
end
|
||||
if #analysis.imports > 0 then
|
||||
table.insert(parts, #analysis.imports .. " imports")
|
||||
end
|
||||
analysis.summary = table.concat(parts, ", ")
|
||||
|
||||
return analysis
|
||||
end
|
||||
|
||||
--- Detect project type from files
|
||||
---@param root string
|
||||
---@return string type, table info
|
||||
local function detect_project_type(root)
|
||||
local info = {
|
||||
name = vim.fn.fnamemodify(root, ":t"),
|
||||
type = "unknown",
|
||||
framework = nil,
|
||||
language = nil,
|
||||
}
|
||||
|
||||
-- Check for common project files
|
||||
if utils.file_exists(root .. "/package.json") then
|
||||
info.type = "node"
|
||||
info.language = "JavaScript/TypeScript"
|
||||
local content = utils.read_file(root .. "/package.json")
|
||||
if content then
|
||||
local ok, pkg = pcall(vim.json.decode, content)
|
||||
if ok then
|
||||
info.name = pkg.name or info.name
|
||||
if pkg.dependencies then
|
||||
if pkg.dependencies.react then
|
||||
info.framework = "React"
|
||||
elseif pkg.dependencies.vue then
|
||||
info.framework = "Vue"
|
||||
elseif pkg.dependencies.next then
|
||||
info.framework = "Next.js"
|
||||
elseif pkg.dependencies.express then
|
||||
info.framework = "Express"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif utils.file_exists(root .. "/pom.xml") then
|
||||
info.type = "maven"
|
||||
info.language = "Java"
|
||||
local content = utils.read_file(root .. "/pom.xml")
|
||||
if content and content:match("spring%-boot") then
|
||||
info.framework = "Spring Boot"
|
||||
end
|
||||
elseif utils.file_exists(root .. "/Cargo.toml") then
|
||||
info.type = "rust"
|
||||
info.language = "Rust"
|
||||
elseif utils.file_exists(root .. "/go.mod") then
|
||||
info.type = "go"
|
||||
info.language = "Go"
|
||||
elseif utils.file_exists(root .. "/requirements.txt") or utils.file_exists(root .. "/pyproject.toml") then
|
||||
info.type = "python"
|
||||
info.language = "Python"
|
||||
elseif utils.file_exists(root .. "/init.lua") or utils.file_exists(root .. "/plugin/") then
|
||||
info.type = "neovim-plugin"
|
||||
info.language = "Lua"
|
||||
end
|
||||
|
||||
return info.type, info
|
||||
end
|
||||
|
||||
--- Build project structure summary
|
||||
---@param files string[]
|
||||
---@param root string
|
||||
---@return table structure
|
||||
local function build_structure(files, root)
|
||||
local structure = {
|
||||
directories = {},
|
||||
by_extension = {},
|
||||
total_files = #files,
|
||||
}
|
||||
|
||||
for _, file in ipairs(files) do
|
||||
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local dir = vim.fn.fnamemodify(relative, ":h")
|
||||
local ext = file:match("%.([^%.]+)$") or "unknown"
|
||||
|
||||
structure.directories[dir] = (structure.directories[dir] or 0) + 1
|
||||
structure.by_extension[ext] = (structure.by_extension[ext] or 0) + 1
|
||||
end
|
||||
|
||||
return structure
|
||||
end
|
||||
|
||||
--- Explore project and build context
|
||||
---@param root string Project root
|
||||
---@param on_log fun(msg: string, level: string) Log callback
|
||||
---@param on_complete fun(result: table) Completion callback
|
||||
function M.explore(root, on_log, on_complete)
|
||||
if state.is_exploring then
|
||||
on_log("⚠️ Already exploring...", "warning")
|
||||
return
|
||||
end
|
||||
|
||||
state.is_exploring = true
|
||||
state.on_log = on_log
|
||||
state.findings = {}
|
||||
|
||||
-- Start exploration
|
||||
log("⏺ Exploring project structure...", "info")
|
||||
log("", "info")
|
||||
|
||||
-- Detect project type
|
||||
log(" Detect(Project type)", "progress")
|
||||
local project_type, project_info = detect_project_type(root)
|
||||
log(" ⎿ " .. project_info.language .. " (" .. (project_info.framework or project_type) .. ")", "debug")
|
||||
|
||||
state.findings.project = project_info
|
||||
|
||||
-- Get all files
|
||||
log("", "info")
|
||||
log(" Scan(Project files)", "progress")
|
||||
local files = get_project_files(root)
|
||||
state.total_files = #files
|
||||
log(" ⎿ Found " .. #files .. " analyzable files", "debug")
|
||||
|
||||
-- Build structure
|
||||
local structure = build_structure(files, root)
|
||||
state.findings.structure = structure
|
||||
|
||||
-- Show directory breakdown
|
||||
log("", "info")
|
||||
log(" Structure(Directories)", "progress")
|
||||
local sorted_dirs = {}
|
||||
for dir, count in pairs(structure.directories) do
|
||||
table.insert(sorted_dirs, { dir = dir, count = count })
|
||||
end
|
||||
table.sort(sorted_dirs, function(a, b)
|
||||
return a.count > b.count
|
||||
end)
|
||||
for i, entry in ipairs(sorted_dirs) do
|
||||
if i <= 5 then
|
||||
log(" ⎿ " .. entry.dir .. " (" .. entry.count .. " files)", "debug")
|
||||
end
|
||||
end
|
||||
if #sorted_dirs > 5 then
|
||||
log(" ⎿ +" .. (#sorted_dirs - 5) .. " more directories", "debug")
|
||||
end
|
||||
|
||||
-- Analyze files asynchronously
|
||||
log("", "info")
|
||||
log(" Analyze(Source files)", "progress")
|
||||
|
||||
state.files_scanned = 0
|
||||
local analyses = {}
|
||||
local key_files = {}
|
||||
|
||||
-- Process files in batches to avoid blocking
|
||||
local batch_size = 10
|
||||
local current_batch = 0
|
||||
|
||||
local function process_batch()
|
||||
local start_idx = current_batch * batch_size + 1
|
||||
local end_idx = math.min(start_idx + batch_size - 1, #files)
|
||||
|
||||
for i = start_idx, end_idx do
|
||||
local file = files[i]
|
||||
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
|
||||
state.files_scanned = state.files_scanned + 1
|
||||
state.current_file = relative
|
||||
|
||||
local analysis = analyze_file(file)
|
||||
if analysis then
|
||||
analysis.relative_path = relative
|
||||
table.insert(analyses, analysis)
|
||||
|
||||
-- Track key files (many functions/classes)
|
||||
if #analysis.functions >= 3 or #analysis.classes >= 1 then
|
||||
table.insert(key_files, {
|
||||
path = relative,
|
||||
functions = #analysis.functions,
|
||||
classes = #analysis.classes,
|
||||
summary = analysis.summary,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Log some files
|
||||
if i <= 3 or (i % 20 == 0) then
|
||||
log(" ⎿ " .. relative .. ": " .. (analysis and analysis.summary or "(empty)"), "file")
|
||||
end
|
||||
end
|
||||
|
||||
-- Progress update
|
||||
local progress = math.floor((state.files_scanned / state.total_files) * 100)
|
||||
if progress % 25 == 0 and progress > 0 then
|
||||
log(" ⎿ " .. progress .. "% complete (" .. state.files_scanned .. "/" .. state.total_files .. ")", "debug")
|
||||
end
|
||||
|
||||
current_batch = current_batch + 1
|
||||
|
||||
if end_idx < #files then
|
||||
-- Schedule next batch
|
||||
vim.defer_fn(process_batch, 10)
|
||||
else
|
||||
-- Complete
|
||||
finish_exploration(root, analyses, key_files, on_complete)
|
||||
end
|
||||
end
|
||||
|
||||
-- Start processing
|
||||
vim.defer_fn(process_batch, 10)
|
||||
end
|
||||
|
||||
--- Finish exploration and store results
|
||||
---@param root string
|
||||
---@param analyses table
|
||||
---@param key_files table
|
||||
---@param on_complete fun(result: table)
|
||||
function finish_exploration(root, analyses, key_files, on_complete)
|
||||
log(" ⎿ +" .. (#analyses - 3) .. " more files analyzed", "debug")
|
||||
|
||||
-- Show key files
|
||||
if #key_files > 0 then
|
||||
log("", "info")
|
||||
log(" KeyFiles(Important components)", "progress")
|
||||
table.sort(key_files, function(a, b)
|
||||
return (a.functions + a.classes * 2) > (b.functions + b.classes * 2)
|
||||
end)
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i <= 5 then
|
||||
log(" ⎿ " .. kf.path .. ": " .. kf.summary, "file")
|
||||
end
|
||||
end
|
||||
if #key_files > 5 then
|
||||
log(" ⎿ +" .. (#key_files - 5) .. " more key files", "debug")
|
||||
end
|
||||
end
|
||||
|
||||
state.findings.analyses = analyses
|
||||
state.findings.key_files = key_files
|
||||
|
||||
-- Store in brain if available
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized() then
|
||||
log("", "info")
|
||||
log(" Store(Brain context)", "progress")
|
||||
|
||||
-- Store project pattern
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root,
|
||||
content = {
|
||||
summary = "Project: " .. state.findings.project.name,
|
||||
detail = state.findings.project.language
|
||||
.. " "
|
||||
.. (state.findings.project.framework or state.findings.project.type),
|
||||
code = nil,
|
||||
},
|
||||
context = {
|
||||
file = root,
|
||||
language = state.findings.project.language,
|
||||
},
|
||||
})
|
||||
|
||||
-- Store key file patterns
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i <= 10 then
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root .. "/" .. kf.path,
|
||||
content = {
|
||||
summary = kf.path .. " - " .. kf.summary,
|
||||
detail = kf.summary,
|
||||
},
|
||||
context = {
|
||||
file = kf.path,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
log(" ⎿ Stored " .. math.min(#key_files, 10) + 1 .. " patterns in brain", "debug")
|
||||
end
|
||||
|
||||
-- Store in indexer if available
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
log(" Index(Project index)", "progress")
|
||||
indexer.index_project(function(index)
|
||||
log(" ⎿ Indexed " .. (index.stats.files or 0) .. " files", "debug")
|
||||
end)
|
||||
end
|
||||
|
||||
log("", "info")
|
||||
log("✓ Exploration complete!", "info")
|
||||
log("", "info")
|
||||
|
||||
-- Build result
|
||||
local result = {
|
||||
project = state.findings.project,
|
||||
structure = state.findings.structure,
|
||||
key_files = key_files,
|
||||
total_files = state.total_files,
|
||||
analyses = analyses,
|
||||
}
|
||||
|
||||
state.is_exploring = false
|
||||
state.on_log = nil
|
||||
|
||||
on_complete(result)
|
||||
end
|
||||
|
||||
--- Check if exploration is in progress
|
||||
---@return boolean
|
||||
function M.is_exploring()
|
||||
return state.is_exploring
|
||||
end
|
||||
|
||||
--- Get exploration progress
|
||||
---@return number scanned, number total
|
||||
function M.get_progress()
|
||||
return state.files_scanned, state.total_files
|
||||
end
|
||||
|
||||
--- Build context string from exploration result
|
||||
---@param result table Exploration result
|
||||
---@return string context
|
||||
function M.build_context(result)
|
||||
local parts = {}
|
||||
|
||||
-- Project info
|
||||
table.insert(parts, "## Project: " .. result.project.name)
|
||||
table.insert(parts, "- Type: " .. result.project.type)
|
||||
table.insert(parts, "- Language: " .. (result.project.language or "Unknown"))
|
||||
if result.project.framework then
|
||||
table.insert(parts, "- Framework: " .. result.project.framework)
|
||||
end
|
||||
table.insert(parts, "- Files: " .. result.total_files)
|
||||
table.insert(parts, "")
|
||||
|
||||
-- Structure
|
||||
table.insert(parts, "## Structure")
|
||||
if result.structure and result.structure.by_extension then
|
||||
for ext, count in pairs(result.structure.by_extension) do
|
||||
table.insert(parts, "- ." .. ext .. ": " .. count .. " files")
|
||||
end
|
||||
end
|
||||
table.insert(parts, "")
|
||||
|
||||
-- Key components
|
||||
if result.key_files and #result.key_files > 0 then
|
||||
table.insert(parts, "## Key Components")
|
||||
for i, kf in ipairs(result.key_files) do
|
||||
if i <= 10 then
|
||||
table.insert(parts, "- " .. kf.path .. ": " .. kf.summary)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
return M
|
||||
302
lua/codetyper/ask/intent.lua
Normal file
302
lua/codetyper/ask/intent.lua
Normal file
@@ -0,0 +1,302 @@
|
||||
---@mod codetyper.ask.intent Intent detection for Ask mode
|
||||
---@brief [[
|
||||
--- Analyzes user prompts to detect intent (ask/explain vs code generation).
|
||||
--- Routes to appropriate prompt type and context sources.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@alias IntentType "ask"|"explain"|"generate"|"refactor"|"document"|"test"
|
||||
|
||||
---@class Intent
|
||||
---@field type IntentType Detected intent type
|
||||
---@field confidence number 0-1 confidence score
|
||||
---@field needs_project_context boolean Whether project-wide context is needed
|
||||
---@field needs_brain_context boolean Whether brain/learned context is helpful
|
||||
---@field needs_exploration boolean Whether full project exploration is needed
|
||||
---@field keywords string[] Keywords that influenced detection
|
||||
|
||||
--- Patterns for detecting ask/explain intent (questions about code)
|
||||
local ASK_PATTERNS = {
|
||||
-- Question words
|
||||
{ pattern = "^what%s", weight = 0.9 },
|
||||
{ pattern = "^why%s", weight = 0.95 },
|
||||
{ pattern = "^how%s+does", weight = 0.9 },
|
||||
{ pattern = "^how%s+do%s+i", weight = 0.7 }, -- Could be asking for code
|
||||
{ pattern = "^where%s", weight = 0.85 },
|
||||
{ pattern = "^when%s", weight = 0.85 },
|
||||
{ pattern = "^which%s", weight = 0.8 },
|
||||
{ pattern = "^who%s", weight = 0.85 },
|
||||
{ pattern = "^can%s+you%s+explain", weight = 0.95 },
|
||||
{ pattern = "^could%s+you%s+explain", weight = 0.95 },
|
||||
{ pattern = "^please%s+explain", weight = 0.95 },
|
||||
|
||||
-- Explanation requests
|
||||
{ pattern = "explain%s", weight = 0.9 },
|
||||
{ pattern = "describe%s", weight = 0.85 },
|
||||
{ pattern = "tell%s+me%s+about", weight = 0.85 },
|
||||
{ pattern = "walk%s+me%s+through", weight = 0.9 },
|
||||
{ pattern = "help%s+me%s+understand", weight = 0.95 },
|
||||
{ pattern = "what%s+is%s+the%s+purpose", weight = 0.95 },
|
||||
{ pattern = "what%s+does%s+this", weight = 0.9 },
|
||||
{ pattern = "what%s+does%s+it", weight = 0.9 },
|
||||
{ pattern = "how%s+does%s+this%s+work", weight = 0.95 },
|
||||
{ pattern = "how%s+does%s+it%s+work", weight = 0.95 },
|
||||
|
||||
-- Understanding queries
|
||||
{ pattern = "understand", weight = 0.7 },
|
||||
{ pattern = "meaning%s+of", weight = 0.85 },
|
||||
{ pattern = "difference%s+between", weight = 0.9 },
|
||||
{ pattern = "compared%s+to", weight = 0.8 },
|
||||
{ pattern = "vs%s", weight = 0.7 },
|
||||
{ pattern = "versus", weight = 0.7 },
|
||||
{ pattern = "pros%s+and%s+cons", weight = 0.9 },
|
||||
{ pattern = "advantages", weight = 0.8 },
|
||||
{ pattern = "disadvantages", weight = 0.8 },
|
||||
{ pattern = "trade%-?offs?", weight = 0.85 },
|
||||
|
||||
-- Analysis requests
|
||||
{ pattern = "analyze", weight = 0.85 },
|
||||
{ pattern = "review", weight = 0.7 }, -- Could also be refactor
|
||||
{ pattern = "overview", weight = 0.9 },
|
||||
{ pattern = "summary", weight = 0.9 },
|
||||
{ pattern = "summarize", weight = 0.9 },
|
||||
|
||||
-- Question marks (weaker signal)
|
||||
{ pattern = "%?$", weight = 0.3 },
|
||||
{ pattern = "%?%s*$", weight = 0.3 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting code generation intent
|
||||
local GENERATE_PATTERNS = {
|
||||
-- Direct commands
|
||||
{ pattern = "^create%s", weight = 0.9 },
|
||||
{ pattern = "^make%s", weight = 0.85 },
|
||||
{ pattern = "^build%s", weight = 0.85 },
|
||||
{ pattern = "^write%s", weight = 0.9 },
|
||||
{ pattern = "^add%s", weight = 0.85 },
|
||||
{ pattern = "^implement%s", weight = 0.95 },
|
||||
{ pattern = "^generate%s", weight = 0.95 },
|
||||
{ pattern = "^code%s", weight = 0.8 },
|
||||
|
||||
-- Modification commands
|
||||
{ pattern = "^fix%s", weight = 0.9 },
|
||||
{ pattern = "^change%s", weight = 0.8 },
|
||||
{ pattern = "^update%s", weight = 0.75 },
|
||||
{ pattern = "^modify%s", weight = 0.8 },
|
||||
{ pattern = "^replace%s", weight = 0.85 },
|
||||
{ pattern = "^remove%s", weight = 0.85 },
|
||||
{ pattern = "^delete%s", weight = 0.85 },
|
||||
|
||||
-- Feature requests
|
||||
{ pattern = "i%s+need%s+a", weight = 0.8 },
|
||||
{ pattern = "i%s+want%s+a", weight = 0.8 },
|
||||
{ pattern = "give%s+me", weight = 0.7 },
|
||||
{ pattern = "show%s+me%s+how%s+to%s+code", weight = 0.9 },
|
||||
{ pattern = "how%s+do%s+i%s+implement", weight = 0.85 },
|
||||
{ pattern = "can%s+you%s+write", weight = 0.9 },
|
||||
{ pattern = "can%s+you%s+create", weight = 0.9 },
|
||||
{ pattern = "can%s+you%s+add", weight = 0.85 },
|
||||
{ pattern = "can%s+you%s+make", weight = 0.85 },
|
||||
|
||||
-- Code-specific terms
|
||||
{ pattern = "function%s+that", weight = 0.85 },
|
||||
{ pattern = "class%s+that", weight = 0.85 },
|
||||
{ pattern = "method%s+that", weight = 0.85 },
|
||||
{ pattern = "component%s+that", weight = 0.85 },
|
||||
{ pattern = "module%s+that", weight = 0.85 },
|
||||
{ pattern = "api%s+for", weight = 0.8 },
|
||||
{ pattern = "endpoint%s+for", weight = 0.8 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting refactor intent
|
||||
local REFACTOR_PATTERNS = {
|
||||
{ pattern = "^refactor%s", weight = 0.95 },
|
||||
{ pattern = "refactor%s+this", weight = 0.95 },
|
||||
{ pattern = "clean%s+up", weight = 0.85 },
|
||||
{ pattern = "improve%s+this%s+code", weight = 0.85 },
|
||||
{ pattern = "make%s+this%s+cleaner", weight = 0.85 },
|
||||
{ pattern = "simplify", weight = 0.8 },
|
||||
{ pattern = "optimize", weight = 0.75 }, -- Could be explain
|
||||
{ pattern = "reorganize", weight = 0.9 },
|
||||
{ pattern = "restructure", weight = 0.9 },
|
||||
{ pattern = "extract%s+to", weight = 0.9 },
|
||||
{ pattern = "split%s+into", weight = 0.85 },
|
||||
{ pattern = "dry%s+this", weight = 0.9 }, -- Don't repeat yourself
|
||||
{ pattern = "reduce%s+duplication", weight = 0.9 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting documentation intent
|
||||
local DOCUMENT_PATTERNS = {
|
||||
{ pattern = "^document%s", weight = 0.95 },
|
||||
{ pattern = "add%s+documentation", weight = 0.95 },
|
||||
{ pattern = "add%s+docs", weight = 0.95 },
|
||||
{ pattern = "add%s+comments", weight = 0.9 },
|
||||
{ pattern = "add%s+docstring", weight = 0.95 },
|
||||
{ pattern = "add%s+jsdoc", weight = 0.95 },
|
||||
{ pattern = "write%s+documentation", weight = 0.95 },
|
||||
{ pattern = "document%s+this", weight = 0.95 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting test generation intent
|
||||
local TEST_PATTERNS = {
|
||||
{ pattern = "^test%s", weight = 0.9 },
|
||||
{ pattern = "write%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "add%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "create%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "generate%s+tests?", weight = 0.95 },
|
||||
{ pattern = "unit%s+tests?", weight = 0.9 },
|
||||
{ pattern = "test%s+cases?%s+for", weight = 0.95 },
|
||||
{ pattern = "spec%s+for", weight = 0.85 },
|
||||
}
|
||||
|
||||
--- Patterns indicating project-wide context is needed
|
||||
local PROJECT_CONTEXT_PATTERNS = {
|
||||
{ pattern = "project", weight = 0.9 },
|
||||
{ pattern = "codebase", weight = 0.95 },
|
||||
{ pattern = "entire", weight = 0.7 },
|
||||
{ pattern = "whole", weight = 0.7 },
|
||||
{ pattern = "all%s+files", weight = 0.9 },
|
||||
{ pattern = "architecture", weight = 0.95 },
|
||||
{ pattern = "structure", weight = 0.85 },
|
||||
{ pattern = "how%s+is%s+.*%s+organized", weight = 0.95 },
|
||||
{ pattern = "where%s+is%s+.*%s+defined", weight = 0.9 },
|
||||
{ pattern = "dependencies", weight = 0.85 },
|
||||
{ pattern = "imports?%s+from", weight = 0.7 },
|
||||
{ pattern = "modules?", weight = 0.6 },
|
||||
{ pattern = "packages?", weight = 0.6 },
|
||||
}
|
||||
|
||||
--- Patterns indicating project exploration is needed (full indexing)
|
||||
local EXPLORE_PATTERNS = {
|
||||
{ pattern = "explain%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "explain%s+.*%s*codebase", weight = 1.0 },
|
||||
{ pattern = "explain%s+me%s+the%s+project", weight = 1.0 },
|
||||
{ pattern = "tell%s+me%s+about%s+.*%s*project", weight = 0.95 },
|
||||
{ pattern = "what%s+is%s+this%s+project", weight = 0.95 },
|
||||
{ pattern = "overview%s+of%s+.*%s*project", weight = 0.95 },
|
||||
{ pattern = "understand%s+.*%s*project", weight = 0.9 },
|
||||
{ pattern = "analyze%s+.*%s*project", weight = 0.9 },
|
||||
{ pattern = "explore%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "explore%s+.*%s*codebase", weight = 1.0 },
|
||||
{ pattern = "index%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "scan%s+.*%s*project", weight = 0.95 },
|
||||
}
|
||||
|
||||
--- Match patterns against text
|
||||
---@param text string Lowercased text to match
|
||||
---@param patterns table Pattern list with weights
|
||||
---@return number Score, string[] Matched keywords
|
||||
local function match_patterns(text, patterns)
|
||||
local score = 0
|
||||
local matched = {}
|
||||
|
||||
for _, p in ipairs(patterns) do
|
||||
if text:match(p.pattern) then
|
||||
score = score + p.weight
|
||||
table.insert(matched, p.pattern)
|
||||
end
|
||||
end
|
||||
|
||||
return score, matched
|
||||
end
|
||||
|
||||
--- Detect intent from user prompt
|
||||
---@param prompt string User's question/request
|
||||
---@return Intent Detected intent
|
||||
function M.detect(prompt)
|
||||
local text = prompt:lower()
|
||||
|
||||
-- Calculate raw scores for each intent type (sum of matched weights)
|
||||
local ask_score, ask_kw = match_patterns(text, ASK_PATTERNS)
|
||||
local gen_score, gen_kw = match_patterns(text, GENERATE_PATTERNS)
|
||||
local ref_score, ref_kw = match_patterns(text, REFACTOR_PATTERNS)
|
||||
local doc_score, doc_kw = match_patterns(text, DOCUMENT_PATTERNS)
|
||||
local test_score, test_kw = match_patterns(text, TEST_PATTERNS)
|
||||
local proj_score, _ = match_patterns(text, PROJECT_CONTEXT_PATTERNS)
|
||||
local explore_score, _ = match_patterns(text, EXPLORE_PATTERNS)
|
||||
|
||||
-- Find the winner by raw score (highest accumulated weight)
|
||||
local scores = {
|
||||
{ type = "ask", score = ask_score, keywords = ask_kw },
|
||||
{ type = "generate", score = gen_score, keywords = gen_kw },
|
||||
{ type = "refactor", score = ref_score, keywords = ref_kw },
|
||||
{ type = "document", score = doc_score, keywords = doc_kw },
|
||||
{ type = "test", score = test_score, keywords = test_kw },
|
||||
}
|
||||
|
||||
table.sort(scores, function(a, b)
|
||||
return a.score > b.score
|
||||
end)
|
||||
|
||||
local winner = scores[1]
|
||||
|
||||
-- If top score is very low, default to ask (safer for Q&A)
|
||||
if winner.score < 0.3 then
|
||||
winner = { type = "ask", score = 0.5, keywords = {} }
|
||||
end
|
||||
|
||||
-- If ask and generate are close AND there's a question mark, prefer ask
|
||||
if winner.type == "generate" and ask_score > 0 then
|
||||
if text:match("%?%s*$") and ask_score >= gen_score * 0.5 then
|
||||
winner = { type = "ask", score = ask_score, keywords = ask_kw }
|
||||
end
|
||||
end
|
||||
|
||||
-- Determine if "explain" vs "ask" (explain needs more context)
|
||||
local intent_type = winner.type
|
||||
if intent_type == "ask" then
|
||||
-- "explain" if asking about how something works, otherwise "ask"
|
||||
if text:match("explain") or text:match("how%s+does") or text:match("walk%s+me%s+through") then
|
||||
intent_type = "explain"
|
||||
end
|
||||
end
|
||||
|
||||
-- Normalize confidence to 0-1 range (cap at reasonable max)
|
||||
local confidence = math.min(winner.score / 2, 1.0)
|
||||
|
||||
-- Check if exploration is needed (full project indexing)
|
||||
local needs_exploration = explore_score >= 0.9
|
||||
|
||||
---@type Intent
|
||||
local intent = {
|
||||
type = intent_type,
|
||||
confidence = confidence,
|
||||
needs_project_context = proj_score > 0.5 or needs_exploration,
|
||||
needs_brain_context = intent_type == "ask" or intent_type == "explain",
|
||||
needs_exploration = needs_exploration,
|
||||
keywords = winner.keywords,
|
||||
}
|
||||
|
||||
return intent
|
||||
end
|
||||
|
||||
--- Get prompt type for system prompt selection
|
||||
---@param intent Intent Detected intent
|
||||
---@return string Prompt type for prompts.system
|
||||
function M.get_prompt_type(intent)
|
||||
local mapping = {
|
||||
ask = "ask",
|
||||
explain = "ask", -- Uses same prompt as ask
|
||||
generate = "code_generation",
|
||||
refactor = "refactor",
|
||||
document = "document",
|
||||
test = "test",
|
||||
}
|
||||
return mapping[intent.type] or "ask"
|
||||
end
|
||||
|
||||
--- Check if intent requires code output
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.produces_code(intent)
|
||||
local code_intents = {
|
||||
generate = true,
|
||||
refactor = true,
|
||||
document = true, -- Documentation is code (comments)
|
||||
test = true,
|
||||
}
|
||||
return code_intents[intent.type] or false
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -18,6 +18,16 @@ local processed_prompts = {}
|
||||
--- Track if we're currently asking for preferences
|
||||
local asking_preference = false
|
||||
|
||||
--- Track if we're currently processing prompts (busy flag)
|
||||
local is_processing = false
|
||||
|
||||
--- Track the previous mode for visual mode detection
|
||||
local previous_mode = "n"
|
||||
|
||||
--- Debounce timer for prompt processing
|
||||
local prompt_process_timer = nil
|
||||
local PROMPT_PROCESS_DEBOUNCE_MS = 200 -- Wait 200ms after mode change before processing
|
||||
|
||||
--- Generate a unique key for a prompt
|
||||
---@param bufnr number Buffer number
|
||||
---@param prompt table Prompt object
|
||||
@@ -64,6 +74,20 @@ function M.setup()
|
||||
desc = "Check for closed prompt tags on InsertLeave",
|
||||
})
|
||||
|
||||
-- Track mode changes for visual mode detection
|
||||
vim.api.nvim_create_autocmd("ModeChanged", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function(ev)
|
||||
-- Extract old mode from pattern (format: "old_mode:new_mode")
|
||||
local old_mode = ev.match:match("^(.-):")
|
||||
if old_mode then
|
||||
previous_mode = old_mode
|
||||
end
|
||||
end,
|
||||
desc = "Track previous mode for visual mode detection",
|
||||
})
|
||||
|
||||
-- Auto-process prompts when entering normal mode (works on ALL files)
|
||||
vim.api.nvim_create_autocmd("ModeChanged", {
|
||||
group = group,
|
||||
@@ -74,10 +98,33 @@ function M.setup()
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
-- Slight delay to let buffer settle
|
||||
vim.defer_fn(function()
|
||||
|
||||
-- Skip if currently processing (avoid concurrent processing)
|
||||
if is_processing then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip if coming from visual mode (v, V, CTRL-V) - user is still editing
|
||||
if previous_mode == "v" or previous_mode == "V" or previous_mode == "\22" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Cancel any pending processing timer
|
||||
if prompt_process_timer then
|
||||
prompt_process_timer:stop()
|
||||
prompt_process_timer = nil
|
||||
end
|
||||
|
||||
-- Debounced processing - wait for user to truly be idle
|
||||
prompt_process_timer = vim.defer_fn(function()
|
||||
prompt_process_timer = nil
|
||||
-- Double-check we're still in normal mode
|
||||
local mode = vim.api.nvim_get_mode().mode
|
||||
if mode ~= "n" then
|
||||
return
|
||||
end
|
||||
M.check_all_prompts_with_preference()
|
||||
end, 50)
|
||||
end, PROMPT_PROCESS_DEBOUNCE_MS)
|
||||
end,
|
||||
desc = "Auto-process closed prompts when entering normal mode",
|
||||
})
|
||||
@@ -92,6 +139,10 @@ function M.setup()
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
-- Skip if currently processing
|
||||
if is_processing then
|
||||
return
|
||||
end
|
||||
local mode = vim.api.nvim_get_mode().mode
|
||||
if mode == "n" then
|
||||
M.check_all_prompts_with_preference()
|
||||
@@ -155,10 +206,28 @@ function M.setup()
|
||||
if filepath:match("%.coder%.") or filepath:match("tree%.log$") then
|
||||
return
|
||||
end
|
||||
-- Skip non-project files
|
||||
if filepath:match("node_modules") or filepath:match("%.git/") or filepath:match("%.coder/") then
|
||||
return
|
||||
end
|
||||
-- Schedule tree update with debounce
|
||||
schedule_tree_update()
|
||||
|
||||
-- Trigger incremental indexing if enabled
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
indexer.schedule_index_file(filepath)
|
||||
end
|
||||
|
||||
-- Update brain with file patterns
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized and brain.is_initialized() then
|
||||
vim.defer_fn(function()
|
||||
M.update_brain_from_file(filepath)
|
||||
end, 500) -- Debounce brain updates
|
||||
end
|
||||
end,
|
||||
desc = "Update tree.log on file creation/save",
|
||||
desc = "Update tree.log, index, and brain on file creation/save",
|
||||
})
|
||||
|
||||
-- Update tree.log when files are deleted (via netrw or file explorer)
|
||||
@@ -186,6 +255,19 @@ function M.setup()
|
||||
desc = "Update tree.log on directory change",
|
||||
})
|
||||
|
||||
-- Shutdown brain on Vim exit
|
||||
vim.api.nvim_create_autocmd("VimLeavePre", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function()
|
||||
local ok, brain = pcall(require, "codetyper.brain")
|
||||
if ok and brain.is_initialized and brain.is_initialized() then
|
||||
brain.shutdown()
|
||||
end
|
||||
end,
|
||||
desc = "Shutdown brain and flush pending changes",
|
||||
})
|
||||
|
||||
-- Auto-index: Create/open coder companion file when opening source files
|
||||
vim.api.nvim_create_autocmd("BufEnter", {
|
||||
group = group,
|
||||
@@ -211,7 +293,7 @@ local function get_config_safe()
|
||||
open_tag = "/@",
|
||||
close_tag = "@/",
|
||||
file_pattern = "*.coder.*",
|
||||
}
|
||||
},
|
||||
}
|
||||
end
|
||||
return config
|
||||
@@ -260,6 +342,12 @@ end
|
||||
|
||||
--- Check if the buffer has a newly closed prompt and auto-process (works on ANY file)
|
||||
function M.check_for_closed_prompt()
|
||||
-- Skip if already processing
|
||||
if is_processing then
|
||||
return
|
||||
end
|
||||
is_processing = true
|
||||
|
||||
local config = get_config_safe()
|
||||
local parser = require("codetyper.parser")
|
||||
|
||||
@@ -268,6 +356,7 @@ function M.check_for_closed_prompt()
|
||||
|
||||
-- Skip if no file
|
||||
if current_file == "" then
|
||||
is_processing = false
|
||||
return
|
||||
end
|
||||
|
||||
@@ -277,6 +366,7 @@ function M.check_for_closed_prompt()
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, line - 1, line, false)
|
||||
|
||||
if #lines == 0 then
|
||||
is_processing = false
|
||||
return
|
||||
end
|
||||
|
||||
@@ -292,6 +382,7 @@ function M.check_for_closed_prompt()
|
||||
|
||||
-- Check if already processed
|
||||
if processed_prompts[prompt_key] then
|
||||
is_processing = false
|
||||
return
|
||||
end
|
||||
|
||||
@@ -335,23 +426,38 @@ function M.check_for_closed_prompt()
|
||||
-- Clean prompt content (strip file references)
|
||||
local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content))
|
||||
|
||||
-- Resolve scope in target file FIRST (need it to adjust intent)
|
||||
local target_bufnr = vim.fn.bufnr(target_path)
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr
|
||||
end
|
||||
-- Check if we're working from a coder file
|
||||
local is_from_coder_file = utils.is_coder_file(current_file)
|
||||
|
||||
-- Resolve scope in target file FIRST (need it to adjust intent)
|
||||
-- Only resolve scope if NOT from coder file (line numbers don't apply)
|
||||
local target_bufnr = vim.fn.bufnr(target_path)
|
||||
local scope = nil
|
||||
local scope_text = nil
|
||||
local scope_range = nil
|
||||
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
if not is_from_coder_file then
|
||||
-- Prompt is in the actual source file, use line position for scope
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr
|
||||
end
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
end
|
||||
else
|
||||
-- Prompt is in coder file - load target if needed, but don't use scope
|
||||
-- Code from coder files should append to target by default
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = vim.fn.bufadd(target_path)
|
||||
if target_bufnr ~= 0 then
|
||||
vim.fn.bufload(target_bufnr)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Detect intent from prompt
|
||||
@@ -359,7 +465,8 @@ function M.check_for_closed_prompt()
|
||||
|
||||
-- IMPORTANT: If prompt is inside a function/method and intent is "add",
|
||||
-- override to "complete" since we're completing the function body
|
||||
if scope and (scope.type == "function" or scope.type == "method") then
|
||||
-- But NOT for coder files - they should use "add/append" by default
|
||||
if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then
|
||||
if intent.type == "add" or intent.action == "insert" or intent.action == "append" then
|
||||
-- Override to complete the function instead of adding new code
|
||||
intent = {
|
||||
@@ -372,6 +479,16 @@ function M.check_for_closed_prompt()
|
||||
end
|
||||
end
|
||||
|
||||
-- For coder files, default to "add" with "append" action
|
||||
if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then
|
||||
intent = {
|
||||
type = intent.type == "complete" and "add" or intent.type,
|
||||
confidence = intent.confidence,
|
||||
action = "append",
|
||||
keywords = intent.keywords,
|
||||
}
|
||||
end
|
||||
|
||||
-- Determine priority based on intent
|
||||
local priority = 2 -- Normal
|
||||
if intent.type == "fix" or intent.type == "complete" then
|
||||
@@ -400,13 +517,11 @@ function M.check_for_closed_prompt()
|
||||
attached_files = attached_files,
|
||||
})
|
||||
|
||||
local scope_info = scope and scope.type ~= "file"
|
||||
and string.format(" [%s: %s]", scope.type, scope.name or "anonymous")
|
||||
local scope_info = scope
|
||||
and scope.type ~= "file"
|
||||
and string.format(" [%s: %s]", scope.type, scope.name or "anonymous")
|
||||
or ""
|
||||
utils.notify(
|
||||
string.format("Prompt queued: %s%s", intent.type, scope_info),
|
||||
vim.log.levels.INFO
|
||||
)
|
||||
utils.notify(string.format("Prompt queued: %s%s", intent.type, scope_info), vim.log.levels.INFO)
|
||||
end)
|
||||
else
|
||||
-- Legacy: direct processing
|
||||
@@ -417,6 +532,7 @@ function M.check_for_closed_prompt()
|
||||
end
|
||||
end
|
||||
end
|
||||
is_processing = false
|
||||
end
|
||||
|
||||
--- Check and process all closed prompts in the buffer (works on ANY file)
|
||||
@@ -478,7 +594,8 @@ function M.check_all_prompts()
|
||||
|
||||
-- Get target path - for coder files, get the target; for regular files, use self
|
||||
local target_path
|
||||
if utils.is_coder_file(current_file) then
|
||||
local is_from_coder_file = utils.is_coder_file(current_file)
|
||||
if is_from_coder_file then
|
||||
target_path = utils.get_target_path(current_file)
|
||||
else
|
||||
target_path = current_file
|
||||
@@ -491,22 +608,33 @@ function M.check_all_prompts()
|
||||
local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content))
|
||||
|
||||
-- Resolve scope in target file FIRST (need it to adjust intent)
|
||||
-- Only resolve scope if NOT from coder file (line numbers don't apply)
|
||||
local target_bufnr = vim.fn.bufnr(target_path)
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr -- Use current buffer if target not loaded
|
||||
end
|
||||
|
||||
local scope = nil
|
||||
local scope_text = nil
|
||||
local scope_range = nil
|
||||
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
if not is_from_coder_file then
|
||||
-- Prompt is in the actual source file, use line position for scope
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr
|
||||
end
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
end
|
||||
else
|
||||
-- Prompt is in coder file - load target if needed
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = vim.fn.bufadd(target_path)
|
||||
if target_bufnr ~= 0 then
|
||||
vim.fn.bufload(target_bufnr)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Detect intent from prompt
|
||||
@@ -514,7 +642,8 @@ function M.check_all_prompts()
|
||||
|
||||
-- IMPORTANT: If prompt is inside a function/method and intent is "add",
|
||||
-- override to "complete" since we're completing the function body
|
||||
if scope and (scope.type == "function" or scope.type == "method") then
|
||||
-- But NOT for coder files - they should use "add/append" by default
|
||||
if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then
|
||||
if intent.type == "add" or intent.action == "insert" or intent.action == "append" then
|
||||
-- Override to complete the function instead of adding new code
|
||||
intent = {
|
||||
@@ -527,6 +656,16 @@ function M.check_all_prompts()
|
||||
end
|
||||
end
|
||||
|
||||
-- For coder files, default to "add" with "append" action
|
||||
if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then
|
||||
intent = {
|
||||
type = intent.type == "complete" and "add" or intent.type,
|
||||
confidence = intent.confidence,
|
||||
action = "append",
|
||||
keywords = intent.keywords,
|
||||
}
|
||||
end
|
||||
|
||||
-- Determine priority based on intent
|
||||
local priority = 2
|
||||
if intent.type == "fix" or intent.type == "complete" then
|
||||
@@ -555,13 +694,11 @@ function M.check_all_prompts()
|
||||
attached_files = attached_files,
|
||||
})
|
||||
|
||||
local scope_info = scope and scope.type ~= "file"
|
||||
and string.format(" [%s: %s]", scope.type, scope.name or "anonymous")
|
||||
local scope_info = scope
|
||||
and scope.type ~= "file"
|
||||
and string.format(" [%s: %s]", scope.type, scope.name or "anonymous")
|
||||
or ""
|
||||
utils.notify(
|
||||
string.format("Prompt queued: %s%s", intent.type, scope_info),
|
||||
vim.log.levels.INFO
|
||||
)
|
||||
utils.notify(string.format("Prompt queued: %s%s", intent.type, scope_info), vim.log.levels.INFO)
|
||||
end)
|
||||
|
||||
::continue::
|
||||
@@ -822,15 +959,135 @@ function M.clear()
|
||||
vim.api.nvim_del_augroup_by_name(AUGROUP)
|
||||
end
|
||||
|
||||
--- Debounce timers for brain updates per file
|
||||
---@type table<string, uv_timer_t>
|
||||
local brain_update_timers = {}
|
||||
|
||||
--- Update brain with patterns from a file
|
||||
---@param filepath string
|
||||
function M.update_brain_from_file(filepath)
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain or not brain.is_initialized() then
|
||||
return
|
||||
end
|
||||
|
||||
-- Read file content
|
||||
local content = utils.read_file(filepath)
|
||||
if not content or content == "" then
|
||||
return
|
||||
end
|
||||
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
-- Extract key patterns from the file
|
||||
local functions = {}
|
||||
local classes = {}
|
||||
local imports = {}
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
-- Functions
|
||||
local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(")
|
||||
or line:match("^%s*local%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*def%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*func%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*async%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*private%s+.*%s+([%w_]+)%s*%(")
|
||||
if func then
|
||||
table.insert(functions, { name = func, line = i })
|
||||
end
|
||||
|
||||
-- Classes
|
||||
local class = line:match("^%s*class%s+([%w_]+)")
|
||||
or line:match("^%s*public%s+class%s+([%w_]+)")
|
||||
or line:match("^%s*interface%s+([%w_]+)")
|
||||
or line:match("^%s*struct%s+([%w_]+)")
|
||||
if class then
|
||||
table.insert(classes, { name = class, line = i })
|
||||
end
|
||||
|
||||
-- Imports
|
||||
local imp = line:match("import%s+.*%s+from%s+[\"']([^\"']+)[\"']")
|
||||
or line:match("require%([\"']([^\"']+)[\"']%)")
|
||||
or line:match("from%s+([%w_.]+)%s+import")
|
||||
if imp then
|
||||
table.insert(imports, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Only store if file has meaningful content
|
||||
if #functions == 0 and #classes == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build summary
|
||||
local parts = {}
|
||||
if #functions > 0 then
|
||||
local func_names = {}
|
||||
for i, f in ipairs(functions) do
|
||||
if i <= 5 then
|
||||
table.insert(func_names, f.name)
|
||||
end
|
||||
end
|
||||
table.insert(parts, "functions: " .. table.concat(func_names, ", "))
|
||||
end
|
||||
if #classes > 0 then
|
||||
local class_names = {}
|
||||
for _, c in ipairs(classes) do
|
||||
table.insert(class_names, c.name)
|
||||
end
|
||||
table.insert(parts, "classes: " .. table.concat(class_names, ", "))
|
||||
end
|
||||
|
||||
local summary = vim.fn.fnamemodify(filepath, ":t") .. " - " .. table.concat(parts, "; ")
|
||||
|
||||
-- Learn this pattern - use "pattern_detected" type to match the pattern learner
|
||||
brain.learn({
|
||||
type = "pattern_detected",
|
||||
file = filepath,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
name = summary,
|
||||
description = #functions .. " functions, " .. #classes .. " classes",
|
||||
language = ext,
|
||||
symbols = vim.tbl_map(function(f) return f.name end, functions),
|
||||
example = nil,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Track buffers that have been auto-indexed
|
||||
---@type table<number, boolean>
|
||||
local auto_indexed_buffers = {}
|
||||
|
||||
--- Supported file extensions for auto-indexing
|
||||
local supported_extensions = {
|
||||
"ts", "tsx", "js", "jsx", "py", "lua", "go", "rs", "rb",
|
||||
"java", "c", "cpp", "cs", "json", "yaml", "yml", "md",
|
||||
"html", "css", "scss", "vue", "svelte", "php", "sh", "zsh",
|
||||
"ts",
|
||||
"tsx",
|
||||
"js",
|
||||
"jsx",
|
||||
"py",
|
||||
"lua",
|
||||
"go",
|
||||
"rs",
|
||||
"rb",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
"cs",
|
||||
"json",
|
||||
"yaml",
|
||||
"yml",
|
||||
"md",
|
||||
"html",
|
||||
"css",
|
||||
"scss",
|
||||
"vue",
|
||||
"svelte",
|
||||
"php",
|
||||
"sh",
|
||||
"zsh",
|
||||
}
|
||||
|
||||
--- Check if extension is supported
|
||||
@@ -847,6 +1104,126 @@ end
|
||||
|
||||
--- Auto-index a file by creating/opening its coder companion
|
||||
---@param bufnr number Buffer number
|
||||
--- Directories to ignore for coder file creation
|
||||
local ignored_directories = {
|
||||
".git",
|
||||
".coder",
|
||||
".claude",
|
||||
".vscode",
|
||||
".idea",
|
||||
"node_modules",
|
||||
"vendor",
|
||||
"dist",
|
||||
"build",
|
||||
"target",
|
||||
"__pycache__",
|
||||
".cache",
|
||||
".npm",
|
||||
".yarn",
|
||||
"coverage",
|
||||
".next",
|
||||
".nuxt",
|
||||
".svelte-kit",
|
||||
"out",
|
||||
"bin",
|
||||
"obj",
|
||||
}
|
||||
|
||||
--- Files to ignore for coder file creation (exact names or patterns)
|
||||
local ignored_files = {
|
||||
-- Git files
|
||||
".gitignore",
|
||||
".gitattributes",
|
||||
".gitmodules",
|
||||
-- Lock files
|
||||
"package-lock.json",
|
||||
"yarn.lock",
|
||||
"pnpm-lock.yaml",
|
||||
"Cargo.lock",
|
||||
"Gemfile.lock",
|
||||
"poetry.lock",
|
||||
"composer.lock",
|
||||
-- Config files that don't need coder companions
|
||||
".env",
|
||||
".env.local",
|
||||
".env.development",
|
||||
".env.production",
|
||||
".eslintrc",
|
||||
".eslintrc.json",
|
||||
".prettierrc",
|
||||
".prettierrc.json",
|
||||
".editorconfig",
|
||||
".dockerignore",
|
||||
"Dockerfile",
|
||||
"docker-compose.yml",
|
||||
"docker-compose.yaml",
|
||||
".npmrc",
|
||||
".yarnrc",
|
||||
".nvmrc",
|
||||
"tsconfig.json",
|
||||
"jsconfig.json",
|
||||
"babel.config.js",
|
||||
"webpack.config.js",
|
||||
"vite.config.js",
|
||||
"rollup.config.js",
|
||||
"jest.config.js",
|
||||
"vitest.config.js",
|
||||
".stylelintrc",
|
||||
"tailwind.config.js",
|
||||
"postcss.config.js",
|
||||
-- Other non-code files
|
||||
"README.md",
|
||||
"CHANGELOG.md",
|
||||
"LICENSE",
|
||||
"LICENSE.md",
|
||||
"CONTRIBUTING.md",
|
||||
"Makefile",
|
||||
"CMakeLists.txt",
|
||||
}
|
||||
|
||||
--- Check if a file path contains an ignored directory
|
||||
---@param filepath string Full file path
|
||||
---@return boolean
|
||||
local function is_in_ignored_directory(filepath)
|
||||
for _, dir in ipairs(ignored_directories) do
|
||||
-- Check for /dirname/ or /dirname at end
|
||||
if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then
|
||||
return true
|
||||
end
|
||||
-- Also check for dirname/ at start (relative paths)
|
||||
if filepath:match("^" .. dir .. "/") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if a file should be ignored for coder companion creation
|
||||
---@param filepath string Full file path
|
||||
---@return boolean
|
||||
local function should_ignore_for_coder(filepath)
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
|
||||
-- Check exact filename matches
|
||||
for _, ignored in ipairs(ignored_files) do
|
||||
if filename == ignored then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if file starts with dot (hidden/config files)
|
||||
if filename:match("^%.") then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check if in ignored directory
|
||||
if is_in_ignored_directory(filepath) then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
function M.auto_index_file(bufnr)
|
||||
-- Skip if buffer is invalid
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
@@ -881,6 +1258,11 @@ function M.auto_index_file(bufnr)
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip ignored directories and files (node_modules, .git, config files, etc.)
|
||||
if should_ignore_for_coder(filepath) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip if auto_index is disabled in config
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
@@ -897,23 +1279,137 @@ function M.auto_index_file(bufnr)
|
||||
-- Check if coder file already exists
|
||||
local coder_exists = utils.file_exists(coder_path)
|
||||
|
||||
-- Create coder file with template if it doesn't exist
|
||||
-- Create coder file with pseudo-code context if it doesn't exist
|
||||
if not coder_exists then
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
local template = string.format(
|
||||
[[-- Coder companion for %s
|
||||
-- Use /@ @/ tags to write pseudo-code prompts
|
||||
-- Example:
|
||||
-- /@
|
||||
-- Add a function that validates user input
|
||||
-- - Check for empty strings
|
||||
-- - Validate email format
|
||||
-- @/
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
|
||||
]],
|
||||
filename
|
||||
)
|
||||
utils.write_file(coder_path, template)
|
||||
-- Determine comment style based on extension
|
||||
local comment_prefix = "--"
|
||||
local comment_block_start = "--[["
|
||||
local comment_block_end = "]]"
|
||||
if ext == "ts" or ext == "tsx" or ext == "js" or ext == "jsx" or ext == "java" or ext == "c" or ext == "cpp" or ext == "cs" or ext == "go" or ext == "rs" then
|
||||
comment_prefix = "//"
|
||||
comment_block_start = "/*"
|
||||
comment_block_end = "*/"
|
||||
elseif ext == "py" or ext == "rb" or ext == "yaml" or ext == "yml" then
|
||||
comment_prefix = "#"
|
||||
comment_block_start = '"""'
|
||||
comment_block_end = '"""'
|
||||
end
|
||||
|
||||
-- Read target file to analyze its structure
|
||||
local content = ""
|
||||
pcall(function()
|
||||
local lines = vim.fn.readfile(filepath)
|
||||
if lines then
|
||||
content = table.concat(lines, "\n")
|
||||
end
|
||||
end)
|
||||
|
||||
-- Extract structure from the file
|
||||
local functions = extract_functions(content, ext)
|
||||
local classes = extract_classes(content, ext)
|
||||
local imports = extract_imports(content, ext)
|
||||
|
||||
-- Build pseudo-code context
|
||||
local pseudo_code = {}
|
||||
|
||||
-- Header
|
||||
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
|
||||
table.insert(pseudo_code, comment_prefix .. " CODER COMPANION: " .. filename)
|
||||
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
|
||||
table.insert(pseudo_code, comment_prefix .. " This file describes the business logic and behavior of " .. filename)
|
||||
table.insert(pseudo_code, comment_prefix .. " Edit this pseudo-code to guide code generation.")
|
||||
table.insert(pseudo_code, comment_prefix .. " Use /@ @/ tags for specific generation requests.")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
|
||||
-- Module purpose
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " MODULE PURPOSE:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " TODO: Describe what this module/file is responsible for")
|
||||
table.insert(pseudo_code, comment_prefix .. " Example: \"Handles user authentication and session management\"")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
|
||||
-- Dependencies section
|
||||
if #imports > 0 then
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " DEPENDENCIES:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
for _, imp in ipairs(imports) do
|
||||
table.insert(pseudo_code, comment_prefix .. " • " .. imp)
|
||||
end
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
end
|
||||
|
||||
-- Classes section
|
||||
if #classes > 0 then
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " CLASSES:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
for _, class in ipairs(classes) do
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
table.insert(pseudo_code, comment_prefix .. " class " .. class.name .. ":")
|
||||
table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - describe what this class represents")
|
||||
table.insert(pseudo_code, comment_prefix .. " RESPONSIBILITIES:")
|
||||
table.insert(pseudo_code, comment_prefix .. " - TODO: list main responsibilities")
|
||||
end
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
end
|
||||
|
||||
-- Functions section
|
||||
if #functions > 0 then
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " FUNCTIONS:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
for _, func in ipairs(functions) do
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
table.insert(pseudo_code, comment_prefix .. " " .. func.name .. "():")
|
||||
table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - what does this function do?")
|
||||
table.insert(pseudo_code, comment_prefix .. " INPUTS: TODO - describe parameters")
|
||||
table.insert(pseudo_code, comment_prefix .. " OUTPUTS: TODO - describe return value")
|
||||
table.insert(pseudo_code, comment_prefix .. " BEHAVIOR:")
|
||||
table.insert(pseudo_code, comment_prefix .. " - TODO: describe step-by-step logic")
|
||||
end
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
end
|
||||
|
||||
-- If empty file, provide starter template
|
||||
if #functions == 0 and #classes == 0 then
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " PLANNED STRUCTURE:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " TODO: Describe what you want to build in this file")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
table.insert(pseudo_code, comment_prefix .. " Example pseudo-code:")
|
||||
table.insert(pseudo_code, comment_prefix .. " /@")
|
||||
table.insert(pseudo_code, comment_prefix .. " Create a module that:")
|
||||
table.insert(pseudo_code, comment_prefix .. " 1. Exports a main function")
|
||||
table.insert(pseudo_code, comment_prefix .. " 2. Handles errors gracefully")
|
||||
table.insert(pseudo_code, comment_prefix .. " 3. Returns structured data")
|
||||
table.insert(pseudo_code, comment_prefix .. " @/")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
end
|
||||
|
||||
-- Business rules section
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " BUSINESS RULES:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────")
|
||||
table.insert(pseudo_code, comment_prefix .. " TODO: Document any business rules, constraints, or requirements")
|
||||
table.insert(pseudo_code, comment_prefix .. " Example:")
|
||||
table.insert(pseudo_code, comment_prefix .. " - Users must be authenticated before accessing this feature")
|
||||
table.insert(pseudo_code, comment_prefix .. " - Data must be validated before saving")
|
||||
table.insert(pseudo_code, comment_prefix .. " - Errors should be logged but not exposed to users")
|
||||
table.insert(pseudo_code, comment_prefix .. "")
|
||||
|
||||
-- Footer with generation tags example
|
||||
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
|
||||
table.insert(pseudo_code, comment_prefix .. " Use /@ @/ tags below to request code generation:")
|
||||
table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════")
|
||||
table.insert(pseudo_code, "")
|
||||
|
||||
utils.write_file(coder_path, table.concat(pseudo_code, "\n"))
|
||||
end
|
||||
|
||||
-- Notify user about the coder companion
|
||||
@@ -968,14 +1464,23 @@ function M.open_coder_companion(open_split)
|
||||
%s @/%s
|
||||
|
||||
]],
|
||||
comment_prefix, filename, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment
|
||||
comment_prefix,
|
||||
filename,
|
||||
close_comment,
|
||||
comment_prefix,
|
||||
close_comment,
|
||||
comment_prefix,
|
||||
close_comment,
|
||||
comment_prefix,
|
||||
close_comment,
|
||||
comment_prefix,
|
||||
close_comment,
|
||||
comment_prefix,
|
||||
close_comment,
|
||||
comment_prefix,
|
||||
close_comment,
|
||||
comment_prefix,
|
||||
close_comment
|
||||
)
|
||||
utils.write_file(coder_path, template)
|
||||
end
|
||||
|
||||
291
lua/codetyper/brain/delta/commit.lua
Normal file
291
lua/codetyper/brain/delta/commit.lua
Normal file
@@ -0,0 +1,291 @@
|
||||
--- Brain Delta Commit Operations
|
||||
--- Git-like commit creation and management
|
||||
|
||||
local storage = require("codetyper.brain.storage")
|
||||
local hash_mod = require("codetyper.brain.hash")
|
||||
local diff_mod = require("codetyper.brain.delta.diff")
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Create a new delta commit
|
||||
---@param changes table[] Changes to commit
|
||||
---@param message string Commit message
|
||||
---@param trigger? string Trigger source
|
||||
---@return Delta|nil Created delta
|
||||
function M.create(changes, message, trigger)
|
||||
if not changes or #changes == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local now = os.time()
|
||||
local head = storage.get_head()
|
||||
|
||||
-- Create delta object
|
||||
local delta = {
|
||||
h = hash_mod.delta_hash(changes, head, now),
|
||||
p = head,
|
||||
ts = now,
|
||||
ch = {},
|
||||
m = {
|
||||
msg = message or "Unnamed commit",
|
||||
trig = trigger or "manual",
|
||||
},
|
||||
}
|
||||
|
||||
-- Process changes
|
||||
for _, change in ipairs(changes) do
|
||||
table.insert(delta.ch, {
|
||||
op = change.op,
|
||||
path = change.path,
|
||||
bh = change.bh,
|
||||
ah = change.ah,
|
||||
diff = change.diff,
|
||||
})
|
||||
end
|
||||
|
||||
-- Save delta
|
||||
storage.save_delta(delta)
|
||||
|
||||
-- Update HEAD
|
||||
storage.set_head(delta.h)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ dc = meta.dc + 1 })
|
||||
|
||||
return delta
|
||||
end
|
||||
|
||||
--- Get a delta by hash
|
||||
---@param delta_hash string Delta hash
|
||||
---@return Delta|nil
|
||||
function M.get(delta_hash)
|
||||
return storage.get_delta(delta_hash)
|
||||
end
|
||||
|
||||
--- Get the current HEAD delta
|
||||
---@return Delta|nil
|
||||
function M.get_head()
|
||||
local head_hash = storage.get_head()
|
||||
if not head_hash then
|
||||
return nil
|
||||
end
|
||||
return M.get(head_hash)
|
||||
end
|
||||
|
||||
--- Get delta history (ancestry chain)
|
||||
---@param limit? number Max entries
|
||||
---@param from_hash? string Starting hash (default: HEAD)
|
||||
---@return Delta[]
|
||||
function M.get_history(limit, from_hash)
|
||||
limit = limit or 50
|
||||
local history = {}
|
||||
local current_hash = from_hash or storage.get_head()
|
||||
|
||||
while current_hash and #history < limit do
|
||||
local delta = M.get(current_hash)
|
||||
if not delta then
|
||||
break
|
||||
end
|
||||
|
||||
table.insert(history, delta)
|
||||
current_hash = delta.p
|
||||
end
|
||||
|
||||
return history
|
||||
end
|
||||
|
||||
--- Check if a delta exists
|
||||
---@param delta_hash string Delta hash
|
||||
---@return boolean
|
||||
function M.exists(delta_hash)
|
||||
return M.get(delta_hash) ~= nil
|
||||
end
|
||||
|
||||
--- Get the path from one delta to another
|
||||
---@param from_hash string Start delta hash
|
||||
---@param to_hash string End delta hash
|
||||
---@return Delta[]|nil Path of deltas, or nil if no path
|
||||
function M.get_path(from_hash, to_hash)
|
||||
-- Build ancestry from both sides
|
||||
local from_ancestry = {}
|
||||
local current = from_hash
|
||||
while current do
|
||||
from_ancestry[current] = true
|
||||
local delta = M.get(current)
|
||||
if not delta then
|
||||
break
|
||||
end
|
||||
current = delta.p
|
||||
end
|
||||
|
||||
-- Walk from to_hash back to find common ancestor
|
||||
local path = {}
|
||||
current = to_hash
|
||||
while current do
|
||||
local delta = M.get(current)
|
||||
if not delta then
|
||||
break
|
||||
end
|
||||
|
||||
table.insert(path, 1, delta)
|
||||
|
||||
if from_ancestry[current] then
|
||||
-- Found common ancestor
|
||||
return path
|
||||
end
|
||||
|
||||
current = delta.p
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get all changes between two deltas
|
||||
---@param from_hash string|nil Start delta hash (nil = beginning)
|
||||
---@param to_hash string End delta hash
|
||||
---@return table[] Combined changes
|
||||
function M.get_changes_between(from_hash, to_hash)
|
||||
local path = {}
|
||||
local current = to_hash
|
||||
|
||||
while current and current ~= from_hash do
|
||||
local delta = M.get(current)
|
||||
if not delta then
|
||||
break
|
||||
end
|
||||
table.insert(path, 1, delta)
|
||||
current = delta.p
|
||||
end
|
||||
|
||||
-- Collect all changes
|
||||
local changes = {}
|
||||
for _, delta in ipairs(path) do
|
||||
for _, change in ipairs(delta.ch) do
|
||||
table.insert(changes, change)
|
||||
end
|
||||
end
|
||||
|
||||
return changes
|
||||
end
|
||||
|
||||
--- Compute reverse changes for rollback
|
||||
---@param delta Delta Delta to reverse
|
||||
---@return table[] Reverse changes
|
||||
function M.compute_reverse(delta)
|
||||
local reversed = {}
|
||||
|
||||
for i = #delta.ch, 1, -1 do
|
||||
local change = delta.ch[i]
|
||||
local rev = {
|
||||
path = change.path,
|
||||
}
|
||||
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
rev.op = types.DELTA_OPS.DELETE
|
||||
rev.bh = change.ah
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
rev.op = types.DELTA_OPS.ADD
|
||||
rev.ah = change.bh
|
||||
elseif change.op == types.DELTA_OPS.MODIFY then
|
||||
rev.op = types.DELTA_OPS.MODIFY
|
||||
rev.bh = change.ah
|
||||
rev.ah = change.bh
|
||||
if change.diff then
|
||||
rev.diff = diff_mod.reverse(change.diff)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(reversed, rev)
|
||||
end
|
||||
|
||||
return reversed
|
||||
end
|
||||
|
||||
--- Squash multiple deltas into one
|
||||
---@param delta_hashes string[] Delta hashes to squash
|
||||
---@param message string Squash commit message
|
||||
---@return Delta|nil Squashed delta
|
||||
function M.squash(delta_hashes, message)
|
||||
if #delta_hashes == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Collect all changes in order
|
||||
local all_changes = {}
|
||||
for _, delta_hash in ipairs(delta_hashes) do
|
||||
local delta = M.get(delta_hash)
|
||||
if delta then
|
||||
for _, change in ipairs(delta.ch) do
|
||||
table.insert(all_changes, change)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Compact the changes
|
||||
local compacted = diff_mod.compact(all_changes)
|
||||
|
||||
return M.create(compacted, message, "squash")
|
||||
end
|
||||
|
||||
--- Get summary of a delta
|
||||
---@param delta Delta Delta to summarize
|
||||
---@return table Summary
|
||||
function M.summarize(delta)
|
||||
local adds = 0
|
||||
local mods = 0
|
||||
local dels = 0
|
||||
local paths = {}
|
||||
|
||||
for _, change in ipairs(delta.ch) do
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
adds = adds + 1
|
||||
elseif change.op == types.DELTA_OPS.MODIFY then
|
||||
mods = mods + 1
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
dels = dels + 1
|
||||
end
|
||||
|
||||
-- Extract category from path
|
||||
local parts = vim.split(change.path, ".", { plain = true })
|
||||
if parts[1] then
|
||||
paths[parts[1]] = true
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
hash = delta.h,
|
||||
parent = delta.p,
|
||||
timestamp = delta.ts,
|
||||
message = delta.m.msg,
|
||||
trigger = delta.m.trig,
|
||||
stats = {
|
||||
adds = adds,
|
||||
modifies = mods,
|
||||
deletes = dels,
|
||||
total = adds + mods + dels,
|
||||
},
|
||||
categories = vim.tbl_keys(paths),
|
||||
}
|
||||
end
|
||||
|
||||
--- Format delta for display
|
||||
---@param delta Delta Delta to format
|
||||
---@return string[] Lines
|
||||
function M.format(delta)
|
||||
local summary = M.summarize(delta)
|
||||
local lines = {
|
||||
string.format("commit %s", delta.h),
|
||||
string.format("Date: %s", os.date("%Y-%m-%d %H:%M:%S", delta.ts)),
|
||||
string.format("Parent: %s", delta.p or "(none)"),
|
||||
"",
|
||||
" " .. (delta.m.msg or "No message"),
|
||||
"",
|
||||
string.format(" %d additions, %d modifications, %d deletions", summary.stats.adds, summary.stats.modifies, summary.stats.deletes),
|
||||
}
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
return M
|
||||
261
lua/codetyper/brain/delta/diff.lua
Normal file
261
lua/codetyper/brain/delta/diff.lua
Normal file
@@ -0,0 +1,261 @@
|
||||
--- Brain Delta Diff Computation
|
||||
--- Field-level diff algorithms for delta versioning
|
||||
|
||||
local hash = require("codetyper.brain.hash")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Compute diff between two values
|
||||
---@param before any Before value
|
||||
---@param after any After value
|
||||
---@param path? string Current path
|
||||
---@return table[] Diff entries
|
||||
function M.compute(before, after, path)
|
||||
path = path or ""
|
||||
local diffs = {}
|
||||
|
||||
local before_type = type(before)
|
||||
local after_type = type(after)
|
||||
|
||||
-- Handle nil cases
|
||||
if before == nil and after == nil then
|
||||
return diffs
|
||||
end
|
||||
|
||||
if before == nil then
|
||||
table.insert(diffs, {
|
||||
path = path,
|
||||
op = "add",
|
||||
value = after,
|
||||
})
|
||||
return diffs
|
||||
end
|
||||
|
||||
if after == nil then
|
||||
table.insert(diffs, {
|
||||
path = path,
|
||||
op = "delete",
|
||||
value = before,
|
||||
})
|
||||
return diffs
|
||||
end
|
||||
|
||||
-- Type change
|
||||
if before_type ~= after_type then
|
||||
table.insert(diffs, {
|
||||
path = path,
|
||||
op = "replace",
|
||||
from = before,
|
||||
to = after,
|
||||
})
|
||||
return diffs
|
||||
end
|
||||
|
||||
-- Tables (recursive)
|
||||
if before_type == "table" then
|
||||
-- Get all keys
|
||||
local keys = {}
|
||||
for k in pairs(before) do
|
||||
keys[k] = true
|
||||
end
|
||||
for k in pairs(after) do
|
||||
keys[k] = true
|
||||
end
|
||||
|
||||
for k in pairs(keys) do
|
||||
local sub_path = path == "" and tostring(k) or (path .. "." .. tostring(k))
|
||||
local sub_diffs = M.compute(before[k], after[k], sub_path)
|
||||
for _, d in ipairs(sub_diffs) do
|
||||
table.insert(diffs, d)
|
||||
end
|
||||
end
|
||||
|
||||
return diffs
|
||||
end
|
||||
|
||||
-- Primitive comparison
|
||||
if before ~= after then
|
||||
table.insert(diffs, {
|
||||
path = path,
|
||||
op = "replace",
|
||||
from = before,
|
||||
to = after,
|
||||
})
|
||||
end
|
||||
|
||||
return diffs
|
||||
end
|
||||
|
||||
--- Apply a diff to a value
|
||||
---@param base any Base value
|
||||
---@param diffs table[] Diff entries
|
||||
---@return any Result value
|
||||
function M.apply(base, diffs)
|
||||
local result = vim.deepcopy(base) or {}
|
||||
|
||||
for _, diff in ipairs(diffs) do
|
||||
M.apply_single(result, diff)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Apply a single diff entry
|
||||
---@param target table Target table
|
||||
---@param diff table Diff entry
|
||||
function M.apply_single(target, diff)
|
||||
local path = diff.path
|
||||
local parts = vim.split(path, ".", { plain = true })
|
||||
|
||||
if #parts == 0 or parts[1] == "" then
|
||||
-- Root-level change
|
||||
if diff.op == "add" or diff.op == "replace" then
|
||||
for k, v in pairs(diff.value or diff.to or {}) do
|
||||
target[k] = v
|
||||
end
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Navigate to parent
|
||||
local current = target
|
||||
for i = 1, #parts - 1 do
|
||||
local key = parts[i]
|
||||
-- Try numeric key
|
||||
local num_key = tonumber(key)
|
||||
key = num_key or key
|
||||
|
||||
if current[key] == nil then
|
||||
current[key] = {}
|
||||
end
|
||||
current = current[key]
|
||||
end
|
||||
|
||||
-- Apply to final key
|
||||
local final_key = parts[#parts]
|
||||
local num_key = tonumber(final_key)
|
||||
final_key = num_key or final_key
|
||||
|
||||
if diff.op == "add" then
|
||||
current[final_key] = diff.value
|
||||
elseif diff.op == "delete" then
|
||||
current[final_key] = nil
|
||||
elseif diff.op == "replace" then
|
||||
current[final_key] = diff.to
|
||||
end
|
||||
end
|
||||
|
||||
--- Reverse a diff (for rollback)
|
||||
---@param diffs table[] Diff entries
|
||||
---@return table[] Reversed diffs
|
||||
function M.reverse(diffs)
|
||||
local reversed = {}
|
||||
|
||||
for i = #diffs, 1, -1 do
|
||||
local diff = diffs[i]
|
||||
local rev = {
|
||||
path = diff.path,
|
||||
}
|
||||
|
||||
if diff.op == "add" then
|
||||
rev.op = "delete"
|
||||
rev.value = diff.value
|
||||
elseif diff.op == "delete" then
|
||||
rev.op = "add"
|
||||
rev.value = diff.value
|
||||
elseif diff.op == "replace" then
|
||||
rev.op = "replace"
|
||||
rev.from = diff.to
|
||||
rev.to = diff.from
|
||||
end
|
||||
|
||||
table.insert(reversed, rev)
|
||||
end
|
||||
|
||||
return reversed
|
||||
end
|
||||
|
||||
--- Compact diffs (combine related changes)
|
||||
---@param diffs table[] Diff entries
|
||||
---@return table[] Compacted diffs
|
||||
function M.compact(diffs)
|
||||
local by_path = {}
|
||||
|
||||
for _, diff in ipairs(diffs) do
|
||||
local existing = by_path[diff.path]
|
||||
if existing then
|
||||
-- Combine: keep first "from", use last "to"
|
||||
if diff.op == "replace" then
|
||||
existing.to = diff.to
|
||||
elseif diff.op == "delete" then
|
||||
existing.op = "delete"
|
||||
existing.to = nil
|
||||
end
|
||||
else
|
||||
by_path[diff.path] = vim.deepcopy(diff)
|
||||
end
|
||||
end
|
||||
|
||||
-- Convert back to array, filter out no-ops
|
||||
local result = {}
|
||||
for _, diff in pairs(by_path) do
|
||||
-- Skip if add then delete (net no change)
|
||||
if not (diff.op == "delete" and diff.from == nil) then
|
||||
table.insert(result, diff)
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Create a minimal diff summary for storage
|
||||
---@param diffs table[] Diff entries
|
||||
---@return table Summary
|
||||
function M.summarize(diffs)
|
||||
local adds = 0
|
||||
local deletes = 0
|
||||
local replaces = 0
|
||||
local paths = {}
|
||||
|
||||
for _, diff in ipairs(diffs) do
|
||||
if diff.op == "add" then
|
||||
adds = adds + 1
|
||||
elseif diff.op == "delete" then
|
||||
deletes = deletes + 1
|
||||
elseif diff.op == "replace" then
|
||||
replaces = replaces + 1
|
||||
end
|
||||
|
||||
-- Extract top-level path
|
||||
local parts = vim.split(diff.path, ".", { plain = true })
|
||||
if parts[1] then
|
||||
paths[parts[1]] = true
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
adds = adds,
|
||||
deletes = deletes,
|
||||
replaces = replaces,
|
||||
paths = vim.tbl_keys(paths),
|
||||
total = adds + deletes + replaces,
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if two states are equal (no diff)
|
||||
---@param state1 any First state
|
||||
---@param state2 any Second state
|
||||
---@return boolean
|
||||
function M.equals(state1, state2)
|
||||
local diffs = M.compute(state1, state2)
|
||||
return #diffs == 0
|
||||
end
|
||||
|
||||
--- Get hash of diff for deduplication
|
||||
---@param diffs table[] Diff entries
|
||||
---@return string Hash
|
||||
function M.hash(diffs)
|
||||
return hash.compute_table(diffs)
|
||||
end
|
||||
|
||||
return M
|
||||
278
lua/codetyper/brain/delta/init.lua
Normal file
278
lua/codetyper/brain/delta/init.lua
Normal file
@@ -0,0 +1,278 @@
|
||||
--- Brain Delta Coordinator
|
||||
--- Git-like versioning system for brain state
|
||||
|
||||
local storage = require("codetyper.brain.storage")
|
||||
local commit_mod = require("codetyper.brain.delta.commit")
|
||||
local diff_mod = require("codetyper.brain.delta.diff")
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
-- Re-export submodules
|
||||
M.commit = commit_mod
|
||||
M.diff = diff_mod
|
||||
|
||||
--- Create a commit from pending graph changes
|
||||
---@param message string Commit message
|
||||
---@param trigger? string Trigger source
|
||||
---@return string|nil Delta hash
|
||||
function M.commit(message, trigger)
|
||||
local graph = require("codetyper.brain.graph")
|
||||
local changes = graph.get_pending_changes()
|
||||
|
||||
if #changes == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local delta = commit_mod.create(changes, message, trigger or "auto")
|
||||
if delta then
|
||||
return delta.h
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Rollback to a specific delta
|
||||
---@param target_hash string Target delta hash
|
||||
---@return boolean Success
|
||||
function M.rollback(target_hash)
|
||||
local current_hash = storage.get_head()
|
||||
if not current_hash then
|
||||
return false
|
||||
end
|
||||
|
||||
if current_hash == target_hash then
|
||||
return true -- Already at target
|
||||
end
|
||||
|
||||
-- Get path from target to current
|
||||
local deltas_to_reverse = {}
|
||||
local current = current_hash
|
||||
|
||||
while current and current ~= target_hash do
|
||||
local delta = commit_mod.get(current)
|
||||
if not delta then
|
||||
return false -- Broken chain
|
||||
end
|
||||
table.insert(deltas_to_reverse, delta)
|
||||
current = delta.p
|
||||
end
|
||||
|
||||
if current ~= target_hash then
|
||||
return false -- Target not in ancestry
|
||||
end
|
||||
|
||||
-- Apply reverse changes
|
||||
for _, delta in ipairs(deltas_to_reverse) do
|
||||
local reverse_changes = commit_mod.compute_reverse(delta)
|
||||
M.apply_changes(reverse_changes)
|
||||
end
|
||||
|
||||
-- Update HEAD
|
||||
storage.set_head(target_hash)
|
||||
|
||||
-- Create a rollback commit
|
||||
commit_mod.create({
|
||||
{
|
||||
op = types.DELTA_OPS.MODIFY,
|
||||
path = "meta.head",
|
||||
bh = current_hash,
|
||||
ah = target_hash,
|
||||
},
|
||||
}, "Rollback to " .. target_hash:sub(1, 8), "rollback")
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Apply changes to current state
|
||||
---@param changes table[] Changes to apply
|
||||
function M.apply_changes(changes)
|
||||
local node_mod = require("codetyper.brain.graph.node")
|
||||
|
||||
for _, change in ipairs(changes) do
|
||||
local parts = vim.split(change.path, ".", { plain = true })
|
||||
|
||||
if parts[1] == "nodes" and #parts >= 3 then
|
||||
local node_type = parts[2]
|
||||
local node_id = parts[3]
|
||||
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
-- Node was added, need to delete for reverse
|
||||
node_mod.delete(node_id)
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
-- Node was deleted, would need original data to restore
|
||||
-- This is a limitation - we'd need content storage
|
||||
elseif change.op == types.DELTA_OPS.MODIFY then
|
||||
-- Apply diff if available
|
||||
if change.diff then
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
local updated = diff_mod.apply(node, change.diff)
|
||||
-- Direct update without tracking
|
||||
local nodes = storage.get_nodes(node_type)
|
||||
nodes[node_id] = updated
|
||||
storage.save_nodes(node_type, nodes)
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif parts[1] == "graph" then
|
||||
-- Handle graph/edge changes
|
||||
local edge_mod = require("codetyper.brain.graph.edge")
|
||||
if parts[2] == "edges" and #parts >= 3 then
|
||||
local edge_id = parts[3]
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
-- Edge was added, delete for reverse
|
||||
-- Parse edge_id to get source/target
|
||||
local graph = storage.get_graph()
|
||||
if graph.edges and graph.edges[edge_id] then
|
||||
local edge = graph.edges[edge_id]
|
||||
edge_mod.delete(edge.s, edge.t, edge.ty)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Get delta history
|
||||
---@param limit? number Max entries
|
||||
---@return Delta[]
|
||||
function M.get_history(limit)
|
||||
return commit_mod.get_history(limit)
|
||||
end
|
||||
|
||||
--- Get formatted log
|
||||
---@param limit? number Max entries
|
||||
---@return string[] Log lines
|
||||
function M.log(limit)
|
||||
local history = M.get_history(limit or 20)
|
||||
local lines = {}
|
||||
|
||||
for _, delta in ipairs(history) do
|
||||
local formatted = commit_mod.format(delta)
|
||||
for _, line in ipairs(formatted) do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Get current HEAD hash
|
||||
---@return string|nil
|
||||
function M.head()
|
||||
return storage.get_head()
|
||||
end
|
||||
|
||||
--- Check if there are uncommitted changes
|
||||
---@return boolean
|
||||
function M.has_pending()
|
||||
local graph = require("codetyper.brain.graph")
|
||||
local node_pending = require("codetyper.brain.graph.node").pending
|
||||
local edge_pending = require("codetyper.brain.graph.edge").pending
|
||||
return #node_pending > 0 or #edge_pending > 0
|
||||
end
|
||||
|
||||
--- Get status (like git status)
|
||||
---@return table Status info
|
||||
function M.status()
|
||||
local node_pending = require("codetyper.brain.graph.node").pending
|
||||
local edge_pending = require("codetyper.brain.graph.edge").pending
|
||||
|
||||
local adds = 0
|
||||
local mods = 0
|
||||
local dels = 0
|
||||
|
||||
for _, change in ipairs(node_pending) do
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
adds = adds + 1
|
||||
elseif change.op == types.DELTA_OPS.MODIFY then
|
||||
mods = mods + 1
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
dels = dels + 1
|
||||
end
|
||||
end
|
||||
|
||||
for _, change in ipairs(edge_pending) do
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
adds = adds + 1
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
dels = dels + 1
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
head = storage.get_head(),
|
||||
pending = {
|
||||
adds = adds,
|
||||
modifies = mods,
|
||||
deletes = dels,
|
||||
total = adds + mods + dels,
|
||||
},
|
||||
clean = (adds + mods + dels) == 0,
|
||||
}
|
||||
end
|
||||
|
||||
--- Prune old deltas
|
||||
---@param keep number Number of recent deltas to keep
|
||||
---@return number Number of pruned deltas
|
||||
function M.prune_history(keep)
|
||||
keep = keep or 100
|
||||
local history = M.get_history(1000) -- Get all
|
||||
|
||||
if #history <= keep then
|
||||
return 0
|
||||
end
|
||||
|
||||
local pruned = 0
|
||||
local brain_dir = storage.get_brain_dir()
|
||||
|
||||
for i = keep + 1, #history do
|
||||
local delta = history[i]
|
||||
local filepath = brain_dir .. "/deltas/objects/" .. delta.h .. ".json"
|
||||
if os.remove(filepath) then
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ dc = math.max(0, meta.dc - pruned) })
|
||||
|
||||
return pruned
|
||||
end
|
||||
|
||||
--- Reset to initial state (dangerous!)
|
||||
---@return boolean Success
|
||||
function M.reset()
|
||||
-- Clear all nodes
|
||||
for _, node_type in pairs(types.NODE_TYPES) do
|
||||
storage.save_nodes(node_type .. "s", {})
|
||||
end
|
||||
|
||||
-- Clear graph
|
||||
storage.save_graph({ adj = {}, radj = {}, edges = {} })
|
||||
|
||||
-- Clear indices
|
||||
storage.save_index("by_file", {})
|
||||
storage.save_index("by_time", {})
|
||||
storage.save_index("by_symbol", {})
|
||||
|
||||
-- Reset meta
|
||||
storage.update_meta({
|
||||
head = nil,
|
||||
nc = 0,
|
||||
ec = 0,
|
||||
dc = 0,
|
||||
})
|
||||
|
||||
-- Clear pending
|
||||
require("codetyper.brain.graph.node").pending = {}
|
||||
require("codetyper.brain.graph.edge").pending = {}
|
||||
|
||||
storage.flush_all()
|
||||
return true
|
||||
end
|
||||
|
||||
return M
|
||||
367
lua/codetyper/brain/graph/edge.lua
Normal file
367
lua/codetyper/brain/graph/edge.lua
Normal file
@@ -0,0 +1,367 @@
|
||||
--- Brain Graph Edge Operations
|
||||
--- CRUD operations for node connections
|
||||
|
||||
local storage = require("codetyper.brain.storage")
|
||||
local hash = require("codetyper.brain.hash")
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Pending changes for delta tracking
|
||||
---@type table[]
|
||||
M.pending = {}
|
||||
|
||||
--- Create a new edge between nodes
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@param edge_type EdgeType Edge type
|
||||
---@param props? EdgeProps Edge properties
|
||||
---@return Edge|nil Created edge
|
||||
function M.create(source_id, target_id, edge_type, props)
|
||||
props = props or {}
|
||||
|
||||
local edge = {
|
||||
id = hash.edge_id(source_id, target_id),
|
||||
s = source_id,
|
||||
t = target_id,
|
||||
ty = edge_type,
|
||||
p = {
|
||||
w = props.w or 0.5,
|
||||
dir = props.dir or "bi",
|
||||
r = props.r,
|
||||
},
|
||||
ts = os.time(),
|
||||
}
|
||||
|
||||
-- Update adjacency lists
|
||||
local graph = storage.get_graph()
|
||||
|
||||
-- Forward adjacency
|
||||
graph.adj[source_id] = graph.adj[source_id] or {}
|
||||
graph.adj[source_id][edge_type] = graph.adj[source_id][edge_type] or {}
|
||||
|
||||
-- Check for duplicate
|
||||
if vim.tbl_contains(graph.adj[source_id][edge_type], target_id) then
|
||||
-- Edge exists, strengthen it instead
|
||||
return M.strengthen(source_id, target_id, edge_type)
|
||||
end
|
||||
|
||||
table.insert(graph.adj[source_id][edge_type], target_id)
|
||||
|
||||
-- Reverse adjacency
|
||||
graph.radj[target_id] = graph.radj[target_id] or {}
|
||||
graph.radj[target_id][edge_type] = graph.radj[target_id][edge_type] or {}
|
||||
table.insert(graph.radj[target_id][edge_type], source_id)
|
||||
|
||||
-- Store edge properties separately (for weight/metadata)
|
||||
graph.edges = graph.edges or {}
|
||||
graph.edges[edge.id] = edge
|
||||
|
||||
storage.save_graph(graph)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ ec = meta.ec + 1 })
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.ADD,
|
||||
path = "graph.edges." .. edge.id,
|
||||
ah = hash.compute_table(edge),
|
||||
})
|
||||
|
||||
return edge
|
||||
end
|
||||
|
||||
--- Get edge by source and target
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@param edge_type? EdgeType Optional edge type filter
|
||||
---@return Edge|nil
|
||||
function M.get(source_id, target_id, edge_type)
|
||||
local graph = storage.get_graph()
|
||||
local edge_id = hash.edge_id(source_id, target_id)
|
||||
|
||||
if not graph.edges or not graph.edges[edge_id] then
|
||||
return nil
|
||||
end
|
||||
|
||||
local edge = graph.edges[edge_id]
|
||||
|
||||
if edge_type and edge.ty ~= edge_type then
|
||||
return nil
|
||||
end
|
||||
|
||||
return edge
|
||||
end
|
||||
|
||||
--- Get all edges for a node
|
||||
---@param node_id string Node ID
|
||||
---@param edge_types? EdgeType[] Edge types to include
|
||||
---@param direction? "out"|"in"|"both" Direction (default: "out")
|
||||
---@return Edge[]
|
||||
function M.get_edges(node_id, edge_types, direction)
|
||||
direction = direction or "out"
|
||||
local graph = storage.get_graph()
|
||||
local results = {}
|
||||
|
||||
edge_types = edge_types or vim.tbl_values(types.EDGE_TYPES)
|
||||
|
||||
-- Outgoing edges
|
||||
if direction == "out" or direction == "both" then
|
||||
local adj = graph.adj[node_id]
|
||||
if adj then
|
||||
for _, edge_type in ipairs(edge_types) do
|
||||
local targets = adj[edge_type] or {}
|
||||
for _, target_id in ipairs(targets) do
|
||||
local edge_id = hash.edge_id(node_id, target_id)
|
||||
if graph.edges and graph.edges[edge_id] then
|
||||
table.insert(results, graph.edges[edge_id])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Incoming edges
|
||||
if direction == "in" or direction == "both" then
|
||||
local radj = graph.radj[node_id]
|
||||
if radj then
|
||||
for _, edge_type in ipairs(edge_types) do
|
||||
local sources = radj[edge_type] or {}
|
||||
for _, source_id in ipairs(sources) do
|
||||
local edge_id = hash.edge_id(source_id, node_id)
|
||||
if graph.edges and graph.edges[edge_id] then
|
||||
table.insert(results, graph.edges[edge_id])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
--- Get neighbor node IDs
|
||||
---@param node_id string Node ID
|
||||
---@param edge_types? EdgeType[] Edge types to follow
|
||||
---@param direction? "out"|"in"|"both" Direction
|
||||
---@return string[] Neighbor node IDs
|
||||
function M.get_neighbors(node_id, edge_types, direction)
|
||||
direction = direction or "out"
|
||||
local graph = storage.get_graph()
|
||||
local neighbors = {}
|
||||
|
||||
edge_types = edge_types or vim.tbl_values(types.EDGE_TYPES)
|
||||
|
||||
-- Outgoing
|
||||
if direction == "out" or direction == "both" then
|
||||
local adj = graph.adj[node_id]
|
||||
if adj then
|
||||
for _, edge_type in ipairs(edge_types) do
|
||||
for _, target in ipairs(adj[edge_type] or {}) do
|
||||
if not vim.tbl_contains(neighbors, target) then
|
||||
table.insert(neighbors, target)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Incoming
|
||||
if direction == "in" or direction == "both" then
|
||||
local radj = graph.radj[node_id]
|
||||
if radj then
|
||||
for _, edge_type in ipairs(edge_types) do
|
||||
for _, source in ipairs(radj[edge_type] or {}) do
|
||||
if not vim.tbl_contains(neighbors, source) then
|
||||
table.insert(neighbors, source)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return neighbors
|
||||
end
|
||||
|
||||
--- Delete an edge
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@param edge_type? EdgeType Edge type (deletes all if nil)
|
||||
---@return boolean Success
|
||||
function M.delete(source_id, target_id, edge_type)
|
||||
local graph = storage.get_graph()
|
||||
local edge_id = hash.edge_id(source_id, target_id)
|
||||
|
||||
if not graph.edges or not graph.edges[edge_id] then
|
||||
return false
|
||||
end
|
||||
|
||||
local edge = graph.edges[edge_id]
|
||||
|
||||
if edge_type and edge.ty ~= edge_type then
|
||||
return false
|
||||
end
|
||||
|
||||
local before_hash = hash.compute_table(edge)
|
||||
|
||||
-- Remove from adjacency
|
||||
if graph.adj[source_id] and graph.adj[source_id][edge.ty] then
|
||||
graph.adj[source_id][edge.ty] = vim.tbl_filter(function(id)
|
||||
return id ~= target_id
|
||||
end, graph.adj[source_id][edge.ty])
|
||||
end
|
||||
|
||||
-- Remove from reverse adjacency
|
||||
if graph.radj[target_id] and graph.radj[target_id][edge.ty] then
|
||||
graph.radj[target_id][edge.ty] = vim.tbl_filter(function(id)
|
||||
return id ~= source_id
|
||||
end, graph.radj[target_id][edge.ty])
|
||||
end
|
||||
|
||||
-- Remove edge data
|
||||
graph.edges[edge_id] = nil
|
||||
|
||||
storage.save_graph(graph)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ ec = math.max(0, meta.ec - 1) })
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.DELETE,
|
||||
path = "graph.edges." .. edge_id,
|
||||
bh = before_hash,
|
||||
})
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Delete all edges for a node
|
||||
---@param node_id string Node ID
|
||||
---@return number Number of deleted edges
|
||||
function M.delete_all(node_id)
|
||||
local edges = M.get_edges(node_id, nil, "both")
|
||||
local count = 0
|
||||
|
||||
for _, edge in ipairs(edges) do
|
||||
if M.delete(edge.s, edge.t, edge.ty) then
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
|
||||
return count
|
||||
end
|
||||
|
||||
--- Strengthen an existing edge
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@param edge_type EdgeType Edge type
|
||||
---@return Edge|nil Updated edge
|
||||
function M.strengthen(source_id, target_id, edge_type)
|
||||
local graph = storage.get_graph()
|
||||
local edge_id = hash.edge_id(source_id, target_id)
|
||||
|
||||
if not graph.edges or not graph.edges[edge_id] then
|
||||
return nil
|
||||
end
|
||||
|
||||
local edge = graph.edges[edge_id]
|
||||
|
||||
if edge.ty ~= edge_type then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Increase weight (diminishing returns)
|
||||
edge.p.w = math.min(1.0, edge.p.w + (1 - edge.p.w) * 0.1)
|
||||
edge.ts = os.time()
|
||||
|
||||
graph.edges[edge_id] = edge
|
||||
storage.save_graph(graph)
|
||||
|
||||
return edge
|
||||
end
|
||||
|
||||
--- Find path between two nodes
|
||||
---@param from_id string Start node ID
|
||||
---@param to_id string End node ID
|
||||
---@param max_depth? number Maximum depth (default: 5)
|
||||
---@return table|nil Path info {nodes: string[], edges: Edge[], found: boolean}
|
||||
function M.find_path(from_id, to_id, max_depth)
|
||||
max_depth = max_depth or 5
|
||||
|
||||
-- BFS
|
||||
local queue = { { id = from_id, path = {}, edges = {} } }
|
||||
local visited = { [from_id] = true }
|
||||
|
||||
while #queue > 0 do
|
||||
local current = table.remove(queue, 1)
|
||||
|
||||
if current.id == to_id then
|
||||
table.insert(current.path, to_id)
|
||||
return {
|
||||
nodes = current.path,
|
||||
edges = current.edges,
|
||||
found = true,
|
||||
}
|
||||
end
|
||||
|
||||
if #current.path >= max_depth then
|
||||
goto continue
|
||||
end
|
||||
|
||||
-- Get all neighbors
|
||||
local edges = M.get_edges(current.id, nil, "both")
|
||||
|
||||
for _, edge in ipairs(edges) do
|
||||
local neighbor = edge.s == current.id and edge.t or edge.s
|
||||
|
||||
if not visited[neighbor] then
|
||||
visited[neighbor] = true
|
||||
|
||||
local new_path = vim.list_extend({}, current.path)
|
||||
table.insert(new_path, current.id)
|
||||
|
||||
local new_edges = vim.list_extend({}, current.edges)
|
||||
table.insert(new_edges, edge)
|
||||
|
||||
table.insert(queue, {
|
||||
id = neighbor,
|
||||
path = new_path,
|
||||
edges = new_edges,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
::continue::
|
||||
end
|
||||
|
||||
return { nodes = {}, edges = {}, found = false }
|
||||
end
|
||||
|
||||
--- Get pending changes and clear
|
||||
---@return table[] Pending changes
|
||||
function M.get_and_clear_pending()
|
||||
local changes = M.pending
|
||||
M.pending = {}
|
||||
return changes
|
||||
end
|
||||
|
||||
--- Check if two nodes are connected
|
||||
---@param node_id_1 string First node ID
|
||||
---@param node_id_2 string Second node ID
|
||||
---@param edge_type? EdgeType Edge type filter
|
||||
---@return boolean
|
||||
function M.are_connected(node_id_1, node_id_2, edge_type)
|
||||
local edge = M.get(node_id_1, node_id_2, edge_type)
|
||||
if edge then
|
||||
return true
|
||||
end
|
||||
-- Check reverse
|
||||
edge = M.get(node_id_2, node_id_1, edge_type)
|
||||
return edge ~= nil
|
||||
end
|
||||
|
||||
return M
|
||||
213
lua/codetyper/brain/graph/init.lua
Normal file
213
lua/codetyper/brain/graph/init.lua
Normal file
@@ -0,0 +1,213 @@
|
||||
--- Brain Graph Coordinator
|
||||
--- High-level graph operations
|
||||
|
||||
local node = require("codetyper.brain.graph.node")
|
||||
local edge = require("codetyper.brain.graph.edge")
|
||||
local query = require("codetyper.brain.graph.query")
|
||||
local storage = require("codetyper.brain.storage")
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
-- Re-export submodules
|
||||
M.node = node
|
||||
M.edge = edge
|
||||
M.query = query
|
||||
|
||||
--- Add a learning with automatic edge creation
|
||||
---@param node_type NodeType Node type
|
||||
---@param content NodeContent Content
|
||||
---@param context? NodeContext Context
|
||||
---@param related_ids? string[] Related node IDs
|
||||
---@return Node Created node
|
||||
function M.add_learning(node_type, content, context, related_ids)
|
||||
-- Create the node
|
||||
local new_node = node.create(node_type, content, context)
|
||||
|
||||
-- Create edges to related nodes
|
||||
if related_ids then
|
||||
for _, related_id in ipairs(related_ids) do
|
||||
local related_node = node.get(related_id)
|
||||
if related_node then
|
||||
-- Determine edge type based on relationship
|
||||
local edge_type = types.EDGE_TYPES.SEMANTIC
|
||||
|
||||
-- If same file, use file edge
|
||||
if context and context.f and related_node.ctx and related_node.ctx.f == context.f then
|
||||
edge_type = types.EDGE_TYPES.FILE
|
||||
end
|
||||
|
||||
edge.create(new_node.id, related_id, edge_type, {
|
||||
w = 0.5,
|
||||
r = "Related learning",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Find and link to similar existing nodes
|
||||
local similar = query.semantic_search(content.s, 5)
|
||||
for _, sim_node in ipairs(similar) do
|
||||
if sim_node.id ~= new_node.id then
|
||||
-- Create semantic edge if similarity is high enough
|
||||
local sim_score = query.compute_relevance(sim_node, { query = content.s })
|
||||
if sim_score > 0.5 then
|
||||
edge.create(new_node.id, sim_node.id, types.EDGE_TYPES.SEMANTIC, {
|
||||
w = sim_score,
|
||||
r = "Semantic similarity",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return new_node
|
||||
end
|
||||
|
||||
--- Remove a learning and its edges
|
||||
---@param node_id string Node ID to remove
|
||||
---@return boolean Success
|
||||
function M.remove_learning(node_id)
|
||||
-- Delete all edges first
|
||||
edge.delete_all(node_id)
|
||||
|
||||
-- Delete the node
|
||||
return node.delete(node_id)
|
||||
end
|
||||
|
||||
--- Prune low-value nodes
|
||||
---@param opts? table Prune options
|
||||
---@return number Number of pruned nodes
|
||||
function M.prune(opts)
|
||||
opts = opts or {}
|
||||
local threshold = opts.threshold or 0.1
|
||||
local unused_days = opts.unused_days or 90
|
||||
local now = os.time()
|
||||
local cutoff = now - (unused_days * 86400)
|
||||
|
||||
local pruned = 0
|
||||
|
||||
-- Find nodes to prune
|
||||
for _, node_type in pairs(types.NODE_TYPES) do
|
||||
local nodes_to_prune = node.find({
|
||||
types = { node_type },
|
||||
min_weight = 0, -- Get all
|
||||
})
|
||||
|
||||
for _, n in ipairs(nodes_to_prune) do
|
||||
local should_prune = false
|
||||
|
||||
-- Prune if weight below threshold and not used recently
|
||||
if n.sc.w < threshold and (n.ts.lu or n.ts.up) < cutoff then
|
||||
should_prune = true
|
||||
end
|
||||
|
||||
-- Prune if never used and old
|
||||
if n.sc.u == 0 and n.ts.cr < cutoff then
|
||||
should_prune = true
|
||||
end
|
||||
|
||||
if should_prune then
|
||||
if M.remove_learning(n.id) then
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return pruned
|
||||
end
|
||||
|
||||
--- Get all pending changes from nodes and edges
|
||||
---@return table[] Combined pending changes
|
||||
function M.get_pending_changes()
|
||||
local changes = {}
|
||||
|
||||
-- Get node changes
|
||||
local node_changes = node.get_and_clear_pending()
|
||||
for _, change in ipairs(node_changes) do
|
||||
table.insert(changes, change)
|
||||
end
|
||||
|
||||
-- Get edge changes
|
||||
local edge_changes = edge.get_and_clear_pending()
|
||||
for _, change in ipairs(edge_changes) do
|
||||
table.insert(changes, change)
|
||||
end
|
||||
|
||||
return changes
|
||||
end
|
||||
|
||||
--- Get graph statistics
|
||||
---@return table Stats
|
||||
function M.stats()
|
||||
local meta = storage.get_meta()
|
||||
|
||||
-- Count nodes by type
|
||||
local by_type = {}
|
||||
for _, node_type in pairs(types.NODE_TYPES) do
|
||||
local nodes = storage.get_nodes(node_type .. "s")
|
||||
by_type[node_type] = vim.tbl_count(nodes)
|
||||
end
|
||||
|
||||
-- Count edges by type
|
||||
local graph = storage.get_graph()
|
||||
local edges_by_type = {}
|
||||
if graph.edges then
|
||||
for _, e in pairs(graph.edges) do
|
||||
edges_by_type[e.ty] = (edges_by_type[e.ty] or 0) + 1
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
node_count = meta.nc,
|
||||
edge_count = meta.ec,
|
||||
delta_count = meta.dc,
|
||||
nodes_by_type = by_type,
|
||||
edges_by_type = edges_by_type,
|
||||
}
|
||||
end
|
||||
|
||||
--- Create temporal edge between nodes created in sequence
|
||||
---@param node_ids string[] Node IDs in temporal order
|
||||
function M.link_temporal(node_ids)
|
||||
for i = 1, #node_ids - 1 do
|
||||
edge.create(node_ids[i], node_ids[i + 1], types.EDGE_TYPES.TEMPORAL, {
|
||||
w = 0.7,
|
||||
dir = "fwd",
|
||||
r = "Temporal sequence",
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
--- Create causal edge (this caused that)
|
||||
---@param cause_id string Cause node ID
|
||||
---@param effect_id string Effect node ID
|
||||
---@param reason? string Reason description
|
||||
function M.link_causal(cause_id, effect_id, reason)
|
||||
edge.create(cause_id, effect_id, types.EDGE_TYPES.CAUSAL, {
|
||||
w = 0.8,
|
||||
dir = "fwd",
|
||||
r = reason or "Caused by",
|
||||
})
|
||||
end
|
||||
|
||||
--- Mark a node as superseded by another
|
||||
---@param old_id string Old node ID
|
||||
---@param new_id string New node ID
|
||||
function M.supersede(old_id, new_id)
|
||||
edge.create(old_id, new_id, types.EDGE_TYPES.SUPERSEDES, {
|
||||
w = 1.0,
|
||||
dir = "fwd",
|
||||
r = "Superseded by newer learning",
|
||||
})
|
||||
|
||||
-- Reduce weight of old node
|
||||
local old_node = node.get(old_id)
|
||||
if old_node then
|
||||
node.update(old_id, {
|
||||
sc = { w = old_node.sc.w * 0.5 },
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
403
lua/codetyper/brain/graph/node.lua
Normal file
403
lua/codetyper/brain/graph/node.lua
Normal file
@@ -0,0 +1,403 @@
|
||||
--- Brain Graph Node Operations
|
||||
--- CRUD operations for learning nodes
|
||||
|
||||
local storage = require("codetyper.brain.storage")
|
||||
local hash = require("codetyper.brain.hash")
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Pending changes for delta tracking
|
||||
---@type table[]
|
||||
M.pending = {}
|
||||
|
||||
--- Node type to file mapping
|
||||
local TYPE_MAP = {
|
||||
[types.NODE_TYPES.PATTERN] = "patterns",
|
||||
[types.NODE_TYPES.CORRECTION] = "corrections",
|
||||
[types.NODE_TYPES.DECISION] = "decisions",
|
||||
[types.NODE_TYPES.CONVENTION] = "conventions",
|
||||
[types.NODE_TYPES.FEEDBACK] = "feedback",
|
||||
[types.NODE_TYPES.SESSION] = "sessions",
|
||||
-- Full names for convenience
|
||||
patterns = "patterns",
|
||||
corrections = "corrections",
|
||||
decisions = "decisions",
|
||||
conventions = "conventions",
|
||||
feedback = "feedback",
|
||||
sessions = "sessions",
|
||||
}
|
||||
|
||||
--- Get storage key for node type
|
||||
---@param node_type string Node type
|
||||
---@return string Storage key
|
||||
local function get_storage_key(node_type)
|
||||
return TYPE_MAP[node_type] or "patterns"
|
||||
end
|
||||
|
||||
--- Create a new node
|
||||
---@param node_type NodeType Node type
|
||||
---@param content NodeContent Content
|
||||
---@param context? NodeContext Context
|
||||
---@param opts? table Additional options
|
||||
---@return Node Created node
|
||||
function M.create(node_type, content, context, opts)
|
||||
opts = opts or {}
|
||||
local now = os.time()
|
||||
|
||||
local node = {
|
||||
id = hash.node_id(node_type, content.s),
|
||||
t = node_type,
|
||||
h = hash.compute(content.s .. (content.d or "")),
|
||||
c = {
|
||||
s = content.s or "",
|
||||
d = content.d or content.s or "",
|
||||
code = content.code,
|
||||
lang = content.lang,
|
||||
},
|
||||
ctx = context or {},
|
||||
sc = {
|
||||
w = opts.weight or 0.5,
|
||||
u = 0,
|
||||
sr = 1.0,
|
||||
},
|
||||
ts = {
|
||||
cr = now,
|
||||
up = now,
|
||||
lu = now,
|
||||
},
|
||||
m = {
|
||||
src = opts.source or types.SOURCES.AUTO,
|
||||
v = 1,
|
||||
},
|
||||
}
|
||||
|
||||
-- Store node
|
||||
local storage_key = get_storage_key(node_type)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
nodes[node.id] = node
|
||||
storage.save_nodes(storage_key, nodes)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ nc = meta.nc + 1 })
|
||||
|
||||
-- Update indices
|
||||
M.update_indices(node, "add")
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.ADD,
|
||||
path = "nodes." .. storage_key .. "." .. node.id,
|
||||
ah = node.h,
|
||||
})
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
--- Get a node by ID
|
||||
---@param node_id string Node ID
|
||||
---@return Node|nil
|
||||
function M.get(node_id)
|
||||
-- Parse node type from ID (n_<type>_<timestamp>_<hash>)
|
||||
local parts = vim.split(node_id, "_")
|
||||
if #parts < 3 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local node_type = parts[2]
|
||||
local storage_key = get_storage_key(node_type)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
|
||||
return nodes[node_id]
|
||||
end
|
||||
|
||||
--- Update a node
|
||||
---@param node_id string Node ID
|
||||
---@param updates table Partial updates
|
||||
---@return Node|nil Updated node
|
||||
function M.update(node_id, updates)
|
||||
local node = M.get(node_id)
|
||||
if not node then
|
||||
return nil
|
||||
end
|
||||
|
||||
local before_hash = node.h
|
||||
|
||||
-- Apply updates
|
||||
if updates.c then
|
||||
node.c = vim.tbl_deep_extend("force", node.c, updates.c)
|
||||
end
|
||||
if updates.ctx then
|
||||
node.ctx = vim.tbl_deep_extend("force", node.ctx, updates.ctx)
|
||||
end
|
||||
if updates.sc then
|
||||
node.sc = vim.tbl_deep_extend("force", node.sc, updates.sc)
|
||||
end
|
||||
|
||||
-- Update timestamps and hash
|
||||
node.ts.up = os.time()
|
||||
node.h = hash.compute((node.c.s or "") .. (node.c.d or ""))
|
||||
node.m.v = (node.m.v or 0) + 1
|
||||
|
||||
-- Save
|
||||
local storage_key = get_storage_key(node.t)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
nodes[node_id] = node
|
||||
storage.save_nodes(storage_key, nodes)
|
||||
|
||||
-- Update indices if context changed
|
||||
if updates.ctx then
|
||||
M.update_indices(node, "update")
|
||||
end
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.MODIFY,
|
||||
path = "nodes." .. storage_key .. "." .. node_id,
|
||||
bh = before_hash,
|
||||
ah = node.h,
|
||||
})
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
--- Delete a node
|
||||
---@param node_id string Node ID
|
||||
---@return boolean Success
|
||||
function M.delete(node_id)
|
||||
local node = M.get(node_id)
|
||||
if not node then
|
||||
return false
|
||||
end
|
||||
|
||||
local storage_key = get_storage_key(node.t)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
|
||||
if not nodes[node_id] then
|
||||
return false
|
||||
end
|
||||
|
||||
local before_hash = node.h
|
||||
nodes[node_id] = nil
|
||||
storage.save_nodes(storage_key, nodes)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ nc = math.max(0, meta.nc - 1) })
|
||||
|
||||
-- Update indices
|
||||
M.update_indices(node, "delete")
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.DELETE,
|
||||
path = "nodes." .. storage_key .. "." .. node_id,
|
||||
bh = before_hash,
|
||||
})
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Find nodes by criteria
|
||||
---@param criteria table Search criteria
|
||||
---@return Node[]
|
||||
function M.find(criteria)
|
||||
local results = {}
|
||||
|
||||
local node_types = criteria.types or vim.tbl_values(types.NODE_TYPES)
|
||||
|
||||
for _, node_type in ipairs(node_types) do
|
||||
local storage_key = get_storage_key(node_type)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
|
||||
for _, node in pairs(nodes) do
|
||||
local matches = true
|
||||
|
||||
-- Filter by file
|
||||
if criteria.file and node.ctx.f ~= criteria.file then
|
||||
matches = false
|
||||
end
|
||||
|
||||
-- Filter by min weight
|
||||
if criteria.min_weight and node.sc.w < criteria.min_weight then
|
||||
matches = false
|
||||
end
|
||||
|
||||
-- Filter by since timestamp
|
||||
if criteria.since and node.ts.cr < criteria.since then
|
||||
matches = false
|
||||
end
|
||||
|
||||
-- Filter by content match
|
||||
if criteria.query then
|
||||
local query_lower = criteria.query:lower()
|
||||
local summary_lower = (node.c.s or ""):lower()
|
||||
local detail_lower = (node.c.d or ""):lower()
|
||||
if not summary_lower:find(query_lower, 1, true) and not detail_lower:find(query_lower, 1, true) then
|
||||
matches = false
|
||||
end
|
||||
end
|
||||
|
||||
if matches then
|
||||
table.insert(results, node)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort by relevance (weight * recency)
|
||||
table.sort(results, function(a, b)
|
||||
local score_a = a.sc.w * (1 / (1 + (os.time() - a.ts.lu) / 86400))
|
||||
local score_b = b.sc.w * (1 / (1 + (os.time() - b.ts.lu) / 86400))
|
||||
return score_a > score_b
|
||||
end)
|
||||
|
||||
-- Apply limit
|
||||
if criteria.limit and #results > criteria.limit then
|
||||
local limited = {}
|
||||
for i = 1, criteria.limit do
|
||||
limited[i] = results[i]
|
||||
end
|
||||
return limited
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
--- Record usage of a node
|
||||
---@param node_id string Node ID
|
||||
---@param success? boolean Was the usage successful
|
||||
function M.record_usage(node_id, success)
|
||||
local node = M.get(node_id)
|
||||
if not node then
|
||||
return
|
||||
end
|
||||
|
||||
-- Update usage stats
|
||||
node.sc.u = node.sc.u + 1
|
||||
node.ts.lu = os.time()
|
||||
|
||||
-- Update success rate
|
||||
if success ~= nil then
|
||||
local total = node.sc.u
|
||||
local successes = node.sc.sr * (total - 1) + (success and 1 or 0)
|
||||
node.sc.sr = successes / total
|
||||
end
|
||||
|
||||
-- Increase weight slightly for frequently used nodes
|
||||
if node.sc.u > 5 then
|
||||
node.sc.w = math.min(1.0, node.sc.w + 0.01)
|
||||
end
|
||||
|
||||
-- Save (direct save, no pending change tracking for usage)
|
||||
local storage_key = get_storage_key(node.t)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
nodes[node_id] = node
|
||||
storage.save_nodes(storage_key, nodes)
|
||||
end
|
||||
|
||||
--- Update indices for a node
|
||||
---@param node Node The node
|
||||
---@param op "add"|"update"|"delete" Operation type
|
||||
function M.update_indices(node, op)
|
||||
-- File index
|
||||
if node.ctx.f then
|
||||
local by_file = storage.get_index("by_file")
|
||||
|
||||
if op == "delete" then
|
||||
if by_file[node.ctx.f] then
|
||||
by_file[node.ctx.f] = vim.tbl_filter(function(id)
|
||||
return id ~= node.id
|
||||
end, by_file[node.ctx.f])
|
||||
end
|
||||
else
|
||||
by_file[node.ctx.f] = by_file[node.ctx.f] or {}
|
||||
if not vim.tbl_contains(by_file[node.ctx.f], node.id) then
|
||||
table.insert(by_file[node.ctx.f], node.id)
|
||||
end
|
||||
end
|
||||
|
||||
storage.save_index("by_file", by_file)
|
||||
end
|
||||
|
||||
-- Symbol index
|
||||
if node.ctx.sym then
|
||||
local by_symbol = storage.get_index("by_symbol")
|
||||
|
||||
for _, sym in ipairs(node.ctx.sym) do
|
||||
if op == "delete" then
|
||||
if by_symbol[sym] then
|
||||
by_symbol[sym] = vim.tbl_filter(function(id)
|
||||
return id ~= node.id
|
||||
end, by_symbol[sym])
|
||||
end
|
||||
else
|
||||
by_symbol[sym] = by_symbol[sym] or {}
|
||||
if not vim.tbl_contains(by_symbol[sym], node.id) then
|
||||
table.insert(by_symbol[sym], node.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
storage.save_index("by_symbol", by_symbol)
|
||||
end
|
||||
|
||||
-- Time index (daily buckets)
|
||||
local day = os.date("%Y-%m-%d", node.ts.cr)
|
||||
local by_time = storage.get_index("by_time")
|
||||
|
||||
if op == "delete" then
|
||||
if by_time[day] then
|
||||
by_time[day] = vim.tbl_filter(function(id)
|
||||
return id ~= node.id
|
||||
end, by_time[day])
|
||||
end
|
||||
elseif op == "add" then
|
||||
by_time[day] = by_time[day] or {}
|
||||
if not vim.tbl_contains(by_time[day], node.id) then
|
||||
table.insert(by_time[day], node.id)
|
||||
end
|
||||
end
|
||||
|
||||
storage.save_index("by_time", by_time)
|
||||
end
|
||||
|
||||
--- Get pending changes and clear
|
||||
---@return table[] Pending changes
|
||||
function M.get_and_clear_pending()
|
||||
local changes = M.pending
|
||||
M.pending = {}
|
||||
return changes
|
||||
end
|
||||
|
||||
--- Merge two similar nodes
|
||||
---@param node_id_1 string First node ID
|
||||
---@param node_id_2 string Second node ID (will be deleted)
|
||||
---@return Node|nil Merged node
|
||||
function M.merge(node_id_1, node_id_2)
|
||||
local node1 = M.get(node_id_1)
|
||||
local node2 = M.get(node_id_2)
|
||||
|
||||
if not node1 or not node2 then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Merge content (keep longer detail)
|
||||
local merged_detail = #node1.c.d > #node2.c.d and node1.c.d or node2.c.d
|
||||
|
||||
-- Merge scores (combine weights and usage)
|
||||
local merged_weight = (node1.sc.w + node2.sc.w) / 2
|
||||
local merged_usage = node1.sc.u + node2.sc.u
|
||||
|
||||
M.update(node_id_1, {
|
||||
c = { d = merged_detail },
|
||||
sc = { w = merged_weight, u = merged_usage },
|
||||
})
|
||||
|
||||
-- Delete the second node
|
||||
M.delete(node_id_2)
|
||||
|
||||
return M.get(node_id_1)
|
||||
end
|
||||
|
||||
return M
|
||||
488
lua/codetyper/brain/graph/query.lua
Normal file
488
lua/codetyper/brain/graph/query.lua
Normal file
@@ -0,0 +1,488 @@
|
||||
--- Brain Graph Query Engine
|
||||
--- Multi-dimensional traversal and relevance scoring
|
||||
|
||||
local storage = require("codetyper.brain.storage")
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Lazy load dependencies to avoid circular requires
|
||||
local function get_node_module()
|
||||
return require("codetyper.brain.graph.node")
|
||||
end
|
||||
|
||||
local function get_edge_module()
|
||||
return require("codetyper.brain.graph.edge")
|
||||
end
|
||||
|
||||
--- Compute text similarity (simple keyword matching)
|
||||
---@param text1 string First text
|
||||
---@param text2 string Second text
|
||||
---@return number Similarity score (0-1)
|
||||
local function text_similarity(text1, text2)
|
||||
if not text1 or not text2 then
|
||||
return 0
|
||||
end
|
||||
|
||||
text1 = text1:lower()
|
||||
text2 = text2:lower()
|
||||
|
||||
-- Extract words
|
||||
local words1 = {}
|
||||
for word in text1:gmatch("%w+") do
|
||||
words1[word] = true
|
||||
end
|
||||
|
||||
local words2 = {}
|
||||
for word in text2:gmatch("%w+") do
|
||||
words2[word] = true
|
||||
end
|
||||
|
||||
-- Count matches
|
||||
local matches = 0
|
||||
local total = 0
|
||||
|
||||
for word in pairs(words1) do
|
||||
total = total + 1
|
||||
if words2[word] then
|
||||
matches = matches + 1
|
||||
end
|
||||
end
|
||||
|
||||
for word in pairs(words2) do
|
||||
if not words1[word] then
|
||||
total = total + 1
|
||||
end
|
||||
end
|
||||
|
||||
if total == 0 then
|
||||
return 0
|
||||
end
|
||||
|
||||
return matches / total
|
||||
end
|
||||
|
||||
--- Compute relevance score for a node
|
||||
---@param node Node Node to score
|
||||
---@param opts QueryOpts Query options
|
||||
---@return number Relevance score (0-1)
|
||||
function M.compute_relevance(node, opts)
|
||||
local score = 0
|
||||
local weights = {
|
||||
content_match = 0.30,
|
||||
recency = 0.20,
|
||||
usage = 0.15,
|
||||
weight = 0.15,
|
||||
connection_density = 0.10,
|
||||
success_rate = 0.10,
|
||||
}
|
||||
|
||||
-- Content similarity
|
||||
if opts.query then
|
||||
local summary = node.c.s or ""
|
||||
local detail = node.c.d or ""
|
||||
local similarity = math.max(text_similarity(opts.query, summary), text_similarity(opts.query, detail) * 0.8)
|
||||
score = score + (similarity * weights.content_match)
|
||||
else
|
||||
score = score + weights.content_match * 0.5 -- Base score if no query
|
||||
end
|
||||
|
||||
-- Recency decay (exponential with 30-day half-life)
|
||||
local age_days = (os.time() - (node.ts.lu or node.ts.up)) / 86400
|
||||
local recency = math.exp(-age_days / 30)
|
||||
score = score + (recency * weights.recency)
|
||||
|
||||
-- Usage frequency (normalized)
|
||||
local usage = math.min(node.sc.u / 10, 1.0)
|
||||
score = score + (usage * weights.usage)
|
||||
|
||||
-- Node weight
|
||||
score = score + (node.sc.w * weights.weight)
|
||||
|
||||
-- Connection density
|
||||
local edge_mod = get_edge_module()
|
||||
local connections = #edge_mod.get_edges(node.id, nil, "both")
|
||||
local density = math.min(connections / 5, 1.0)
|
||||
score = score + (density * weights.connection_density)
|
||||
|
||||
-- Success rate
|
||||
score = score + (node.sc.sr * weights.success_rate)
|
||||
|
||||
return score
|
||||
end
|
||||
|
||||
--- Traverse graph from seed nodes (basic traversal)
|
||||
---@param seed_ids string[] Starting node IDs
|
||||
---@param depth number Traversal depth
|
||||
---@param edge_types? EdgeType[] Edge types to follow
|
||||
---@return table<string, Node> Discovered nodes indexed by ID
|
||||
local function traverse(seed_ids, depth, edge_types)
|
||||
local node_mod = get_node_module()
|
||||
local edge_mod = get_edge_module()
|
||||
local discovered = {}
|
||||
local frontier = seed_ids
|
||||
|
||||
for _ = 1, depth do
|
||||
local next_frontier = {}
|
||||
|
||||
for _, node_id in ipairs(frontier) do
|
||||
-- Skip if already discovered
|
||||
if discovered[node_id] then
|
||||
goto continue
|
||||
end
|
||||
|
||||
-- Get and store node
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
discovered[node_id] = node
|
||||
|
||||
-- Get neighbors
|
||||
local neighbors = edge_mod.get_neighbors(node_id, edge_types, "both")
|
||||
for _, neighbor_id in ipairs(neighbors) do
|
||||
if not discovered[neighbor_id] then
|
||||
table.insert(next_frontier, neighbor_id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
::continue::
|
||||
end
|
||||
|
||||
frontier = next_frontier
|
||||
if #frontier == 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
return discovered
|
||||
end
|
||||
|
||||
--- Spreading activation - mimics human associative memory
|
||||
--- Activation spreads from seed nodes along edges, decaying by weight
|
||||
--- Nodes accumulate activation from multiple paths (like neural pathways)
|
||||
---@param seed_activations table<string, number> Initial activations {node_id: activation}
|
||||
---@param max_iterations number Max spread iterations (default 3)
|
||||
---@param decay number Activation decay per hop (default 0.5)
|
||||
---@param threshold number Minimum activation to continue spreading (default 0.1)
|
||||
---@return table<string, number> Final activations {node_id: accumulated_activation}
|
||||
local function spreading_activation(seed_activations, max_iterations, decay, threshold)
|
||||
local edge_mod = get_edge_module()
|
||||
max_iterations = max_iterations or 3
|
||||
decay = decay or 0.5
|
||||
threshold = threshold or 0.1
|
||||
|
||||
-- Accumulated activation for each node
|
||||
local activation = {}
|
||||
for node_id, act in pairs(seed_activations) do
|
||||
activation[node_id] = act
|
||||
end
|
||||
|
||||
-- Current frontier with their activation levels
|
||||
local frontier = {}
|
||||
for node_id, act in pairs(seed_activations) do
|
||||
frontier[node_id] = act
|
||||
end
|
||||
|
||||
-- Spread activation iteratively
|
||||
for _ = 1, max_iterations do
|
||||
local next_frontier = {}
|
||||
|
||||
for source_id, source_activation in pairs(frontier) do
|
||||
-- Get all outgoing edges
|
||||
local edges = edge_mod.get_edges(source_id, nil, "both")
|
||||
|
||||
for _, edge in ipairs(edges) do
|
||||
-- Determine target (could be source or target of edge)
|
||||
local target_id = edge.s == source_id and edge.t or edge.s
|
||||
|
||||
-- Calculate spreading activation
|
||||
-- Activation = source_activation * edge_weight * decay
|
||||
local edge_weight = edge.p and edge.p.w or 0.5
|
||||
local spread_amount = source_activation * edge_weight * decay
|
||||
|
||||
-- Only spread if above threshold
|
||||
if spread_amount >= threshold then
|
||||
-- Accumulate activation (multiple paths add up)
|
||||
activation[target_id] = (activation[target_id] or 0) + spread_amount
|
||||
|
||||
-- Add to next frontier if not already processed with higher activation
|
||||
if not next_frontier[target_id] or next_frontier[target_id] < spread_amount then
|
||||
next_frontier[target_id] = spread_amount
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Stop if no more spreading
|
||||
if vim.tbl_count(next_frontier) == 0 then
|
||||
break
|
||||
end
|
||||
|
||||
frontier = next_frontier
|
||||
end
|
||||
|
||||
return activation
|
||||
end
|
||||
|
||||
--- Execute a query across all dimensions
|
||||
---@param opts QueryOpts Query options
|
||||
---@return QueryResult
|
||||
function M.execute(opts)
|
||||
opts = opts or {}
|
||||
local node_mod = get_node_module()
|
||||
local results = {
|
||||
semantic = {},
|
||||
file = {},
|
||||
temporal = {},
|
||||
}
|
||||
|
||||
-- 1. Semantic traversal (content similarity)
|
||||
if opts.query then
|
||||
local seed_nodes = node_mod.find({
|
||||
query = opts.query,
|
||||
types = opts.types,
|
||||
limit = 10,
|
||||
})
|
||||
|
||||
local seed_ids = vim.tbl_map(function(n)
|
||||
return n.id
|
||||
end, seed_nodes)
|
||||
local depth = opts.depth or 2
|
||||
|
||||
local discovered = traverse(seed_ids, depth, { types.EDGE_TYPES.SEMANTIC })
|
||||
for id, node in pairs(discovered) do
|
||||
results.semantic[id] = node
|
||||
end
|
||||
end
|
||||
|
||||
-- 2. File-based traversal
|
||||
if opts.file then
|
||||
local by_file = storage.get_index("by_file")
|
||||
local file_node_ids = by_file[opts.file] or {}
|
||||
|
||||
for _, node_id in ipairs(file_node_ids) do
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
results.file[node.id] = node
|
||||
end
|
||||
end
|
||||
|
||||
-- Also get nodes from related files via edges
|
||||
local discovered = traverse(file_node_ids, 1, { types.EDGE_TYPES.FILE })
|
||||
for id, node in pairs(discovered) do
|
||||
results.file[id] = node
|
||||
end
|
||||
end
|
||||
|
||||
-- 3. Temporal traversal (recent context)
|
||||
if opts.since then
|
||||
local by_time = storage.get_index("by_time")
|
||||
local now = os.time()
|
||||
|
||||
for day, node_ids in pairs(by_time) do
|
||||
-- Parse day to timestamp
|
||||
local year, month, day_num = day:match("(%d+)-(%d+)-(%d+)")
|
||||
if year then
|
||||
local day_ts = os.time({ year = tonumber(year), month = tonumber(month), day = tonumber(day_num) })
|
||||
if day_ts >= opts.since then
|
||||
for _, node_id in ipairs(node_ids) do
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
results.temporal[node.id] = node
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Follow temporal edges
|
||||
local temporal_ids = vim.tbl_keys(results.temporal)
|
||||
local discovered = traverse(temporal_ids, 1, { types.EDGE_TYPES.TEMPORAL })
|
||||
for id, node in pairs(discovered) do
|
||||
results.temporal[id] = node
|
||||
end
|
||||
end
|
||||
|
||||
-- 4. Combine all found nodes and compute seed activations
|
||||
local all_nodes = {}
|
||||
local seed_activations = {}
|
||||
|
||||
for _, category in pairs(results) do
|
||||
for id, node in pairs(category) do
|
||||
if not all_nodes[id] then
|
||||
all_nodes[id] = node
|
||||
-- Compute initial activation based on relevance
|
||||
local relevance = M.compute_relevance(node, opts)
|
||||
seed_activations[id] = relevance
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 5. Apply spreading activation - like human associative memory
|
||||
-- Activation spreads from seed nodes along edges, accumulating
|
||||
-- Nodes connected to multiple relevant seeds get higher activation
|
||||
local final_activations = spreading_activation(
|
||||
seed_activations,
|
||||
opts.spread_iterations or 3, -- How far activation spreads
|
||||
opts.spread_decay or 0.5, -- How much activation decays per hop
|
||||
opts.spread_threshold or 0.05 -- Minimum activation to continue spreading
|
||||
)
|
||||
|
||||
-- 6. Score and rank by combined activation
|
||||
local scored = {}
|
||||
for id, activation in pairs(final_activations) do
|
||||
local node = all_nodes[id] or node_mod.get(id)
|
||||
if node then
|
||||
all_nodes[id] = node
|
||||
-- Final score = spreading activation + base relevance
|
||||
local base_relevance = M.compute_relevance(node, opts)
|
||||
local final_score = (activation * 0.6) + (base_relevance * 0.4)
|
||||
table.insert(scored, { node = node, relevance = final_score, activation = activation })
|
||||
end
|
||||
end
|
||||
|
||||
table.sort(scored, function(a, b)
|
||||
return a.relevance > b.relevance
|
||||
end)
|
||||
|
||||
-- 7. Apply limit
|
||||
local limit = opts.limit or 50
|
||||
local result_nodes = {}
|
||||
local truncated = #scored > limit
|
||||
|
||||
for i = 1, math.min(limit, #scored) do
|
||||
table.insert(result_nodes, scored[i].node)
|
||||
end
|
||||
|
||||
-- 8. Get edges between result nodes
|
||||
local edge_mod = get_edge_module()
|
||||
local result_edges = {}
|
||||
local node_ids = {}
|
||||
for _, node in ipairs(result_nodes) do
|
||||
node_ids[node.id] = true
|
||||
end
|
||||
|
||||
for _, node in ipairs(result_nodes) do
|
||||
local edges = edge_mod.get_edges(node.id, nil, "out")
|
||||
for _, edge in ipairs(edges) do
|
||||
if node_ids[edge.t] then
|
||||
table.insert(result_edges, edge)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
nodes = result_nodes,
|
||||
edges = result_edges,
|
||||
stats = {
|
||||
semantic_count = vim.tbl_count(results.semantic),
|
||||
file_count = vim.tbl_count(results.file),
|
||||
temporal_count = vim.tbl_count(results.temporal),
|
||||
total_scored = #scored,
|
||||
seed_nodes = vim.tbl_count(seed_activations),
|
||||
activated_nodes = vim.tbl_count(final_activations),
|
||||
},
|
||||
truncated = truncated,
|
||||
}
|
||||
end
|
||||
|
||||
--- Expose spreading activation for direct use
|
||||
--- Useful for custom activation patterns or debugging
|
||||
M.spreading_activation = spreading_activation
|
||||
|
||||
--- Find nodes by file
|
||||
---@param filepath string File path
|
||||
---@param limit? number Max results
|
||||
---@return Node[]
|
||||
function M.by_file(filepath, limit)
|
||||
local result = M.execute({
|
||||
file = filepath,
|
||||
limit = limit or 20,
|
||||
})
|
||||
return result.nodes
|
||||
end
|
||||
|
||||
--- Find nodes by time range
|
||||
---@param since number Start timestamp
|
||||
---@param until_ts? number End timestamp
|
||||
---@param limit? number Max results
|
||||
---@return Node[]
|
||||
function M.by_time_range(since, until_ts, limit)
|
||||
local node_mod = get_node_module()
|
||||
local by_time = storage.get_index("by_time")
|
||||
local results = {}
|
||||
|
||||
until_ts = until_ts or os.time()
|
||||
|
||||
for day, node_ids in pairs(by_time) do
|
||||
local year, month, day_num = day:match("(%d+)-(%d+)-(%d+)")
|
||||
if year then
|
||||
local day_ts = os.time({ year = tonumber(year), month = tonumber(month), day = tonumber(day_num) })
|
||||
if day_ts >= since and day_ts <= until_ts then
|
||||
for _, node_id in ipairs(node_ids) do
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
table.insert(results, node)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort by creation time
|
||||
table.sort(results, function(a, b)
|
||||
return a.ts.cr > b.ts.cr
|
||||
end)
|
||||
|
||||
if limit and #results > limit then
|
||||
local limited = {}
|
||||
for i = 1, limit do
|
||||
limited[i] = results[i]
|
||||
end
|
||||
return limited
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
--- Find semantically similar nodes
|
||||
---@param query string Query text
|
||||
---@param limit? number Max results
|
||||
---@return Node[]
|
||||
function M.semantic_search(query, limit)
|
||||
local result = M.execute({
|
||||
query = query,
|
||||
limit = limit or 10,
|
||||
depth = 2,
|
||||
})
|
||||
return result.nodes
|
||||
end
|
||||
|
||||
--- Get context chain (path) for explanation
|
||||
---@param node_ids string[] Node IDs to chain
|
||||
---@return string[] Chain descriptions
|
||||
function M.get_context_chain(node_ids)
|
||||
local node_mod = get_node_module()
|
||||
local edge_mod = get_edge_module()
|
||||
local chain = {}
|
||||
|
||||
for i, node_id in ipairs(node_ids) do
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
local entry = string.format("[%s] %s (w:%.2f)", node.t:upper(), node.c.s, node.sc.w)
|
||||
table.insert(chain, entry)
|
||||
|
||||
-- Add edge to next node if exists
|
||||
if node_ids[i + 1] then
|
||||
local edge = edge_mod.get(node_id, node_ids[i + 1])
|
||||
if edge then
|
||||
table.insert(chain, string.format(" -> %s (w:%.2f)", edge.ty, edge.p.w))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return chain
|
||||
end
|
||||
|
||||
return M
|
||||
112
lua/codetyper/brain/hash.lua
Normal file
112
lua/codetyper/brain/hash.lua
Normal file
@@ -0,0 +1,112 @@
|
||||
--- Brain Hashing Utilities
|
||||
--- Content-addressable storage with 8-character hashes
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Simple DJB2 hash algorithm (fast, good distribution)
|
||||
---@param str string String to hash
|
||||
---@return number Hash value
|
||||
local function djb2(str)
|
||||
local hash = 5381
|
||||
for i = 1, #str do
|
||||
hash = ((hash * 33) + string.byte(str, i)) % 0x100000000
|
||||
end
|
||||
return hash
|
||||
end
|
||||
|
||||
--- Convert number to hex string
|
||||
---@param num number Number to convert
|
||||
---@param len number Desired length
|
||||
---@return string Hex string
|
||||
local function to_hex(num, len)
|
||||
local hex = string.format("%x", num)
|
||||
if #hex < len then
|
||||
hex = string.rep("0", len - #hex) .. hex
|
||||
end
|
||||
return hex:sub(-len)
|
||||
end
|
||||
|
||||
--- Compute 8-character hash from string
|
||||
---@param content string Content to hash
|
||||
---@return string 8-character hex hash
|
||||
function M.compute(content)
|
||||
if not content or content == "" then
|
||||
return "00000000"
|
||||
end
|
||||
local hash = djb2(content)
|
||||
return to_hex(hash, 8)
|
||||
end
|
||||
|
||||
--- Compute hash from table (JSON-serialized)
|
||||
---@param tbl table Table to hash
|
||||
---@return string 8-character hex hash
|
||||
function M.compute_table(tbl)
|
||||
local ok, json = pcall(vim.json.encode, tbl)
|
||||
if not ok then
|
||||
return "00000000"
|
||||
end
|
||||
return M.compute(json)
|
||||
end
|
||||
|
||||
--- Generate unique node ID
|
||||
---@param node_type string Node type prefix
|
||||
---@param content? string Optional content for hash
|
||||
---@return string Node ID (n_<timestamp>_<hash>)
|
||||
function M.node_id(node_type, content)
|
||||
local ts = os.time()
|
||||
local hash_input = (content or "") .. tostring(ts) .. tostring(math.random(100000))
|
||||
local hash = M.compute(hash_input):sub(1, 6)
|
||||
return string.format("n_%s_%d_%s", node_type, ts, hash)
|
||||
end
|
||||
|
||||
--- Generate unique edge ID
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@return string Edge ID (e_<source_hash>_<target_hash>)
|
||||
function M.edge_id(source_id, target_id)
|
||||
local src_hash = M.compute(source_id):sub(1, 4)
|
||||
local tgt_hash = M.compute(target_id):sub(1, 4)
|
||||
return string.format("e_%s_%s", src_hash, tgt_hash)
|
||||
end
|
||||
|
||||
--- Generate delta hash
|
||||
---@param changes table[] Delta changes
|
||||
---@param parent string|nil Parent delta hash
|
||||
---@param timestamp number Delta timestamp
|
||||
---@return string 8-character delta hash
|
||||
function M.delta_hash(changes, parent, timestamp)
|
||||
local content = (parent or "root") .. tostring(timestamp)
|
||||
for _, change in ipairs(changes or {}) do
|
||||
content = content .. (change.op or "") .. (change.path or "")
|
||||
end
|
||||
return M.compute(content)
|
||||
end
|
||||
|
||||
--- Hash file path for storage
|
||||
---@param filepath string File path
|
||||
---@return string 8-character hash
|
||||
function M.path_hash(filepath)
|
||||
return M.compute(filepath)
|
||||
end
|
||||
|
||||
--- Check if two hashes match
|
||||
---@param hash1 string First hash
|
||||
---@param hash2 string Second hash
|
||||
---@return boolean True if matching
|
||||
function M.matches(hash1, hash2)
|
||||
return hash1 == hash2
|
||||
end
|
||||
|
||||
--- Generate random hash (for testing/temporary IDs)
|
||||
---@return string 8-character random hash
|
||||
function M.random()
|
||||
local chars = "0123456789abcdef"
|
||||
local result = ""
|
||||
for _ = 1, 8 do
|
||||
local idx = math.random(1, #chars)
|
||||
result = result .. chars:sub(idx, idx)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
return M
|
||||
276
lua/codetyper/brain/init.lua
Normal file
276
lua/codetyper/brain/init.lua
Normal file
@@ -0,0 +1,276 @@
|
||||
--- Brain Learning System
|
||||
--- Graph-based knowledge storage with delta versioning
|
||||
|
||||
local storage = require("codetyper.brain.storage")
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
---@type BrainConfig|nil
|
||||
local config = nil
|
||||
|
||||
---@type boolean
|
||||
local initialized = false
|
||||
|
||||
--- Pending changes counter for auto-commit
|
||||
local pending_changes = 0
|
||||
|
||||
--- Default configuration
|
||||
local DEFAULT_CONFIG = {
|
||||
enabled = true,
|
||||
auto_learn = true,
|
||||
auto_commit = true,
|
||||
commit_threshold = 10,
|
||||
max_nodes = 5000,
|
||||
max_deltas = 500,
|
||||
prune = {
|
||||
enabled = true,
|
||||
threshold = 0.1,
|
||||
unused_days = 90,
|
||||
},
|
||||
output = {
|
||||
max_tokens = 4000,
|
||||
format = "compact",
|
||||
},
|
||||
}
|
||||
|
||||
--- Initialize brain system
|
||||
---@param opts? BrainConfig Configuration options
|
||||
function M.setup(opts)
|
||||
config = vim.tbl_deep_extend("force", DEFAULT_CONFIG, opts or {})
|
||||
|
||||
if not config.enabled then
|
||||
return
|
||||
end
|
||||
|
||||
-- Ensure storage directories
|
||||
storage.ensure_dirs()
|
||||
|
||||
-- Initialize meta if not exists
|
||||
storage.get_meta()
|
||||
|
||||
initialized = true
|
||||
end
|
||||
|
||||
--- Check if brain is initialized
|
||||
---@return boolean
|
||||
function M.is_initialized()
|
||||
return initialized and config and config.enabled
|
||||
end
|
||||
|
||||
--- Get current configuration
|
||||
---@return BrainConfig|nil
|
||||
function M.get_config()
|
||||
return config
|
||||
end
|
||||
|
||||
--- Learn from an event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return string|nil Node ID if created
|
||||
function M.learn(event)
|
||||
if not M.is_initialized() or not config.auto_learn then
|
||||
return nil
|
||||
end
|
||||
|
||||
local learners = require("codetyper.brain.learners")
|
||||
local node_id = learners.process(event)
|
||||
|
||||
if node_id then
|
||||
pending_changes = pending_changes + 1
|
||||
|
||||
-- Auto-commit if threshold reached
|
||||
if config.auto_commit and pending_changes >= config.commit_threshold then
|
||||
M.commit("Auto-commit: " .. pending_changes .. " changes")
|
||||
pending_changes = 0
|
||||
end
|
||||
end
|
||||
|
||||
return node_id
|
||||
end
|
||||
|
||||
--- Query relevant knowledge for context
|
||||
---@param opts QueryOpts Query options
|
||||
---@return QueryResult
|
||||
function M.query(opts)
|
||||
if not M.is_initialized() then
|
||||
return { nodes = {}, edges = {}, stats = {}, truncated = false }
|
||||
end
|
||||
|
||||
local query_engine = require("codetyper.brain.graph.query")
|
||||
return query_engine.execute(opts)
|
||||
end
|
||||
|
||||
--- Get LLM-optimized context string
|
||||
---@param opts? QueryOpts Query options
|
||||
---@return string Formatted context
|
||||
function M.get_context_for_llm(opts)
|
||||
if not M.is_initialized() then
|
||||
return ""
|
||||
end
|
||||
|
||||
opts = opts or {}
|
||||
opts.max_tokens = opts.max_tokens or config.output.max_tokens
|
||||
|
||||
local result = M.query(opts)
|
||||
local formatter = require("codetyper.brain.output.formatter")
|
||||
|
||||
if config.output.format == "json" then
|
||||
return formatter.to_json(result, opts)
|
||||
else
|
||||
return formatter.to_compact(result, opts)
|
||||
end
|
||||
end
|
||||
|
||||
--- Create a delta commit
|
||||
---@param message string Commit message
|
||||
---@return string|nil Delta hash
|
||||
function M.commit(message)
|
||||
if not M.is_initialized() then
|
||||
return nil
|
||||
end
|
||||
|
||||
local delta_mgr = require("codetyper.brain.delta")
|
||||
return delta_mgr.commit(message)
|
||||
end
|
||||
|
||||
--- Rollback to a previous delta
|
||||
---@param delta_hash string Target delta hash
|
||||
---@return boolean Success
|
||||
function M.rollback(delta_hash)
|
||||
if not M.is_initialized() then
|
||||
return false
|
||||
end
|
||||
|
||||
local delta_mgr = require("codetyper.brain.delta")
|
||||
return delta_mgr.rollback(delta_hash)
|
||||
end
|
||||
|
||||
--- Get delta history
|
||||
---@param limit? number Max entries
|
||||
---@return Delta[]
|
||||
function M.get_history(limit)
|
||||
if not M.is_initialized() then
|
||||
return {}
|
||||
end
|
||||
|
||||
local delta_mgr = require("codetyper.brain.delta")
|
||||
return delta_mgr.get_history(limit or 50)
|
||||
end
|
||||
|
||||
--- Prune low-value nodes
|
||||
---@param opts? table Prune options
|
||||
---@return number Number of pruned nodes
|
||||
function M.prune(opts)
|
||||
if not M.is_initialized() or not config.prune.enabled then
|
||||
return 0
|
||||
end
|
||||
|
||||
opts = vim.tbl_extend("force", {
|
||||
threshold = config.prune.threshold,
|
||||
unused_days = config.prune.unused_days,
|
||||
}, opts or {})
|
||||
|
||||
local graph = require("codetyper.brain.graph")
|
||||
return graph.prune(opts)
|
||||
end
|
||||
|
||||
--- Export brain state
|
||||
---@return table|nil Exported data
|
||||
function M.export()
|
||||
if not M.is_initialized() then
|
||||
return nil
|
||||
end
|
||||
|
||||
return {
|
||||
schema = types.SCHEMA_VERSION,
|
||||
meta = storage.get_meta(),
|
||||
graph = storage.get_graph(),
|
||||
nodes = {
|
||||
patterns = storage.get_nodes("patterns"),
|
||||
corrections = storage.get_nodes("corrections"),
|
||||
decisions = storage.get_nodes("decisions"),
|
||||
conventions = storage.get_nodes("conventions"),
|
||||
feedback = storage.get_nodes("feedback"),
|
||||
sessions = storage.get_nodes("sessions"),
|
||||
},
|
||||
indices = {
|
||||
by_file = storage.get_index("by_file"),
|
||||
by_time = storage.get_index("by_time"),
|
||||
by_symbol = storage.get_index("by_symbol"),
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Import brain state
|
||||
---@param data table Exported data
|
||||
---@return boolean Success
|
||||
function M.import(data)
|
||||
if not data or data.schema ~= types.SCHEMA_VERSION then
|
||||
return false
|
||||
end
|
||||
|
||||
storage.ensure_dirs()
|
||||
|
||||
-- Import nodes
|
||||
if data.nodes then
|
||||
for node_type, nodes in pairs(data.nodes) do
|
||||
storage.save_nodes(node_type, nodes)
|
||||
end
|
||||
end
|
||||
|
||||
-- Import graph
|
||||
if data.graph then
|
||||
storage.save_graph(data.graph)
|
||||
end
|
||||
|
||||
-- Import indices
|
||||
if data.indices then
|
||||
for index_type, index_data in pairs(data.indices) do
|
||||
storage.save_index(index_type, index_data)
|
||||
end
|
||||
end
|
||||
|
||||
-- Import meta last
|
||||
if data.meta then
|
||||
for k, v in pairs(data.meta) do
|
||||
storage.update_meta({ [k] = v })
|
||||
end
|
||||
end
|
||||
|
||||
storage.flush_all()
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get stats about the brain
|
||||
---@return table Stats
|
||||
function M.stats()
|
||||
if not M.is_initialized() then
|
||||
return {}
|
||||
end
|
||||
|
||||
local meta = storage.get_meta()
|
||||
return {
|
||||
initialized = true,
|
||||
node_count = meta.nc,
|
||||
edge_count = meta.ec,
|
||||
delta_count = meta.dc,
|
||||
head = meta.head,
|
||||
pending_changes = pending_changes,
|
||||
}
|
||||
end
|
||||
|
||||
--- Flush all pending writes to disk
|
||||
function M.flush()
|
||||
storage.flush_all()
|
||||
end
|
||||
|
||||
--- Shutdown brain (call before exit)
|
||||
function M.shutdown()
|
||||
if pending_changes > 0 then
|
||||
M.commit("Session end: " .. pending_changes .. " changes")
|
||||
end
|
||||
storage.flush_all()
|
||||
initialized = false
|
||||
end
|
||||
|
||||
return M
|
||||
233
lua/codetyper/brain/learners/convention.lua
Normal file
233
lua/codetyper/brain/learners/convention.lua
Normal file
@@ -0,0 +1,233 @@
|
||||
--- Brain Convention Learner
|
||||
--- Learns project conventions and coding standards
|
||||
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Detect if event contains convention info
|
||||
---@param event LearnEvent Learning event
|
||||
---@return boolean
|
||||
function M.detect(event)
|
||||
local valid_types = {
|
||||
"convention_detected",
|
||||
"naming_pattern",
|
||||
"style_pattern",
|
||||
"project_structure",
|
||||
"config_change",
|
||||
}
|
||||
|
||||
for _, t in ipairs(valid_types) do
|
||||
if event.type == t then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Extract convention data from event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return table|nil Extracted data
|
||||
function M.extract(event)
|
||||
local data = event.data or {}
|
||||
|
||||
if event.type == "convention_detected" then
|
||||
return {
|
||||
summary = "Convention: " .. (data.name or "unnamed"),
|
||||
detail = data.description or data.name,
|
||||
rule = data.rule,
|
||||
examples = data.examples,
|
||||
category = data.category or "general",
|
||||
file = event.file,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "naming_pattern" then
|
||||
return {
|
||||
summary = "Naming: " .. (data.pattern_name or data.pattern),
|
||||
detail = "Naming convention: " .. (data.description or data.pattern),
|
||||
rule = data.pattern,
|
||||
examples = data.examples,
|
||||
category = "naming",
|
||||
scope = data.scope, -- function, variable, class, file
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "style_pattern" then
|
||||
return {
|
||||
summary = "Style: " .. (data.name or "unnamed"),
|
||||
detail = data.description or "Code style pattern",
|
||||
rule = data.rule,
|
||||
examples = data.examples,
|
||||
category = "style",
|
||||
lang = data.language,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "project_structure" then
|
||||
return {
|
||||
summary = "Structure: " .. (data.pattern or "project layout"),
|
||||
detail = data.description or "Project structure convention",
|
||||
rule = data.rule,
|
||||
category = "structure",
|
||||
paths = data.paths,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "config_change" then
|
||||
return {
|
||||
summary = "Config: " .. (data.setting or "setting change"),
|
||||
detail = "Configuration: " .. (data.description or data.setting),
|
||||
before = data.before,
|
||||
after = data.after,
|
||||
category = "config",
|
||||
file = event.file,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Check if convention should be learned
|
||||
---@param data table Extracted data
|
||||
---@return boolean
|
||||
function M.should_learn(data)
|
||||
if not data.summary then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip very vague conventions
|
||||
if not data.detail or #data.detail < 5 then
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Create node from convention data
|
||||
---@param data table Extracted data
|
||||
---@return table Node creation params
|
||||
function M.create_node_params(data)
|
||||
local detail = data.detail or ""
|
||||
|
||||
-- Add examples if available
|
||||
if data.examples and #data.examples > 0 then
|
||||
detail = detail .. "\n\nExamples:"
|
||||
for _, ex in ipairs(data.examples) do
|
||||
detail = detail .. "\n- " .. tostring(ex)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add rule if available
|
||||
if data.rule then
|
||||
detail = detail .. "\n\nRule: " .. tostring(data.rule)
|
||||
end
|
||||
|
||||
return {
|
||||
node_type = types.NODE_TYPES.CONVENTION,
|
||||
content = {
|
||||
s = data.summary:sub(1, 200),
|
||||
d = detail,
|
||||
lang = data.lang,
|
||||
},
|
||||
context = {
|
||||
f = data.file,
|
||||
sym = data.scope and { data.scope } or nil,
|
||||
},
|
||||
opts = {
|
||||
weight = 0.6,
|
||||
source = types.SOURCES.AUTO,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Find related conventions
|
||||
---@param data table Extracted data
|
||||
---@param query_fn function Query function
|
||||
---@return string[] Related node IDs
|
||||
function M.find_related(data, query_fn)
|
||||
local related = {}
|
||||
|
||||
-- Find conventions in same category
|
||||
if data.category then
|
||||
local similar = query_fn({
|
||||
query = data.category,
|
||||
types = { types.NODE_TYPES.CONVENTION },
|
||||
limit = 5,
|
||||
})
|
||||
for _, node in ipairs(similar) do
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
|
||||
-- Find patterns that follow this convention
|
||||
if data.rule then
|
||||
local patterns = query_fn({
|
||||
query = data.rule,
|
||||
types = { types.NODE_TYPES.PATTERN },
|
||||
limit = 3,
|
||||
})
|
||||
for _, node in ipairs(patterns) do
|
||||
if not vim.tbl_contains(related, node.id) then
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return related
|
||||
end
|
||||
|
||||
--- Detect naming convention from symbol names
|
||||
---@param symbols string[] Symbol names to analyze
|
||||
---@return table|nil Detected convention
|
||||
function M.detect_naming(symbols)
|
||||
if not symbols or #symbols < 3 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local patterns = {
|
||||
snake_case = 0,
|
||||
camelCase = 0,
|
||||
PascalCase = 0,
|
||||
SCREAMING_SNAKE = 0,
|
||||
kebab_case = 0,
|
||||
}
|
||||
|
||||
for _, sym in ipairs(symbols) do
|
||||
if sym:match("^[a-z][a-z0-9_]*$") then
|
||||
patterns.snake_case = patterns.snake_case + 1
|
||||
elseif sym:match("^[a-z][a-zA-Z0-9]*$") then
|
||||
patterns.camelCase = patterns.camelCase + 1
|
||||
elseif sym:match("^[A-Z][a-zA-Z0-9]*$") then
|
||||
patterns.PascalCase = patterns.PascalCase + 1
|
||||
elseif sym:match("^[A-Z][A-Z0-9_]*$") then
|
||||
patterns.SCREAMING_SNAKE = patterns.SCREAMING_SNAKE + 1
|
||||
elseif sym:match("^[a-z][a-z0-9%-]*$") then
|
||||
patterns.kebab_case = patterns.kebab_case + 1
|
||||
end
|
||||
end
|
||||
|
||||
-- Find dominant pattern
|
||||
local max_count = 0
|
||||
local dominant = nil
|
||||
|
||||
for pattern, count in pairs(patterns) do
|
||||
if count > max_count then
|
||||
max_count = count
|
||||
dominant = pattern
|
||||
end
|
||||
end
|
||||
|
||||
if dominant and max_count >= #symbols * 0.6 then
|
||||
return {
|
||||
pattern = dominant,
|
||||
confidence = max_count / #symbols,
|
||||
sample_size = #symbols,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
return M
|
||||
213
lua/codetyper/brain/learners/correction.lua
Normal file
213
lua/codetyper/brain/learners/correction.lua
Normal file
@@ -0,0 +1,213 @@
|
||||
--- Brain Correction Learner
|
||||
--- Learns from user corrections and edits
|
||||
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Detect if event is a correction
|
||||
---@param event LearnEvent Learning event
|
||||
---@return boolean
|
||||
function M.detect(event)
|
||||
local valid_types = {
|
||||
"user_correction",
|
||||
"code_rejected",
|
||||
"code_modified",
|
||||
"suggestion_rejected",
|
||||
}
|
||||
|
||||
for _, t in ipairs(valid_types) do
|
||||
if event.type == t then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Extract correction data from event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return table|nil Extracted data
|
||||
function M.extract(event)
|
||||
local data = event.data or {}
|
||||
|
||||
if event.type == "user_correction" then
|
||||
return {
|
||||
summary = "Correction: " .. (data.error_type or "user edit"),
|
||||
detail = data.description or "User corrected the generated code",
|
||||
before = data.before,
|
||||
after = data.after,
|
||||
error_type = data.error_type,
|
||||
file = event.file,
|
||||
function_name = data.function_name,
|
||||
lines = data.lines,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "code_rejected" then
|
||||
return {
|
||||
summary = "Rejected: " .. (data.reason or "not accepted"),
|
||||
detail = data.description or "User rejected generated code",
|
||||
rejected_code = data.code,
|
||||
reason = data.reason,
|
||||
file = event.file,
|
||||
intent = data.intent,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "code_modified" then
|
||||
local changes = M.analyze_changes(data.before, data.after)
|
||||
return {
|
||||
summary = "Modified: " .. changes.summary,
|
||||
detail = changes.detail,
|
||||
before = data.before,
|
||||
after = data.after,
|
||||
change_type = changes.type,
|
||||
file = event.file,
|
||||
lines = data.lines,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Analyze changes between before/after code
|
||||
---@param before string Before code
|
||||
---@param after string After code
|
||||
---@return table Change analysis
|
||||
function M.analyze_changes(before, after)
|
||||
before = before or ""
|
||||
after = after or ""
|
||||
|
||||
local before_lines = vim.split(before, "\n")
|
||||
local after_lines = vim.split(after, "\n")
|
||||
|
||||
local added = 0
|
||||
local removed = 0
|
||||
local modified = 0
|
||||
|
||||
-- Simple line-based diff
|
||||
local max_lines = math.max(#before_lines, #after_lines)
|
||||
for i = 1, max_lines do
|
||||
local b = before_lines[i]
|
||||
local a = after_lines[i]
|
||||
|
||||
if b == nil and a ~= nil then
|
||||
added = added + 1
|
||||
elseif b ~= nil and a == nil then
|
||||
removed = removed + 1
|
||||
elseif b ~= a then
|
||||
modified = modified + 1
|
||||
end
|
||||
end
|
||||
|
||||
local change_type = "mixed"
|
||||
if added > 0 and removed == 0 and modified == 0 then
|
||||
change_type = "addition"
|
||||
elseif removed > 0 and added == 0 and modified == 0 then
|
||||
change_type = "deletion"
|
||||
elseif modified > 0 and added == 0 and removed == 0 then
|
||||
change_type = "modification"
|
||||
end
|
||||
|
||||
return {
|
||||
type = change_type,
|
||||
summary = string.format("+%d -%d ~%d lines", added, removed, modified),
|
||||
detail = string.format("Added %d, removed %d, modified %d lines", added, removed, modified),
|
||||
stats = {
|
||||
added = added,
|
||||
removed = removed,
|
||||
modified = modified,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if correction should be learned
|
||||
---@param data table Extracted data
|
||||
---@return boolean
|
||||
function M.should_learn(data)
|
||||
-- Always learn corrections - they're valuable
|
||||
if not data.summary then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip trivial changes
|
||||
if data.before and data.after then
|
||||
-- Skip if only whitespace changed
|
||||
local before_trimmed = data.before:gsub("%s+", "")
|
||||
local after_trimmed = data.after:gsub("%s+", "")
|
||||
if before_trimmed == after_trimmed then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Create node from correction data
|
||||
---@param data table Extracted data
|
||||
---@return table Node creation params
|
||||
function M.create_node_params(data)
|
||||
local detail = data.detail or ""
|
||||
|
||||
-- Include before/after in detail for learning
|
||||
if data.before and data.after then
|
||||
detail = detail .. "\n\nBefore:\n" .. data.before:sub(1, 500)
|
||||
detail = detail .. "\n\nAfter:\n" .. data.after:sub(1, 500)
|
||||
end
|
||||
|
||||
return {
|
||||
node_type = types.NODE_TYPES.CORRECTION,
|
||||
content = {
|
||||
s = data.summary:sub(1, 200),
|
||||
d = detail,
|
||||
code = data.after or data.rejected_code,
|
||||
lang = data.lang,
|
||||
},
|
||||
context = {
|
||||
f = data.file,
|
||||
fn = data.function_name,
|
||||
ln = data.lines,
|
||||
},
|
||||
opts = {
|
||||
weight = 0.7, -- Corrections are valuable
|
||||
source = types.SOURCES.USER,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Find related nodes for corrections
|
||||
---@param data table Extracted data
|
||||
---@param query_fn function Query function
|
||||
---@return string[] Related node IDs
|
||||
function M.find_related(data, query_fn)
|
||||
local related = {}
|
||||
|
||||
-- Find patterns that might be corrected
|
||||
if data.before then
|
||||
local similar = query_fn({
|
||||
query = data.before:sub(1, 100),
|
||||
types = { types.NODE_TYPES.PATTERN },
|
||||
limit = 3,
|
||||
})
|
||||
for _, node in ipairs(similar) do
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
|
||||
-- Find other corrections in same file
|
||||
if data.file then
|
||||
local file_corrections = query_fn({
|
||||
file = data.file,
|
||||
types = { types.NODE_TYPES.CORRECTION },
|
||||
limit = 3,
|
||||
})
|
||||
for _, node in ipairs(file_corrections) do
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
|
||||
return related
|
||||
end
|
||||
|
||||
return M
|
||||
232
lua/codetyper/brain/learners/init.lua
Normal file
232
lua/codetyper/brain/learners/init.lua
Normal file
@@ -0,0 +1,232 @@
|
||||
--- Brain Learners Coordinator
|
||||
--- Routes learning events to appropriate learners
|
||||
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
-- Lazy load learners
|
||||
local function get_pattern_learner()
|
||||
return require("codetyper.brain.learners.pattern")
|
||||
end
|
||||
|
||||
local function get_correction_learner()
|
||||
return require("codetyper.brain.learners.correction")
|
||||
end
|
||||
|
||||
local function get_convention_learner()
|
||||
return require("codetyper.brain.learners.convention")
|
||||
end
|
||||
|
||||
--- All available learners
|
||||
local LEARNERS = {
|
||||
{ name = "pattern", loader = get_pattern_learner },
|
||||
{ name = "correction", loader = get_correction_learner },
|
||||
{ name = "convention", loader = get_convention_learner },
|
||||
}
|
||||
|
||||
--- Process a learning event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return string|nil Created node ID
|
||||
function M.process(event)
|
||||
if not event or not event.type then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Add timestamp if missing
|
||||
event.timestamp = event.timestamp or os.time()
|
||||
|
||||
-- Find matching learner
|
||||
for _, learner_info in ipairs(LEARNERS) do
|
||||
local learner = learner_info.loader()
|
||||
|
||||
if learner.detect(event) then
|
||||
return M.learn_with(learner, event)
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle generic feedback events
|
||||
if event.type == "user_feedback" then
|
||||
return M.process_feedback(event)
|
||||
end
|
||||
|
||||
-- Handle session events
|
||||
if event.type == "session_start" or event.type == "session_end" then
|
||||
return M.process_session(event)
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Learn using a specific learner
|
||||
---@param learner table Learner module
|
||||
---@param event LearnEvent Learning event
|
||||
---@return string|nil Created node ID
|
||||
function M.learn_with(learner, event)
|
||||
-- Extract data
|
||||
local extracted = learner.extract(event)
|
||||
if not extracted then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Handle multiple extractions (e.g., from file indexing)
|
||||
if vim.islist(extracted) then
|
||||
local node_ids = {}
|
||||
for _, data in ipairs(extracted) do
|
||||
local node_id = M.create_learning(learner, data, event)
|
||||
if node_id then
|
||||
table.insert(node_ids, node_id)
|
||||
end
|
||||
end
|
||||
return node_ids[1] -- Return first for now
|
||||
end
|
||||
|
||||
return M.create_learning(learner, extracted, event)
|
||||
end
|
||||
|
||||
--- Create a learning from extracted data
|
||||
---@param learner table Learner module
|
||||
---@param data table Extracted data
|
||||
---@param event LearnEvent Original event
|
||||
---@return string|nil Created node ID
|
||||
function M.create_learning(learner, data, event)
|
||||
-- Check if should learn
|
||||
if not learner.should_learn(data) then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Get node params
|
||||
local params = learner.create_node_params(data)
|
||||
|
||||
-- Get graph module
|
||||
local graph = require("codetyper.brain.graph")
|
||||
|
||||
-- Find related nodes
|
||||
local related_ids = {}
|
||||
if learner.find_related then
|
||||
related_ids = learner.find_related(data, function(opts)
|
||||
return graph.query.execute(opts).nodes
|
||||
end)
|
||||
end
|
||||
|
||||
-- Create the learning
|
||||
local node = graph.add_learning(params.node_type, params.content, params.context, related_ids)
|
||||
|
||||
-- Update weight if specified
|
||||
if params.opts and params.opts.weight then
|
||||
graph.node.update(node.id, { sc = { w = params.opts.weight } })
|
||||
end
|
||||
|
||||
return node.id
|
||||
end
|
||||
|
||||
--- Process feedback event
|
||||
---@param event LearnEvent Feedback event
|
||||
---@return string|nil Created node ID
|
||||
function M.process_feedback(event)
|
||||
local data = event.data or {}
|
||||
local graph = require("codetyper.brain.graph")
|
||||
|
||||
local content = {
|
||||
s = "Feedback: " .. (data.feedback or "unknown"),
|
||||
d = data.description or ("User " .. (data.feedback or "gave feedback")),
|
||||
}
|
||||
|
||||
local context = {
|
||||
f = event.file,
|
||||
}
|
||||
|
||||
-- If feedback references a node, update it
|
||||
if data.node_id then
|
||||
local node = graph.node.get(data.node_id)
|
||||
if node then
|
||||
local weight_delta = data.feedback == "accepted" and 0.1 or -0.1
|
||||
local new_weight = math.max(0, math.min(1, node.sc.w + weight_delta))
|
||||
|
||||
graph.node.update(data.node_id, {
|
||||
sc = { w = new_weight },
|
||||
})
|
||||
|
||||
-- Record usage
|
||||
graph.node.record_usage(data.node_id, data.feedback == "accepted")
|
||||
|
||||
-- Create feedback node linked to original
|
||||
local fb_node = graph.add_learning(types.NODE_TYPES.FEEDBACK, content, context, { data.node_id })
|
||||
|
||||
return fb_node.id
|
||||
end
|
||||
end
|
||||
|
||||
-- Create standalone feedback node
|
||||
local node = graph.add_learning(types.NODE_TYPES.FEEDBACK, content, context)
|
||||
return node.id
|
||||
end
|
||||
|
||||
--- Process session event
|
||||
---@param event LearnEvent Session event
|
||||
---@return string|nil Created node ID
|
||||
function M.process_session(event)
|
||||
local data = event.data or {}
|
||||
local graph = require("codetyper.brain.graph")
|
||||
|
||||
local content = {
|
||||
s = event.type == "session_start" and "Session started" or "Session ended",
|
||||
d = data.description or event.type,
|
||||
}
|
||||
|
||||
if event.type == "session_end" and data.stats then
|
||||
content.d = content.d .. "\n\nStats:"
|
||||
content.d = content.d .. "\n- Completions: " .. (data.stats.completions or 0)
|
||||
content.d = content.d .. "\n- Corrections: " .. (data.stats.corrections or 0)
|
||||
content.d = content.d .. "\n- Files: " .. (data.stats.files or 0)
|
||||
end
|
||||
|
||||
local node = graph.add_learning(types.NODE_TYPES.SESSION, content, {})
|
||||
|
||||
-- Link to recent session nodes
|
||||
if event.type == "session_end" then
|
||||
local recent = graph.query.by_time_range(os.time() - 3600, os.time(), 20) -- Last hour
|
||||
local session_nodes = {}
|
||||
|
||||
for _, n in ipairs(recent) do
|
||||
if n.id ~= node.id then
|
||||
table.insert(session_nodes, n.id)
|
||||
end
|
||||
end
|
||||
|
||||
-- Create temporal links
|
||||
if #session_nodes > 0 then
|
||||
graph.link_temporal(session_nodes)
|
||||
end
|
||||
end
|
||||
|
||||
return node.id
|
||||
end
|
||||
|
||||
--- Batch process multiple events
|
||||
---@param events LearnEvent[] Events to process
|
||||
---@return string[] Created node IDs
|
||||
function M.batch_process(events)
|
||||
local node_ids = {}
|
||||
|
||||
for _, event in ipairs(events) do
|
||||
local node_id = M.process(event)
|
||||
if node_id then
|
||||
table.insert(node_ids, node_id)
|
||||
end
|
||||
end
|
||||
|
||||
return node_ids
|
||||
end
|
||||
|
||||
--- Get learner names
|
||||
---@return string[]
|
||||
function M.get_learner_names()
|
||||
local names = {}
|
||||
for _, learner in ipairs(LEARNERS) do
|
||||
table.insert(names, learner.name)
|
||||
end
|
||||
return names
|
||||
end
|
||||
|
||||
return M
|
||||
176
lua/codetyper/brain/learners/pattern.lua
Normal file
176
lua/codetyper/brain/learners/pattern.lua
Normal file
@@ -0,0 +1,176 @@
|
||||
--- Brain Pattern Learner
|
||||
--- Detects and learns code patterns
|
||||
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Detect if event contains a learnable pattern
|
||||
---@param event LearnEvent Learning event
|
||||
---@return boolean
|
||||
function M.detect(event)
|
||||
if not event or not event.type then
|
||||
return false
|
||||
end
|
||||
|
||||
local valid_types = {
|
||||
"code_completion",
|
||||
"file_indexed",
|
||||
"code_analyzed",
|
||||
"pattern_detected",
|
||||
}
|
||||
|
||||
for _, t in ipairs(valid_types) do
|
||||
if event.type == t then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Extract pattern data from event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return table|nil Extracted data
|
||||
function M.extract(event)
|
||||
local data = event.data or {}
|
||||
|
||||
-- Extract from code completion
|
||||
if event.type == "code_completion" then
|
||||
return {
|
||||
summary = "Code pattern: " .. (data.intent or "unknown"),
|
||||
detail = data.code or data.content or "",
|
||||
code = data.code,
|
||||
lang = data.language,
|
||||
file = event.file,
|
||||
function_name = data.function_name,
|
||||
symbols = data.symbols,
|
||||
}
|
||||
end
|
||||
|
||||
-- Extract from file indexing
|
||||
if event.type == "file_indexed" then
|
||||
local patterns = {}
|
||||
|
||||
-- Extract function patterns
|
||||
if data.functions then
|
||||
for _, func in ipairs(data.functions) do
|
||||
table.insert(patterns, {
|
||||
summary = "Function: " .. func.name,
|
||||
detail = func.signature or func.name,
|
||||
code = func.body,
|
||||
lang = data.language,
|
||||
file = event.file,
|
||||
function_name = func.name,
|
||||
lines = func.lines,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Extract class patterns
|
||||
if data.classes then
|
||||
for _, class in ipairs(data.classes) do
|
||||
table.insert(patterns, {
|
||||
summary = "Class: " .. class.name,
|
||||
detail = class.description or class.name,
|
||||
lang = data.language,
|
||||
file = event.file,
|
||||
symbols = { class.name },
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return #patterns > 0 and patterns or nil
|
||||
end
|
||||
|
||||
-- Extract from explicit pattern detection
|
||||
if event.type == "pattern_detected" then
|
||||
return {
|
||||
summary = data.name or "Unnamed pattern",
|
||||
detail = data.description or data.name or "",
|
||||
code = data.example,
|
||||
lang = data.language,
|
||||
file = event.file,
|
||||
symbols = data.symbols,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Check if pattern should be learned
|
||||
---@param data table Extracted data
|
||||
---@return boolean
|
||||
function M.should_learn(data)
|
||||
-- Skip if no meaningful content
|
||||
if not data.summary or data.summary == "" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip very short patterns
|
||||
if data.detail and #data.detail < 10 then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip auto-generated patterns
|
||||
if data.summary:match("^%s*$") then
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Create node from pattern data
|
||||
---@param data table Extracted data
|
||||
---@return table Node creation params
|
||||
function M.create_node_params(data)
|
||||
return {
|
||||
node_type = types.NODE_TYPES.PATTERN,
|
||||
content = {
|
||||
s = data.summary:sub(1, 200), -- Limit summary
|
||||
d = data.detail,
|
||||
code = data.code,
|
||||
lang = data.lang,
|
||||
},
|
||||
context = {
|
||||
f = data.file,
|
||||
fn = data.function_name,
|
||||
ln = data.lines,
|
||||
sym = data.symbols,
|
||||
},
|
||||
opts = {
|
||||
weight = 0.5,
|
||||
source = types.SOURCES.AUTO,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Find potentially related nodes
|
||||
---@param data table Extracted data
|
||||
---@param query_fn function Query function
|
||||
---@return string[] Related node IDs
|
||||
function M.find_related(data, query_fn)
|
||||
local related = {}
|
||||
|
||||
-- Find nodes in same file
|
||||
if data.file then
|
||||
local file_nodes = query_fn({ file = data.file, limit = 5 })
|
||||
for _, node in ipairs(file_nodes) do
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
|
||||
-- Find semantically similar
|
||||
if data.summary then
|
||||
local similar = query_fn({ query = data.summary, limit = 3 })
|
||||
for _, node in ipairs(similar) do
|
||||
if not vim.tbl_contains(related, node.id) then
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return related
|
||||
end
|
||||
|
||||
return M
|
||||
279
lua/codetyper/brain/output/formatter.lua
Normal file
279
lua/codetyper/brain/output/formatter.lua
Normal file
@@ -0,0 +1,279 @@
|
||||
--- Brain Output Formatter
|
||||
--- LLM-optimized output formatting
|
||||
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Estimate token count (rough approximation)
|
||||
---@param text string Text to estimate
|
||||
---@return number Estimated tokens
|
||||
function M.estimate_tokens(text)
|
||||
if not text then
|
||||
return 0
|
||||
end
|
||||
-- Rough estimate: 1 token ~= 4 characters
|
||||
return math.ceil(#text / 4)
|
||||
end
|
||||
|
||||
--- Format nodes to compact text format
|
||||
---@param result QueryResult Query result
|
||||
---@param opts? table Options
|
||||
---@return string Formatted output
|
||||
function M.to_compact(result, opts)
|
||||
opts = opts or {}
|
||||
local max_tokens = opts.max_tokens or 4000
|
||||
local lines = {}
|
||||
local current_tokens = 0
|
||||
|
||||
-- Header
|
||||
table.insert(lines, "---BRAIN_CONTEXT---")
|
||||
if opts.query then
|
||||
table.insert(lines, "Q: " .. opts.query)
|
||||
end
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Add nodes by relevance (already sorted)
|
||||
table.insert(lines, "Learnings:")
|
||||
|
||||
for i, node in ipairs(result.nodes) do
|
||||
-- Format: [idx] TYPE | w:0.85 u:5 | Summary
|
||||
local line = string.format(
|
||||
"[%d] %s | w:%.2f u:%d | %s",
|
||||
i,
|
||||
(node.t or "?"):upper(),
|
||||
node.sc.w or 0,
|
||||
node.sc.u or 0,
|
||||
(node.c.s or ""):sub(1, 100)
|
||||
)
|
||||
|
||||
local line_tokens = M.estimate_tokens(line)
|
||||
if current_tokens + line_tokens > max_tokens - 100 then
|
||||
table.insert(lines, "... (truncated)")
|
||||
break
|
||||
end
|
||||
|
||||
table.insert(lines, line)
|
||||
current_tokens = current_tokens + line_tokens
|
||||
|
||||
-- Add context if file-related
|
||||
if node.ctx and node.ctx.f then
|
||||
local ctx_line = " @ " .. node.ctx.f
|
||||
if node.ctx.fn then
|
||||
ctx_line = ctx_line .. ":" .. node.ctx.fn
|
||||
end
|
||||
if node.ctx.ln then
|
||||
ctx_line = ctx_line .. " L" .. node.ctx.ln[1]
|
||||
end
|
||||
table.insert(lines, ctx_line)
|
||||
current_tokens = current_tokens + M.estimate_tokens(ctx_line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add connections if space allows
|
||||
if #result.edges > 0 and current_tokens < max_tokens - 200 then
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Connections:")
|
||||
|
||||
for _, edge in ipairs(result.edges) do
|
||||
if current_tokens >= max_tokens - 50 then
|
||||
break
|
||||
end
|
||||
|
||||
local conn_line = string.format(
|
||||
" %s --%s(%.2f)--> %s",
|
||||
edge.s:sub(-8),
|
||||
edge.ty,
|
||||
edge.p.w or 0.5,
|
||||
edge.t:sub(-8)
|
||||
)
|
||||
table.insert(lines, conn_line)
|
||||
current_tokens = current_tokens + M.estimate_tokens(conn_line)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(lines, "---END_CONTEXT---")
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Format nodes to JSON format
|
||||
---@param result QueryResult Query result
|
||||
---@param opts? table Options
|
||||
---@return string JSON output
|
||||
function M.to_json(result, opts)
|
||||
opts = opts or {}
|
||||
local max_tokens = opts.max_tokens or 4000
|
||||
|
||||
local output = {
|
||||
_s = "brain-v1", -- Schema
|
||||
q = opts.query,
|
||||
l = {}, -- Learnings
|
||||
c = {}, -- Connections
|
||||
}
|
||||
|
||||
local current_tokens = 50 -- Base overhead
|
||||
|
||||
-- Add nodes
|
||||
for _, node in ipairs(result.nodes) do
|
||||
local entry = {
|
||||
t = node.t,
|
||||
s = (node.c.s or ""):sub(1, 150),
|
||||
w = node.sc.w,
|
||||
u = node.sc.u,
|
||||
}
|
||||
|
||||
if node.ctx and node.ctx.f then
|
||||
entry.f = node.ctx.f
|
||||
end
|
||||
|
||||
local entry_tokens = M.estimate_tokens(vim.json.encode(entry))
|
||||
if current_tokens + entry_tokens > max_tokens - 100 then
|
||||
break
|
||||
end
|
||||
|
||||
table.insert(output.l, entry)
|
||||
current_tokens = current_tokens + entry_tokens
|
||||
end
|
||||
|
||||
-- Add edges if space
|
||||
if current_tokens < max_tokens - 200 then
|
||||
for _, edge in ipairs(result.edges) do
|
||||
if current_tokens >= max_tokens - 50 then
|
||||
break
|
||||
end
|
||||
|
||||
local e = {
|
||||
s = edge.s:sub(-8),
|
||||
t = edge.t:sub(-8),
|
||||
r = edge.ty,
|
||||
w = edge.p.w,
|
||||
}
|
||||
|
||||
table.insert(output.c, e)
|
||||
current_tokens = current_tokens + 30
|
||||
end
|
||||
end
|
||||
|
||||
return vim.json.encode(output)
|
||||
end
|
||||
|
||||
--- Format as natural language
|
||||
---@param result QueryResult Query result
|
||||
---@param opts? table Options
|
||||
---@return string Natural language output
|
||||
function M.to_natural(result, opts)
|
||||
opts = opts or {}
|
||||
local max_tokens = opts.max_tokens or 4000
|
||||
local lines = {}
|
||||
local current_tokens = 0
|
||||
|
||||
if #result.nodes == 0 then
|
||||
return "No relevant learnings found."
|
||||
end
|
||||
|
||||
table.insert(lines, "Based on previous learnings:")
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Group by type
|
||||
local by_type = {}
|
||||
for _, node in ipairs(result.nodes) do
|
||||
by_type[node.t] = by_type[node.t] or {}
|
||||
table.insert(by_type[node.t], node)
|
||||
end
|
||||
|
||||
local type_names = {
|
||||
[types.NODE_TYPES.PATTERN] = "Code Patterns",
|
||||
[types.NODE_TYPES.CORRECTION] = "Previous Corrections",
|
||||
[types.NODE_TYPES.CONVENTION] = "Project Conventions",
|
||||
[types.NODE_TYPES.DECISION] = "Architectural Decisions",
|
||||
[types.NODE_TYPES.FEEDBACK] = "User Preferences",
|
||||
[types.NODE_TYPES.SESSION] = "Session Context",
|
||||
}
|
||||
|
||||
for node_type, nodes in pairs(by_type) do
|
||||
local type_name = type_names[node_type] or node_type
|
||||
|
||||
table.insert(lines, "**" .. type_name .. "**")
|
||||
|
||||
for _, node in ipairs(nodes) do
|
||||
if current_tokens >= max_tokens - 100 then
|
||||
table.insert(lines, "...")
|
||||
goto done
|
||||
end
|
||||
|
||||
local bullet = string.format("- %s (confidence: %.0f%%)", node.c.s or "?", (node.sc.w or 0) * 100)
|
||||
|
||||
table.insert(lines, bullet)
|
||||
current_tokens = current_tokens + M.estimate_tokens(bullet)
|
||||
|
||||
-- Add detail if high weight
|
||||
if node.sc.w > 0.7 and node.c.d and #node.c.d > #(node.c.s or "") then
|
||||
local detail = " " .. node.c.d:sub(1, 150)
|
||||
if #node.c.d > 150 then
|
||||
detail = detail .. "..."
|
||||
end
|
||||
table.insert(lines, detail)
|
||||
current_tokens = current_tokens + M.estimate_tokens(detail)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
::done::
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Format context chain for explanation
|
||||
---@param chain table[] Chain of nodes and edges
|
||||
---@return string Chain explanation
|
||||
function M.format_chain(chain)
|
||||
local lines = {}
|
||||
|
||||
for i, item in ipairs(chain) do
|
||||
if item.node then
|
||||
local prefix = i == 1 and "" or " -> "
|
||||
table.insert(lines, string.format("%s[%s] %s (w:%.2f)", prefix, item.node.t:upper(), item.node.c.s:sub(1, 50), item.node.sc.w))
|
||||
end
|
||||
if item.edge then
|
||||
table.insert(lines, string.format(" via %s (w:%.2f)", item.edge.ty, item.edge.p.w))
|
||||
end
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Compress output to fit token budget
|
||||
---@param text string Text to compress
|
||||
---@param max_tokens number Token budget
|
||||
---@return string Compressed text
|
||||
function M.compress(text, max_tokens)
|
||||
local current = M.estimate_tokens(text)
|
||||
|
||||
if current <= max_tokens then
|
||||
return text
|
||||
end
|
||||
|
||||
-- Simple truncation with ellipsis
|
||||
local ratio = max_tokens / current
|
||||
local target_chars = math.floor(#text * ratio * 0.9) -- 10% buffer
|
||||
|
||||
return text:sub(1, target_chars) .. "\n...(truncated)"
|
||||
end
|
||||
|
||||
--- Get minimal context for quick lookups
|
||||
---@param nodes Node[] Nodes to format
|
||||
---@return string Minimal context
|
||||
function M.minimal(nodes)
|
||||
local items = {}
|
||||
|
||||
for _, node in ipairs(nodes) do
|
||||
table.insert(items, string.format("%s:%s", node.t, (node.c.s or ""):sub(1, 40)))
|
||||
end
|
||||
|
||||
return table.concat(items, " | ")
|
||||
end
|
||||
|
||||
return M
|
||||
166
lua/codetyper/brain/output/init.lua
Normal file
166
lua/codetyper/brain/output/init.lua
Normal file
@@ -0,0 +1,166 @@
|
||||
--- Brain Output Coordinator
|
||||
--- Manages LLM context generation
|
||||
|
||||
local formatter = require("codetyper.brain.output.formatter")
|
||||
|
||||
local M = {}
|
||||
|
||||
-- Re-export formatter
|
||||
M.formatter = formatter
|
||||
|
||||
--- Default token budget
|
||||
local DEFAULT_MAX_TOKENS = 4000
|
||||
|
||||
--- Generate context for LLM prompt
|
||||
---@param opts? table Options
|
||||
---@return string Context string
|
||||
function M.generate(opts)
|
||||
opts = opts or {}
|
||||
|
||||
local brain = require("codetyper.brain")
|
||||
if not brain.is_initialized() then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Build query opts
|
||||
local query_opts = {
|
||||
query = opts.query,
|
||||
file = opts.file,
|
||||
types = opts.types,
|
||||
since = opts.since,
|
||||
limit = opts.limit or 30,
|
||||
depth = opts.depth or 2,
|
||||
max_tokens = opts.max_tokens or DEFAULT_MAX_TOKENS,
|
||||
}
|
||||
|
||||
-- Execute query
|
||||
local result = brain.query(query_opts)
|
||||
|
||||
if #result.nodes == 0 then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Format based on style
|
||||
local format = opts.format or "compact"
|
||||
|
||||
if format == "json" then
|
||||
return formatter.to_json(result, query_opts)
|
||||
elseif format == "natural" then
|
||||
return formatter.to_natural(result, query_opts)
|
||||
else
|
||||
return formatter.to_compact(result, query_opts)
|
||||
end
|
||||
end
|
||||
|
||||
--- Generate context for a specific file
|
||||
---@param filepath string File path
|
||||
---@param opts? table Options
|
||||
---@return string Context string
|
||||
function M.for_file(filepath, opts)
|
||||
opts = opts or {}
|
||||
opts.file = filepath
|
||||
return M.generate(opts)
|
||||
end
|
||||
|
||||
--- Generate context for current buffer
|
||||
---@param opts? table Options
|
||||
---@return string Context string
|
||||
function M.for_current_buffer(opts)
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
if filepath == "" then
|
||||
return ""
|
||||
end
|
||||
return M.for_file(filepath, opts)
|
||||
end
|
||||
|
||||
--- Generate context for a query/prompt
|
||||
---@param query string Query text
|
||||
---@param opts? table Options
|
||||
---@return string Context string
|
||||
function M.for_query(query, opts)
|
||||
opts = opts or {}
|
||||
opts.query = query
|
||||
return M.generate(opts)
|
||||
end
|
||||
|
||||
--- Get context for LLM system prompt
|
||||
---@param opts? table Options
|
||||
---@return string System context
|
||||
function M.system_context(opts)
|
||||
opts = opts or {}
|
||||
opts.limit = opts.limit or 20
|
||||
opts.format = opts.format or "compact"
|
||||
|
||||
local context = M.generate(opts)
|
||||
|
||||
if context == "" then
|
||||
return ""
|
||||
end
|
||||
|
||||
return [[
|
||||
The following context contains learned patterns and conventions from this project:
|
||||
|
||||
]] .. context .. [[
|
||||
|
||||
|
||||
Use this context to inform your responses, following established patterns and conventions.
|
||||
]]
|
||||
end
|
||||
|
||||
--- Get relevant context for code completion
|
||||
---@param prefix string Code before cursor
|
||||
---@param suffix string Code after cursor
|
||||
---@param filepath string Current file
|
||||
---@return string Context
|
||||
function M.for_completion(prefix, suffix, filepath)
|
||||
-- Extract relevant terms from code
|
||||
local terms = {}
|
||||
|
||||
-- Get function/class names
|
||||
for word in prefix:gmatch("[A-Z][a-zA-Z0-9]+") do
|
||||
table.insert(terms, word)
|
||||
end
|
||||
for word in prefix:gmatch("function%s+([a-zA-Z_][a-zA-Z0-9_]*)") do
|
||||
table.insert(terms, word)
|
||||
end
|
||||
|
||||
local query = table.concat(terms, " ")
|
||||
|
||||
return M.generate({
|
||||
query = query,
|
||||
file = filepath,
|
||||
limit = 15,
|
||||
max_tokens = 2000,
|
||||
format = "compact",
|
||||
})
|
||||
end
|
||||
|
||||
--- Check if context is available
|
||||
---@return boolean
|
||||
function M.has_context()
|
||||
local brain = require("codetyper.brain")
|
||||
if not brain.is_initialized() then
|
||||
return false
|
||||
end
|
||||
|
||||
local stats = brain.stats()
|
||||
return stats.node_count > 0
|
||||
end
|
||||
|
||||
--- Get context stats
|
||||
---@return table Stats
|
||||
function M.stats()
|
||||
local brain = require("codetyper.brain")
|
||||
if not brain.is_initialized() then
|
||||
return { available = false }
|
||||
end
|
||||
|
||||
local stats = brain.stats()
|
||||
return {
|
||||
available = true,
|
||||
node_count = stats.node_count,
|
||||
edge_count = stats.edge_count,
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
338
lua/codetyper/brain/storage.lua
Normal file
338
lua/codetyper/brain/storage.lua
Normal file
@@ -0,0 +1,338 @@
|
||||
--- Brain Storage Layer
|
||||
--- Cache + disk persistence with lazy loading
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local types = require("codetyper.brain.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- In-memory cache keyed by project root
|
||||
---@type table<string, table>
|
||||
local cache = {}
|
||||
|
||||
--- Dirty flags for pending writes
|
||||
---@type table<string, table<string, boolean>>
|
||||
local dirty = {}
|
||||
|
||||
--- Debounce timers
|
||||
---@type table<string, userdata>
|
||||
local timers = {}
|
||||
|
||||
local DEBOUNCE_MS = 500
|
||||
|
||||
--- Get brain directory path for current project
|
||||
---@param root? string Project root (defaults to current)
|
||||
---@return string Brain directory path
|
||||
function M.get_brain_dir(root)
|
||||
root = root or utils.get_project_root()
|
||||
return root .. "/.coder/brain"
|
||||
end
|
||||
|
||||
--- Ensure brain directory structure exists
|
||||
---@param root? string Project root
|
||||
---@return boolean Success
|
||||
function M.ensure_dirs(root)
|
||||
local brain_dir = M.get_brain_dir(root)
|
||||
local dirs = {
|
||||
brain_dir,
|
||||
brain_dir .. "/nodes",
|
||||
brain_dir .. "/indices",
|
||||
brain_dir .. "/deltas",
|
||||
brain_dir .. "/deltas/objects",
|
||||
}
|
||||
for _, dir in ipairs(dirs) do
|
||||
if not utils.ensure_dir(dir) then
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get file path for a storage key
|
||||
---@param key string Storage key (e.g., "meta", "nodes.patterns", "deltas.objects.abc123")
|
||||
---@param root? string Project root
|
||||
---@return string File path
|
||||
function M.get_path(key, root)
|
||||
local brain_dir = M.get_brain_dir(root)
|
||||
local parts = vim.split(key, ".", { plain = true })
|
||||
|
||||
if #parts == 1 then
|
||||
return brain_dir .. "/" .. key .. ".json"
|
||||
elseif #parts == 2 then
|
||||
return brain_dir .. "/" .. parts[1] .. "/" .. parts[2] .. ".json"
|
||||
else
|
||||
return brain_dir .. "/" .. table.concat(parts, "/") .. ".json"
|
||||
end
|
||||
end
|
||||
|
||||
--- Get cache for project
|
||||
---@param root? string Project root
|
||||
---@return table Project cache
|
||||
local function get_cache(root)
|
||||
root = root or utils.get_project_root()
|
||||
if not cache[root] then
|
||||
cache[root] = {}
|
||||
dirty[root] = {}
|
||||
end
|
||||
return cache[root]
|
||||
end
|
||||
|
||||
--- Read JSON from disk
|
||||
---@param filepath string File path
|
||||
---@return table|nil Data or nil on error
|
||||
local function read_json(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content or content == "" then
|
||||
return nil
|
||||
end
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok then
|
||||
return nil
|
||||
end
|
||||
return data
|
||||
end
|
||||
|
||||
--- Write JSON to disk
|
||||
---@param filepath string File path
|
||||
---@param data table Data to write
|
||||
---@return boolean Success
|
||||
local function write_json(filepath, data)
|
||||
local ok, json = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
return utils.write_file(filepath, json)
|
||||
end
|
||||
|
||||
--- Load data from disk into cache
|
||||
---@param key string Storage key
|
||||
---@param root? string Project root
|
||||
---@return table|nil Data or nil
|
||||
function M.load(key, root)
|
||||
root = root or utils.get_project_root()
|
||||
local project_cache = get_cache(root)
|
||||
|
||||
-- Return cached if available
|
||||
if project_cache[key] ~= nil then
|
||||
return project_cache[key]
|
||||
end
|
||||
|
||||
-- Load from disk
|
||||
local filepath = M.get_path(key, root)
|
||||
local data = read_json(filepath)
|
||||
|
||||
-- Cache the result (even nil to avoid repeated reads)
|
||||
project_cache[key] = data or {}
|
||||
|
||||
return project_cache[key]
|
||||
end
|
||||
|
||||
--- Save data to cache and schedule disk write
|
||||
---@param key string Storage key
|
||||
---@param data table Data to save
|
||||
---@param root? string Project root
|
||||
---@param immediate? boolean Skip debounce
|
||||
function M.save(key, data, root, immediate)
|
||||
root = root or utils.get_project_root()
|
||||
local project_cache = get_cache(root)
|
||||
|
||||
-- Update cache
|
||||
project_cache[key] = data
|
||||
dirty[root][key] = true
|
||||
|
||||
if immediate then
|
||||
M.flush(key, root)
|
||||
return
|
||||
end
|
||||
|
||||
-- Debounced write
|
||||
local timer_key = root .. ":" .. key
|
||||
if timers[timer_key] then
|
||||
timers[timer_key]:stop()
|
||||
end
|
||||
|
||||
timers[timer_key] = vim.defer_fn(function()
|
||||
M.flush(key, root)
|
||||
timers[timer_key] = nil
|
||||
end, DEBOUNCE_MS)
|
||||
end
|
||||
|
||||
--- Flush a key to disk immediately
|
||||
---@param key string Storage key
|
||||
---@param root? string Project root
|
||||
---@return boolean Success
|
||||
function M.flush(key, root)
|
||||
root = root or utils.get_project_root()
|
||||
local project_cache = get_cache(root)
|
||||
|
||||
if not dirty[root][key] then
|
||||
return true
|
||||
end
|
||||
|
||||
M.ensure_dirs(root)
|
||||
local filepath = M.get_path(key, root)
|
||||
local data = project_cache[key]
|
||||
|
||||
if data == nil then
|
||||
-- Delete file if data is nil
|
||||
os.remove(filepath)
|
||||
dirty[root][key] = nil
|
||||
return true
|
||||
end
|
||||
|
||||
local success = write_json(filepath, data)
|
||||
if success then
|
||||
dirty[root][key] = nil
|
||||
end
|
||||
return success
|
||||
end
|
||||
|
||||
--- Flush all dirty keys to disk
|
||||
---@param root? string Project root
|
||||
function M.flush_all(root)
|
||||
root = root or utils.get_project_root()
|
||||
if not dirty[root] then
|
||||
return
|
||||
end
|
||||
|
||||
for key, is_dirty in pairs(dirty[root]) do
|
||||
if is_dirty then
|
||||
M.flush(key, root)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Get meta.json data
|
||||
---@param root? string Project root
|
||||
---@return GraphMeta
|
||||
function M.get_meta(root)
|
||||
local meta = M.load("meta", root)
|
||||
if not meta or not meta.v then
|
||||
meta = {
|
||||
v = types.SCHEMA_VERSION,
|
||||
head = nil,
|
||||
nc = 0,
|
||||
ec = 0,
|
||||
dc = 0,
|
||||
}
|
||||
M.save("meta", meta, root)
|
||||
end
|
||||
return meta
|
||||
end
|
||||
|
||||
--- Update meta.json
|
||||
---@param updates table Partial updates
|
||||
---@param root? string Project root
|
||||
function M.update_meta(updates, root)
|
||||
local meta = M.get_meta(root)
|
||||
for k, v in pairs(updates) do
|
||||
meta[k] = v
|
||||
end
|
||||
M.save("meta", meta, root)
|
||||
end
|
||||
|
||||
--- Get nodes by type
|
||||
---@param node_type string Node type (e.g., "patterns", "corrections")
|
||||
---@param root? string Project root
|
||||
---@return table<string, Node> Nodes indexed by ID
|
||||
function M.get_nodes(node_type, root)
|
||||
return M.load("nodes." .. node_type, root) or {}
|
||||
end
|
||||
|
||||
--- Save nodes by type
|
||||
---@param node_type string Node type
|
||||
---@param nodes table<string, Node> Nodes indexed by ID
|
||||
---@param root? string Project root
|
||||
function M.save_nodes(node_type, nodes, root)
|
||||
M.save("nodes." .. node_type, nodes, root)
|
||||
end
|
||||
|
||||
--- Get graph adjacency
|
||||
---@param root? string Project root
|
||||
---@return Graph Graph data
|
||||
function M.get_graph(root)
|
||||
local graph = M.load("graph", root)
|
||||
if not graph or not graph.adj then
|
||||
graph = {
|
||||
adj = {},
|
||||
radj = {},
|
||||
}
|
||||
M.save("graph", graph, root)
|
||||
end
|
||||
return graph
|
||||
end
|
||||
|
||||
--- Save graph
|
||||
---@param graph Graph Graph data
|
||||
---@param root? string Project root
|
||||
function M.save_graph(graph, root)
|
||||
M.save("graph", graph, root)
|
||||
end
|
||||
|
||||
--- Get index by type
|
||||
---@param index_type string Index type (e.g., "by_file", "by_time")
|
||||
---@param root? string Project root
|
||||
---@return table Index data
|
||||
function M.get_index(index_type, root)
|
||||
return M.load("indices." .. index_type, root) or {}
|
||||
end
|
||||
|
||||
--- Save index
|
||||
---@param index_type string Index type
|
||||
---@param data table Index data
|
||||
---@param root? string Project root
|
||||
function M.save_index(index_type, data, root)
|
||||
M.save("indices." .. index_type, data, root)
|
||||
end
|
||||
|
||||
--- Get delta by hash
|
||||
---@param hash string Delta hash
|
||||
---@param root? string Project root
|
||||
---@return Delta|nil Delta data
|
||||
function M.get_delta(hash, root)
|
||||
return M.load("deltas.objects." .. hash, root)
|
||||
end
|
||||
|
||||
--- Save delta
|
||||
---@param delta Delta Delta data
|
||||
---@param root? string Project root
|
||||
function M.save_delta(delta, root)
|
||||
M.save("deltas.objects." .. delta.h, delta, root, true) -- Immediate write for deltas
|
||||
end
|
||||
|
||||
--- Get HEAD delta hash
|
||||
---@param root? string Project root
|
||||
---@return string|nil HEAD hash
|
||||
function M.get_head(root)
|
||||
local meta = M.get_meta(root)
|
||||
return meta.head
|
||||
end
|
||||
|
||||
--- Set HEAD delta hash
|
||||
---@param hash string|nil Delta hash
|
||||
---@param root? string Project root
|
||||
function M.set_head(hash, root)
|
||||
M.update_meta({ head = hash }, root)
|
||||
end
|
||||
|
||||
--- Clear all caches (for testing)
|
||||
function M.clear_cache()
|
||||
cache = {}
|
||||
dirty = {}
|
||||
for _, timer in pairs(timers) do
|
||||
if timer then
|
||||
timer:stop()
|
||||
end
|
||||
end
|
||||
timers = {}
|
||||
end
|
||||
|
||||
--- Check if brain exists for project
|
||||
---@param root? string Project root
|
||||
---@return boolean
|
||||
function M.exists(root)
|
||||
local brain_dir = M.get_brain_dir(root)
|
||||
return vim.fn.isdirectory(brain_dir) == 1
|
||||
end
|
||||
|
||||
return M
|
||||
175
lua/codetyper/brain/types.lua
Normal file
175
lua/codetyper/brain/types.lua
Normal file
@@ -0,0 +1,175 @@
|
||||
---@meta
|
||||
--- Brain Learning System Type Definitions
|
||||
--- Optimized for LLM consumption with compact field names
|
||||
|
||||
local M = {}
|
||||
|
||||
---@alias NodeType "pat"|"cor"|"dec"|"con"|"fbk"|"ses"
|
||||
-- pat = pattern, cor = correction, dec = decision
|
||||
-- con = convention, fbk = feedback, ses = session
|
||||
|
||||
---@alias EdgeType "sem"|"file"|"temp"|"caus"|"sup"
|
||||
-- sem = semantic, file = file-based, temp = temporal
|
||||
-- caus = causal, sup = supersedes
|
||||
|
||||
---@alias DeltaOp "add"|"mod"|"del"
|
||||
|
||||
---@class NodeContent
|
||||
---@field s string Summary (max 200 chars)
|
||||
---@field d string Detail (full description)
|
||||
---@field code? string Optional code snippet
|
||||
---@field lang? string Language identifier
|
||||
|
||||
---@class NodeContext
|
||||
---@field f? string File path (relative)
|
||||
---@field fn? string Function name
|
||||
---@field ln? number[] Line range [start, end]
|
||||
---@field sym? string[] Symbol references
|
||||
|
||||
---@class NodeScores
|
||||
---@field w number Weight (0-1)
|
||||
---@field u number Usage count
|
||||
---@field sr number Success rate (0-1)
|
||||
|
||||
---@class NodeTimestamps
|
||||
---@field cr number Created (unix timestamp)
|
||||
---@field up number Updated (unix timestamp)
|
||||
---@field lu? number Last used (unix timestamp)
|
||||
|
||||
---@class NodeMeta
|
||||
---@field src "auto"|"user"|"llm" Source of learning
|
||||
---@field v number Version number
|
||||
---@field dr? string[] Delta references
|
||||
|
||||
---@class Node
|
||||
---@field id string Unique identifier (n_<timestamp>_<hash>)
|
||||
---@field t NodeType Node type
|
||||
---@field h string Content hash (8 chars)
|
||||
---@field c NodeContent Content
|
||||
---@field ctx NodeContext Context
|
||||
---@field sc NodeScores Scores
|
||||
---@field ts NodeTimestamps Timestamps
|
||||
---@field m? NodeMeta Metadata
|
||||
|
||||
---@class EdgeProps
|
||||
---@field w number Weight (0-1)
|
||||
---@field dir "bi"|"fwd"|"bwd" Direction
|
||||
---@field r? string Reason/description
|
||||
|
||||
---@class Edge
|
||||
---@field id string Unique identifier (e_<source>_<target>)
|
||||
---@field s string Source node ID
|
||||
---@field t string Target node ID
|
||||
---@field ty EdgeType Edge type
|
||||
---@field p EdgeProps Properties
|
||||
---@field ts number Created timestamp
|
||||
|
||||
---@class DeltaChange
|
||||
---@field op DeltaOp Operation type
|
||||
---@field path string JSON path (e.g., "nodes.pat.n_123")
|
||||
---@field bh? string Before hash
|
||||
---@field ah? string After hash
|
||||
---@field diff? table Field-level diff
|
||||
|
||||
---@class DeltaMeta
|
||||
---@field msg string Commit message
|
||||
---@field trig string Trigger source
|
||||
---@field sid? string Session ID
|
||||
|
||||
---@class Delta
|
||||
---@field h string Hash (8 chars)
|
||||
---@field p? string Parent hash
|
||||
---@field ts number Timestamp
|
||||
---@field ch DeltaChange[] Changes
|
||||
---@field m DeltaMeta Metadata
|
||||
|
||||
---@class GraphMeta
|
||||
---@field v number Schema version
|
||||
---@field head? string Current HEAD delta hash
|
||||
---@field nc number Node count
|
||||
---@field ec number Edge count
|
||||
---@field dc number Delta count
|
||||
|
||||
---@class AdjacencyEntry
|
||||
---@field sem? string[] Semantic edges
|
||||
---@field file? string[] File edges
|
||||
---@field temp? string[] Temporal edges
|
||||
---@field caus? string[] Causal edges
|
||||
---@field sup? string[] Supersedes edges
|
||||
|
||||
---@class Graph
|
||||
---@field meta GraphMeta Metadata
|
||||
---@field adj table<string, AdjacencyEntry> Adjacency list
|
||||
---@field radj table<string, AdjacencyEntry> Reverse adjacency
|
||||
|
||||
---@class QueryOpts
|
||||
---@field query? string Text query
|
||||
---@field file? string File path filter
|
||||
---@field types? NodeType[] Node types to include
|
||||
---@field since? number Timestamp filter
|
||||
---@field limit? number Max results
|
||||
---@field depth? number Traversal depth
|
||||
---@field max_tokens? number Token budget
|
||||
|
||||
---@class QueryResult
|
||||
---@field nodes Node[] Matched nodes
|
||||
---@field edges Edge[] Related edges
|
||||
---@field stats table Query statistics
|
||||
---@field truncated boolean Whether results were truncated
|
||||
|
||||
---@class LLMContext
|
||||
---@field schema string Schema version
|
||||
---@field query string Original query
|
||||
---@field learnings table[] Compact learning entries
|
||||
---@field connections table[] Connection summaries
|
||||
---@field tokens number Estimated token count
|
||||
|
||||
---@class LearnEvent
|
||||
---@field type string Event type
|
||||
---@field data table Event data
|
||||
---@field file? string Related file
|
||||
---@field timestamp number Event timestamp
|
||||
|
||||
---@class BrainConfig
|
||||
---@field enabled boolean Enable brain system
|
||||
---@field auto_learn boolean Auto-learn from events
|
||||
---@field auto_commit boolean Auto-commit after threshold
|
||||
---@field commit_threshold number Changes before auto-commit
|
||||
---@field max_nodes number Max nodes before pruning
|
||||
---@field max_deltas number Max delta history
|
||||
---@field prune table Pruning config
|
||||
---@field output table Output config
|
||||
|
||||
-- Type constants for runtime use
|
||||
M.NODE_TYPES = {
|
||||
PATTERN = "pat",
|
||||
CORRECTION = "cor",
|
||||
DECISION = "dec",
|
||||
CONVENTION = "con",
|
||||
FEEDBACK = "fbk",
|
||||
SESSION = "ses",
|
||||
}
|
||||
|
||||
M.EDGE_TYPES = {
|
||||
SEMANTIC = "sem",
|
||||
FILE = "file",
|
||||
TEMPORAL = "temp",
|
||||
CAUSAL = "caus",
|
||||
SUPERSEDES = "sup",
|
||||
}
|
||||
|
||||
M.DELTA_OPS = {
|
||||
ADD = "add",
|
||||
MODIFY = "mod",
|
||||
DELETE = "del",
|
||||
}
|
||||
|
||||
M.SOURCES = {
|
||||
AUTO = "auto",
|
||||
USER = "user",
|
||||
LLM = "llm",
|
||||
}
|
||||
|
||||
M.SCHEMA_VERSION = 1
|
||||
|
||||
return M
|
||||
301
lua/codetyper/cmp_source/init.lua
Normal file
301
lua/codetyper/cmp_source/init.lua
Normal file
@@ -0,0 +1,301 @@
|
||||
---@mod codetyper.cmp_source Completion source for nvim-cmp
|
||||
---@brief [[
|
||||
--- Provides intelligent code completions using the brain, indexer, and LLM.
|
||||
--- Integrates with nvim-cmp as a custom source.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local source = {}
|
||||
|
||||
--- Check if cmp is available
|
||||
---@return boolean
|
||||
local function has_cmp()
|
||||
return pcall(require, "cmp")
|
||||
end
|
||||
|
||||
--- Get completion items from brain context
|
||||
---@param prefix string Current word prefix
|
||||
---@return table[] items
|
||||
local function get_brain_completions(prefix)
|
||||
local items = {}
|
||||
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain then
|
||||
return items
|
||||
end
|
||||
|
||||
-- Check if brain is initialized safely
|
||||
local is_init = false
|
||||
if brain.is_initialized then
|
||||
local ok, result = pcall(brain.is_initialized)
|
||||
is_init = ok and result
|
||||
end
|
||||
|
||||
if not is_init then
|
||||
return items
|
||||
end
|
||||
|
||||
-- Query brain for relevant patterns
|
||||
local ok_query, result = pcall(brain.query, {
|
||||
query = prefix,
|
||||
max_results = 10,
|
||||
types = { "pattern" },
|
||||
})
|
||||
|
||||
if ok_query and result and result.nodes then
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c and node.c.s then
|
||||
-- Extract function/class names from summary
|
||||
local summary = node.c.s
|
||||
for name in summary:gmatch("functions:%s*([^;]+)") do
|
||||
for func in name:gmatch("([%w_]+)") do
|
||||
if func:lower():find(prefix:lower(), 1, true) then
|
||||
table.insert(items, {
|
||||
label = func,
|
||||
kind = 3, -- Function
|
||||
detail = "[brain]",
|
||||
documentation = summary,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
for name in summary:gmatch("classes:%s*([^;]+)") do
|
||||
for class in name:gmatch("([%w_]+)") do
|
||||
if class:lower():find(prefix:lower(), 1, true) then
|
||||
table.insert(items, {
|
||||
label = class,
|
||||
kind = 7, -- Class
|
||||
detail = "[brain]",
|
||||
documentation = summary,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return items
|
||||
end
|
||||
|
||||
--- Get completion items from indexer symbols
|
||||
---@param prefix string Current word prefix
|
||||
---@return table[] items
|
||||
local function get_indexer_completions(prefix)
|
||||
local items = {}
|
||||
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if not ok_indexer then
|
||||
return items
|
||||
end
|
||||
|
||||
local ok_load, index = pcall(indexer.load_index)
|
||||
if not ok_load or not index then
|
||||
return items
|
||||
end
|
||||
|
||||
-- Search symbols
|
||||
if index.symbols then
|
||||
for symbol, files in pairs(index.symbols) do
|
||||
if symbol:lower():find(prefix:lower(), 1, true) then
|
||||
local files_str = type(files) == "table" and table.concat(files, ", ") or tostring(files)
|
||||
table.insert(items, {
|
||||
label = symbol,
|
||||
kind = 6, -- Variable (generic)
|
||||
detail = "[index] " .. files_str:sub(1, 30),
|
||||
documentation = "Symbol found in: " .. files_str,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Search functions in files
|
||||
if index.files then
|
||||
for filepath, file_index in pairs(index.files) do
|
||||
if file_index and file_index.functions then
|
||||
for _, func in ipairs(file_index.functions) do
|
||||
if func.name and func.name:lower():find(prefix:lower(), 1, true) then
|
||||
table.insert(items, {
|
||||
label = func.name,
|
||||
kind = 3, -- Function
|
||||
detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"),
|
||||
documentation = func.docstring or ("Function at line " .. (func.line or "?")),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
if file_index and file_index.classes then
|
||||
for _, class in ipairs(file_index.classes) do
|
||||
if class.name and class.name:lower():find(prefix:lower(), 1, true) then
|
||||
table.insert(items, {
|
||||
label = class.name,
|
||||
kind = 7, -- Class
|
||||
detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"),
|
||||
documentation = class.docstring or ("Class at line " .. (class.line or "?")),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return items
|
||||
end
|
||||
|
||||
--- Get completion items from current buffer (fallback)
|
||||
---@param prefix string Current word prefix
|
||||
---@param bufnr number Buffer number
|
||||
---@return table[] items
|
||||
local function get_buffer_completions(prefix, bufnr)
|
||||
local items = {}
|
||||
local seen = {}
|
||||
|
||||
-- Get all lines in buffer
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local prefix_lower = prefix:lower()
|
||||
|
||||
for _, line in ipairs(lines) do
|
||||
-- Extract words that could be identifiers
|
||||
for word in line:gmatch("[%a_][%w_]*") do
|
||||
if #word >= 3 and word:lower():find(prefix_lower, 1, true) and not seen[word] and word ~= prefix then
|
||||
seen[word] = true
|
||||
table.insert(items, {
|
||||
label = word,
|
||||
kind = 1, -- Text
|
||||
detail = "[buffer]",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return items
|
||||
end
|
||||
|
||||
--- Create new cmp source instance
|
||||
function source.new()
|
||||
return setmetatable({}, { __index = source })
|
||||
end
|
||||
|
||||
--- Get source name
|
||||
function source:get_keyword_pattern()
|
||||
return [[\k\+]]
|
||||
end
|
||||
|
||||
--- Check if source is available
|
||||
function source:is_available()
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get debug name
|
||||
function source:get_debug_name()
|
||||
return "codetyper"
|
||||
end
|
||||
|
||||
--- Get trigger characters
|
||||
function source:get_trigger_characters()
|
||||
return { ".", ":", "_" }
|
||||
end
|
||||
|
||||
--- Complete
|
||||
---@param params table
|
||||
---@param callback fun(response: table|nil)
|
||||
function source:complete(params, callback)
|
||||
local prefix = params.context.cursor_before_line:match("[%w_]+$") or ""
|
||||
|
||||
if #prefix < 2 then
|
||||
callback({ items = {}, isIncomplete = true })
|
||||
return
|
||||
end
|
||||
|
||||
-- Collect completions from brain, indexer, and buffer
|
||||
local items = {}
|
||||
local seen = {}
|
||||
|
||||
-- Get brain completions (highest priority)
|
||||
local ok1, brain_items = pcall(get_brain_completions, prefix)
|
||||
if ok1 and brain_items then
|
||||
for _, item in ipairs(brain_items) do
|
||||
if not seen[item.label] then
|
||||
seen[item.label] = true
|
||||
item.sortText = "1" .. item.label
|
||||
table.insert(items, item)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get indexer completions
|
||||
local ok2, indexer_items = pcall(get_indexer_completions, prefix)
|
||||
if ok2 and indexer_items then
|
||||
for _, item in ipairs(indexer_items) do
|
||||
if not seen[item.label] then
|
||||
seen[item.label] = true
|
||||
item.sortText = "2" .. item.label
|
||||
table.insert(items, item)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get buffer completions as fallback (lower priority)
|
||||
local bufnr = params.context.bufnr
|
||||
if bufnr then
|
||||
local ok3, buffer_items = pcall(get_buffer_completions, prefix, bufnr)
|
||||
if ok3 and buffer_items then
|
||||
for _, item in ipairs(buffer_items) do
|
||||
if not seen[item.label] then
|
||||
seen[item.label] = true
|
||||
item.sortText = "3" .. item.label
|
||||
table.insert(items, item)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback({
|
||||
items = items,
|
||||
isIncomplete = #items >= 50,
|
||||
})
|
||||
end
|
||||
|
||||
--- Setup the completion source
|
||||
function M.setup()
|
||||
if not has_cmp() then
|
||||
return false
|
||||
end
|
||||
|
||||
local cmp = require("cmp")
|
||||
local new_source = source.new()
|
||||
|
||||
-- Register the source
|
||||
cmp.register_source("codetyper", new_source)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Check if source is registered
|
||||
---@return boolean
|
||||
function M.is_registered()
|
||||
local ok, cmp = pcall(require, "cmp")
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Try to get registered sources
|
||||
local config = cmp.get_config()
|
||||
if config and config.sources then
|
||||
for _, src in ipairs(config.sources) do
|
||||
if src.name == "codetyper" then
|
||||
return true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Get source for manual registration
|
||||
function M.get_source()
|
||||
return source
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -164,13 +164,19 @@ local function cmd_status()
|
||||
"Provider: " .. config.llm.provider,
|
||||
}
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
local has_key = (config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY) ~= nil
|
||||
table.insert(status, "Claude API Key: " .. (has_key and "configured" or "NOT SET"))
|
||||
table.insert(status, "Claude Model: " .. config.llm.claude.model)
|
||||
else
|
||||
if config.llm.provider == "ollama" then
|
||||
table.insert(status, "Ollama Host: " .. config.llm.ollama.host)
|
||||
table.insert(status, "Ollama Model: " .. config.llm.ollama.model)
|
||||
elseif config.llm.provider == "openai" then
|
||||
local has_key = (config.llm.openai.api_key or vim.env.OPENAI_API_KEY) ~= nil
|
||||
table.insert(status, "OpenAI API Key: " .. (has_key and "configured" or "NOT SET"))
|
||||
table.insert(status, "OpenAI Model: " .. config.llm.openai.model)
|
||||
elseif config.llm.provider == "gemini" then
|
||||
local has_key = (config.llm.gemini.api_key or vim.env.GEMINI_API_KEY) ~= nil
|
||||
table.insert(status, "Gemini API Key: " .. (has_key and "configured" or "NOT SET"))
|
||||
table.insert(status, "Gemini Model: " .. config.llm.gemini.model)
|
||||
elseif config.llm.provider == "copilot" then
|
||||
table.insert(status, "Copilot Model: " .. config.llm.copilot.model)
|
||||
end
|
||||
|
||||
table.insert(status, "")
|
||||
@@ -281,6 +287,118 @@ local function cmd_agent_stop()
|
||||
end
|
||||
end
|
||||
|
||||
--- Run the agentic loop with a task
|
||||
---@param task string The task to accomplish
|
||||
---@param agent_name? string Optional agent name
|
||||
local function cmd_agentic_run(task, agent_name)
|
||||
local agentic = require("codetyper.agent.agentic")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
-- Open logs panel
|
||||
logs_panel.open()
|
||||
|
||||
logs.info("Starting agentic task: " .. task:sub(1, 50) .. "...")
|
||||
utils.notify("Running agentic task...", vim.log.levels.INFO)
|
||||
|
||||
-- Get current file for context
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
local files = {}
|
||||
if current_file ~= "" then
|
||||
table.insert(files, current_file)
|
||||
end
|
||||
|
||||
agentic.run({
|
||||
task = task,
|
||||
files = files,
|
||||
agent = agent_name or "coder",
|
||||
on_status = function(status)
|
||||
logs.thinking(status)
|
||||
end,
|
||||
on_tool_start = function(name, args)
|
||||
logs.info("Tool: " .. name)
|
||||
end,
|
||||
on_tool_end = function(name, result, err)
|
||||
if err then
|
||||
logs.error(name .. " failed: " .. err)
|
||||
else
|
||||
logs.debug(name .. " completed")
|
||||
end
|
||||
end,
|
||||
on_file_change = function(path, action)
|
||||
logs.info("File " .. action .. ": " .. path)
|
||||
end,
|
||||
on_message = function(msg)
|
||||
if msg.role == "assistant" and type(msg.content) == "string" and msg.content ~= "" then
|
||||
logs.thinking(msg.content:sub(1, 100) .. "...")
|
||||
end
|
||||
end,
|
||||
on_complete = function(result, err)
|
||||
if err then
|
||||
logs.error("Task failed: " .. err)
|
||||
utils.notify("Agentic task failed: " .. err, vim.log.levels.ERROR)
|
||||
else
|
||||
logs.info("Task completed successfully")
|
||||
utils.notify("Agentic task completed!", vim.log.levels.INFO)
|
||||
if result and result ~= "" then
|
||||
-- Show summary in a float
|
||||
vim.schedule(function()
|
||||
vim.notify("Result:\n" .. result:sub(1, 500), vim.log.levels.INFO)
|
||||
end)
|
||||
end
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- List available agents
|
||||
local function cmd_agentic_list()
|
||||
local agentic = require("codetyper.agent.agentic")
|
||||
local agents = agentic.list_agents()
|
||||
|
||||
local lines = {
|
||||
"Available Agents",
|
||||
"================",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, agent in ipairs(agents) do
|
||||
local badge = agent.builtin and "[builtin]" or "[custom]"
|
||||
table.insert(lines, string.format(" %s %s", agent.name, badge))
|
||||
table.insert(lines, string.format(" %s", agent.description))
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
table.insert(lines, "Use :CoderAgenticRun <task> [agent] to run a task")
|
||||
table.insert(lines, "Use :CoderAgenticInit to create custom agents")
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Initialize .coder/agents/ and .coder/rules/ directories
|
||||
local function cmd_agentic_init()
|
||||
local agentic = require("codetyper.agent.agentic")
|
||||
agentic.init()
|
||||
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
|
||||
local lines = {
|
||||
"Initialized Coder directories:",
|
||||
"",
|
||||
" " .. agents_dir,
|
||||
" - example.md (template for custom agents)",
|
||||
"",
|
||||
" " .. rules_dir,
|
||||
" - code-style.md (template for project rules)",
|
||||
"",
|
||||
"Edit these files to customize agent behavior.",
|
||||
"Create new .md files to add more agents/rules.",
|
||||
}
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Show chat type switcher modal (Ask/Agent)
|
||||
local function cmd_type_toggle()
|
||||
local switcher = require("codetyper.chat_switcher")
|
||||
@@ -618,6 +736,131 @@ local function cmd_transform_visual()
|
||||
cmd_transform_range(start_line, end_line)
|
||||
end
|
||||
|
||||
--- Index the entire project
|
||||
local function cmd_index_project()
|
||||
local indexer = require("codetyper.indexer")
|
||||
|
||||
utils.notify("Indexing project...", vim.log.levels.INFO)
|
||||
|
||||
indexer.index_project(function(index)
|
||||
if index then
|
||||
local msg = string.format(
|
||||
"Indexed: %d files, %d functions, %d classes, %d exports",
|
||||
index.stats.files,
|
||||
index.stats.functions,
|
||||
index.stats.classes,
|
||||
index.stats.exports
|
||||
)
|
||||
utils.notify(msg, vim.log.levels.INFO)
|
||||
else
|
||||
utils.notify("Failed to index project", vim.log.levels.ERROR)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Show index status
|
||||
local function cmd_index_status()
|
||||
local indexer = require("codetyper.indexer")
|
||||
local memory = require("codetyper.indexer.memory")
|
||||
|
||||
local status = indexer.get_status()
|
||||
local mem_stats = memory.get_stats()
|
||||
|
||||
local lines = {
|
||||
"Project Index Status",
|
||||
"====================",
|
||||
"",
|
||||
}
|
||||
|
||||
if status.indexed then
|
||||
table.insert(lines, "Status: Indexed")
|
||||
table.insert(lines, "Project Type: " .. (status.project_type or "unknown"))
|
||||
table.insert(lines, "Last Indexed: " .. os.date("%Y-%m-%d %H:%M:%S", status.last_indexed))
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Stats:")
|
||||
table.insert(lines, " Files: " .. (status.stats.files or 0))
|
||||
table.insert(lines, " Functions: " .. (status.stats.functions or 0))
|
||||
table.insert(lines, " Classes: " .. (status.stats.classes or 0))
|
||||
table.insert(lines, " Exports: " .. (status.stats.exports or 0))
|
||||
else
|
||||
table.insert(lines, "Status: Not indexed")
|
||||
table.insert(lines, "Run :CoderIndexProject to index")
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Memories:")
|
||||
table.insert(lines, " Patterns: " .. mem_stats.patterns)
|
||||
table.insert(lines, " Conventions: " .. mem_stats.conventions)
|
||||
table.insert(lines, " Symbols: " .. mem_stats.symbols)
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Show learned memories
|
||||
local function cmd_memories()
|
||||
local memory = require("codetyper.indexer.memory")
|
||||
|
||||
local all = memory.get_all()
|
||||
local lines = {
|
||||
"Learned Memories",
|
||||
"================",
|
||||
"",
|
||||
"Patterns:",
|
||||
}
|
||||
|
||||
local pattern_count = 0
|
||||
for _, mem in pairs(all.patterns) do
|
||||
pattern_count = pattern_count + 1
|
||||
if pattern_count <= 10 then
|
||||
table.insert(lines, " - " .. (mem.content or ""):sub(1, 60))
|
||||
end
|
||||
end
|
||||
if pattern_count > 10 then
|
||||
table.insert(lines, " ... and " .. (pattern_count - 10) .. " more")
|
||||
elseif pattern_count == 0 then
|
||||
table.insert(lines, " (none)")
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Conventions:")
|
||||
|
||||
local conv_count = 0
|
||||
for _, mem in pairs(all.conventions) do
|
||||
conv_count = conv_count + 1
|
||||
if conv_count <= 10 then
|
||||
table.insert(lines, " - " .. (mem.content or ""):sub(1, 60))
|
||||
end
|
||||
end
|
||||
if conv_count > 10 then
|
||||
table.insert(lines, " ... and " .. (conv_count - 10) .. " more")
|
||||
elseif conv_count == 0 then
|
||||
table.insert(lines, " (none)")
|
||||
end
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Clear memories
|
||||
---@param pattern string|nil Optional pattern to match
|
||||
local function cmd_forget(pattern)
|
||||
local memory = require("codetyper.indexer.memory")
|
||||
|
||||
if not pattern or pattern == "" then
|
||||
-- Confirm before clearing all
|
||||
vim.ui.select({ "Yes", "No" }, {
|
||||
prompt = "Clear all memories?",
|
||||
}, function(choice)
|
||||
if choice == "Yes" then
|
||||
memory.clear()
|
||||
utils.notify("All memories cleared", vim.log.levels.INFO)
|
||||
end
|
||||
end)
|
||||
else
|
||||
memory.clear(pattern)
|
||||
utils.notify("Cleared memories matching: " .. pattern, vim.log.levels.INFO)
|
||||
end
|
||||
end
|
||||
|
||||
--- Transform a single prompt at cursor position
|
||||
local function cmd_transform_at_cursor()
|
||||
local parser = require("codetyper.parser")
|
||||
@@ -713,6 +956,65 @@ end
|
||||
|
||||
--- Main command dispatcher
|
||||
---@param args table Command arguments
|
||||
--- Show LLM accuracy statistics
|
||||
local function cmd_llm_stats()
|
||||
local llm = require("codetyper.llm")
|
||||
local stats = llm.get_accuracy_stats()
|
||||
|
||||
local lines = {
|
||||
"LLM Provider Accuracy Statistics",
|
||||
"================================",
|
||||
"",
|
||||
string.format("Ollama:"),
|
||||
string.format(" Total requests: %d", stats.ollama.total),
|
||||
string.format(" Correct: %d", stats.ollama.correct),
|
||||
string.format(" Accuracy: %.1f%%", stats.ollama.accuracy * 100),
|
||||
"",
|
||||
string.format("Copilot:"),
|
||||
string.format(" Total requests: %d", stats.copilot.total),
|
||||
string.format(" Correct: %d", stats.copilot.correct),
|
||||
string.format(" Accuracy: %.1f%%", stats.copilot.accuracy * 100),
|
||||
"",
|
||||
"Note: Smart selection prefers Ollama when brain memories",
|
||||
"provide enough context. Accuracy improves over time via",
|
||||
"pondering (verification with other LLMs).",
|
||||
}
|
||||
|
||||
vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Report feedback on last LLM response
|
||||
---@param was_good boolean Whether the response was good
|
||||
local function cmd_llm_feedback(was_good)
|
||||
local llm = require("codetyper.llm")
|
||||
-- Get the last used provider from logs or default
|
||||
local provider = "ollama" -- Default assumption
|
||||
|
||||
-- Try to get actual last provider from logs
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local entries = logs.get(10)
|
||||
for i = #entries, 1, -1 do
|
||||
local entry = entries[i]
|
||||
if entry.message and entry.message:match("^LLM:") then
|
||||
provider = entry.message:match("LLM: (%w+)") or provider
|
||||
break
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
llm.report_feedback(provider, was_good)
|
||||
local feedback_type = was_good and "positive" or "negative"
|
||||
utils.notify(string.format("Reported %s feedback for %s", feedback_type, provider), vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Reset LLM accuracy statistics
|
||||
local function cmd_llm_reset_stats()
|
||||
local selector = require("codetyper.llm.selector")
|
||||
selector.reset_accuracy_stats()
|
||||
utils.notify("LLM accuracy statistics reset", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
local function coder_cmd(args)
|
||||
local subcommand = args.fargs[1] or "toggle"
|
||||
|
||||
@@ -741,6 +1043,23 @@ local function coder_cmd(args)
|
||||
["logs-toggle"] = cmd_logs_toggle,
|
||||
["queue-status"] = cmd_queue_status,
|
||||
["queue-process"] = cmd_queue_process,
|
||||
-- Agentic commands
|
||||
["agentic-run"] = function(args)
|
||||
local task = table.concat(vim.list_slice(args.fargs, 2), " ")
|
||||
if task == "" then
|
||||
utils.notify("Usage: Coder agentic-run <task> [agent]", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
cmd_agentic_run(task)
|
||||
end,
|
||||
["agentic-list"] = cmd_agentic_list,
|
||||
["agentic-init"] = cmd_agentic_init,
|
||||
["index-project"] = cmd_index_project,
|
||||
["index-status"] = cmd_index_status,
|
||||
memories = cmd_memories,
|
||||
forget = function(args)
|
||||
cmd_forget(args.fargs[2])
|
||||
end,
|
||||
["auto-toggle"] = function()
|
||||
local preferences = require("codetyper.preferences")
|
||||
preferences.toggle_auto_process()
|
||||
@@ -764,6 +1083,41 @@ local function coder_cmd(args)
|
||||
end
|
||||
end
|
||||
end,
|
||||
-- LLM smart selection commands
|
||||
["llm-stats"] = cmd_llm_stats,
|
||||
["llm-feedback-good"] = function()
|
||||
cmd_llm_feedback(true)
|
||||
end,
|
||||
["llm-feedback-bad"] = function()
|
||||
cmd_llm_feedback(false)
|
||||
end,
|
||||
["llm-reset-stats"] = cmd_llm_reset_stats,
|
||||
-- Cost tracking commands
|
||||
["cost"] = function()
|
||||
local cost = require("codetyper.cost")
|
||||
cost.toggle()
|
||||
end,
|
||||
["cost-clear"] = function()
|
||||
local cost = require("codetyper.cost")
|
||||
cost.clear()
|
||||
end,
|
||||
-- Credentials management commands
|
||||
["add-api-key"] = function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_add()
|
||||
end,
|
||||
["remove-api-key"] = function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_remove()
|
||||
end,
|
||||
["credentials"] = function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.show_status()
|
||||
end,
|
||||
["switch-provider"] = function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_switch_provider()
|
||||
end,
|
||||
}
|
||||
|
||||
local cmd_fn = commands[subcommand]
|
||||
@@ -785,9 +1139,14 @@ function M.setup()
|
||||
"ask", "ask-close", "ask-toggle", "ask-clear",
|
||||
"transform", "transform-cursor",
|
||||
"agent", "agent-close", "agent-toggle", "agent-stop",
|
||||
"agentic-run", "agentic-list", "agentic-init",
|
||||
"type-toggle", "logs-toggle",
|
||||
"queue-status", "queue-process",
|
||||
"index-project", "index-status", "memories", "forget",
|
||||
"auto-toggle", "auto-set",
|
||||
"llm-stats", "llm-feedback-good", "llm-feedback-bad", "llm-reset-stats",
|
||||
"cost", "cost-clear",
|
||||
"add-api-key", "remove-api-key", "credentials", "switch-provider",
|
||||
}
|
||||
end,
|
||||
desc = "Codetyper.nvim commands",
|
||||
@@ -859,6 +1218,31 @@ function M.setup()
|
||||
cmd_agent_stop()
|
||||
end, { desc = "Stop running agent" })
|
||||
|
||||
-- Agentic commands (full IDE-like agent functionality)
|
||||
vim.api.nvim_create_user_command("CoderAgenticRun", function(opts)
|
||||
local task = opts.args
|
||||
if task == "" then
|
||||
vim.ui.input({ prompt = "Task: " }, function(input)
|
||||
if input and input ~= "" then
|
||||
cmd_agentic_run(input)
|
||||
end
|
||||
end)
|
||||
else
|
||||
cmd_agentic_run(task)
|
||||
end
|
||||
end, {
|
||||
desc = "Run agentic task (IDE-like multi-file changes)",
|
||||
nargs = "*",
|
||||
})
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAgenticList", function()
|
||||
cmd_agentic_list()
|
||||
end, { desc = "List available agents" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAgenticInit", function()
|
||||
cmd_agentic_init()
|
||||
end, { desc = "Initialize .coder/agents/ and .coder/rules/ directories" })
|
||||
|
||||
-- Chat type switcher command
|
||||
vim.api.nvim_create_user_command("CoderType", function()
|
||||
cmd_type_toggle()
|
||||
@@ -875,6 +1259,26 @@ function M.setup()
|
||||
autocmds.open_coder_companion()
|
||||
end, { desc = "Open coder companion for current file" })
|
||||
|
||||
-- Project indexer commands
|
||||
vim.api.nvim_create_user_command("CoderIndexProject", function()
|
||||
cmd_index_project()
|
||||
end, { desc = "Index the entire project" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderIndexStatus", function()
|
||||
cmd_index_status()
|
||||
end, { desc = "Show project index status" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderMemories", function()
|
||||
cmd_memories()
|
||||
end, { desc = "Show learned memories" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderForget", function(opts)
|
||||
cmd_forget(opts.args ~= "" and opts.args or nil)
|
||||
end, {
|
||||
desc = "Clear memories (optionally matching pattern)",
|
||||
nargs = "?",
|
||||
})
|
||||
|
||||
-- Queue commands
|
||||
vim.api.nvim_create_user_command("CoderQueueStatus", function()
|
||||
cmd_queue_status()
|
||||
@@ -917,6 +1321,147 @@ function M.setup()
|
||||
end,
|
||||
})
|
||||
|
||||
-- Brain feedback command - teach the brain from your experience
|
||||
vim.api.nvim_create_user_command("CoderFeedback", function(opts)
|
||||
local brain = require("codetyper.brain")
|
||||
if not brain.is_initialized() then
|
||||
vim.notify("Brain not initialized", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local feedback_type = opts.args:lower()
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
|
||||
if feedback_type == "good" or feedback_type == "accept" or feedback_type == "+" then
|
||||
-- Learn positive feedback
|
||||
brain.learn({
|
||||
type = "user_feedback",
|
||||
file = current_file,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
feedback = "accepted",
|
||||
description = "User marked code as good/accepted",
|
||||
},
|
||||
})
|
||||
vim.notify("Brain: Learned positive feedback ✓", vim.log.levels.INFO)
|
||||
|
||||
elseif feedback_type == "bad" or feedback_type == "reject" or feedback_type == "-" then
|
||||
-- Learn negative feedback
|
||||
brain.learn({
|
||||
type = "user_feedback",
|
||||
file = current_file,
|
||||
timestamp = os.time(),
|
||||
data = {
|
||||
feedback = "rejected",
|
||||
description = "User marked code as bad/rejected",
|
||||
},
|
||||
})
|
||||
vim.notify("Brain: Learned negative feedback ✗", vim.log.levels.INFO)
|
||||
|
||||
elseif feedback_type == "stats" or feedback_type == "status" then
|
||||
-- Show brain stats
|
||||
local stats = brain.stats()
|
||||
local msg = string.format(
|
||||
"Brain Stats:\n• Nodes: %d\n• Edges: %d\n• Pending: %d\n• Deltas: %d",
|
||||
stats.node_count or 0,
|
||||
stats.edge_count or 0,
|
||||
stats.pending_changes or 0,
|
||||
stats.delta_count or 0
|
||||
)
|
||||
vim.notify(msg, vim.log.levels.INFO)
|
||||
|
||||
else
|
||||
vim.notify("Usage: CoderFeedback <good|bad|stats>", vim.log.levels.INFO)
|
||||
end
|
||||
end, {
|
||||
desc = "Give feedback to the brain (good/bad/stats)",
|
||||
nargs = "?",
|
||||
complete = function()
|
||||
return { "good", "bad", "stats" }
|
||||
end,
|
||||
})
|
||||
|
||||
-- Brain stats command
|
||||
vim.api.nvim_create_user_command("CoderBrain", function(opts)
|
||||
local brain = require("codetyper.brain")
|
||||
if not brain.is_initialized() then
|
||||
vim.notify("Brain not initialized", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local action = opts.args:lower()
|
||||
|
||||
if action == "stats" or action == "" then
|
||||
local stats = brain.stats()
|
||||
local lines = {
|
||||
"╭─────────────────────────────────╮",
|
||||
"│ CODETYPER BRAIN │",
|
||||
"╰─────────────────────────────────╯",
|
||||
"",
|
||||
string.format(" Nodes: %d", stats.node_count or 0),
|
||||
string.format(" Edges: %d", stats.edge_count or 0),
|
||||
string.format(" Deltas: %d", stats.delta_count or 0),
|
||||
string.format(" Pending: %d", stats.pending_changes or 0),
|
||||
"",
|
||||
" The more you use Codetyper,",
|
||||
" the smarter it becomes!",
|
||||
}
|
||||
vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO)
|
||||
|
||||
elseif action == "commit" then
|
||||
local hash = brain.commit("Manual commit")
|
||||
if hash then
|
||||
vim.notify("Brain: Committed changes (hash: " .. hash:sub(1, 8) .. ")", vim.log.levels.INFO)
|
||||
else
|
||||
vim.notify("Brain: Nothing to commit", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
elseif action == "flush" then
|
||||
brain.flush()
|
||||
vim.notify("Brain: Flushed to disk", vim.log.levels.INFO)
|
||||
|
||||
elseif action == "prune" then
|
||||
local pruned = brain.prune()
|
||||
vim.notify("Brain: Pruned " .. pruned .. " low-value nodes", vim.log.levels.INFO)
|
||||
|
||||
else
|
||||
vim.notify("Usage: CoderBrain <stats|commit|flush|prune>", vim.log.levels.INFO)
|
||||
end
|
||||
end, {
|
||||
desc = "Brain management commands",
|
||||
nargs = "?",
|
||||
complete = function()
|
||||
return { "stats", "commit", "flush", "prune" }
|
||||
end,
|
||||
})
|
||||
|
||||
-- Cost estimation command
|
||||
vim.api.nvim_create_user_command("CoderCost", function()
|
||||
local cost = require("codetyper.cost")
|
||||
cost.toggle()
|
||||
end, { desc = "Show LLM cost estimation window" })
|
||||
|
||||
-- Credentials management commands
|
||||
vim.api.nvim_create_user_command("CoderAddApiKey", function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_add()
|
||||
end, { desc = "Add or update LLM provider API key" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderRemoveApiKey", function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_remove()
|
||||
end, { desc = "Remove LLM provider credentials" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderCredentials", function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.show_status()
|
||||
end, { desc = "Show credentials status" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderSwitchProvider", function()
|
||||
local credentials = require("codetyper.credentials")
|
||||
credentials.interactive_switch_provider()
|
||||
end, { desc = "Switch active LLM provider" })
|
||||
|
||||
-- Setup default keymaps
|
||||
M.setup_keymaps()
|
||||
end
|
||||
|
||||
@@ -5,11 +5,7 @@ local M = {}
|
||||
---@type CoderConfig
|
||||
local defaults = {
|
||||
llm = {
|
||||
provider = "ollama", -- Options: "claude", "ollama", "openai", "gemini", "copilot"
|
||||
claude = {
|
||||
api_key = nil, -- Will use ANTHROPIC_API_KEY env var if nil
|
||||
model = "claude-sonnet-4-20250514",
|
||||
},
|
||||
provider = "ollama", -- Options: "ollama", "openai", "gemini", "copilot"
|
||||
ollama = {
|
||||
host = "http://localhost:11434",
|
||||
model = "deepseek-coder:6.7b",
|
||||
@@ -48,6 +44,48 @@ local defaults = {
|
||||
completion_delay_ms = 100, -- Wait after completion popup closes
|
||||
apply_delay_ms = 5000, -- Wait before removing tags and applying code (ms)
|
||||
},
|
||||
indexer = {
|
||||
enabled = true, -- Enable project indexing
|
||||
auto_index = true, -- Index files on save
|
||||
index_on_open = false, -- Index project when opening
|
||||
max_file_size = 100000, -- Skip files larger than 100KB
|
||||
excluded_dirs = { "node_modules", "dist", "build", ".git", ".coder", "__pycache__", "vendor", "target" },
|
||||
index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" },
|
||||
memory = {
|
||||
enabled = true, -- Enable memory persistence
|
||||
max_memories = 1000, -- Maximum stored memories
|
||||
prune_threshold = 0.1, -- Remove low-weight memories
|
||||
},
|
||||
},
|
||||
brain = {
|
||||
enabled = true, -- Enable brain learning system
|
||||
auto_learn = true, -- Auto-learn from events
|
||||
auto_commit = true, -- Auto-commit after threshold
|
||||
commit_threshold = 10, -- Changes before auto-commit
|
||||
max_nodes = 5000, -- Maximum nodes before pruning
|
||||
max_deltas = 500, -- Maximum delta history
|
||||
prune = {
|
||||
enabled = true, -- Enable auto-pruning
|
||||
threshold = 0.1, -- Remove nodes below this weight
|
||||
unused_days = 90, -- Remove unused nodes after N days
|
||||
},
|
||||
output = {
|
||||
max_tokens = 4000, -- Token budget for LLM context
|
||||
format = "compact", -- "compact"|"json"|"natural"
|
||||
},
|
||||
},
|
||||
suggestion = {
|
||||
enabled = true, -- Enable ghost text suggestions (Copilot-style)
|
||||
auto_trigger = true, -- Auto-trigger on typing
|
||||
debounce = 150, -- Debounce in milliseconds
|
||||
use_copilot = true, -- Use copilot.lua suggestions when available, fallback to codetyper
|
||||
keymap = {
|
||||
accept = "<Tab>", -- Accept suggestion
|
||||
next = "<M-]>", -- Next suggestion (Alt+])
|
||||
prev = "<M-[>", -- Previous suggestion (Alt+[)
|
||||
dismiss = "<C-]>", -- Dismiss suggestion (Ctrl+])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
--- Deep merge two tables
|
||||
@@ -88,7 +126,7 @@ function M.validate(config)
|
||||
return false, "Missing LLM configuration"
|
||||
end
|
||||
|
||||
local valid_providers = { "claude", "ollama", "openai", "gemini", "copilot" }
|
||||
local valid_providers = { "ollama", "openai", "gemini", "copilot" }
|
||||
local is_valid_provider = false
|
||||
for _, p in ipairs(valid_providers) do
|
||||
if config.llm.provider == p then
|
||||
@@ -102,12 +140,7 @@ function M.validate(config)
|
||||
end
|
||||
|
||||
-- Validate provider-specific configuration
|
||||
if config.llm.provider == "claude" then
|
||||
local api_key = config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY
|
||||
if not api_key or api_key == "" then
|
||||
return false, "Claude API key not configured. Set llm.claude.api_key or ANTHROPIC_API_KEY env var"
|
||||
end
|
||||
elseif config.llm.provider == "openai" then
|
||||
if config.llm.provider == "openai" then
|
||||
local api_key = config.llm.openai.api_key or vim.env.OPENAI_API_KEY
|
||||
if not api_key or api_key == "" then
|
||||
return false, "OpenAI API key not configured. Set llm.openai.api_key or OPENAI_API_KEY env var"
|
||||
|
||||
750
lua/codetyper/cost.lua
Normal file
750
lua/codetyper/cost.lua
Normal file
@@ -0,0 +1,750 @@
|
||||
---@mod codetyper.cost Cost estimation for LLM usage
|
||||
---@brief [[
|
||||
--- Tracks token usage and estimates costs based on model pricing.
|
||||
--- Prices are per 1M tokens. Persists usage data in the brain.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
--- Cost history file name
|
||||
local COST_HISTORY_FILE = "cost_history.json"
|
||||
|
||||
--- Get path to cost history file
|
||||
---@return string File path
|
||||
local function get_history_path()
|
||||
local root = utils.get_project_root()
|
||||
return root .. "/.coder/" .. COST_HISTORY_FILE
|
||||
end
|
||||
|
||||
--- Default model for savings comparison (what you'd pay if not using Ollama)
|
||||
M.comparison_model = "gpt-4o"
|
||||
|
||||
--- Models considered "free" (Ollama, local, Copilot subscription)
|
||||
M.free_models = {
|
||||
["ollama"] = true,
|
||||
["codellama"] = true,
|
||||
["llama2"] = true,
|
||||
["llama3"] = true,
|
||||
["mistral"] = true,
|
||||
["deepseek-coder"] = true,
|
||||
["copilot"] = true,
|
||||
}
|
||||
|
||||
--- Model pricing table (per 1M tokens in USD)
|
||||
---@type table<string, {input: number, cached_input: number|nil, output: number|nil}>
|
||||
M.pricing = {
|
||||
-- GPT-5.x series
|
||||
["gpt-5.2"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 },
|
||||
["gpt-5-nano"] = { input = 0.05, cached_input = 0.005, output = 0.40 },
|
||||
["gpt-5.2-chat-latest"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.2-codex"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1-codex-max"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.1-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.2-pro"] = { input = 21.00, cached_input = nil, output = 168.00 },
|
||||
["gpt-5-pro"] = { input = 15.00, cached_input = nil, output = 120.00 },
|
||||
["gpt-5.1-codex-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 },
|
||||
["gpt-5-search-api"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
|
||||
-- GPT-4.x series
|
||||
["gpt-4.1"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["gpt-4.1-mini"] = { input = 0.40, cached_input = 0.10, output = 1.60 },
|
||||
["gpt-4.1-nano"] = { input = 0.10, cached_input = 0.025, output = 0.40 },
|
||||
["gpt-4o"] = { input = 2.50, cached_input = 1.25, output = 10.00 },
|
||||
["gpt-4o-2024-05-13"] = { input = 5.00, cached_input = nil, output = 15.00 },
|
||||
["gpt-4o-mini"] = { input = 0.15, cached_input = 0.075, output = 0.60 },
|
||||
|
||||
-- Realtime models
|
||||
["gpt-realtime"] = { input = 4.00, cached_input = 0.40, output = 16.00 },
|
||||
["gpt-realtime-mini"] = { input = 0.60, cached_input = 0.06, output = 2.40 },
|
||||
["gpt-4o-realtime-preview"] = { input = 5.00, cached_input = 2.50, output = 20.00 },
|
||||
["gpt-4o-mini-realtime-preview"] = { input = 0.60, cached_input = 0.30, output = 2.40 },
|
||||
|
||||
-- Audio models
|
||||
["gpt-audio"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
["gpt-audio-mini"] = { input = 0.60, cached_input = nil, output = 2.40 },
|
||||
["gpt-4o-audio-preview"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
["gpt-4o-mini-audio-preview"] = { input = 0.15, cached_input = nil, output = 0.60 },
|
||||
|
||||
-- O-series reasoning models
|
||||
["o1"] = { input = 15.00, cached_input = 7.50, output = 60.00 },
|
||||
["o1-pro"] = { input = 150.00, cached_input = nil, output = 600.00 },
|
||||
["o3-pro"] = { input = 20.00, cached_input = nil, output = 80.00 },
|
||||
["o3"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["o3-deep-research"] = { input = 10.00, cached_input = 2.50, output = 40.00 },
|
||||
["o4-mini"] = { input = 1.10, cached_input = 0.275, output = 4.40 },
|
||||
["o4-mini-deep-research"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["o3-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 },
|
||||
["o1-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 },
|
||||
|
||||
-- Codex
|
||||
["codex-mini-latest"] = { input = 1.50, cached_input = 0.375, output = 6.00 },
|
||||
|
||||
-- Search models
|
||||
["gpt-4o-mini-search-preview"] = { input = 0.15, cached_input = nil, output = 0.60 },
|
||||
["gpt-4o-search-preview"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
|
||||
-- Computer use
|
||||
["computer-use-preview"] = { input = 3.00, cached_input = nil, output = 12.00 },
|
||||
|
||||
-- Image models
|
||||
["gpt-image-1.5"] = { input = 5.00, cached_input = 1.25, output = 10.00 },
|
||||
["chatgpt-image-latest"] = { input = 5.00, cached_input = 1.25, output = 10.00 },
|
||||
["gpt-image-1"] = { input = 5.00, cached_input = 1.25, output = nil },
|
||||
["gpt-image-1-mini"] = { input = 2.00, cached_input = 0.20, output = nil },
|
||||
|
||||
-- Claude models (Anthropic)
|
||||
["claude-3-opus"] = { input = 15.00, cached_input = 7.50, output = 75.00 },
|
||||
["claude-3-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 },
|
||||
["claude-3-haiku"] = { input = 0.25, cached_input = 0.125, output = 1.25 },
|
||||
["claude-3.5-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 },
|
||||
["claude-3.5-haiku"] = { input = 0.80, cached_input = 0.40, output = 4.00 },
|
||||
|
||||
-- Ollama/Local models (free)
|
||||
["ollama"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["codellama"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["llama2"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["llama3"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["mistral"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["deepseek-coder"] = { input = 0, cached_input = 0, output = 0 },
|
||||
|
||||
-- Copilot (included in subscription, but tracking usage)
|
||||
["copilot"] = { input = 0, cached_input = 0, output = 0 },
|
||||
}
|
||||
|
||||
---@class CostUsage
|
||||
---@field model string Model name
|
||||
---@field input_tokens number Input tokens used
|
||||
---@field output_tokens number Output tokens used
|
||||
---@field cached_tokens number Cached input tokens
|
||||
---@field timestamp number Unix timestamp
|
||||
---@field cost number Calculated cost in USD
|
||||
|
||||
---@class CostState
|
||||
---@field usage CostUsage[] Current session usage
|
||||
---@field all_usage CostUsage[] All historical usage from brain
|
||||
---@field session_start number Session start timestamp
|
||||
---@field win number|nil Window handle
|
||||
---@field buf number|nil Buffer handle
|
||||
---@field loaded boolean Whether historical data has been loaded
|
||||
local state = {
|
||||
usage = {},
|
||||
all_usage = {},
|
||||
session_start = os.time(),
|
||||
win = nil,
|
||||
buf = nil,
|
||||
loaded = false,
|
||||
}
|
||||
|
||||
--- Load historical usage from disk
|
||||
function M.load_from_history()
|
||||
if state.loaded then
|
||||
return
|
||||
end
|
||||
|
||||
local history_path = get_history_path()
|
||||
local content = utils.read_file(history_path)
|
||||
|
||||
if content and content ~= "" then
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if ok and data and data.usage then
|
||||
state.all_usage = data.usage
|
||||
end
|
||||
end
|
||||
|
||||
state.loaded = true
|
||||
end
|
||||
|
||||
--- Save all usage to disk (debounced)
|
||||
local save_timer = nil
|
||||
local function save_to_disk()
|
||||
-- Cancel existing timer
|
||||
if save_timer then
|
||||
save_timer:stop()
|
||||
save_timer = nil
|
||||
end
|
||||
|
||||
-- Debounce writes (500ms)
|
||||
save_timer = vim.loop.new_timer()
|
||||
save_timer:start(500, 0, vim.schedule_wrap(function()
|
||||
local history_path = get_history_path()
|
||||
|
||||
-- Ensure directory exists
|
||||
local dir = vim.fn.fnamemodify(history_path, ":h")
|
||||
utils.ensure_dir(dir)
|
||||
|
||||
-- Merge session and historical usage
|
||||
local all_data = vim.deepcopy(state.all_usage)
|
||||
for _, usage in ipairs(state.usage) do
|
||||
table.insert(all_data, usage)
|
||||
end
|
||||
|
||||
-- Save to file
|
||||
local data = {
|
||||
version = 1,
|
||||
updated = os.time(),
|
||||
usage = all_data,
|
||||
}
|
||||
|
||||
local ok, json = pcall(vim.json.encode, data)
|
||||
if ok then
|
||||
utils.write_file(history_path, json)
|
||||
end
|
||||
|
||||
save_timer = nil
|
||||
end))
|
||||
end
|
||||
|
||||
--- Normalize model name for pricing lookup
|
||||
---@param model string Model name from API
|
||||
---@return string Normalized model name
|
||||
local function normalize_model(model)
|
||||
if not model then
|
||||
return "unknown"
|
||||
end
|
||||
|
||||
-- Convert to lowercase
|
||||
local normalized = model:lower()
|
||||
|
||||
-- Handle Copilot models
|
||||
if normalized:match("copilot") then
|
||||
return "copilot"
|
||||
end
|
||||
|
||||
-- Handle common prefixes
|
||||
normalized = normalized:gsub("^openai/", "")
|
||||
normalized = normalized:gsub("^anthropic/", "")
|
||||
|
||||
-- Try exact match first
|
||||
if M.pricing[normalized] then
|
||||
return normalized
|
||||
end
|
||||
|
||||
-- Try partial matches
|
||||
for price_model, _ in pairs(M.pricing) do
|
||||
if normalized:match(price_model) or price_model:match(normalized) then
|
||||
return price_model
|
||||
end
|
||||
end
|
||||
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if a model is considered "free" (local/Ollama/Copilot subscription)
|
||||
---@param model string Model name
|
||||
---@return boolean True if free
|
||||
function M.is_free_model(model)
|
||||
local normalized = normalize_model(model)
|
||||
|
||||
-- Check direct match
|
||||
if M.free_models[normalized] then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check if it's an Ollama model (any model with : in name like deepseek-coder:6.7b)
|
||||
if model:match(":") then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check pricing - if cost is 0, it's free
|
||||
local pricing = M.pricing[normalized]
|
||||
if pricing and pricing.input == 0 and pricing.output == 0 then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Calculate cost for token usage
|
||||
---@param model string Model name
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
---@return number Cost in USD
|
||||
function M.calculate_cost(model, input_tokens, output_tokens, cached_tokens)
|
||||
local normalized = normalize_model(model)
|
||||
local pricing = M.pricing[normalized]
|
||||
|
||||
if not pricing then
|
||||
-- Unknown model, return 0
|
||||
return 0
|
||||
end
|
||||
|
||||
cached_tokens = cached_tokens or 0
|
||||
local regular_input = input_tokens - cached_tokens
|
||||
|
||||
-- Calculate cost (prices are per 1M tokens)
|
||||
local input_cost = (regular_input / 1000000) * (pricing.input or 0)
|
||||
local cached_cost = (cached_tokens / 1000000) * (pricing.cached_input or pricing.input or 0)
|
||||
local output_cost = (output_tokens / 1000000) * (pricing.output or 0)
|
||||
|
||||
return input_cost + cached_cost + output_cost
|
||||
end
|
||||
|
||||
--- Calculate estimated savings (what would have been paid if using comparison model)
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
---@return number Estimated savings in USD
|
||||
function M.calculate_savings(input_tokens, output_tokens, cached_tokens)
|
||||
-- Calculate what it would have cost with the comparison model
|
||||
return M.calculate_cost(M.comparison_model, input_tokens, output_tokens, cached_tokens)
|
||||
end
|
||||
|
||||
--- Record token usage
|
||||
---@param model string Model name
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
function M.record_usage(model, input_tokens, output_tokens, cached_tokens)
|
||||
cached_tokens = cached_tokens or 0
|
||||
local cost = M.calculate_cost(model, input_tokens, output_tokens, cached_tokens)
|
||||
|
||||
-- Calculate savings if using a free model
|
||||
local savings = 0
|
||||
if M.is_free_model(model) then
|
||||
savings = M.calculate_savings(input_tokens, output_tokens, cached_tokens)
|
||||
end
|
||||
|
||||
table.insert(state.usage, {
|
||||
model = model,
|
||||
input_tokens = input_tokens,
|
||||
output_tokens = output_tokens,
|
||||
cached_tokens = cached_tokens,
|
||||
timestamp = os.time(),
|
||||
cost = cost,
|
||||
savings = savings,
|
||||
is_free = M.is_free_model(model),
|
||||
})
|
||||
|
||||
-- Save to disk (debounced)
|
||||
save_to_disk()
|
||||
|
||||
-- Update window if open
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
M.refresh_window()
|
||||
end
|
||||
end
|
||||
|
||||
--- Aggregate usage data into stats
|
||||
---@param usage_list CostUsage[] List of usage records
|
||||
---@return table Stats
|
||||
local function aggregate_usage(usage_list)
|
||||
local stats = {
|
||||
total_input = 0,
|
||||
total_output = 0,
|
||||
total_cached = 0,
|
||||
total_cost = 0,
|
||||
total_savings = 0,
|
||||
free_requests = 0,
|
||||
paid_requests = 0,
|
||||
by_model = {},
|
||||
request_count = #usage_list,
|
||||
}
|
||||
|
||||
for _, usage in ipairs(usage_list) do
|
||||
stats.total_input = stats.total_input + (usage.input_tokens or 0)
|
||||
stats.total_output = stats.total_output + (usage.output_tokens or 0)
|
||||
stats.total_cached = stats.total_cached + (usage.cached_tokens or 0)
|
||||
stats.total_cost = stats.total_cost + (usage.cost or 0)
|
||||
|
||||
-- Track savings
|
||||
local usage_savings = usage.savings or 0
|
||||
-- For historical data without savings field, calculate it
|
||||
if usage_savings == 0 and usage.is_free == nil then
|
||||
local model = usage.model or "unknown"
|
||||
if M.is_free_model(model) then
|
||||
usage_savings = M.calculate_savings(
|
||||
usage.input_tokens or 0,
|
||||
usage.output_tokens or 0,
|
||||
usage.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
end
|
||||
stats.total_savings = stats.total_savings + usage_savings
|
||||
|
||||
-- Track free vs paid
|
||||
local is_free = usage.is_free
|
||||
if is_free == nil then
|
||||
is_free = M.is_free_model(usage.model or "unknown")
|
||||
end
|
||||
if is_free then
|
||||
stats.free_requests = stats.free_requests + 1
|
||||
else
|
||||
stats.paid_requests = stats.paid_requests + 1
|
||||
end
|
||||
|
||||
local model = usage.model or "unknown"
|
||||
if not stats.by_model[model] then
|
||||
stats.by_model[model] = {
|
||||
input_tokens = 0,
|
||||
output_tokens = 0,
|
||||
cached_tokens = 0,
|
||||
cost = 0,
|
||||
savings = 0,
|
||||
requests = 0,
|
||||
is_free = is_free,
|
||||
}
|
||||
end
|
||||
|
||||
stats.by_model[model].input_tokens = stats.by_model[model].input_tokens + (usage.input_tokens or 0)
|
||||
stats.by_model[model].output_tokens = stats.by_model[model].output_tokens + (usage.output_tokens or 0)
|
||||
stats.by_model[model].cached_tokens = stats.by_model[model].cached_tokens + (usage.cached_tokens or 0)
|
||||
stats.by_model[model].cost = stats.by_model[model].cost + (usage.cost or 0)
|
||||
stats.by_model[model].savings = stats.by_model[model].savings + usage_savings
|
||||
stats.by_model[model].requests = stats.by_model[model].requests + 1
|
||||
end
|
||||
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Get session statistics
|
||||
---@return table Statistics
|
||||
function M.get_stats()
|
||||
local stats = aggregate_usage(state.usage)
|
||||
stats.session_duration = os.time() - state.session_start
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Get all-time statistics (session + historical)
|
||||
---@return table Statistics
|
||||
function M.get_all_time_stats()
|
||||
-- Load history if not loaded
|
||||
M.load_from_history()
|
||||
|
||||
-- Combine session and historical usage
|
||||
local all_usage = vim.deepcopy(state.all_usage)
|
||||
for _, usage in ipairs(state.usage) do
|
||||
table.insert(all_usage, usage)
|
||||
end
|
||||
|
||||
local stats = aggregate_usage(all_usage)
|
||||
|
||||
-- Calculate time span
|
||||
if #all_usage > 0 then
|
||||
local oldest = all_usage[1].timestamp or os.time()
|
||||
for _, usage in ipairs(all_usage) do
|
||||
if usage.timestamp and usage.timestamp < oldest then
|
||||
oldest = usage.timestamp
|
||||
end
|
||||
end
|
||||
stats.time_span = os.time() - oldest
|
||||
else
|
||||
stats.time_span = 0
|
||||
end
|
||||
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Format cost as string
|
||||
---@param cost number Cost in USD
|
||||
---@return string Formatted cost
|
||||
local function format_cost(cost)
|
||||
if cost < 0.01 then
|
||||
return string.format("$%.4f", cost)
|
||||
elseif cost < 1 then
|
||||
return string.format("$%.3f", cost)
|
||||
else
|
||||
return string.format("$%.2f", cost)
|
||||
end
|
||||
end
|
||||
|
||||
--- Format token count
|
||||
---@param tokens number Token count
|
||||
---@return string Formatted count
|
||||
local function format_tokens(tokens)
|
||||
if tokens >= 1000000 then
|
||||
return string.format("%.2fM", tokens / 1000000)
|
||||
elseif tokens >= 1000 then
|
||||
return string.format("%.1fK", tokens / 1000)
|
||||
else
|
||||
return tostring(tokens)
|
||||
end
|
||||
end
|
||||
|
||||
--- Format duration
|
||||
---@param seconds number Duration in seconds
|
||||
---@return string Formatted duration
|
||||
local function format_duration(seconds)
|
||||
if seconds < 60 then
|
||||
return string.format("%ds", seconds)
|
||||
elseif seconds < 3600 then
|
||||
return string.format("%dm %ds", math.floor(seconds / 60), seconds % 60)
|
||||
else
|
||||
local hours = math.floor(seconds / 3600)
|
||||
local mins = math.floor((seconds % 3600) / 60)
|
||||
return string.format("%dh %dm", hours, mins)
|
||||
end
|
||||
end
|
||||
|
||||
--- Generate model breakdown section
|
||||
---@param stats table Stats with by_model
|
||||
---@return string[] Lines
|
||||
local function generate_model_breakdown(stats)
|
||||
local lines = {}
|
||||
|
||||
if next(stats.by_model) then
|
||||
-- Sort models by cost (descending)
|
||||
local models = {}
|
||||
for model, data in pairs(stats.by_model) do
|
||||
table.insert(models, { name = model, data = data })
|
||||
end
|
||||
table.sort(models, function(a, b)
|
||||
return a.data.cost > b.data.cost
|
||||
end)
|
||||
|
||||
for _, item in ipairs(models) do
|
||||
local model = item.name
|
||||
local data = item.data
|
||||
local pricing = M.pricing[normalize_model(model)]
|
||||
local is_free = data.is_free or M.is_free_model(model)
|
||||
|
||||
table.insert(lines, "")
|
||||
local model_icon = is_free and "🆓" or "💳"
|
||||
table.insert(lines, string.format(" %s %s", model_icon, model))
|
||||
table.insert(lines, string.format(" Requests: %d", data.requests))
|
||||
table.insert(lines, string.format(" Input: %s tokens", format_tokens(data.input_tokens)))
|
||||
table.insert(lines, string.format(" Output: %s tokens", format_tokens(data.output_tokens)))
|
||||
|
||||
if is_free then
|
||||
-- Show savings for free models
|
||||
if data.savings and data.savings > 0 then
|
||||
table.insert(lines, string.format(" Saved: %s", format_cost(data.savings)))
|
||||
end
|
||||
else
|
||||
table.insert(lines, string.format(" Cost: %s", format_cost(data.cost)))
|
||||
end
|
||||
|
||||
-- Show pricing info for paid models
|
||||
if pricing and not is_free then
|
||||
local price_info = string.format(
|
||||
" Rate: $%.2f/1M in, $%.2f/1M out",
|
||||
pricing.input or 0,
|
||||
pricing.output or 0
|
||||
)
|
||||
table.insert(lines, price_info)
|
||||
end
|
||||
end
|
||||
else
|
||||
table.insert(lines, " No usage recorded.")
|
||||
end
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Generate window content
|
||||
---@return string[] Lines for the buffer
|
||||
local function generate_content()
|
||||
local session_stats = M.get_stats()
|
||||
local all_time_stats = M.get_all_time_stats()
|
||||
local lines = {}
|
||||
|
||||
-- Header
|
||||
table.insert(lines, "╔══════════════════════════════════════════════════════╗")
|
||||
table.insert(lines, "║ 💰 LLM Cost Estimation ║")
|
||||
table.insert(lines, "╠══════════════════════════════════════════════════════╣")
|
||||
table.insert(lines, "")
|
||||
|
||||
-- All-time summary (prominent)
|
||||
table.insert(lines, "🌐 All-Time Summary (Project)")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
if all_time_stats.time_span > 0 then
|
||||
table.insert(lines, string.format(" Time span: %s", format_duration(all_time_stats.time_span)))
|
||||
end
|
||||
table.insert(lines, string.format(" Requests: %d total", all_time_stats.request_count))
|
||||
table.insert(lines, string.format(" Local/Free: %d requests", all_time_stats.free_requests or 0))
|
||||
table.insert(lines, string.format(" Paid API: %d requests", all_time_stats.paid_requests or 0))
|
||||
table.insert(lines, string.format(" Input tokens: %s", format_tokens(all_time_stats.total_input)))
|
||||
table.insert(lines, string.format(" Output tokens: %s", format_tokens(all_time_stats.total_output)))
|
||||
if all_time_stats.total_cached > 0 then
|
||||
table.insert(lines, string.format(" Cached tokens: %s", format_tokens(all_time_stats.total_cached)))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, string.format(" 💵 Total Cost: %s", format_cost(all_time_stats.total_cost)))
|
||||
|
||||
-- Show savings prominently if there are any
|
||||
if all_time_stats.total_savings and all_time_stats.total_savings > 0 then
|
||||
table.insert(lines, string.format(" 💚 Saved: %s (vs %s)", format_cost(all_time_stats.total_savings), M.comparison_model))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Session summary
|
||||
table.insert(lines, "📊 Current Session")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
table.insert(lines, string.format(" Duration: %s", format_duration(session_stats.session_duration)))
|
||||
table.insert(lines, string.format(" Requests: %d (%d free, %d paid)",
|
||||
session_stats.request_count,
|
||||
session_stats.free_requests or 0,
|
||||
session_stats.paid_requests or 0))
|
||||
table.insert(lines, string.format(" Input tokens: %s", format_tokens(session_stats.total_input)))
|
||||
table.insert(lines, string.format(" Output tokens: %s", format_tokens(session_stats.total_output)))
|
||||
if session_stats.total_cached > 0 then
|
||||
table.insert(lines, string.format(" Cached tokens: %s", format_tokens(session_stats.total_cached)))
|
||||
end
|
||||
table.insert(lines, string.format(" Session Cost: %s", format_cost(session_stats.total_cost)))
|
||||
if session_stats.total_savings and session_stats.total_savings > 0 then
|
||||
table.insert(lines, string.format(" Session Saved: %s", format_cost(session_stats.total_savings)))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Per-model breakdown (all-time)
|
||||
table.insert(lines, "📈 Cost by Model (All-Time)")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
local model_lines = generate_model_breakdown(all_time_stats)
|
||||
for _, line in ipairs(model_lines) do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
table.insert(lines, " 'q' close | 'r' refresh | 'c' clear session | 'C' all")
|
||||
table.insert(lines, "╚══════════════════════════════════════════════════════╝")
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Refresh the cost window content
|
||||
function M.refresh_window()
|
||||
if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then
|
||||
return
|
||||
end
|
||||
|
||||
local lines = generate_content()
|
||||
|
||||
vim.bo[state.buf].modifiable = true
|
||||
vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, lines)
|
||||
vim.bo[state.buf].modifiable = false
|
||||
end
|
||||
|
||||
--- Open the cost estimation window
|
||||
function M.open()
|
||||
-- Load historical data if not loaded
|
||||
M.load_from_history()
|
||||
|
||||
-- Close existing window if open
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
vim.api.nvim_win_close(state.win, true)
|
||||
end
|
||||
|
||||
-- Create buffer
|
||||
state.buf = vim.api.nvim_create_buf(false, true)
|
||||
vim.bo[state.buf].buftype = "nofile"
|
||||
vim.bo[state.buf].bufhidden = "wipe"
|
||||
vim.bo[state.buf].swapfile = false
|
||||
vim.bo[state.buf].filetype = "codetyper-cost"
|
||||
|
||||
-- Calculate window size
|
||||
local width = 58
|
||||
local height = 40
|
||||
local row = math.floor((vim.o.lines - height) / 2)
|
||||
local col = math.floor((vim.o.columns - width) / 2)
|
||||
|
||||
-- Create floating window
|
||||
state.win = vim.api.nvim_open_win(state.buf, true, {
|
||||
relative = "editor",
|
||||
width = width,
|
||||
height = height,
|
||||
row = row,
|
||||
col = col,
|
||||
style = "minimal",
|
||||
border = "rounded",
|
||||
title = " Cost Estimation ",
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
-- Set window options
|
||||
vim.wo[state.win].wrap = false
|
||||
vim.wo[state.win].cursorline = false
|
||||
|
||||
-- Populate content
|
||||
M.refresh_window()
|
||||
|
||||
-- Set up keymaps
|
||||
local opts = { buffer = state.buf, silent = true }
|
||||
vim.keymap.set("n", "q", function()
|
||||
M.close()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "<Esc>", function()
|
||||
M.close()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "r", function()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "c", function()
|
||||
M.clear_session()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "C", function()
|
||||
M.clear_all()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
|
||||
-- Set up highlights
|
||||
vim.api.nvim_buf_call(state.buf, function()
|
||||
vim.fn.matchadd("Title", "LLM Cost Estimation")
|
||||
vim.fn.matchadd("Number", "\\$[0-9.]*")
|
||||
vim.fn.matchadd("Keyword", "[0-9.]*[KM]\\? tokens")
|
||||
vim.fn.matchadd("Special", "🤖\\|💰\\|📊\\|📈\\|💵")
|
||||
end)
|
||||
end
|
||||
|
||||
--- Close the cost window
|
||||
function M.close()
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
vim.api.nvim_win_close(state.win, true)
|
||||
end
|
||||
state.win = nil
|
||||
state.buf = nil
|
||||
end
|
||||
|
||||
--- Toggle the cost window
|
||||
function M.toggle()
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
M.close()
|
||||
else
|
||||
M.open()
|
||||
end
|
||||
end
|
||||
|
||||
--- Clear session usage (not history)
|
||||
function M.clear_session()
|
||||
state.usage = {}
|
||||
state.session_start = os.time()
|
||||
utils.notify("Session cost tracking cleared", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Clear all history (session + saved)
|
||||
function M.clear_all()
|
||||
state.usage = {}
|
||||
state.all_usage = {}
|
||||
state.session_start = os.time()
|
||||
state.loaded = false
|
||||
|
||||
-- Delete history file
|
||||
local history_path = get_history_path()
|
||||
local ok, err = os.remove(history_path)
|
||||
if not ok and err and not err:match("No such file") then
|
||||
utils.notify("Failed to delete history: " .. err, vim.log.levels.WARN)
|
||||
end
|
||||
|
||||
utils.notify("All cost history cleared", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Clear usage history (alias for clear_session)
|
||||
function M.clear()
|
||||
M.clear_session()
|
||||
end
|
||||
|
||||
--- Reset session
|
||||
function M.reset()
|
||||
M.clear_session()
|
||||
end
|
||||
|
||||
return M
|
||||
602
lua/codetyper/credentials.lua
Normal file
602
lua/codetyper/credentials.lua
Normal file
@@ -0,0 +1,602 @@
|
||||
---@mod codetyper.credentials Secure credential storage for Codetyper.nvim
|
||||
---@brief [[
|
||||
--- Manages API keys and model preferences stored outside of config files.
|
||||
--- Credentials are stored in ~/.local/share/nvim/codetyper/configuration.json
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
--- Get the credentials file path
|
||||
---@return string Path to credentials file
|
||||
local function get_credentials_path()
|
||||
local data_dir = vim.fn.stdpath("data")
|
||||
return data_dir .. "/codetyper/configuration.json"
|
||||
end
|
||||
|
||||
--- Ensure the credentials directory exists
|
||||
---@return boolean Success
|
||||
local function ensure_dir()
|
||||
local data_dir = vim.fn.stdpath("data")
|
||||
local codetyper_dir = data_dir .. "/codetyper"
|
||||
return utils.ensure_dir(codetyper_dir)
|
||||
end
|
||||
|
||||
--- Load credentials from file
|
||||
---@return table Credentials data
|
||||
function M.load()
|
||||
local path = get_credentials_path()
|
||||
local content = utils.read_file(path)
|
||||
|
||||
if not content or content == "" then
|
||||
return {
|
||||
version = 1,
|
||||
providers = {},
|
||||
}
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok or not data then
|
||||
return {
|
||||
version = 1,
|
||||
providers = {},
|
||||
}
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Save credentials to file
|
||||
---@param data table Credentials data
|
||||
---@return boolean Success
|
||||
function M.save(data)
|
||||
if not ensure_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local path = get_credentials_path()
|
||||
local ok, json = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, json)
|
||||
end
|
||||
|
||||
--- Get API key for a provider
|
||||
---@param provider string Provider name (claude, openai, gemini, copilot, ollama)
|
||||
---@return string|nil API key or nil if not found
|
||||
function M.get_api_key(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.api_key then
|
||||
return provider_data.api_key
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get model for a provider
|
||||
---@param provider string Provider name
|
||||
---@return string|nil Model name or nil if not found
|
||||
function M.get_model(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.model then
|
||||
return provider_data.model
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get endpoint for a provider (for custom OpenAI-compatible endpoints)
|
||||
---@param provider string Provider name
|
||||
---@return string|nil Endpoint URL or nil if not found
|
||||
function M.get_endpoint(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.endpoint then
|
||||
return provider_data.endpoint
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get host for Ollama
|
||||
---@return string|nil Host URL or nil if not found
|
||||
function M.get_ollama_host()
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers.ollama
|
||||
|
||||
if provider_data and provider_data.host then
|
||||
return provider_data.host
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Set credentials for a provider
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Credentials (api_key, model, endpoint, host)
|
||||
---@return boolean Success
|
||||
function M.set_credentials(provider, credentials)
|
||||
local data = M.load()
|
||||
|
||||
if not data.providers then
|
||||
data.providers = {}
|
||||
end
|
||||
|
||||
if not data.providers[provider] then
|
||||
data.providers[provider] = {}
|
||||
end
|
||||
|
||||
-- Merge credentials
|
||||
for key, value in pairs(credentials) do
|
||||
if value and value ~= "" then
|
||||
data.providers[provider][key] = value
|
||||
end
|
||||
end
|
||||
|
||||
data.updated = os.time()
|
||||
|
||||
return M.save(data)
|
||||
end
|
||||
|
||||
--- Remove credentials for a provider
|
||||
---@param provider string Provider name
|
||||
---@return boolean Success
|
||||
function M.remove_credentials(provider)
|
||||
local data = M.load()
|
||||
|
||||
if data.providers and data.providers[provider] then
|
||||
data.providers[provider] = nil
|
||||
data.updated = os.time()
|
||||
return M.save(data)
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- List all configured providers (checks both stored credentials AND config)
|
||||
---@return table List of provider names with their config status
|
||||
function M.list_providers()
|
||||
local data = M.load()
|
||||
local result = {}
|
||||
|
||||
local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
|
||||
for _, provider in ipairs(all_providers) do
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
local has_stored_key = provider_data and provider_data.api_key and provider_data.api_key ~= ""
|
||||
local has_model = provider_data and provider_data.model and provider_data.model ~= ""
|
||||
|
||||
-- Check if configured from config or environment
|
||||
local configured_from_config = false
|
||||
local config_model = nil
|
||||
local ok, codetyper = pcall(require, "codetyper")
|
||||
if ok then
|
||||
local config = codetyper.get_config()
|
||||
if config and config.llm and config.llm[provider] then
|
||||
local pc = config.llm[provider]
|
||||
config_model = pc.model
|
||||
|
||||
if provider == "claude" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.ANTHROPIC_API_KEY ~= nil
|
||||
elseif provider == "openai" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.OPENAI_API_KEY ~= nil
|
||||
elseif provider == "gemini" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.GEMINI_API_KEY ~= nil
|
||||
elseif provider == "copilot" then
|
||||
configured_from_config = true -- Just needs copilot.lua
|
||||
elseif provider == "ollama" then
|
||||
configured_from_config = pc.host ~= nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local is_configured = has_stored_key
|
||||
or (provider == "ollama" and provider_data ~= nil)
|
||||
or (provider == "copilot" and (provider_data ~= nil or configured_from_config))
|
||||
or configured_from_config
|
||||
|
||||
table.insert(result, {
|
||||
name = provider,
|
||||
configured = is_configured,
|
||||
has_api_key = has_stored_key,
|
||||
has_model = has_model or config_model ~= nil,
|
||||
model = (provider_data and provider_data.model) or config_model,
|
||||
source = has_stored_key and "stored" or (configured_from_config and "config" or nil),
|
||||
})
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Default models for each provider
|
||||
M.default_models = {
|
||||
claude = "claude-sonnet-4-20250514",
|
||||
openai = "gpt-4o",
|
||||
gemini = "gemini-2.0-flash",
|
||||
copilot = "gpt-4o",
|
||||
ollama = "deepseek-coder:6.7b",
|
||||
}
|
||||
|
||||
--- Interactive command to add/update API key
|
||||
function M.interactive_add()
|
||||
local providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
|
||||
-- Step 1: Select provider
|
||||
vim.ui.select(providers, {
|
||||
prompt = "Select LLM provider:",
|
||||
format_item = function(item)
|
||||
local display = item:sub(1, 1):upper() .. item:sub(2)
|
||||
local creds = M.load()
|
||||
local configured = creds.providers and creds.providers[item]
|
||||
if configured and (configured.api_key or item == "ollama") then
|
||||
return display .. " [configured]"
|
||||
end
|
||||
return display
|
||||
end,
|
||||
}, function(provider)
|
||||
if not provider then
|
||||
return
|
||||
end
|
||||
|
||||
-- Step 2: Get API key (skip for Ollama)
|
||||
if provider == "ollama" then
|
||||
M.interactive_ollama_config()
|
||||
else
|
||||
M.interactive_api_key(provider)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive API key input
|
||||
---@param provider string Provider name
|
||||
function M.interactive_api_key(provider)
|
||||
-- Copilot uses OAuth from copilot.lua, no API key needed
|
||||
if provider == "copilot" then
|
||||
M.interactive_copilot_config()
|
||||
return
|
||||
end
|
||||
|
||||
local prompt = string.format("Enter %s API key (leave empty to skip): ", provider:upper())
|
||||
|
||||
vim.ui.input({ prompt = prompt }, function(api_key)
|
||||
if api_key == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
-- Step 3: Get model
|
||||
M.interactive_model(provider, api_key)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive Copilot configuration (no API key, uses OAuth)
|
||||
function M.interactive_copilot_config()
|
||||
utils.notify("Copilot uses OAuth from copilot.lua/copilot.vim - no API key needed", vim.log.levels.INFO)
|
||||
|
||||
-- Just ask for model
|
||||
local default_model = M.default_models.copilot
|
||||
vim.ui.input({
|
||||
prompt = string.format("Copilot model (default: %s): ", default_model),
|
||||
default = default_model,
|
||||
}, function(model)
|
||||
if model == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if model == "" then
|
||||
model = default_model
|
||||
end
|
||||
|
||||
M.save_and_notify("copilot", {
|
||||
model = model,
|
||||
-- Mark as configured even without API key
|
||||
configured = true,
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive model selection
|
||||
---@param provider string Provider name
|
||||
---@param api_key string|nil API key
|
||||
function M.interactive_model(provider, api_key)
|
||||
local default_model = M.default_models[provider] or ""
|
||||
local prompt = string.format("Enter model (default: %s): ", default_model)
|
||||
|
||||
vim.ui.input({ prompt = prompt, default = default_model }, function(model)
|
||||
if model == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
-- Use default if empty
|
||||
if model == "" then
|
||||
model = default_model
|
||||
end
|
||||
|
||||
-- Save credentials
|
||||
local credentials = {
|
||||
model = model,
|
||||
}
|
||||
|
||||
if api_key and api_key ~= "" then
|
||||
credentials.api_key = api_key
|
||||
end
|
||||
|
||||
-- For OpenAI, also ask for custom endpoint
|
||||
if provider == "openai" then
|
||||
M.interactive_endpoint(provider, credentials)
|
||||
else
|
||||
M.save_and_notify(provider, credentials)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive endpoint input for OpenAI-compatible providers
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Current credentials
|
||||
function M.interactive_endpoint(provider, credentials)
|
||||
vim.ui.input({
|
||||
prompt = "Custom endpoint (leave empty for default OpenAI): ",
|
||||
}, function(endpoint)
|
||||
if endpoint == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if endpoint ~= "" then
|
||||
credentials.endpoint = endpoint
|
||||
end
|
||||
|
||||
M.save_and_notify(provider, credentials)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive Ollama configuration
|
||||
function M.interactive_ollama_config()
|
||||
vim.ui.input({
|
||||
prompt = "Ollama host (default: http://localhost:11434): ",
|
||||
default = "http://localhost:11434",
|
||||
}, function(host)
|
||||
if host == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if host == "" then
|
||||
host = "http://localhost:11434"
|
||||
end
|
||||
|
||||
-- Get model
|
||||
local default_model = M.default_models.ollama
|
||||
vim.ui.input({
|
||||
prompt = string.format("Ollama model (default: %s): ", default_model),
|
||||
default = default_model,
|
||||
}, function(model)
|
||||
if model == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if model == "" then
|
||||
model = default_model
|
||||
end
|
||||
|
||||
M.save_and_notify("ollama", {
|
||||
host = host,
|
||||
model = model,
|
||||
})
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Save credentials and notify user
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Credentials to save
|
||||
function M.save_and_notify(provider, credentials)
|
||||
if M.set_credentials(provider, credentials) then
|
||||
local msg = string.format("Saved %s configuration", provider:upper())
|
||||
if credentials.model then
|
||||
msg = msg .. " (model: " .. credentials.model .. ")"
|
||||
end
|
||||
utils.notify(msg, vim.log.levels.INFO)
|
||||
else
|
||||
utils.notify("Failed to save credentials", vim.log.levels.ERROR)
|
||||
end
|
||||
end
|
||||
|
||||
--- Show current credentials status
|
||||
function M.show_status()
|
||||
local providers = M.list_providers()
|
||||
|
||||
-- Get current active provider
|
||||
local codetyper = require("codetyper")
|
||||
local current = codetyper.get_config().llm.provider
|
||||
|
||||
local lines = {
|
||||
"Codetyper Credentials Status",
|
||||
"============================",
|
||||
"",
|
||||
"Storage: " .. get_credentials_path(),
|
||||
"Active: " .. current:upper(),
|
||||
"",
|
||||
}
|
||||
|
||||
for _, p in ipairs(providers) do
|
||||
local status_icon = p.configured and "✓" or "✗"
|
||||
local active_marker = p.name == current and " [ACTIVE]" or ""
|
||||
local source_info = ""
|
||||
if p.configured then
|
||||
source_info = p.source == "stored" and " (stored)" or " (config)"
|
||||
end
|
||||
local model_info = p.model and (" - " .. p.model) or ""
|
||||
|
||||
table.insert(lines, string.format(" %s %s%s%s%s",
|
||||
status_icon,
|
||||
p.name:upper(),
|
||||
active_marker,
|
||||
source_info,
|
||||
model_info))
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Commands:")
|
||||
table.insert(lines, " :CoderAddApiKey - Add/update credentials")
|
||||
table.insert(lines, " :CoderSwitchProvider - Switch active provider")
|
||||
table.insert(lines, " :CoderRemoveApiKey - Remove stored credentials")
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Interactive remove credentials
|
||||
function M.interactive_remove()
|
||||
local data = M.load()
|
||||
local configured = {}
|
||||
|
||||
for provider, _ in pairs(data.providers or {}) do
|
||||
table.insert(configured, provider)
|
||||
end
|
||||
|
||||
if #configured == 0 then
|
||||
utils.notify("No credentials configured", vim.log.levels.INFO)
|
||||
return
|
||||
end
|
||||
|
||||
vim.ui.select(configured, {
|
||||
prompt = "Select provider to remove:",
|
||||
}, function(provider)
|
||||
if not provider then
|
||||
return
|
||||
end
|
||||
|
||||
vim.ui.select({ "Yes", "No" }, {
|
||||
prompt = "Remove " .. provider:upper() .. " credentials?",
|
||||
}, function(choice)
|
||||
if choice == "Yes" then
|
||||
if M.remove_credentials(provider) then
|
||||
utils.notify("Removed " .. provider:upper() .. " credentials", vim.log.levels.INFO)
|
||||
else
|
||||
utils.notify("Failed to remove credentials", vim.log.levels.ERROR)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Set the active provider
|
||||
---@param provider string Provider name
|
||||
function M.set_active_provider(provider)
|
||||
local data = M.load()
|
||||
data.active_provider = provider
|
||||
data.updated = os.time()
|
||||
M.save(data)
|
||||
|
||||
-- Also update the runtime config
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
config.llm.provider = provider
|
||||
|
||||
utils.notify("Active provider set to: " .. provider:upper(), vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Get the active provider from stored config
|
||||
---@return string|nil Active provider
|
||||
function M.get_active_provider()
|
||||
local data = M.load()
|
||||
return data.active_provider
|
||||
end
|
||||
|
||||
--- Check if a provider is configured (from stored credentials OR config)
|
||||
---@param provider string Provider name
|
||||
---@return boolean configured, string|nil source
|
||||
local function is_provider_configured(provider)
|
||||
-- Check stored credentials first
|
||||
local data = M.load()
|
||||
local stored = data.providers and data.providers[provider]
|
||||
if stored then
|
||||
if stored.configured or stored.api_key or provider == "ollama" or provider == "copilot" then
|
||||
return true, "stored"
|
||||
end
|
||||
end
|
||||
|
||||
-- Check codetyper config
|
||||
local ok, codetyper = pcall(require, "codetyper")
|
||||
if not ok then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local config = codetyper.get_config()
|
||||
if not config or not config.llm then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local provider_config = config.llm[provider]
|
||||
if not provider_config then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
-- Check for API key in config or environment
|
||||
if provider == "claude" then
|
||||
if provider_config.api_key or vim.env.ANTHROPIC_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "openai" then
|
||||
if provider_config.api_key or vim.env.OPENAI_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "gemini" then
|
||||
if provider_config.api_key or vim.env.GEMINI_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "copilot" then
|
||||
-- Copilot just needs copilot.lua installed
|
||||
return true, "config"
|
||||
elseif provider == "ollama" then
|
||||
-- Ollama just needs host configured
|
||||
if provider_config.host then
|
||||
return true, "config"
|
||||
end
|
||||
end
|
||||
|
||||
return false, nil
|
||||
end
|
||||
|
||||
--- Interactive switch provider
|
||||
function M.interactive_switch_provider()
|
||||
local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
local available = {}
|
||||
local sources = {}
|
||||
|
||||
for _, provider in ipairs(all_providers) do
|
||||
local configured, source = is_provider_configured(provider)
|
||||
if configured then
|
||||
table.insert(available, provider)
|
||||
sources[provider] = source
|
||||
end
|
||||
end
|
||||
|
||||
if #available == 0 then
|
||||
utils.notify("No providers configured. Use :CoderAddApiKey or add to your config.", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local current = codetyper.get_config().llm.provider
|
||||
|
||||
vim.ui.select(available, {
|
||||
prompt = "Select provider (current: " .. current .. "):",
|
||||
format_item = function(item)
|
||||
local marker = item == current and " [active]" or ""
|
||||
local source_marker = sources[item] == "stored" and " (stored)" or " (config)"
|
||||
return item:upper() .. marker .. source_marker
|
||||
end,
|
||||
}, function(provider)
|
||||
if provider then
|
||||
M.set_active_provider(provider)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -36,15 +36,7 @@ function M.check()
|
||||
|
||||
health.info("LLM Provider: " .. config.llm.provider)
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
local api_key = config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY
|
||||
if api_key and api_key ~= "" then
|
||||
health.ok("Claude API key configured")
|
||||
else
|
||||
health.warn("Claude API key not set. Set ANTHROPIC_API_KEY or llm.claude.api_key")
|
||||
end
|
||||
health.info("Claude model: " .. config.llm.claude.model)
|
||||
elseif config.llm.provider == "ollama" then
|
||||
if config.llm.provider == "ollama" then
|
||||
health.info("Ollama host: " .. config.llm.ollama.host)
|
||||
health.info("Ollama model: " .. config.llm.ollama.model)
|
||||
|
||||
|
||||
585
lua/codetyper/indexer/analyzer.lua
Normal file
585
lua/codetyper/indexer/analyzer.lua
Normal file
@@ -0,0 +1,585 @@
|
||||
---@mod codetyper.indexer.analyzer Code analyzer using Tree-sitter
|
||||
---@brief [[
|
||||
--- Analyzes source files to extract functions, classes, exports, and imports.
|
||||
--- Uses Tree-sitter when available, falls back to pattern matching.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local scanner = require("codetyper.indexer.scanner")
|
||||
|
||||
--- Language-specific query patterns for Tree-sitter
|
||||
local TS_QUERIES = {
|
||||
lua = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(function_definition) @func
|
||||
(local_function name: (identifier) @name) @func
|
||||
(assignment_statement
|
||||
(variable_list name: (identifier) @name)
|
||||
(expression_list value: (function_definition) @func))
|
||||
]],
|
||||
exports = [[
|
||||
(return_statement (expression_list (table_constructor))) @export
|
||||
]],
|
||||
},
|
||||
typescript = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_definition name: (property_identifier) @name) @func
|
||||
(arrow_function) @func
|
||||
(lexical_declaration
|
||||
(variable_declarator name: (identifier) @name value: (arrow_function) @func))
|
||||
]],
|
||||
exports = [[
|
||||
(export_statement) @export
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
]],
|
||||
},
|
||||
javascript = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_definition name: (property_identifier) @name) @func
|
||||
(arrow_function) @func
|
||||
]],
|
||||
exports = [[
|
||||
(export_statement) @export
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
]],
|
||||
},
|
||||
python = {
|
||||
functions = [[
|
||||
(function_definition name: (identifier) @name) @func
|
||||
]],
|
||||
classes = [[
|
||||
(class_definition name: (identifier) @name) @class
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
(import_from_statement) @import
|
||||
]],
|
||||
},
|
||||
go = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_declaration name: (field_identifier) @name) @func
|
||||
]],
|
||||
imports = [[
|
||||
(import_declaration) @import
|
||||
]],
|
||||
},
|
||||
rust = {
|
||||
functions = [[
|
||||
(function_item name: (identifier) @name) @func
|
||||
]],
|
||||
imports = [[
|
||||
(use_declaration) @import
|
||||
]],
|
||||
},
|
||||
}
|
||||
|
||||
-- Forward declaration for analyze_tree_generic (defined below)
|
||||
local analyze_tree_generic
|
||||
|
||||
--- Hash file content for change detection
|
||||
---@param content string
|
||||
---@return string
|
||||
local function hash_content(content)
|
||||
local hash = 0
|
||||
for i = 1, math.min(#content, 10000) do
|
||||
hash = (hash * 31 + string.byte(content, i)) % 2147483647
|
||||
end
|
||||
return string.format("%08x", hash)
|
||||
end
|
||||
|
||||
--- Try to get Tree-sitter parser for a language
|
||||
---@param lang string
|
||||
---@return boolean
|
||||
local function has_ts_parser(lang)
|
||||
local ok = pcall(vim.treesitter.language.inspect, lang)
|
||||
return ok
|
||||
end
|
||||
|
||||
--- Analyze file using Tree-sitter
|
||||
---@param filepath string
|
||||
---@param lang string
|
||||
---@param content string
|
||||
---@return table|nil
|
||||
local function analyze_with_treesitter(filepath, lang, content)
|
||||
if not has_ts_parser(lang) then
|
||||
return nil
|
||||
end
|
||||
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
-- Create a temporary buffer for parsing
|
||||
local bufnr = vim.api.nvim_create_buf(false, true)
|
||||
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, vim.split(content, "\n"))
|
||||
|
||||
local ok, parser = pcall(vim.treesitter.get_parser, bufnr, lang)
|
||||
if not ok or not parser then
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return nil
|
||||
end
|
||||
|
||||
local tree = parser:parse()[1]
|
||||
if not tree then
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return nil
|
||||
end
|
||||
|
||||
local root = tree:root()
|
||||
local queries = TS_QUERIES[lang]
|
||||
|
||||
if not queries then
|
||||
-- Fallback: walk tree manually for common patterns
|
||||
result = analyze_tree_generic(root, bufnr)
|
||||
else
|
||||
-- Use language-specific queries
|
||||
if queries.functions then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.functions)
|
||||
if query_ok then
|
||||
for id, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local capture_name = query.captures[id]
|
||||
if capture_name == "func" or capture_name == "name" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name = nil
|
||||
|
||||
-- Try to get name from sibling capture or child
|
||||
if capture_name == "func" then
|
||||
local name_node = node:field("name")[1]
|
||||
if name_node then
|
||||
name = vim.treesitter.get_node_text(name_node, bufnr)
|
||||
end
|
||||
else
|
||||
name = vim.treesitter.get_node_text(node, bufnr)
|
||||
end
|
||||
|
||||
if name and not vim.tbl_contains(vim.tbl_map(function(f)
|
||||
return f.name
|
||||
end, result.functions), name) then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.classes then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.classes)
|
||||
if query_ok then
|
||||
for id, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local capture_name = query.captures[id]
|
||||
if capture_name == "class" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.exports then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.exports)
|
||||
if query_ok then
|
||||
for _, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local text = vim.treesitter.get_node_text(node, bufnr)
|
||||
local start_row, _, _, _ = node:range()
|
||||
|
||||
-- Extract export names (simplified)
|
||||
local names = {}
|
||||
for name in text:gmatch("export%s+[%w_]+%s+([%w_]+)") do
|
||||
table.insert(names, name)
|
||||
end
|
||||
for name in text:gmatch("export%s*{([^}]+)}") do
|
||||
for n in name:gmatch("([%w_]+)") do
|
||||
table.insert(names, n)
|
||||
end
|
||||
end
|
||||
|
||||
for _, name in ipairs(names) do
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "unknown",
|
||||
line = start_row + 1,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.imports then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.imports)
|
||||
if query_ok then
|
||||
for _, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local text = vim.treesitter.get_node_text(node, bufnr)
|
||||
local start_row, _, _, _ = node:range()
|
||||
|
||||
-- Extract import source
|
||||
local source = text:match('["\']([^"\']+)["\']')
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = start_row + 1,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return result
|
||||
end
|
||||
|
||||
--- Generic tree analysis for unsupported languages
|
||||
---@param root TSNode
|
||||
---@param bufnr number
|
||||
---@return table
|
||||
analyze_tree_generic = function(root, bufnr)
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
local function visit(node)
|
||||
local node_type = node:type()
|
||||
|
||||
-- Common function patterns
|
||||
if
|
||||
node_type:match("function")
|
||||
or node_type:match("method")
|
||||
or node_type == "arrow_function"
|
||||
or node_type == "func_literal"
|
||||
then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
|
||||
-- Common class patterns
|
||||
if node_type:match("class") or node_type == "struct_item" or node_type == "impl_item" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
|
||||
-- Recurse into children
|
||||
for child in node:iter_children() do
|
||||
visit(child)
|
||||
end
|
||||
end
|
||||
|
||||
visit(root)
|
||||
return result
|
||||
end
|
||||
|
||||
--- Analyze file using pattern matching (fallback)
|
||||
---@param content string
|
||||
---@param lang string
|
||||
---@return table
|
||||
local function analyze_with_patterns(content, lang)
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
-- Language-specific patterns
|
||||
local patterns = {
|
||||
lua = {
|
||||
func_start = "^%s*local?%s*function%s+([%w_%.]+)",
|
||||
func_assign = "^%s*([%w_%.]+)%s*=%s*function",
|
||||
module_return = "^return%s+M",
|
||||
},
|
||||
javascript = {
|
||||
func_start = "^%s*function%s+([%w_]+)",
|
||||
func_arrow = "^%s*const%s+([%w_]+)%s*=%s*",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
export_line = "^%s*export%s+",
|
||||
import_line = "^%s*import%s+",
|
||||
},
|
||||
typescript = {
|
||||
func_start = "^%s*function%s+([%w_]+)",
|
||||
func_arrow = "^%s*const%s+([%w_]+)%s*=%s*",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
export_line = "^%s*export%s+",
|
||||
import_line = "^%s*import%s+",
|
||||
},
|
||||
python = {
|
||||
func_start = "^%s*def%s+([%w_]+)",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
import_line = "^%s*import%s+",
|
||||
from_import = "^%s*from%s+",
|
||||
},
|
||||
go = {
|
||||
func_start = "^func%s+([%w_]+)",
|
||||
method_start = "^func%s+%([^%)]+%)%s+([%w_]+)",
|
||||
import_line = "^import%s+",
|
||||
},
|
||||
rust = {
|
||||
func_start = "^%s*pub?%s*fn%s+([%w_]+)",
|
||||
struct_start = "^%s*pub?%s*struct%s+([%w_]+)",
|
||||
impl_start = "^%s*impl%s+([%w_<>]+)",
|
||||
use_line = "^%s*use%s+",
|
||||
},
|
||||
}
|
||||
|
||||
local lang_patterns = patterns[lang] or patterns.javascript
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
-- Functions
|
||||
if lang_patterns.func_start then
|
||||
local name = line:match(lang_patterns.func_start)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.func_arrow then
|
||||
local name = line:match(lang_patterns.func_arrow)
|
||||
if name and line:match("=>") then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.func_assign then
|
||||
local name = line:match(lang_patterns.func_assign)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.method_start then
|
||||
local name = line:match(lang_patterns.method_start)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Classes
|
||||
if lang_patterns.class_start then
|
||||
local name = line:match(lang_patterns.class_start)
|
||||
if name then
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.struct_start then
|
||||
local name = line:match(lang_patterns.struct_start)
|
||||
if name then
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Exports
|
||||
if lang_patterns.export_line and line:match(lang_patterns.export_line) then
|
||||
local name = line:match("export%s+[%w_]+%s+([%w_]+)")
|
||||
or line:match("export%s+default%s+([%w_]+)")
|
||||
or line:match("export%s+{%s*([%w_]+)")
|
||||
if name then
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "unknown",
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Imports
|
||||
if lang_patterns.import_line and line:match(lang_patterns.import_line) then
|
||||
local source = line:match('["\']([^"\']+)["\']')
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.from_import and line:match(lang_patterns.from_import) then
|
||||
local source = line:match("from%s+([%w_%.]+)")
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.use_line and line:match(lang_patterns.use_line) then
|
||||
local source = line:match("use%s+([%w_:]+)")
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- For Lua, infer exports from module table
|
||||
if lang == "lua" then
|
||||
for _, func in ipairs(result.functions) do
|
||||
if func.name:match("^M%.") then
|
||||
local name = func.name:gsub("^M%.", "")
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "function",
|
||||
line = func.line,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Analyze a single file
|
||||
---@param filepath string Full path to file
|
||||
---@return FileIndex|nil
|
||||
function M.analyze_file(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local lang = scanner.get_language(filepath)
|
||||
|
||||
-- Map to Tree-sitter language names
|
||||
local ts_lang_map = {
|
||||
typescript = "typescript",
|
||||
typescriptreact = "tsx",
|
||||
javascript = "javascript",
|
||||
javascriptreact = "javascript",
|
||||
python = "python",
|
||||
go = "go",
|
||||
rust = "rust",
|
||||
lua = "lua",
|
||||
}
|
||||
|
||||
local ts_lang = ts_lang_map[lang] or lang
|
||||
|
||||
-- Try Tree-sitter first
|
||||
local analysis = analyze_with_treesitter(filepath, ts_lang, content)
|
||||
|
||||
-- Fallback to pattern matching
|
||||
if not analysis then
|
||||
analysis = analyze_with_patterns(content, lang)
|
||||
end
|
||||
|
||||
return {
|
||||
path = filepath,
|
||||
language = lang,
|
||||
hash = hash_content(content),
|
||||
exports = analysis.exports,
|
||||
imports = analysis.imports,
|
||||
functions = analysis.functions,
|
||||
classes = analysis.classes,
|
||||
last_indexed = os.time(),
|
||||
}
|
||||
end
|
||||
|
||||
--- Extract exports from a buffer
|
||||
---@param bufnr number
|
||||
---@return Export[]
|
||||
function M.extract_exports(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.exports or {}
|
||||
end
|
||||
|
||||
--- Extract functions from a buffer
|
||||
---@param bufnr number
|
||||
---@return FunctionInfo[]
|
||||
function M.extract_functions(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.functions or {}
|
||||
end
|
||||
|
||||
--- Extract imports from a buffer
|
||||
---@param bufnr number
|
||||
---@return Import[]
|
||||
function M.extract_imports(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.imports or {}
|
||||
end
|
||||
|
||||
return M
|
||||
604
lua/codetyper/indexer/init.lua
Normal file
604
lua/codetyper/indexer/init.lua
Normal file
@@ -0,0 +1,604 @@
|
||||
---@mod codetyper.indexer Project indexer for Codetyper.nvim
|
||||
---@brief [[
|
||||
--- Indexes project structure, dependencies, and code symbols.
|
||||
--- Stores knowledge in .coder/ directory for enriching LLM context.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
--- Index schema version for migrations
|
||||
local INDEX_VERSION = 1
|
||||
|
||||
--- Index file name
|
||||
local INDEX_FILE = "index.json"
|
||||
|
||||
--- Debounce timer for file indexing
|
||||
local index_timer = nil
|
||||
local INDEX_DEBOUNCE_MS = 500
|
||||
|
||||
--- Default indexer configuration
|
||||
local default_config = {
|
||||
enabled = true,
|
||||
auto_index = true,
|
||||
index_on_open = false,
|
||||
max_file_size = 100000,
|
||||
excluded_dirs = { "node_modules", "dist", "build", ".git", ".coder", "__pycache__", "vendor", "target" },
|
||||
index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" },
|
||||
memory = {
|
||||
enabled = true,
|
||||
max_memories = 1000,
|
||||
prune_threshold = 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
--- Current configuration
|
||||
---@type table
|
||||
local config = vim.deepcopy(default_config)
|
||||
|
||||
--- Cached project index
|
||||
---@type table<string, ProjectIndex>
|
||||
local index_cache = {}
|
||||
|
||||
---@class ProjectIndex
|
||||
---@field version number Index schema version
|
||||
---@field project_root string Absolute path to project
|
||||
---@field project_name string Project name
|
||||
---@field project_type string "node"|"rust"|"go"|"python"|"lua"|"unknown"
|
||||
---@field dependencies table<string, string> name -> version
|
||||
---@field dev_dependencies table<string, string> name -> version
|
||||
---@field files table<string, FileIndex> path -> FileIndex
|
||||
---@field symbols table<string, string[]> symbol -> [file paths]
|
||||
---@field last_indexed number Timestamp
|
||||
---@field stats {files: number, functions: number, classes: number, exports: number}
|
||||
|
||||
---@class FileIndex
|
||||
---@field path string Relative path from project root
|
||||
---@field language string Detected language
|
||||
---@field hash string Content hash for change detection
|
||||
---@field exports Export[] Exported symbols
|
||||
---@field imports Import[] Dependencies
|
||||
---@field functions FunctionInfo[]
|
||||
---@field classes ClassInfo[]
|
||||
---@field last_indexed number Timestamp
|
||||
|
||||
---@class Export
|
||||
---@field name string Symbol name
|
||||
---@field type string "function"|"class"|"constant"|"type"|"variable"
|
||||
---@field line number Line number
|
||||
|
||||
---@class Import
|
||||
---@field source string Import source/module
|
||||
---@field names string[] Imported names
|
||||
---@field line number Line number
|
||||
|
||||
---@class FunctionInfo
|
||||
---@field name string Function name
|
||||
---@field params string[] Parameter names
|
||||
---@field line number Start line
|
||||
---@field end_line number End line
|
||||
---@field docstring string|nil Documentation
|
||||
|
||||
---@class ClassInfo
|
||||
---@field name string Class name
|
||||
---@field methods string[] Method names
|
||||
---@field line number Start line
|
||||
---@field end_line number End line
|
||||
---@field docstring string|nil Documentation
|
||||
|
||||
--- Get the index file path
|
||||
---@return string|nil
|
||||
local function get_index_path()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. INDEX_FILE
|
||||
end
|
||||
|
||||
--- Create empty index structure
|
||||
---@return ProjectIndex
|
||||
local function create_empty_index()
|
||||
local root = utils.get_project_root()
|
||||
return {
|
||||
version = INDEX_VERSION,
|
||||
project_root = root or "",
|
||||
project_name = root and vim.fn.fnamemodify(root, ":t") or "",
|
||||
project_type = "unknown",
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
files = {},
|
||||
symbols = {},
|
||||
last_indexed = os.time(),
|
||||
stats = {
|
||||
files = 0,
|
||||
functions = 0,
|
||||
classes = 0,
|
||||
exports = 0,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Load index from disk
|
||||
---@return ProjectIndex|nil
|
||||
function M.load_index()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Check cache first
|
||||
if index_cache[root] then
|
||||
return index_cache[root]
|
||||
end
|
||||
|
||||
local path = get_index_path()
|
||||
if not path then
|
||||
return nil
|
||||
end
|
||||
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ok, index = pcall(vim.json.decode, content)
|
||||
if not ok or not index then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Validate version
|
||||
if index.version ~= INDEX_VERSION then
|
||||
-- Index needs migration or rebuild
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Cache it
|
||||
index_cache[root] = index
|
||||
return index
|
||||
end
|
||||
|
||||
--- Save index to disk
|
||||
---@param index ProjectIndex
|
||||
---@return boolean
|
||||
function M.save_index(index)
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Ensure .coder directory exists
|
||||
local coder_dir = root .. "/.coder"
|
||||
utils.ensure_dir(coder_dir)
|
||||
|
||||
local path = get_index_path()
|
||||
if not path then
|
||||
return false
|
||||
end
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, index)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
local success = utils.write_file(path, encoded)
|
||||
if success then
|
||||
-- Update cache
|
||||
index_cache[root] = index
|
||||
end
|
||||
return success
|
||||
end
|
||||
|
||||
--- Index the entire project
|
||||
---@param callback? fun(index: ProjectIndex)
|
||||
---@return ProjectIndex|nil
|
||||
function M.index_project(callback)
|
||||
local scanner = require("codetyper.indexer.scanner")
|
||||
local analyzer = require("codetyper.indexer.analyzer")
|
||||
|
||||
local index = create_empty_index()
|
||||
local root = utils.get_project_root()
|
||||
|
||||
if not root then
|
||||
if callback then
|
||||
callback(index)
|
||||
end
|
||||
return index
|
||||
end
|
||||
|
||||
-- Detect project type and parse dependencies
|
||||
index.project_type = scanner.detect_project_type(root)
|
||||
local deps = scanner.parse_dependencies(root, index.project_type)
|
||||
index.dependencies = deps.dependencies or {}
|
||||
index.dev_dependencies = deps.dev_dependencies or {}
|
||||
|
||||
-- Get all indexable files
|
||||
local files = scanner.get_indexable_files(root, config)
|
||||
|
||||
-- Index each file
|
||||
local total_functions = 0
|
||||
local total_classes = 0
|
||||
local total_exports = 0
|
||||
|
||||
for _, filepath in ipairs(files) do
|
||||
local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local file_index = analyzer.analyze_file(filepath)
|
||||
|
||||
if file_index then
|
||||
file_index.path = relative_path
|
||||
index.files[relative_path] = file_index
|
||||
|
||||
-- Update symbol index
|
||||
for _, exp in ipairs(file_index.exports or {}) do
|
||||
if not index.symbols[exp.name] then
|
||||
index.symbols[exp.name] = {}
|
||||
end
|
||||
table.insert(index.symbols[exp.name], relative_path)
|
||||
total_exports = total_exports + 1
|
||||
end
|
||||
|
||||
total_functions = total_functions + #(file_index.functions or {})
|
||||
total_classes = total_classes + #(file_index.classes or {})
|
||||
end
|
||||
end
|
||||
|
||||
-- Update stats
|
||||
index.stats = {
|
||||
files = #files,
|
||||
functions = total_functions,
|
||||
classes = total_classes,
|
||||
exports = total_exports,
|
||||
}
|
||||
index.last_indexed = os.time()
|
||||
|
||||
-- Save to disk
|
||||
M.save_index(index)
|
||||
|
||||
-- Store memories
|
||||
local memory = require("codetyper.indexer.memory")
|
||||
memory.store_index_summary(index)
|
||||
|
||||
-- Sync project summary to brain
|
||||
M.sync_project_to_brain(index, files, root)
|
||||
|
||||
if callback then
|
||||
callback(index)
|
||||
end
|
||||
|
||||
return index
|
||||
end
|
||||
|
||||
--- Sync project index to brain
|
||||
---@param index ProjectIndex
|
||||
---@param files string[] List of file paths
|
||||
---@param root string Project root
|
||||
function M.sync_project_to_brain(index, files, root)
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain or not brain.is_initialized or not brain.is_initialized() then
|
||||
return
|
||||
end
|
||||
|
||||
-- Store project-level pattern
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root,
|
||||
content = {
|
||||
summary = "Project: "
|
||||
.. index.project_name
|
||||
.. " ("
|
||||
.. index.project_type
|
||||
.. ") - "
|
||||
.. index.stats.files
|
||||
.. " files",
|
||||
detail = string.format(
|
||||
"%d functions, %d classes, %d exports",
|
||||
index.stats.functions,
|
||||
index.stats.classes,
|
||||
index.stats.exports
|
||||
),
|
||||
},
|
||||
context = {
|
||||
file = root,
|
||||
project_type = index.project_type,
|
||||
dependencies = index.dependencies,
|
||||
},
|
||||
})
|
||||
|
||||
-- Store key file patterns (files with most functions/classes)
|
||||
local key_files = {}
|
||||
for path, file_index in pairs(index.files) do
|
||||
local score = #(file_index.functions or {}) + (#(file_index.classes or {}) * 2)
|
||||
if score >= 3 then
|
||||
table.insert(key_files, { path = path, index = file_index, score = score })
|
||||
end
|
||||
end
|
||||
|
||||
table.sort(key_files, function(a, b)
|
||||
return a.score > b.score
|
||||
end)
|
||||
|
||||
-- Store top 20 key files in brain
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i > 20 then
|
||||
break
|
||||
end
|
||||
M.sync_to_brain(root .. "/" .. kf.path, kf.index)
|
||||
end
|
||||
end
|
||||
|
||||
--- Index a single file (incremental update)
|
||||
---@param filepath string
|
||||
---@return FileIndex|nil
|
||||
function M.index_file(filepath)
|
||||
local analyzer = require("codetyper.indexer.analyzer")
|
||||
local memory = require("codetyper.indexer.memory")
|
||||
local root = utils.get_project_root()
|
||||
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Load existing index
|
||||
local index = M.load_index() or create_empty_index()
|
||||
|
||||
-- Analyze file
|
||||
local file_index = analyzer.analyze_file(filepath)
|
||||
if not file_index then
|
||||
return nil
|
||||
end
|
||||
|
||||
local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
file_index.path = relative_path
|
||||
|
||||
-- Remove old symbol references for this file
|
||||
for symbol, paths in pairs(index.symbols) do
|
||||
for i = #paths, 1, -1 do
|
||||
if paths[i] == relative_path then
|
||||
table.remove(paths, i)
|
||||
end
|
||||
end
|
||||
if #paths == 0 then
|
||||
index.symbols[symbol] = nil
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new file index
|
||||
index.files[relative_path] = file_index
|
||||
|
||||
-- Update symbol index
|
||||
for _, exp in ipairs(file_index.exports or {}) do
|
||||
if not index.symbols[exp.name] then
|
||||
index.symbols[exp.name] = {}
|
||||
end
|
||||
table.insert(index.symbols[exp.name], relative_path)
|
||||
end
|
||||
|
||||
-- Recalculate stats
|
||||
local total_functions = 0
|
||||
local total_classes = 0
|
||||
local total_exports = 0
|
||||
local file_count = 0
|
||||
|
||||
for _, f in pairs(index.files) do
|
||||
file_count = file_count + 1
|
||||
total_functions = total_functions + #(f.functions or {})
|
||||
total_classes = total_classes + #(f.classes or {})
|
||||
total_exports = total_exports + #(f.exports or {})
|
||||
end
|
||||
|
||||
index.stats = {
|
||||
files = file_count,
|
||||
functions = total_functions,
|
||||
classes = total_classes,
|
||||
exports = total_exports,
|
||||
}
|
||||
index.last_indexed = os.time()
|
||||
|
||||
-- Save to disk
|
||||
M.save_index(index)
|
||||
|
||||
-- Store file memory
|
||||
memory.store_file_memory(relative_path, file_index)
|
||||
|
||||
-- Sync to brain if available
|
||||
M.sync_to_brain(filepath, file_index)
|
||||
|
||||
return file_index
|
||||
end
|
||||
|
||||
--- Sync file analysis to brain system
|
||||
---@param filepath string Full file path
|
||||
---@param file_index FileIndex File analysis
|
||||
function M.sync_to_brain(filepath, file_index)
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain or not brain.is_initialized or not brain.is_initialized() then
|
||||
return
|
||||
end
|
||||
|
||||
-- Only store if file has meaningful content
|
||||
local funcs = file_index.functions or {}
|
||||
local classes = file_index.classes or {}
|
||||
if #funcs == 0 and #classes == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build summary
|
||||
local parts = {}
|
||||
if #funcs > 0 then
|
||||
local func_names = {}
|
||||
for i, f in ipairs(funcs) do
|
||||
if i <= 5 then
|
||||
table.insert(func_names, f.name)
|
||||
end
|
||||
end
|
||||
table.insert(parts, "functions: " .. table.concat(func_names, ", "))
|
||||
if #funcs > 5 then
|
||||
table.insert(parts, "(+" .. (#funcs - 5) .. " more)")
|
||||
end
|
||||
end
|
||||
if #classes > 0 then
|
||||
local class_names = {}
|
||||
for _, c in ipairs(classes) do
|
||||
table.insert(class_names, c.name)
|
||||
end
|
||||
table.insert(parts, "classes: " .. table.concat(class_names, ", "))
|
||||
end
|
||||
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
local summary = filename .. " - " .. table.concat(parts, "; ")
|
||||
|
||||
-- Learn this pattern in brain
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = filepath,
|
||||
content = {
|
||||
summary = summary,
|
||||
detail = #funcs .. " functions, " .. #classes .. " classes",
|
||||
},
|
||||
context = {
|
||||
file = file_index.path or filepath,
|
||||
language = file_index.language,
|
||||
functions = funcs,
|
||||
classes = classes,
|
||||
exports = file_index.exports,
|
||||
imports = file_index.imports,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Schedule file indexing with debounce
|
||||
---@param filepath string
|
||||
function M.schedule_index_file(filepath)
|
||||
if not config.enabled or not config.auto_index then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if file should be indexed
|
||||
local scanner = require("codetyper.indexer.scanner")
|
||||
if not scanner.should_index(filepath, config) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Cancel existing timer
|
||||
if index_timer then
|
||||
index_timer:stop()
|
||||
end
|
||||
|
||||
-- Schedule new index
|
||||
index_timer = vim.defer_fn(function()
|
||||
M.index_file(filepath)
|
||||
index_timer = nil
|
||||
end, INDEX_DEBOUNCE_MS)
|
||||
end
|
||||
|
||||
--- Get relevant context for a prompt
|
||||
---@param opts {file: string, intent: table|nil, prompt: string, scope: string|nil}
|
||||
---@return table Context information
|
||||
function M.get_context_for(opts)
|
||||
local memory = require("codetyper.indexer.memory")
|
||||
local index = M.load_index()
|
||||
|
||||
local context = {
|
||||
project_type = "unknown",
|
||||
dependencies = {},
|
||||
relevant_files = {},
|
||||
relevant_symbols = {},
|
||||
patterns = {},
|
||||
}
|
||||
|
||||
if not index then
|
||||
return context
|
||||
end
|
||||
|
||||
context.project_type = index.project_type
|
||||
context.dependencies = index.dependencies
|
||||
|
||||
-- Find relevant symbols from prompt
|
||||
local words = {}
|
||||
for word in opts.prompt:gmatch("%w+") do
|
||||
if #word > 2 then
|
||||
words[word:lower()] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- Match symbols
|
||||
for symbol, files in pairs(index.symbols) do
|
||||
if words[symbol:lower()] then
|
||||
context.relevant_symbols[symbol] = files
|
||||
end
|
||||
end
|
||||
|
||||
-- Get file context if available
|
||||
if opts.file then
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
local relative_path = opts.file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local file_index = index.files[relative_path]
|
||||
if file_index then
|
||||
context.current_file = file_index
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get relevant memories
|
||||
context.patterns = memory.get_relevant(opts.prompt, 5)
|
||||
|
||||
return context
|
||||
end
|
||||
|
||||
--- Get index status
|
||||
---@return table Status information
|
||||
function M.get_status()
|
||||
local index = M.load_index()
|
||||
if not index then
|
||||
return {
|
||||
indexed = false,
|
||||
stats = nil,
|
||||
last_indexed = nil,
|
||||
}
|
||||
end
|
||||
|
||||
return {
|
||||
indexed = true,
|
||||
stats = index.stats,
|
||||
last_indexed = index.last_indexed,
|
||||
project_type = index.project_type,
|
||||
}
|
||||
end
|
||||
|
||||
--- Clear the project index
|
||||
function M.clear()
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
index_cache[root] = nil
|
||||
end
|
||||
|
||||
local path = get_index_path()
|
||||
if path and utils.file_exists(path) then
|
||||
os.remove(path)
|
||||
end
|
||||
end
|
||||
|
||||
--- Setup the indexer with configuration
|
||||
---@param opts? table Configuration options
|
||||
function M.setup(opts)
|
||||
if opts then
|
||||
config = vim.tbl_deep_extend("force", config, opts)
|
||||
end
|
||||
|
||||
-- Index on startup if configured
|
||||
if config.index_on_open then
|
||||
vim.defer_fn(function()
|
||||
M.index_project()
|
||||
end, 1000)
|
||||
end
|
||||
end
|
||||
|
||||
--- Get current configuration
|
||||
---@return table
|
||||
function M.get_config()
|
||||
return vim.deepcopy(config)
|
||||
end
|
||||
|
||||
return M
|
||||
539
lua/codetyper/indexer/memory.lua
Normal file
539
lua/codetyper/indexer/memory.lua
Normal file
@@ -0,0 +1,539 @@
|
||||
---@mod codetyper.indexer.memory Memory persistence manager
|
||||
---@brief [[
|
||||
--- Stores and retrieves learned patterns and memories in .coder/memories/.
|
||||
--- Supports session history for learning from interactions.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
--- Memory directories
|
||||
local MEMORIES_DIR = "memories"
|
||||
local SESSIONS_DIR = "sessions"
|
||||
local FILES_DIR = "files"
|
||||
|
||||
--- Memory files
|
||||
local PATTERNS_FILE = "patterns.json"
|
||||
local CONVENTIONS_FILE = "conventions.json"
|
||||
local SYMBOLS_FILE = "symbols.json"
|
||||
|
||||
--- In-memory cache
|
||||
local cache = {
|
||||
patterns = nil,
|
||||
conventions = nil,
|
||||
symbols = nil,
|
||||
}
|
||||
|
||||
---@class Memory
|
||||
---@field id string Unique identifier
|
||||
---@field type "pattern"|"convention"|"session"|"interaction"
|
||||
---@field content string The learned information
|
||||
---@field context table Where/when learned
|
||||
---@field weight number Importance score (0.0-1.0)
|
||||
---@field created_at number Timestamp
|
||||
---@field updated_at number Last update timestamp
|
||||
---@field used_count number Times referenced
|
||||
|
||||
--- Get the memories base directory
|
||||
---@return string|nil
|
||||
local function get_memories_dir()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. MEMORIES_DIR
|
||||
end
|
||||
|
||||
--- Get the sessions directory
|
||||
---@return string|nil
|
||||
local function get_sessions_dir()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. SESSIONS_DIR
|
||||
end
|
||||
|
||||
--- Ensure memories directory exists
|
||||
---@return boolean
|
||||
local function ensure_memories_dir()
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
utils.ensure_dir(dir)
|
||||
utils.ensure_dir(dir .. "/" .. FILES_DIR)
|
||||
return true
|
||||
end
|
||||
|
||||
--- Ensure sessions directory exists
|
||||
---@return boolean
|
||||
local function ensure_sessions_dir()
|
||||
local dir = get_sessions_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
return utils.ensure_dir(dir)
|
||||
end
|
||||
|
||||
--- Generate a unique ID
|
||||
---@return string
|
||||
local function generate_id()
|
||||
return string.format("mem_%d_%s", os.time(), string.sub(tostring(math.random()), 3, 8))
|
||||
end
|
||||
|
||||
--- Load a memory file
|
||||
---@param filename string
|
||||
---@return table
|
||||
local function load_memory_file(filename)
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return {}
|
||||
end
|
||||
|
||||
local path = dir .. "/" .. filename
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return {}
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok or not data then
|
||||
return {}
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Save a memory file
|
||||
---@param filename string
|
||||
---@param data table
|
||||
---@return boolean
|
||||
local function save_memory_file(filename, data)
|
||||
if not ensure_memories_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
|
||||
local path = dir .. "/" .. filename
|
||||
local ok, encoded = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, encoded)
|
||||
end
|
||||
|
||||
--- Hash a file path for storage
|
||||
---@param filepath string
|
||||
---@return string
|
||||
local function hash_path(filepath)
|
||||
local hash = 0
|
||||
for i = 1, #filepath do
|
||||
hash = (hash * 31 + string.byte(filepath, i)) % 2147483647
|
||||
end
|
||||
return string.format("%08x", hash)
|
||||
end
|
||||
|
||||
--- Load patterns from cache or disk
|
||||
---@return table
|
||||
function M.load_patterns()
|
||||
if cache.patterns then
|
||||
return cache.patterns
|
||||
end
|
||||
cache.patterns = load_memory_file(PATTERNS_FILE)
|
||||
return cache.patterns
|
||||
end
|
||||
|
||||
--- Load conventions from cache or disk
|
||||
---@return table
|
||||
function M.load_conventions()
|
||||
if cache.conventions then
|
||||
return cache.conventions
|
||||
end
|
||||
cache.conventions = load_memory_file(CONVENTIONS_FILE)
|
||||
return cache.conventions
|
||||
end
|
||||
|
||||
--- Load symbols from cache or disk
|
||||
---@return table
|
||||
function M.load_symbols()
|
||||
if cache.symbols then
|
||||
return cache.symbols
|
||||
end
|
||||
cache.symbols = load_memory_file(SYMBOLS_FILE)
|
||||
return cache.symbols
|
||||
end
|
||||
|
||||
--- Store a new memory
|
||||
---@param memory Memory
|
||||
---@return boolean
|
||||
function M.store_memory(memory)
|
||||
memory.id = memory.id or generate_id()
|
||||
memory.created_at = memory.created_at or os.time()
|
||||
memory.updated_at = os.time()
|
||||
memory.used_count = memory.used_count or 0
|
||||
memory.weight = memory.weight or 0.5
|
||||
|
||||
local filename
|
||||
if memory.type == "pattern" then
|
||||
filename = PATTERNS_FILE
|
||||
cache.patterns = nil
|
||||
elseif memory.type == "convention" then
|
||||
filename = CONVENTIONS_FILE
|
||||
cache.conventions = nil
|
||||
else
|
||||
filename = PATTERNS_FILE
|
||||
cache.patterns = nil
|
||||
end
|
||||
|
||||
local data = load_memory_file(filename)
|
||||
data[memory.id] = memory
|
||||
|
||||
return save_memory_file(filename, data)
|
||||
end
|
||||
|
||||
--- Store file-specific memory
|
||||
---@param relative_path string Relative file path
|
||||
---@param file_index table FileIndex data
|
||||
---@return boolean
|
||||
function M.store_file_memory(relative_path, file_index)
|
||||
if not ensure_memories_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
|
||||
local hash = hash_path(relative_path)
|
||||
local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json"
|
||||
|
||||
local data = {
|
||||
path = relative_path,
|
||||
indexed_at = os.time(),
|
||||
functions = file_index.functions or {},
|
||||
classes = file_index.classes or {},
|
||||
exports = file_index.exports or {},
|
||||
imports = file_index.imports or {},
|
||||
}
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, encoded)
|
||||
end
|
||||
|
||||
--- Load file-specific memory
|
||||
---@param relative_path string
|
||||
---@return table|nil
|
||||
function M.load_file_memory(relative_path)
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return nil
|
||||
end
|
||||
|
||||
local hash = hash_path(relative_path)
|
||||
local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json"
|
||||
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok then
|
||||
return nil
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Store index summary as memories
|
||||
---@param index ProjectIndex
|
||||
function M.store_index_summary(index)
|
||||
-- Store project type convention
|
||||
if index.project_type and index.project_type ~= "unknown" then
|
||||
M.store_memory({
|
||||
type = "convention",
|
||||
content = "Project uses " .. index.project_type .. " ecosystem",
|
||||
context = {
|
||||
project_root = index.project_root,
|
||||
detected_at = os.time(),
|
||||
},
|
||||
weight = 0.9,
|
||||
})
|
||||
end
|
||||
|
||||
-- Store dependency patterns
|
||||
local dep_count = 0
|
||||
for _ in pairs(index.dependencies or {}) do
|
||||
dep_count = dep_count + 1
|
||||
end
|
||||
|
||||
if dep_count > 0 then
|
||||
local deps_list = {}
|
||||
for name, _ in pairs(index.dependencies) do
|
||||
table.insert(deps_list, name)
|
||||
end
|
||||
|
||||
M.store_memory({
|
||||
type = "pattern",
|
||||
content = "Project dependencies: " .. table.concat(deps_list, ", "),
|
||||
context = {
|
||||
dependency_count = dep_count,
|
||||
},
|
||||
weight = 0.7,
|
||||
})
|
||||
end
|
||||
|
||||
-- Update symbol cache
|
||||
cache.symbols = nil
|
||||
save_memory_file(SYMBOLS_FILE, index.symbols or {})
|
||||
end
|
||||
|
||||
--- Store session interaction
|
||||
---@param interaction {prompt: string, response: string, file: string|nil, success: boolean}
|
||||
function M.store_session(interaction)
|
||||
if not ensure_sessions_dir() then
|
||||
return
|
||||
end
|
||||
|
||||
local dir = get_sessions_dir()
|
||||
if not dir then
|
||||
return
|
||||
end
|
||||
|
||||
-- Use date-based session files
|
||||
local date = os.date("%Y-%m-%d")
|
||||
local path = dir .. "/" .. date .. ".json"
|
||||
|
||||
local sessions = {}
|
||||
local content = utils.read_file(path)
|
||||
if content then
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if ok and data then
|
||||
sessions = data
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(sessions, {
|
||||
timestamp = os.time(),
|
||||
prompt = interaction.prompt,
|
||||
response = string.sub(interaction.response or "", 1, 500), -- Truncate
|
||||
file = interaction.file,
|
||||
success = interaction.success,
|
||||
})
|
||||
|
||||
-- Limit session size
|
||||
if #sessions > 100 then
|
||||
sessions = { unpack(sessions, #sessions - 99) }
|
||||
end
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, sessions)
|
||||
if ok then
|
||||
utils.write_file(path, encoded)
|
||||
end
|
||||
end
|
||||
|
||||
--- Get relevant memories for a query
|
||||
---@param query string Search query
|
||||
---@param limit number Maximum results
|
||||
---@return Memory[]
|
||||
function M.get_relevant(query, limit)
|
||||
limit = limit or 10
|
||||
local results = {}
|
||||
|
||||
-- Tokenize query
|
||||
local query_words = {}
|
||||
for word in query:lower():gmatch("%w+") do
|
||||
if #word > 2 then
|
||||
query_words[word] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- Search patterns
|
||||
local patterns = M.load_patterns()
|
||||
for _, memory in pairs(patterns) do
|
||||
local score = 0
|
||||
local content_lower = (memory.content or ""):lower()
|
||||
|
||||
for word in pairs(query_words) do
|
||||
if content_lower:find(word, 1, true) then
|
||||
score = score + 1
|
||||
end
|
||||
end
|
||||
|
||||
if score > 0 then
|
||||
memory.relevance_score = score * (memory.weight or 0.5)
|
||||
table.insert(results, memory)
|
||||
end
|
||||
end
|
||||
|
||||
-- Search conventions
|
||||
local conventions = M.load_conventions()
|
||||
for _, memory in pairs(conventions) do
|
||||
local score = 0
|
||||
local content_lower = (memory.content or ""):lower()
|
||||
|
||||
for word in pairs(query_words) do
|
||||
if content_lower:find(word, 1, true) then
|
||||
score = score + 1
|
||||
end
|
||||
end
|
||||
|
||||
if score > 0 then
|
||||
memory.relevance_score = score * (memory.weight or 0.5)
|
||||
table.insert(results, memory)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort by relevance
|
||||
table.sort(results, function(a, b)
|
||||
return (a.relevance_score or 0) > (b.relevance_score or 0)
|
||||
end)
|
||||
|
||||
-- Limit results
|
||||
local limited = {}
|
||||
for i = 1, math.min(limit, #results) do
|
||||
limited[i] = results[i]
|
||||
end
|
||||
|
||||
return limited
|
||||
end
|
||||
|
||||
--- Update memory usage count
|
||||
---@param memory_id string
|
||||
function M.update_usage(memory_id)
|
||||
local patterns = M.load_patterns()
|
||||
if patterns[memory_id] then
|
||||
patterns[memory_id].used_count = (patterns[memory_id].used_count or 0) + 1
|
||||
patterns[memory_id].updated_at = os.time()
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
return
|
||||
end
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
if conventions[memory_id] then
|
||||
conventions[memory_id].used_count = (conventions[memory_id].used_count or 0) + 1
|
||||
conventions[memory_id].updated_at = os.time()
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
end
|
||||
|
||||
--- Get all memories
|
||||
---@return {patterns: table, conventions: table, symbols: table}
|
||||
function M.get_all()
|
||||
return {
|
||||
patterns = M.load_patterns(),
|
||||
conventions = M.load_conventions(),
|
||||
symbols = M.load_symbols(),
|
||||
}
|
||||
end
|
||||
|
||||
--- Clear all memories
|
||||
---@param pattern? string Optional pattern to match memory IDs
|
||||
function M.clear(pattern)
|
||||
if not pattern then
|
||||
-- Clear all
|
||||
cache = { patterns = nil, conventions = nil, symbols = nil }
|
||||
save_memory_file(PATTERNS_FILE, {})
|
||||
save_memory_file(CONVENTIONS_FILE, {})
|
||||
save_memory_file(SYMBOLS_FILE, {})
|
||||
return
|
||||
end
|
||||
|
||||
-- Clear matching pattern
|
||||
local patterns = M.load_patterns()
|
||||
for id in pairs(patterns) do
|
||||
if id:match(pattern) then
|
||||
patterns[id] = nil
|
||||
end
|
||||
end
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
for id in pairs(conventions) do
|
||||
if id:match(pattern) then
|
||||
conventions[id] = nil
|
||||
end
|
||||
end
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
|
||||
--- Prune low-weight memories
|
||||
---@param threshold number Weight threshold (default: 0.1)
|
||||
function M.prune(threshold)
|
||||
threshold = threshold or 0.1
|
||||
|
||||
local patterns = M.load_patterns()
|
||||
local pruned = 0
|
||||
for id, memory in pairs(patterns) do
|
||||
if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then
|
||||
patterns[id] = nil
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
if pruned > 0 then
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
end
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
for id, memory in pairs(conventions) do
|
||||
if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then
|
||||
conventions[id] = nil
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
if pruned > 0 then
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
|
||||
return pruned
|
||||
end
|
||||
|
||||
--- Get memory statistics
|
||||
---@return table
|
||||
function M.get_stats()
|
||||
local patterns = M.load_patterns()
|
||||
local conventions = M.load_conventions()
|
||||
local symbols = M.load_symbols()
|
||||
|
||||
local pattern_count = 0
|
||||
for _ in pairs(patterns) do
|
||||
pattern_count = pattern_count + 1
|
||||
end
|
||||
|
||||
local convention_count = 0
|
||||
for _ in pairs(conventions) do
|
||||
convention_count = convention_count + 1
|
||||
end
|
||||
|
||||
local symbol_count = 0
|
||||
for _ in pairs(symbols) do
|
||||
symbol_count = symbol_count + 1
|
||||
end
|
||||
|
||||
return {
|
||||
patterns = pattern_count,
|
||||
conventions = convention_count,
|
||||
symbols = symbol_count,
|
||||
total = pattern_count + convention_count,
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
409
lua/codetyper/indexer/scanner.lua
Normal file
409
lua/codetyper/indexer/scanner.lua
Normal file
@@ -0,0 +1,409 @@
|
||||
---@mod codetyper.indexer.scanner File scanner for project indexing
|
||||
---@brief [[
|
||||
--- Discovers indexable files, detects project type, and parses dependencies.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
--- Project type markers
|
||||
local PROJECT_MARKERS = {
|
||||
node = { "package.json" },
|
||||
rust = { "Cargo.toml" },
|
||||
go = { "go.mod" },
|
||||
python = { "pyproject.toml", "setup.py", "requirements.txt" },
|
||||
lua = { "init.lua", ".luarc.json" },
|
||||
ruby = { "Gemfile" },
|
||||
java = { "pom.xml", "build.gradle" },
|
||||
csharp = { "*.csproj", "*.sln" },
|
||||
}
|
||||
|
||||
--- File extension to language mapping
|
||||
local EXTENSION_LANGUAGE = {
|
||||
lua = "lua",
|
||||
ts = "typescript",
|
||||
tsx = "typescriptreact",
|
||||
js = "javascript",
|
||||
jsx = "javascriptreact",
|
||||
py = "python",
|
||||
go = "go",
|
||||
rs = "rust",
|
||||
rb = "ruby",
|
||||
java = "java",
|
||||
c = "c",
|
||||
cpp = "cpp",
|
||||
h = "c",
|
||||
hpp = "cpp",
|
||||
cs = "csharp",
|
||||
}
|
||||
|
||||
--- Default ignore patterns
|
||||
local DEFAULT_IGNORES = {
|
||||
"^%.", -- Hidden files/folders
|
||||
"^node_modules$",
|
||||
"^__pycache__$",
|
||||
"^%.git$",
|
||||
"^%.coder$",
|
||||
"^dist$",
|
||||
"^build$",
|
||||
"^target$",
|
||||
"^vendor$",
|
||||
"^%.next$",
|
||||
"^%.nuxt$",
|
||||
"^coverage$",
|
||||
"%.min%.js$",
|
||||
"%.min%.css$",
|
||||
"%.map$",
|
||||
"%.lock$",
|
||||
"%-lock%.json$",
|
||||
}
|
||||
|
||||
--- Detect project type from root markers
|
||||
---@param root string Project root path
|
||||
---@return string Project type
|
||||
function M.detect_project_type(root)
|
||||
for project_type, markers in pairs(PROJECT_MARKERS) do
|
||||
for _, marker in ipairs(markers) do
|
||||
local path = root .. "/" .. marker
|
||||
if marker:match("^%*") then
|
||||
-- Glob pattern
|
||||
local pattern = marker:gsub("^%*", "")
|
||||
local entries = vim.fn.glob(root .. "/*" .. pattern, false, true)
|
||||
if #entries > 0 then
|
||||
return project_type
|
||||
end
|
||||
else
|
||||
if utils.file_exists(path) then
|
||||
return project_type
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return "unknown"
|
||||
end
|
||||
|
||||
--- Parse project dependencies
|
||||
---@param root string Project root path
|
||||
---@param project_type string Project type
|
||||
---@return {dependencies: table<string, string>, dev_dependencies: table<string, string>}
|
||||
function M.parse_dependencies(root, project_type)
|
||||
local deps = {
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
}
|
||||
|
||||
if project_type == "node" then
|
||||
deps = M.parse_package_json(root)
|
||||
elseif project_type == "rust" then
|
||||
deps = M.parse_cargo_toml(root)
|
||||
elseif project_type == "go" then
|
||||
deps = M.parse_go_mod(root)
|
||||
elseif project_type == "python" then
|
||||
deps = M.parse_python_deps(root)
|
||||
end
|
||||
|
||||
return deps
|
||||
end
|
||||
|
||||
--- Parse package.json for Node.js projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_package_json(root)
|
||||
local path = root .. "/package.json"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local ok, pkg = pcall(vim.json.decode, content)
|
||||
if not ok or not pkg then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
return {
|
||||
dependencies = pkg.dependencies or {},
|
||||
dev_dependencies = pkg.devDependencies or {},
|
||||
}
|
||||
end
|
||||
|
||||
--- Parse Cargo.toml for Rust projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_cargo_toml(root)
|
||||
local path = root .. "/Cargo.toml"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local deps = {}
|
||||
local dev_deps = {}
|
||||
local in_deps = false
|
||||
local in_dev_deps = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^%[dependencies%]") then
|
||||
in_deps = true
|
||||
in_dev_deps = false
|
||||
elseif line:match("^%[dev%-dependencies%]") then
|
||||
in_deps = false
|
||||
in_dev_deps = true
|
||||
elseif line:match("^%[") then
|
||||
in_deps = false
|
||||
in_dev_deps = false
|
||||
elseif in_deps or in_dev_deps then
|
||||
local name, version = line:match('^([%w_%-]+)%s*=%s*"([^"]+)"')
|
||||
if not name then
|
||||
name = line:match("^([%w_%-]+)%s*=")
|
||||
version = "workspace"
|
||||
end
|
||||
if name then
|
||||
if in_deps then
|
||||
deps[name] = version or "unknown"
|
||||
else
|
||||
dev_deps[name] = version or "unknown"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = dev_deps }
|
||||
end
|
||||
|
||||
--- Parse go.mod for Go projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_go_mod(root)
|
||||
local path = root .. "/go.mod"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local deps = {}
|
||||
local in_require = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^require%s*%(") then
|
||||
in_require = true
|
||||
elseif line:match("^%)") then
|
||||
in_require = false
|
||||
elseif in_require then
|
||||
local module, version = line:match("^%s*([%w%.%-%_/]+)%s+([%w%.%-]+)")
|
||||
if module then
|
||||
deps[module] = version
|
||||
end
|
||||
else
|
||||
local module, version = line:match("^require%s+([%w%.%-%_/]+)%s+([%w%.%-]+)")
|
||||
if module then
|
||||
deps[module] = version
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
--- Parse Python dependencies (pyproject.toml or requirements.txt)
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_python_deps(root)
|
||||
local deps = {}
|
||||
local dev_deps = {}
|
||||
|
||||
-- Try pyproject.toml first
|
||||
local pyproject = root .. "/pyproject.toml"
|
||||
local content = utils.read_file(pyproject)
|
||||
|
||||
if content then
|
||||
-- Simple parsing for dependencies
|
||||
local in_deps = false
|
||||
local in_dev = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^%[project%.dependencies%]") or line:match("^dependencies%s*=") then
|
||||
in_deps = true
|
||||
in_dev = false
|
||||
elseif line:match("dev") and line:match("dependencies") then
|
||||
in_deps = false
|
||||
in_dev = true
|
||||
elseif line:match("^%[") then
|
||||
in_deps = false
|
||||
in_dev = false
|
||||
elseif in_deps or in_dev then
|
||||
local name = line:match('"([%w_%-]+)')
|
||||
if name then
|
||||
if in_deps then
|
||||
deps[name] = "latest"
|
||||
else
|
||||
dev_deps[name] = "latest"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Fallback to requirements.txt
|
||||
local req_file = root .. "/requirements.txt"
|
||||
content = utils.read_file(req_file)
|
||||
|
||||
if content then
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if not line:match("^#") and not line:match("^%s*$") then
|
||||
local name, version = line:match("^([%w_%-]+)==([%d%.]+)")
|
||||
if not name then
|
||||
name = line:match("^([%w_%-]+)")
|
||||
version = "latest"
|
||||
end
|
||||
if name then
|
||||
deps[name] = version or "latest"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = dev_deps }
|
||||
end
|
||||
|
||||
--- Check if a file/directory should be ignored
|
||||
---@param name string File or directory name
|
||||
---@param config table Indexer configuration
|
||||
---@return boolean
|
||||
function M.should_ignore(name, config)
|
||||
-- Check default patterns
|
||||
for _, pattern in ipairs(DEFAULT_IGNORES) do
|
||||
if name:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
-- Check config excluded dirs
|
||||
if config and config.excluded_dirs then
|
||||
for _, dir in ipairs(config.excluded_dirs) do
|
||||
if name == dir then
|
||||
return true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if a file should be indexed
|
||||
---@param filepath string Full file path
|
||||
---@param config table Indexer configuration
|
||||
---@return boolean
|
||||
function M.should_index(filepath, config)
|
||||
local name = vim.fn.fnamemodify(filepath, ":t")
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
|
||||
-- Check if it's a coder file
|
||||
if utils.is_coder_file(filepath) then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Check file size
|
||||
if config and config.max_file_size then
|
||||
local stat = vim.loop.fs_stat(filepath)
|
||||
if stat and stat.size > config.max_file_size then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
-- Check extension
|
||||
if config and config.index_extensions then
|
||||
local valid_ext = false
|
||||
for _, allowed_ext in ipairs(config.index_extensions) do
|
||||
if ext == allowed_ext then
|
||||
valid_ext = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if not valid_ext then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
-- Check ignore patterns
|
||||
if M.should_ignore(name, config) then
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get all indexable files in the project
|
||||
---@param root string Project root path
|
||||
---@param config table Indexer configuration
|
||||
---@return string[] List of file paths
|
||||
function M.get_indexable_files(root, config)
|
||||
local files = {}
|
||||
|
||||
local function scan_dir(path)
|
||||
local handle = vim.loop.fs_scandir(path)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
|
||||
local full_path = path .. "/" .. name
|
||||
|
||||
if M.should_ignore(name, config) then
|
||||
goto continue
|
||||
end
|
||||
|
||||
if type == "directory" then
|
||||
scan_dir(full_path)
|
||||
elseif type == "file" then
|
||||
if M.should_index(full_path, config) then
|
||||
table.insert(files, full_path)
|
||||
end
|
||||
end
|
||||
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
|
||||
scan_dir(root)
|
||||
return files
|
||||
end
|
||||
|
||||
--- Get language from file extension
|
||||
---@param filepath string File path
|
||||
---@return string Language name
|
||||
function M.get_language(filepath)
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
return EXTENSION_LANGUAGE[ext] or ext
|
||||
end
|
||||
|
||||
--- Read .gitignore patterns
|
||||
---@param root string Project root
|
||||
---@return string[] Patterns
|
||||
function M.read_gitignore(root)
|
||||
local patterns = {}
|
||||
local path = root .. "/.gitignore"
|
||||
local content = utils.read_file(path)
|
||||
|
||||
if not content then
|
||||
return patterns
|
||||
end
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
-- Skip comments and empty lines
|
||||
if not line:match("^#") and not line:match("^%s*$") then
|
||||
-- Convert gitignore pattern to Lua pattern (simplified)
|
||||
local pattern = line:gsub("^/", "^"):gsub("%*%*", ".*"):gsub("%*", "[^/]*"):gsub("%?", ".")
|
||||
table.insert(patterns, pattern)
|
||||
end
|
||||
end
|
||||
|
||||
return patterns
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,7 +1,7 @@
|
||||
---@mod codetyper Codetyper.nvim - AI-powered coding partner
|
||||
---@brief [[
|
||||
--- Codetyper.nvim is a Neovim plugin that acts as your coding partner.
|
||||
--- It uses LLM APIs (Claude, OpenAI, Gemini, Copilot, Ollama) to help you
|
||||
--- It uses LLM APIs (OpenAI, Gemini, Copilot, Ollama) to help you
|
||||
--- write code faster using special `.coder.*` files and inline prompt tags.
|
||||
--- Features an event-driven scheduler with confidence scoring and
|
||||
--- completion-aware injection timing.
|
||||
@@ -51,6 +51,24 @@ function M.setup(opts)
|
||||
-- Initialize tree logging (creates .coder folder and initial tree.log)
|
||||
tree.setup()
|
||||
|
||||
-- Initialize project indexer if enabled
|
||||
if M.config.indexer and M.config.indexer.enabled then
|
||||
local indexer = require("codetyper.indexer")
|
||||
indexer.setup(M.config.indexer)
|
||||
end
|
||||
|
||||
-- Initialize brain learning system if enabled
|
||||
if M.config.brain and M.config.brain.enabled then
|
||||
local brain = require("codetyper.brain")
|
||||
brain.setup(M.config.brain)
|
||||
end
|
||||
|
||||
-- Setup inline ghost text suggestions (Copilot-style)
|
||||
if M.config.suggestion and M.config.suggestion.enabled then
|
||||
local suggestion = require("codetyper.suggestion")
|
||||
suggestion.setup(M.config.suggestion)
|
||||
end
|
||||
|
||||
-- Start the event-driven scheduler if enabled
|
||||
if M.config.scheduler and M.config.scheduler.enabled then
|
||||
local scheduler = require("codetyper.agent.scheduler")
|
||||
|
||||
@@ -1,364 +0,0 @@
|
||||
---@mod codetyper.llm.claude Claude API client for Codetyper.nvim
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local llm = require("codetyper.llm")
|
||||
|
||||
--- Claude API endpoint
|
||||
local API_URL = "https://api.anthropic.com/v1/messages"
|
||||
|
||||
--- Get API key from config or environment
|
||||
---@return string|nil API key
|
||||
local function get_api_key()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.claude.model
|
||||
end
|
||||
|
||||
--- Build request body for Claude API
|
||||
---@param prompt string User prompt
|
||||
---@param context table Context information
|
||||
---@return table Request body
|
||||
local function build_request_body(prompt, context)
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
|
||||
return {
|
||||
model = get_model(),
|
||||
max_tokens = 4096,
|
||||
system = system_prompt,
|
||||
messages = {
|
||||
{
|
||||
role = "user",
|
||||
content = prompt,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Make HTTP request to Claude API
|
||||
---@param body table Request body
|
||||
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
|
||||
local function make_request(body, callback)
|
||||
local api_key = get_api_key()
|
||||
if not api_key then
|
||||
callback(nil, "Claude API key not configured", nil)
|
||||
return
|
||||
end
|
||||
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
-- Use curl for HTTP request (plenary.curl alternative)
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X",
|
||||
"POST",
|
||||
API_URL,
|
||||
"-H",
|
||||
"Content-Type: application/json",
|
||||
"-H",
|
||||
"x-api-key: " .. api_key,
|
||||
"-H",
|
||||
"anthropic-version: 2023-06-01",
|
||||
"-d",
|
||||
json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Claude response", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error.message or "Claude API error", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract usage info
|
||||
local usage = response.usage or {}
|
||||
|
||||
if response.content and response.content[1] and response.content[1].text then
|
||||
local code = llm.extract_code(response.content[1].text)
|
||||
vim.schedule(function()
|
||||
callback(code, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No content in Claude response", nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Claude API request failed: " .. table.concat(data, "\n"), nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Claude API request failed with code: " .. code, nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Generate code using Claude API
|
||||
---@param prompt string The user's prompt
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate(prompt, context, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local model = get_model()
|
||||
|
||||
-- Log the request
|
||||
logs.request("claude", model)
|
||||
logs.thinking("Building request body...")
|
||||
|
||||
local body = build_request_body(prompt, context)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Claude API...")
|
||||
|
||||
utils.notify("Sending request to Claude...", vim.log.levels.INFO)
|
||||
|
||||
make_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
utils.notify(err, vim.log.levels.ERROR)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(usage.input_tokens or 0, usage.output_tokens or 0, "end_turn")
|
||||
end
|
||||
logs.thinking("Response received, extracting code...")
|
||||
logs.info("Code generated successfully")
|
||||
utils.notify("Code generated successfully", vim.log.levels.INFO)
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Check if Claude is properly configured
|
||||
---@return boolean, string? Valid status and optional error message
|
||||
function M.validate()
|
||||
local api_key = get_api_key()
|
||||
if not api_key or api_key == "" then
|
||||
return false, "Claude API key not configured"
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
--- Generate with tool use support for agentic mode
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil) Callback with raw response
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local model = get_model()
|
||||
|
||||
-- Log the request
|
||||
logs.request("claude", model)
|
||||
logs.thinking("Preparing agent request...")
|
||||
|
||||
local api_key = get_api_key()
|
||||
if not api_key then
|
||||
logs.error("Claude API key not configured")
|
||||
callback(nil, "Claude API key not configured")
|
||||
return
|
||||
end
|
||||
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
|
||||
-- Build system prompt with agent instructions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
|
||||
|
||||
-- Build request body with tools
|
||||
local body = {
|
||||
model = get_model(),
|
||||
max_tokens = 4096,
|
||||
system = system_prompt,
|
||||
messages = M.format_messages_for_claude(messages),
|
||||
tools = tools_module.to_claude_format(),
|
||||
}
|
||||
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(json_body)
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Claude API...")
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X",
|
||||
"POST",
|
||||
API_URL,
|
||||
"-H",
|
||||
"Content-Type: application/json",
|
||||
"-H",
|
||||
"x-api-key: " .. api_key,
|
||||
"-H",
|
||||
"anthropic-version: 2023-06-01",
|
||||
"-d",
|
||||
json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
logs.error("Failed to parse Claude response")
|
||||
callback(nil, "Failed to parse Claude response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
logs.error(response.error.message or "Claude API error")
|
||||
callback(nil, response.error.message or "Claude API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Log token usage from response
|
||||
if response.usage then
|
||||
logs.response(response.usage.input_tokens or 0, response.usage.output_tokens or 0, response.stop_reason)
|
||||
end
|
||||
|
||||
-- Log what's in the response
|
||||
if response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
logs.thinking("Response contains text")
|
||||
elseif block.type == "tool_use" then
|
||||
logs.thinking("Tool call: " .. block.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Return raw response for parser to handle
|
||||
vim.schedule(function()
|
||||
callback(response, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
logs.error("Claude API request failed: " .. table.concat(data, "\n"))
|
||||
callback(nil, "Claude API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
logs.error("Claude API request failed with code: " .. code)
|
||||
callback(nil, "Claude API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Format messages for Claude API
|
||||
---@param messages table[] Internal message format
|
||||
---@return table[] Claude API message format
|
||||
function M.format_messages_for_claude(messages)
|
||||
local formatted = {}
|
||||
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
if type(msg.content) == "table" then
|
||||
-- Tool results
|
||||
table.insert(formatted, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
else
|
||||
table.insert(formatted, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "assistant" then
|
||||
-- Build content array for assistant messages
|
||||
local content = {}
|
||||
|
||||
-- Add text if present
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
|
||||
-- Add tool uses if present
|
||||
if msg.tool_calls then
|
||||
for _, tool_call in ipairs(msg.tool_calls) do
|
||||
table.insert(content, {
|
||||
type = "tool_use",
|
||||
id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
input = tool_call.parameters,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if #content > 0 then
|
||||
table.insert(formatted, {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return formatted
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -14,6 +14,51 @@ local AUTH_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
---@field github_token table|nil
|
||||
M.state = nil
|
||||
|
||||
--- Track if we've already suggested Ollama fallback this session
|
||||
local ollama_fallback_suggested = false
|
||||
|
||||
--- Suggest switching to Ollama when rate limits are hit
|
||||
---@param error_msg string The error message that triggered this
|
||||
function M.suggest_ollama_fallback(error_msg)
|
||||
if ollama_fallback_suggested then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if Ollama is available
|
||||
local ollama_available = false
|
||||
vim.fn.jobstart({ "curl", "-s", "http://localhost:11434/api/tags" }, {
|
||||
on_exit = function(_, code)
|
||||
if code == 0 then
|
||||
ollama_available = true
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
if ollama_available then
|
||||
-- Switch to Ollama automatically
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
config.llm.provider = "ollama"
|
||||
|
||||
ollama_fallback_suggested = true
|
||||
utils.notify(
|
||||
"⚠️ Copilot rate limit reached. Switched to Ollama automatically.\n"
|
||||
.. "Original error: "
|
||||
.. error_msg:sub(1, 100),
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
else
|
||||
utils.notify(
|
||||
"⚠️ Copilot rate limit reached. Ollama not available.\n"
|
||||
.. "Start Ollama with: ollama serve\n"
|
||||
.. "Or wait for Copilot limits to reset.",
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
end
|
||||
end)
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Get OAuth token from copilot.lua or copilot.vim config
|
||||
---@return string|nil OAuth token
|
||||
local function get_oauth_token()
|
||||
@@ -51,9 +96,16 @@ local function get_oauth_token()
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("copilot")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.copilot.model
|
||||
@@ -204,15 +256,37 @@ local function make_request(token, body, callback)
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
-- Show the actual response text as the error (truncated if too long)
|
||||
local error_msg = response_text
|
||||
if #error_msg > 200 then
|
||||
error_msg = error_msg:sub(1, 200) .. "..."
|
||||
end
|
||||
|
||||
-- Clean up common patterns
|
||||
if response_text:match("<!DOCTYPE") or response_text:match("<html") then
|
||||
error_msg = "Copilot API returned HTML error page. Service may be unavailable."
|
||||
end
|
||||
|
||||
-- Check for rate limit and suggest Ollama fallback
|
||||
if response_text:match("limit") or response_text:match("Upgrade") or response_text:match("quota") then
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Copilot response", nil)
|
||||
callback(nil, error_msg, nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
local error_msg = response.error.message or "Copilot API error"
|
||||
if response.error.code == "rate_limit_exceeded" or (error_msg:match("limit") and error_msg:match("plan")) then
|
||||
error_msg = "Copilot rate limit: " .. error_msg
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error.message or "Copilot API error", nil)
|
||||
callback(nil, error_msg, nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
@@ -220,6 +294,17 @@ local function make_request(token, body, callback)
|
||||
-- Extract usage info
|
||||
local usage = response.usage or {}
|
||||
|
||||
-- Record usage for cost tracking
|
||||
if usage.prompt_tokens or usage.completion_tokens then
|
||||
local cost = require("codetyper.cost")
|
||||
cost.record_usage(
|
||||
get_model(),
|
||||
usage.prompt_tokens or 0,
|
||||
usage.completion_tokens or 0,
|
||||
usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
if response.choices and response.choices[1] and response.choices[1].message then
|
||||
local code = llm.extract_code(response.choices[1].message.content)
|
||||
vim.schedule(function()
|
||||
@@ -362,20 +447,46 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
-- Format messages for Copilot (OpenAI-compatible format)
|
||||
local copilot_messages = { { role = "system", content = system_prompt } }
|
||||
for _, msg in ipairs(messages) do
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(copilot_messages, { role = msg.role, content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(msg.content) do
|
||||
if part.type == "tool_result" then
|
||||
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
|
||||
elseif part.type == "text" then
|
||||
table.insert(text_parts, part.text or "")
|
||||
if msg.role == "user" then
|
||||
-- User messages - handle string or table content
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(copilot_messages, { role = "user", content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
-- Handle complex content (like tool results from user perspective)
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(msg.content) do
|
||||
if part.type == "tool_result" then
|
||||
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
|
||||
elseif part.type == "text" then
|
||||
table.insert(text_parts, part.text or "")
|
||||
end
|
||||
end
|
||||
if #text_parts > 0 then
|
||||
table.insert(copilot_messages, { role = "user", content = table.concat(text_parts, "\n") })
|
||||
end
|
||||
end
|
||||
if #text_parts > 0 then
|
||||
table.insert(copilot_messages, { role = msg.role, content = table.concat(text_parts, "\n") })
|
||||
elseif msg.role == "assistant" then
|
||||
-- Assistant messages - must preserve tool_calls if present
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = type(msg.content) == "string" and msg.content or nil,
|
||||
}
|
||||
-- Preserve tool_calls for the API
|
||||
if msg.tool_calls then
|
||||
assistant_msg.tool_calls = msg.tool_calls
|
||||
-- Ensure content is not nil when tool_calls present
|
||||
if assistant_msg.content == nil then
|
||||
assistant_msg.content = ""
|
||||
end
|
||||
end
|
||||
table.insert(copilot_messages, assistant_msg)
|
||||
elseif msg.role == "tool" then
|
||||
-- Tool result messages - must have tool_call_id
|
||||
table.insert(copilot_messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
@@ -396,6 +507,20 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Copilot API...")
|
||||
|
||||
-- Log request to debug file
|
||||
local debug_log_path = vim.fn.expand("~/.local/codetyper-debug.log")
|
||||
local debug_f = io.open(debug_log_path, "a")
|
||||
if debug_f then
|
||||
debug_f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. "COPILOT REQUEST\n")
|
||||
debug_f:write("Messages count: " .. #copilot_messages .. "\n")
|
||||
for i, m in ipairs(copilot_messages) do
|
||||
debug_f:write(string.format(" [%d] role=%s, has_tool_calls=%s, has_tool_call_id=%s\n",
|
||||
i, m.role, tostring(m.tool_calls ~= nil), tostring(m.tool_call_id ~= nil)))
|
||||
end
|
||||
debug_f:write("---\n")
|
||||
debug_f:close()
|
||||
end
|
||||
|
||||
local headers = build_headers(token)
|
||||
local cmd = {
|
||||
"curl",
|
||||
@@ -413,35 +538,97 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
table.insert(cmd, "-d")
|
||||
table.insert(cmd, json_body)
|
||||
|
||||
-- Debug logging helper
|
||||
local function debug_log(msg, data)
|
||||
local log_path = vim.fn.expand("~/.local/codetyper-debug.log")
|
||||
local f = io.open(log_path, "a")
|
||||
if f then
|
||||
f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. msg .. "\n")
|
||||
if data then
|
||||
f:write("DATA: " .. tostring(data):sub(1, 2000) .. "\n")
|
||||
end
|
||||
f:write("---\n")
|
||||
f:close()
|
||||
end
|
||||
end
|
||||
|
||||
-- Prevent double callback calls
|
||||
local callback_called = false
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if callback_called then
|
||||
debug_log("on_stdout: callback already called, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
debug_log("on_stdout: empty data")
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
debug_log("on_stdout: received response", response_text)
|
||||
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
debug_log("JSON parse failed", response_text)
|
||||
callback_called = true
|
||||
|
||||
-- Show the actual response text as the error (truncated if too long)
|
||||
local error_msg = response_text
|
||||
if #error_msg > 200 then
|
||||
error_msg = error_msg:sub(1, 200) .. "..."
|
||||
end
|
||||
|
||||
-- Clean up common patterns
|
||||
if response_text:match("<!DOCTYPE") or response_text:match("<html") then
|
||||
error_msg = "Copilot API returned HTML error page. Service may be unavailable."
|
||||
end
|
||||
|
||||
-- Check for rate limit and suggest Ollama fallback
|
||||
if response_text:match("limit") or response_text:match("Upgrade") or response_text:match("quota") then
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
logs.error("Failed to parse Copilot response")
|
||||
callback(nil, "Failed to parse Copilot response")
|
||||
logs.error(error_msg)
|
||||
callback(nil, error_msg)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
callback_called = true
|
||||
local error_msg = response.error.message or "Copilot API error"
|
||||
|
||||
-- Check for rate limit in structured error
|
||||
if response.error.code == "rate_limit_exceeded" or (error_msg:match("limit") and error_msg:match("plan")) then
|
||||
error_msg = "Copilot rate limit: " .. error_msg
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
logs.error(response.error.message or "Copilot API error")
|
||||
callback(nil, response.error.message or "Copilot API error")
|
||||
logs.error(error_msg)
|
||||
callback(nil, error_msg)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Log token usage
|
||||
-- Log token usage and record cost
|
||||
if response.usage then
|
||||
logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop")
|
||||
|
||||
-- Record usage for cost tracking
|
||||
local cost_tracker = require("codetyper.cost")
|
||||
cost_tracker.record_usage(
|
||||
get_model(),
|
||||
response.usage.prompt_tokens or 0,
|
||||
response.usage.completion_tokens or 0,
|
||||
response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
-- Convert to Claude-like format for parser compatibility
|
||||
@@ -474,12 +661,19 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
end
|
||||
end
|
||||
|
||||
callback_called = true
|
||||
debug_log("on_stdout: success, calling callback")
|
||||
vim.schedule(function()
|
||||
callback(converted, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if callback_called then
|
||||
return
|
||||
end
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
debug_log("on_stderr", table.concat(data, "\n"))
|
||||
callback_called = true
|
||||
vim.schedule(function()
|
||||
logs.error("Copilot API request failed: " .. table.concat(data, "\n"))
|
||||
callback(nil, "Copilot API request failed: " .. table.concat(data, "\n"))
|
||||
@@ -487,7 +681,12 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
debug_log("on_exit: code=" .. code .. ", callback_called=" .. tostring(callback_called))
|
||||
if callback_called then
|
||||
return
|
||||
end
|
||||
if code ~= 0 then
|
||||
callback_called = true
|
||||
vim.schedule(function()
|
||||
logs.error("Copilot API request failed with code: " .. code)
|
||||
callback(nil, "Copilot API request failed with code: " .. code)
|
||||
|
||||
@@ -8,17 +8,31 @@ local llm = require("codetyper.llm")
|
||||
--- Gemini API endpoint
|
||||
local API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
--- Get API key from config or environment
|
||||
--- Get API key from stored credentials, config, or environment
|
||||
---@return string|nil API key
|
||||
local function get_api_key()
|
||||
-- Priority: stored credentials > config > environment
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_key = credentials.get_api_key("gemini")
|
||||
if stored_key then
|
||||
return stored_key
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.gemini.api_key or vim.env.GEMINI_API_KEY
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("gemini")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.gemini.model
|
||||
|
||||
@@ -10,9 +10,7 @@ function M.get_client()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
return require("codetyper.llm.claude")
|
||||
elseif config.llm.provider == "ollama" then
|
||||
if config.llm.provider == "ollama" then
|
||||
return require("codetyper.llm.ollama")
|
||||
elseif config.llm.provider == "openai" then
|
||||
return require("codetyper.llm.openai")
|
||||
@@ -34,6 +32,32 @@ function M.generate(prompt, context, callback)
|
||||
client.generate(prompt, context, callback)
|
||||
end
|
||||
|
||||
--- Smart generate with automatic provider selection based on brain memories
|
||||
--- Prefers Ollama when context is rich, falls back to Copilot otherwise.
|
||||
--- Implements verification pondering to reinforce Ollama accuracy over time.
|
||||
---@param prompt string The user's prompt
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback
|
||||
function M.smart_generate(prompt, context, callback)
|
||||
local selector = require("codetyper.llm.selector")
|
||||
selector.smart_generate(prompt, context, callback)
|
||||
end
|
||||
|
||||
--- Get accuracy statistics for providers
|
||||
---@return table Statistics for each provider
|
||||
function M.get_accuracy_stats()
|
||||
local selector = require("codetyper.llm.selector")
|
||||
return selector.get_accuracy_stats()
|
||||
end
|
||||
|
||||
--- Report user feedback on response quality (for reinforcement learning)
|
||||
---@param provider string Which provider generated the response
|
||||
---@param was_correct boolean Whether the response was good
|
||||
function M.report_feedback(provider, was_correct)
|
||||
local selector = require("codetyper.llm.selector")
|
||||
selector.report_feedback(provider, was_correct)
|
||||
end
|
||||
|
||||
--- Build the system prompt for code generation
|
||||
---@param context table Context information
|
||||
---@return string System prompt
|
||||
@@ -50,7 +74,49 @@ function M.build_system_prompt(context)
|
||||
system = system:gsub("{{language}}", context.language or "unknown")
|
||||
system = system:gsub("{{filepath}}", context.file_path or "unknown")
|
||||
|
||||
-- Add file content with analysis hints
|
||||
-- For agent mode, include project context
|
||||
if prompt_type == "agent" then
|
||||
local project_info = "\n\n## PROJECT CONTEXT\n"
|
||||
|
||||
if context.project_root then
|
||||
project_info = project_info .. "- Project root: " .. context.project_root .. "\n"
|
||||
end
|
||||
if context.cwd then
|
||||
project_info = project_info .. "- Working directory: " .. context.cwd .. "\n"
|
||||
end
|
||||
if context.project_type then
|
||||
project_info = project_info .. "- Project type: " .. context.project_type .. "\n"
|
||||
end
|
||||
if context.project_stats then
|
||||
project_info = project_info
|
||||
.. string.format(
|
||||
"- Stats: %d files, %d functions, %d classes\n",
|
||||
context.project_stats.files or 0,
|
||||
context.project_stats.functions or 0,
|
||||
context.project_stats.classes or 0
|
||||
)
|
||||
end
|
||||
if context.file_path then
|
||||
project_info = project_info .. "- Current file: " .. context.file_path .. "\n"
|
||||
end
|
||||
|
||||
system = system .. project_info
|
||||
return system
|
||||
end
|
||||
|
||||
-- For "ask" or "explain" mode, don't add code generation instructions
|
||||
if prompt_type == "ask" or prompt_type == "explain" then
|
||||
-- Just add context about the file if available
|
||||
if context.file_path then
|
||||
system = system .. "\n\nContext: The user is working with " .. context.file_path
|
||||
if context.language then
|
||||
system = system .. " (" .. context.language .. ")"
|
||||
end
|
||||
end
|
||||
return system
|
||||
end
|
||||
|
||||
-- Add file content with analysis hints (for code generation modes)
|
||||
if context.file_content and context.file_content ~= "" then
|
||||
system = system .. "\n\n===== EXISTING FILE CONTENT (analyze and match this style) =====\n"
|
||||
system = system .. context.file_content
|
||||
@@ -74,13 +140,34 @@ function M.build_context(target_path, prompt_type)
|
||||
local content = utils.read_file(target_path)
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
return {
|
||||
local context = {
|
||||
file_content = content,
|
||||
language = lang_map[ext] or ext,
|
||||
extension = ext,
|
||||
prompt_type = prompt_type,
|
||||
file_path = target_path,
|
||||
}
|
||||
|
||||
-- For agent mode, include additional project context
|
||||
if prompt_type == "agent" then
|
||||
local project_root = utils.get_project_root()
|
||||
context.project_root = project_root
|
||||
|
||||
-- Try to get project info from indexer
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
local status = indexer.get_status()
|
||||
if status.indexed then
|
||||
context.project_type = status.project_type
|
||||
context.project_stats = status.stats
|
||||
end
|
||||
end
|
||||
|
||||
-- Include working directory
|
||||
context.cwd = vim.fn.getcwd()
|
||||
end
|
||||
|
||||
return context
|
||||
end
|
||||
|
||||
--- Parse LLM response and extract code
|
||||
|
||||
@@ -5,21 +5,33 @@ local M = {}
|
||||
local utils = require("codetyper.utils")
|
||||
local llm = require("codetyper.llm")
|
||||
|
||||
--- Get Ollama host from config
|
||||
--- Get Ollama host from stored credentials or config
|
||||
---@return string Host URL
|
||||
local function get_host()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_host = credentials.get_ollama_host()
|
||||
if stored_host then
|
||||
return stored_host
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.ollama.host
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("ollama")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.ollama.model
|
||||
end
|
||||
|
||||
@@ -199,47 +211,41 @@ function M.validate()
|
||||
return true
|
||||
end
|
||||
|
||||
--- Build system prompt for agent mode with tool instructions
|
||||
--- Generate with tool use support for agentic mode (text-based tool calling)
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@return string System prompt
|
||||
local function build_agent_system_prompt(context)
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil) Callback with Claude-like response format
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
|
||||
local system_prompt = agent_prompts.system .. "\n\n"
|
||||
system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n"
|
||||
system_prompt = system_prompt .. agent_prompts.tool_instructions
|
||||
logs.request("ollama", get_model())
|
||||
logs.thinking("Preparing agent request...")
|
||||
|
||||
-- Add context about current file if available
|
||||
if context.file_path then
|
||||
system_prompt = system_prompt .. "\n\nCurrent working context:\n"
|
||||
system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n"
|
||||
if context.language then
|
||||
system_prompt = system_prompt .. "- Language: " .. context.language .. "\n"
|
||||
-- Build system prompt with tool instructions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
|
||||
|
||||
-- Add tool descriptions
|
||||
system_prompt = system_prompt .. "\n\n## Available Tools\n"
|
||||
system_prompt = system_prompt .. "Call tools by outputting JSON in this exact format:\n"
|
||||
system_prompt = system_prompt .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
|
||||
|
||||
for _, tool in ipairs(tool_definitions) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
system_prompt = system_prompt .. string.format("### %s\n%s\n\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
-- Add project root info
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
system_prompt = system_prompt .. "- Project root: " .. root .. "\n"
|
||||
end
|
||||
|
||||
return system_prompt
|
||||
end
|
||||
|
||||
--- Build request body for Ollama API with tools (chat format)
|
||||
---@param messages table[] Conversation messages
|
||||
---@param context table Context information
|
||||
---@return table Request body
|
||||
local function build_tools_request_body(messages, context)
|
||||
local system_prompt = build_agent_system_prompt(context)
|
||||
|
||||
-- Convert messages to Ollama chat format
|
||||
local ollama_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
local content = msg.content
|
||||
-- Handle complex content (like tool results)
|
||||
if type(content) == "table" then
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(content) do
|
||||
@@ -251,14 +257,10 @@ local function build_tools_request_body(messages, context)
|
||||
end
|
||||
content = table.concat(text_parts, "\n")
|
||||
end
|
||||
|
||||
table.insert(ollama_messages, {
|
||||
role = msg.role,
|
||||
content = content,
|
||||
})
|
||||
table.insert(ollama_messages, { role = msg.role, content = content })
|
||||
end
|
||||
|
||||
return {
|
||||
local body = {
|
||||
model = get_model(),
|
||||
messages = ollama_messages,
|
||||
system = system_prompt,
|
||||
@@ -268,16 +270,15 @@ local function build_tools_request_body(messages, context)
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Make HTTP request to Ollama chat API
|
||||
---@param body table Request body
|
||||
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
|
||||
local function make_chat_request(body, callback)
|
||||
local host = get_host()
|
||||
local url = host .. "/api/chat"
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local prompt_estimate = logs.estimate_tokens(json_body)
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Ollama API...")
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
@@ -302,196 +303,82 @@ local function make_chat_request(body, callback)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response", nil)
|
||||
logs.error("Failed to parse Ollama response")
|
||||
callback(nil, "Failed to parse Ollama response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error", nil)
|
||||
logs.error(response.error or "Ollama API error")
|
||||
callback(nil, response.error or "Ollama API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract usage info
|
||||
local usage = {
|
||||
prompt_tokens = response.prompt_eval_count or 0,
|
||||
response_tokens = response.eval_count or 0,
|
||||
}
|
||||
-- Log token usage and record cost (Ollama is free but we track usage)
|
||||
if response.prompt_eval_count or response.eval_count then
|
||||
logs.response(response.prompt_eval_count or 0, response.eval_count or 0, "stop")
|
||||
|
||||
-- Return the message content for agent parsing
|
||||
if response.message and response.message.content then
|
||||
vim.schedule(function()
|
||||
callback(response.message.content, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No response from Ollama", nil)
|
||||
end)
|
||||
-- Record usage for cost tracking (free for local models)
|
||||
local cost = require("codetyper.cost")
|
||||
cost.record_usage(
|
||||
get_model(),
|
||||
response.prompt_eval_count or 0,
|
||||
response.eval_count or 0,
|
||||
0 -- No cached tokens for Ollama
|
||||
)
|
||||
end
|
||||
|
||||
-- Parse the response text for tool calls
|
||||
local content_text = response.message and response.message.content or ""
|
||||
local converted = { content = {}, stop_reason = "end_turn" }
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = content_text:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok_json, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok_json and parsed.tool then
|
||||
table.insert(converted.content, {
|
||||
type = "tool_use",
|
||||
id = "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF)),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
logs.thinking("Tool call: " .. parsed.tool)
|
||||
content_text = content_text:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
converted.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
-- Add text content
|
||||
if content_text and content_text ~= "" then
|
||||
table.insert(converted.content, 1, { type = "text", text = content_text })
|
||||
logs.thinking("Response contains text")
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(converted, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
|
||||
logs.error("Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
-- Don't double-report errors
|
||||
vim.schedule(function()
|
||||
logs.error("Ollama API request failed with code: " .. code)
|
||||
callback(nil, "Ollama API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Generate response with tools using Ollama API
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tools table Tool definitions (embedded in prompt for Ollama)
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate_with_tools(messages, context, tools, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
-- Log the request
|
||||
local model = get_model()
|
||||
logs.request("ollama", model)
|
||||
logs.thinking("Preparing API request...")
|
||||
|
||||
local body = build_tools_request_body(messages, context)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
|
||||
make_chat_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn")
|
||||
end
|
||||
|
||||
-- Log if response contains tool calls
|
||||
if response then
|
||||
local parser = require("codetyper.agent.parser")
|
||||
local parsed = parser.parse_ollama_response(response)
|
||||
if #parsed.tool_calls > 0 then
|
||||
for _, tc in ipairs(parsed.tool_calls) do
|
||||
logs.thinking("Tool call: " .. tc.name)
|
||||
end
|
||||
end
|
||||
if parsed.text and parsed.text ~= "" then
|
||||
logs.thinking("Response contains text")
|
||||
end
|
||||
end
|
||||
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Generate with tool use support for agentic mode (simulated via prompts)
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback with response text
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
|
||||
-- Build system prompt with agent instructions and tool definitions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. tools_module.to_prompt_format()
|
||||
|
||||
-- Flatten messages to single prompt (Ollama's generate API)
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if type(msg.content) == "string" then
|
||||
local role_prefix = msg.role == "user" and "User" or "Assistant"
|
||||
table.insert(prompt_parts, role_prefix .. ": " .. msg.content)
|
||||
elseif type(msg.content) == "table" then
|
||||
-- Handle tool results
|
||||
for _, item in ipairs(msg.content) do
|
||||
if item.type == "tool_result" then
|
||||
table.insert(prompt_parts, "Tool result: " .. item.content)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local body = {
|
||||
model = get_model(),
|
||||
system = system_prompt,
|
||||
prompt = table.concat(prompt_parts, "\n\n"),
|
||||
stream = false,
|
||||
options = {
|
||||
temperature = 0.2,
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
|
||||
local host = get_host()
|
||||
local url = host .. "/api/generate"
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X", "POST",
|
||||
url,
|
||||
"-H", "Content-Type: application/json",
|
||||
"-d", json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Return raw response text for parser to handle
|
||||
vim.schedule(function()
|
||||
callback(response.response or "", nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -8,25 +8,46 @@ local llm = require("codetyper.llm")
|
||||
--- OpenAI API endpoint
|
||||
local API_URL = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
--- Get API key from config or environment
|
||||
--- Get API key from stored credentials, config, or environment
|
||||
---@return string|nil API key
|
||||
local function get_api_key()
|
||||
-- Priority: stored credentials > config > environment
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_key = credentials.get_api_key("openai")
|
||||
if stored_key then
|
||||
return stored_key
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.api_key or vim.env.OPENAI_API_KEY
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("openai")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.model
|
||||
end
|
||||
|
||||
--- Get endpoint from config (allows custom endpoints like Azure, OpenRouter)
|
||||
--- Get endpoint from stored credentials or config (allows custom endpoints like Azure, OpenRouter)
|
||||
---@return string API endpoint
|
||||
local function get_endpoint()
|
||||
-- Priority: stored credentials > config > default
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_endpoint = credentials.get_endpoint("openai")
|
||||
if stored_endpoint then
|
||||
return stored_endpoint
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.endpoint or API_URL
|
||||
@@ -284,9 +305,18 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
return
|
||||
end
|
||||
|
||||
-- Log token usage
|
||||
-- Log token usage and record cost
|
||||
if response.usage then
|
||||
logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop")
|
||||
|
||||
-- Record usage for cost tracking
|
||||
local cost = require("codetyper.cost")
|
||||
cost.record_usage(
|
||||
model,
|
||||
response.usage.prompt_tokens or 0,
|
||||
response.usage.completion_tokens or 0,
|
||||
response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
-- Convert to Claude-like format for parser compatibility
|
||||
|
||||
514
lua/codetyper/llm/selector.lua
Normal file
514
lua/codetyper/llm/selector.lua
Normal file
@@ -0,0 +1,514 @@
|
||||
---@mod codetyper.llm.selector Smart LLM selection with memory-based confidence
|
||||
---@brief [[
|
||||
--- Intelligent LLM provider selection based on brain memories.
|
||||
--- Prefers local Ollama when context is rich, falls back to Copilot otherwise.
|
||||
--- Implements verification pondering to reinforce Ollama accuracy over time.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class SelectionResult
|
||||
---@field provider string Selected provider name
|
||||
---@field confidence number Confidence score (0-1)
|
||||
---@field memory_count number Number of relevant memories found
|
||||
---@field reason string Human-readable reason for selection
|
||||
|
||||
---@class PonderResult
|
||||
---@field ollama_response string Ollama's response
|
||||
---@field verifier_response string Verifier's response
|
||||
---@field agreement_score number How much they agree (0-1)
|
||||
---@field ollama_correct boolean Whether Ollama was deemed correct
|
||||
---@field feedback string Feedback for learning
|
||||
|
||||
--- Minimum memories required for high confidence
|
||||
local MIN_MEMORIES_FOR_LOCAL = 3
|
||||
|
||||
--- Minimum memory relevance score for local provider
|
||||
local MIN_RELEVANCE_FOR_LOCAL = 0.6
|
||||
|
||||
--- Agreement threshold for Ollama verification
|
||||
local AGREEMENT_THRESHOLD = 0.7
|
||||
|
||||
--- Pondering sample rate (0-1) - how often to verify Ollama
|
||||
local PONDER_SAMPLE_RATE = 0.2
|
||||
|
||||
--- Provider accuracy tracking (persisted in brain)
|
||||
local accuracy_cache = {
|
||||
ollama = { correct = 0, total = 0 },
|
||||
copilot = { correct = 0, total = 0 },
|
||||
}
|
||||
|
||||
--- Get the brain module safely
|
||||
---@return table|nil
|
||||
local function get_brain()
|
||||
local ok, brain = pcall(require, "codetyper.brain")
|
||||
if ok and brain.is_initialized and brain.is_initialized() then
|
||||
return brain
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Load accuracy stats from brain
|
||||
local function load_accuracy_stats()
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return
|
||||
end
|
||||
|
||||
-- Query for accuracy tracking nodes
|
||||
pcall(function()
|
||||
local result = brain.query({
|
||||
query = "provider_accuracy_stats",
|
||||
types = { "metric" },
|
||||
limit = 1,
|
||||
})
|
||||
|
||||
if result and result.nodes and #result.nodes > 0 then
|
||||
local node = result.nodes[1]
|
||||
if node.c and node.c.d then
|
||||
local ok, stats = pcall(vim.json.decode, node.c.d)
|
||||
if ok and stats then
|
||||
accuracy_cache = stats
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Save accuracy stats to brain
|
||||
local function save_accuracy_stats()
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return
|
||||
end
|
||||
|
||||
pcall(function()
|
||||
brain.learn({
|
||||
type = "metric",
|
||||
summary = "provider_accuracy_stats",
|
||||
detail = vim.json.encode(accuracy_cache),
|
||||
weight = 1.0,
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Calculate Ollama confidence based on historical accuracy
|
||||
---@return number confidence (0-1)
|
||||
local function get_ollama_historical_confidence()
|
||||
local stats = accuracy_cache.ollama
|
||||
if stats.total < 5 then
|
||||
-- Not enough data, return neutral confidence
|
||||
return 0.5
|
||||
end
|
||||
|
||||
local accuracy = stats.correct / stats.total
|
||||
-- Boost confidence if accuracy is high
|
||||
return math.min(1.0, accuracy * 1.2)
|
||||
end
|
||||
|
||||
--- Query brain for relevant context
|
||||
---@param prompt string User prompt
|
||||
---@param file_path string|nil Current file path
|
||||
---@return table result {memories: table[], relevance: number, count: number}
|
||||
local function query_brain_context(prompt, file_path)
|
||||
local result = {
|
||||
memories = {},
|
||||
relevance = 0,
|
||||
count = 0,
|
||||
}
|
||||
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return result
|
||||
end
|
||||
|
||||
-- Query brain with multiple dimensions
|
||||
local ok, query_result = pcall(function()
|
||||
return brain.query({
|
||||
query = prompt,
|
||||
file = file_path,
|
||||
limit = 10,
|
||||
types = { "pattern", "correction", "convention", "fact" },
|
||||
})
|
||||
end)
|
||||
|
||||
if not ok or not query_result then
|
||||
return result
|
||||
end
|
||||
|
||||
result.memories = query_result.nodes or {}
|
||||
result.count = #result.memories
|
||||
|
||||
-- Calculate average relevance
|
||||
if result.count > 0 then
|
||||
local total_relevance = 0
|
||||
for _, node in ipairs(result.memories) do
|
||||
-- Use node weight and success rate as relevance indicators
|
||||
local node_relevance = (node.sc and node.sc.w or 0.5) * (node.sc and node.sc.sr or 0.5)
|
||||
total_relevance = total_relevance + node_relevance
|
||||
end
|
||||
result.relevance = total_relevance / result.count
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Select the best LLM provider based on context
|
||||
---@param prompt string User prompt
|
||||
---@param context table LLM context
|
||||
---@return SelectionResult
|
||||
function M.select_provider(prompt, context)
|
||||
-- Load accuracy stats on first call
|
||||
if accuracy_cache.ollama.total == 0 then
|
||||
load_accuracy_stats()
|
||||
end
|
||||
|
||||
local file_path = context.file_path
|
||||
|
||||
-- Query brain for relevant memories
|
||||
local brain_context = query_brain_context(prompt, file_path)
|
||||
|
||||
-- Calculate base confidence from memories
|
||||
local memory_confidence = 0
|
||||
if brain_context.count >= MIN_MEMORIES_FOR_LOCAL then
|
||||
memory_confidence = math.min(1.0, brain_context.count / 10) * brain_context.relevance
|
||||
end
|
||||
|
||||
-- Factor in historical Ollama accuracy
|
||||
local historical_confidence = get_ollama_historical_confidence()
|
||||
|
||||
-- Combined confidence score
|
||||
local combined_confidence = (memory_confidence * 0.6) + (historical_confidence * 0.4)
|
||||
|
||||
-- Decision logic
|
||||
local provider = "copilot" -- Default to more capable
|
||||
local reason = ""
|
||||
|
||||
if brain_context.count >= MIN_MEMORIES_FOR_LOCAL and combined_confidence >= MIN_RELEVANCE_FOR_LOCAL then
|
||||
provider = "ollama"
|
||||
reason = string.format(
|
||||
"Rich context: %d memories (%.1f%% relevance), historical accuracy: %.1f%%",
|
||||
brain_context.count,
|
||||
brain_context.relevance * 100,
|
||||
historical_confidence * 100
|
||||
)
|
||||
elseif brain_context.count > 0 and combined_confidence >= 0.4 then
|
||||
-- Medium confidence - use Ollama but with pondering
|
||||
provider = "ollama"
|
||||
reason = string.format(
|
||||
"Moderate context: %d memories, will verify with pondering",
|
||||
brain_context.count
|
||||
)
|
||||
else
|
||||
reason = string.format(
|
||||
"Insufficient context: %d memories (need %d), using capable provider",
|
||||
brain_context.count,
|
||||
MIN_MEMORIES_FOR_LOCAL
|
||||
)
|
||||
end
|
||||
|
||||
return {
|
||||
provider = provider,
|
||||
confidence = combined_confidence,
|
||||
memory_count = brain_context.count,
|
||||
reason = reason,
|
||||
memories = brain_context.memories,
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if we should ponder (verify) this Ollama response
|
||||
---@param confidence number Current confidence level
|
||||
---@return boolean
|
||||
function M.should_ponder(confidence)
|
||||
-- Always ponder when confidence is medium
|
||||
if confidence >= 0.4 and confidence < 0.7 then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Random sampling for high confidence to keep learning
|
||||
if confidence >= 0.7 then
|
||||
return math.random() < PONDER_SAMPLE_RATE
|
||||
end
|
||||
|
||||
-- Low confidence shouldn't reach Ollama anyway
|
||||
return false
|
||||
end
|
||||
|
||||
--- Calculate agreement score between two responses
|
||||
---@param response1 string First response
|
||||
---@param response2 string Second response
|
||||
---@return number Agreement score (0-1)
|
||||
local function calculate_agreement(response1, response2)
|
||||
-- Normalize responses
|
||||
local norm1 = response1:lower():gsub("%s+", " "):gsub("[^%w%s]", "")
|
||||
local norm2 = response2:lower():gsub("%s+", " "):gsub("[^%w%s]", "")
|
||||
|
||||
-- Extract words
|
||||
local words1 = {}
|
||||
for word in norm1:gmatch("%w+") do
|
||||
words1[word] = (words1[word] or 0) + 1
|
||||
end
|
||||
|
||||
local words2 = {}
|
||||
for word in norm2:gmatch("%w+") do
|
||||
words2[word] = (words2[word] or 0) + 1
|
||||
end
|
||||
|
||||
-- Calculate Jaccard similarity
|
||||
local intersection = 0
|
||||
local union = 0
|
||||
|
||||
for word, count1 in pairs(words1) do
|
||||
local count2 = words2[word] or 0
|
||||
intersection = intersection + math.min(count1, count2)
|
||||
union = union + math.max(count1, count2)
|
||||
end
|
||||
|
||||
for word, count2 in pairs(words2) do
|
||||
if not words1[word] then
|
||||
union = union + count2
|
||||
end
|
||||
end
|
||||
|
||||
if union == 0 then
|
||||
return 1.0 -- Both empty
|
||||
end
|
||||
|
||||
-- Also check structural similarity (code structure)
|
||||
local struct_score = 0
|
||||
local function_count1 = select(2, response1:gsub("function", ""))
|
||||
local function_count2 = select(2, response2:gsub("function", ""))
|
||||
if function_count1 > 0 or function_count2 > 0 then
|
||||
struct_score = 1 - math.abs(function_count1 - function_count2) / math.max(function_count1, function_count2, 1)
|
||||
else
|
||||
struct_score = 1.0
|
||||
end
|
||||
|
||||
-- Combined score
|
||||
local jaccard = intersection / union
|
||||
return (jaccard * 0.7) + (struct_score * 0.3)
|
||||
end
|
||||
|
||||
--- Ponder (verify) Ollama's response with another LLM
|
||||
---@param prompt string Original prompt
|
||||
---@param context table LLM context
|
||||
---@param ollama_response string Ollama's response
|
||||
---@param callback fun(result: PonderResult) Callback with pondering result
|
||||
function M.ponder(prompt, context, ollama_response, callback)
|
||||
-- Use Copilot as verifier
|
||||
local copilot = require("codetyper.llm.copilot")
|
||||
|
||||
-- Build verification prompt
|
||||
local verify_prompt = prompt
|
||||
|
||||
copilot.generate(verify_prompt, context, function(verifier_response, error)
|
||||
if error or not verifier_response then
|
||||
-- Verification failed, assume Ollama is correct
|
||||
callback({
|
||||
ollama_response = ollama_response,
|
||||
verifier_response = "",
|
||||
agreement_score = 1.0,
|
||||
ollama_correct = true,
|
||||
feedback = "Verification unavailable, trusting Ollama",
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Calculate agreement
|
||||
local agreement = calculate_agreement(ollama_response, verifier_response)
|
||||
|
||||
-- Determine if Ollama was correct
|
||||
local ollama_correct = agreement >= AGREEMENT_THRESHOLD
|
||||
|
||||
-- Generate feedback
|
||||
local feedback
|
||||
if ollama_correct then
|
||||
feedback = string.format("Agreement: %.1f%% - Ollama response validated", agreement * 100)
|
||||
else
|
||||
feedback = string.format(
|
||||
"Disagreement: %.1f%% - Ollama may need correction",
|
||||
(1 - agreement) * 100
|
||||
)
|
||||
end
|
||||
|
||||
-- Update accuracy tracking
|
||||
accuracy_cache.ollama.total = accuracy_cache.ollama.total + 1
|
||||
if ollama_correct then
|
||||
accuracy_cache.ollama.correct = accuracy_cache.ollama.correct + 1
|
||||
end
|
||||
save_accuracy_stats()
|
||||
|
||||
-- Learn from this verification
|
||||
local brain = get_brain()
|
||||
if brain then
|
||||
pcall(function()
|
||||
if ollama_correct then
|
||||
-- Reinforce the pattern
|
||||
brain.learn({
|
||||
type = "correction",
|
||||
summary = "Ollama verified correct",
|
||||
detail = string.format(
|
||||
"Prompt: %s\nAgreement: %.1f%%",
|
||||
prompt:sub(1, 100),
|
||||
agreement * 100
|
||||
),
|
||||
weight = 0.8,
|
||||
file = context.file_path,
|
||||
})
|
||||
else
|
||||
-- Learn the correction
|
||||
brain.learn({
|
||||
type = "correction",
|
||||
summary = "Ollama needed correction",
|
||||
detail = string.format(
|
||||
"Prompt: %s\nOllama: %s\nCorrect: %s",
|
||||
prompt:sub(1, 100),
|
||||
ollama_response:sub(1, 200),
|
||||
verifier_response:sub(1, 200)
|
||||
),
|
||||
weight = 0.9,
|
||||
file = context.file_path,
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
callback({
|
||||
ollama_response = ollama_response,
|
||||
verifier_response = verifier_response,
|
||||
agreement_score = agreement,
|
||||
ollama_correct = ollama_correct,
|
||||
feedback = feedback,
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Smart generate with automatic provider selection and pondering
|
||||
---@param prompt string User prompt
|
||||
---@param context table LLM context
|
||||
---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback
|
||||
function M.smart_generate(prompt, context, callback)
|
||||
-- Select provider
|
||||
local selection = M.select_provider(prompt, context)
|
||||
|
||||
-- Log selection
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
"LLM: %s (confidence: %.1f%%, %s)",
|
||||
selection.provider,
|
||||
selection.confidence * 100,
|
||||
selection.reason
|
||||
),
|
||||
})
|
||||
end)
|
||||
|
||||
-- Get the selected client
|
||||
local client
|
||||
if selection.provider == "ollama" then
|
||||
client = require("codetyper.llm.ollama")
|
||||
else
|
||||
client = require("codetyper.llm.copilot")
|
||||
end
|
||||
|
||||
-- Generate response
|
||||
client.generate(prompt, context, function(response, error)
|
||||
if error then
|
||||
-- Fallback on error
|
||||
if selection.provider == "ollama" then
|
||||
-- Try Copilot as fallback
|
||||
local copilot = require("codetyper.llm.copilot")
|
||||
copilot.generate(prompt, context, function(fallback_response, fallback_error)
|
||||
callback(fallback_response, fallback_error, {
|
||||
provider = "copilot",
|
||||
fallback = true,
|
||||
original_provider = "ollama",
|
||||
original_error = error,
|
||||
})
|
||||
end)
|
||||
return
|
||||
end
|
||||
callback(nil, error, { provider = selection.provider })
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if we should ponder
|
||||
if selection.provider == "ollama" and M.should_ponder(selection.confidence) then
|
||||
M.ponder(prompt, context, response, function(ponder_result)
|
||||
if ponder_result.ollama_correct then
|
||||
-- Ollama was correct, use its response
|
||||
callback(response, nil, {
|
||||
provider = "ollama",
|
||||
pondered = true,
|
||||
agreement = ponder_result.agreement_score,
|
||||
confidence = selection.confidence,
|
||||
})
|
||||
else
|
||||
-- Use verifier's response instead
|
||||
callback(ponder_result.verifier_response, nil, {
|
||||
provider = "copilot",
|
||||
pondered = true,
|
||||
agreement = ponder_result.agreement_score,
|
||||
original_provider = "ollama",
|
||||
corrected = true,
|
||||
})
|
||||
end
|
||||
end)
|
||||
else
|
||||
-- No pondering needed
|
||||
callback(response, nil, {
|
||||
provider = selection.provider,
|
||||
pondered = false,
|
||||
confidence = selection.confidence,
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Get current accuracy statistics
|
||||
---@return table {ollama: {correct, total, accuracy}, copilot: {correct, total, accuracy}}
|
||||
function M.get_accuracy_stats()
|
||||
local stats = {
|
||||
ollama = {
|
||||
correct = accuracy_cache.ollama.correct,
|
||||
total = accuracy_cache.ollama.total,
|
||||
accuracy = accuracy_cache.ollama.total > 0
|
||||
and (accuracy_cache.ollama.correct / accuracy_cache.ollama.total)
|
||||
or 0,
|
||||
},
|
||||
copilot = {
|
||||
correct = accuracy_cache.copilot.correct,
|
||||
total = accuracy_cache.copilot.total,
|
||||
accuracy = accuracy_cache.copilot.total > 0
|
||||
and (accuracy_cache.copilot.correct / accuracy_cache.copilot.total)
|
||||
or 0,
|
||||
},
|
||||
}
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Reset accuracy statistics
|
||||
function M.reset_accuracy_stats()
|
||||
accuracy_cache = {
|
||||
ollama = { correct = 0, total = 0 },
|
||||
copilot = { correct = 0, total = 0 },
|
||||
}
|
||||
save_accuracy_stats()
|
||||
end
|
||||
|
||||
--- Report user feedback on response quality
|
||||
---@param provider string Which provider generated the response
|
||||
---@param was_correct boolean Whether the response was good
|
||||
function M.report_feedback(provider, was_correct)
|
||||
if accuracy_cache[provider] then
|
||||
accuracy_cache[provider].total = accuracy_cache[provider].total + 1
|
||||
if was_correct then
|
||||
accuracy_cache[provider].correct = accuracy_cache[provider].correct + 1
|
||||
end
|
||||
save_accuracy_stats()
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -5,66 +5,86 @@
|
||||
local M = {}
|
||||
|
||||
--- System prompt for agent mode
|
||||
M.system = [[You are an AI coding agent integrated into Neovim via Codetyper.nvim.
|
||||
M.system =
|
||||
[[You are an expert AI coding assistant integrated into Neovim. You help developers by reading, writing, and modifying code files, as well as running shell commands.
|
||||
|
||||
Your role is to ASSIST the developer by planning, coordinating, and executing
|
||||
SAFE, MINIMAL changes using the available tools.
|
||||
## YOUR CAPABILITIES
|
||||
|
||||
You do NOT operate autonomously on the entire codebase.
|
||||
You operate on clearly defined tasks and scopes.
|
||||
You have access to these tools - USE THEM to accomplish tasks:
|
||||
|
||||
You have access to the following tools:
|
||||
- read_file: Read file contents
|
||||
- edit_file: Apply a precise, scoped replacement to a file
|
||||
- write_file: Create a new file or fully replace an existing file
|
||||
- bash: Execute non-destructive shell commands when necessary
|
||||
### File Operations
|
||||
- **view**: Read any file. ALWAYS read files before modifying them. Parameters: path (string)
|
||||
- **write**: Create new files or completely replace existing ones. Use for new files. Parameters: path (string), content (string)
|
||||
- **edit**: Make precise edits to existing files using search/replace. Parameters: path (string), old_string (string), new_string (string)
|
||||
- **glob**: Find files by pattern (e.g., "**/*.lua"). Parameters: pattern (string), path (optional)
|
||||
- **grep**: Search file contents with regex. Parameters: pattern (string), path (optional)
|
||||
|
||||
OPERATING PRINCIPLES:
|
||||
1. Prefer understanding over action — read before modifying
|
||||
2. Prefer small, scoped edits over large rewrites
|
||||
3. Preserve existing behavior unless explicitly instructed otherwise
|
||||
4. Minimize the number of tool calls required
|
||||
5. Never surprise the user
|
||||
### Shell Commands
|
||||
- **bash**: Run shell commands (git, npm, make, etc.). User approves each command. Parameters: command (string)
|
||||
|
||||
IMPORTANT EDITING RULES:
|
||||
- Always read a file before editing it
|
||||
- Use edit_file ONLY for well-scoped, exact replacements
|
||||
- The "find" field MUST match existing content exactly
|
||||
- Include enough surrounding context to ensure uniqueness
|
||||
- Use write_file ONLY for new files or intentional full replacements
|
||||
- NEVER delete files unless explicitly confirmed by the user
|
||||
## HOW TO WORK
|
||||
|
||||
BASH SAFETY:
|
||||
- Use bash only when code inspection or execution is required
|
||||
- Do NOT run destructive commands (rm, mv, chmod, etc.)
|
||||
- Prefer read_file over bash when inspecting files
|
||||
1. **UNDERSTAND FIRST**: Use view, glob, or grep to understand the codebase before making changes.
|
||||
|
||||
THINKING AND PLANNING:
|
||||
- If a task requires multiple steps, outline a brief plan internally
|
||||
- Execute steps one at a time
|
||||
- Re-evaluate after each tool result
|
||||
- If uncertainty arises, stop and ask for clarification
|
||||
2. **MAKE CHANGES**: Use write for new files, edit for modifications.
|
||||
- For edit: The "old_string" parameter must match file content EXACTLY (including whitespace)
|
||||
- Include enough context in "old_string" to be unique
|
||||
- For write: Provide complete file content
|
||||
|
||||
COMMUNICATION:
|
||||
- Do NOT explain every micro-step while working
|
||||
- After completing changes, provide a clear, concise summary
|
||||
- If no changes were made, explain why
|
||||
3. **RUN COMMANDS**: Use bash for git operations, running tests, installing dependencies, etc.
|
||||
|
||||
4. **ITERATE**: After each tool result, decide if more actions are needed.
|
||||
|
||||
## EXAMPLE WORKFLOW
|
||||
|
||||
User: "Create a new React component for a login form"
|
||||
|
||||
Your approach:
|
||||
1. Use glob to see project structure (glob pattern="**/*.tsx")
|
||||
2. Use view to check existing component patterns
|
||||
3. Use write to create the new component file
|
||||
4. Use write to create a test file if appropriate
|
||||
5. Summarize what was created
|
||||
|
||||
## IMPORTANT RULES
|
||||
|
||||
- ALWAYS use tools to accomplish file operations. Don't just describe what to do - DO IT.
|
||||
- Read files before editing to ensure your "old_string" matches exactly.
|
||||
- When creating files, write complete, working code.
|
||||
- When editing, preserve existing code style and conventions.
|
||||
- If a file path is provided, use it. If not, infer from context.
|
||||
- For multi-file tasks, handle each file sequentially.
|
||||
|
||||
## OUTPUT STYLE
|
||||
|
||||
- Be concise in explanations
|
||||
- Use tools proactively to complete tasks
|
||||
- After making changes, briefly summarize what was done
|
||||
]]
|
||||
|
||||
--- Tool usage instructions appended to system prompt
|
||||
M.tool_instructions = [[
|
||||
When you need to use a tool, output ONLY a single tool call in valid JSON.
|
||||
Do NOT include explanations alongside the tool call.
|
||||
## TOOL USAGE
|
||||
|
||||
After receiving a tool result:
|
||||
- Decide whether another tool call is required
|
||||
- Or produce a final response to the user
|
||||
When you need to perform an action, call the appropriate tool. You can call tools to:
|
||||
- Read files with view (parameters: path)
|
||||
- Create new files with write (parameters: path, content)
|
||||
- Modify existing files with edit (parameters: path, old_string, new_string) - read first!
|
||||
- Find files by pattern with glob (parameters: pattern, path)
|
||||
- Search file contents with grep (parameters: pattern, path)
|
||||
- Run shell commands with bash (parameters: command)
|
||||
|
||||
SAFETY RULES:
|
||||
- Never run destructive or irreversible commands
|
||||
- Never modify code outside the requested scope
|
||||
- Never guess file contents — read them first
|
||||
- If a requested change appears risky or ambiguous, ask before proceeding
|
||||
After receiving a tool result, continue working:
|
||||
- If more actions are needed, call another tool
|
||||
- When the task is complete, provide a brief summary
|
||||
|
||||
## CRITICAL RULES
|
||||
|
||||
1. **Always read before editing**: Use view before edit to ensure exact matches
|
||||
2. **Be precise with edits**: The "old_string" parameter must match the file content EXACTLY
|
||||
3. **Create complete files**: When using write, provide fully working code
|
||||
4. **User approval required**: File writes, edits, and bash commands need approval
|
||||
5. **Don't guess**: If unsure about file structure, use glob or grep
|
||||
]]
|
||||
|
||||
--- Prompt for when agent finishes
|
||||
|
||||
@@ -11,6 +11,7 @@ M.code = require("codetyper.prompts.code")
|
||||
M.ask = require("codetyper.prompts.ask")
|
||||
M.refactor = require("codetyper.prompts.refactor")
|
||||
M.document = require("codetyper.prompts.document")
|
||||
M.agent = require("codetyper.prompts.agent")
|
||||
|
||||
--- Get a prompt by category and name
|
||||
---@param category string Category name (system, code, ask, refactor, document)
|
||||
|
||||
@@ -45,11 +45,14 @@ GUIDELINES:
|
||||
6. Focus on practical understanding and tradeoffs
|
||||
|
||||
IMPORTANT:
|
||||
- Do NOT output raw code intended for insertion
|
||||
- Do NOT refuse to explain code - that IS your purpose in this mode
|
||||
- Do NOT assume missing context
|
||||
- Do NOT speculate beyond the provided information
|
||||
- Provide helpful, detailed explanations when asked
|
||||
]]
|
||||
|
||||
-- Alias for backward compatibility
|
||||
M.explain = M.ask
|
||||
|
||||
--- System prompt for scoped refactoring
|
||||
M.refactor = [[You are an expert refactoring assistant integrated into Neovim via Codetyper.nvim.
|
||||
|
||||
@@ -121,4 +124,8 @@ Language: {{language}}
|
||||
REMEMBER: Output ONLY valid {{language}} test code.
|
||||
]]
|
||||
|
||||
--- Base prompt for agent mode (full prompt is in agent.lua)
|
||||
--- This provides minimal context; the agent prompts module adds tool instructions
|
||||
M.agent = [[]]
|
||||
|
||||
return M
|
||||
|
||||
491
lua/codetyper/suggestion/init.lua
Normal file
491
lua/codetyper/suggestion/init.lua
Normal file
@@ -0,0 +1,491 @@
|
||||
---@mod codetyper.suggestion Inline ghost text suggestions
|
||||
---@brief [[
|
||||
--- Provides Copilot-style inline suggestions with ghost text.
|
||||
--- Uses Copilot when available, falls back to codetyper's own suggestions.
|
||||
--- Shows suggestions as grayed-out text that can be accepted with Tab.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class SuggestionState
|
||||
---@field current_suggestion string|nil Current suggestion text
|
||||
---@field suggestions string[] List of available suggestions
|
||||
---@field current_index number Current suggestion index
|
||||
---@field extmark_id number|nil Virtual text extmark ID
|
||||
---@field bufnr number|nil Buffer where suggestion is shown
|
||||
---@field line number|nil Line where suggestion is shown
|
||||
---@field col number|nil Column where suggestion starts
|
||||
---@field timer any|nil Debounce timer
|
||||
---@field using_copilot boolean Whether currently using copilot
|
||||
|
||||
local state = {
|
||||
current_suggestion = nil,
|
||||
suggestions = {},
|
||||
current_index = 0,
|
||||
extmark_id = nil,
|
||||
bufnr = nil,
|
||||
line = nil,
|
||||
col = nil,
|
||||
timer = nil,
|
||||
using_copilot = false,
|
||||
}
|
||||
|
||||
--- Namespace for virtual text
|
||||
local ns = vim.api.nvim_create_namespace("codetyper_suggestion")
|
||||
|
||||
--- Highlight group for ghost text
|
||||
local hl_group = "CmpGhostText"
|
||||
|
||||
--- Configuration
|
||||
local config = {
|
||||
enabled = true,
|
||||
auto_trigger = true,
|
||||
debounce = 150,
|
||||
use_copilot = true, -- Use copilot when available
|
||||
keymap = {
|
||||
accept = "<Tab>",
|
||||
next = "<M-]>",
|
||||
prev = "<M-[>",
|
||||
dismiss = "<C-]>",
|
||||
},
|
||||
}
|
||||
|
||||
--- Check if copilot is available and enabled
|
||||
---@return boolean, table|nil available, copilot_suggestion module
|
||||
local function get_copilot()
|
||||
if not config.use_copilot then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local ok, copilot_suggestion = pcall(require, "copilot.suggestion")
|
||||
if not ok then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
-- Check if copilot suggestion is enabled
|
||||
local ok_client, copilot_client = pcall(require, "copilot.client")
|
||||
if ok_client and copilot_client.is_disabled and copilot_client.is_disabled() then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
return true, copilot_suggestion
|
||||
end
|
||||
|
||||
--- Check if suggestion is visible (copilot or codetyper)
|
||||
---@return boolean
|
||||
function M.is_visible()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
state.using_copilot = true
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check codetyper's own suggestion
|
||||
state.using_copilot = false
|
||||
return state.extmark_id ~= nil and state.current_suggestion ~= nil
|
||||
end
|
||||
|
||||
--- Clear the current suggestion
|
||||
function M.dismiss()
|
||||
-- Dismiss copilot if active
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.dismiss()
|
||||
end
|
||||
|
||||
-- Clear codetyper's suggestion
|
||||
if state.extmark_id and state.bufnr then
|
||||
pcall(vim.api.nvim_buf_del_extmark, state.bufnr, ns, state.extmark_id)
|
||||
end
|
||||
|
||||
state.current_suggestion = nil
|
||||
state.suggestions = {}
|
||||
state.current_index = 0
|
||||
state.extmark_id = nil
|
||||
state.bufnr = nil
|
||||
state.line = nil
|
||||
state.col = nil
|
||||
state.using_copilot = false
|
||||
end
|
||||
|
||||
--- Display suggestion as ghost text
|
||||
---@param suggestion string The suggestion to display
|
||||
local function display_suggestion(suggestion)
|
||||
if not suggestion or suggestion == "" then
|
||||
return
|
||||
end
|
||||
|
||||
M.dismiss()
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local line = cursor[1] - 1
|
||||
local col = cursor[2]
|
||||
|
||||
-- Split suggestion into lines
|
||||
local lines = vim.split(suggestion, "\n", { plain = true })
|
||||
|
||||
-- Build virtual text
|
||||
local virt_text = {}
|
||||
local virt_lines = {}
|
||||
|
||||
-- First line goes inline
|
||||
if #lines > 0 then
|
||||
virt_text = { { lines[1], hl_group } }
|
||||
end
|
||||
|
||||
-- Remaining lines go below
|
||||
for i = 2, #lines do
|
||||
table.insert(virt_lines, { { lines[i], hl_group } })
|
||||
end
|
||||
|
||||
-- Create extmark with virtual text
|
||||
local opts = {
|
||||
virt_text = virt_text,
|
||||
virt_text_pos = "overlay",
|
||||
hl_mode = "combine",
|
||||
}
|
||||
|
||||
if #virt_lines > 0 then
|
||||
opts.virt_lines = virt_lines
|
||||
end
|
||||
|
||||
state.extmark_id = vim.api.nvim_buf_set_extmark(bufnr, ns, line, col, opts)
|
||||
state.bufnr = bufnr
|
||||
state.line = line
|
||||
state.col = col
|
||||
state.current_suggestion = suggestion
|
||||
end
|
||||
|
||||
--- Accept the current suggestion
|
||||
---@return boolean Whether a suggestion was accepted
|
||||
function M.accept()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.accept()
|
||||
state.using_copilot = false
|
||||
return true
|
||||
end
|
||||
|
||||
-- Accept codetyper's suggestion
|
||||
if not M.is_visible() then
|
||||
return false
|
||||
end
|
||||
|
||||
local suggestion = state.current_suggestion
|
||||
local bufnr = state.bufnr
|
||||
local line = state.line
|
||||
local col = state.col
|
||||
|
||||
M.dismiss()
|
||||
|
||||
if suggestion and bufnr and line ~= nil and col ~= nil then
|
||||
-- Get current line content
|
||||
local current_line = vim.api.nvim_buf_get_lines(bufnr, line, line + 1, false)[1] or ""
|
||||
|
||||
-- Split suggestion into lines
|
||||
local suggestion_lines = vim.split(suggestion, "\n", { plain = true })
|
||||
|
||||
if #suggestion_lines == 1 then
|
||||
-- Single line - insert at cursor
|
||||
local new_line = current_line:sub(1, col) .. suggestion .. current_line:sub(col + 1)
|
||||
vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, { new_line })
|
||||
-- Move cursor to end of inserted text
|
||||
vim.api.nvim_win_set_cursor(0, { line + 1, col + #suggestion })
|
||||
else
|
||||
-- Multi-line - insert at cursor
|
||||
local first_line = current_line:sub(1, col) .. suggestion_lines[1]
|
||||
local last_line = suggestion_lines[#suggestion_lines] .. current_line:sub(col + 1)
|
||||
|
||||
local new_lines = { first_line }
|
||||
for i = 2, #suggestion_lines - 1 do
|
||||
table.insert(new_lines, suggestion_lines[i])
|
||||
end
|
||||
table.insert(new_lines, last_line)
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, new_lines)
|
||||
-- Move cursor to end of last line
|
||||
vim.api.nvim_win_set_cursor(0, { line + #new_lines, #suggestion_lines[#suggestion_lines] })
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Show next suggestion
|
||||
function M.next()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.next()
|
||||
return
|
||||
end
|
||||
|
||||
-- Codetyper's suggestions
|
||||
if #state.suggestions <= 1 then
|
||||
return
|
||||
end
|
||||
|
||||
state.current_index = (state.current_index % #state.suggestions) + 1
|
||||
display_suggestion(state.suggestions[state.current_index])
|
||||
end
|
||||
|
||||
--- Show previous suggestion
|
||||
function M.prev()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.prev()
|
||||
return
|
||||
end
|
||||
|
||||
-- Codetyper's suggestions
|
||||
if #state.suggestions <= 1 then
|
||||
return
|
||||
end
|
||||
|
||||
state.current_index = state.current_index - 1
|
||||
if state.current_index < 1 then
|
||||
state.current_index = #state.suggestions
|
||||
end
|
||||
display_suggestion(state.suggestions[state.current_index])
|
||||
end
|
||||
|
||||
--- Get suggestions from brain/indexer
|
||||
---@param prefix string Current word prefix
|
||||
---@param context table Context info
|
||||
---@return string[] suggestions
|
||||
local function get_suggestions(prefix, context)
|
||||
local suggestions = {}
|
||||
|
||||
-- Get completions from brain
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized and brain.is_initialized() then
|
||||
local result = brain.query({
|
||||
query = prefix,
|
||||
max_results = 5,
|
||||
types = { "pattern" },
|
||||
})
|
||||
|
||||
if result and result.nodes then
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c and node.c.code then
|
||||
table.insert(suggestions, node.c.code)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get completions from indexer
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
local index = indexer.load_index()
|
||||
if index and index.symbols then
|
||||
for symbol, _ in pairs(index.symbols) do
|
||||
if symbol:lower():find(prefix:lower(), 1, true) and symbol ~= prefix then
|
||||
-- Just complete the symbol name
|
||||
local completion = symbol:sub(#prefix + 1)
|
||||
if completion ~= "" then
|
||||
table.insert(suggestions, completion)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Buffer-based completions
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local seen = {}
|
||||
|
||||
for _, line in ipairs(lines) do
|
||||
for word in line:gmatch("[%a_][%w_]*") do
|
||||
if
|
||||
#word > #prefix
|
||||
and word:lower():find(prefix:lower(), 1, true) == 1
|
||||
and not seen[word]
|
||||
and word ~= prefix
|
||||
then
|
||||
seen[word] = true
|
||||
local completion = word:sub(#prefix + 1)
|
||||
if completion ~= "" then
|
||||
table.insert(suggestions, completion)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return suggestions
|
||||
end
|
||||
|
||||
--- Trigger suggestion generation
|
||||
function M.trigger()
|
||||
if not config.enabled then
|
||||
return
|
||||
end
|
||||
|
||||
-- If copilot is available and has a suggestion, don't show codetyper's
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
-- Copilot is handling suggestions
|
||||
state.using_copilot = true
|
||||
return
|
||||
end
|
||||
|
||||
-- Cancel existing timer
|
||||
if state.timer then
|
||||
state.timer:stop()
|
||||
state.timer = nil
|
||||
end
|
||||
|
||||
-- Get current context
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local line = vim.api.nvim_get_current_line()
|
||||
local col = cursor[2]
|
||||
local before_cursor = line:sub(1, col)
|
||||
|
||||
-- Extract prefix (word being typed)
|
||||
local prefix = before_cursor:match("[%a_][%w_]*$") or ""
|
||||
|
||||
if #prefix < 2 then
|
||||
M.dismiss()
|
||||
return
|
||||
end
|
||||
|
||||
-- Debounce - wait a bit longer to let copilot try first
|
||||
local debounce_time = copilot_ok and (config.debounce + 200) or config.debounce
|
||||
|
||||
state.timer = vim.defer_fn(function()
|
||||
-- Check again if copilot has shown something
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
state.using_copilot = true
|
||||
state.timer = nil
|
||||
return
|
||||
end
|
||||
|
||||
local suggestions = get_suggestions(prefix, {
|
||||
line = line,
|
||||
col = col,
|
||||
bufnr = vim.api.nvim_get_current_buf(),
|
||||
})
|
||||
|
||||
if #suggestions > 0 then
|
||||
state.suggestions = suggestions
|
||||
state.current_index = 1
|
||||
display_suggestion(suggestions[1])
|
||||
else
|
||||
M.dismiss()
|
||||
end
|
||||
|
||||
state.timer = nil
|
||||
end, debounce_time)
|
||||
end
|
||||
|
||||
--- Setup keymaps
|
||||
local function setup_keymaps()
|
||||
-- Accept with Tab (only when suggestion visible)
|
||||
vim.keymap.set("i", config.keymap.accept, function()
|
||||
if M.is_visible() then
|
||||
M.accept()
|
||||
return ""
|
||||
end
|
||||
-- Fallback to normal Tab behavior
|
||||
return vim.api.nvim_replace_termcodes("<Tab>", true, false, true)
|
||||
end, { expr = true, silent = true, desc = "Accept codetyper suggestion" })
|
||||
|
||||
-- Next suggestion
|
||||
vim.keymap.set("i", config.keymap.next, function()
|
||||
M.next()
|
||||
end, { silent = true, desc = "Next codetyper suggestion" })
|
||||
|
||||
-- Previous suggestion
|
||||
vim.keymap.set("i", config.keymap.prev, function()
|
||||
M.prev()
|
||||
end, { silent = true, desc = "Previous codetyper suggestion" })
|
||||
|
||||
-- Dismiss
|
||||
vim.keymap.set("i", config.keymap.dismiss, function()
|
||||
M.dismiss()
|
||||
end, { silent = true, desc = "Dismiss codetyper suggestion" })
|
||||
end
|
||||
|
||||
--- Setup autocmds for auto-trigger
|
||||
local function setup_autocmds()
|
||||
local group = vim.api.nvim_create_augroup("CodetypeSuggestion", { clear = true })
|
||||
|
||||
-- Trigger on text change in insert mode
|
||||
if config.auto_trigger then
|
||||
vim.api.nvim_create_autocmd("TextChangedI", {
|
||||
group = group,
|
||||
callback = function()
|
||||
M.trigger()
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
-- Dismiss on leaving insert mode
|
||||
vim.api.nvim_create_autocmd("InsertLeave", {
|
||||
group = group,
|
||||
callback = function()
|
||||
M.dismiss()
|
||||
end,
|
||||
})
|
||||
|
||||
-- Dismiss on cursor move (not from typing)
|
||||
vim.api.nvim_create_autocmd("CursorMovedI", {
|
||||
group = group,
|
||||
callback = function()
|
||||
-- Only dismiss if cursor moved significantly
|
||||
if state.line ~= nil then
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
if cursor[1] - 1 ~= state.line then
|
||||
M.dismiss()
|
||||
end
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Setup highlight group
|
||||
local function setup_highlights()
|
||||
-- Use Comment highlight or define custom ghost text style
|
||||
vim.api.nvim_set_hl(0, hl_group, { link = "Comment" })
|
||||
end
|
||||
|
||||
--- Setup the suggestion system
|
||||
---@param opts? table Configuration options
|
||||
function M.setup(opts)
|
||||
if opts then
|
||||
config = vim.tbl_deep_extend("force", config, opts)
|
||||
end
|
||||
|
||||
setup_highlights()
|
||||
setup_keymaps()
|
||||
setup_autocmds()
|
||||
end
|
||||
|
||||
--- Enable suggestions
|
||||
function M.enable()
|
||||
config.enabled = true
|
||||
end
|
||||
|
||||
--- Disable suggestions
|
||||
function M.disable()
|
||||
config.enabled = false
|
||||
M.dismiss()
|
||||
end
|
||||
|
||||
--- Toggle suggestions
|
||||
function M.toggle()
|
||||
if config.enabled then
|
||||
M.disable()
|
||||
else
|
||||
M.enable()
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -7,18 +7,28 @@
|
||||
---@field auto_gitignore boolean Auto-manage .gitignore
|
||||
|
||||
---@class LLMConfig
|
||||
---@field provider "claude" | "ollama" The LLM provider to use
|
||||
---@field claude ClaudeConfig Claude-specific configuration
|
||||
---@field provider "ollama" | "openai" | "gemini" | "copilot" The LLM provider to use
|
||||
---@field ollama OllamaConfig Ollama-specific configuration
|
||||
|
||||
---@class ClaudeConfig
|
||||
---@field api_key string | nil Claude API key (or env var ANTHROPIC_API_KEY)
|
||||
---@field model string Claude model to use
|
||||
---@field openai OpenAIConfig OpenAI-specific configuration
|
||||
---@field gemini GeminiConfig Gemini-specific configuration
|
||||
---@field copilot CopilotConfig Copilot-specific configuration
|
||||
|
||||
---@class OllamaConfig
|
||||
---@field host string Ollama host URL
|
||||
---@field model string Ollama model to use
|
||||
|
||||
---@class OpenAIConfig
|
||||
---@field api_key string | nil OpenAI API key (or env var OPENAI_API_KEY)
|
||||
---@field model string OpenAI model to use
|
||||
---@field endpoint string | nil Custom endpoint (Azure, OpenRouter, etc.)
|
||||
|
||||
---@class GeminiConfig
|
||||
---@field api_key string | nil Gemini API key (or env var GEMINI_API_KEY)
|
||||
---@field model string Gemini model to use
|
||||
|
||||
---@class CopilotConfig
|
||||
---@field model string Copilot model to use
|
||||
|
||||
---@class WindowConfig
|
||||
---@field width number Width of the coder window (percentage or columns)
|
||||
---@field position "left" | "right" Position of the coder window
|
||||
|
||||
@@ -18,14 +18,14 @@ M._target_buf = nil
|
||||
|
||||
--- Calculate window width based on configuration
|
||||
---@param config CoderConfig Plugin configuration
|
||||
---@return number Width in columns
|
||||
---@return number Width in columns (minimum 30)
|
||||
local function calculate_width(config)
|
||||
local width = config.window.width
|
||||
if width <= 1 then
|
||||
-- Percentage of total width
|
||||
return math.floor(vim.o.columns * width)
|
||||
-- Percentage of total width (1/4 of screen with minimum 30)
|
||||
return math.max(math.floor(vim.o.columns * width), 30)
|
||||
end
|
||||
return math.floor(width)
|
||||
return math.max(math.floor(width), 30)
|
||||
end
|
||||
|
||||
--- Open the coder split view
|
||||
|
||||
427
tests/spec/agent_tools_spec.lua
Normal file
427
tests/spec/agent_tools_spec.lua
Normal file
@@ -0,0 +1,427 @@
|
||||
--- Tests for agent tools system
|
||||
|
||||
describe("codetyper.agent.tools", function()
|
||||
local tools
|
||||
|
||||
before_each(function()
|
||||
tools = require("codetyper.agent.tools")
|
||||
-- Clear any existing registrations
|
||||
for name, _ in pairs(tools.get_all()) do
|
||||
tools.unregister(name)
|
||||
end
|
||||
end)
|
||||
|
||||
describe("tool registration", function()
|
||||
it("should register a tool", function()
|
||||
local test_tool = {
|
||||
name = "test_tool",
|
||||
description = "A test tool",
|
||||
params = {
|
||||
{ name = "input", type = "string", description = "Test input" },
|
||||
},
|
||||
func = function(input, opts)
|
||||
return "result", nil
|
||||
end,
|
||||
}
|
||||
|
||||
tools.register(test_tool)
|
||||
local retrieved = tools.get("test_tool")
|
||||
|
||||
assert.is_not_nil(retrieved)
|
||||
assert.equals("test_tool", retrieved.name)
|
||||
end)
|
||||
|
||||
it("should unregister a tool", function()
|
||||
local test_tool = {
|
||||
name = "temp_tool",
|
||||
description = "Temporary",
|
||||
func = function() end,
|
||||
}
|
||||
|
||||
tools.register(test_tool)
|
||||
assert.is_not_nil(tools.get("temp_tool"))
|
||||
|
||||
tools.unregister("temp_tool")
|
||||
assert.is_nil(tools.get("temp_tool"))
|
||||
end)
|
||||
|
||||
it("should list all tools", function()
|
||||
tools.register({ name = "tool1", func = function() end })
|
||||
tools.register({ name = "tool2", func = function() end })
|
||||
tools.register({ name = "tool3", func = function() end })
|
||||
|
||||
local list = tools.list()
|
||||
assert.equals(3, #list)
|
||||
end)
|
||||
|
||||
it("should filter tools with predicate", function()
|
||||
tools.register({ name = "safe_tool", requires_confirmation = false, func = function() end })
|
||||
tools.register({ name = "dangerous_tool", requires_confirmation = true, func = function() end })
|
||||
|
||||
local safe_list = tools.list(function(t)
|
||||
return not t.requires_confirmation
|
||||
end)
|
||||
|
||||
assert.equals(1, #safe_list)
|
||||
assert.equals("safe_tool", safe_list[1].name)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tool execution", function()
|
||||
it("should execute a tool and return result", function()
|
||||
tools.register({
|
||||
name = "adder",
|
||||
params = {
|
||||
{ name = "a", type = "number" },
|
||||
{ name = "b", type = "number" },
|
||||
},
|
||||
func = function(input, opts)
|
||||
return input.a + input.b, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.execute("adder", { a = 5, b = 3 }, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals(8, result)
|
||||
end)
|
||||
|
||||
it("should return error for unknown tool", function()
|
||||
local result, err = tools.execute("nonexistent", {}, {})
|
||||
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("Unknown tool"))
|
||||
end)
|
||||
|
||||
it("should track execution history", function()
|
||||
tools.clear_history()
|
||||
tools.register({
|
||||
name = "tracked_tool",
|
||||
func = function()
|
||||
return "done", nil
|
||||
end,
|
||||
})
|
||||
|
||||
tools.execute("tracked_tool", {}, {})
|
||||
tools.execute("tracked_tool", {}, {})
|
||||
|
||||
local history = tools.get_history()
|
||||
assert.equals(2, #history)
|
||||
assert.equals("tracked_tool", history[1].tool)
|
||||
assert.equals("completed", history[1].status)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tool schemas", function()
|
||||
it("should generate JSON schema for tools", function()
|
||||
tools.register({
|
||||
name = "schema_test",
|
||||
description = "Test schema generation",
|
||||
params = {
|
||||
{ name = "required_param", type = "string", description = "A required param" },
|
||||
{ name = "optional_param", type = "number", description = "Optional", optional = true },
|
||||
},
|
||||
returns = {
|
||||
{ name = "result", type = "string" },
|
||||
},
|
||||
to_schema = require("codetyper.agent.tools.base").to_schema,
|
||||
func = function() end,
|
||||
})
|
||||
|
||||
local schemas = tools.get_schemas()
|
||||
assert.equals(1, #schemas)
|
||||
|
||||
local schema = schemas[1]
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("schema_test", schema.function_def.name)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.required_param)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.optional_param)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("process_tool_call", function()
|
||||
it("should process tool call with name and input", function()
|
||||
tools.register({
|
||||
name = "processor_test",
|
||||
func = function(input, opts)
|
||||
return "processed: " .. input.value, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.process_tool_call({
|
||||
name = "processor_test",
|
||||
input = { value = "test" },
|
||||
}, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals("processed: test", result)
|
||||
end)
|
||||
|
||||
it("should parse JSON string arguments", function()
|
||||
tools.register({
|
||||
name = "json_parser_test",
|
||||
func = function(input, opts)
|
||||
return input.key, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.process_tool_call({
|
||||
name = "json_parser_test",
|
||||
arguments = '{"key": "value"}',
|
||||
}, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals("value", result)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("codetyper.agent.tools.base", function()
|
||||
local base
|
||||
|
||||
before_each(function()
|
||||
base = require("codetyper.agent.tools.base")
|
||||
end)
|
||||
|
||||
describe("validate_input", function()
|
||||
it("should validate required parameters", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "required", type = "string" },
|
||||
{ name = "optional", type = "string", optional = true },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid, err = tool:validate_input({ required = "value" })
|
||||
assert.is_true(valid)
|
||||
assert.is_nil(err)
|
||||
end)
|
||||
|
||||
it("should fail on missing required parameter", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "required", type = "string" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid, err = tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
assert.truthy(err:match("Missing required parameter"))
|
||||
end)
|
||||
|
||||
it("should validate parameter types", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "num", type = "number" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid1, _ = tool:validate_input({ num = 42 })
|
||||
assert.is_true(valid1)
|
||||
|
||||
local valid2, err2 = tool:validate_input({ num = "not a number" })
|
||||
assert.is_false(valid2)
|
||||
assert.truthy(err2:match("must be number"))
|
||||
end)
|
||||
|
||||
it("should validate integer type", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "int", type = "integer" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid1, _ = tool:validate_input({ int = 42 })
|
||||
assert.is_true(valid1)
|
||||
|
||||
local valid2, err2 = tool:validate_input({ int = 42.5 })
|
||||
assert.is_false(valid2)
|
||||
assert.truthy(err2:match("must be an integer"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_description", function()
|
||||
it("should return string description", function()
|
||||
local tool = setmetatable({
|
||||
description = "Static description",
|
||||
}, base)
|
||||
|
||||
assert.equals("Static description", tool:get_description())
|
||||
end)
|
||||
|
||||
it("should call function description", function()
|
||||
local tool = setmetatable({
|
||||
description = function()
|
||||
return "Dynamic description"
|
||||
end,
|
||||
}, base)
|
||||
|
||||
assert.equals("Dynamic description", tool:get_description())
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("to_schema", function()
|
||||
it("should generate valid schema", function()
|
||||
local tool = setmetatable({
|
||||
name = "test",
|
||||
description = "Test tool",
|
||||
params = {
|
||||
{ name = "input", type = "string", description = "Input value" },
|
||||
{ name = "count", type = "integer", description = "Count", optional = true },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local schema = tool:to_schema()
|
||||
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("test", schema.function_def.name)
|
||||
assert.equals("Test tool", schema.function_def.description)
|
||||
assert.equals("object", schema.function_def.parameters.type)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.input)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.count)
|
||||
assert.same({ "input" }, schema.function_def.parameters.required)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("built-in tools", function()
|
||||
describe("view tool", function()
|
||||
local view
|
||||
|
||||
before_each(function()
|
||||
view = require("codetyper.agent.tools.view")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("view", view.name)
|
||||
assert.is_string(view.description)
|
||||
assert.is_table(view.params)
|
||||
assert.is_function(view.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = view.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("grep tool", function()
|
||||
local grep
|
||||
|
||||
before_each(function()
|
||||
grep = require("codetyper.agent.tools.grep")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("grep", grep.name)
|
||||
assert.is_string(grep.description)
|
||||
assert.is_table(grep.params)
|
||||
assert.is_function(grep.func)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local result, err = grep.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("pattern is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("glob tool", function()
|
||||
local glob
|
||||
|
||||
before_each(function()
|
||||
glob = require("codetyper.agent.tools.glob")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("glob", glob.name)
|
||||
assert.is_string(glob.description)
|
||||
assert.is_table(glob.params)
|
||||
assert.is_function(glob.func)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local result, err = glob.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("pattern is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edit tool", function()
|
||||
local edit
|
||||
|
||||
before_each(function()
|
||||
edit = require("codetyper.agent.tools.edit")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("edit", edit.name)
|
||||
assert.is_string(edit.description)
|
||||
assert.is_table(edit.params)
|
||||
assert.is_function(edit.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = edit.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
|
||||
it("should require old_string parameter", function()
|
||||
local result, err = edit.func({ path = "/tmp/test" }, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("old_string is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("write tool", function()
|
||||
local write
|
||||
|
||||
before_each(function()
|
||||
write = require("codetyper.agent.tools.write")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("write", write.name)
|
||||
assert.is_string(write.description)
|
||||
assert.is_table(write.params)
|
||||
assert.is_function(write.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = write.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
|
||||
it("should require content parameter", function()
|
||||
local result, err = write.func({ path = "/tmp/test" }, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("content is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("bash tool", function()
|
||||
local bash
|
||||
|
||||
before_each(function()
|
||||
bash = require("codetyper.agent.tools.bash")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("bash", bash.name)
|
||||
assert.is_function(bash.func)
|
||||
end)
|
||||
|
||||
it("should require command parameter", function()
|
||||
local result, err = bash.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("command is required"))
|
||||
end)
|
||||
|
||||
it("should require confirmation by default", function()
|
||||
assert.is_true(bash.requires_confirmation)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
312
tests/spec/agentic_spec.lua
Normal file
312
tests/spec/agentic_spec.lua
Normal file
@@ -0,0 +1,312 @@
|
||||
---@diagnostic disable: undefined-global
|
||||
-- Unit tests for the agentic system
|
||||
|
||||
describe("agentic module", function()
|
||||
local agentic
|
||||
|
||||
before_each(function()
|
||||
-- Reset and reload
|
||||
package.loaded["codetyper.agent.agentic"] = nil
|
||||
agentic = require("codetyper.agent.agentic")
|
||||
end)
|
||||
|
||||
it("should list built-in agents", function()
|
||||
local agents = agentic.list_agents()
|
||||
assert.is_table(agents)
|
||||
assert.is_true(#agents >= 3) -- coder, planner, explorer
|
||||
|
||||
local names = {}
|
||||
for _, agent in ipairs(agents) do
|
||||
names[agent.name] = true
|
||||
end
|
||||
|
||||
assert.is_true(names["coder"])
|
||||
assert.is_true(names["planner"])
|
||||
assert.is_true(names["explorer"])
|
||||
end)
|
||||
|
||||
it("should have description for each agent", function()
|
||||
local agents = agentic.list_agents()
|
||||
for _, agent in ipairs(agents) do
|
||||
assert.is_string(agent.description)
|
||||
assert.is_true(#agent.description > 0)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should mark built-in agents as builtin", function()
|
||||
local agents = agentic.list_agents()
|
||||
local coder = nil
|
||||
for _, agent in ipairs(agents) do
|
||||
if agent.name == "coder" then
|
||||
coder = agent
|
||||
break
|
||||
end
|
||||
end
|
||||
assert.is_not_nil(coder)
|
||||
assert.is_true(coder.builtin)
|
||||
end)
|
||||
|
||||
it("should have init function to create directories", function()
|
||||
assert.is_function(agentic.init)
|
||||
assert.is_function(agentic.init_agents_dir)
|
||||
assert.is_function(agentic.init_rules_dir)
|
||||
end)
|
||||
|
||||
it("should have run function for executing tasks", function()
|
||||
assert.is_function(agentic.run)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tools format conversion", function()
|
||||
local tools_module
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools"] = nil
|
||||
tools_module = require("codetyper.agent.tools")
|
||||
-- Load tools
|
||||
if tools_module.load_builtins then
|
||||
pcall(tools_module.load_builtins)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should have to_openai_format function", function()
|
||||
assert.is_function(tools_module.to_openai_format)
|
||||
end)
|
||||
|
||||
it("should have to_claude_format function", function()
|
||||
assert.is_function(tools_module.to_claude_format)
|
||||
end)
|
||||
|
||||
it("should convert tools to OpenAI format", function()
|
||||
local openai_tools = tools_module.to_openai_format()
|
||||
assert.is_table(openai_tools)
|
||||
|
||||
-- If tools are loaded, check format
|
||||
if #openai_tools > 0 then
|
||||
local first_tool = openai_tools[1]
|
||||
assert.equals("function", first_tool.type)
|
||||
assert.is_table(first_tool["function"])
|
||||
assert.is_string(first_tool["function"].name)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should convert tools to Claude format", function()
|
||||
local claude_tools = tools_module.to_claude_format()
|
||||
assert.is_table(claude_tools)
|
||||
|
||||
-- If tools are loaded, check format
|
||||
if #claude_tools > 0 then
|
||||
local first_tool = claude_tools[1]
|
||||
assert.is_string(first_tool.name)
|
||||
assert.is_table(first_tool.input_schema)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edit tool", function()
|
||||
local edit_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.edit"] = nil
|
||||
edit_tool = require("codetyper.agent.tools.edit")
|
||||
end)
|
||||
|
||||
it("should have name 'edit'", function()
|
||||
assert.equals("edit", edit_tool.name)
|
||||
end)
|
||||
|
||||
it("should have description mentioning matching strategies", function()
|
||||
local desc = edit_tool:get_description()
|
||||
assert.is_string(desc)
|
||||
-- Should mention the matching capabilities
|
||||
assert.is_true(desc:lower():match("match") ~= nil or desc:lower():match("replac") ~= nil)
|
||||
end)
|
||||
|
||||
it("should have params defined", function()
|
||||
assert.is_table(edit_tool.params)
|
||||
assert.is_true(#edit_tool.params >= 3) -- path, old_string, new_string
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
old_string = "test",
|
||||
new_string = "test2",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
assert.is_string(err)
|
||||
end)
|
||||
|
||||
it("should require old_string parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test",
|
||||
new_string = "test",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should require new_string parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test",
|
||||
old_string = "test",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept empty old_string for new file creation", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test/new_file.lua",
|
||||
old_string = "",
|
||||
new_string = "new content",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
assert.is_nil(err)
|
||||
end)
|
||||
|
||||
it("should have func implementation", function()
|
||||
assert.is_function(edit_tool.func)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("view tool", function()
|
||||
local view_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.view"] = nil
|
||||
view_tool = require("codetyper.agent.tools.view")
|
||||
end)
|
||||
|
||||
it("should have name 'view'", function()
|
||||
assert.equals("view", view_tool.name)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local valid, err = view_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid path", function()
|
||||
local valid, err = view_tool:validate_input({
|
||||
path = "/test/file.lua",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("write tool", function()
|
||||
local write_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.write"] = nil
|
||||
write_tool = require("codetyper.agent.tools.write")
|
||||
end)
|
||||
|
||||
it("should have name 'write'", function()
|
||||
assert.equals("write", write_tool.name)
|
||||
end)
|
||||
|
||||
it("should require path and content parameters", function()
|
||||
local valid, err = write_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
|
||||
valid, err = write_tool:validate_input({ path = "/test" })
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid input", function()
|
||||
local valid, err = write_tool:validate_input({
|
||||
path = "/test/file.lua",
|
||||
content = "test content",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("grep tool", function()
|
||||
local grep_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.grep"] = nil
|
||||
grep_tool = require("codetyper.agent.tools.grep")
|
||||
end)
|
||||
|
||||
it("should have name 'grep'", function()
|
||||
assert.equals("grep", grep_tool.name)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local valid, err = grep_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid pattern", function()
|
||||
local valid, err = grep_tool:validate_input({
|
||||
pattern = "function.*test",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("glob tool", function()
|
||||
local glob_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.glob"] = nil
|
||||
glob_tool = require("codetyper.agent.tools.glob")
|
||||
end)
|
||||
|
||||
it("should have name 'glob'", function()
|
||||
assert.equals("glob", glob_tool.name)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local valid, err = glob_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid pattern", function()
|
||||
local valid, err = glob_tool:validate_input({
|
||||
pattern = "**/*.lua",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("base tool", function()
|
||||
local Base
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.base"] = nil
|
||||
Base = require("codetyper.agent.tools.base")
|
||||
end)
|
||||
|
||||
it("should have validate_input method", function()
|
||||
assert.is_function(Base.validate_input)
|
||||
end)
|
||||
|
||||
it("should have to_schema method", function()
|
||||
assert.is_function(Base.to_schema)
|
||||
end)
|
||||
|
||||
it("should have get_description method", function()
|
||||
assert.is_function(Base.get_description)
|
||||
end)
|
||||
|
||||
it("should generate valid schema", function()
|
||||
local test_tool = setmetatable({
|
||||
name = "test",
|
||||
description = "A test tool",
|
||||
params = {
|
||||
{ name = "arg1", type = "string", description = "First arg" },
|
||||
{ name = "arg2", type = "number", description = "Second arg", optional = true },
|
||||
},
|
||||
}, Base)
|
||||
|
||||
local schema = test_tool:to_schema()
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("test", schema.function_def.name)
|
||||
assert.is_table(schema.function_def.parameters.properties)
|
||||
assert.is_table(schema.function_def.parameters.required)
|
||||
assert.is_true(vim.tbl_contains(schema.function_def.parameters.required, "arg1"))
|
||||
assert.is_false(vim.tbl_contains(schema.function_def.parameters.required, "arg2"))
|
||||
end)
|
||||
end)
|
||||
229
tests/spec/ask_intent_spec.lua
Normal file
229
tests/spec/ask_intent_spec.lua
Normal file
@@ -0,0 +1,229 @@
|
||||
--- Tests for ask intent detection
|
||||
local intent = require("codetyper.ask.intent")
|
||||
|
||||
describe("ask.intent", function()
|
||||
describe("detect", function()
|
||||
-- Ask/Explain intent tests
|
||||
describe("ask intent", function()
|
||||
it("detects 'what' questions as ask", function()
|
||||
local result = intent.detect("What does this function do?")
|
||||
assert.equals("ask", result.type)
|
||||
assert.is_true(result.confidence > 0.3)
|
||||
end)
|
||||
|
||||
it("detects 'why' questions as ask", function()
|
||||
local result = intent.detect("Why is this variable undefined?")
|
||||
assert.equals("ask", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'how does' as ask", function()
|
||||
local result = intent.detect("How does this algorithm work?")
|
||||
assert.is_true(result.type == "ask" or result.type == "explain")
|
||||
end)
|
||||
|
||||
it("detects 'explain' requests as explain", function()
|
||||
local result = intent.detect("Explain me the project structure")
|
||||
assert.equals("explain", result.type)
|
||||
assert.is_true(result.confidence > 0.4)
|
||||
end)
|
||||
|
||||
it("detects 'walk me through' as explain", function()
|
||||
local result = intent.detect("Walk me through this code")
|
||||
assert.equals("explain", result.type)
|
||||
end)
|
||||
|
||||
it("detects questions ending with ? as likely ask", function()
|
||||
local result = intent.detect("Is this the right approach?")
|
||||
assert.equals("ask", result.type)
|
||||
end)
|
||||
|
||||
it("sets needs_brain_context for ask intent", function()
|
||||
local result = intent.detect("What patterns are used here?")
|
||||
assert.is_true(result.needs_brain_context)
|
||||
end)
|
||||
end)
|
||||
|
||||
-- Generate intent tests
|
||||
describe("generate intent", function()
|
||||
it("detects 'create' commands as generate", function()
|
||||
local result = intent.detect("Create a function to sort arrays")
|
||||
assert.equals("generate", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'write' commands as generate", function()
|
||||
local result = intent.detect("Write a unit test for this module")
|
||||
-- Could be generate or test
|
||||
assert.is_true(result.type == "generate" or result.type == "test")
|
||||
end)
|
||||
|
||||
it("detects 'implement' as generate", function()
|
||||
local result = intent.detect("Implement a binary search")
|
||||
assert.equals("generate", result.type)
|
||||
assert.is_true(result.confidence > 0.4)
|
||||
end)
|
||||
|
||||
it("detects 'add' commands as generate", function()
|
||||
local result = intent.detect("Add error handling to this function")
|
||||
assert.equals("generate", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'fix' as generate", function()
|
||||
local result = intent.detect("Fix the bug in line 42")
|
||||
assert.equals("generate", result.type)
|
||||
end)
|
||||
end)
|
||||
|
||||
-- Refactor intent tests
|
||||
describe("refactor intent", function()
|
||||
it("detects explicit 'refactor' as refactor", function()
|
||||
local result = intent.detect("Refactor this function")
|
||||
assert.equals("refactor", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'clean up' as refactor", function()
|
||||
local result = intent.detect("Clean up this messy code")
|
||||
assert.equals("refactor", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'simplify' as refactor", function()
|
||||
local result = intent.detect("Simplify this logic")
|
||||
assert.equals("refactor", result.type)
|
||||
end)
|
||||
end)
|
||||
|
||||
-- Document intent tests
|
||||
describe("document intent", function()
|
||||
it("detects 'document' as document", function()
|
||||
local result = intent.detect("Document this function")
|
||||
assert.equals("document", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'add documentation' as document", function()
|
||||
local result = intent.detect("Add documentation to this class")
|
||||
assert.equals("document", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'add jsdoc' as document", function()
|
||||
local result = intent.detect("Add jsdoc comments")
|
||||
assert.equals("document", result.type)
|
||||
end)
|
||||
end)
|
||||
|
||||
-- Test intent tests
|
||||
describe("test intent", function()
|
||||
it("detects 'write tests for' as test", function()
|
||||
local result = intent.detect("Write tests for this module")
|
||||
assert.equals("test", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'add unit tests' as test", function()
|
||||
local result = intent.detect("Add unit tests for the parser")
|
||||
assert.equals("test", result.type)
|
||||
end)
|
||||
|
||||
it("detects 'generate tests' as test", function()
|
||||
local result = intent.detect("Generate tests for the API")
|
||||
assert.equals("test", result.type)
|
||||
end)
|
||||
end)
|
||||
|
||||
-- Project context tests
|
||||
describe("project context detection", function()
|
||||
it("detects 'project' as needing project context", function()
|
||||
local result = intent.detect("Explain the project architecture")
|
||||
assert.is_true(result.needs_project_context)
|
||||
end)
|
||||
|
||||
it("detects 'codebase' as needing project context", function()
|
||||
local result = intent.detect("How is the codebase organized?")
|
||||
assert.is_true(result.needs_project_context)
|
||||
end)
|
||||
|
||||
it("does not need project context for simple questions", function()
|
||||
local result = intent.detect("What does this variable mean?")
|
||||
assert.is_false(result.needs_project_context)
|
||||
end)
|
||||
end)
|
||||
|
||||
-- Exploration tests
|
||||
describe("exploration detection", function()
|
||||
it("detects 'explain me the project' as needing exploration", function()
|
||||
local result = intent.detect("Explain me the project")
|
||||
assert.is_true(result.needs_exploration)
|
||||
end)
|
||||
|
||||
it("detects 'explain the codebase' as needing exploration", function()
|
||||
local result = intent.detect("Explain the codebase structure")
|
||||
assert.is_true(result.needs_exploration)
|
||||
end)
|
||||
|
||||
it("detects 'explore project' as needing exploration", function()
|
||||
local result = intent.detect("Explore this project")
|
||||
assert.is_true(result.needs_exploration)
|
||||
end)
|
||||
|
||||
it("does not need exploration for simple questions", function()
|
||||
local result = intent.detect("What does this function do?")
|
||||
assert.is_false(result.needs_exploration)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_prompt_type", function()
|
||||
it("maps ask to ask", function()
|
||||
local result = intent.get_prompt_type({ type = "ask" })
|
||||
assert.equals("ask", result)
|
||||
end)
|
||||
|
||||
it("maps explain to ask", function()
|
||||
local result = intent.get_prompt_type({ type = "explain" })
|
||||
assert.equals("ask", result)
|
||||
end)
|
||||
|
||||
it("maps generate to code_generation", function()
|
||||
local result = intent.get_prompt_type({ type = "generate" })
|
||||
assert.equals("code_generation", result)
|
||||
end)
|
||||
|
||||
it("maps refactor to refactor", function()
|
||||
local result = intent.get_prompt_type({ type = "refactor" })
|
||||
assert.equals("refactor", result)
|
||||
end)
|
||||
|
||||
it("maps document to document", function()
|
||||
local result = intent.get_prompt_type({ type = "document" })
|
||||
assert.equals("document", result)
|
||||
end)
|
||||
|
||||
it("maps test to test", function()
|
||||
local result = intent.get_prompt_type({ type = "test" })
|
||||
assert.equals("test", result)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("produces_code", function()
|
||||
it("returns false for ask", function()
|
||||
assert.is_false(intent.produces_code({ type = "ask" }))
|
||||
end)
|
||||
|
||||
it("returns false for explain", function()
|
||||
assert.is_false(intent.produces_code({ type = "explain" }))
|
||||
end)
|
||||
|
||||
it("returns true for generate", function()
|
||||
assert.is_true(intent.produces_code({ type = "generate" }))
|
||||
end)
|
||||
|
||||
it("returns true for refactor", function()
|
||||
assert.is_true(intent.produces_code({ type = "refactor" }))
|
||||
end)
|
||||
|
||||
it("returns true for document", function()
|
||||
assert.is_true(intent.produces_code({ type = "document" }))
|
||||
end)
|
||||
|
||||
it("returns true for test", function()
|
||||
assert.is_true(intent.produces_code({ type = "test" }))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
252
tests/spec/brain_delta_spec.lua
Normal file
252
tests/spec/brain_delta_spec.lua
Normal file
@@ -0,0 +1,252 @@
|
||||
--- Tests for brain/delta modules
|
||||
describe("brain.delta", function()
|
||||
local diff
|
||||
local commit
|
||||
local storage
|
||||
local types
|
||||
local test_root = "/tmp/codetyper_test_" .. os.time()
|
||||
|
||||
before_each(function()
|
||||
-- Clear module cache
|
||||
package.loaded["codetyper.brain.delta.diff"] = nil
|
||||
package.loaded["codetyper.brain.delta.commit"] = nil
|
||||
package.loaded["codetyper.brain.storage"] = nil
|
||||
package.loaded["codetyper.brain.types"] = nil
|
||||
|
||||
diff = require("codetyper.brain.delta.diff")
|
||||
commit = require("codetyper.brain.delta.commit")
|
||||
storage = require("codetyper.brain.storage")
|
||||
types = require("codetyper.brain.types")
|
||||
|
||||
storage.clear_cache()
|
||||
vim.fn.mkdir(test_root, "p")
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
-- Mock get_project_root
|
||||
local utils = require("codetyper.utils")
|
||||
utils.get_project_root = function()
|
||||
return test_root
|
||||
end
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
vim.fn.delete(test_root, "rf")
|
||||
storage.clear_cache()
|
||||
end)
|
||||
|
||||
describe("diff.compute", function()
|
||||
it("detects added values", function()
|
||||
local diffs = diff.compute(nil, { a = 1 })
|
||||
|
||||
assert.equals(1, #diffs)
|
||||
assert.equals("add", diffs[1].op)
|
||||
end)
|
||||
|
||||
it("detects deleted values", function()
|
||||
local diffs = diff.compute({ a = 1 }, nil)
|
||||
|
||||
assert.equals(1, #diffs)
|
||||
assert.equals("delete", diffs[1].op)
|
||||
end)
|
||||
|
||||
it("detects replaced values", function()
|
||||
local diffs = diff.compute({ a = 1 }, { a = 2 })
|
||||
|
||||
assert.equals(1, #diffs)
|
||||
assert.equals("replace", diffs[1].op)
|
||||
assert.equals(1, diffs[1].from)
|
||||
assert.equals(2, diffs[1].to)
|
||||
end)
|
||||
|
||||
it("detects nested changes", function()
|
||||
local before = { a = { b = 1 } }
|
||||
local after = { a = { b = 2 } }
|
||||
|
||||
local diffs = diff.compute(before, after)
|
||||
|
||||
assert.equals(1, #diffs)
|
||||
assert.equals("a.b", diffs[1].path)
|
||||
end)
|
||||
|
||||
it("returns empty for identical values", function()
|
||||
local diffs = diff.compute({ a = 1 }, { a = 1 })
|
||||
assert.equals(0, #diffs)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("diff.apply", function()
|
||||
it("applies add operation", function()
|
||||
local base = { a = 1 }
|
||||
local diffs = { { op = "add", path = "b", value = 2 } }
|
||||
|
||||
local result = diff.apply(base, diffs)
|
||||
|
||||
assert.equals(2, result.b)
|
||||
end)
|
||||
|
||||
it("applies replace operation", function()
|
||||
local base = { a = 1 }
|
||||
local diffs = { { op = "replace", path = "a", to = 2 } }
|
||||
|
||||
local result = diff.apply(base, diffs)
|
||||
|
||||
assert.equals(2, result.a)
|
||||
end)
|
||||
|
||||
it("applies delete operation", function()
|
||||
local base = { a = 1, b = 2 }
|
||||
local diffs = { { op = "delete", path = "a" } }
|
||||
|
||||
local result = diff.apply(base, diffs)
|
||||
|
||||
assert.is_nil(result.a)
|
||||
assert.equals(2, result.b)
|
||||
end)
|
||||
|
||||
it("applies nested changes", function()
|
||||
local base = { a = { b = 1 } }
|
||||
local diffs = { { op = "replace", path = "a.b", to = 2 } }
|
||||
|
||||
local result = diff.apply(base, diffs)
|
||||
|
||||
assert.equals(2, result.a.b)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("diff.reverse", function()
|
||||
it("reverses add to delete", function()
|
||||
local diffs = { { op = "add", path = "a", value = 1 } }
|
||||
|
||||
local reversed = diff.reverse(diffs)
|
||||
|
||||
assert.equals("delete", reversed[1].op)
|
||||
end)
|
||||
|
||||
it("reverses delete to add", function()
|
||||
local diffs = { { op = "delete", path = "a", value = 1 } }
|
||||
|
||||
local reversed = diff.reverse(diffs)
|
||||
|
||||
assert.equals("add", reversed[1].op)
|
||||
end)
|
||||
|
||||
it("reverses replace", function()
|
||||
local diffs = { { op = "replace", path = "a", from = 1, to = 2 } }
|
||||
|
||||
local reversed = diff.reverse(diffs)
|
||||
|
||||
assert.equals("replace", reversed[1].op)
|
||||
assert.equals(2, reversed[1].from)
|
||||
assert.equals(1, reversed[1].to)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("diff.equals", function()
|
||||
it("returns true for identical states", function()
|
||||
assert.is_true(diff.equals({ a = 1 }, { a = 1 }))
|
||||
end)
|
||||
|
||||
it("returns false for different states", function()
|
||||
assert.is_false(diff.equals({ a = 1 }, { a = 2 }))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("commit.create", function()
|
||||
it("creates a delta commit", function()
|
||||
local changes = {
|
||||
{ op = "add", path = "test.node1", ah = "abc123" },
|
||||
}
|
||||
|
||||
local delta = commit.create(changes, "Test commit", "test")
|
||||
|
||||
assert.is_not_nil(delta)
|
||||
assert.is_not_nil(delta.h)
|
||||
assert.equals("Test commit", delta.m.msg)
|
||||
assert.equals(1, #delta.ch)
|
||||
end)
|
||||
|
||||
it("updates HEAD", function()
|
||||
local changes = { { op = "add", path = "test.node1", ah = "abc123" } }
|
||||
|
||||
local delta = commit.create(changes, "Test", "test")
|
||||
|
||||
local head = storage.get_head(test_root)
|
||||
assert.equals(delta.h, head)
|
||||
end)
|
||||
|
||||
it("links to parent", function()
|
||||
local changes1 = { { op = "add", path = "test.node1", ah = "abc123" } }
|
||||
local delta1 = commit.create(changes1, "First", "test")
|
||||
|
||||
local changes2 = { { op = "add", path = "test.node2", ah = "def456" } }
|
||||
local delta2 = commit.create(changes2, "Second", "test")
|
||||
|
||||
assert.equals(delta1.h, delta2.p)
|
||||
end)
|
||||
|
||||
it("returns nil for empty changes", function()
|
||||
local delta = commit.create({}, "Empty")
|
||||
assert.is_nil(delta)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("commit.get", function()
|
||||
it("retrieves created delta", function()
|
||||
local changes = { { op = "add", path = "test.node1", ah = "abc123" } }
|
||||
local created = commit.create(changes, "Test", "test")
|
||||
|
||||
local retrieved = commit.get(created.h)
|
||||
|
||||
assert.is_not_nil(retrieved)
|
||||
assert.equals(created.h, retrieved.h)
|
||||
end)
|
||||
|
||||
it("returns nil for non-existent delta", function()
|
||||
local retrieved = commit.get("nonexistent")
|
||||
assert.is_nil(retrieved)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("commit.get_history", function()
|
||||
it("returns delta chain", function()
|
||||
commit.create({ { op = "add", path = "node1", ah = "1" } }, "First", "test")
|
||||
commit.create({ { op = "add", path = "node2", ah = "2" } }, "Second", "test")
|
||||
commit.create({ { op = "add", path = "node3", ah = "3" } }, "Third", "test")
|
||||
|
||||
local history = commit.get_history(10)
|
||||
|
||||
assert.equals(3, #history)
|
||||
assert.equals("Third", history[1].m.msg)
|
||||
assert.equals("Second", history[2].m.msg)
|
||||
assert.equals("First", history[3].m.msg)
|
||||
end)
|
||||
|
||||
it("respects limit", function()
|
||||
for i = 1, 5 do
|
||||
commit.create({ { op = "add", path = "node" .. i, ah = tostring(i) } }, "Commit " .. i, "test")
|
||||
end
|
||||
|
||||
local history = commit.get_history(3)
|
||||
|
||||
assert.equals(3, #history)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("commit.summarize", function()
|
||||
it("summarizes delta statistics", function()
|
||||
local changes = {
|
||||
{ op = "add", path = "nodes.a" },
|
||||
{ op = "add", path = "nodes.b" },
|
||||
{ op = "mod", path = "nodes.c" },
|
||||
{ op = "del", path = "nodes.d" },
|
||||
}
|
||||
local delta = commit.create(changes, "Test", "test")
|
||||
|
||||
local summary = commit.summarize(delta)
|
||||
|
||||
assert.equals(2, summary.stats.adds)
|
||||
assert.equals(4, summary.stats.total)
|
||||
assert.is_true(vim.tbl_contains(summary.categories, "nodes"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
128
tests/spec/brain_hash_spec.lua
Normal file
128
tests/spec/brain_hash_spec.lua
Normal file
@@ -0,0 +1,128 @@
|
||||
--- Tests for brain/hash.lua
|
||||
describe("brain.hash", function()
|
||||
local hash
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.brain.hash"] = nil
|
||||
hash = require("codetyper.brain.hash")
|
||||
end)
|
||||
|
||||
describe("compute", function()
|
||||
it("returns 8-character hash", function()
|
||||
local result = hash.compute("test string")
|
||||
assert.equals(8, #result)
|
||||
end)
|
||||
|
||||
it("returns consistent hash for same input", function()
|
||||
local result1 = hash.compute("test")
|
||||
local result2 = hash.compute("test")
|
||||
assert.equals(result1, result2)
|
||||
end)
|
||||
|
||||
it("returns different hash for different input", function()
|
||||
local result1 = hash.compute("test1")
|
||||
local result2 = hash.compute("test2")
|
||||
assert.not_equals(result1, result2)
|
||||
end)
|
||||
|
||||
it("handles empty string", function()
|
||||
local result = hash.compute("")
|
||||
assert.equals("00000000", result)
|
||||
end)
|
||||
|
||||
it("handles nil", function()
|
||||
local result = hash.compute(nil)
|
||||
assert.equals("00000000", result)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("compute_table", function()
|
||||
it("hashes table as JSON", function()
|
||||
local result = hash.compute_table({ a = 1, b = 2 })
|
||||
assert.equals(8, #result)
|
||||
end)
|
||||
|
||||
it("returns consistent hash for same table", function()
|
||||
local result1 = hash.compute_table({ x = "y" })
|
||||
local result2 = hash.compute_table({ x = "y" })
|
||||
assert.equals(result1, result2)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("node_id", function()
|
||||
it("generates ID with correct format", function()
|
||||
local id = hash.node_id("pat", "test content")
|
||||
assert.truthy(id:match("^n_pat_%d+_%w+$"))
|
||||
end)
|
||||
|
||||
it("generates unique IDs", function()
|
||||
local id1 = hash.node_id("pat", "test1")
|
||||
local id2 = hash.node_id("pat", "test2")
|
||||
assert.not_equals(id1, id2)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edge_id", function()
|
||||
it("generates ID with correct format", function()
|
||||
local id = hash.edge_id("source_node", "target_node")
|
||||
assert.truthy(id:match("^e_%w+_%w+$"))
|
||||
end)
|
||||
|
||||
it("returns same ID for same source/target", function()
|
||||
local id1 = hash.edge_id("s1", "t1")
|
||||
local id2 = hash.edge_id("s1", "t1")
|
||||
assert.equals(id1, id2)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("delta_hash", function()
|
||||
it("generates 8-character hash", function()
|
||||
local changes = { { op = "add", path = "test" } }
|
||||
local result = hash.delta_hash(changes, "parent", 12345)
|
||||
assert.equals(8, #result)
|
||||
end)
|
||||
|
||||
it("includes parent in hash", function()
|
||||
local changes = { { op = "add", path = "test" } }
|
||||
local result1 = hash.delta_hash(changes, "parent1", 12345)
|
||||
local result2 = hash.delta_hash(changes, "parent2", 12345)
|
||||
assert.not_equals(result1, result2)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("path_hash", function()
|
||||
it("returns 8-character hash", function()
|
||||
local result = hash.path_hash("/path/to/file.lua")
|
||||
assert.equals(8, #result)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("matches", function()
|
||||
it("returns true for matching hashes", function()
|
||||
assert.is_true(hash.matches("abc12345", "abc12345"))
|
||||
end)
|
||||
|
||||
it("returns false for different hashes", function()
|
||||
assert.is_false(hash.matches("abc12345", "def67890"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("random", function()
|
||||
it("returns 8-character string", function()
|
||||
local result = hash.random()
|
||||
assert.equals(8, #result)
|
||||
end)
|
||||
|
||||
it("generates different values", function()
|
||||
local result1 = hash.random()
|
||||
local result2 = hash.random()
|
||||
-- Note: There's a tiny chance these could match, but very unlikely
|
||||
assert.not_equals(result1, result2)
|
||||
end)
|
||||
|
||||
it("contains only hex characters", function()
|
||||
local result = hash.random()
|
||||
assert.truthy(result:match("^[0-9a-f]+$"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
153
tests/spec/brain_learners_spec.lua
Normal file
153
tests/spec/brain_learners_spec.lua
Normal file
@@ -0,0 +1,153 @@
|
||||
--- Tests for brain/learners pattern detection and extraction
|
||||
describe("brain.learners", function()
|
||||
local pattern_learner
|
||||
|
||||
before_each(function()
|
||||
-- Clear module cache
|
||||
package.loaded["codetyper.brain.learners.pattern"] = nil
|
||||
package.loaded["codetyper.brain.types"] = nil
|
||||
|
||||
pattern_learner = require("codetyper.brain.learners.pattern")
|
||||
end)
|
||||
|
||||
describe("pattern learner detection", function()
|
||||
it("should detect code_completion events", function()
|
||||
local event = { type = "code_completion", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect file_indexed events", function()
|
||||
local event = { type = "file_indexed", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect code_analyzed events", function()
|
||||
local event = { type = "code_analyzed", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect pattern_detected events", function()
|
||||
local event = { type = "pattern_detected", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect plain 'pattern' type events", function()
|
||||
-- This was the bug - 'pattern' type was not in the valid_types list
|
||||
local event = { type = "pattern", data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect unknown event types", function()
|
||||
local event = { type = "unknown_type", data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect nil events", function()
|
||||
assert.is_false(pattern_learner.detect(nil))
|
||||
end)
|
||||
|
||||
it("should NOT detect events without type", function()
|
||||
local event = { data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("pattern learner extraction", function()
|
||||
it("should extract from pattern_detected events", function()
|
||||
local event = {
|
||||
type = "pattern_detected",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
name = "Test pattern",
|
||||
description = "Pattern description",
|
||||
language = "lua",
|
||||
symbols = { "func1", "func2" },
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.equals("Test pattern", extracted.summary)
|
||||
assert.equals("Pattern description", extracted.detail)
|
||||
assert.equals("lua", extracted.lang)
|
||||
assert.equals("/path/to/file.lua", extracted.file)
|
||||
end)
|
||||
|
||||
it("should handle pattern_detected with minimal data", function()
|
||||
local event = {
|
||||
type = "pattern_detected",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
name = "Minimal pattern",
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.equals("Minimal pattern", extracted.summary)
|
||||
assert.equals("Minimal pattern", extracted.detail)
|
||||
end)
|
||||
|
||||
it("should extract from code_completion events", function()
|
||||
local event = {
|
||||
type = "code_completion",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
intent = "add function",
|
||||
code = "function test() end",
|
||||
language = "lua",
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.is_true(extracted.summary:find("Code pattern") ~= nil)
|
||||
assert.equals("function test() end", extracted.detail)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("should_learn validation", function()
|
||||
it("should accept valid patterns", function()
|
||||
local data = {
|
||||
summary = "Valid pattern summary",
|
||||
detail = "This is a detailed description of the pattern",
|
||||
}
|
||||
assert.is_true(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns without summary", function()
|
||||
local data = {
|
||||
summary = "",
|
||||
detail = "Some detail",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns with nil summary", function()
|
||||
local data = {
|
||||
summary = nil,
|
||||
detail = "Some detail",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns with very short detail", function()
|
||||
local data = {
|
||||
summary = "Valid summary",
|
||||
detail = "short", -- Less than 10 chars
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject whitespace-only summaries", function()
|
||||
local data = {
|
||||
summary = " ",
|
||||
detail = "Some valid detail here",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
234
tests/spec/brain_node_spec.lua
Normal file
234
tests/spec/brain_node_spec.lua
Normal file
@@ -0,0 +1,234 @@
|
||||
--- Tests for brain/graph/node.lua
|
||||
describe("brain.graph.node", function()
|
||||
local node
|
||||
local storage
|
||||
local types
|
||||
local test_root = "/tmp/codetyper_test_" .. os.time()
|
||||
|
||||
before_each(function()
|
||||
-- Clear module cache
|
||||
package.loaded["codetyper.brain.graph.node"] = nil
|
||||
package.loaded["codetyper.brain.storage"] = nil
|
||||
package.loaded["codetyper.brain.types"] = nil
|
||||
package.loaded["codetyper.brain.hash"] = nil
|
||||
|
||||
storage = require("codetyper.brain.storage")
|
||||
types = require("codetyper.brain.types")
|
||||
node = require("codetyper.brain.graph.node")
|
||||
|
||||
storage.clear_cache()
|
||||
vim.fn.mkdir(test_root, "p")
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
-- Mock get_project_root
|
||||
local utils = require("codetyper.utils")
|
||||
utils.get_project_root = function()
|
||||
return test_root
|
||||
end
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
vim.fn.delete(test_root, "rf")
|
||||
storage.clear_cache()
|
||||
node.pending = {}
|
||||
end)
|
||||
|
||||
describe("create", function()
|
||||
it("creates a new node with correct structure", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, {
|
||||
s = "Test pattern summary",
|
||||
d = "Test pattern detail",
|
||||
}, {
|
||||
f = "test.lua",
|
||||
})
|
||||
|
||||
assert.is_not_nil(created.id)
|
||||
assert.equals(types.NODE_TYPES.PATTERN, created.t)
|
||||
assert.equals("Test pattern summary", created.c.s)
|
||||
assert.equals("test.lua", created.ctx.f)
|
||||
assert.equals(0.5, created.sc.w)
|
||||
assert.equals(0, created.sc.u)
|
||||
end)
|
||||
|
||||
it("generates unique IDs", function()
|
||||
local node1 = node.create(types.NODE_TYPES.PATTERN, { s = "First" }, {})
|
||||
local node2 = node.create(types.NODE_TYPES.PATTERN, { s = "Second" }, {})
|
||||
|
||||
assert.is_not_nil(node1.id)
|
||||
assert.is_not_nil(node2.id)
|
||||
assert.not_equals(node1.id, node2.id)
|
||||
end)
|
||||
|
||||
it("updates meta node count", function()
|
||||
local meta_before = storage.get_meta(test_root)
|
||||
local count_before = meta_before.nc
|
||||
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
|
||||
local meta_after = storage.get_meta(test_root)
|
||||
assert.equals(count_before + 1, meta_after.nc)
|
||||
end)
|
||||
|
||||
it("tracks pending change", function()
|
||||
node.pending = {}
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
|
||||
assert.equals(1, #node.pending)
|
||||
assert.equals("add", node.pending[1].op)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get", function()
|
||||
it("retrieves created node", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
|
||||
local retrieved = node.get(created.id)
|
||||
|
||||
assert.is_not_nil(retrieved)
|
||||
assert.equals(created.id, retrieved.id)
|
||||
assert.equals("Test", retrieved.c.s)
|
||||
end)
|
||||
|
||||
it("returns nil for non-existent node", function()
|
||||
local retrieved = node.get("n_pat_0_nonexistent")
|
||||
assert.is_nil(retrieved)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("update", function()
|
||||
it("updates node content", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, { s = "Original" }, {})
|
||||
|
||||
node.update(created.id, { c = { s = "Updated" } })
|
||||
|
||||
local updated = node.get(created.id)
|
||||
assert.equals("Updated", updated.c.s)
|
||||
end)
|
||||
|
||||
it("updates node scores", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
|
||||
node.update(created.id, { sc = { w = 0.9 } })
|
||||
|
||||
local updated = node.get(created.id)
|
||||
assert.equals(0.9, updated.sc.w)
|
||||
end)
|
||||
|
||||
it("increments version", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
local original_version = created.m.v
|
||||
|
||||
node.update(created.id, { c = { s = "Updated" } })
|
||||
|
||||
local updated = node.get(created.id)
|
||||
assert.equals(original_version + 1, updated.m.v)
|
||||
end)
|
||||
|
||||
it("returns nil for non-existent node", function()
|
||||
local result = node.update("n_pat_0_nonexistent", { c = { s = "Test" } })
|
||||
assert.is_nil(result)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("delete", function()
|
||||
it("removes node", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
|
||||
local result = node.delete(created.id)
|
||||
|
||||
assert.is_true(result)
|
||||
assert.is_nil(node.get(created.id))
|
||||
end)
|
||||
|
||||
it("decrements meta node count", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
local meta_before = storage.get_meta(test_root)
|
||||
local count_before = meta_before.nc
|
||||
|
||||
node.delete(created.id)
|
||||
|
||||
local meta_after = storage.get_meta(test_root)
|
||||
assert.equals(count_before - 1, meta_after.nc)
|
||||
end)
|
||||
|
||||
it("returns false for non-existent node", function()
|
||||
local result = node.delete("n_pat_0_nonexistent")
|
||||
assert.is_false(result)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("find", function()
|
||||
it("finds nodes by type", function()
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Pattern 1" }, {})
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Pattern 2" }, {})
|
||||
node.create(types.NODE_TYPES.CORRECTION, { s = "Correction 1" }, {})
|
||||
|
||||
local patterns = node.find({ types = { types.NODE_TYPES.PATTERN } })
|
||||
|
||||
assert.equals(2, #patterns)
|
||||
end)
|
||||
|
||||
it("finds nodes by file", function()
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Test 1" }, { f = "file1.lua" })
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Test 2" }, { f = "file2.lua" })
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Test 3" }, { f = "file1.lua" })
|
||||
|
||||
local found = node.find({ file = "file1.lua" })
|
||||
|
||||
assert.equals(2, #found)
|
||||
end)
|
||||
|
||||
it("finds nodes by query", function()
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Foo bar baz" }, {})
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Something else" }, {})
|
||||
|
||||
local found = node.find({ query = "foo" })
|
||||
|
||||
assert.equals(1, #found)
|
||||
assert.equals("Foo bar baz", found[1].c.s)
|
||||
end)
|
||||
|
||||
it("respects limit", function()
|
||||
for i = 1, 10 do
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Node " .. i }, {})
|
||||
end
|
||||
|
||||
local found = node.find({ limit = 5 })
|
||||
|
||||
assert.equals(5, #found)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("record_usage", function()
|
||||
it("increments usage count", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
|
||||
node.record_usage(created.id, true)
|
||||
|
||||
local updated = node.get(created.id)
|
||||
assert.equals(1, updated.sc.u)
|
||||
end)
|
||||
|
||||
it("updates success rate", function()
|
||||
local created = node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
|
||||
node.record_usage(created.id, true)
|
||||
node.record_usage(created.id, false)
|
||||
|
||||
local updated = node.get(created.id)
|
||||
assert.equals(0.5, updated.sc.sr)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_and_clear_pending", function()
|
||||
it("returns and clears pending changes", function()
|
||||
node.pending = {}
|
||||
node.create(types.NODE_TYPES.PATTERN, { s = "Test" }, {})
|
||||
|
||||
local pending = node.get_and_clear_pending()
|
||||
|
||||
assert.equals(1, #pending)
|
||||
assert.equals(0, #node.pending)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
173
tests/spec/brain_storage_spec.lua
Normal file
173
tests/spec/brain_storage_spec.lua
Normal file
@@ -0,0 +1,173 @@
|
||||
--- Tests for brain/storage.lua
|
||||
describe("brain.storage", function()
|
||||
local storage
|
||||
local test_root = "/tmp/codetyper_test_" .. os.time()
|
||||
|
||||
before_each(function()
|
||||
-- Clear module cache to get fresh state
|
||||
package.loaded["codetyper.brain.storage"] = nil
|
||||
package.loaded["codetyper.brain.types"] = nil
|
||||
storage = require("codetyper.brain.storage")
|
||||
|
||||
-- Clear cache before each test
|
||||
storage.clear_cache()
|
||||
|
||||
-- Create test directory
|
||||
vim.fn.mkdir(test_root, "p")
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
-- Clean up test directory
|
||||
vim.fn.delete(test_root, "rf")
|
||||
storage.clear_cache()
|
||||
end)
|
||||
|
||||
describe("get_brain_dir", function()
|
||||
it("returns correct path", function()
|
||||
local dir = storage.get_brain_dir(test_root)
|
||||
assert.equals(test_root .. "/.coder/brain", dir)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("ensure_dirs", function()
|
||||
it("creates required directories", function()
|
||||
local result = storage.ensure_dirs(test_root)
|
||||
assert.is_true(result)
|
||||
|
||||
-- Check directories exist
|
||||
assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain"))
|
||||
assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain/nodes"))
|
||||
assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain/indices"))
|
||||
assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain/deltas"))
|
||||
assert.equals(1, vim.fn.isdirectory(test_root .. "/.coder/brain/deltas/objects"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_path", function()
|
||||
it("returns correct path for simple key", function()
|
||||
local path = storage.get_path("meta", test_root)
|
||||
assert.equals(test_root .. "/.coder/brain/meta.json", path)
|
||||
end)
|
||||
|
||||
it("returns correct path for nested key", function()
|
||||
local path = storage.get_path("nodes.patterns", test_root)
|
||||
assert.equals(test_root .. "/.coder/brain/nodes/patterns.json", path)
|
||||
end)
|
||||
|
||||
it("returns correct path for deeply nested key", function()
|
||||
local path = storage.get_path("deltas.objects.abc123", test_root)
|
||||
assert.equals(test_root .. "/.coder/brain/deltas/objects/abc123.json", path)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("save and load", function()
|
||||
it("saves and loads data correctly", function()
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
local data = { test = "value", count = 42 }
|
||||
storage.save("meta", data, test_root, true) -- immediate
|
||||
|
||||
-- Clear cache and reload
|
||||
storage.clear_cache()
|
||||
local loaded = storage.load("meta", test_root)
|
||||
|
||||
assert.equals("value", loaded.test)
|
||||
assert.equals(42, loaded.count)
|
||||
end)
|
||||
|
||||
it("returns empty table for missing files", function()
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
local loaded = storage.load("nonexistent", test_root)
|
||||
assert.same({}, loaded)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_meta", function()
|
||||
it("creates default meta if not exists", function()
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
local meta = storage.get_meta(test_root)
|
||||
|
||||
assert.is_not_nil(meta.v)
|
||||
assert.equals(0, meta.nc)
|
||||
assert.equals(0, meta.ec)
|
||||
assert.equals(0, meta.dc)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("update_meta", function()
|
||||
it("updates meta values", function()
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
storage.update_meta({ nc = 5 }, test_root)
|
||||
local meta = storage.get_meta(test_root)
|
||||
|
||||
assert.equals(5, meta.nc)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get/save_nodes", function()
|
||||
it("saves and retrieves nodes by type", function()
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
local nodes = {
|
||||
["n_pat_123_abc"] = { id = "n_pat_123_abc", t = "pat" },
|
||||
["n_pat_456_def"] = { id = "n_pat_456_def", t = "pat" },
|
||||
}
|
||||
|
||||
storage.save_nodes("patterns", nodes, test_root)
|
||||
storage.flush("nodes.patterns", test_root)
|
||||
|
||||
storage.clear_cache()
|
||||
local loaded = storage.get_nodes("patterns", test_root)
|
||||
|
||||
assert.equals(2, vim.tbl_count(loaded))
|
||||
assert.equals("n_pat_123_abc", loaded["n_pat_123_abc"].id)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get/save_graph", function()
|
||||
it("saves and retrieves graph", function()
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
local graph = {
|
||||
adj = { node1 = { sem = { "node2" } } },
|
||||
radj = { node2 = { sem = { "node1" } } },
|
||||
}
|
||||
|
||||
storage.save_graph(graph, test_root)
|
||||
storage.flush("graph", test_root)
|
||||
|
||||
storage.clear_cache()
|
||||
local loaded = storage.get_graph(test_root)
|
||||
|
||||
assert.same({ "node2" }, loaded.adj.node1.sem)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get/set_head", function()
|
||||
it("stores and retrieves HEAD", function()
|
||||
storage.ensure_dirs(test_root)
|
||||
|
||||
storage.set_head("abc12345", test_root)
|
||||
storage.flush("meta", test_root) -- Ensure written to disk
|
||||
|
||||
storage.clear_cache()
|
||||
local head = storage.get_head(test_root)
|
||||
|
||||
assert.equals("abc12345", head)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("exists", function()
|
||||
it("returns false for non-existent brain", function()
|
||||
assert.is_false(storage.exists(test_root))
|
||||
end)
|
||||
|
||||
it("returns true after ensure_dirs", function()
|
||||
storage.ensure_dirs(test_root)
|
||||
assert.is_true(storage.exists(test_root))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
194
tests/spec/coder_context_spec.lua
Normal file
194
tests/spec/coder_context_spec.lua
Normal file
@@ -0,0 +1,194 @@
|
||||
--- Tests for coder file context injection
|
||||
describe("coder context injection", function()
|
||||
local test_dir
|
||||
local original_filereadable
|
||||
|
||||
before_each(function()
|
||||
test_dir = "/tmp/codetyper_coder_test_" .. os.time()
|
||||
vim.fn.mkdir(test_dir, "p")
|
||||
|
||||
-- Store original function
|
||||
original_filereadable = vim.fn.filereadable
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
vim.fn.delete(test_dir, "rf")
|
||||
vim.fn.filereadable = original_filereadable
|
||||
end)
|
||||
|
||||
describe("get_coder_companion_path logic", function()
|
||||
-- Test the path generation logic (simulating the function behavior)
|
||||
local function get_coder_companion_path(target_path, file_exists_check)
|
||||
if not target_path or target_path == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Skip if target is already a coder file
|
||||
if target_path:match("%.coder%.") then
|
||||
return nil
|
||||
end
|
||||
|
||||
local dir = vim.fn.fnamemodify(target_path, ":h")
|
||||
local name = vim.fn.fnamemodify(target_path, ":t:r")
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
local coder_path = dir .. "/" .. name .. ".coder." .. ext
|
||||
if file_exists_check(coder_path) then
|
||||
return coder_path
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
it("should generate correct coder path for source file", function()
|
||||
local target = "/path/to/file.ts"
|
||||
local expected = "/path/to/file.coder.ts"
|
||||
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
|
||||
assert.equals(expected, path)
|
||||
end)
|
||||
|
||||
it("should return nil for empty path", function()
|
||||
local path = get_coder_companion_path("", function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil for nil path", function()
|
||||
local path = get_coder_companion_path(nil, function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil for coder files (avoid recursion)", function()
|
||||
local target = "/path/to/file.coder.ts"
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil if coder file doesn't exist", function()
|
||||
local target = "/path/to/file.ts"
|
||||
local path = get_coder_companion_path(target, function() return false end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should handle files with multiple dots", function()
|
||||
local target = "/path/to/my.component.ts"
|
||||
local expected = "/path/to/my.component.coder.ts"
|
||||
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
|
||||
assert.equals(expected, path)
|
||||
end)
|
||||
|
||||
it("should handle different extensions", function()
|
||||
local test_cases = {
|
||||
{ target = "/path/file.lua", expected = "/path/file.coder.lua" },
|
||||
{ target = "/path/file.py", expected = "/path/file.coder.py" },
|
||||
{ target = "/path/file.js", expected = "/path/file.coder.js" },
|
||||
{ target = "/path/file.go", expected = "/path/file.coder.go" },
|
||||
}
|
||||
|
||||
for _, tc in ipairs(test_cases) do
|
||||
local path = get_coder_companion_path(tc.target, function() return true end)
|
||||
assert.equals(tc.expected, path, "Failed for: " .. tc.target)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("coder content filtering", function()
|
||||
-- Test the filtering logic that skips template-only content
|
||||
local function has_meaningful_content(lines)
|
||||
for _, line in ipairs(lines) do
|
||||
local trimmed = line:gsub("^%s*", "")
|
||||
if not trimmed:match("^[%-#/]+%s*Coder companion")
|
||||
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
|
||||
and not trimmed:match("^[%-#/]+%s*Example:")
|
||||
and not trimmed:match("^<!%-%-")
|
||||
and trimmed ~= ""
|
||||
and not trimmed:match("^[%-#/]+%s*$") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
it("should detect meaningful content", function()
|
||||
local lines = {
|
||||
"-- Coder companion for test.lua",
|
||||
"-- This file handles authentication",
|
||||
"/@",
|
||||
"Add login function",
|
||||
"@/",
|
||||
}
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should reject template-only content", function()
|
||||
-- Template lines are filtered by specific patterns
|
||||
-- Only header comments that match the template format are filtered
|
||||
local lines = {
|
||||
"-- Coder companion for test.lua",
|
||||
"-- Use /@ @/ tags to write pseudo-code prompts",
|
||||
"-- Example:",
|
||||
"--",
|
||||
"",
|
||||
}
|
||||
assert.is_false(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should detect pseudo-code content", function()
|
||||
local lines = {
|
||||
"-- Authentication module",
|
||||
"",
|
||||
"-- This module should:",
|
||||
"-- 1. Validate user credentials",
|
||||
"-- 2. Generate JWT tokens",
|
||||
"-- 3. Handle session management",
|
||||
}
|
||||
-- "-- Authentication module" doesn't match template patterns
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should handle JavaScript style comments", function()
|
||||
local lines = {
|
||||
"// Coder companion for test.ts",
|
||||
"// Business logic for user authentication",
|
||||
"",
|
||||
"// The auth flow should:",
|
||||
"// 1. Check OAuth token",
|
||||
"// 2. Validate permissions",
|
||||
}
|
||||
-- "// Business logic..." doesn't match template patterns
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should handle empty lines", function()
|
||||
local lines = {
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
}
|
||||
assert.is_false(has_meaningful_content(lines))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("context format", function()
|
||||
it("should format context with proper header", function()
|
||||
local function format_coder_context(content, ext)
|
||||
return string.format(
|
||||
"\n\n--- Business Context / Pseudo-code ---\n" ..
|
||||
"The following describes the intended behavior and design for this file:\n" ..
|
||||
"```%s\n%s\n```",
|
||||
ext,
|
||||
content
|
||||
)
|
||||
end
|
||||
|
||||
local formatted = format_coder_context("-- Auth logic here", "lua")
|
||||
|
||||
assert.is_true(formatted:find("Business Context") ~= nil)
|
||||
assert.is_true(formatted:find("```lua") ~= nil)
|
||||
assert.is_true(formatted:find("Auth logic here") ~= nil)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
161
tests/spec/coder_ignore_spec.lua
Normal file
161
tests/spec/coder_ignore_spec.lua
Normal file
@@ -0,0 +1,161 @@
|
||||
--- Tests for coder file ignore logic
|
||||
describe("coder file ignore logic", function()
|
||||
-- Directories to ignore
|
||||
local ignored_directories = {
|
||||
".git",
|
||||
".coder",
|
||||
".claude",
|
||||
".vscode",
|
||||
".idea",
|
||||
"node_modules",
|
||||
"vendor",
|
||||
"dist",
|
||||
"build",
|
||||
"target",
|
||||
"__pycache__",
|
||||
".cache",
|
||||
".npm",
|
||||
".yarn",
|
||||
"coverage",
|
||||
".next",
|
||||
".nuxt",
|
||||
".svelte-kit",
|
||||
"out",
|
||||
"bin",
|
||||
"obj",
|
||||
}
|
||||
|
||||
-- Files to ignore
|
||||
local ignored_files = {
|
||||
".gitignore",
|
||||
".gitattributes",
|
||||
"package-lock.json",
|
||||
"yarn.lock",
|
||||
".env",
|
||||
".eslintrc",
|
||||
"tsconfig.json",
|
||||
"README.md",
|
||||
"LICENSE",
|
||||
"Makefile",
|
||||
}
|
||||
|
||||
local function is_in_ignored_directory(filepath)
|
||||
for _, dir in ipairs(ignored_directories) do
|
||||
if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then
|
||||
return true
|
||||
end
|
||||
if filepath:match("^" .. dir .. "/") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function should_ignore_for_coder(filepath)
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
|
||||
for _, ignored in ipairs(ignored_files) do
|
||||
if filename == ignored then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
if filename:match("^%.") then
|
||||
return true
|
||||
end
|
||||
|
||||
if is_in_ignored_directory(filepath) then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
describe("ignored directories", function()
|
||||
it("should ignore files in node_modules", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/node_modules/lodash/index.js"))
|
||||
assert.is_true(should_ignore_for_coder("/project/node_modules/react/index.js"))
|
||||
end)
|
||||
|
||||
it("should ignore files in .git", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.git/config"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.git/hooks/pre-commit"))
|
||||
end)
|
||||
|
||||
it("should ignore files in .coder", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.coder/brain/meta.json"))
|
||||
end)
|
||||
|
||||
it("should ignore files in vendor", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/vendor/autoload.php"))
|
||||
end)
|
||||
|
||||
it("should ignore files in dist/build", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/dist/bundle.js"))
|
||||
assert.is_true(should_ignore_for_coder("/project/build/output.js"))
|
||||
end)
|
||||
|
||||
it("should ignore files in __pycache__", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/__pycache__/module.cpython-39.pyc"))
|
||||
end)
|
||||
|
||||
it("should NOT ignore regular source files", function()
|
||||
assert.is_false(should_ignore_for_coder("/project/src/index.ts"))
|
||||
assert.is_false(should_ignore_for_coder("/project/lib/utils.lua"))
|
||||
assert.is_false(should_ignore_for_coder("/project/app/main.py"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("ignored files", function()
|
||||
it("should ignore .gitignore", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.gitignore"))
|
||||
end)
|
||||
|
||||
it("should ignore lock files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/package-lock.json"))
|
||||
assert.is_true(should_ignore_for_coder("/project/yarn.lock"))
|
||||
end)
|
||||
|
||||
it("should ignore config files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/tsconfig.json"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.eslintrc"))
|
||||
end)
|
||||
|
||||
it("should ignore .env files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.env"))
|
||||
end)
|
||||
|
||||
it("should ignore README and LICENSE", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/README.md"))
|
||||
assert.is_true(should_ignore_for_coder("/project/LICENSE"))
|
||||
end)
|
||||
|
||||
it("should ignore hidden/dot files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.hidden"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.secret"))
|
||||
end)
|
||||
|
||||
it("should NOT ignore regular source files", function()
|
||||
assert.is_false(should_ignore_for_coder("/project/src/app.ts"))
|
||||
assert.is_false(should_ignore_for_coder("/project/components/Button.tsx"))
|
||||
assert.is_false(should_ignore_for_coder("/project/utils/helpers.js"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edge cases", function()
|
||||
it("should handle nested node_modules", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/packages/core/node_modules/dep/index.js"))
|
||||
end)
|
||||
|
||||
it("should handle files named like directories but not in them", function()
|
||||
-- A file named "node_modules.md" in root should be ignored (starts with .)
|
||||
-- But a file in a folder that contains "node" should NOT be ignored
|
||||
assert.is_false(should_ignore_for_coder("/project/src/node_utils.ts"))
|
||||
end)
|
||||
|
||||
it("should handle relative paths", function()
|
||||
assert.is_true(should_ignore_for_coder("node_modules/lodash/index.js"))
|
||||
assert.is_false(should_ignore_for_coder("src/index.ts"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
345
tests/spec/indexer_spec.lua
Normal file
345
tests/spec/indexer_spec.lua
Normal file
@@ -0,0 +1,345 @@
|
||||
---@diagnostic disable: undefined-global
|
||||
-- Tests for lua/codetyper/indexer/init.lua
|
||||
|
||||
describe("indexer", function()
|
||||
local indexer
|
||||
local utils
|
||||
|
||||
-- Mock cwd for testing
|
||||
local test_cwd = "/tmp/codetyper_test_indexer"
|
||||
|
||||
before_each(function()
|
||||
-- Reset modules
|
||||
package.loaded["codetyper.indexer"] = nil
|
||||
package.loaded["codetyper.indexer.scanner"] = nil
|
||||
package.loaded["codetyper.indexer.analyzer"] = nil
|
||||
package.loaded["codetyper.indexer.memory"] = nil
|
||||
package.loaded["codetyper.utils"] = nil
|
||||
|
||||
indexer = require("codetyper.indexer")
|
||||
utils = require("codetyper.utils")
|
||||
|
||||
-- Create test directory structure
|
||||
vim.fn.mkdir(test_cwd, "p")
|
||||
vim.fn.mkdir(test_cwd .. "/.coder", "p")
|
||||
vim.fn.mkdir(test_cwd .. "/src", "p")
|
||||
|
||||
-- Mock getcwd to return test directory
|
||||
vim.fn.getcwd = function()
|
||||
return test_cwd
|
||||
end
|
||||
|
||||
-- Mock get_project_root
|
||||
package.loaded["codetyper.utils"].get_project_root = function()
|
||||
return test_cwd
|
||||
end
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
-- Clean up test directory
|
||||
vim.fn.delete(test_cwd, "rf")
|
||||
end)
|
||||
|
||||
describe("setup", function()
|
||||
it("should accept configuration options", function()
|
||||
indexer.setup({
|
||||
enabled = true,
|
||||
auto_index = false,
|
||||
})
|
||||
|
||||
local config = indexer.get_config()
|
||||
assert.is_false(config.auto_index)
|
||||
end)
|
||||
|
||||
it("should use default configuration when no options provided", function()
|
||||
indexer.setup()
|
||||
|
||||
local config = indexer.get_config()
|
||||
assert.is_true(config.enabled)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("load_index", function()
|
||||
it("should return nil when no index exists", function()
|
||||
local index = indexer.load_index()
|
||||
|
||||
assert.is_nil(index)
|
||||
end)
|
||||
|
||||
it("should load existing index from file", function()
|
||||
-- Create a mock index file
|
||||
local mock_index = {
|
||||
version = 1,
|
||||
project_root = test_cwd,
|
||||
project_name = "test",
|
||||
project_type = "node",
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
files = {},
|
||||
symbols = {},
|
||||
last_indexed = os.time(),
|
||||
stats = { files = 0, functions = 0, classes = 0, exports = 0 },
|
||||
}
|
||||
utils.write_file(test_cwd .. "/.coder/index.json", vim.json.encode(mock_index))
|
||||
|
||||
local index = indexer.load_index()
|
||||
|
||||
assert.is_table(index)
|
||||
assert.equals("test", index.project_name)
|
||||
assert.equals("node", index.project_type)
|
||||
end)
|
||||
|
||||
it("should cache loaded index", function()
|
||||
local mock_index = {
|
||||
version = 1,
|
||||
project_root = test_cwd,
|
||||
project_name = "cached_test",
|
||||
project_type = "lua",
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
files = {},
|
||||
symbols = {},
|
||||
last_indexed = os.time(),
|
||||
stats = { files = 0, functions = 0, classes = 0, exports = 0 },
|
||||
}
|
||||
utils.write_file(test_cwd .. "/.coder/index.json", vim.json.encode(mock_index))
|
||||
|
||||
local index1 = indexer.load_index()
|
||||
local index2 = indexer.load_index()
|
||||
|
||||
assert.equals(index1.project_name, index2.project_name)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("save_index", function()
|
||||
it("should save index to file", function()
|
||||
local index = {
|
||||
version = 1,
|
||||
project_root = test_cwd,
|
||||
project_name = "save_test",
|
||||
project_type = "node",
|
||||
dependencies = { express = "^4.18.0" },
|
||||
dev_dependencies = {},
|
||||
files = {},
|
||||
symbols = {},
|
||||
last_indexed = os.time(),
|
||||
stats = { files = 0, functions = 0, classes = 0, exports = 0 },
|
||||
}
|
||||
|
||||
local result = indexer.save_index(index)
|
||||
|
||||
assert.is_true(result)
|
||||
|
||||
-- Verify file was created
|
||||
local content = utils.read_file(test_cwd .. "/.coder/index.json")
|
||||
assert.is_truthy(content)
|
||||
|
||||
local decoded = vim.json.decode(content)
|
||||
assert.equals("save_test", decoded.project_name)
|
||||
end)
|
||||
|
||||
it("should create .coder directory if it does not exist", function()
|
||||
vim.fn.delete(test_cwd .. "/.coder", "rf")
|
||||
|
||||
local index = {
|
||||
version = 1,
|
||||
project_root = test_cwd,
|
||||
project_name = "test",
|
||||
project_type = "unknown",
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
files = {},
|
||||
symbols = {},
|
||||
last_indexed = os.time(),
|
||||
stats = { files = 0, functions = 0, classes = 0, exports = 0 },
|
||||
}
|
||||
|
||||
indexer.save_index(index)
|
||||
|
||||
assert.equals(1, vim.fn.isdirectory(test_cwd .. "/.coder"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("index_project", function()
|
||||
it("should create an index for the project", function()
|
||||
-- Create some test files
|
||||
utils.write_file(test_cwd .. "/package.json", '{"name":"test","dependencies":{}}')
|
||||
utils.write_file(test_cwd .. "/src/main.lua", [[
|
||||
local M = {}
|
||||
function M.hello()
|
||||
return "world"
|
||||
end
|
||||
return M
|
||||
]])
|
||||
|
||||
indexer.setup({ index_extensions = { "lua" } })
|
||||
local index = indexer.index_project()
|
||||
|
||||
assert.is_table(index)
|
||||
assert.equals("node", index.project_type)
|
||||
assert.is_truthy(index.stats.files >= 0)
|
||||
end)
|
||||
|
||||
it("should detect project dependencies", function()
|
||||
utils.write_file(test_cwd .. "/package.json", [[{
|
||||
"name": "test",
|
||||
"dependencies": {
|
||||
"express": "^4.18.0",
|
||||
"lodash": "^4.17.0"
|
||||
}
|
||||
}]])
|
||||
|
||||
indexer.setup()
|
||||
local index = indexer.index_project()
|
||||
|
||||
assert.is_table(index.dependencies)
|
||||
assert.equals("^4.18.0", index.dependencies.express)
|
||||
end)
|
||||
|
||||
it("should call callback when complete", function()
|
||||
local callback_called = false
|
||||
local callback_index = nil
|
||||
|
||||
indexer.setup()
|
||||
indexer.index_project(function(index)
|
||||
callback_called = true
|
||||
callback_index = index
|
||||
end)
|
||||
|
||||
assert.is_true(callback_called)
|
||||
assert.is_table(callback_index)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("index_file", function()
|
||||
it("should index a single file", function()
|
||||
utils.write_file(test_cwd .. "/src/test.lua", [[
|
||||
local M = {}
|
||||
function M.add(a, b)
|
||||
return a + b
|
||||
end
|
||||
function M.subtract(a, b)
|
||||
return a - b
|
||||
end
|
||||
return M
|
||||
]])
|
||||
|
||||
indexer.setup({ index_extensions = { "lua" } })
|
||||
-- First create an initial index
|
||||
indexer.index_project()
|
||||
|
||||
local file_index = indexer.index_file(test_cwd .. "/src/test.lua")
|
||||
|
||||
assert.is_table(file_index)
|
||||
assert.equals("src/test.lua", file_index.path)
|
||||
end)
|
||||
|
||||
it("should update symbols in the main index", function()
|
||||
utils.write_file(test_cwd .. "/src/utils.lua", [[
|
||||
local M = {}
|
||||
function M.format_string(str)
|
||||
return string.upper(str)
|
||||
end
|
||||
return M
|
||||
]])
|
||||
|
||||
indexer.setup({ index_extensions = { "lua" } })
|
||||
indexer.index_project()
|
||||
indexer.index_file(test_cwd .. "/src/utils.lua")
|
||||
|
||||
local index = indexer.load_index()
|
||||
assert.is_table(index.files)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_status", function()
|
||||
it("should return indexed: false when no index exists", function()
|
||||
local status = indexer.get_status()
|
||||
|
||||
assert.is_false(status.indexed)
|
||||
assert.is_nil(status.stats)
|
||||
end)
|
||||
|
||||
it("should return status when index exists", function()
|
||||
indexer.setup()
|
||||
indexer.index_project()
|
||||
|
||||
local status = indexer.get_status()
|
||||
|
||||
assert.is_true(status.indexed)
|
||||
assert.is_table(status.stats)
|
||||
assert.is_truthy(status.last_indexed)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_context_for", function()
|
||||
it("should return context with project type", function()
|
||||
utils.write_file(test_cwd .. "/package.json", '{"name":"test"}')
|
||||
indexer.setup()
|
||||
indexer.index_project()
|
||||
|
||||
local context = indexer.get_context_for({
|
||||
file = test_cwd .. "/src/main.lua",
|
||||
prompt = "add a function",
|
||||
})
|
||||
|
||||
assert.is_table(context)
|
||||
assert.equals("node", context.project_type)
|
||||
end)
|
||||
|
||||
it("should find relevant symbols", function()
|
||||
utils.write_file(test_cwd .. "/src/utils.lua", [[
|
||||
local M = {}
|
||||
function M.calculate_total(items)
|
||||
return 0
|
||||
end
|
||||
return M
|
||||
]])
|
||||
indexer.setup({ index_extensions = { "lua" } })
|
||||
indexer.index_project()
|
||||
|
||||
local context = indexer.get_context_for({
|
||||
file = test_cwd .. "/src/main.lua",
|
||||
prompt = "use calculate_total function",
|
||||
})
|
||||
|
||||
assert.is_table(context)
|
||||
-- Should find the calculate symbol
|
||||
if context.relevant_symbols and context.relevant_symbols.calculate then
|
||||
assert.is_table(context.relevant_symbols.calculate)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("clear", function()
|
||||
it("should remove the index file", function()
|
||||
indexer.setup()
|
||||
indexer.index_project()
|
||||
|
||||
-- Verify index exists
|
||||
assert.is_true(indexer.get_status().indexed)
|
||||
|
||||
indexer.clear()
|
||||
|
||||
-- Verify index is gone
|
||||
local status = indexer.get_status()
|
||||
assert.is_false(status.indexed)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("schedule_index_file", function()
|
||||
it("should not index when disabled", function()
|
||||
indexer.setup({ enabled = false })
|
||||
|
||||
-- This should not throw or cause issues
|
||||
indexer.schedule_index_file(test_cwd .. "/src/test.lua")
|
||||
end)
|
||||
|
||||
it("should not index when auto_index is false", function()
|
||||
indexer.setup({ enabled = true, auto_index = false })
|
||||
|
||||
-- This should not throw or cause issues
|
||||
indexer.schedule_index_file(test_cwd .. "/src/test.lua")
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
371
tests/spec/inject_spec.lua
Normal file
371
tests/spec/inject_spec.lua
Normal file
@@ -0,0 +1,371 @@
|
||||
--- Tests for smart code injection with import handling
|
||||
|
||||
describe("codetyper.agent.inject", function()
|
||||
local inject
|
||||
|
||||
before_each(function()
|
||||
inject = require("codetyper.agent.inject")
|
||||
end)
|
||||
|
||||
describe("parse_code", function()
|
||||
describe("JavaScript/TypeScript", function()
|
||||
it("should detect ES6 named imports", function()
|
||||
local code = [[import { useState, useEffect } from 'react';
|
||||
import { Button } from './components';
|
||||
|
||||
function App() {
|
||||
return <div>Hello</div>;
|
||||
}]]
|
||||
local result = inject.parse_code(code, "typescript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("useState"))
|
||||
assert.truthy(result.imports[2]:match("Button"))
|
||||
assert.truthy(#result.body > 0)
|
||||
end)
|
||||
|
||||
it("should detect ES6 default imports", function()
|
||||
local code = [[import React from 'react';
|
||||
import axios from 'axios';
|
||||
|
||||
const api = axios.create();]]
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("React"))
|
||||
assert.truthy(result.imports[2]:match("axios"))
|
||||
end)
|
||||
|
||||
it("should detect require imports", function()
|
||||
local code = [[const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
module.exports = { fs, path };]]
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("fs"))
|
||||
assert.truthy(result.imports[2]:match("path"))
|
||||
end)
|
||||
|
||||
it("should detect multi-line imports", function()
|
||||
local code = [[import {
|
||||
useState,
|
||||
useEffect,
|
||||
useCallback
|
||||
} from 'react';
|
||||
|
||||
function Component() {}]]
|
||||
local result = inject.parse_code(code, "typescript")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("useState"))
|
||||
assert.truthy(result.imports[1]:match("useCallback"))
|
||||
end)
|
||||
|
||||
it("should detect namespace imports", function()
|
||||
local code = [[import * as React from 'react';
|
||||
|
||||
export default React;]]
|
||||
local result = inject.parse_code(code, "tsx")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("%* as React"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Python", function()
|
||||
it("should detect simple imports", function()
|
||||
local code = [[import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
def main():
|
||||
pass]]
|
||||
local result = inject.parse_code(code, "python")
|
||||
|
||||
assert.equals(3, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("import os"))
|
||||
assert.truthy(result.imports[2]:match("import sys"))
|
||||
assert.truthy(result.imports[3]:match("import json"))
|
||||
end)
|
||||
|
||||
it("should detect from imports", function()
|
||||
local code = [[from typing import List, Dict
|
||||
from pathlib import Path
|
||||
|
||||
def process(items: List[str]) -> None:
|
||||
pass]]
|
||||
local result = inject.parse_code(code, "py")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("from typing"))
|
||||
assert.truthy(result.imports[2]:match("from pathlib"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Lua", function()
|
||||
it("should detect require statements", function()
|
||||
local code = [[local M = {}
|
||||
local utils = require("codetyper.utils")
|
||||
local config = require('codetyper.config')
|
||||
|
||||
function M.setup()
|
||||
end
|
||||
|
||||
return M]]
|
||||
local result = inject.parse_code(code, "lua")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("utils"))
|
||||
assert.truthy(result.imports[2]:match("config"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Go", function()
|
||||
it("should detect single imports", function()
|
||||
local code = [[package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println("Hello")
|
||||
}]]
|
||||
local result = inject.parse_code(code, "go")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match('import "fmt"'))
|
||||
end)
|
||||
|
||||
it("should detect grouped imports", function()
|
||||
local code = [[package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {}]]
|
||||
local result = inject.parse_code(code, "go")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("fmt"))
|
||||
assert.truthy(result.imports[1]:match("os"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Rust", function()
|
||||
it("should detect use statements", function()
|
||||
local code = [[use std::io;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
let map = HashMap::new();
|
||||
}]]
|
||||
local result = inject.parse_code(code, "rs")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("std::io"))
|
||||
assert.truthy(result.imports[2]:match("HashMap"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("C/C++", function()
|
||||
it("should detect include statements", function()
|
||||
local code = [[#include <stdio.h>
|
||||
#include "myheader.h"
|
||||
|
||||
int main() {
|
||||
return 0;
|
||||
}]]
|
||||
local result = inject.parse_code(code, "c")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("stdio"))
|
||||
assert.truthy(result.imports[2]:match("myheader"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("merge_imports", function()
|
||||
it("should merge without duplicates", function()
|
||||
local existing = {
|
||||
"import { useState } from 'react';",
|
||||
"import { Button } from './components';",
|
||||
}
|
||||
local new_imports = {
|
||||
"import { useEffect } from 'react';",
|
||||
"import { useState } from 'react';", -- duplicate
|
||||
"import { Card } from './components';",
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(4, #merged) -- Should not have duplicate useState
|
||||
end)
|
||||
|
||||
it("should handle empty existing imports", function()
|
||||
local existing = {}
|
||||
local new_imports = {
|
||||
"import os",
|
||||
"import sys",
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(2, #merged)
|
||||
end)
|
||||
|
||||
it("should handle empty new imports", function()
|
||||
local existing = {
|
||||
"import os",
|
||||
"import sys",
|
||||
}
|
||||
local new_imports = {}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(2, #merged)
|
||||
end)
|
||||
|
||||
it("should handle whitespace variations in duplicates", function()
|
||||
local existing = {
|
||||
"import { useState } from 'react';",
|
||||
}
|
||||
local new_imports = {
|
||||
"import {useState} from 'react';", -- Same but different spacing
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(1, #merged) -- Should detect as duplicate
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("sort_imports", function()
|
||||
it("should group imports by type for JavaScript", function()
|
||||
local imports = {
|
||||
"import React from 'react';",
|
||||
"import { Button } from './components';",
|
||||
"import axios from 'axios';",
|
||||
"import path from 'path';",
|
||||
}
|
||||
|
||||
local sorted = inject.sort_imports(imports, "javascript")
|
||||
|
||||
-- Check ordering: builtin -> third-party -> local
|
||||
local found_builtin = false
|
||||
local found_local = false
|
||||
local builtin_pos = 0
|
||||
local local_pos = 0
|
||||
|
||||
for i, imp in ipairs(sorted) do
|
||||
if imp:match("path") then
|
||||
found_builtin = true
|
||||
builtin_pos = i
|
||||
end
|
||||
if imp:match("%.%/") then
|
||||
found_local = true
|
||||
local_pos = i
|
||||
end
|
||||
end
|
||||
|
||||
-- Local imports should come after third-party
|
||||
if found_local and found_builtin then
|
||||
assert.truthy(local_pos > builtin_pos)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("has_imports", function()
|
||||
it("should return true when code has imports", function()
|
||||
local code = [[import { useState } from 'react';
|
||||
function App() {}]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "typescript"))
|
||||
end)
|
||||
|
||||
it("should return false when code has no imports", function()
|
||||
local code = [[function App() {
|
||||
return <div>Hello</div>;
|
||||
}]]
|
||||
|
||||
assert.is_false(inject.has_imports(code, "typescript"))
|
||||
end)
|
||||
|
||||
it("should detect Python imports", function()
|
||||
local code = [[from typing import List
|
||||
|
||||
def process(items: List[str]):
|
||||
pass]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "python"))
|
||||
end)
|
||||
|
||||
it("should detect Lua requires", function()
|
||||
local code = [[local utils = require("utils")
|
||||
|
||||
local M = {}
|
||||
return M]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "lua"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edge cases", function()
|
||||
it("should handle empty code", function()
|
||||
local result = inject.parse_code("", "javascript")
|
||||
|
||||
assert.equals(0, #result.imports)
|
||||
assert.equals(1, #result.body) -- Empty string becomes one empty line
|
||||
end)
|
||||
|
||||
it("should handle code with only imports", function()
|
||||
local code = [[import React from 'react';
|
||||
import { useState } from 'react';]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.equals(0, #result.body)
|
||||
end)
|
||||
|
||||
it("should handle code with only body", function()
|
||||
local code = [[function hello() {
|
||||
console.log("Hello");
|
||||
}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(0, #result.imports)
|
||||
assert.truthy(#result.body > 0)
|
||||
end)
|
||||
|
||||
it("should handle imports in string literals (not detect as imports)", function()
|
||||
local code = [[const example = "import { fake } from 'not-real';";
|
||||
const config = { import: true };
|
||||
|
||||
function test() {}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
-- The first line looks like an import but is in a string
|
||||
-- This is a known limitation - we accept some false positives
|
||||
-- The important thing is we don't break the code
|
||||
assert.truthy(#result.body >= 0)
|
||||
end)
|
||||
|
||||
it("should handle mixed import styles in same file", function()
|
||||
local code = [[import React from 'react';
|
||||
const axios = require('axios');
|
||||
import { useState } from 'react';
|
||||
|
||||
function App() {}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(3, #result.imports)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
174
tests/spec/llm_selector_spec.lua
Normal file
174
tests/spec/llm_selector_spec.lua
Normal file
@@ -0,0 +1,174 @@
|
||||
--- Tests for smart LLM selection with memory-based confidence
|
||||
|
||||
describe("codetyper.llm.selector", function()
|
||||
local selector
|
||||
|
||||
before_each(function()
|
||||
selector = require("codetyper.llm.selector")
|
||||
-- Reset stats for clean tests
|
||||
selector.reset_accuracy_stats()
|
||||
end)
|
||||
|
||||
describe("select_provider", function()
|
||||
it("should return copilot when no brain memories exist", function()
|
||||
local result = selector.select_provider("write a function", {
|
||||
file_path = "/test/file.lua",
|
||||
})
|
||||
|
||||
assert.equals("copilot", result.provider)
|
||||
assert.equals(0, result.memory_count)
|
||||
assert.truthy(result.reason:match("Insufficient context"))
|
||||
end)
|
||||
|
||||
it("should return a valid selection result structure", function()
|
||||
local result = selector.select_provider("test prompt", {})
|
||||
|
||||
assert.is_string(result.provider)
|
||||
assert.is_number(result.confidence)
|
||||
assert.is_number(result.memory_count)
|
||||
assert.is_string(result.reason)
|
||||
end)
|
||||
|
||||
it("should have confidence between 0 and 1", function()
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
assert.truthy(result.confidence >= 0)
|
||||
assert.truthy(result.confidence <= 1)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("should_ponder", function()
|
||||
it("should return true for medium confidence", function()
|
||||
assert.is_true(selector.should_ponder(0.5))
|
||||
assert.is_true(selector.should_ponder(0.6))
|
||||
end)
|
||||
|
||||
it("should return false for low confidence", function()
|
||||
assert.is_false(selector.should_ponder(0.2))
|
||||
assert.is_false(selector.should_ponder(0.3))
|
||||
end)
|
||||
|
||||
-- High confidence pondering is probabilistic, so we test the range
|
||||
it("should sometimes ponder for high confidence (sampling)", function()
|
||||
-- Run multiple times to test probabilistic behavior
|
||||
local pondered_count = 0
|
||||
for _ = 1, 100 do
|
||||
if selector.should_ponder(0.9) then
|
||||
pondered_count = pondered_count + 1
|
||||
end
|
||||
end
|
||||
-- Should ponder roughly 20% of the time (PONDER_SAMPLE_RATE = 0.2)
|
||||
-- Allow range of 5-40% due to randomness
|
||||
assert.truthy(pondered_count >= 5, "Should ponder at least sometimes")
|
||||
assert.truthy(pondered_count <= 40, "Should not ponder too often")
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_accuracy_stats", function()
|
||||
it("should return initial empty stats", function()
|
||||
local stats = selector.get_accuracy_stats()
|
||||
|
||||
assert.equals(0, stats.ollama.total)
|
||||
assert.equals(0, stats.ollama.correct)
|
||||
assert.equals(0, stats.ollama.accuracy)
|
||||
assert.equals(0, stats.copilot.total)
|
||||
assert.equals(0, stats.copilot.correct)
|
||||
assert.equals(0, stats.copilot.accuracy)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("report_feedback", function()
|
||||
it("should track positive feedback", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(3, stats.ollama.total)
|
||||
assert.equals(2, stats.ollama.correct)
|
||||
end)
|
||||
|
||||
it("should track copilot feedback separately", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("copilot", true)
|
||||
selector.report_feedback("copilot", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(1, stats.ollama.total)
|
||||
assert.equals(2, stats.copilot.total)
|
||||
end)
|
||||
|
||||
it("should calculate accuracy correctly", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(0.75, stats.ollama.accuracy)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("reset_accuracy_stats", function()
|
||||
it("should clear all stats", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("copilot", true)
|
||||
|
||||
selector.reset_accuracy_stats()
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(0, stats.ollama.total)
|
||||
assert.equals(0, stats.copilot.total)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("agreement calculation", function()
|
||||
-- Test the internal agreement calculation through pondering behavior
|
||||
-- Since calculate_agreement is local, we test its effects indirectly
|
||||
|
||||
it("should detect high agreement for similar responses", function()
|
||||
-- This is tested through the pondering system
|
||||
-- When responses are similar, agreement should be high
|
||||
local selector = require("codetyper.llm.selector")
|
||||
|
||||
-- Verify that should_ponder returns predictable results
|
||||
-- for medium confidence (where pondering always happens)
|
||||
assert.is_true(selector.should_ponder(0.5))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("provider selection with accuracy history", function()
|
||||
local selector
|
||||
|
||||
before_each(function()
|
||||
selector = require("codetyper.llm.selector")
|
||||
selector.reset_accuracy_stats()
|
||||
end)
|
||||
|
||||
it("should factor in historical accuracy for selection", function()
|
||||
-- Simulate high Ollama accuracy
|
||||
for _ = 1, 10 do
|
||||
selector.report_feedback("ollama", true)
|
||||
end
|
||||
|
||||
-- Even with no brain context, historical accuracy should influence confidence
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
-- Confidence should be higher due to historical accuracy
|
||||
-- but provider might still be copilot if no memories
|
||||
assert.is_number(result.confidence)
|
||||
end)
|
||||
|
||||
it("should have lower confidence for low historical accuracy", function()
|
||||
-- Simulate low Ollama accuracy
|
||||
for _ = 1, 10 do
|
||||
selector.report_feedback("ollama", false)
|
||||
end
|
||||
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
-- With bad history and no memories, should definitely use copilot
|
||||
assert.equals("copilot", result.provider)
|
||||
end)
|
||||
end)
|
||||
341
tests/spec/memory_spec.lua
Normal file
341
tests/spec/memory_spec.lua
Normal file
@@ -0,0 +1,341 @@
|
||||
---@diagnostic disable: undefined-global
|
||||
-- Tests for lua/codetyper/indexer/memory.lua
|
||||
|
||||
describe("indexer.memory", function()
|
||||
local memory
|
||||
local utils
|
||||
|
||||
-- Mock cwd for testing
|
||||
local test_cwd = "/tmp/codetyper_test_memory"
|
||||
|
||||
before_each(function()
|
||||
-- Reset modules
|
||||
package.loaded["codetyper.indexer.memory"] = nil
|
||||
package.loaded["codetyper.utils"] = nil
|
||||
|
||||
memory = require("codetyper.indexer.memory")
|
||||
utils = require("codetyper.utils")
|
||||
|
||||
-- Create test directory structure
|
||||
vim.fn.mkdir(test_cwd, "p")
|
||||
vim.fn.mkdir(test_cwd .. "/.coder", "p")
|
||||
vim.fn.mkdir(test_cwd .. "/.coder/memories", "p")
|
||||
vim.fn.mkdir(test_cwd .. "/.coder/memories/files", "p")
|
||||
vim.fn.mkdir(test_cwd .. "/.coder/sessions", "p")
|
||||
|
||||
-- Mock getcwd to return test directory
|
||||
vim.fn.getcwd = function()
|
||||
return test_cwd
|
||||
end
|
||||
|
||||
-- Mock get_project_root
|
||||
package.loaded["codetyper.utils"].get_project_root = function()
|
||||
return test_cwd
|
||||
end
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
-- Clean up test directory
|
||||
vim.fn.delete(test_cwd, "rf")
|
||||
end)
|
||||
|
||||
describe("store_memory", function()
|
||||
it("should store a pattern memory", function()
|
||||
local mem = {
|
||||
type = "pattern",
|
||||
content = "Use snake_case for function names",
|
||||
weight = 0.8,
|
||||
}
|
||||
|
||||
local result = memory.store_memory(mem)
|
||||
|
||||
assert.is_true(result)
|
||||
end)
|
||||
|
||||
it("should store a convention memory", function()
|
||||
local mem = {
|
||||
type = "convention",
|
||||
content = "Project uses TypeScript",
|
||||
weight = 0.9,
|
||||
}
|
||||
|
||||
local result = memory.store_memory(mem)
|
||||
|
||||
assert.is_true(result)
|
||||
end)
|
||||
|
||||
it("should assign an ID to the memory", function()
|
||||
local mem = {
|
||||
type = "pattern",
|
||||
content = "Test memory",
|
||||
}
|
||||
|
||||
memory.store_memory(mem)
|
||||
|
||||
assert.is_truthy(mem.id)
|
||||
assert.is_true(mem.id:match("^mem_") ~= nil)
|
||||
end)
|
||||
|
||||
it("should set timestamps", function()
|
||||
local mem = {
|
||||
type = "pattern",
|
||||
content = "Test memory",
|
||||
}
|
||||
|
||||
memory.store_memory(mem)
|
||||
|
||||
assert.is_truthy(mem.created_at)
|
||||
assert.is_truthy(mem.updated_at)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("load_patterns", function()
|
||||
it("should return empty table when no patterns exist", function()
|
||||
local patterns = memory.load_patterns()
|
||||
|
||||
assert.is_table(patterns)
|
||||
end)
|
||||
|
||||
it("should load stored patterns", function()
|
||||
-- Store a pattern first
|
||||
memory.store_memory({
|
||||
type = "pattern",
|
||||
content = "Test pattern",
|
||||
weight = 0.5,
|
||||
})
|
||||
|
||||
-- Force reload
|
||||
package.loaded["codetyper.indexer.memory"] = nil
|
||||
memory = require("codetyper.indexer.memory")
|
||||
|
||||
local patterns = memory.load_patterns()
|
||||
|
||||
assert.is_table(patterns)
|
||||
local count = 0
|
||||
for _ in pairs(patterns) do
|
||||
count = count + 1
|
||||
end
|
||||
assert.is_true(count >= 1)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("load_conventions", function()
|
||||
it("should return empty table when no conventions exist", function()
|
||||
local conventions = memory.load_conventions()
|
||||
|
||||
assert.is_table(conventions)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("store_file_memory", function()
|
||||
it("should store file-specific memory", function()
|
||||
local file_index = {
|
||||
functions = {
|
||||
{ name = "test_func", line = 10, end_line = 20 },
|
||||
},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
local result = memory.store_file_memory("src/main.lua", file_index)
|
||||
|
||||
assert.is_true(result)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("load_file_memory", function()
|
||||
it("should return nil when file memory does not exist", function()
|
||||
local result = memory.load_file_memory("nonexistent.lua")
|
||||
|
||||
assert.is_nil(result)
|
||||
end)
|
||||
|
||||
it("should load stored file memory", function()
|
||||
local file_index = {
|
||||
functions = {
|
||||
{ name = "my_function", line = 5, end_line = 15 },
|
||||
},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
memory.store_file_memory("src/test.lua", file_index)
|
||||
local loaded = memory.load_file_memory("src/test.lua")
|
||||
|
||||
assert.is_table(loaded)
|
||||
assert.equals("src/test.lua", loaded.path)
|
||||
assert.equals(1, #loaded.functions)
|
||||
assert.equals("my_function", loaded.functions[1].name)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_relevant", function()
|
||||
it("should return empty table when no memories exist", function()
|
||||
local results = memory.get_relevant("test query", 10)
|
||||
|
||||
assert.is_table(results)
|
||||
assert.equals(0, #results)
|
||||
end)
|
||||
|
||||
it("should find relevant memories by keyword", function()
|
||||
memory.store_memory({
|
||||
type = "pattern",
|
||||
content = "Use TypeScript for type safety",
|
||||
weight = 0.8,
|
||||
})
|
||||
memory.store_memory({
|
||||
type = "pattern",
|
||||
content = "Use Python for data processing",
|
||||
weight = 0.7,
|
||||
})
|
||||
|
||||
local results = memory.get_relevant("TypeScript", 10)
|
||||
|
||||
assert.is_true(#results >= 1)
|
||||
-- First result should contain TypeScript
|
||||
local found = false
|
||||
for _, r in ipairs(results) do
|
||||
if r.content:find("TypeScript") then
|
||||
found = true
|
||||
break
|
||||
end
|
||||
end
|
||||
assert.is_true(found)
|
||||
end)
|
||||
|
||||
it("should limit results", function()
|
||||
-- Store multiple memories
|
||||
for i = 1, 20 do
|
||||
memory.store_memory({
|
||||
type = "pattern",
|
||||
content = "Pattern number " .. i .. " about testing",
|
||||
weight = 0.5,
|
||||
})
|
||||
end
|
||||
|
||||
local results = memory.get_relevant("testing", 5)
|
||||
|
||||
assert.is_true(#results <= 5)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("update_usage", function()
|
||||
it("should increment used_count", function()
|
||||
local mem = {
|
||||
type = "pattern",
|
||||
content = "Test pattern for usage tracking",
|
||||
weight = 0.5,
|
||||
}
|
||||
memory.store_memory(mem)
|
||||
|
||||
memory.update_usage(mem.id)
|
||||
|
||||
-- Reload and check
|
||||
package.loaded["codetyper.indexer.memory"] = nil
|
||||
memory = require("codetyper.indexer.memory")
|
||||
|
||||
local patterns = memory.load_patterns()
|
||||
if patterns[mem.id] then
|
||||
assert.equals(1, patterns[mem.id].used_count)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_all", function()
|
||||
it("should return all memory types", function()
|
||||
memory.store_memory({ type = "pattern", content = "A pattern" })
|
||||
memory.store_memory({ type = "convention", content = "A convention" })
|
||||
|
||||
local all = memory.get_all()
|
||||
|
||||
assert.is_table(all.patterns)
|
||||
assert.is_table(all.conventions)
|
||||
assert.is_table(all.symbols)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("clear", function()
|
||||
it("should clear all memories when no pattern provided", function()
|
||||
memory.store_memory({ type = "pattern", content = "Pattern 1" })
|
||||
memory.store_memory({ type = "convention", content = "Convention 1" })
|
||||
|
||||
memory.clear()
|
||||
|
||||
local all = memory.get_all()
|
||||
assert.equals(0, vim.tbl_count(all.patterns))
|
||||
assert.equals(0, vim.tbl_count(all.conventions))
|
||||
end)
|
||||
|
||||
it("should clear only matching memories when pattern provided", function()
|
||||
local mem1 = { type = "pattern", content = "Pattern 1" }
|
||||
local mem2 = { type = "pattern", content = "Pattern 2" }
|
||||
memory.store_memory(mem1)
|
||||
memory.store_memory(mem2)
|
||||
|
||||
-- Clear memories matching the first ID
|
||||
memory.clear(mem1.id)
|
||||
|
||||
local patterns = memory.load_patterns()
|
||||
assert.is_nil(patterns[mem1.id])
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("prune", function()
|
||||
it("should remove low-weight unused memories", function()
|
||||
-- Store some low-weight memories
|
||||
memory.store_memory({
|
||||
type = "pattern",
|
||||
content = "Low weight pattern",
|
||||
weight = 0.05,
|
||||
used_count = 0,
|
||||
})
|
||||
memory.store_memory({
|
||||
type = "pattern",
|
||||
content = "High weight pattern",
|
||||
weight = 0.9,
|
||||
used_count = 0,
|
||||
})
|
||||
|
||||
local pruned = memory.prune(0.1)
|
||||
|
||||
-- Should have pruned at least one
|
||||
assert.is_true(pruned >= 0)
|
||||
end)
|
||||
|
||||
it("should not remove frequently used memories", function()
|
||||
local mem = {
|
||||
type = "pattern",
|
||||
content = "Frequently used but low weight",
|
||||
weight = 0.05,
|
||||
used_count = 10,
|
||||
}
|
||||
memory.store_memory(mem)
|
||||
|
||||
memory.prune(0.1)
|
||||
|
||||
-- Memory should still exist because used_count > 0
|
||||
local patterns = memory.load_patterns()
|
||||
-- Note: prune only removes if used_count == 0 AND weight < threshold
|
||||
if patterns[mem.id] then
|
||||
assert.is_truthy(patterns[mem.id])
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_stats", function()
|
||||
it("should return memory statistics", function()
|
||||
memory.store_memory({ type = "pattern", content = "P1" })
|
||||
memory.store_memory({ type = "pattern", content = "P2" })
|
||||
memory.store_memory({ type = "convention", content = "C1" })
|
||||
|
||||
local stats = memory.get_stats()
|
||||
|
||||
assert.is_table(stats)
|
||||
assert.equals(2, stats.patterns)
|
||||
assert.equals(1, stats.conventions)
|
||||
assert.equals(3, stats.total)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
285
tests/spec/scanner_spec.lua
Normal file
285
tests/spec/scanner_spec.lua
Normal file
@@ -0,0 +1,285 @@
|
||||
---@diagnostic disable: undefined-global
|
||||
-- Tests for lua/codetyper/indexer/scanner.lua
|
||||
|
||||
describe("indexer.scanner", function()
|
||||
local scanner
|
||||
local utils
|
||||
|
||||
-- Mock cwd for testing
|
||||
local test_cwd = "/tmp/codetyper_test_scanner"
|
||||
|
||||
before_each(function()
|
||||
-- Reset modules
|
||||
package.loaded["codetyper.indexer.scanner"] = nil
|
||||
package.loaded["codetyper.utils"] = nil
|
||||
|
||||
scanner = require("codetyper.indexer.scanner")
|
||||
utils = require("codetyper.utils")
|
||||
|
||||
-- Create test directory
|
||||
vim.fn.mkdir(test_cwd, "p")
|
||||
|
||||
-- Mock getcwd to return test directory
|
||||
vim.fn.getcwd = function()
|
||||
return test_cwd
|
||||
end
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
-- Clean up test directory
|
||||
vim.fn.delete(test_cwd, "rf")
|
||||
end)
|
||||
|
||||
describe("detect_project_type", function()
|
||||
it("should detect node project from package.json", function()
|
||||
utils.write_file(test_cwd .. "/package.json", '{"name":"test"}')
|
||||
|
||||
local project_type = scanner.detect_project_type(test_cwd)
|
||||
|
||||
assert.equals("node", project_type)
|
||||
end)
|
||||
|
||||
it("should detect rust project from Cargo.toml", function()
|
||||
utils.write_file(test_cwd .. "/Cargo.toml", '[package]\nname = "test"')
|
||||
|
||||
local project_type = scanner.detect_project_type(test_cwd)
|
||||
|
||||
assert.equals("rust", project_type)
|
||||
end)
|
||||
|
||||
it("should detect go project from go.mod", function()
|
||||
utils.write_file(test_cwd .. "/go.mod", "module example.com/test")
|
||||
|
||||
local project_type = scanner.detect_project_type(test_cwd)
|
||||
|
||||
assert.equals("go", project_type)
|
||||
end)
|
||||
|
||||
it("should detect python project from pyproject.toml", function()
|
||||
utils.write_file(test_cwd .. "/pyproject.toml", '[project]\nname = "test"')
|
||||
|
||||
local project_type = scanner.detect_project_type(test_cwd)
|
||||
|
||||
assert.equals("python", project_type)
|
||||
end)
|
||||
|
||||
it("should return unknown for unrecognized project", function()
|
||||
-- Empty directory
|
||||
local project_type = scanner.detect_project_type(test_cwd)
|
||||
|
||||
assert.equals("unknown", project_type)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("parse_package_json", function()
|
||||
it("should parse dependencies from package.json", function()
|
||||
local pkg_content = [[{
|
||||
"name": "test",
|
||||
"dependencies": {
|
||||
"express": "^4.18.0",
|
||||
"lodash": "^4.17.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"jest": "^29.0.0"
|
||||
}
|
||||
}]]
|
||||
utils.write_file(test_cwd .. "/package.json", pkg_content)
|
||||
|
||||
local result = scanner.parse_package_json(test_cwd)
|
||||
|
||||
assert.is_table(result.dependencies)
|
||||
assert.is_table(result.dev_dependencies)
|
||||
assert.equals("^4.18.0", result.dependencies.express)
|
||||
assert.equals("^4.17.0", result.dependencies.lodash)
|
||||
assert.equals("^29.0.0", result.dev_dependencies.jest)
|
||||
end)
|
||||
|
||||
it("should return empty tables when package.json does not exist", function()
|
||||
local result = scanner.parse_package_json(test_cwd)
|
||||
|
||||
assert.is_table(result.dependencies)
|
||||
assert.is_table(result.dev_dependencies)
|
||||
assert.equals(0, vim.tbl_count(result.dependencies))
|
||||
end)
|
||||
|
||||
it("should handle malformed JSON gracefully", function()
|
||||
utils.write_file(test_cwd .. "/package.json", "not valid json")
|
||||
|
||||
local result = scanner.parse_package_json(test_cwd)
|
||||
|
||||
assert.is_table(result.dependencies)
|
||||
assert.equals(0, vim.tbl_count(result.dependencies))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("parse_cargo_toml", function()
|
||||
it("should parse dependencies from Cargo.toml", function()
|
||||
local cargo_content = [[
|
||||
[package]
|
||||
name = "test"
|
||||
|
||||
[dependencies]
|
||||
serde = "1.0"
|
||||
tokio = "1.28"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.5"
|
||||
]]
|
||||
utils.write_file(test_cwd .. "/Cargo.toml", cargo_content)
|
||||
|
||||
local result = scanner.parse_cargo_toml(test_cwd)
|
||||
|
||||
assert.is_table(result.dependencies)
|
||||
assert.equals("1.0", result.dependencies.serde)
|
||||
assert.equals("1.28", result.dependencies.tokio)
|
||||
assert.equals("3.5", result.dev_dependencies.tempfile)
|
||||
end)
|
||||
|
||||
it("should return empty tables when Cargo.toml does not exist", function()
|
||||
local result = scanner.parse_cargo_toml(test_cwd)
|
||||
|
||||
assert.equals(0, vim.tbl_count(result.dependencies))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("parse_go_mod", function()
|
||||
it("should parse dependencies from go.mod", function()
|
||||
local go_mod_content = [[
|
||||
module example.com/test
|
||||
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
)
|
||||
]]
|
||||
utils.write_file(test_cwd .. "/go.mod", go_mod_content)
|
||||
|
||||
local result = scanner.parse_go_mod(test_cwd)
|
||||
|
||||
assert.is_table(result.dependencies)
|
||||
assert.equals("v1.9.1", result.dependencies["github.com/gin-gonic/gin"])
|
||||
assert.equals("v1.8.4", result.dependencies["github.com/stretchr/testify"])
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("should_ignore", function()
|
||||
it("should ignore hidden files", function()
|
||||
local config = { excluded_dirs = {} }
|
||||
|
||||
assert.is_true(scanner.should_ignore(".hidden", config))
|
||||
assert.is_true(scanner.should_ignore(".git", config))
|
||||
end)
|
||||
|
||||
it("should ignore node_modules", function()
|
||||
local config = { excluded_dirs = {} }
|
||||
|
||||
assert.is_true(scanner.should_ignore("node_modules", config))
|
||||
end)
|
||||
|
||||
it("should ignore configured directories", function()
|
||||
local config = { excluded_dirs = { "custom_ignore" } }
|
||||
|
||||
assert.is_true(scanner.should_ignore("custom_ignore", config))
|
||||
end)
|
||||
|
||||
it("should not ignore regular files", function()
|
||||
local config = { excluded_dirs = {} }
|
||||
|
||||
assert.is_false(scanner.should_ignore("main.lua", config))
|
||||
assert.is_false(scanner.should_ignore("src", config))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("should_index", function()
|
||||
it("should index files with allowed extensions", function()
|
||||
vim.fn.mkdir(test_cwd .. "/src", "p")
|
||||
utils.write_file(test_cwd .. "/src/main.lua", "-- test")
|
||||
|
||||
local config = {
|
||||
index_extensions = { "lua", "ts", "js" },
|
||||
max_file_size = 100000,
|
||||
excluded_dirs = {},
|
||||
}
|
||||
|
||||
assert.is_true(scanner.should_index(test_cwd .. "/src/main.lua", config))
|
||||
end)
|
||||
|
||||
it("should not index coder files", function()
|
||||
utils.write_file(test_cwd .. "/main.coder.lua", "-- test")
|
||||
|
||||
local config = {
|
||||
index_extensions = { "lua" },
|
||||
max_file_size = 100000,
|
||||
excluded_dirs = {},
|
||||
}
|
||||
|
||||
assert.is_false(scanner.should_index(test_cwd .. "/main.coder.lua", config))
|
||||
end)
|
||||
|
||||
it("should not index files with disallowed extensions", function()
|
||||
utils.write_file(test_cwd .. "/image.png", "binary")
|
||||
|
||||
local config = {
|
||||
index_extensions = { "lua", "ts", "js" },
|
||||
max_file_size = 100000,
|
||||
excluded_dirs = {},
|
||||
}
|
||||
|
||||
assert.is_false(scanner.should_index(test_cwd .. "/image.png", config))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_indexable_files", function()
|
||||
it("should return list of indexable files", function()
|
||||
vim.fn.mkdir(test_cwd .. "/src", "p")
|
||||
utils.write_file(test_cwd .. "/src/main.lua", "-- main")
|
||||
utils.write_file(test_cwd .. "/src/utils.lua", "-- utils")
|
||||
utils.write_file(test_cwd .. "/README.md", "# Readme")
|
||||
|
||||
local config = {
|
||||
index_extensions = { "lua" },
|
||||
max_file_size = 100000,
|
||||
excluded_dirs = { "node_modules" },
|
||||
}
|
||||
|
||||
local files = scanner.get_indexable_files(test_cwd, config)
|
||||
|
||||
assert.equals(2, #files)
|
||||
end)
|
||||
|
||||
it("should skip ignored directories", function()
|
||||
vim.fn.mkdir(test_cwd .. "/src", "p")
|
||||
vim.fn.mkdir(test_cwd .. "/node_modules", "p")
|
||||
utils.write_file(test_cwd .. "/src/main.lua", "-- main")
|
||||
utils.write_file(test_cwd .. "/node_modules/package.lua", "-- ignore")
|
||||
|
||||
local config = {
|
||||
index_extensions = { "lua" },
|
||||
max_file_size = 100000,
|
||||
excluded_dirs = { "node_modules" },
|
||||
}
|
||||
|
||||
local files = scanner.get_indexable_files(test_cwd, config)
|
||||
|
||||
-- Should only include src/main.lua
|
||||
assert.equals(1, #files)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_language", function()
|
||||
it("should return correct language for extensions", function()
|
||||
assert.equals("lua", scanner.get_language("test.lua"))
|
||||
assert.equals("typescript", scanner.get_language("test.ts"))
|
||||
assert.equals("javascript", scanner.get_language("test.js"))
|
||||
assert.equals("python", scanner.get_language("test.py"))
|
||||
assert.equals("go", scanner.get_language("test.go"))
|
||||
assert.equals("rust", scanner.get_language("test.rs"))
|
||||
end)
|
||||
|
||||
it("should return extension as fallback", function()
|
||||
assert.equals("unknown", scanner.get_language("test.unknown"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
Reference in New Issue
Block a user