From f5df1a9ac0359e2243afd173c7967bb7f5185406 Mon Sep 17 00:00:00 2001 From: Carlos Gutierrez Date: Thu, 15 Jan 2026 20:58:56 -0500 Subject: [PATCH] Adding more features --- CHANGELOG.md | 54 +- README.md | 214 +++++- llms.txt | 188 ++++- lua/codetyper/agent/agentic.lua | 854 +++++++++++++++++++++++ lua/codetyper/agent/confidence.lua | 25 +- lua/codetyper/agent/init.lua | 6 +- lua/codetyper/agent/inject.lua | 614 ++++++++++++++++ lua/codetyper/agent/loop.lua | 398 +++++++++++ lua/codetyper/agent/patch.lua | 149 ++-- lua/codetyper/agent/tools/base.lua | 128 ++++ lua/codetyper/agent/tools/bash.lua | 198 ++++++ lua/codetyper/agent/tools/edit.lua | 429 ++++++++++++ lua/codetyper/agent/tools/glob.lua | 146 ++++ lua/codetyper/agent/tools/grep.lua | 150 ++++ lua/codetyper/agent/tools/init.lua | 308 ++++++++ lua/codetyper/agent/tools/view.lua | 149 ++++ lua/codetyper/agent/tools/write.lua | 101 +++ lua/codetyper/agent/worker.lua | 188 ++++- lua/codetyper/ask.lua | 4 +- lua/codetyper/autocmds.lua | 454 ++++++++++-- lua/codetyper/brain/graph/query.lua | 110 ++- lua/codetyper/brain/learners/pattern.lua | 4 + lua/codetyper/commands.lua | 387 ++++++++++ lua/codetyper/cost.lua | 750 ++++++++++++++++++++ lua/codetyper/credentials.lua | 602 ++++++++++++++++ lua/codetyper/indexer/analyzer.lua | 5 +- lua/codetyper/llm/copilot.lua | 237 ++++++- lua/codetyper/llm/gemini.lua | 18 +- lua/codetyper/llm/init.lua | 26 + lua/codetyper/llm/ollama.lua | 307 +++----- lua/codetyper/llm/openai.lua | 38 +- lua/codetyper/llm/selector.lua | 514 ++++++++++++++ lua/codetyper/prompts/agent.lua | 56 +- tests/spec/agent_tools_spec.lua | 427 ++++++++++++ tests/spec/agentic_spec.lua | 312 +++++++++ tests/spec/brain_learners_spec.lua | 153 ++++ tests/spec/coder_context_spec.lua | 194 +++++ tests/spec/coder_ignore_spec.lua | 161 +++++ tests/spec/inject_spec.lua | 371 ++++++++++ tests/spec/llm_selector_spec.lua | 174 +++++ 40 files changed, 9145 insertions(+), 458 deletions(-) create mode 100644 lua/codetyper/agent/agentic.lua create mode 100644 lua/codetyper/agent/inject.lua create mode 100644 lua/codetyper/agent/loop.lua create mode 100644 lua/codetyper/agent/tools/base.lua create mode 100644 lua/codetyper/agent/tools/bash.lua create mode 100644 lua/codetyper/agent/tools/edit.lua create mode 100644 lua/codetyper/agent/tools/glob.lua create mode 100644 lua/codetyper/agent/tools/grep.lua create mode 100644 lua/codetyper/agent/tools/init.lua create mode 100644 lua/codetyper/agent/tools/view.lua create mode 100644 lua/codetyper/agent/tools/write.lua create mode 100644 lua/codetyper/cost.lua create mode 100644 lua/codetyper/credentials.lua create mode 100644 lua/codetyper/llm/selector.lua create mode 100644 tests/spec/agent_tools_spec.lua create mode 100644 tests/spec/agentic_spec.lua create mode 100644 tests/spec/brain_learners_spec.lua create mode 100644 tests/spec/coder_context_spec.lua create mode 100644 tests/spec/coder_ignore_spec.lua create mode 100644 tests/spec/inject_spec.lua create mode 100644 tests/spec/llm_selector_spec.lua diff --git a/CHANGELOG.md b/CHANGELOG.md index 3234e67..985727f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,57 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.5.0] - 2026-01-15 + +### Added + +- **Cost Tracking System** - Track LLM API costs across sessions + - New `:CoderCost` command opens cost estimation floating window + - Session costs tracked in real-time + - All-time costs persisted in `.coder/cost_history.json` + - Per-model breakdown with token counts + - Pricing database for 50+ models (GPT-4/5, Claude, O-series, Gemini) + - Window keymaps: `q` close, `r` refresh, `c` clear session, `C` clear all + +- **Automatic Ollama Fallback** - Graceful degradation when API limits hit + - Automatically switches to Ollama when Copilot rate limits exceeded + - Detects local Ollama availability before fallback + - Notifies user of provider switch + +- **Enhanced Error Handling** - Better error messages for API failures + - Shows actual API response on parse errors (not generic "failed to parse") + - Improved rate limit detection and messaging + - Sanitized newlines in error notifications to prevent UI crashes + +- **Agent Tools System Improvements** + - New `to_openai_format()` and `to_claude_format()` functions + - `get_definitions()` for generic tool access + - Fixed tool call argument serialization (JSON strings vs tables) + +- **Credentials Management System** - Store API keys outside of config files + - New `:CoderAddApiKey` command for interactive credential setup + - `:CoderRemoveApiKey` to remove stored credentials + - `:CoderCredentials` to view credential status + - `:CoderSwitchProvider` to switch active LLM provider + - Credentials stored in `~/.local/share/nvim/codetyper/configuration.json` + - Priority: stored credentials > config > environment variables + - Supports all providers: Claude, OpenAI, Gemini, Copilot, Ollama + +### Changed + +- Cost window now shows both session and all-time statistics +- Improved agent prompt templates with correct tool names +- Better error context in LLM provider responses + +### Fixed + +- Fixed "Failed to parse Copilot response" error showing instead of actual error +- Fixed `nvim_buf_set_lines` crash from newlines in error messages +- Fixed `tools.definitions` nil error in agent initialization +- Fixed tool name mismatch in agent prompts (write_file vs write) + +--- + ## [0.4.0] - 2026-01-13 ### Added @@ -194,7 +245,8 @@ scheduler = { - **Fixed** - Bug fixes - **Security** - Vulnerability fixes -[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...HEAD +[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.5.0...HEAD +[0.5.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...v0.5.0 [0.4.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.3.0...v0.4.0 [0.3.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.2.0...v0.3.0 [0.2.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.1.0...v0.2.0 diff --git a/README.md b/README.md index 7ab2c64..36de3f5 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,10 @@ - 🛡️ **Completion-Aware**: Safe injection that doesn't fight with autocomplete - 📁 **Auto-Index**: Automatically create coder companion files on file open - 📜 **Logs Panel**: Real-time visibility into LLM requests and token usage +- 💰 **Cost Tracking**: Persistent LLM cost estimation with session and all-time stats - 🔒 **Git Integration**: Automatically adds `.coder.*` files to `.gitignore` - 🌳 **Project Tree Logging**: Maintains a `tree.log` tracking your project structure +- 🧠 **Brain System**: Knowledge graph that learns from your coding patterns --- @@ -34,6 +36,8 @@ - [LLM Providers](#-llm-providers) - [Commands Reference](#-commands-reference) - [Usage Guide](#-usage-guide) +- [Logs Panel](#-logs-panel) +- [Cost Tracking](#-cost-tracking) - [Agent Mode](#-agent-mode) - [Keymaps](#-keymaps) - [Health Check](#-health-check) @@ -196,6 +200,32 @@ require("codetyper").setup({ | `OPENAI_API_KEY` | OpenAI API key | | `GEMINI_API_KEY` | Google Gemini API key | +### Credentials Management + +Instead of storing API keys in your config (which may be committed to git), you can use the credentials system: + +```vim +:CoderAddApiKey +``` + +This command interactively prompts for: +1. Provider selection (Claude, OpenAI, Gemini, Copilot, Ollama) +2. API key (for cloud providers) +3. Model name +4. Custom endpoint (for OpenAI-compatible APIs) + +Credentials are stored securely in `~/.local/share/nvim/codetyper/configuration.json` (not in your config files). + +**Priority order for credentials:** +1. Stored credentials (via `:CoderAddApiKey`) +2. Config file settings +3. Environment variables + +**Other credential commands:** +- `:CoderCredentials` - View configured providers +- `:CoderSwitchProvider` - Switch between configured providers +- `:CoderRemoveApiKey` - Remove stored credentials + --- ## 🔌 LLM Providers @@ -255,49 +285,129 @@ llm = { ## 📝 Commands Reference -### Main Commands +All commands can be invoked via `:Coder {subcommand}` or their dedicated command aliases. + +### Core Commands + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder open` | `:CoderOpen` | Open the coder split view | +| `:Coder close` | `:CoderClose` | Close the coder split view | +| `:Coder toggle` | `:CoderToggle` | Toggle the coder split view | +| `:Coder process` | `:CoderProcess` | Process the last prompt in coder file | +| `:Coder status` | - | Show plugin status and configuration | +| `:Coder focus` | - | Switch focus between coder and target windows | +| `:Coder reset` | - | Reset processed prompts to allow re-processing | +| `:Coder gitignore` | - | Force update .gitignore with coder patterns | + +### Ask Panel (Chat Interface) + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder ask` | `:CoderAsk` | Open the Ask panel | +| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle the Ask panel | +| `:Coder ask-close` | - | Close the Ask panel | +| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history | + +### Agent Mode (Autonomous Coding) + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder agent` | `:CoderAgent` | Open the Agent panel | +| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle the Agent panel | +| `:Coder agent-close` | - | Close the Agent panel | +| `:Coder agent-stop` | `:CoderAgentStop` | Stop the running agent | + +### Agentic Mode (IDE-like Multi-file Agent) + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder agentic-run ` | `:CoderAgenticRun ` | Run an agentic task (multi-file changes) | +| `:Coder agentic-list` | `:CoderAgenticList` | List available agents | +| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize `.coder/agents/` and `.coder/rules/` | + +### Transform Commands (Inline Tag Processing) + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder transform` | `:CoderTransform` | Transform all `/@ @/` tags in file | +| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor position | +| - | `:CoderTransformVisual` | Transform selected tags (visual mode) | + +### Project & Index Commands + +| Command | Alias | Description | +|---------|-------|-------------| +| - | `:CoderIndex` | Open coder companion for current file | +| `:Coder index-project` | `:CoderIndexProject` | Index the entire project | +| `:Coder index-status` | `:CoderIndexStatus` | Show project index status | + +### Tree & Structure Commands + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder tree` | `:CoderTree` | Refresh `.coder/tree.log` | +| `:Coder tree-view` | `:CoderTreeView` | View `.coder/tree.log` in split | + +### Queue & Scheduler Commands + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler and queue status | +| `:Coder queue-process` | `:CoderQueueProcess` | Manually trigger queue processing | + +### Processing Mode Commands + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle automatic/manual prompt processing | +| `:Coder auto-set ` | `:CoderAutoSet ` | Set processing mode (`auto`/`manual`) | + +### Memory & Learning Commands + +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder memories` | `:CoderMemories` | Show learned memories | +| `:Coder forget [pattern]` | `:CoderForget [pattern]` | Clear memories (optionally matching pattern) | + +### Brain Commands (Knowledge Graph) + +| Command | Alias | Description | +|---------|-------|-------------| +| - | `:CoderBrain [action]` | Brain management (`stats`/`commit`/`flush`/`prune`) | +| - | `:CoderFeedback ` | Give feedback to brain (`good`/`bad`/`stats`) | + +### LLM Statistics & Feedback | Command | Description | |---------|-------------| -| `:Coder {subcommand}` | Main command with subcommands | -| `:CoderOpen` | Open the coder split view | -| `:CoderClose` | Close the coder split view | -| `:CoderToggle` | Toggle the coder split view | -| `:CoderProcess` | Process the last prompt | +| `:Coder llm-stats` | Show LLM provider accuracy statistics | +| `:Coder llm-feedback-good` | Report positive feedback on last response | +| `:Coder llm-feedback-bad` | Report negative feedback on last response | +| `:Coder llm-reset-stats` | Reset LLM accuracy statistics | -### Ask Panel +### Cost Tracking -| Command | Description | -|---------|-------------| -| `:CoderAsk` | Open the Ask panel | -| `:CoderAskToggle` | Toggle the Ask panel | -| `:CoderAskClear` | Clear chat history | +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder cost` | `:CoderCost` | Show LLM cost estimation window | +| `:Coder cost-clear` | - | Clear session cost tracking | -### Agent Mode +### Credentials Management -| Command | Description | -|---------|-------------| -| `:CoderAgent` | Open the Agent panel | -| `:CoderAgentToggle` | Toggle the Agent panel | -| `:CoderAgentStop` | Stop the running agent | +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder add-api-key` | `:CoderAddApiKey` | Add or update LLM provider API key | +| `:Coder remove-api-key` | `:CoderRemoveApiKey` | Remove LLM provider credentials | +| `:Coder credentials` | `:CoderCredentials` | Show credentials status | +| `:Coder switch-provider` | `:CoderSwitchProvider` | Switch active LLM provider | -### Transform Commands +### UI Commands -| Command | Description | -|---------|-------------| -| `:CoderTransform` | Transform all /@ @/ tags in file | -| `:CoderTransformCursor` | Transform tag at cursor position | -| `:CoderTransformVisual` | Transform selected tags (visual mode) | - -### Utility Commands - -| Command | Description | -|---------|-------------| -| `:CoderIndex` | Open coder companion for current file | -| `:CoderLogs` | Toggle logs panel | -| `:CoderType` | Switch between Ask/Agent modes | -| `:CoderTree` | Refresh tree.log | -| `:CoderTreeView` | View tree.log | +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder type-toggle` | `:CoderType` | Show Ask/Agent mode switcher | +| `:Coder logs-toggle` | `:CoderLogs` | Toggle logs panel | --- @@ -384,6 +494,42 @@ The logs panel opens automatically when processing prompts with the scheduler en --- +## 💰 Cost Tracking + +Track your LLM API costs across sessions with the Cost Estimation window. + +### Features + +- **Session Tracking**: Monitor current session token usage and costs +- **All-Time Tracking**: Persistent cost history stored per-project in `.coder/cost_history.json` +- **Model Breakdown**: See costs by individual model +- **Pricing Database**: Built-in pricing for 50+ models (GPT, Claude, Gemini, O-series, etc.) + +### Opening the Cost Window + +```vim +:CoderCost +``` + +### Cost Window Keymaps + +| Key | Description | +|-----|-------------| +| `q` / `` | Close window | +| `r` | Refresh display | +| `c` | Clear session costs | +| `C` | Clear all history | + +### Supported Models + +The cost tracker includes pricing for: +- **OpenAI**: GPT-4, GPT-4o, GPT-4o-mini, O1, O3, O4-mini, and more +- **Anthropic**: Claude 3 Opus, Sonnet, Haiku, Claude 3.5 Sonnet/Haiku +- **Local**: Ollama models (free, but usage tracked) +- **Copilot**: Usage tracked (included in subscription) + +--- + ## 🤖 Agent Mode The Agent mode provides an autonomous coding assistant with tool access: diff --git a/llms.txt b/llms.txt index d00a8d8..a3ec734 100644 --- a/llms.txt +++ b/llms.txt @@ -33,6 +33,8 @@ lua/codetyper/ ├── health.lua # Health check for :checkhealth ├── tree.lua # Project tree logging (.coder/tree.log) ├── logs_panel.lua # Standalone logs panel UI +├── cost.lua # LLM cost tracking with persistent history +├── credentials.lua # Secure credential storage (API keys, models) ├── llm/ │ ├── init.lua # LLM interface, provider selection │ ├── claude.lua # Claude API client (Anthropic) @@ -68,7 +70,14 @@ The plugin automatically creates and maintains a `.coder/` folder in your projec ``` .coder/ -└── tree.log # Project structure, auto-updated on file changes +├── tree.log # Project structure, auto-updated on file changes +├── cost_history.json # LLM cost tracking history (persistent) +├── brain/ # Knowledge graph storage +│ ├── nodes/ # Learning nodes by type +│ ├── indices/ # Search indices +│ └── deltas/ # Version history +├── agents/ # Custom agent definitions +└── rules/ # Project-specific rules ``` ## Key Features @@ -115,7 +124,49 @@ auto_index = true -- disabled by default Real-time visibility into LLM operations with token usage tracking. -### 6. Event-Driven Scheduler +### 6. Cost Tracking + +Track LLM API costs across sessions: + +- **Session tracking**: Monitor current session costs in real-time +- **All-time tracking**: Persistent history in `.coder/cost_history.json` +- **Per-model breakdown**: See costs by individual model +- **50+ models**: Built-in pricing for GPT, Claude, O-series, Gemini + +Cost window keymaps: +- `q`/`` - Close window +- `r` - Refresh display +- `c` - Clear session costs +- `C` - Clear all history + +### 7. Automatic Ollama Fallback + +When API rate limits are hit (e.g., Copilot free tier), the plugin: +1. Detects the rate limit error +2. Checks if local Ollama is available +3. Automatically switches provider to Ollama +4. Notifies user of the provider change + +### 8. Credentials Management + +Store API keys securely outside of config files: + +```vim +:CoderAddApiKey +``` + +**Features:** +- Interactive prompts for provider, API key, model, endpoint +- Stored in `~/.local/share/nvim/codetyper/configuration.json` +- Supports all providers: Claude, OpenAI, Gemini, Copilot, Ollama +- Switch providers at runtime with `:CoderSwitchProvider` + +**Credential priority:** +1. Stored credentials (via `:CoderAddApiKey`) +2. Config file settings (`require("codetyper").setup({...})`) +3. Environment variables (`OPENAI_API_KEY`, etc.) + +### 9. Event-Driven Scheduler Prompts are treated as events, not commands: @@ -143,7 +194,7 @@ scheduler = { } ``` -### 7. Tree-sitter Scope Resolution +### 10. Tree-sitter Scope Resolution Prompts automatically resolve to their enclosing function/method/class: @@ -158,7 +209,7 @@ end For replacement intents (complete, refactor, fix), the entire scope is extracted and sent to the LLM, then replaced with the transformed version. -### 8. Intent Detection +### 11. Intent Detection The system parses prompts to detect user intent: @@ -173,7 +224,7 @@ The system parses prompts to detect user intent: | optimize | optimize, performance, faster | replace | | explain | explain, what, how, why | none | -### 9. Tag Precedence +### 12. Tag Precedence Multiple tags in the same scope follow "first tag wins" rule: - Earlier (by line number) unresolved tag processes first @@ -182,33 +233,114 @@ Multiple tags in the same scope follow "first tag wins" rule: ## Commands -### Main Commands -- `:Coder open` - Opens split view with coder file -- `:Coder close` - Closes the split -- `:Coder toggle` - Toggles the view -- `:Coder process` - Manually triggers code generation +All commands can be invoked via `:Coder {subcommand}` or dedicated aliases. -### Ask Panel -- `:CoderAsk` - Open Ask panel -- `:CoderAskToggle` - Toggle Ask panel -- `:CoderAskClear` - Clear chat history +### Core Commands +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder open` | `:CoderOpen` | Open coder split view | +| `:Coder close` | `:CoderClose` | Close coder split view | +| `:Coder toggle` | `:CoderToggle` | Toggle coder split view | +| `:Coder process` | `:CoderProcess` | Process last prompt in coder file | +| `:Coder status` | - | Show plugin status and configuration | +| `:Coder focus` | - | Switch focus between coder/target windows | +| `:Coder reset` | - | Reset processed prompts | +| `:Coder gitignore` | - | Force update .gitignore | -### Agent Mode -- `:CoderAgent` - Open Agent panel -- `:CoderAgentToggle` - Toggle Agent panel -- `:CoderAgentStop` - Stop running agent +### Ask Panel (Chat Interface) +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder ask` | `:CoderAsk` | Open Ask panel | +| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle Ask panel | +| `:Coder ask-close` | - | Close Ask panel | +| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history | -### Transform -- `:CoderTransform` - Transform all tags -- `:CoderTransformCursor` - Transform at cursor -- `:CoderTransformVisual` - Transform selection +### Agent Mode (Autonomous Coding) +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder agent` | `:CoderAgent` | Open Agent panel | +| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle Agent panel | +| `:Coder agent-close` | - | Close Agent panel | +| `:Coder agent-stop` | `:CoderAgentStop` | Stop running agent | -### Utility -- `:CoderIndex` - Open coder companion -- `:CoderLogs` - Toggle logs panel -- `:CoderType` - Switch Ask/Agent mode -- `:CoderTree` - Refresh tree.log -- `:CoderTreeView` - View tree.log +### Agentic Mode (IDE-like Multi-file Agent) +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder agentic-run ` | `:CoderAgenticRun ` | Run agentic task | +| `:Coder agentic-list` | `:CoderAgenticList` | List available agents | +| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize .coder/agents/ and .coder/rules/ | + +### Transform Commands (Inline Tag Processing) +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder transform` | `:CoderTransform` | Transform all /@ @/ tags in file | +| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor | +| - | `:CoderTransformVisual` | Transform selected tags (visual mode) | + +### Project & Index Commands +| Command | Alias | Description | +|---------|-------|-------------| +| - | `:CoderIndex` | Open coder companion for current file | +| `:Coder index-project` | `:CoderIndexProject` | Index entire project | +| `:Coder index-status` | `:CoderIndexStatus` | Show project index status | + +### Tree & Structure Commands +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder tree` | `:CoderTree` | Refresh .coder/tree.log | +| `:Coder tree-view` | `:CoderTreeView` | View .coder/tree.log | + +### Queue & Scheduler Commands +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler/queue status | +| `:Coder queue-process` | `:CoderQueueProcess` | Manually trigger queue processing | + +### Processing Mode Commands +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle automatic/manual processing | +| `:Coder auto-set ` | `:CoderAutoSet ` | Set mode (auto/manual) | + +### Memory & Learning Commands +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder memories` | `:CoderMemories` | Show learned memories | +| `:Coder forget [pattern]` | `:CoderForget [pattern]` | Clear memories | + +### Brain Commands (Knowledge Graph) +| Command | Alias | Description | +|---------|-------|-------------| +| - | `:CoderBrain [action]` | Brain management (stats/commit/flush/prune) | +| - | `:CoderFeedback ` | Give feedback (good/bad/stats) | + +### LLM Statistics & Feedback +| Command | Description | +|---------|-------------| +| `:Coder llm-stats` | Show LLM provider accuracy stats | +| `:Coder llm-feedback-good` | Report positive feedback | +| `:Coder llm-feedback-bad` | Report negative feedback | +| `:Coder llm-reset-stats` | Reset LLM accuracy stats | + +### Cost Tracking +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder cost` | `:CoderCost` | Show LLM cost estimation window | +| `:Coder cost-clear` | - | Clear session cost tracking | + +### Credentials Management +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder add-api-key` | `:CoderAddApiKey` | Add/update LLM provider credentials | +| `:Coder remove-api-key` | `:CoderRemoveApiKey` | Remove provider credentials | +| `:Coder credentials` | `:CoderCredentials` | Show credentials status | +| `:Coder switch-provider` | `:CoderSwitchProvider` | Switch active provider | + +### UI Commands +| Command | Alias | Description | +|---------|-------|-------------| +| `:Coder type-toggle` | `:CoderType` | Show Ask/Agent mode switcher | +| `:Coder logs-toggle` | `:CoderLogs` | Toggle logs panel | ## Configuration Schema diff --git a/lua/codetyper/agent/agentic.lua b/lua/codetyper/agent/agentic.lua new file mode 100644 index 0000000..fced368 --- /dev/null +++ b/lua/codetyper/agent/agentic.lua @@ -0,0 +1,854 @@ +---@mod codetyper.agent.agentic Agentic loop with proper tool calling +---@brief [[ +--- Full agentic system that handles multi-file changes via tool calling. +--- Inspired by avante.nvim and opencode patterns. +---@brief ]] + +local M = {} + +---@class AgenticMessage +---@field role "system"|"user"|"assistant"|"tool" +---@field content string|table +---@field tool_calls? table[] For assistant messages with tool calls +---@field tool_call_id? string For tool result messages +---@field name? string Tool name for tool results + +---@class AgenticToolCall +---@field id string Unique tool call ID +---@field type "function" +---@field function {name: string, arguments: string|table} + +---@class AgenticOpts +---@field task string The task to accomplish +---@field files? string[] Initial files to include as context +---@field agent? string Agent name to use (default: "coder") +---@field model? string Model override +---@field max_iterations? number Max tool call rounds (default: 20) +---@field on_message? fun(msg: AgenticMessage) Called for each message +---@field on_tool_start? fun(name: string, args: table) Called before tool execution +---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution +---@field on_file_change? fun(path: string, action: string) Called when file is modified +---@field on_complete? fun(result: string|nil, error: string|nil) Called when done +---@field on_status? fun(status: string) Status updates + +--- Generate unique tool call ID +local function generate_tool_call_id() + return "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF)) +end + +--- Load agent definition +---@param name string Agent name +---@return table|nil agent definition +local function load_agent(name) + local agents_dir = vim.fn.getcwd() .. "/.coder/agents" + local agent_file = agents_dir .. "/" .. name .. ".md" + + -- Check if custom agent exists + if vim.fn.filereadable(agent_file) == 1 then + local content = table.concat(vim.fn.readfile(agent_file), "\n") + -- Parse frontmatter and content + local frontmatter = {} + local body = content + + local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$") + if fm_match then + -- Parse YAML-like frontmatter + for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do + local key, value = line:match("^(%w+):%s*(.+)$") + if key and value then + frontmatter[key] = value + end + end + body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content + end + + return { + name = name, + description = frontmatter.description or "Custom agent: " .. name, + system_prompt = body, + tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil, + model = frontmatter.model, + } + end + + -- Built-in agents + local builtin_agents = { + coder = { + name = "coder", + description = "Full-featured coding agent with file modification capabilities", + system_prompt = [[You are an expert software engineer. You have access to tools to read, write, and modify files. + +## Your Capabilities +- Read files to understand the codebase +- Search for patterns with grep and glob +- Create new files with write tool +- Edit existing files with precise replacements +- Execute shell commands for builds and tests + +## Guidelines +1. Always read relevant files before making changes +2. Make minimal, focused changes +3. Follow existing code style and patterns +4. Create tests when adding new functionality +5. Verify changes work by running tests or builds + +## Important Rules +- NEVER guess file contents - always read first +- Make precise edits using exact string matching +- Explain your reasoning before making changes +- If unsure, ask for clarification]], + tools = { "view", "edit", "write", "grep", "glob", "bash" }, + }, + planner = { + name = "planner", + description = "Planning agent - read-only, helps design implementations", + system_prompt = [[You are a software architect. Analyze codebases and create implementation plans. + +You can read files and search the codebase, but cannot modify files. +Your role is to: +1. Understand the existing architecture +2. Identify relevant files and patterns +3. Create step-by-step implementation plans +4. Suggest which files to modify and how + +Be thorough in your analysis before making recommendations.]], + tools = { "view", "grep", "glob" }, + }, + explorer = { + name = "explorer", + description = "Exploration agent - quickly find information in codebase", + system_prompt = [[You are a codebase exploration assistant. Find information quickly and report back. + +Your goal is to efficiently search and summarize findings. +Use glob to find files, grep to search content, and view to read specific files. +Be concise and focused in your responses.]], + tools = { "view", "grep", "glob" }, + }, + } + + return builtin_agents[name] +end + +--- Load rules from .coder/rules/ +---@return string Combined rules content +local function load_rules() + local rules_dir = vim.fn.getcwd() .. "/.coder/rules" + local rules = {} + + if vim.fn.isdirectory(rules_dir) == 1 then + local files = vim.fn.glob(rules_dir .. "/*.md", false, true) + for _, file in ipairs(files) do + local content = table.concat(vim.fn.readfile(file), "\n") + local filename = vim.fn.fnamemodify(file, ":t:r") + table.insert(rules, string.format("## Rule: %s\n%s", filename, content)) + end + end + + if #rules > 0 then + return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n") + end + return "" +end + +--- Build messages array for API request +---@param history AgenticMessage[] +---@param provider string "openai"|"claude" +---@return table[] Formatted messages +local function build_messages(history, provider) + local messages = {} + + for _, msg in ipairs(history) do + if msg.role == "system" then + if provider == "claude" then + -- Claude uses system parameter, not message + -- Skip system messages in array + else + table.insert(messages, { + role = "system", + content = msg.content, + }) + end + elseif msg.role == "user" then + table.insert(messages, { + role = "user", + content = msg.content, + }) + elseif msg.role == "assistant" then + local message = { + role = "assistant", + content = msg.content, + } + if msg.tool_calls then + message.tool_calls = msg.tool_calls + if provider == "claude" then + -- Claude format: content is array of blocks + message.content = {} + if msg.content and msg.content ~= "" then + table.insert(message.content, { + type = "text", + text = msg.content, + }) + end + for _, tc in ipairs(msg.tool_calls) do + table.insert(message.content, { + type = "tool_use", + id = tc.id, + name = tc["function"].name, + input = type(tc["function"].arguments) == "string" + and vim.json.decode(tc["function"].arguments) + or tc["function"].arguments, + }) + end + end + end + table.insert(messages, message) + elseif msg.role == "tool" then + if provider == "claude" then + table.insert(messages, { + role = "user", + content = { + { + type = "tool_result", + tool_use_id = msg.tool_call_id, + content = msg.content, + }, + }, + }) + else + table.insert(messages, { + role = "tool", + tool_call_id = msg.tool_call_id, + content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content), + }) + end + end + end + + return messages +end + +--- Build tools array for API request +---@param tool_names string[] Tool names to include +---@param provider string "openai"|"claude" +---@return table[] Formatted tools +local function build_tools(tool_names, provider) + local tools_mod = require("codetyper.agent.tools") + local tools = {} + + for _, name in ipairs(tool_names) do + local tool = tools_mod.get(name) + if tool then + local properties = {} + local required = {} + + for _, param in ipairs(tool.params or {}) do + properties[param.name] = { + type = param.type == "integer" and "number" or param.type, + description = param.description, + } + if not param.optional then + table.insert(required, param.name) + end + end + + local description = type(tool.description) == "function" and tool.description() or tool.description + + if provider == "claude" then + table.insert(tools, { + name = tool.name, + description = description, + input_schema = { + type = "object", + properties = properties, + required = required, + }, + }) + else + table.insert(tools, { + type = "function", + ["function"] = { + name = tool.name, + description = description, + parameters = { + type = "object", + properties = properties, + required = required, + }, + }, + }) + end + end + end + + return tools +end + +--- Execute a tool call +---@param tool_call AgenticToolCall +---@param opts AgenticOpts +---@return string result +---@return string|nil error +local function execute_tool(tool_call, opts) + local tools_mod = require("codetyper.agent.tools") + local name = tool_call["function"].name + local args = tool_call["function"].arguments + + -- Parse arguments if string + if type(args) == "string" then + local ok, parsed = pcall(vim.json.decode, args) + if ok then + args = parsed + else + return "", "Failed to parse tool arguments: " .. args + end + end + + -- Notify tool start + if opts.on_tool_start then + opts.on_tool_start(name, args) + end + + if opts.on_status then + opts.on_status("Executing: " .. name) + end + + -- Execute the tool + local tool = tools_mod.get(name) + if not tool then + local err = "Unknown tool: " .. name + if opts.on_tool_end then + opts.on_tool_end(name, nil, err) + end + return "", err + end + + local result, err = tool.func(args, { + on_log = function(msg) + if opts.on_status then + opts.on_status(msg) + end + end, + }) + + -- Notify tool end + if opts.on_tool_end then + opts.on_tool_end(name, result, err) + end + + -- Track file changes + if opts.on_file_change and (name == "write" or name == "edit") and not err then + opts.on_file_change(args.path, name == "write" and "created" or "modified") + end + + if err then + return "", err + end + + return type(result) == "string" and result or vim.json.encode(result), nil +end + +--- Parse tool calls from LLM response (unified Claude-like format) +---@param response table Raw API response in unified format +---@param provider string Provider name (unused, kept for signature compatibility) +---@return AgenticToolCall[] +local function parse_tool_calls(response, provider) + local tool_calls = {} + + -- Unified format: content array with tool_use blocks + local content = response.content or {} + for _, block in ipairs(content) do + if block.type == "tool_use" then + -- OpenAI expects arguments as JSON string, not table + local args = block.input + if type(args) == "table" then + args = vim.json.encode(args) + end + + table.insert(tool_calls, { + id = block.id or generate_tool_call_id(), + type = "function", + ["function"] = { + name = block.name, + arguments = args, + }, + }) + end + end + + return tool_calls +end + +--- Extract text content from response (unified Claude-like format) +---@param response table Raw API response in unified format +---@param provider string Provider name (unused, kept for signature compatibility) +---@return string +local function extract_content(response, provider) + local parts = {} + for _, block in ipairs(response.content or {}) do + if block.type == "text" then + table.insert(parts, block.text) + end + end + return table.concat(parts, "\n") +end + +--- Check if response indicates completion (unified Claude-like format) +---@param response table Raw API response in unified format +---@param provider string Provider name (unused, kept for signature compatibility) +---@return boolean +local function is_complete(response, provider) + return response.stop_reason == "end_turn" +end + +--- Make API request to LLM with native tool calling support +---@param messages table[] Formatted messages +---@param tools table[] Formatted tools +---@param system_prompt string System prompt +---@param provider string "openai"|"claude"|"copilot" +---@param model string Model name +---@param callback fun(response: table|nil, error: string|nil) +local function call_llm(messages, tools, system_prompt, provider, model, callback) + local context = { + language = "lua", + file_content = "", + prompt_type = "agent", + project_root = vim.fn.getcwd(), + cwd = vim.fn.getcwd(), + } + + -- Use native tool calling APIs + if provider == "copilot" then + local client = require("codetyper.llm.copilot") + + -- Copilot's generate_with_tools expects messages in a specific format + -- Convert to the format it expects + local converted_messages = {} + for _, msg in ipairs(messages) do + if msg.role ~= "system" then + table.insert(converted_messages, msg) + end + end + + client.generate_with_tools(converted_messages, context, tools, function(response, err) + if err then + callback(nil, err) + return + end + + -- Response is already in Claude-like format from the provider + -- Convert to our internal format + local result = { + content = {}, + stop_reason = "end_turn", + } + + if response and response.content then + for _, block in ipairs(response.content) do + if block.type == "text" then + table.insert(result.content, { type = "text", text = block.text }) + elseif block.type == "tool_use" then + table.insert(result.content, { + type = "tool_use", + id = block.id or generate_tool_call_id(), + name = block.name, + input = block.input, + }) + result.stop_reason = "tool_use" + end + end + end + + callback(result, nil) + end) + elseif provider == "openai" then + local client = require("codetyper.llm.openai") + + -- OpenAI's generate_with_tools + local converted_messages = {} + for _, msg in ipairs(messages) do + if msg.role ~= "system" then + table.insert(converted_messages, msg) + end + end + + client.generate_with_tools(converted_messages, context, tools, function(response, err) + if err then + callback(nil, err) + return + end + + -- Response is already in Claude-like format from the provider + local result = { + content = {}, + stop_reason = "end_turn", + } + + if response and response.content then + for _, block in ipairs(response.content) do + if block.type == "text" then + table.insert(result.content, { type = "text", text = block.text }) + elseif block.type == "tool_use" then + table.insert(result.content, { + type = "tool_use", + id = block.id or generate_tool_call_id(), + name = block.name, + input = block.input, + }) + result.stop_reason = "tool_use" + end + end + end + + callback(result, nil) + end) + elseif provider == "ollama" then + local client = require("codetyper.llm.ollama") + + -- Ollama's generate_with_tools (text-based tool calling) + local converted_messages = {} + for _, msg in ipairs(messages) do + if msg.role ~= "system" then + table.insert(converted_messages, msg) + end + end + + client.generate_with_tools(converted_messages, context, tools, function(response, err) + if err then + callback(nil, err) + return + end + + -- Response is already in Claude-like format from the provider + callback(response, nil) + end) + else + -- Fallback for other providers (ollama, etc.) - use text-based parsing + local client = require("codetyper.llm." .. provider) + + -- Build prompt from messages + local prompt_parts = {} + for _, msg in ipairs(messages) do + if msg.role == "user" then + local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content) + table.insert(prompt_parts, "User: " .. content) + elseif msg.role == "assistant" then + local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content) + table.insert(prompt_parts, "Assistant: " .. content) + end + end + + -- Add tool descriptions to prompt for text-based providers + local tool_desc = "\n\n## Available Tools\n" + tool_desc = tool_desc .. "Call tools by outputting JSON in this format:\n" + tool_desc = tool_desc .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n' + for _, tool in ipairs(tools) do + local name = tool.name or (tool["function"] and tool["function"].name) + local desc = tool.description or (tool["function"] and tool["function"].description) + if name then + tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "") + end + end + + context.file_content = system_prompt .. tool_desc + + client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err) + if err then + callback(nil, err) + return + end + + -- Parse response for tool calls (text-based fallback) + local result = { + content = {}, + stop_reason = "end_turn", + } + + -- Extract text content + local text_content = response + + -- Try to extract JSON tool calls from response + local json_match = response:match("```json%s*(%b{})%s*```") + if json_match then + local ok, parsed = pcall(vim.json.decode, json_match) + if ok and parsed.tool then + table.insert(result.content, { + type = "tool_use", + id = generate_tool_call_id(), + name = parsed.tool, + input = parsed.arguments or {}, + }) + text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "") + result.stop_reason = "tool_use" + end + end + + if text_content and text_content ~= "" then + table.insert(result.content, 1, { type = "text", text = text_content }) + end + + callback(result, nil) + end) + end +end + +--- Run the agentic loop +---@param opts AgenticOpts +function M.run(opts) + -- Load agent + local agent = load_agent(opts.agent or "coder") + if not agent then + if opts.on_complete then + opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder")) + end + return + end + + -- Load rules + local rules = load_rules() + + -- Build system prompt + local system_prompt = agent.system_prompt .. rules + + -- Initialize message history + ---@type AgenticMessage[] + local history = { + { role = "system", content = system_prompt }, + } + + -- Add initial file context if provided + if opts.files and #opts.files > 0 then + local file_context = "# Initial Files\n" + for _, file_path in ipairs(opts.files) do + local content = table.concat(vim.fn.readfile(file_path) or {}, "\n") + file_context = file_context .. string.format("\n## %s\n```\n%s\n```\n", file_path, content) + end + table.insert(history, { role = "user", content = file_context }) + table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" }) + end + + -- Add the task + table.insert(history, { role = "user", content = opts.task }) + + -- Determine provider + local config = require("codetyper").get_config() + local provider = config.llm.provider or "copilot" + -- Note: Ollama has its own handler in call_llm, don't change it + + -- Get tools for this agent + local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" } + + -- Ensure tools are loaded + local tools_mod = require("codetyper.agent.tools") + tools_mod.setup() + + -- Build tools for API + local tools = build_tools(tool_names, provider) + + -- Iteration tracking + local iteration = 0 + local max_iterations = opts.max_iterations or 20 + + --- Process one iteration + local function process_iteration() + iteration = iteration + 1 + + if iteration > max_iterations then + if opts.on_complete then + opts.on_complete(nil, "Max iterations reached") + end + return + end + + if opts.on_status then + opts.on_status(string.format("Thinking... (iteration %d)", iteration)) + end + + -- Build messages for API + local messages = build_messages(history, provider) + + -- Call LLM + call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err) + if err then + if opts.on_complete then + opts.on_complete(nil, err) + end + return + end + + -- Extract content and tool calls + local content = extract_content(response, provider) + local tool_calls = parse_tool_calls(response, provider) + + -- Add assistant message to history + local assistant_msg = { + role = "assistant", + content = content, + tool_calls = #tool_calls > 0 and tool_calls or nil, + } + table.insert(history, assistant_msg) + + if opts.on_message then + opts.on_message(assistant_msg) + end + + -- Process tool calls if any + if #tool_calls > 0 then + for _, tc in ipairs(tool_calls) do + local result, tool_err = execute_tool(tc, opts) + + -- Add tool result to history + local tool_msg = { + role = "tool", + tool_call_id = tc.id, + name = tc["function"].name, + content = tool_err or result, + } + table.insert(history, tool_msg) + + if opts.on_message then + opts.on_message(tool_msg) + end + end + + -- Continue the loop + vim.schedule(process_iteration) + else + -- No tool calls - check if complete + if is_complete(response, provider) or content ~= "" then + if opts.on_complete then + opts.on_complete(content, nil) + end + else + -- Continue if not explicitly complete + vim.schedule(process_iteration) + end + end + end) + end + + -- Start the loop + process_iteration() +end + +--- Create default agent files in .coder/agents/ +function M.init_agents_dir() + local agents_dir = vim.fn.getcwd() .. "/.coder/agents" + vim.fn.mkdir(agents_dir, "p") + + -- Create example agent + local example_agent = [[--- +description: Example custom agent +tools: view,grep,glob,edit,write +model: +--- + +# Custom Agent + +You are a custom coding agent. Describe your specialized behavior here. + +## Your Role +- Define what this agent specializes in +- List specific capabilities + +## Guidelines +- Add agent-specific rules +- Define coding standards to follow + +## Examples +Provide examples of how to handle common tasks. +]] + + local example_path = agents_dir .. "/example.md" + if vim.fn.filereadable(example_path) ~= 1 then + vim.fn.writefile(vim.split(example_agent, "\n"), example_path) + end + + return agents_dir +end + +--- Create default rules in .coder/rules/ +function M.init_rules_dir() + local rules_dir = vim.fn.getcwd() .. "/.coder/rules" + vim.fn.mkdir(rules_dir, "p") + + -- Create example rule + local example_rule = [[# Code Style + +Follow these coding standards: + +## General +- Use consistent indentation (tabs or spaces based on project) +- Keep lines under 100 characters +- Add comments for complex logic + +## Naming Conventions +- Use descriptive variable names +- Functions should be verbs (e.g., getUserData, calculateTotal) +- Constants in UPPER_SNAKE_CASE + +## Testing +- Write tests for new functionality +- Aim for >80% code coverage +- Test edge cases + +## Documentation +- Document public APIs +- Include usage examples +- Keep docs up to date with code +]] + + local example_path = rules_dir .. "/code-style.md" + if vim.fn.filereadable(example_path) ~= 1 then + vim.fn.writefile(vim.split(example_rule, "\n"), example_path) + end + + return rules_dir +end + +--- Initialize both agents and rules directories +function M.init() + M.init_agents_dir() + M.init_rules_dir() +end + +--- List available agents +---@return table[] List of {name, description, builtin} +function M.list_agents() + local agents = {} + + -- Built-in agents + local builtins = { "coder", "planner", "explorer" } + for _, name in ipairs(builtins) do + local agent = load_agent(name) + if agent then + table.insert(agents, { + name = agent.name, + description = agent.description, + builtin = true, + }) + end + end + + -- Custom agents from .coder/agents/ + local agents_dir = vim.fn.getcwd() .. "/.coder/agents" + if vim.fn.isdirectory(agents_dir) == 1 then + local files = vim.fn.glob(agents_dir .. "/*.md", false, true) + for _, file in ipairs(files) do + local name = vim.fn.fnamemodify(file, ":t:r") + if not vim.tbl_contains(builtins, name) then + local agent = load_agent(name) + if agent then + table.insert(agents, { + name = agent.name, + description = agent.description, + builtin = false, + }) + end + end + end + end + + return agents +end + +return M diff --git a/lua/codetyper/agent/confidence.lua b/lua/codetyper/agent/confidence.lua index e8e002f..e02d4b7 100644 --- a/lua/codetyper/agent/confidence.lua +++ b/lua/codetyper/agent/confidence.lua @@ -8,11 +8,11 @@ local M = {} --- Heuristic weights (must sum to 1.0) M.weights = { - length = 0.15, -- Response length relative to prompt + length = 0.15, -- Response length relative to prompt uncertainty = 0.30, -- Uncertainty phrases - syntax = 0.25, -- Syntax completeness - repetition = 0.15, -- Duplicate lines - truncation = 0.15, -- Incomplete ending + syntax = 0.25, -- Syntax completeness + repetition = 0.15, -- Duplicate lines + truncation = 0.15, -- Incomplete ending } --- Uncertainty phrases that indicate low confidence @@ -255,14 +255,15 @@ function M.score(response, prompt, context) _ = context -- Reserved for future use if not response or #response == 0 then - return 0, { - length = 0, - uncertainty = 0, - syntax = 0, - repetition = 0, - truncation = 0, - weighted_total = 0, - } + return 0, + { + length = 0, + uncertainty = 0, + syntax = 0, + repetition = 0, + truncation = 0, + weighted_total = 0, + } end local scores = { diff --git a/lua/codetyper/agent/init.lua b/lua/codetyper/agent/init.lua index 3751dd1..bc781e5 100644 --- a/lua/codetyper/agent/init.lua +++ b/lua/codetyper/agent/init.lua @@ -111,7 +111,11 @@ function M.agent_loop(context, callbacks) logs.thinking("Calling LLM with " .. #state.conversation .. " messages...") -- Generate with tools enabled - client.generate_with_tools(state.conversation, context, tools.definitions, function(response, err) + -- Ensure tools are loaded and get definitions + tools.setup() + local tool_defs = tools.to_openai_format() + + client.generate_with_tools(state.conversation, context, tool_defs, function(response, err) if err then state.is_running = false callbacks.on_error(err) diff --git a/lua/codetyper/agent/inject.lua b/lua/codetyper/agent/inject.lua new file mode 100644 index 0000000..97c59cb --- /dev/null +++ b/lua/codetyper/agent/inject.lua @@ -0,0 +1,614 @@ +---@mod codetyper.agent.inject Smart code injection with import handling +---@brief [[ +--- Intelligent code injection that properly handles imports, merging them +--- into existing import sections instead of blindly appending. +---@brief ]] + +local M = {} + +---@class ImportConfig +---@field pattern string Lua pattern to match import statements +---@field multi_line boolean Whether imports can span multiple lines +---@field sort_key function|nil Function to extract sort key from import +---@field group_by function|nil Function to group imports + +---@class ParsedCode +---@field imports string[] Import statements +---@field body string[] Non-import code lines +---@field import_lines table Map of line numbers that are imports + +--- Language-specific import patterns +local import_patterns = { + -- JavaScript/TypeScript + javascript = { + { pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true }, + { pattern = "^%s*import%s+['\"]", multi_line = false }, + { pattern = "^%s*import%s*{", multi_line = true }, + { pattern = "^%s*import%s*%*", multi_line = true }, + { pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true }, + { pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false }, + { pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false }, + { pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false }, + }, + -- Python + python = { + { pattern = "^%s*import%s+%w", multi_line = false }, + { pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true }, + }, + -- Lua + lua = { + { pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false }, + { pattern = "^%s*require%s*%(?['\"]", multi_line = false }, + }, + -- Go + go = { + { pattern = "^%s*import%s+%(?", multi_line = true }, + }, + -- Rust + rust = { + { pattern = "^%s*use%s+", multi_line = true }, + { pattern = "^%s*extern%s+crate%s+", multi_line = false }, + }, + -- C/C++ + c = { + { pattern = "^%s*#include%s*[<\"]", multi_line = false }, + }, + -- Java/Kotlin + java = { + { pattern = "^%s*import%s+", multi_line = false }, + }, + -- Ruby + ruby = { + { pattern = "^%s*require%s+['\"]", multi_line = false }, + { pattern = "^%s*require_relative%s+['\"]", multi_line = false }, + }, + -- PHP + php = { + { pattern = "^%s*use%s+", multi_line = false }, + { pattern = "^%s*require%s+['\"]", multi_line = false }, + { pattern = "^%s*require_once%s+['\"]", multi_line = false }, + { pattern = "^%s*include%s+['\"]", multi_line = false }, + { pattern = "^%s*include_once%s+['\"]", multi_line = false }, + }, +} + +-- Alias common extensions to language configs +import_patterns.ts = import_patterns.javascript +import_patterns.tsx = import_patterns.javascript +import_patterns.jsx = import_patterns.javascript +import_patterns.mjs = import_patterns.javascript +import_patterns.cjs = import_patterns.javascript +import_patterns.py = import_patterns.python +import_patterns.cpp = import_patterns.c +import_patterns.hpp = import_patterns.c +import_patterns.h = import_patterns.c +import_patterns.kt = import_patterns.java +import_patterns.rs = import_patterns.rust +import_patterns.rb = import_patterns.ruby + +--- Check if a line is an import statement for the given language +---@param line string +---@param patterns table[] Import patterns for the language +---@return boolean is_import +---@return boolean is_multi_line +local function is_import_line(line, patterns) + for _, p in ipairs(patterns) do + if line:match(p.pattern) then + return true, p.multi_line or false + end + end + return false, false +end + +--- Check if a line is empty or a comment +---@param line string +---@param filetype string +---@return boolean +local function is_empty_or_comment(line, filetype) + local trimmed = line:match("^%s*(.-)%s*$") + if trimmed == "" then + return true + end + + -- Language-specific comment patterns + local comment_patterns = { + lua = { "^%-%-" }, + python = { "^#" }, + javascript = { "^//", "^/%*", "^%*" }, + typescript = { "^//", "^/%*", "^%*" }, + go = { "^//", "^/%*", "^%*" }, + rust = { "^//", "^/%*", "^%*" }, + c = { "^//", "^/%*", "^%*", "^#" }, + java = { "^//", "^/%*", "^%*" }, + ruby = { "^#" }, + php = { "^//", "^/%*", "^%*", "^#" }, + } + + local patterns = comment_patterns[filetype] or comment_patterns.javascript + for _, pattern in ipairs(patterns) do + if trimmed:match(pattern) then + return true + end + end + + return false +end + +--- Check if a line ends a multi-line import +---@param line string +---@param filetype string +---@return boolean +local function ends_multiline_import(line, filetype) + -- Check for closing patterns + if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then + -- ES6 imports end with 'from "..." ;' or just ';' or a line with just '}' + if line:match("from%s+['\"][^'\"]+['\"]%s*;?%s*$") then + return true + end + if line:match("}%s*from%s+['\"]") then + return true + end + if line:match("^%s*}%s*;?%s*$") then + return true + end + if line:match(";%s*$") then + return true + end + elseif filetype == "python" or filetype == "py" then + -- Python single-line import: doesn't end with \, (, or , + -- Examples: "from typing import List, Dict" or "import os" + if not line:match("\\%s*$") and not line:match("%(%s*$") and not line:match(",%s*$") then + return true + end + -- Python multiline imports end with closing paren + if line:match("%)%s*$") then + return true + end + elseif filetype == "go" then + -- Go multi-line imports end with ')' + if line:match("%)%s*$") then + return true + end + elseif filetype == "rust" or filetype == "rs" then + -- Rust use statements end with ';' + if line:match(";%s*$") then + return true + end + end + + return false +end + +--- Parse code into imports and body +---@param code string|string[] Code to parse +---@param filetype string File type/extension +---@return ParsedCode +function M.parse_code(code, filetype) + local lines + if type(code) == "string" then + lines = vim.split(code, "\n", { plain = true }) + else + lines = code + end + + local patterns = import_patterns[filetype] or import_patterns.javascript + + local result = { + imports = {}, + body = {}, + import_lines = {}, + } + + local in_multiline_import = false + local current_import_lines = {} + + for i, line in ipairs(lines) do + if in_multiline_import then + -- Continue collecting multi-line import + table.insert(current_import_lines, line) + + if ends_multiline_import(line, filetype) then + -- Complete the multi-line import + table.insert(result.imports, table.concat(current_import_lines, "\n")) + for j = i - #current_import_lines + 1, i do + result.import_lines[j] = true + end + current_import_lines = {} + in_multiline_import = false + end + else + local is_import, is_multi = is_import_line(line, patterns) + + if is_import then + result.import_lines[i] = true + + if is_multi and not ends_multiline_import(line, filetype) then + -- Start of multi-line import + in_multiline_import = true + current_import_lines = { line } + else + -- Single-line import + table.insert(result.imports, line) + end + else + -- Non-import line + table.insert(result.body, line) + end + end + end + + -- Handle unclosed multi-line import (shouldn't happen with well-formed code) + if #current_import_lines > 0 then + table.insert(result.imports, table.concat(current_import_lines, "\n")) + end + + return result +end + +--- Find the import section range in a buffer +---@param bufnr number Buffer number +---@param filetype string +---@return number|nil start_line First import line (1-indexed) +---@return number|nil end_line Last import line (1-indexed) +function M.find_import_section(bufnr, filetype) + if not vim.api.nvim_buf_is_valid(bufnr) then + return nil, nil + end + + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local patterns = import_patterns[filetype] or import_patterns.javascript + + local first_import = nil + local last_import = nil + local in_multiline = false + local consecutive_non_import = 0 + local max_gap = 3 -- Allow up to 3 blank/comment lines between imports + + for i, line in ipairs(lines) do + if in_multiline then + last_import = i + consecutive_non_import = 0 + + if ends_multiline_import(line, filetype) then + in_multiline = false + end + else + local is_import, is_multi = is_import_line(line, patterns) + + if is_import then + if not first_import then + first_import = i + end + last_import = i + consecutive_non_import = 0 + + if is_multi and not ends_multiline_import(line, filetype) then + in_multiline = true + end + elseif is_empty_or_comment(line, filetype) then + -- Allow gaps in import section + if first_import then + consecutive_non_import = consecutive_non_import + 1 + if consecutive_non_import > max_gap then + -- Too many non-import lines, import section has ended + break + end + end + else + -- Non-import, non-empty line + if first_import then + -- Import section has ended + break + end + end + end + end + + return first_import, last_import +end + +--- Get existing imports from a buffer +---@param bufnr number Buffer number +---@param filetype string +---@return string[] Existing import statements +function M.get_existing_imports(bufnr, filetype) + local start_line, end_line = M.find_import_section(bufnr, filetype) + if not start_line then + return {} + end + + local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false) + local parsed = M.parse_code(lines, filetype) + return parsed.imports +end + +--- Normalize an import for comparison (remove whitespace variations) +---@param import_str string +---@return string +local function normalize_import(import_str) + -- Remove trailing semicolon for comparison + local normalized = import_str:gsub(";%s*$", "") + -- Remove all whitespace around braces, commas, colons + normalized = normalized:gsub("%s*{%s*", "{") + normalized = normalized:gsub("%s*}%s*", "}") + normalized = normalized:gsub("%s*,%s*", ",") + normalized = normalized:gsub("%s*:%s*", ":") + -- Collapse multiple whitespace to single space + normalized = normalized:gsub("%s+", " ") + -- Trim leading/trailing whitespace + normalized = normalized:match("^%s*(.-)%s*$") + return normalized +end + +--- Check if two imports are duplicates +---@param import1 string +---@param import2 string +---@return boolean +local function are_duplicate_imports(import1, import2) + return normalize_import(import1) == normalize_import(import2) +end + +--- Merge new imports with existing ones, avoiding duplicates +---@param existing string[] Existing imports +---@param new_imports string[] New imports to merge +---@return string[] Merged imports +function M.merge_imports(existing, new_imports) + local merged = {} + local seen = {} + + -- Add existing imports + for _, imp in ipairs(existing) do + local normalized = normalize_import(imp) + if not seen[normalized] then + seen[normalized] = true + table.insert(merged, imp) + end + end + + -- Add new imports that aren't duplicates + for _, imp in ipairs(new_imports) do + local normalized = normalize_import(imp) + if not seen[normalized] then + seen[normalized] = true + table.insert(merged, imp) + end + end + + return merged +end + +--- Sort imports by their source/module +---@param imports string[] +---@param filetype string +---@return string[] +function M.sort_imports(imports, filetype) + -- Group imports: stdlib/builtin first, then third-party, then local + local builtin = {} + local third_party = {} + local local_imports = {} + + for _, imp in ipairs(imports) do + -- Detect import type based on patterns + local is_local = false + local is_builtin = false + + if filetype == "javascript" or filetype == "typescript" or filetype == "ts" or filetype == "tsx" then + -- Local: starts with . or .. + is_local = imp:match("from%s+['\"]%.") or imp:match("require%(['\"]%.") + -- Node builtin modules + is_builtin = imp:match("from%s+['\"]node:") or imp:match("from%s+['\"]fs['\"]") + or imp:match("from%s+['\"]path['\"]") or imp:match("from%s+['\"]http['\"]") + elseif filetype == "python" or filetype == "py" then + -- Local: relative imports + is_local = imp:match("^from%s+%.") or imp:match("^import%s+%.") + -- Python stdlib (simplified check) + is_builtin = imp:match("^import%s+os") or imp:match("^import%s+sys") + or imp:match("^from%s+os%s+") or imp:match("^from%s+sys%s+") + or imp:match("^import%s+re") or imp:match("^import%s+json") + elseif filetype == "lua" then + -- Local: relative requires + is_local = imp:match("require%(['\"]%.") or imp:match("require%s+['\"]%.") + elseif filetype == "go" then + -- Local: project imports (contain /) + is_local = imp:match("['\"][^'\"]+/[^'\"]+['\"]") and not imp:match("github%.com") + end + + if is_builtin then + table.insert(builtin, imp) + elseif is_local then + table.insert(local_imports, imp) + else + table.insert(third_party, imp) + end + end + + -- Sort each group alphabetically + table.sort(builtin) + table.sort(third_party) + table.sort(local_imports) + + -- Combine with proper spacing + local result = {} + + for _, imp in ipairs(builtin) do + table.insert(result, imp) + end + if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then + table.insert(result, "") -- Blank line between groups + end + + for _, imp in ipairs(third_party) do + table.insert(result, imp) + end + if #third_party > 0 and #local_imports > 0 then + table.insert(result, "") -- Blank line between groups + end + + for _, imp in ipairs(local_imports) do + table.insert(result, imp) + end + + return result +end + +---@class InjectResult +---@field success boolean +---@field imports_added number Number of new imports added +---@field imports_merged boolean Whether imports were merged into existing section +---@field body_lines number Number of body lines injected + +--- Smart inject code into a buffer, properly handling imports +---@param bufnr number Target buffer +---@param code string|string[] Code to inject +---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil } +---@return InjectResult +function M.inject(bufnr, code, opts) + opts = opts or {} + + if not vim.api.nvim_buf_is_valid(bufnr) then + return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 } + end + + -- Get filetype + local filetype = opts.filetype + if not filetype then + local bufname = vim.api.nvim_buf_get_name(bufnr) + filetype = vim.fn.fnamemodify(bufname, ":e") + end + + -- Parse the code to separate imports from body + local parsed = M.parse_code(code, filetype) + + local result = { + success = true, + imports_added = 0, + imports_merged = false, + body_lines = #parsed.body, + } + + -- Handle imports first if there are any + if #parsed.imports > 0 then + local import_start, import_end = M.find_import_section(bufnr, filetype) + + if import_start then + -- Merge with existing import section + local existing_imports = M.get_existing_imports(bufnr, filetype) + local merged = M.merge_imports(existing_imports, parsed.imports) + + -- Count how many new imports were actually added + result.imports_added = #merged - #existing_imports + result.imports_merged = true + + -- Optionally sort imports + if opts.sort_imports ~= false then + merged = M.sort_imports(merged, filetype) + end + + -- Convert back to lines (handling multi-line imports) + local import_lines = {} + for _, imp in ipairs(merged) do + for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do + table.insert(import_lines, line) + end + end + + -- Replace the import section + vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines) + + -- Adjust line numbers for body injection + local lines_diff = #import_lines - (import_end - import_start + 1) + if opts.range and opts.range.start_line and opts.range.start_line > import_end then + opts.range.start_line = opts.range.start_line + lines_diff + if opts.range.end_line then + opts.range.end_line = opts.range.end_line + lines_diff + end + end + else + -- No existing import section, add imports at the top + -- Find the first non-comment, non-empty line + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local insert_at = 0 + + for i, line in ipairs(lines) do + local trimmed = line:match("^%s*(.-)%s*$") + -- Skip shebang, docstrings, and initial comments + if trimmed ~= "" and not trimmed:match("^#!") + and not trimmed:match("^['\"]") and not is_empty_or_comment(line, filetype) then + insert_at = i - 1 + break + end + insert_at = i + end + + -- Add imports with a trailing blank line + local import_lines = {} + for _, imp in ipairs(parsed.imports) do + for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do + table.insert(import_lines, line) + end + end + table.insert(import_lines, "") -- Blank line after imports + + vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines) + result.imports_added = #parsed.imports + result.imports_merged = false + + -- Adjust body injection range + if opts.range and opts.range.start_line then + opts.range.start_line = opts.range.start_line + #import_lines + if opts.range.end_line then + opts.range.end_line = opts.range.end_line + #import_lines + end + end + end + end + + -- Handle body (non-import) code + if #parsed.body > 0 then + -- Filter out empty leading/trailing lines from body + local body_lines = parsed.body + while #body_lines > 0 and body_lines[1]:match("^%s*$") do + table.remove(body_lines, 1) + end + while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do + table.remove(body_lines) + end + + if #body_lines > 0 then + local line_count = vim.api.nvim_buf_line_count(bufnr) + local strategy = opts.strategy or "append" + + if strategy == "replace" and opts.range then + local start_line = math.max(1, opts.range.start_line) + local end_line = math.min(line_count, opts.range.end_line) + vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines) + elseif strategy == "insert" and opts.range then + local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1)) + vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines) + else + -- Default: append + local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or "" + if last_line:match("%S") then + -- Add blank line for spacing + table.insert(body_lines, 1, "") + end + vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines) + end + + result.body_lines = #body_lines + end + end + + return result +end + +--- Check if code contains imports +---@param code string|string[] +---@param filetype string +---@return boolean +function M.has_imports(code, filetype) + local parsed = M.parse_code(code, filetype) + return #parsed.imports > 0 +end + +return M diff --git a/lua/codetyper/agent/loop.lua b/lua/codetyper/agent/loop.lua new file mode 100644 index 0000000..cea5832 --- /dev/null +++ b/lua/codetyper/agent/loop.lua @@ -0,0 +1,398 @@ +---@mod codetyper.agent.loop Agent loop with tool orchestration +---@brief [[ +--- Main agent loop that handles multi-turn conversations with tool use. +--- Inspired by avante.nvim's agent_loop pattern. +---@brief ]] + +local M = {} + +---@class AgentMessage +---@field role "system"|"user"|"assistant"|"tool" +---@field content string|table +---@field tool_call_id? string For tool responses +---@field tool_calls? table[] For assistant tool calls +---@field name? string Tool name for tool responses + +---@class AgentLoopOpts +---@field system_prompt string System prompt +---@field user_input string Initial user message +---@field tools? CoderTool[] Available tools (default: all registered) +---@field max_iterations? number Max tool call iterations (default: 10) +---@field provider? string LLM provider to use +---@field on_start? fun() Called when loop starts +---@field on_chunk? fun(chunk: string) Called for each response chunk +---@field on_tool_call? fun(name: string, input: table) Called before tool execution +---@field on_tool_result? fun(name: string, result: any, error: string|nil) Called after tool execution +---@field on_message? fun(message: AgentMessage) Called for each message added +---@field on_complete? fun(result: string|nil, error: string|nil) Called when loop completes +---@field session_ctx? table Session context shared across tools + +--- Format tool definitions for OpenAI-compatible API +---@param tools CoderTool[] +---@return table[] +local function format_tools_for_api(tools) + local formatted = {} + for _, tool in ipairs(tools) do + local properties = {} + local required = {} + + for _, param in ipairs(tool.params or {}) do + properties[param.name] = { + type = param.type == "integer" and "number" or param.type, + description = param.description, + } + if not param.optional then + table.insert(required, param.name) + end + end + + table.insert(formatted, { + type = "function", + ["function"] = { + name = tool.name, + description = type(tool.description) == "function" and tool.description() or tool.description, + parameters = { + type = "object", + properties = properties, + required = required, + }, + }, + }) + end + return formatted +end + +--- Parse tool calls from LLM response +---@param response table LLM response +---@return table[] tool_calls +local function parse_tool_calls(response) + local tool_calls = {} + + -- Handle different response formats + if response.tool_calls then + -- OpenAI format + for _, call in ipairs(response.tool_calls) do + local args = call["function"].arguments + if type(args) == "string" then + local ok, parsed = pcall(vim.json.decode, args) + if ok then + args = parsed + end + end + table.insert(tool_calls, { + id = call.id, + name = call["function"].name, + input = args, + }) + end + elseif response.content and type(response.content) == "table" then + -- Claude format (content blocks) + for _, block in ipairs(response.content) do + if block.type == "tool_use" then + table.insert(tool_calls, { + id = block.id, + name = block.name, + input = block.input, + }) + end + end + end + + return tool_calls +end + +--- Build messages for LLM request +---@param history AgentMessage[] +---@return table[] +local function build_messages(history) + local messages = {} + + for _, msg in ipairs(history) do + if msg.role == "system" then + table.insert(messages, { + role = "system", + content = msg.content, + }) + elseif msg.role == "user" then + table.insert(messages, { + role = "user", + content = msg.content, + }) + elseif msg.role == "assistant" then + local message = { + role = "assistant", + content = msg.content, + } + if msg.tool_calls then + message.tool_calls = msg.tool_calls + end + table.insert(messages, message) + elseif msg.role == "tool" then + table.insert(messages, { + role = "tool", + tool_call_id = msg.tool_call_id, + content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content), + }) + end + end + + return messages +end + +--- Execute the agent loop +---@param opts AgentLoopOpts +function M.run(opts) + local tools_mod = require("codetyper.agent.tools") + local llm = require("codetyper.llm") + + -- Get tools + local tools = opts.tools or tools_mod.list() + local tool_map = {} + for _, tool in ipairs(tools) do + tool_map[tool.name] = tool + end + + -- Initialize conversation history + ---@type AgentMessage[] + local history = { + { role = "system", content = opts.system_prompt }, + { role = "user", content = opts.user_input }, + } + + local session_ctx = opts.session_ctx or {} + local max_iterations = opts.max_iterations or 10 + local iteration = 0 + + -- Callback wrappers + local function on_message(msg) + if opts.on_message then + opts.on_message(msg) + end + end + + -- Notify of initial messages + for _, msg in ipairs(history) do + on_message(msg) + end + + -- Start notification + if opts.on_start then + opts.on_start() + end + + --- Process one iteration of the loop + local function process_iteration() + iteration = iteration + 1 + + if iteration > max_iterations then + if opts.on_complete then + opts.on_complete(nil, "Max iterations reached") + end + return + end + + -- Build request + local messages = build_messages(history) + local formatted_tools = format_tools_for_api(tools) + + -- Build context for LLM + local context = { + file_content = "", + language = "lua", + extension = "lua", + prompt_type = "agent", + tools = formatted_tools, + } + + -- Get LLM response + local client = llm.get_client() + if not client then + if opts.on_complete then + opts.on_complete(nil, "No LLM client available") + end + return + end + + -- Build prompt from messages + local prompt_parts = {} + for _, msg in ipairs(messages) do + if msg.role ~= "system" then + table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or "")) + end + end + local prompt = table.concat(prompt_parts, "\n\n") + + client.generate(prompt, context, function(response, error) + if error then + if opts.on_complete then + opts.on_complete(nil, error) + end + return + end + + -- Chunk callback + if opts.on_chunk then + opts.on_chunk(response) + end + + -- Parse response for tool calls + -- For now, we'll use a simple heuristic to detect tool calls in the response + -- In a full implementation, the LLM would return structured tool calls + local tool_calls = {} + + -- Try to parse JSON tool calls from response + local json_match = response:match("```json%s*(%b{})%s*```") + if json_match then + local ok, parsed = pcall(vim.json.decode, json_match) + if ok and parsed.tool_calls then + tool_calls = parsed.tool_calls + end + end + + -- Add assistant message + local assistant_msg = { + role = "assistant", + content = response, + tool_calls = #tool_calls > 0 and tool_calls or nil, + } + table.insert(history, assistant_msg) + on_message(assistant_msg) + + -- Process tool calls + if #tool_calls > 0 then + local pending = #tool_calls + local results = {} + + for i, call in ipairs(tool_calls) do + local tool = tool_map[call.name] + if not tool then + results[i] = { error = "Unknown tool: " .. call.name } + pending = pending - 1 + else + -- Notify of tool call + if opts.on_tool_call then + opts.on_tool_call(call.name, call.input) + end + + -- Execute tool + local tool_opts = { + on_log = function(msg) + pcall(function() + local logs = require("codetyper.agent.logs") + logs.add({ type = "tool", message = msg }) + end) + end, + on_complete = function(result, err) + results[i] = { result = result, error = err } + pending = pending - 1 + + -- Notify of tool result + if opts.on_tool_result then + opts.on_tool_result(call.name, result, err) + end + + -- Add tool response to history + local tool_msg = { + role = "tool", + tool_call_id = call.id or tostring(i), + name = call.name, + content = err or result, + } + table.insert(history, tool_msg) + on_message(tool_msg) + + -- Continue loop when all tools complete + if pending == 0 then + vim.schedule(process_iteration) + end + end, + session_ctx = session_ctx, + } + + -- Validate and execute + local valid, validation_err = true, nil + if tool.validate_input then + valid, validation_err = tool:validate_input(call.input) + end + + if not valid then + tool_opts.on_complete(nil, validation_err) + else + local result, err = tool.func(call.input, tool_opts) + -- If sync result, call on_complete + if result ~= nil or err ~= nil then + tool_opts.on_complete(result, err) + end + end + end + end + else + -- No tool calls - loop complete + if opts.on_complete then + opts.on_complete(response, nil) + end + end + end) + end + + -- Start the loop + process_iteration() +end + +--- Create an agent with default settings +---@param task string Task description +---@param opts? AgentLoopOpts Additional options +function M.create(task, opts) + opts = opts or {} + + local system_prompt = opts.system_prompt or [[You are a helpful coding assistant with access to tools. + +Available tools: +- view: Read file contents +- grep: Search for patterns in files +- glob: Find files by pattern +- edit: Make targeted edits to files +- write: Create or overwrite files +- bash: Execute shell commands + +When you need to perform a task: +1. Use tools to gather information +2. Plan your approach +3. Execute changes using appropriate tools +4. Verify the results + +Always explain your reasoning before using tools. +When you're done, provide a clear summary of what was accomplished.]] + + M.run(vim.tbl_extend("force", opts, { + system_prompt = system_prompt, + user_input = task, + })) +end + +--- Simple dispatch agent for sub-tasks +---@param prompt string Task for the sub-agent +---@param on_complete fun(result: string|nil, error: string|nil) Completion callback +---@param opts? table Additional options +function M.dispatch(prompt, on_complete, opts) + opts = opts or {} + + -- Sub-agents get limited tools by default + local tools_mod = require("codetyper.agent.tools") + local safe_tools = tools_mod.list(function(tool) + return tool.name == "view" or tool.name == "grep" or tool.name == "glob" + end) + + M.run({ + system_prompt = [[You are a research assistant. Your task is to find information and report back. +You have access to: view (read files), grep (search content), glob (find files). +Be thorough and report your findings clearly.]], + user_input = prompt, + tools = opts.tools or safe_tools, + max_iterations = opts.max_iterations or 5, + on_complete = on_complete, + session_ctx = opts.session_ctx, + }) +end + +return M diff --git a/lua/codetyper/agent/patch.lua b/lua/codetyper/agent/patch.lua index b418b97..623d59f 100644 --- a/lua/codetyper/agent/patch.lua +++ b/lua/codetyper/agent/patch.lua @@ -2,10 +2,16 @@ ---@brief [[ --- Manages code patches with buffer snapshots for staleness detection. --- Patches are queued for safe injection when completion popup is not visible. +--- Uses smart injection for intelligent import merging. ---@brief ]] local M = {} +--- Lazy load inject module to avoid circular requires +local function get_inject_module() + return require("codetyper.agent.inject") +end + ---@class BufferSnapshot ---@field bufnr number Buffer number ---@field changedtick number vim.b.changedtick at snapshot time @@ -15,7 +21,8 @@ local M = {} ---@class PatchCandidate ---@field id string Unique patch ID ---@field event_id string Related PromptEvent ID ----@field target_bufnr number Target buffer for injection +---@field source_bufnr number Source buffer where prompt tags are (coder file) +---@field target_bufnr number Target buffer for injection (real file) ---@field target_path string Target file path ---@field original_snapshot BufferSnapshot Snapshot at event creation ---@field generated_code string Code to inject @@ -171,7 +178,10 @@ end ---@param strategy string|nil Injection strategy (overrides intent-based) ---@return PatchCandidate function M.create_from_event(event, generated_code, confidence, strategy) - -- Get target buffer + -- Source buffer is where the prompt tags are (could be coder file) + local source_bufnr = event.bufnr + + -- Get target buffer (where code should be injected - the real file) local target_bufnr = vim.fn.bufnr(event.target_path) if target_bufnr == -1 then -- Try to find by filename @@ -220,7 +230,8 @@ function M.create_from_event(event, generated_code, confidence, strategy) return { id = M.generate_id(), event_id = event.id, - target_bufnr = target_bufnr, + source_bufnr = source_bufnr, -- Where prompt tags are (coder file) + target_bufnr = target_bufnr, -- Where code goes (real file) target_path = event.target_path, original_snapshot = snapshot, generated_code = generated_code, @@ -453,39 +464,56 @@ function M.apply(patch) -- Prepare code lines local code_lines = vim.split(patch.generated_code, "\n", { plain = true }) - -- FIRST: Remove the prompt tags from the buffer before applying code - -- This prevents the infinite loop where tags stay and get re-detected - local tags_removed = remove_prompt_tags(target_bufnr) + -- FIRST: Remove the prompt tags from the SOURCE buffer (coder file), not target + -- The tags are in the coder file where the user wrote the prompt + -- Code goes to target file, tags get removed from source file + local source_bufnr = patch.source_bufnr + local tags_removed = 0 - pcall(function() - if tags_removed > 0 then - local logs = require("codetyper.agent.logs") - logs.add({ - type = "info", - message = string.format("Removed %d prompt tag(s) from buffer", tags_removed), - }) - end - end) + if source_bufnr and vim.api.nvim_buf_is_valid(source_bufnr) then + tags_removed = remove_prompt_tags(source_bufnr) - -- Recalculate line count after tag removal - local line_count = vim.api.nvim_buf_line_count(target_bufnr) + pcall(function() + if tags_removed > 0 then + local logs = require("codetyper.agent.logs") + local source_name = vim.api.nvim_buf_get_name(source_bufnr) + logs.add({ + type = "info", + message = string.format("Removed %d prompt tag(s) from %s", + tags_removed, + vim.fn.fnamemodify(source_name, ":t")), + }) + end + end) + end - -- Apply based on strategy + -- Get filetype for smart injection + local filetype = vim.fn.fnamemodify(patch.target_path or "", ":e") + + -- Use smart injection module for intelligent import handling + local inject = get_inject_module() + local inject_result = nil + + -- Apply based on strategy using smart injection local ok, err = pcall(function() + -- Prepare injection options + local inject_opts = { + strategy = patch.injection_strategy, + filetype = filetype, + sort_imports = true, + } + if patch.injection_strategy == "replace" and patch.injection_range then -- Replace the scope range with the new code - -- The injection_range points to the function/method we're completing local start_line = patch.injection_range.start_line local end_line = patch.injection_range.end_line -- Adjust for tag removal - find the new range by searching for the scope -- After removing tags, line numbers may have shifted - -- Use the scope information to find the correct range if patch.scope and patch.scope.type then -- Try to find the scope using treesitter if available local found_range = nil pcall(function() - local ts_utils = require("nvim-treesitter.ts_utils") local parsers = require("nvim-treesitter.parsers") local parser = parsers.get_parser(target_bufnr) if parser then @@ -528,34 +556,38 @@ function M.apply(patch) end -- Clamp to valid range + local line_count = vim.api.nvim_buf_line_count(target_bufnr) start_line = math.max(1, start_line) end_line = math.min(line_count, end_line) - -- Replace the range (0-indexed for nvim_buf_set_lines) - vim.api.nvim_buf_set_lines(target_bufnr, start_line - 1, end_line, false, code_lines) + inject_opts.range = { start_line = start_line, end_line = end_line } + elseif patch.injection_strategy == "insert" and patch.injection_range then + inject_opts.range = { start_line = patch.injection_range.start_line } + end - pcall(function() - local logs = require("codetyper.agent.logs") + -- Use smart injection - handles imports automatically + inject_result = inject.inject(target_bufnr, patch.generated_code, inject_opts) + + -- Log injection details + pcall(function() + local logs = require("codetyper.agent.logs") + if inject_result.imports_added > 0 then logs.add({ type = "info", - message = string.format("Replacing lines %d-%d with %d lines of code", start_line, end_line, #code_lines), + message = string.format( + "%s %d import(s), injected %d body line(s)", + inject_result.imports_merged and "Merged" or "Added", + inject_result.imports_added, + inject_result.body_lines + ), + }) + else + logs.add({ + type = "info", + message = string.format("Injected %d line(s) of code", inject_result.body_lines), }) - end) - elseif patch.injection_strategy == "insert" and patch.injection_range then - -- Insert at the specified location - local insert_line = patch.injection_range.start_line - insert_line = math.max(1, math.min(line_count + 1, insert_line)) - vim.api.nvim_buf_set_lines(target_bufnr, insert_line - 1, insert_line - 1, false, code_lines) - else - -- Default: append to end - -- Check if last line is empty, if not add a blank line for spacing - local last_line = vim.api.nvim_buf_get_lines(target_bufnr, line_count - 1, line_count, false)[1] or "" - if last_line:match("%S") then - -- Last line has content, add blank line for spacing - table.insert(code_lines, 1, "") end - vim.api.nvim_buf_set_lines(target_bufnr, line_count, line_count, false, code_lines) - end + end) end) if not ok then @@ -577,6 +609,41 @@ function M.apply(patch) }) end) + -- Learn from successful code generation - this builds neural pathways + -- The more code is successfully applied, the better the brain becomes + pcall(function() + local brain = require("codetyper.brain") + if brain.is_initialized() then + -- Learn the successful pattern + local intent_type = patch.intent and patch.intent.type or "unknown" + local scope_type = patch.scope and patch.scope.type or "file" + local scope_name = patch.scope and patch.scope.name or "" + + -- Create a meaningful summary for this learning + local summary = string.format( + "Generated %s: %s %s in %s", + intent_type, + scope_type, + scope_name ~= "" and scope_name or "", + vim.fn.fnamemodify(patch.target_path or "", ":t") + ) + + brain.learn({ + type = "code_completion", + file = patch.target_path, + timestamp = os.time(), + data = { + intent = intent_type, + code = patch.generated_code:sub(1, 500), -- Store first 500 chars + language = vim.fn.fnamemodify(patch.target_path or "", ":e"), + function_name = scope_name, + prompt = patch.prompt_content, + confidence = patch.confidence or 0.5, + }, + }) + end + end) + return true, nil end diff --git a/lua/codetyper/agent/tools/base.lua b/lua/codetyper/agent/tools/base.lua new file mode 100644 index 0000000..866ddf6 --- /dev/null +++ b/lua/codetyper/agent/tools/base.lua @@ -0,0 +1,128 @@ +---@mod codetyper.agent.tools.base Base tool definition +---@brief [[ +--- Base metatable for all LLM tools. +--- Tools extend this base to provide structured AI capabilities. +---@brief ]] + +---@class CoderToolParam +---@field name string Parameter name +---@field description string Parameter description +---@field type string Parameter type ("string", "number", "boolean", "table") +---@field optional? boolean Whether the parameter is optional +---@field default? any Default value for optional parameters + +---@class CoderToolReturn +---@field name string Return value name +---@field description string Return value description +---@field type string Return type +---@field optional? boolean Whether the return is optional + +---@class CoderToolOpts +---@field on_log? fun(message: string) Log callback +---@field on_complete? fun(result: any, error: string|nil) Completion callback +---@field session_ctx? table Session context +---@field streaming? boolean Whether response is still streaming +---@field confirm? fun(message: string, callback: fun(ok: boolean)) Confirmation callback + +---@class CoderTool +---@field name string Tool identifier +---@field description string|fun(): string Tool description +---@field params CoderToolParam[] Input parameters +---@field returns CoderToolReturn[] Return values +---@field requires_confirmation? boolean Whether tool needs user confirmation +---@field func fun(input: table, opts: CoderToolOpts): any, string|nil Tool implementation + +local M = {} +M.__index = M + +--- Call the tool function +---@param opts CoderToolOpts Options for the tool call +---@return any result +---@return string|nil error +function M:__call(opts, on_log, on_complete) + return self.func(opts, on_log, on_complete) +end + +--- Get the tool description +---@return string +function M:get_description() + if type(self.description) == "function" then + return self.description() + end + return self.description +end + +--- Validate input against parameter schema +---@param input table Input to validate +---@return boolean valid +---@return string|nil error +function M:validate_input(input) + if not self.params then + return true + end + + for _, param in ipairs(self.params) do + local value = input[param.name] + + -- Check required parameters + if not param.optional and value == nil then + return false, string.format("Missing required parameter: %s", param.name) + end + + -- Type checking + if value ~= nil then + local actual_type = type(value) + local expected_type = param.type + + -- Handle special types + if expected_type == "integer" and actual_type == "number" then + if math.floor(value) ~= value then + return false, string.format("Parameter %s must be an integer", param.name) + end + elseif expected_type ~= actual_type and expected_type ~= "any" then + return false, string.format("Parameter %s must be %s, got %s", param.name, expected_type, actual_type) + end + end + end + + return true +end + +--- Generate JSON schema for the tool (for LLM function calling) +---@return table schema +function M:to_schema() + local properties = {} + local required = {} + + for _, param in ipairs(self.params or {}) do + local prop = { + type = param.type == "integer" and "number" or param.type, + description = param.description, + } + + if param.default ~= nil then + prop.default = param.default + end + + properties[param.name] = prop + + if not param.optional then + table.insert(required, param.name) + end + end + + return { + type = "function", + function_def = { + name = self.name, + description = self:get_description(), + parameters = { + type = "object", + properties = properties, + required = required, + }, + }, + } +end + +return M diff --git a/lua/codetyper/agent/tools/bash.lua b/lua/codetyper/agent/tools/bash.lua new file mode 100644 index 0000000..4730584 --- /dev/null +++ b/lua/codetyper/agent/tools/bash.lua @@ -0,0 +1,198 @@ +---@mod codetyper.agent.tools.bash Shell command execution tool +---@brief [[ +--- Tool for executing shell commands with safety checks. +---@brief ]] + +local Base = require("codetyper.agent.tools.base") + +---@class CoderTool +local M = setmetatable({}, Base) + +M.name = "bash" + +M.description = [[Executes a bash command in a shell. + +IMPORTANT RULES: +- Do NOT use bash to read files (use 'view' tool instead) +- Do NOT use bash to modify files (use 'write' or 'edit' tools instead) +- Do NOT use interactive commands (vim, nano, less, etc.) +- Commands timeout after 2 minutes by default + +Allowed uses: +- Running builds (make, npm run build, cargo build) +- Running tests (npm test, pytest, cargo test) +- Git operations (git status, git diff, git commit) +- Package management (npm install, pip install) +- System info commands (ls, pwd, which)]] + +M.params = { + { + name = "command", + description = "The shell command to execute", + type = "string", + }, + { + name = "cwd", + description = "Working directory for the command (optional)", + type = "string", + optional = true, + }, + { + name = "timeout", + description = "Timeout in milliseconds (default: 120000)", + type = "integer", + optional = true, + }, +} + +M.returns = { + { + name = "stdout", + description = "Command output", + type = "string", + }, + { + name = "error", + description = "Error message if command failed", + type = "string", + optional = true, + }, +} + +M.requires_confirmation = true + +--- Banned commands for safety +local BANNED_COMMANDS = { + "rm -rf /", + "rm -rf /*", + "dd if=/dev/zero", + "mkfs", + ":(){ :|:& };:", + "> /dev/sda", +} + +--- Banned patterns +local BANNED_PATTERNS = { + "curl.*|.*sh", + "wget.*|.*sh", + "rm%s+%-rf%s+/", +} + +--- Check if command is safe +---@param command string +---@return boolean safe +---@return string|nil reason +local function is_safe_command(command) + -- Check exact matches + for _, banned in ipairs(BANNED_COMMANDS) do + if command == banned then + return false, "Command is banned for safety" + end + end + + -- Check patterns + for _, pattern in ipairs(BANNED_PATTERNS) do + if command:match(pattern) then + return false, "Command matches banned pattern" + end + end + + return true +end + +---@param input {command: string, cwd?: string, timeout?: integer} +---@param opts CoderToolOpts +---@return string|nil result +---@return string|nil error +function M.func(input, opts) + if not input.command then + return nil, "command is required" + end + + -- Safety check + local safe, reason = is_safe_command(input.command) + if not safe then + return nil, reason + end + + -- Confirmation required + if M.requires_confirmation and opts.confirm then + local confirmed = false + local confirm_error = nil + + opts.confirm("Execute command: " .. input.command, function(ok) + if not ok then + confirm_error = "User declined command execution" + end + confirmed = ok + end) + + -- Wait for confirmation (in async context, this would be handled differently) + if confirm_error then + return nil, confirm_error + end + end + + -- Log the operation + if opts.on_log then + opts.on_log("Executing: " .. input.command) + end + + -- Prepare command + local cwd = input.cwd or vim.fn.getcwd() + local timeout = input.timeout or 120000 + + -- Execute command + local output = "" + local exit_code = 0 + + local job_opts = { + command = "bash", + args = { "-c", input.command }, + cwd = cwd, + on_stdout = function(_, data) + if data then + output = output .. table.concat(data, "\n") + end + end, + on_stderr = function(_, data) + if data then + output = output .. table.concat(data, "\n") + end + end, + on_exit = function(_, code) + exit_code = code + end, + } + + -- Run synchronously with timeout + local Job = require("plenary.job") + local job = Job:new(job_opts) + + job:sync(timeout) + exit_code = job.code or 0 + output = table.concat(job:result() or {}, "\n") + + -- Also get stderr + local stderr = table.concat(job:stderr_result() or {}, "\n") + if stderr and stderr ~= "" then + output = output .. "\n" .. stderr + end + + -- Check result + if exit_code ~= 0 then + local error_msg = string.format("Command failed with exit code %d: %s", exit_code, output) + if opts.on_complete then + opts.on_complete(nil, error_msg) + end + return nil, error_msg + end + + if opts.on_complete then + opts.on_complete(output, nil) + end + + return output, nil +end + +return M diff --git a/lua/codetyper/agent/tools/edit.lua b/lua/codetyper/agent/tools/edit.lua new file mode 100644 index 0000000..b5ef998 --- /dev/null +++ b/lua/codetyper/agent/tools/edit.lua @@ -0,0 +1,429 @@ +---@mod codetyper.agent.tools.edit File editing tool with fallback matching +---@brief [[ +--- Tool for making targeted edits to files using search/replace. +--- Implements multiple fallback strategies for robust matching. +--- Inspired by opencode's 9-strategy approach. +---@brief ]] + +local Base = require("codetyper.agent.tools.base") + +---@class CoderTool +local M = setmetatable({}, Base) + +M.name = "edit" + +M.description = [[Makes a targeted edit to a file by replacing text. + +The old_string should match the content you want to replace. The tool uses multiple +matching strategies with fallbacks: +1. Exact match +2. Whitespace-normalized match +3. Indentation-flexible match +4. Line-trimmed match +5. Fuzzy anchor-based match + +For creating new files, use old_string="" and provide the full content in new_string. +For large changes, consider using 'write' tool instead.]] + +M.params = { + { + name = "path", + description = "Path to the file to edit", + type = "string", + }, + { + name = "old_string", + description = "Text to find and replace (empty string to create new file or append)", + type = "string", + }, + { + name = "new_string", + description = "Text to replace with", + type = "string", + }, +} + +M.returns = { + { + name = "success", + description = "Whether the edit was applied", + type = "boolean", + }, + { + name = "error", + description = "Error message if edit failed", + type = "string", + optional = true, + }, +} + +M.requires_confirmation = false + +--- Normalize line endings to LF +---@param str string +---@return string +local function normalize_line_endings(str) + return str:gsub("\r\n", "\n"):gsub("\r", "\n") +end + +--- Strategy 1: Exact match +---@param content string File content +---@param old_str string String to find +---@return number|nil start_pos +---@return number|nil end_pos +local function exact_match(content, old_str) + local pos = content:find(old_str, 1, true) + if pos then + return pos, pos + #old_str - 1 + end + return nil, nil +end + +--- Strategy 2: Whitespace-normalized match +--- Collapses all whitespace to single spaces +---@param content string +---@param old_str string +---@return number|nil start_pos +---@return number|nil end_pos +local function whitespace_normalized_match(content, old_str) + local function normalize_ws(s) + return s:gsub("%s+", " "):gsub("^%s+", ""):gsub("%s+$", "") + end + + local norm_old = normalize_ws(old_str) + local lines = vim.split(content, "\n") + + -- Try to find matching block + for i = 1, #lines do + local block = {} + local block_start = nil + + for j = i, #lines do + table.insert(block, lines[j]) + local block_text = table.concat(block, "\n") + local norm_block = normalize_ws(block_text) + + if norm_block == norm_old then + -- Found match + local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n") + local start_pos = #before + (i > 1 and 2 or 1) + local end_pos = start_pos + #block_text - 1 + return start_pos, end_pos + end + + -- If block is already longer than target, stop + if #norm_block > #norm_old then + break + end + end + end + + return nil, nil +end + +--- Strategy 3: Indentation-flexible match +--- Ignores leading whitespace differences +---@param content string +---@param old_str string +---@return number|nil start_pos +---@return number|nil end_pos +local function indentation_flexible_match(content, old_str) + local function strip_indent(s) + local lines = vim.split(s, "\n") + local result = {} + for _, line in ipairs(lines) do + table.insert(result, line:gsub("^%s+", "")) + end + return table.concat(result, "\n") + end + + local stripped_old = strip_indent(old_str) + local lines = vim.split(content, "\n") + local old_lines = vim.split(old_str, "\n") + local num_old_lines = #old_lines + + for i = 1, #lines - num_old_lines + 1 do + local block = vim.list_slice(lines, i, i + num_old_lines - 1) + local block_text = table.concat(block, "\n") + + if strip_indent(block_text) == stripped_old then + local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n") + local start_pos = #before + (i > 1 and 2 or 1) + local end_pos = start_pos + #block_text - 1 + return start_pos, end_pos + end + end + + return nil, nil +end + +--- Strategy 4: Line-trimmed match +--- Trims each line before comparing +---@param content string +---@param old_str string +---@return number|nil start_pos +---@return number|nil end_pos +local function line_trimmed_match(content, old_str) + local function trim_lines(s) + local lines = vim.split(s, "\n") + local result = {} + for _, line in ipairs(lines) do + table.insert(result, line:match("^%s*(.-)%s*$")) + end + return table.concat(result, "\n") + end + + local trimmed_old = trim_lines(old_str) + local lines = vim.split(content, "\n") + local old_lines = vim.split(old_str, "\n") + local num_old_lines = #old_lines + + for i = 1, #lines - num_old_lines + 1 do + local block = vim.list_slice(lines, i, i + num_old_lines - 1) + local block_text = table.concat(block, "\n") + + if trim_lines(block_text) == trimmed_old then + local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n") + local start_pos = #before + (i > 1 and 2 or 1) + local end_pos = start_pos + #block_text - 1 + return start_pos, end_pos + end + end + + return nil, nil +end + +--- Calculate Levenshtein distance between two strings +---@param s1 string +---@param s2 string +---@return number +local function levenshtein(s1, s2) + local len1, len2 = #s1, #s2 + local matrix = {} + + for i = 0, len1 do + matrix[i] = { [0] = i } + end + for j = 0, len2 do + matrix[0][j] = j + end + + for i = 1, len1 do + for j = 1, len2 do + local cost = s1:sub(i, i) == s2:sub(j, j) and 0 or 1 + matrix[i][j] = math.min( + matrix[i - 1][j] + 1, + matrix[i][j - 1] + 1, + matrix[i - 1][j - 1] + cost + ) + end + end + + return matrix[len1][len2] +end + +--- Strategy 5: Fuzzy anchor-based match +--- Uses first and last lines as anchors, allows fuzzy matching in between +---@param content string +---@param old_str string +---@param threshold? number Similarity threshold (0-1), default 0.8 +---@return number|nil start_pos +---@return number|nil end_pos +local function fuzzy_anchor_match(content, old_str, threshold) + threshold = threshold or 0.8 + + local old_lines = vim.split(old_str, "\n") + if #old_lines < 2 then + return nil, nil + end + + local first_line = old_lines[1]:match("^%s*(.-)%s*$") + local last_line = old_lines[#old_lines]:match("^%s*(.-)%s*$") + local content_lines = vim.split(content, "\n") + + -- Find potential start positions + local candidates = {} + for i, line in ipairs(content_lines) do + local trimmed = line:match("^%s*(.-)%s*$") + if trimmed == first_line or ( + #first_line > 0 and + 1 - (levenshtein(trimmed, first_line) / math.max(#trimmed, #first_line)) >= threshold + ) then + table.insert(candidates, i) + end + end + + -- For each candidate, look for matching end + for _, start_idx in ipairs(candidates) do + local expected_end = start_idx + #old_lines - 1 + if expected_end <= #content_lines then + local end_line = content_lines[expected_end]:match("^%s*(.-)%s*$") + if end_line == last_line or ( + #last_line > 0 and + 1 - (levenshtein(end_line, last_line) / math.max(#end_line, #last_line)) >= threshold + ) then + -- Calculate positions + local before = table.concat(vim.list_slice(content_lines, 1, start_idx - 1), "\n") + local block = table.concat(vim.list_slice(content_lines, start_idx, expected_end), "\n") + local start_pos = #before + (start_idx > 1 and 2 or 1) + local end_pos = start_pos + #block - 1 + return start_pos, end_pos + end + end + end + + return nil, nil +end + +--- Try all matching strategies in order +---@param content string File content +---@param old_str string String to find +---@return number|nil start_pos +---@return number|nil end_pos +---@return string strategy_used +local function find_match(content, old_str) + -- Strategy 1: Exact match + local start_pos, end_pos = exact_match(content, old_str) + if start_pos then + return start_pos, end_pos, "exact" + end + + -- Strategy 2: Whitespace-normalized + start_pos, end_pos = whitespace_normalized_match(content, old_str) + if start_pos then + return start_pos, end_pos, "whitespace_normalized" + end + + -- Strategy 3: Indentation-flexible + start_pos, end_pos = indentation_flexible_match(content, old_str) + if start_pos then + return start_pos, end_pos, "indentation_flexible" + end + + -- Strategy 4: Line-trimmed + start_pos, end_pos = line_trimmed_match(content, old_str) + if start_pos then + return start_pos, end_pos, "line_trimmed" + end + + -- Strategy 5: Fuzzy anchor + start_pos, end_pos = fuzzy_anchor_match(content, old_str) + if start_pos then + return start_pos, end_pos, "fuzzy_anchor" + end + + return nil, nil, "none" +end + +---@param input {path: string, old_string: string, new_string: string} +---@param opts CoderToolOpts +---@return boolean|nil result +---@return string|nil error +function M.func(input, opts) + if not input.path then + return nil, "path is required" + end + if input.old_string == nil then + return nil, "old_string is required" + end + if input.new_string == nil then + return nil, "new_string is required" + end + + -- Log the operation + if opts.on_log then + opts.on_log("Editing file: " .. input.path) + end + + -- Resolve path + local path = input.path + if not vim.startswith(path, "/") then + path = vim.fn.getcwd() .. "/" .. path + end + + -- Normalize inputs + local old_str = normalize_line_endings(input.old_string) + local new_str = normalize_line_endings(input.new_string) + + -- Handle new file creation (empty old_string) + if old_str == "" then + -- Create parent directories + local dir = vim.fn.fnamemodify(path, ":h") + if vim.fn.isdirectory(dir) == 0 then + vim.fn.mkdir(dir, "p") + end + + -- Write new file + local lines = vim.split(new_str, "\n", { plain = true }) + local ok = pcall(vim.fn.writefile, lines, path) + + if not ok then + return nil, "Failed to create file: " .. input.path + end + + -- Reload buffer if open + local bufnr = vim.fn.bufnr(path) + if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then + vim.api.nvim_buf_call(bufnr, function() + vim.cmd("edit!") + end) + end + + if opts.on_complete then + opts.on_complete(true, nil) + end + + return true, nil + end + + -- Check if file exists + if vim.fn.filereadable(path) ~= 1 then + return nil, "File not found: " .. input.path + end + + -- Read current content + local lines = vim.fn.readfile(path) + if not lines then + return nil, "Failed to read file: " .. input.path + end + + local content = normalize_line_endings(table.concat(lines, "\n")) + + -- Find match using fallback strategies + local start_pos, end_pos, strategy = find_match(content, old_str) + + if not start_pos then + return nil, "old_string not found in file (tried 5 matching strategies)" + end + + if opts.on_log then + opts.on_log("Match found using strategy: " .. strategy) + end + + -- Perform replacement + local new_content = content:sub(1, start_pos - 1) .. new_str .. content:sub(end_pos + 1) + + -- Write back + local new_lines = vim.split(new_content, "\n", { plain = true }) + local ok = pcall(vim.fn.writefile, new_lines, path) + + if not ok then + return nil, "Failed to write file: " .. input.path + end + + -- Reload buffer if open + local bufnr = vim.fn.bufnr(path) + if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then + vim.api.nvim_buf_call(bufnr, function() + vim.cmd("edit!") + end) + end + + if opts.on_complete then + opts.on_complete(true, nil) + end + + return true, nil +end + +return M diff --git a/lua/codetyper/agent/tools/glob.lua b/lua/codetyper/agent/tools/glob.lua new file mode 100644 index 0000000..7ede066 --- /dev/null +++ b/lua/codetyper/agent/tools/glob.lua @@ -0,0 +1,146 @@ +---@mod codetyper.agent.tools.glob File pattern matching tool +---@brief [[ +--- Tool for finding files by glob pattern. +---@brief ]] + +local Base = require("codetyper.agent.tools.base") + +---@class CoderTool +local M = setmetatable({}, Base) + +M.name = "glob" + +M.description = [[Finds files matching a glob pattern. + +Example patterns: +- "**/*.lua" - All Lua files +- "src/**/*.ts" - TypeScript files in src +- "**/test_*.py" - Test files in Python]] + +M.params = { + { + name = "pattern", + description = "Glob pattern to match files", + type = "string", + }, + { + name = "path", + description = "Base directory to search in (default: project root)", + type = "string", + optional = true, + }, + { + name = "max_results", + description = "Maximum number of results (default: 100)", + type = "integer", + optional = true, + }, +} + +M.returns = { + { + name = "matches", + description = "JSON array of matching file paths", + type = "string", + }, + { + name = "error", + description = "Error message if glob failed", + type = "string", + optional = true, + }, +} + +M.requires_confirmation = false + +---@param input {pattern: string, path?: string, max_results?: integer} +---@param opts CoderToolOpts +---@return string|nil result +---@return string|nil error +function M.func(input, opts) + if not input.pattern then + return nil, "pattern is required" + end + + -- Log the operation + if opts.on_log then + opts.on_log("Finding files: " .. input.pattern) + end + + -- Resolve base path + local base_path = input.path or vim.fn.getcwd() + if not vim.startswith(base_path, "/") then + base_path = vim.fn.getcwd() .. "/" .. base_path + end + + local max_results = input.max_results or 100 + + -- Use vim.fn.glob or fd if available + local matches = {} + + if vim.fn.executable("fd") == 1 then + -- Use fd for better performance + local Job = require("plenary.job") + + -- Convert glob to fd pattern + local fd_pattern = input.pattern:gsub("%*%*/", ""):gsub("%*", ".*") + + local job = Job:new({ + command = "fd", + args = { + "--type", + "f", + "--max-results", + tostring(max_results), + "--glob", + input.pattern, + base_path, + }, + cwd = base_path, + }) + + job:sync(30000) + matches = job:result() or {} + else + -- Fallback to vim.fn.globpath + local pattern = base_path .. "/" .. input.pattern + local files = vim.fn.glob(pattern, false, true) + + for i, file in ipairs(files) do + if i > max_results then + break + end + -- Make paths relative to base_path + local relative = file:gsub("^" .. vim.pesc(base_path) .. "/", "") + table.insert(matches, relative) + end + end + + -- Clean up matches + local cleaned = {} + for _, match in ipairs(matches) do + if match and match ~= "" then + -- Make relative if absolute + local relative = match + if vim.startswith(match, base_path) then + relative = match:sub(#base_path + 2) + end + table.insert(cleaned, relative) + end + end + + -- Return as JSON + local result = vim.json.encode({ + matches = cleaned, + total = #cleaned, + truncated = #cleaned >= max_results, + }) + + if opts.on_complete then + opts.on_complete(result, nil) + end + + return result, nil +end + +return M diff --git a/lua/codetyper/agent/tools/grep.lua b/lua/codetyper/agent/tools/grep.lua new file mode 100644 index 0000000..b1396df --- /dev/null +++ b/lua/codetyper/agent/tools/grep.lua @@ -0,0 +1,150 @@ +---@mod codetyper.agent.tools.grep Search tool +---@brief [[ +--- Tool for searching file contents using ripgrep. +---@brief ]] + +local Base = require("codetyper.agent.tools.base") + +---@class CoderTool +local M = setmetatable({}, Base) + +M.name = "grep" + +M.description = [[Searches for a pattern in files using ripgrep. + +Returns file paths and matching lines. Use this to find code by content. + +Example patterns: +- "function foo" - Find function definitions +- "import.*react" - Find React imports +- "TODO|FIXME" - Find todo comments]] + +M.params = { + { + name = "pattern", + description = "Regular expression pattern to search for", + type = "string", + }, + { + name = "path", + description = "Directory or file to search in (default: project root)", + type = "string", + optional = true, + }, + { + name = "include", + description = "File glob pattern to include (e.g., '*.lua')", + type = "string", + optional = true, + }, + { + name = "max_results", + description = "Maximum number of results (default: 50)", + type = "integer", + optional = true, + }, +} + +M.returns = { + { + name = "matches", + description = "JSON array of matches with file, line_number, and content", + type = "string", + }, + { + name = "error", + description = "Error message if search failed", + type = "string", + optional = true, + }, +} + +M.requires_confirmation = false + +---@param input {pattern: string, path?: string, include?: string, max_results?: integer} +---@param opts CoderToolOpts +---@return string|nil result +---@return string|nil error +function M.func(input, opts) + if not input.pattern then + return nil, "pattern is required" + end + + -- Log the operation + if opts.on_log then + opts.on_log("Searching for: " .. input.pattern) + end + + -- Build ripgrep command + local path = input.path or vim.fn.getcwd() + local max_results = input.max_results or 50 + + -- Resolve path + if not vim.startswith(path, "/") then + path = vim.fn.getcwd() .. "/" .. path + end + + -- Check if ripgrep is available + if vim.fn.executable("rg") ~= 1 then + return nil, "ripgrep (rg) is not installed" + end + + -- Build command args + local args = { + "--json", + "--max-count", + tostring(max_results), + "--no-heading", + } + + if input.include then + table.insert(args, "--glob") + table.insert(args, input.include) + end + + table.insert(args, input.pattern) + table.insert(args, path) + + -- Execute ripgrep + local Job = require("plenary.job") + local job = Job:new({ + command = "rg", + args = args, + cwd = vim.fn.getcwd(), + }) + + job:sync(30000) -- 30 second timeout + + local results = job:result() or {} + local matches = {} + + -- Parse JSON output + for _, line in ipairs(results) do + if line and line ~= "" then + local ok, parsed = pcall(vim.json.decode, line) + if ok and parsed.type == "match" then + local data = parsed.data + table.insert(matches, { + file = data.path.text, + line_number = data.line_number, + content = data.lines.text:gsub("\n$", ""), + }) + end + end + end + + -- Return as JSON + local result = vim.json.encode({ + matches = matches, + total = #matches, + truncated = #matches >= max_results, + }) + + if opts.on_complete then + opts.on_complete(result, nil) + end + + return result, nil +end + +return M diff --git a/lua/codetyper/agent/tools/init.lua b/lua/codetyper/agent/tools/init.lua new file mode 100644 index 0000000..386ece6 --- /dev/null +++ b/lua/codetyper/agent/tools/init.lua @@ -0,0 +1,308 @@ +---@mod codetyper.agent.tools Tool registry and orchestration +---@brief [[ +--- Registry for LLM tools with execution and schema generation. +--- Inspired by avante.nvim's tool system. +---@brief ]] + +local M = {} + +--- Registered tools +---@type table +local tools = {} + +--- Tool execution history for current session +---@type table[] +local execution_history = {} + +--- Register a tool +---@param tool CoderTool Tool to register +function M.register(tool) + if not tool.name then + error("Tool must have a name") + end + tools[tool.name] = tool +end + +--- Unregister a tool +---@param name string Tool name +function M.unregister(name) + tools[name] = nil +end + +--- Get a tool by name +---@param name string Tool name +---@return CoderTool|nil +function M.get(name) + return tools[name] +end + +--- Get all registered tools +---@return table +function M.get_all() + return tools +end + +--- Get tools as a list +---@param filter? fun(tool: CoderTool): boolean Optional filter function +---@return CoderTool[] +function M.list(filter) + local result = {} + for _, tool in pairs(tools) do + if not filter or filter(tool) then + table.insert(result, tool) + end + end + return result +end + +--- Generate schemas for all tools (for LLM function calling) +---@param filter? fun(tool: CoderTool): boolean Optional filter function +---@return table[] schemas +function M.get_schemas(filter) + local schemas = {} + for _, tool in pairs(tools) do + if not filter or filter(tool) then + if tool.to_schema then + table.insert(schemas, tool:to_schema()) + end + end + end + return schemas +end + +--- Execute a tool by name +---@param name string Tool name +---@param input table Input parameters +---@param opts CoderToolOpts Execution options +---@return any result +---@return string|nil error +function M.execute(name, input, opts) + local tool = tools[name] + if not tool then + return nil, "Unknown tool: " .. name + end + + -- Validate input + if tool.validate_input then + local valid, err = tool:validate_input(input) + if not valid then + return nil, err + end + end + + -- Log execution + if opts.on_log then + opts.on_log(string.format("Executing tool: %s", name)) + end + + -- Track execution + local execution = { + tool = name, + input = input, + start_time = os.time(), + status = "running", + } + table.insert(execution_history, execution) + + -- Execute the tool + local result, err = tool.func(input, opts) + + -- Update execution record + execution.end_time = os.time() + execution.status = err and "error" or "completed" + execution.result = result + execution.error = err + + return result, err +end + +--- Process a tool call from LLM response +---@param tool_call table Tool call from LLM (name + input) +---@param opts CoderToolOpts Execution options +---@return any result +---@return string|nil error +function M.process_tool_call(tool_call, opts) + local name = tool_call.name or tool_call.function_name + local input = tool_call.input or tool_call.arguments or {} + + -- Parse JSON arguments if string + if type(input) == "string" then + local ok, parsed = pcall(vim.json.decode, input) + if ok then + input = parsed + else + return nil, "Failed to parse tool arguments: " .. input + end + end + + return M.execute(name, input, opts) +end + +--- Get execution history +---@param limit? number Max entries to return +---@return table[] +function M.get_history(limit) + if not limit then + return execution_history + end + + local result = {} + local start = math.max(1, #execution_history - limit + 1) + for i = start, #execution_history do + table.insert(result, execution_history[i]) + end + return result +end + +--- Clear execution history +function M.clear_history() + execution_history = {} +end + +--- Load built-in tools +function M.load_builtins() + -- View file tool + local view = require("codetyper.agent.tools.view") + M.register(view) + + -- Bash tool + local bash = require("codetyper.agent.tools.bash") + M.register(bash) + + -- Grep tool + local grep = require("codetyper.agent.tools.grep") + M.register(grep) + + -- Glob tool + local glob = require("codetyper.agent.tools.glob") + M.register(glob) + + -- Write file tool + local write = require("codetyper.agent.tools.write") + M.register(write) + + -- Edit tool + local edit = require("codetyper.agent.tools.edit") + M.register(edit) +end + +--- Initialize tools system +function M.setup() + M.load_builtins() +end + +--- Get tool definitions for LLM (lazy-loaded, OpenAI format) +--- This is accessed as M.definitions property +M.definitions = setmetatable({}, { + __call = function() + -- Ensure tools are loaded + if vim.tbl_count(tools) == 0 then + M.load_builtins() + end + return M.to_openai_format() + end, + __index = function(_, key) + -- Make it work as both function and table + if key == "get" then + return function() + if vim.tbl_count(tools) == 0 then + M.load_builtins() + end + return M.to_openai_format() + end + end + return nil + end, +}) + +--- Get definitions as a function (for backwards compatibility) +function M.get_definitions() + if vim.tbl_count(tools) == 0 then + M.load_builtins() + end + return M.to_openai_format() +end + +--- Convert all tools to OpenAI function calling format +---@param filter? fun(tool: CoderTool): boolean Optional filter function +---@return table[] OpenAI-compatible tool definitions +function M.to_openai_format(filter) + local openai_tools = {} + + for _, tool in pairs(tools) do + if not filter or filter(tool) then + local properties = {} + local required = {} + + for _, param in ipairs(tool.params or {}) do + properties[param.name] = { + type = param.type == "integer" and "number" or param.type, + description = param.description, + } + if param.default ~= nil then + properties[param.name].default = param.default + end + if not param.optional then + table.insert(required, param.name) + end + end + + local description = type(tool.description) == "function" and tool.description() or tool.description + + table.insert(openai_tools, { + type = "function", + ["function"] = { + name = tool.name, + description = description, + parameters = { + type = "object", + properties = properties, + required = required, + }, + }, + }) + end + end + + return openai_tools +end + +--- Convert all tools to Claude tool use format +---@param filter? fun(tool: CoderTool): boolean Optional filter function +---@return table[] Claude-compatible tool definitions +function M.to_claude_format(filter) + local claude_tools = {} + + for _, tool in pairs(tools) do + if not filter or filter(tool) then + local properties = {} + local required = {} + + for _, param in ipairs(tool.params or {}) do + properties[param.name] = { + type = param.type == "integer" and "number" or param.type, + description = param.description, + } + if not param.optional then + table.insert(required, param.name) + end + end + + local description = type(tool.description) == "function" and tool.description() or tool.description + + table.insert(claude_tools, { + name = tool.name, + description = description, + input_schema = { + type = "object", + properties = properties, + required = required, + }, + }) + end + end + + return claude_tools +end + +return M diff --git a/lua/codetyper/agent/tools/view.lua b/lua/codetyper/agent/tools/view.lua new file mode 100644 index 0000000..8e52feb --- /dev/null +++ b/lua/codetyper/agent/tools/view.lua @@ -0,0 +1,149 @@ +---@mod codetyper.agent.tools.view File viewing tool +---@brief [[ +--- Tool for reading file contents with line range support. +---@brief ]] + +local Base = require("codetyper.agent.tools.base") + +---@class CoderTool +local M = setmetatable({}, Base) + +M.name = "view" + +M.description = [[Reads the content of a file. + +Usage notes: +- Provide the file path relative to the project root +- Use start_line and end_line to read specific sections +- If content is truncated, use line ranges to read in chunks +- Returns JSON with content, total_line_count, and is_truncated]] + +M.params = { + { + name = "path", + description = "Path to the file (relative to project root or absolute)", + type = "string", + }, + { + name = "start_line", + description = "Line number to start reading (1-indexed)", + type = "integer", + optional = true, + }, + { + name = "end_line", + description = "Line number to end reading (1-indexed, inclusive)", + type = "integer", + optional = true, + }, +} + +M.returns = { + { + name = "content", + description = "File contents as JSON with content, total_line_count, is_truncated", + type = "string", + }, + { + name = "error", + description = "Error message if file could not be read", + type = "string", + optional = true, + }, +} + +M.requires_confirmation = false + +--- Maximum content size before truncation +local MAX_CONTENT_SIZE = 200 * 1024 -- 200KB + +---@param input {path: string, start_line?: integer, end_line?: integer} +---@param opts CoderToolOpts +---@return string|nil result +---@return string|nil error +function M.func(input, opts) + if not input.path then + return nil, "path is required" + end + + -- Log the operation + if opts.on_log then + opts.on_log("Reading file: " .. input.path) + end + + -- Resolve path + local path = input.path + if not vim.startswith(path, "/") then + -- Relative path - resolve from project root + local root = vim.fn.getcwd() + path = root .. "/" .. path + end + + -- Check if file exists + local stat = vim.uv.fs_stat(path) + if not stat then + return nil, "File not found: " .. input.path + end + + if stat.type == "directory" then + return nil, "Path is a directory: " .. input.path + end + + -- Read file + local lines = vim.fn.readfile(path) + if not lines then + return nil, "Failed to read file: " .. input.path + end + + -- Apply line range + local start_line = input.start_line or 1 + local end_line = input.end_line or #lines + + start_line = math.max(1, start_line) + end_line = math.min(#lines, end_line) + + local total_lines = #lines + local selected_lines = {} + + for i = start_line, end_line do + table.insert(selected_lines, lines[i]) + end + + -- Check for truncation + local content = table.concat(selected_lines, "\n") + local is_truncated = false + + if #content > MAX_CONTENT_SIZE then + -- Truncate content + local truncated_lines = {} + local size = 0 + + for _, line in ipairs(selected_lines) do + size = size + #line + 1 + if size > MAX_CONTENT_SIZE then + is_truncated = true + break + end + table.insert(truncated_lines, line) + end + + content = table.concat(truncated_lines, "\n") + end + + -- Return as JSON + local result = vim.json.encode({ + content = content, + total_line_count = total_lines, + is_truncated = is_truncated, + start_line = start_line, + end_line = end_line, + }) + + if opts.on_complete then + opts.on_complete(result, nil) + end + + return result, nil +end + +return M diff --git a/lua/codetyper/agent/tools/write.lua b/lua/codetyper/agent/tools/write.lua new file mode 100644 index 0000000..e410a56 --- /dev/null +++ b/lua/codetyper/agent/tools/write.lua @@ -0,0 +1,101 @@ +---@mod codetyper.agent.tools.write File writing tool +---@brief [[ +--- Tool for creating or overwriting files. +---@brief ]] + +local Base = require("codetyper.agent.tools.base") + +---@class CoderTool +local M = setmetatable({}, Base) + +M.name = "write" + +M.description = [[Creates or overwrites a file with new content. + +IMPORTANT: +- This will completely replace the file contents +- Use 'edit' tool for partial modifications +- Parent directories will be created if needed]] + +M.params = { + { + name = "path", + description = "Path to the file to write", + type = "string", + }, + { + name = "content", + description = "Content to write to the file", + type = "string", + }, +} + +M.returns = { + { + name = "success", + description = "Whether the file was written successfully", + type = "boolean", + }, + { + name = "error", + description = "Error message if write failed", + type = "string", + optional = true, + }, +} + +M.requires_confirmation = true + +---@param input {path: string, content: string} +---@param opts CoderToolOpts +---@return boolean|nil result +---@return string|nil error +function M.func(input, opts) + if not input.path then + return nil, "path is required" + end + if not input.content then + return nil, "content is required" + end + + -- Log the operation + if opts.on_log then + opts.on_log("Writing file: " .. input.path) + end + + -- Resolve path + local path = input.path + if not vim.startswith(path, "/") then + path = vim.fn.getcwd() .. "/" .. path + end + + -- Create parent directories + local dir = vim.fn.fnamemodify(path, ":h") + if vim.fn.isdirectory(dir) == 0 then + vim.fn.mkdir(dir, "p") + end + + -- Write the file + local lines = vim.split(input.content, "\n", { plain = true }) + local ok = pcall(vim.fn.writefile, lines, path) + + if not ok then + return nil, "Failed to write file: " .. path + end + + -- Reload buffer if open + local bufnr = vim.fn.bufnr(path) + if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then + vim.api.nvim_buf_call(bufnr, function() + vim.cmd("edit!") + end) + end + + if opts.on_complete then + opts.on_complete(true, nil) + end + + return true, nil +end + +return M diff --git a/lua/codetyper/agent/worker.lua b/lua/codetyper/agent/worker.lua index 2af866c..cc3601e 100644 --- a/lua/codetyper/agent/worker.lua +++ b/lua/codetyper/agent/worker.lua @@ -224,6 +224,86 @@ local function format_attached_files(attached_files) return table.concat(parts, "") end +--- Get coder companion file path for a target file +---@param target_path string Target file path +---@return string|nil Coder file path if exists +local function get_coder_companion_path(target_path) + if not target_path or target_path == "" then + return nil + end + + -- Skip if target is already a coder file + if target_path:match("%.coder%.") then + return nil + end + + local dir = vim.fn.fnamemodify(target_path, ":h") + local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension + local ext = vim.fn.fnamemodify(target_path, ":e") + + local coder_path = dir .. "/" .. name .. ".coder." .. ext + if vim.fn.filereadable(coder_path) == 1 then + return coder_path + end + + return nil +end + +--- Read and format coder companion context (business logic, pseudo-code) +---@param target_path string Target file path +---@return string Formatted coder context +local function get_coder_context(target_path) + local coder_path = get_coder_companion_path(target_path) + if not coder_path then + return "" + end + + local ok, lines = pcall(function() + return vim.fn.readfile(coder_path) + end) + + if not ok or not lines or #lines == 0 then + return "" + end + + local content = table.concat(lines, "\n") + + -- Skip if only template comments (no actual content) + local stripped = content:gsub("^%s*", ""):gsub("%s*$", "") + if stripped == "" then + return "" + end + + -- Check if there's meaningful content (not just template) + local has_content = false + for _, line in ipairs(lines) do + -- Skip comment lines that are part of the template + local trimmed = line:gsub("^%s*", "") + if not trimmed:match("^[%-#/]+%s*Coder companion") + and not trimmed:match("^[%-#/]+%s*Use /@ @/") + and not trimmed:match("^[%-#/]+%s*Example:") + and not trimmed:match("^ 0 then + local memories = { "\n\n--- Learned Patterns & Conventions ---" } + for _, node in ipairs(result.nodes) do + if node.c then + local summary = node.c.s or "" + local detail = node.c.d or "" + if summary ~= "" then + table.insert(memories, "• " .. summary) + if detail ~= "" and #detail < 200 then + table.insert(memories, " " .. detail) + end + end + end + end + if #memories > 1 then + brain_context = table.concat(memories, "\n") + end + end + end + end) + + -- Combine all context sources: brain memories first, then coder context, attached files, indexed + local extra_context = brain_context .. coder_context .. attached_content .. indexed_content -- Build context with scope information local context = { @@ -502,21 +627,21 @@ function M.start(worker) end end, worker.timeout_ms) - -- Get client and execute - local client, client_err = get_client(worker.worker_type) - if not client then - M.complete(worker, nil, client_err) - return - end - local prompt, context = build_prompt(worker.event) - -- Call the LLM - client.generate(prompt, context, function(response, err, usage) + -- Check if smart selection is enabled (memory-based provider selection) + local use_smart_selection = false + pcall(function() + local codetyper = require("codetyper") + local config = codetyper.get_config() + use_smart_selection = config.llm.smart_selection ~= false -- Default to true + end) + + -- Define the response handler + local function handle_response(response, err, usage_or_metadata) -- Cancel timeout timer if worker.timer then pcall(function() - -- Timer might have already fired if type(worker.timer) == "userdata" and worker.timer.stop then worker.timer:stop() end @@ -527,8 +652,45 @@ function M.start(worker) return -- Already timed out or cancelled end + -- Extract usage from metadata if smart_generate was used + local usage = usage_or_metadata + if type(usage_or_metadata) == "table" and usage_or_metadata.provider then + -- This is metadata from smart_generate + usage = nil + -- Update worker type to reflect actual provider used + worker.worker_type = usage_or_metadata.provider + -- Log if pondering occurred + if usage_or_metadata.pondered then + pcall(function() + local logs = require("codetyper.agent.logs") + logs.add({ + type = "info", + message = string.format( + "Pondering: %s (agreement: %.0f%%)", + usage_or_metadata.corrected and "corrected" or "validated", + (usage_or_metadata.agreement or 1) * 100 + ), + }) + end) + end + end + M.complete(worker, response, err, usage) - end) + end + + -- Use smart selection or direct client + if use_smart_selection then + local llm = require("codetyper.llm") + llm.smart_generate(prompt, context, handle_response) + else + -- Get client and execute directly + local client, client_err = get_client(worker.worker_type) + if not client then + M.complete(worker, nil, client_err) + return + end + client.generate(prompt, context, handle_response) + end end --- Complete worker execution diff --git a/lua/codetyper/ask.lua b/lua/codetyper/ask.lua index 0282f0b..c3180db 100644 --- a/lua/codetyper/ask.lua +++ b/lua/codetyper/ask.lua @@ -312,7 +312,9 @@ local function append_log_to_output(entry) } local icon = icons[entry.level] or "•" - local formatted = string.format("[%s] %s %s", entry.timestamp, icon, entry.message) + -- Sanitize message - replace newlines with spaces to prevent nvim_buf_set_lines error + local sanitized_message = entry.message:gsub("\n", " "):gsub("\r", "") + local formatted = string.format("[%s] %s %s", entry.timestamp, icon, sanitized_message) vim.schedule(function() if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then diff --git a/lua/codetyper/autocmds.lua b/lua/codetyper/autocmds.lua index 26123f3..93f4987 100644 --- a/lua/codetyper/autocmds.lua +++ b/lua/codetyper/autocmds.lua @@ -18,6 +18,16 @@ local processed_prompts = {} --- Track if we're currently asking for preferences local asking_preference = false +--- Track if we're currently processing prompts (busy flag) +local is_processing = false + +--- Track the previous mode for visual mode detection +local previous_mode = "n" + +--- Debounce timer for prompt processing +local prompt_process_timer = nil +local PROMPT_PROCESS_DEBOUNCE_MS = 200 -- Wait 200ms after mode change before processing + --- Generate a unique key for a prompt ---@param bufnr number Buffer number ---@param prompt table Prompt object @@ -64,6 +74,20 @@ function M.setup() desc = "Check for closed prompt tags on InsertLeave", }) + -- Track mode changes for visual mode detection + vim.api.nvim_create_autocmd("ModeChanged", { + group = group, + pattern = "*", + callback = function(ev) + -- Extract old mode from pattern (format: "old_mode:new_mode") + local old_mode = ev.match:match("^(.-):") + if old_mode then + previous_mode = old_mode + end + end, + desc = "Track previous mode for visual mode detection", + }) + -- Auto-process prompts when entering normal mode (works on ALL files) vim.api.nvim_create_autocmd("ModeChanged", { group = group, @@ -74,10 +98,33 @@ function M.setup() if buftype ~= "" then return end - -- Slight delay to let buffer settle - vim.defer_fn(function() + + -- Skip if currently processing (avoid concurrent processing) + if is_processing then + return + end + + -- Skip if coming from visual mode (v, V, CTRL-V) - user is still editing + if previous_mode == "v" or previous_mode == "V" or previous_mode == "\22" then + return + end + + -- Cancel any pending processing timer + if prompt_process_timer then + prompt_process_timer:stop() + prompt_process_timer = nil + end + + -- Debounced processing - wait for user to truly be idle + prompt_process_timer = vim.defer_fn(function() + prompt_process_timer = nil + -- Double-check we're still in normal mode + local mode = vim.api.nvim_get_mode().mode + if mode ~= "n" then + return + end M.check_all_prompts_with_preference() - end, 50) + end, PROMPT_PROCESS_DEBOUNCE_MS) end, desc = "Auto-process closed prompts when entering normal mode", }) @@ -92,6 +139,10 @@ function M.setup() if buftype ~= "" then return end + -- Skip if currently processing + if is_processing then + return + end local mode = vim.api.nvim_get_mode().mode if mode == "n" then M.check_all_prompts_with_preference() @@ -291,6 +342,12 @@ end --- Check if the buffer has a newly closed prompt and auto-process (works on ANY file) function M.check_for_closed_prompt() + -- Skip if already processing + if is_processing then + return + end + is_processing = true + local config = get_config_safe() local parser = require("codetyper.parser") @@ -299,6 +356,7 @@ function M.check_for_closed_prompt() -- Skip if no file if current_file == "" then + is_processing = false return end @@ -308,6 +366,7 @@ function M.check_for_closed_prompt() local lines = vim.api.nvim_buf_get_lines(bufnr, line - 1, line, false) if #lines == 0 then + is_processing = false return end @@ -323,6 +382,7 @@ function M.check_for_closed_prompt() -- Check if already processed if processed_prompts[prompt_key] then + is_processing = false return end @@ -366,23 +426,38 @@ function M.check_for_closed_prompt() -- Clean prompt content (strip file references) local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content)) - -- Resolve scope in target file FIRST (need it to adjust intent) - local target_bufnr = vim.fn.bufnr(target_path) - if target_bufnr == -1 then - target_bufnr = bufnr - end + -- Check if we're working from a coder file + local is_from_coder_file = utils.is_coder_file(current_file) + -- Resolve scope in target file FIRST (need it to adjust intent) + -- Only resolve scope if NOT from coder file (line numbers don't apply) + local target_bufnr = vim.fn.bufnr(target_path) local scope = nil local scope_text = nil local scope_range = nil - scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1) - if scope and scope.type ~= "file" then - scope_text = scope.text - scope_range = { - start_line = scope.range.start_row, - end_line = scope.range.end_row, - } + if not is_from_coder_file then + -- Prompt is in the actual source file, use line position for scope + if target_bufnr == -1 then + target_bufnr = bufnr + end + scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1) + if scope and scope.type ~= "file" then + scope_text = scope.text + scope_range = { + start_line = scope.range.start_row, + end_line = scope.range.end_row, + } + end + else + -- Prompt is in coder file - load target if needed, but don't use scope + -- Code from coder files should append to target by default + if target_bufnr == -1 then + target_bufnr = vim.fn.bufadd(target_path) + if target_bufnr ~= 0 then + vim.fn.bufload(target_bufnr) + end + end end -- Detect intent from prompt @@ -390,7 +465,8 @@ function M.check_for_closed_prompt() -- IMPORTANT: If prompt is inside a function/method and intent is "add", -- override to "complete" since we're completing the function body - if scope and (scope.type == "function" or scope.type == "method") then + -- But NOT for coder files - they should use "add/append" by default + if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then if intent.type == "add" or intent.action == "insert" or intent.action == "append" then -- Override to complete the function instead of adding new code intent = { @@ -403,6 +479,16 @@ function M.check_for_closed_prompt() end end + -- For coder files, default to "add" with "append" action + if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then + intent = { + type = intent.type == "complete" and "add" or intent.type, + confidence = intent.confidence, + action = "append", + keywords = intent.keywords, + } + end + -- Determine priority based on intent local priority = 2 -- Normal if intent.type == "fix" or intent.type == "complete" then @@ -446,6 +532,7 @@ function M.check_for_closed_prompt() end end end + is_processing = false end --- Check and process all closed prompts in the buffer (works on ANY file) @@ -507,7 +594,8 @@ function M.check_all_prompts() -- Get target path - for coder files, get the target; for regular files, use self local target_path - if utils.is_coder_file(current_file) then + local is_from_coder_file = utils.is_coder_file(current_file) + if is_from_coder_file then target_path = utils.get_target_path(current_file) else target_path = current_file @@ -520,22 +608,33 @@ function M.check_all_prompts() local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content)) -- Resolve scope in target file FIRST (need it to adjust intent) + -- Only resolve scope if NOT from coder file (line numbers don't apply) local target_bufnr = vim.fn.bufnr(target_path) - if target_bufnr == -1 then - target_bufnr = bufnr -- Use current buffer if target not loaded - end - local scope = nil local scope_text = nil local scope_range = nil - scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1) - if scope and scope.type ~= "file" then - scope_text = scope.text - scope_range = { - start_line = scope.range.start_row, - end_line = scope.range.end_row, - } + if not is_from_coder_file then + -- Prompt is in the actual source file, use line position for scope + if target_bufnr == -1 then + target_bufnr = bufnr + end + scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1) + if scope and scope.type ~= "file" then + scope_text = scope.text + scope_range = { + start_line = scope.range.start_row, + end_line = scope.range.end_row, + } + end + else + -- Prompt is in coder file - load target if needed + if target_bufnr == -1 then + target_bufnr = vim.fn.bufadd(target_path) + if target_bufnr ~= 0 then + vim.fn.bufload(target_bufnr) + end + end end -- Detect intent from prompt @@ -543,7 +642,8 @@ function M.check_all_prompts() -- IMPORTANT: If prompt is inside a function/method and intent is "add", -- override to "complete" since we're completing the function body - if scope and (scope.type == "function" or scope.type == "method") then + -- But NOT for coder files - they should use "add/append" by default + if not is_from_coder_file and scope and (scope.type == "function" or scope.type == "method") then if intent.type == "add" or intent.action == "insert" or intent.action == "append" then -- Override to complete the function instead of adding new code intent = { @@ -556,6 +656,16 @@ function M.check_all_prompts() end end + -- For coder files, default to "add" with "append" action + if is_from_coder_file and (intent.action == "replace" or intent.type == "complete") then + intent = { + type = intent.type == "complete" and "add" or intent.type, + confidence = intent.confidence, + action = "append", + keywords = intent.keywords, + } + end + -- Determine priority based on intent local priority = 2 if intent.type == "fix" or intent.type == "complete" then @@ -932,20 +1042,17 @@ function M.update_brain_from_file(filepath) local summary = vim.fn.fnamemodify(filepath, ":t") .. " - " .. table.concat(parts, "; ") - -- Learn this pattern + -- Learn this pattern - use "pattern_detected" type to match the pattern learner brain.learn({ - type = "pattern", + type = "pattern_detected", file = filepath, - content = { - summary = summary, - detail = #functions .. " functions, " .. #classes .. " classes", - code = nil, - }, - context = { - file = filepath, + timestamp = os.time(), + data = { + name = summary, + description = #functions .. " functions, " .. #classes .. " classes", language = ext, - functions = functions, - classes = classes, + symbols = vim.tbl_map(function(f) return f.name end, functions), + example = nil, }, }) end @@ -997,6 +1104,126 @@ end --- Auto-index a file by creating/opening its coder companion ---@param bufnr number Buffer number +--- Directories to ignore for coder file creation +local ignored_directories = { + ".git", + ".coder", + ".claude", + ".vscode", + ".idea", + "node_modules", + "vendor", + "dist", + "build", + "target", + "__pycache__", + ".cache", + ".npm", + ".yarn", + "coverage", + ".next", + ".nuxt", + ".svelte-kit", + "out", + "bin", + "obj", +} + +--- Files to ignore for coder file creation (exact names or patterns) +local ignored_files = { + -- Git files + ".gitignore", + ".gitattributes", + ".gitmodules", + -- Lock files + "package-lock.json", + "yarn.lock", + "pnpm-lock.yaml", + "Cargo.lock", + "Gemfile.lock", + "poetry.lock", + "composer.lock", + -- Config files that don't need coder companions + ".env", + ".env.local", + ".env.development", + ".env.production", + ".eslintrc", + ".eslintrc.json", + ".prettierrc", + ".prettierrc.json", + ".editorconfig", + ".dockerignore", + "Dockerfile", + "docker-compose.yml", + "docker-compose.yaml", + ".npmrc", + ".yarnrc", + ".nvmrc", + "tsconfig.json", + "jsconfig.json", + "babel.config.js", + "webpack.config.js", + "vite.config.js", + "rollup.config.js", + "jest.config.js", + "vitest.config.js", + ".stylelintrc", + "tailwind.config.js", + "postcss.config.js", + -- Other non-code files + "README.md", + "CHANGELOG.md", + "LICENSE", + "LICENSE.md", + "CONTRIBUTING.md", + "Makefile", + "CMakeLists.txt", +} + +--- Check if a file path contains an ignored directory +---@param filepath string Full file path +---@return boolean +local function is_in_ignored_directory(filepath) + for _, dir in ipairs(ignored_directories) do + -- Check for /dirname/ or /dirname at end + if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then + return true + end + -- Also check for dirname/ at start (relative paths) + if filepath:match("^" .. dir .. "/") then + return true + end + end + return false +end + +--- Check if a file should be ignored for coder companion creation +---@param filepath string Full file path +---@return boolean +local function should_ignore_for_coder(filepath) + local filename = vim.fn.fnamemodify(filepath, ":t") + + -- Check exact filename matches + for _, ignored in ipairs(ignored_files) do + if filename == ignored then + return true + end + end + + -- Check if file starts with dot (hidden/config files) + if filename:match("^%.") then + return true + end + + -- Check if in ignored directory + if is_in_ignored_directory(filepath) then + return true + end + + return false +end + function M.auto_index_file(bufnr) -- Skip if buffer is invalid if not vim.api.nvim_buf_is_valid(bufnr) then @@ -1031,6 +1258,11 @@ function M.auto_index_file(bufnr) return end + -- Skip ignored directories and files (node_modules, .git, config files, etc.) + if should_ignore_for_coder(filepath) then + return + end + -- Skip if auto_index is disabled in config local codetyper = require("codetyper") local config = codetyper.get_config() @@ -1047,23 +1279,137 @@ function M.auto_index_file(bufnr) -- Check if coder file already exists local coder_exists = utils.file_exists(coder_path) - -- Create coder file with template if it doesn't exist + -- Create coder file with pseudo-code context if it doesn't exist if not coder_exists then local filename = vim.fn.fnamemodify(filepath, ":t") - local template = string.format( - [[-- Coder companion for %s --- Use /@ @/ tags to write pseudo-code prompts --- Example: --- /@ --- Add a function that validates user input --- - Check for empty strings --- - Validate email format --- @/ + local ext = vim.fn.fnamemodify(filepath, ":e") -]], - filename - ) - utils.write_file(coder_path, template) + -- Determine comment style based on extension + local comment_prefix = "--" + local comment_block_start = "--[[" + local comment_block_end = "]]" + if ext == "ts" or ext == "tsx" or ext == "js" or ext == "jsx" or ext == "java" or ext == "c" or ext == "cpp" or ext == "cs" or ext == "go" or ext == "rs" then + comment_prefix = "//" + comment_block_start = "/*" + comment_block_end = "*/" + elseif ext == "py" or ext == "rb" or ext == "yaml" or ext == "yml" then + comment_prefix = "#" + comment_block_start = '"""' + comment_block_end = '"""' + end + + -- Read target file to analyze its structure + local content = "" + pcall(function() + local lines = vim.fn.readfile(filepath) + if lines then + content = table.concat(lines, "\n") + end + end) + + -- Extract structure from the file + local functions = extract_functions(content, ext) + local classes = extract_classes(content, ext) + local imports = extract_imports(content, ext) + + -- Build pseudo-code context + local pseudo_code = {} + + -- Header + table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════") + table.insert(pseudo_code, comment_prefix .. " CODER COMPANION: " .. filename) + table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════") + table.insert(pseudo_code, comment_prefix .. " This file describes the business logic and behavior of " .. filename) + table.insert(pseudo_code, comment_prefix .. " Edit this pseudo-code to guide code generation.") + table.insert(pseudo_code, comment_prefix .. " Use /@ @/ tags for specific generation requests.") + table.insert(pseudo_code, comment_prefix .. "") + + -- Module purpose + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " MODULE PURPOSE:") + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " TODO: Describe what this module/file is responsible for") + table.insert(pseudo_code, comment_prefix .. " Example: \"Handles user authentication and session management\"") + table.insert(pseudo_code, comment_prefix .. "") + + -- Dependencies section + if #imports > 0 then + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " DEPENDENCIES:") + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + for _, imp in ipairs(imports) do + table.insert(pseudo_code, comment_prefix .. " • " .. imp) + end + table.insert(pseudo_code, comment_prefix .. "") + end + + -- Classes section + if #classes > 0 then + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " CLASSES:") + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + for _, class in ipairs(classes) do + table.insert(pseudo_code, comment_prefix .. "") + table.insert(pseudo_code, comment_prefix .. " class " .. class.name .. ":") + table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - describe what this class represents") + table.insert(pseudo_code, comment_prefix .. " RESPONSIBILITIES:") + table.insert(pseudo_code, comment_prefix .. " - TODO: list main responsibilities") + end + table.insert(pseudo_code, comment_prefix .. "") + end + + -- Functions section + if #functions > 0 then + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " FUNCTIONS:") + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + for _, func in ipairs(functions) do + table.insert(pseudo_code, comment_prefix .. "") + table.insert(pseudo_code, comment_prefix .. " " .. func.name .. "():") + table.insert(pseudo_code, comment_prefix .. " PURPOSE: TODO - what does this function do?") + table.insert(pseudo_code, comment_prefix .. " INPUTS: TODO - describe parameters") + table.insert(pseudo_code, comment_prefix .. " OUTPUTS: TODO - describe return value") + table.insert(pseudo_code, comment_prefix .. " BEHAVIOR:") + table.insert(pseudo_code, comment_prefix .. " - TODO: describe step-by-step logic") + end + table.insert(pseudo_code, comment_prefix .. "") + end + + -- If empty file, provide starter template + if #functions == 0 and #classes == 0 then + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " PLANNED STRUCTURE:") + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " TODO: Describe what you want to build in this file") + table.insert(pseudo_code, comment_prefix .. "") + table.insert(pseudo_code, comment_prefix .. " Example pseudo-code:") + table.insert(pseudo_code, comment_prefix .. " /@") + table.insert(pseudo_code, comment_prefix .. " Create a module that:") + table.insert(pseudo_code, comment_prefix .. " 1. Exports a main function") + table.insert(pseudo_code, comment_prefix .. " 2. Handles errors gracefully") + table.insert(pseudo_code, comment_prefix .. " 3. Returns structured data") + table.insert(pseudo_code, comment_prefix .. " @/") + table.insert(pseudo_code, comment_prefix .. "") + end + + -- Business rules section + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " BUSINESS RULES:") + table.insert(pseudo_code, comment_prefix .. " ─────────────────────────────────────────────────────────────") + table.insert(pseudo_code, comment_prefix .. " TODO: Document any business rules, constraints, or requirements") + table.insert(pseudo_code, comment_prefix .. " Example:") + table.insert(pseudo_code, comment_prefix .. " - Users must be authenticated before accessing this feature") + table.insert(pseudo_code, comment_prefix .. " - Data must be validated before saving") + table.insert(pseudo_code, comment_prefix .. " - Errors should be logged but not exposed to users") + table.insert(pseudo_code, comment_prefix .. "") + + -- Footer with generation tags example + table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════") + table.insert(pseudo_code, comment_prefix .. " Use /@ @/ tags below to request code generation:") + table.insert(pseudo_code, comment_prefix .. " ═══════════════════════════════════════════════════════════") + table.insert(pseudo_code, "") + + utils.write_file(coder_path, table.concat(pseudo_code, "\n")) end -- Notify user about the coder companion diff --git a/lua/codetyper/brain/graph/query.lua b/lua/codetyper/brain/graph/query.lua index a2e9f9a..b685c58 100644 --- a/lua/codetyper/brain/graph/query.lua +++ b/lua/codetyper/brain/graph/query.lua @@ -111,7 +111,7 @@ function M.compute_relevance(node, opts) return score end ---- Traverse graph from seed nodes +--- Traverse graph from seed nodes (basic traversal) ---@param seed_ids string[] Starting node IDs ---@param depth number Traversal depth ---@param edge_types? EdgeType[] Edge types to follow @@ -157,6 +157,73 @@ local function traverse(seed_ids, depth, edge_types) return discovered end +--- Spreading activation - mimics human associative memory +--- Activation spreads from seed nodes along edges, decaying by weight +--- Nodes accumulate activation from multiple paths (like neural pathways) +---@param seed_activations table Initial activations {node_id: activation} +---@param max_iterations number Max spread iterations (default 3) +---@param decay number Activation decay per hop (default 0.5) +---@param threshold number Minimum activation to continue spreading (default 0.1) +---@return table Final activations {node_id: accumulated_activation} +local function spreading_activation(seed_activations, max_iterations, decay, threshold) + local edge_mod = get_edge_module() + max_iterations = max_iterations or 3 + decay = decay or 0.5 + threshold = threshold or 0.1 + + -- Accumulated activation for each node + local activation = {} + for node_id, act in pairs(seed_activations) do + activation[node_id] = act + end + + -- Current frontier with their activation levels + local frontier = {} + for node_id, act in pairs(seed_activations) do + frontier[node_id] = act + end + + -- Spread activation iteratively + for _ = 1, max_iterations do + local next_frontier = {} + + for source_id, source_activation in pairs(frontier) do + -- Get all outgoing edges + local edges = edge_mod.get_edges(source_id, nil, "both") + + for _, edge in ipairs(edges) do + -- Determine target (could be source or target of edge) + local target_id = edge.s == source_id and edge.t or edge.s + + -- Calculate spreading activation + -- Activation = source_activation * edge_weight * decay + local edge_weight = edge.p and edge.p.w or 0.5 + local spread_amount = source_activation * edge_weight * decay + + -- Only spread if above threshold + if spread_amount >= threshold then + -- Accumulate activation (multiple paths add up) + activation[target_id] = (activation[target_id] or 0) + spread_amount + + -- Add to next frontier if not already processed with higher activation + if not next_frontier[target_id] or next_frontier[target_id] < spread_amount then + next_frontier[target_id] = spread_amount + end + end + end + end + + -- Stop if no more spreading + if vim.tbl_count(next_frontier) == 0 then + break + end + + frontier = next_frontier + end + + return activation +end + --- Execute a query across all dimensions ---@param opts QueryOpts Query options ---@return QueryResult @@ -236,28 +303,49 @@ function M.execute(opts) end end - -- 4. Combine and deduplicate + -- 4. Combine all found nodes and compute seed activations local all_nodes = {} + local seed_activations = {} + for _, category in pairs(results) do for id, node in pairs(category) do if not all_nodes[id] then all_nodes[id] = node + -- Compute initial activation based on relevance + local relevance = M.compute_relevance(node, opts) + seed_activations[id] = relevance end end end - -- 5. Score and rank + -- 5. Apply spreading activation - like human associative memory + -- Activation spreads from seed nodes along edges, accumulating + -- Nodes connected to multiple relevant seeds get higher activation + local final_activations = spreading_activation( + seed_activations, + opts.spread_iterations or 3, -- How far activation spreads + opts.spread_decay or 0.5, -- How much activation decays per hop + opts.spread_threshold or 0.05 -- Minimum activation to continue spreading + ) + + -- 6. Score and rank by combined activation local scored = {} - for id, node in pairs(all_nodes) do - local relevance = M.compute_relevance(node, opts) - table.insert(scored, { node = node, relevance = relevance }) + for id, activation in pairs(final_activations) do + local node = all_nodes[id] or node_mod.get(id) + if node then + all_nodes[id] = node + -- Final score = spreading activation + base relevance + local base_relevance = M.compute_relevance(node, opts) + local final_score = (activation * 0.6) + (base_relevance * 0.4) + table.insert(scored, { node = node, relevance = final_score, activation = activation }) + end end table.sort(scored, function(a, b) return a.relevance > b.relevance end) - -- 6. Apply limit + -- 7. Apply limit local limit = opts.limit or 50 local result_nodes = {} local truncated = #scored > limit @@ -266,7 +354,7 @@ function M.execute(opts) table.insert(result_nodes, scored[i].node) end - -- 7. Get edges between result nodes + -- 8. Get edges between result nodes local edge_mod = get_edge_module() local result_edges = {} local node_ids = {} @@ -291,11 +379,17 @@ function M.execute(opts) file_count = vim.tbl_count(results.file), temporal_count = vim.tbl_count(results.temporal), total_scored = #scored, + seed_nodes = vim.tbl_count(seed_activations), + activated_nodes = vim.tbl_count(final_activations), }, truncated = truncated, } end +--- Expose spreading activation for direct use +--- Useful for custom activation patterns or debugging +M.spreading_activation = spreading_activation + --- Find nodes by file ---@param filepath string File path ---@param limit? number Max results diff --git a/lua/codetyper/brain/learners/pattern.lua b/lua/codetyper/brain/learners/pattern.lua index c07276e..eabbb57 100644 --- a/lua/codetyper/brain/learners/pattern.lua +++ b/lua/codetyper/brain/learners/pattern.lua @@ -9,6 +9,10 @@ local M = {} ---@param event LearnEvent Learning event ---@return boolean function M.detect(event) + if not event or not event.type then + return false + end + local valid_types = { "code_completion", "file_indexed", diff --git a/lua/codetyper/commands.lua b/lua/codetyper/commands.lua index 7be1cb7..ab08885 100644 --- a/lua/codetyper/commands.lua +++ b/lua/codetyper/commands.lua @@ -287,6 +287,118 @@ local function cmd_agent_stop() end end +--- Run the agentic loop with a task +---@param task string The task to accomplish +---@param agent_name? string Optional agent name +local function cmd_agentic_run(task, agent_name) + local agentic = require("codetyper.agent.agentic") + local logs_panel = require("codetyper.logs_panel") + local logs = require("codetyper.agent.logs") + + -- Open logs panel + logs_panel.open() + + logs.info("Starting agentic task: " .. task:sub(1, 50) .. "...") + utils.notify("Running agentic task...", vim.log.levels.INFO) + + -- Get current file for context + local current_file = vim.fn.expand("%:p") + local files = {} + if current_file ~= "" then + table.insert(files, current_file) + end + + agentic.run({ + task = task, + files = files, + agent = agent_name or "coder", + on_status = function(status) + logs.thinking(status) + end, + on_tool_start = function(name, args) + logs.info("Tool: " .. name) + end, + on_tool_end = function(name, result, err) + if err then + logs.error(name .. " failed: " .. err) + else + logs.debug(name .. " completed") + end + end, + on_file_change = function(path, action) + logs.info("File " .. action .. ": " .. path) + end, + on_message = function(msg) + if msg.role == "assistant" and type(msg.content) == "string" and msg.content ~= "" then + logs.thinking(msg.content:sub(1, 100) .. "...") + end + end, + on_complete = function(result, err) + if err then + logs.error("Task failed: " .. err) + utils.notify("Agentic task failed: " .. err, vim.log.levels.ERROR) + else + logs.info("Task completed successfully") + utils.notify("Agentic task completed!", vim.log.levels.INFO) + if result and result ~= "" then + -- Show summary in a float + vim.schedule(function() + vim.notify("Result:\n" .. result:sub(1, 500), vim.log.levels.INFO) + end) + end + end + end, + }) +end + +--- List available agents +local function cmd_agentic_list() + local agentic = require("codetyper.agent.agentic") + local agents = agentic.list_agents() + + local lines = { + "Available Agents", + "================", + "", + } + + for _, agent in ipairs(agents) do + local badge = agent.builtin and "[builtin]" or "[custom]" + table.insert(lines, string.format(" %s %s", agent.name, badge)) + table.insert(lines, string.format(" %s", agent.description)) + table.insert(lines, "") + end + + table.insert(lines, "Use :CoderAgenticRun [agent] to run a task") + table.insert(lines, "Use :CoderAgenticInit to create custom agents") + + utils.notify(table.concat(lines, "\n")) +end + +--- Initialize .coder/agents/ and .coder/rules/ directories +local function cmd_agentic_init() + local agentic = require("codetyper.agent.agentic") + agentic.init() + + local agents_dir = vim.fn.getcwd() .. "/.coder/agents" + local rules_dir = vim.fn.getcwd() .. "/.coder/rules" + + local lines = { + "Initialized Coder directories:", + "", + " " .. agents_dir, + " - example.md (template for custom agents)", + "", + " " .. rules_dir, + " - code-style.md (template for project rules)", + "", + "Edit these files to customize agent behavior.", + "Create new .md files to add more agents/rules.", + } + + utils.notify(table.concat(lines, "\n")) +end + --- Show chat type switcher modal (Ask/Agent) local function cmd_type_toggle() local switcher = require("codetyper.chat_switcher") @@ -844,6 +956,65 @@ end --- Main command dispatcher ---@param args table Command arguments +--- Show LLM accuracy statistics +local function cmd_llm_stats() + local llm = require("codetyper.llm") + local stats = llm.get_accuracy_stats() + + local lines = { + "LLM Provider Accuracy Statistics", + "================================", + "", + string.format("Ollama:"), + string.format(" Total requests: %d", stats.ollama.total), + string.format(" Correct: %d", stats.ollama.correct), + string.format(" Accuracy: %.1f%%", stats.ollama.accuracy * 100), + "", + string.format("Copilot:"), + string.format(" Total requests: %d", stats.copilot.total), + string.format(" Correct: %d", stats.copilot.correct), + string.format(" Accuracy: %.1f%%", stats.copilot.accuracy * 100), + "", + "Note: Smart selection prefers Ollama when brain memories", + "provide enough context. Accuracy improves over time via", + "pondering (verification with other LLMs).", + } + + vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO) +end + +--- Report feedback on last LLM response +---@param was_good boolean Whether the response was good +local function cmd_llm_feedback(was_good) + local llm = require("codetyper.llm") + -- Get the last used provider from logs or default + local provider = "ollama" -- Default assumption + + -- Try to get actual last provider from logs + pcall(function() + local logs = require("codetyper.agent.logs") + local entries = logs.get(10) + for i = #entries, 1, -1 do + local entry = entries[i] + if entry.message and entry.message:match("^LLM:") then + provider = entry.message:match("LLM: (%w+)") or provider + break + end + end + end) + + llm.report_feedback(provider, was_good) + local feedback_type = was_good and "positive" or "negative" + utils.notify(string.format("Reported %s feedback for %s", feedback_type, provider), vim.log.levels.INFO) +end + +--- Reset LLM accuracy statistics +local function cmd_llm_reset_stats() + local selector = require("codetyper.llm.selector") + selector.reset_accuracy_stats() + utils.notify("LLM accuracy statistics reset", vim.log.levels.INFO) +end + local function coder_cmd(args) local subcommand = args.fargs[1] or "toggle" @@ -872,6 +1043,17 @@ local function coder_cmd(args) ["logs-toggle"] = cmd_logs_toggle, ["queue-status"] = cmd_queue_status, ["queue-process"] = cmd_queue_process, + -- Agentic commands + ["agentic-run"] = function(args) + local task = table.concat(vim.list_slice(args.fargs, 2), " ") + if task == "" then + utils.notify("Usage: Coder agentic-run [agent]", vim.log.levels.WARN) + return + end + cmd_agentic_run(task) + end, + ["agentic-list"] = cmd_agentic_list, + ["agentic-init"] = cmd_agentic_init, ["index-project"] = cmd_index_project, ["index-status"] = cmd_index_status, memories = cmd_memories, @@ -901,6 +1083,41 @@ local function coder_cmd(args) end end end, + -- LLM smart selection commands + ["llm-stats"] = cmd_llm_stats, + ["llm-feedback-good"] = function() + cmd_llm_feedback(true) + end, + ["llm-feedback-bad"] = function() + cmd_llm_feedback(false) + end, + ["llm-reset-stats"] = cmd_llm_reset_stats, + -- Cost tracking commands + ["cost"] = function() + local cost = require("codetyper.cost") + cost.toggle() + end, + ["cost-clear"] = function() + local cost = require("codetyper.cost") + cost.clear() + end, + -- Credentials management commands + ["add-api-key"] = function() + local credentials = require("codetyper.credentials") + credentials.interactive_add() + end, + ["remove-api-key"] = function() + local credentials = require("codetyper.credentials") + credentials.interactive_remove() + end, + ["credentials"] = function() + local credentials = require("codetyper.credentials") + credentials.show_status() + end, + ["switch-provider"] = function() + local credentials = require("codetyper.credentials") + credentials.interactive_switch_provider() + end, } local cmd_fn = commands[subcommand] @@ -922,10 +1139,14 @@ function M.setup() "ask", "ask-close", "ask-toggle", "ask-clear", "transform", "transform-cursor", "agent", "agent-close", "agent-toggle", "agent-stop", + "agentic-run", "agentic-list", "agentic-init", "type-toggle", "logs-toggle", "queue-status", "queue-process", "index-project", "index-status", "memories", "forget", "auto-toggle", "auto-set", + "llm-stats", "llm-feedback-good", "llm-feedback-bad", "llm-reset-stats", + "cost", "cost-clear", + "add-api-key", "remove-api-key", "credentials", "switch-provider", } end, desc = "Codetyper.nvim commands", @@ -997,6 +1218,31 @@ function M.setup() cmd_agent_stop() end, { desc = "Stop running agent" }) + -- Agentic commands (full IDE-like agent functionality) + vim.api.nvim_create_user_command("CoderAgenticRun", function(opts) + local task = opts.args + if task == "" then + vim.ui.input({ prompt = "Task: " }, function(input) + if input and input ~= "" then + cmd_agentic_run(input) + end + end) + else + cmd_agentic_run(task) + end + end, { + desc = "Run agentic task (IDE-like multi-file changes)", + nargs = "*", + }) + + vim.api.nvim_create_user_command("CoderAgenticList", function() + cmd_agentic_list() + end, { desc = "List available agents" }) + + vim.api.nvim_create_user_command("CoderAgenticInit", function() + cmd_agentic_init() + end, { desc = "Initialize .coder/agents/ and .coder/rules/ directories" }) + -- Chat type switcher command vim.api.nvim_create_user_command("CoderType", function() cmd_type_toggle() @@ -1075,6 +1321,147 @@ function M.setup() end, }) + -- Brain feedback command - teach the brain from your experience + vim.api.nvim_create_user_command("CoderFeedback", function(opts) + local brain = require("codetyper.brain") + if not brain.is_initialized() then + vim.notify("Brain not initialized", vim.log.levels.WARN) + return + end + + local feedback_type = opts.args:lower() + local current_file = vim.fn.expand("%:p") + + if feedback_type == "good" or feedback_type == "accept" or feedback_type == "+" then + -- Learn positive feedback + brain.learn({ + type = "user_feedback", + file = current_file, + timestamp = os.time(), + data = { + feedback = "accepted", + description = "User marked code as good/accepted", + }, + }) + vim.notify("Brain: Learned positive feedback ✓", vim.log.levels.INFO) + + elseif feedback_type == "bad" or feedback_type == "reject" or feedback_type == "-" then + -- Learn negative feedback + brain.learn({ + type = "user_feedback", + file = current_file, + timestamp = os.time(), + data = { + feedback = "rejected", + description = "User marked code as bad/rejected", + }, + }) + vim.notify("Brain: Learned negative feedback ✗", vim.log.levels.INFO) + + elseif feedback_type == "stats" or feedback_type == "status" then + -- Show brain stats + local stats = brain.stats() + local msg = string.format( + "Brain Stats:\n• Nodes: %d\n• Edges: %d\n• Pending: %d\n• Deltas: %d", + stats.node_count or 0, + stats.edge_count or 0, + stats.pending_changes or 0, + stats.delta_count or 0 + ) + vim.notify(msg, vim.log.levels.INFO) + + else + vim.notify("Usage: CoderFeedback ", vim.log.levels.INFO) + end + end, { + desc = "Give feedback to the brain (good/bad/stats)", + nargs = "?", + complete = function() + return { "good", "bad", "stats" } + end, + }) + + -- Brain stats command + vim.api.nvim_create_user_command("CoderBrain", function(opts) + local brain = require("codetyper.brain") + if not brain.is_initialized() then + vim.notify("Brain not initialized", vim.log.levels.WARN) + return + end + + local action = opts.args:lower() + + if action == "stats" or action == "" then + local stats = brain.stats() + local lines = { + "╭─────────────────────────────────╮", + "│ CODETYPER BRAIN │", + "╰─────────────────────────────────╯", + "", + string.format(" Nodes: %d", stats.node_count or 0), + string.format(" Edges: %d", stats.edge_count or 0), + string.format(" Deltas: %d", stats.delta_count or 0), + string.format(" Pending: %d", stats.pending_changes or 0), + "", + " The more you use Codetyper,", + " the smarter it becomes!", + } + vim.notify(table.concat(lines, "\n"), vim.log.levels.INFO) + + elseif action == "commit" then + local hash = brain.commit("Manual commit") + if hash then + vim.notify("Brain: Committed changes (hash: " .. hash:sub(1, 8) .. ")", vim.log.levels.INFO) + else + vim.notify("Brain: Nothing to commit", vim.log.levels.INFO) + end + + elseif action == "flush" then + brain.flush() + vim.notify("Brain: Flushed to disk", vim.log.levels.INFO) + + elseif action == "prune" then + local pruned = brain.prune() + vim.notify("Brain: Pruned " .. pruned .. " low-value nodes", vim.log.levels.INFO) + + else + vim.notify("Usage: CoderBrain ", vim.log.levels.INFO) + end + end, { + desc = "Brain management commands", + nargs = "?", + complete = function() + return { "stats", "commit", "flush", "prune" } + end, + }) + + -- Cost estimation command + vim.api.nvim_create_user_command("CoderCost", function() + local cost = require("codetyper.cost") + cost.toggle() + end, { desc = "Show LLM cost estimation window" }) + + -- Credentials management commands + vim.api.nvim_create_user_command("CoderAddApiKey", function() + local credentials = require("codetyper.credentials") + credentials.interactive_add() + end, { desc = "Add or update LLM provider API key" }) + + vim.api.nvim_create_user_command("CoderRemoveApiKey", function() + local credentials = require("codetyper.credentials") + credentials.interactive_remove() + end, { desc = "Remove LLM provider credentials" }) + + vim.api.nvim_create_user_command("CoderCredentials", function() + local credentials = require("codetyper.credentials") + credentials.show_status() + end, { desc = "Show credentials status" }) + + vim.api.nvim_create_user_command("CoderSwitchProvider", function() + local credentials = require("codetyper.credentials") + credentials.interactive_switch_provider() + end, { desc = "Switch active LLM provider" }) + -- Setup default keymaps M.setup_keymaps() end diff --git a/lua/codetyper/cost.lua b/lua/codetyper/cost.lua new file mode 100644 index 0000000..3e95922 --- /dev/null +++ b/lua/codetyper/cost.lua @@ -0,0 +1,750 @@ +---@mod codetyper.cost Cost estimation for LLM usage +---@brief [[ +--- Tracks token usage and estimates costs based on model pricing. +--- Prices are per 1M tokens. Persists usage data in the brain. +---@brief ]] + +local M = {} + +local utils = require("codetyper.utils") + +--- Cost history file name +local COST_HISTORY_FILE = "cost_history.json" + +--- Get path to cost history file +---@return string File path +local function get_history_path() + local root = utils.get_project_root() + return root .. "/.coder/" .. COST_HISTORY_FILE +end + +--- Default model for savings comparison (what you'd pay if not using Ollama) +M.comparison_model = "gpt-4o" + +--- Models considered "free" (Ollama, local, Copilot subscription) +M.free_models = { + ["ollama"] = true, + ["codellama"] = true, + ["llama2"] = true, + ["llama3"] = true, + ["mistral"] = true, + ["deepseek-coder"] = true, + ["copilot"] = true, +} + +--- Model pricing table (per 1M tokens in USD) +---@type table +M.pricing = { + -- GPT-5.x series + ["gpt-5.2"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, + ["gpt-5.1"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 }, + ["gpt-5-nano"] = { input = 0.05, cached_input = 0.005, output = 0.40 }, + ["gpt-5.2-chat-latest"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, + ["gpt-5.1-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5.2-codex"] = { input = 1.75, cached_input = 0.175, output = 14.00 }, + ["gpt-5.1-codex-max"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5.1-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + ["gpt-5.2-pro"] = { input = 21.00, cached_input = nil, output = 168.00 }, + ["gpt-5-pro"] = { input = 15.00, cached_input = nil, output = 120.00 }, + ["gpt-5.1-codex-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 }, + ["gpt-5-search-api"] = { input = 1.25, cached_input = 0.125, output = 10.00 }, + + -- GPT-4.x series + ["gpt-4.1"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, + ["gpt-4.1-mini"] = { input = 0.40, cached_input = 0.10, output = 1.60 }, + ["gpt-4.1-nano"] = { input = 0.10, cached_input = 0.025, output = 0.40 }, + ["gpt-4o"] = { input = 2.50, cached_input = 1.25, output = 10.00 }, + ["gpt-4o-2024-05-13"] = { input = 5.00, cached_input = nil, output = 15.00 }, + ["gpt-4o-mini"] = { input = 0.15, cached_input = 0.075, output = 0.60 }, + + -- Realtime models + ["gpt-realtime"] = { input = 4.00, cached_input = 0.40, output = 16.00 }, + ["gpt-realtime-mini"] = { input = 0.60, cached_input = 0.06, output = 2.40 }, + ["gpt-4o-realtime-preview"] = { input = 5.00, cached_input = 2.50, output = 20.00 }, + ["gpt-4o-mini-realtime-preview"] = { input = 0.60, cached_input = 0.30, output = 2.40 }, + + -- Audio models + ["gpt-audio"] = { input = 2.50, cached_input = nil, output = 10.00 }, + ["gpt-audio-mini"] = { input = 0.60, cached_input = nil, output = 2.40 }, + ["gpt-4o-audio-preview"] = { input = 2.50, cached_input = nil, output = 10.00 }, + ["gpt-4o-mini-audio-preview"] = { input = 0.15, cached_input = nil, output = 0.60 }, + + -- O-series reasoning models + ["o1"] = { input = 15.00, cached_input = 7.50, output = 60.00 }, + ["o1-pro"] = { input = 150.00, cached_input = nil, output = 600.00 }, + ["o3-pro"] = { input = 20.00, cached_input = nil, output = 80.00 }, + ["o3"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, + ["o3-deep-research"] = { input = 10.00, cached_input = 2.50, output = 40.00 }, + ["o4-mini"] = { input = 1.10, cached_input = 0.275, output = 4.40 }, + ["o4-mini-deep-research"] = { input = 2.00, cached_input = 0.50, output = 8.00 }, + ["o3-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 }, + ["o1-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 }, + + -- Codex + ["codex-mini-latest"] = { input = 1.50, cached_input = 0.375, output = 6.00 }, + + -- Search models + ["gpt-4o-mini-search-preview"] = { input = 0.15, cached_input = nil, output = 0.60 }, + ["gpt-4o-search-preview"] = { input = 2.50, cached_input = nil, output = 10.00 }, + + -- Computer use + ["computer-use-preview"] = { input = 3.00, cached_input = nil, output = 12.00 }, + + -- Image models + ["gpt-image-1.5"] = { input = 5.00, cached_input = 1.25, output = 10.00 }, + ["chatgpt-image-latest"] = { input = 5.00, cached_input = 1.25, output = 10.00 }, + ["gpt-image-1"] = { input = 5.00, cached_input = 1.25, output = nil }, + ["gpt-image-1-mini"] = { input = 2.00, cached_input = 0.20, output = nil }, + + -- Claude models (Anthropic) + ["claude-3-opus"] = { input = 15.00, cached_input = 7.50, output = 75.00 }, + ["claude-3-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 }, + ["claude-3-haiku"] = { input = 0.25, cached_input = 0.125, output = 1.25 }, + ["claude-3.5-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 }, + ["claude-3.5-haiku"] = { input = 0.80, cached_input = 0.40, output = 4.00 }, + + -- Ollama/Local models (free) + ["ollama"] = { input = 0, cached_input = 0, output = 0 }, + ["codellama"] = { input = 0, cached_input = 0, output = 0 }, + ["llama2"] = { input = 0, cached_input = 0, output = 0 }, + ["llama3"] = { input = 0, cached_input = 0, output = 0 }, + ["mistral"] = { input = 0, cached_input = 0, output = 0 }, + ["deepseek-coder"] = { input = 0, cached_input = 0, output = 0 }, + + -- Copilot (included in subscription, but tracking usage) + ["copilot"] = { input = 0, cached_input = 0, output = 0 }, +} + +---@class CostUsage +---@field model string Model name +---@field input_tokens number Input tokens used +---@field output_tokens number Output tokens used +---@field cached_tokens number Cached input tokens +---@field timestamp number Unix timestamp +---@field cost number Calculated cost in USD + +---@class CostState +---@field usage CostUsage[] Current session usage +---@field all_usage CostUsage[] All historical usage from brain +---@field session_start number Session start timestamp +---@field win number|nil Window handle +---@field buf number|nil Buffer handle +---@field loaded boolean Whether historical data has been loaded +local state = { + usage = {}, + all_usage = {}, + session_start = os.time(), + win = nil, + buf = nil, + loaded = false, +} + +--- Load historical usage from disk +function M.load_from_history() + if state.loaded then + return + end + + local history_path = get_history_path() + local content = utils.read_file(history_path) + + if content and content ~= "" then + local ok, data = pcall(vim.json.decode, content) + if ok and data and data.usage then + state.all_usage = data.usage + end + end + + state.loaded = true +end + +--- Save all usage to disk (debounced) +local save_timer = nil +local function save_to_disk() + -- Cancel existing timer + if save_timer then + save_timer:stop() + save_timer = nil + end + + -- Debounce writes (500ms) + save_timer = vim.loop.new_timer() + save_timer:start(500, 0, vim.schedule_wrap(function() + local history_path = get_history_path() + + -- Ensure directory exists + local dir = vim.fn.fnamemodify(history_path, ":h") + utils.ensure_dir(dir) + + -- Merge session and historical usage + local all_data = vim.deepcopy(state.all_usage) + for _, usage in ipairs(state.usage) do + table.insert(all_data, usage) + end + + -- Save to file + local data = { + version = 1, + updated = os.time(), + usage = all_data, + } + + local ok, json = pcall(vim.json.encode, data) + if ok then + utils.write_file(history_path, json) + end + + save_timer = nil + end)) +end + +--- Normalize model name for pricing lookup +---@param model string Model name from API +---@return string Normalized model name +local function normalize_model(model) + if not model then + return "unknown" + end + + -- Convert to lowercase + local normalized = model:lower() + + -- Handle Copilot models + if normalized:match("copilot") then + return "copilot" + end + + -- Handle common prefixes + normalized = normalized:gsub("^openai/", "") + normalized = normalized:gsub("^anthropic/", "") + + -- Try exact match first + if M.pricing[normalized] then + return normalized + end + + -- Try partial matches + for price_model, _ in pairs(M.pricing) do + if normalized:match(price_model) or price_model:match(normalized) then + return price_model + end + end + + return normalized +end + +--- Check if a model is considered "free" (local/Ollama/Copilot subscription) +---@param model string Model name +---@return boolean True if free +function M.is_free_model(model) + local normalized = normalize_model(model) + + -- Check direct match + if M.free_models[normalized] then + return true + end + + -- Check if it's an Ollama model (any model with : in name like deepseek-coder:6.7b) + if model:match(":") then + return true + end + + -- Check pricing - if cost is 0, it's free + local pricing = M.pricing[normalized] + if pricing and pricing.input == 0 and pricing.output == 0 then + return true + end + + return false +end + +--- Calculate cost for token usage +---@param model string Model name +---@param input_tokens number Input tokens +---@param output_tokens number Output tokens +---@param cached_tokens? number Cached input tokens +---@return number Cost in USD +function M.calculate_cost(model, input_tokens, output_tokens, cached_tokens) + local normalized = normalize_model(model) + local pricing = M.pricing[normalized] + + if not pricing then + -- Unknown model, return 0 + return 0 + end + + cached_tokens = cached_tokens or 0 + local regular_input = input_tokens - cached_tokens + + -- Calculate cost (prices are per 1M tokens) + local input_cost = (regular_input / 1000000) * (pricing.input or 0) + local cached_cost = (cached_tokens / 1000000) * (pricing.cached_input or pricing.input or 0) + local output_cost = (output_tokens / 1000000) * (pricing.output or 0) + + return input_cost + cached_cost + output_cost +end + +--- Calculate estimated savings (what would have been paid if using comparison model) +---@param input_tokens number Input tokens +---@param output_tokens number Output tokens +---@param cached_tokens? number Cached input tokens +---@return number Estimated savings in USD +function M.calculate_savings(input_tokens, output_tokens, cached_tokens) + -- Calculate what it would have cost with the comparison model + return M.calculate_cost(M.comparison_model, input_tokens, output_tokens, cached_tokens) +end + +--- Record token usage +---@param model string Model name +---@param input_tokens number Input tokens +---@param output_tokens number Output tokens +---@param cached_tokens? number Cached input tokens +function M.record_usage(model, input_tokens, output_tokens, cached_tokens) + cached_tokens = cached_tokens or 0 + local cost = M.calculate_cost(model, input_tokens, output_tokens, cached_tokens) + + -- Calculate savings if using a free model + local savings = 0 + if M.is_free_model(model) then + savings = M.calculate_savings(input_tokens, output_tokens, cached_tokens) + end + + table.insert(state.usage, { + model = model, + input_tokens = input_tokens, + output_tokens = output_tokens, + cached_tokens = cached_tokens, + timestamp = os.time(), + cost = cost, + savings = savings, + is_free = M.is_free_model(model), + }) + + -- Save to disk (debounced) + save_to_disk() + + -- Update window if open + if state.win and vim.api.nvim_win_is_valid(state.win) then + M.refresh_window() + end +end + +--- Aggregate usage data into stats +---@param usage_list CostUsage[] List of usage records +---@return table Stats +local function aggregate_usage(usage_list) + local stats = { + total_input = 0, + total_output = 0, + total_cached = 0, + total_cost = 0, + total_savings = 0, + free_requests = 0, + paid_requests = 0, + by_model = {}, + request_count = #usage_list, + } + + for _, usage in ipairs(usage_list) do + stats.total_input = stats.total_input + (usage.input_tokens or 0) + stats.total_output = stats.total_output + (usage.output_tokens or 0) + stats.total_cached = stats.total_cached + (usage.cached_tokens or 0) + stats.total_cost = stats.total_cost + (usage.cost or 0) + + -- Track savings + local usage_savings = usage.savings or 0 + -- For historical data without savings field, calculate it + if usage_savings == 0 and usage.is_free == nil then + local model = usage.model or "unknown" + if M.is_free_model(model) then + usage_savings = M.calculate_savings( + usage.input_tokens or 0, + usage.output_tokens or 0, + usage.cached_tokens or 0 + ) + end + end + stats.total_savings = stats.total_savings + usage_savings + + -- Track free vs paid + local is_free = usage.is_free + if is_free == nil then + is_free = M.is_free_model(usage.model or "unknown") + end + if is_free then + stats.free_requests = stats.free_requests + 1 + else + stats.paid_requests = stats.paid_requests + 1 + end + + local model = usage.model or "unknown" + if not stats.by_model[model] then + stats.by_model[model] = { + input_tokens = 0, + output_tokens = 0, + cached_tokens = 0, + cost = 0, + savings = 0, + requests = 0, + is_free = is_free, + } + end + + stats.by_model[model].input_tokens = stats.by_model[model].input_tokens + (usage.input_tokens or 0) + stats.by_model[model].output_tokens = stats.by_model[model].output_tokens + (usage.output_tokens or 0) + stats.by_model[model].cached_tokens = stats.by_model[model].cached_tokens + (usage.cached_tokens or 0) + stats.by_model[model].cost = stats.by_model[model].cost + (usage.cost or 0) + stats.by_model[model].savings = stats.by_model[model].savings + usage_savings + stats.by_model[model].requests = stats.by_model[model].requests + 1 + end + + return stats +end + +--- Get session statistics +---@return table Statistics +function M.get_stats() + local stats = aggregate_usage(state.usage) + stats.session_duration = os.time() - state.session_start + return stats +end + +--- Get all-time statistics (session + historical) +---@return table Statistics +function M.get_all_time_stats() + -- Load history if not loaded + M.load_from_history() + + -- Combine session and historical usage + local all_usage = vim.deepcopy(state.all_usage) + for _, usage in ipairs(state.usage) do + table.insert(all_usage, usage) + end + + local stats = aggregate_usage(all_usage) + + -- Calculate time span + if #all_usage > 0 then + local oldest = all_usage[1].timestamp or os.time() + for _, usage in ipairs(all_usage) do + if usage.timestamp and usage.timestamp < oldest then + oldest = usage.timestamp + end + end + stats.time_span = os.time() - oldest + else + stats.time_span = 0 + end + + return stats +end + +--- Format cost as string +---@param cost number Cost in USD +---@return string Formatted cost +local function format_cost(cost) + if cost < 0.01 then + return string.format("$%.4f", cost) + elseif cost < 1 then + return string.format("$%.3f", cost) + else + return string.format("$%.2f", cost) + end +end + +--- Format token count +---@param tokens number Token count +---@return string Formatted count +local function format_tokens(tokens) + if tokens >= 1000000 then + return string.format("%.2fM", tokens / 1000000) + elseif tokens >= 1000 then + return string.format("%.1fK", tokens / 1000) + else + return tostring(tokens) + end +end + +--- Format duration +---@param seconds number Duration in seconds +---@return string Formatted duration +local function format_duration(seconds) + if seconds < 60 then + return string.format("%ds", seconds) + elseif seconds < 3600 then + return string.format("%dm %ds", math.floor(seconds / 60), seconds % 60) + else + local hours = math.floor(seconds / 3600) + local mins = math.floor((seconds % 3600) / 60) + return string.format("%dh %dm", hours, mins) + end +end + +--- Generate model breakdown section +---@param stats table Stats with by_model +---@return string[] Lines +local function generate_model_breakdown(stats) + local lines = {} + + if next(stats.by_model) then + -- Sort models by cost (descending) + local models = {} + for model, data in pairs(stats.by_model) do + table.insert(models, { name = model, data = data }) + end + table.sort(models, function(a, b) + return a.data.cost > b.data.cost + end) + + for _, item in ipairs(models) do + local model = item.name + local data = item.data + local pricing = M.pricing[normalize_model(model)] + local is_free = data.is_free or M.is_free_model(model) + + table.insert(lines, "") + local model_icon = is_free and "🆓" or "💳" + table.insert(lines, string.format(" %s %s", model_icon, model)) + table.insert(lines, string.format(" Requests: %d", data.requests)) + table.insert(lines, string.format(" Input: %s tokens", format_tokens(data.input_tokens))) + table.insert(lines, string.format(" Output: %s tokens", format_tokens(data.output_tokens))) + + if is_free then + -- Show savings for free models + if data.savings and data.savings > 0 then + table.insert(lines, string.format(" Saved: %s", format_cost(data.savings))) + end + else + table.insert(lines, string.format(" Cost: %s", format_cost(data.cost))) + end + + -- Show pricing info for paid models + if pricing and not is_free then + local price_info = string.format( + " Rate: $%.2f/1M in, $%.2f/1M out", + pricing.input or 0, + pricing.output or 0 + ) + table.insert(lines, price_info) + end + end + else + table.insert(lines, " No usage recorded.") + end + + return lines +end + +--- Generate window content +---@return string[] Lines for the buffer +local function generate_content() + local session_stats = M.get_stats() + local all_time_stats = M.get_all_time_stats() + local lines = {} + + -- Header + table.insert(lines, "╔══════════════════════════════════════════════════════╗") + table.insert(lines, "║ 💰 LLM Cost Estimation ║") + table.insert(lines, "╠══════════════════════════════════════════════════════╣") + table.insert(lines, "") + + -- All-time summary (prominent) + table.insert(lines, "🌐 All-Time Summary (Project)") + table.insert(lines, "───────────────────────────────────────────────────────") + if all_time_stats.time_span > 0 then + table.insert(lines, string.format(" Time span: %s", format_duration(all_time_stats.time_span))) + end + table.insert(lines, string.format(" Requests: %d total", all_time_stats.request_count)) + table.insert(lines, string.format(" Local/Free: %d requests", all_time_stats.free_requests or 0)) + table.insert(lines, string.format(" Paid API: %d requests", all_time_stats.paid_requests or 0)) + table.insert(lines, string.format(" Input tokens: %s", format_tokens(all_time_stats.total_input))) + table.insert(lines, string.format(" Output tokens: %s", format_tokens(all_time_stats.total_output))) + if all_time_stats.total_cached > 0 then + table.insert(lines, string.format(" Cached tokens: %s", format_tokens(all_time_stats.total_cached))) + end + table.insert(lines, "") + table.insert(lines, string.format(" 💵 Total Cost: %s", format_cost(all_time_stats.total_cost))) + + -- Show savings prominently if there are any + if all_time_stats.total_savings and all_time_stats.total_savings > 0 then + table.insert(lines, string.format(" 💚 Saved: %s (vs %s)", format_cost(all_time_stats.total_savings), M.comparison_model)) + end + table.insert(lines, "") + + -- Session summary + table.insert(lines, "📊 Current Session") + table.insert(lines, "───────────────────────────────────────────────────────") + table.insert(lines, string.format(" Duration: %s", format_duration(session_stats.session_duration))) + table.insert(lines, string.format(" Requests: %d (%d free, %d paid)", + session_stats.request_count, + session_stats.free_requests or 0, + session_stats.paid_requests or 0)) + table.insert(lines, string.format(" Input tokens: %s", format_tokens(session_stats.total_input))) + table.insert(lines, string.format(" Output tokens: %s", format_tokens(session_stats.total_output))) + if session_stats.total_cached > 0 then + table.insert(lines, string.format(" Cached tokens: %s", format_tokens(session_stats.total_cached))) + end + table.insert(lines, string.format(" Session Cost: %s", format_cost(session_stats.total_cost))) + if session_stats.total_savings and session_stats.total_savings > 0 then + table.insert(lines, string.format(" Session Saved: %s", format_cost(session_stats.total_savings))) + end + table.insert(lines, "") + + -- Per-model breakdown (all-time) + table.insert(lines, "📈 Cost by Model (All-Time)") + table.insert(lines, "───────────────────────────────────────────────────────") + local model_lines = generate_model_breakdown(all_time_stats) + for _, line in ipairs(model_lines) do + table.insert(lines, line) + end + + table.insert(lines, "") + table.insert(lines, "───────────────────────────────────────────────────────") + table.insert(lines, " 'q' close | 'r' refresh | 'c' clear session | 'C' all") + table.insert(lines, "╚══════════════════════════════════════════════════════╝") + + return lines +end + +--- Refresh the cost window content +function M.refresh_window() + if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then + return + end + + local lines = generate_content() + + vim.bo[state.buf].modifiable = true + vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, lines) + vim.bo[state.buf].modifiable = false +end + +--- Open the cost estimation window +function M.open() + -- Load historical data if not loaded + M.load_from_history() + + -- Close existing window if open + if state.win and vim.api.nvim_win_is_valid(state.win) then + vim.api.nvim_win_close(state.win, true) + end + + -- Create buffer + state.buf = vim.api.nvim_create_buf(false, true) + vim.bo[state.buf].buftype = "nofile" + vim.bo[state.buf].bufhidden = "wipe" + vim.bo[state.buf].swapfile = false + vim.bo[state.buf].filetype = "codetyper-cost" + + -- Calculate window size + local width = 58 + local height = 40 + local row = math.floor((vim.o.lines - height) / 2) + local col = math.floor((vim.o.columns - width) / 2) + + -- Create floating window + state.win = vim.api.nvim_open_win(state.buf, true, { + relative = "editor", + width = width, + height = height, + row = row, + col = col, + style = "minimal", + border = "rounded", + title = " Cost Estimation ", + title_pos = "center", + }) + + -- Set window options + vim.wo[state.win].wrap = false + vim.wo[state.win].cursorline = false + + -- Populate content + M.refresh_window() + + -- Set up keymaps + local opts = { buffer = state.buf, silent = true } + vim.keymap.set("n", "q", function() + M.close() + end, opts) + vim.keymap.set("n", "", function() + M.close() + end, opts) + vim.keymap.set("n", "r", function() + M.refresh_window() + end, opts) + vim.keymap.set("n", "c", function() + M.clear_session() + M.refresh_window() + end, opts) + vim.keymap.set("n", "C", function() + M.clear_all() + M.refresh_window() + end, opts) + + -- Set up highlights + vim.api.nvim_buf_call(state.buf, function() + vim.fn.matchadd("Title", "LLM Cost Estimation") + vim.fn.matchadd("Number", "\\$[0-9.]*") + vim.fn.matchadd("Keyword", "[0-9.]*[KM]\\? tokens") + vim.fn.matchadd("Special", "🤖\\|💰\\|📊\\|📈\\|💵") + end) +end + +--- Close the cost window +function M.close() + if state.win and vim.api.nvim_win_is_valid(state.win) then + vim.api.nvim_win_close(state.win, true) + end + state.win = nil + state.buf = nil +end + +--- Toggle the cost window +function M.toggle() + if state.win and vim.api.nvim_win_is_valid(state.win) then + M.close() + else + M.open() + end +end + +--- Clear session usage (not history) +function M.clear_session() + state.usage = {} + state.session_start = os.time() + utils.notify("Session cost tracking cleared", vim.log.levels.INFO) +end + +--- Clear all history (session + saved) +function M.clear_all() + state.usage = {} + state.all_usage = {} + state.session_start = os.time() + state.loaded = false + + -- Delete history file + local history_path = get_history_path() + local ok, err = os.remove(history_path) + if not ok and err and not err:match("No such file") then + utils.notify("Failed to delete history: " .. err, vim.log.levels.WARN) + end + + utils.notify("All cost history cleared", vim.log.levels.INFO) +end + +--- Clear usage history (alias for clear_session) +function M.clear() + M.clear_session() +end + +--- Reset session +function M.reset() + M.clear_session() +end + +return M diff --git a/lua/codetyper/credentials.lua b/lua/codetyper/credentials.lua new file mode 100644 index 0000000..e4d5156 --- /dev/null +++ b/lua/codetyper/credentials.lua @@ -0,0 +1,602 @@ +---@mod codetyper.credentials Secure credential storage for Codetyper.nvim +---@brief [[ +--- Manages API keys and model preferences stored outside of config files. +--- Credentials are stored in ~/.local/share/nvim/codetyper/configuration.json +---@brief ]] + +local M = {} + +local utils = require("codetyper.utils") + +--- Get the credentials file path +---@return string Path to credentials file +local function get_credentials_path() + local data_dir = vim.fn.stdpath("data") + return data_dir .. "/codetyper/configuration.json" +end + +--- Ensure the credentials directory exists +---@return boolean Success +local function ensure_dir() + local data_dir = vim.fn.stdpath("data") + local codetyper_dir = data_dir .. "/codetyper" + return utils.ensure_dir(codetyper_dir) +end + +--- Load credentials from file +---@return table Credentials data +function M.load() + local path = get_credentials_path() + local content = utils.read_file(path) + + if not content or content == "" then + return { + version = 1, + providers = {}, + } + end + + local ok, data = pcall(vim.json.decode, content) + if not ok or not data then + return { + version = 1, + providers = {}, + } + end + + return data +end + +--- Save credentials to file +---@param data table Credentials data +---@return boolean Success +function M.save(data) + if not ensure_dir() then + return false + end + + local path = get_credentials_path() + local ok, json = pcall(vim.json.encode, data) + if not ok then + return false + end + + return utils.write_file(path, json) +end + +--- Get API key for a provider +---@param provider string Provider name (claude, openai, gemini, copilot, ollama) +---@return string|nil API key or nil if not found +function M.get_api_key(provider) + local data = M.load() + local provider_data = data.providers and data.providers[provider] + + if provider_data and provider_data.api_key then + return provider_data.api_key + end + + return nil +end + +--- Get model for a provider +---@param provider string Provider name +---@return string|nil Model name or nil if not found +function M.get_model(provider) + local data = M.load() + local provider_data = data.providers and data.providers[provider] + + if provider_data and provider_data.model then + return provider_data.model + end + + return nil +end + +--- Get endpoint for a provider (for custom OpenAI-compatible endpoints) +---@param provider string Provider name +---@return string|nil Endpoint URL or nil if not found +function M.get_endpoint(provider) + local data = M.load() + local provider_data = data.providers and data.providers[provider] + + if provider_data and provider_data.endpoint then + return provider_data.endpoint + end + + return nil +end + +--- Get host for Ollama +---@return string|nil Host URL or nil if not found +function M.get_ollama_host() + local data = M.load() + local provider_data = data.providers and data.providers.ollama + + if provider_data and provider_data.host then + return provider_data.host + end + + return nil +end + +--- Set credentials for a provider +---@param provider string Provider name +---@param credentials table Credentials (api_key, model, endpoint, host) +---@return boolean Success +function M.set_credentials(provider, credentials) + local data = M.load() + + if not data.providers then + data.providers = {} + end + + if not data.providers[provider] then + data.providers[provider] = {} + end + + -- Merge credentials + for key, value in pairs(credentials) do + if value and value ~= "" then + data.providers[provider][key] = value + end + end + + data.updated = os.time() + + return M.save(data) +end + +--- Remove credentials for a provider +---@param provider string Provider name +---@return boolean Success +function M.remove_credentials(provider) + local data = M.load() + + if data.providers and data.providers[provider] then + data.providers[provider] = nil + data.updated = os.time() + return M.save(data) + end + + return true +end + +--- List all configured providers (checks both stored credentials AND config) +---@return table List of provider names with their config status +function M.list_providers() + local data = M.load() + local result = {} + + local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" } + + for _, provider in ipairs(all_providers) do + local provider_data = data.providers and data.providers[provider] + local has_stored_key = provider_data and provider_data.api_key and provider_data.api_key ~= "" + local has_model = provider_data and provider_data.model and provider_data.model ~= "" + + -- Check if configured from config or environment + local configured_from_config = false + local config_model = nil + local ok, codetyper = pcall(require, "codetyper") + if ok then + local config = codetyper.get_config() + if config and config.llm and config.llm[provider] then + local pc = config.llm[provider] + config_model = pc.model + + if provider == "claude" then + configured_from_config = pc.api_key ~= nil or vim.env.ANTHROPIC_API_KEY ~= nil + elseif provider == "openai" then + configured_from_config = pc.api_key ~= nil or vim.env.OPENAI_API_KEY ~= nil + elseif provider == "gemini" then + configured_from_config = pc.api_key ~= nil or vim.env.GEMINI_API_KEY ~= nil + elseif provider == "copilot" then + configured_from_config = true -- Just needs copilot.lua + elseif provider == "ollama" then + configured_from_config = pc.host ~= nil + end + end + end + + local is_configured = has_stored_key + or (provider == "ollama" and provider_data ~= nil) + or (provider == "copilot" and (provider_data ~= nil or configured_from_config)) + or configured_from_config + + table.insert(result, { + name = provider, + configured = is_configured, + has_api_key = has_stored_key, + has_model = has_model or config_model ~= nil, + model = (provider_data and provider_data.model) or config_model, + source = has_stored_key and "stored" or (configured_from_config and "config" or nil), + }) + end + + return result +end + +--- Default models for each provider +M.default_models = { + claude = "claude-sonnet-4-20250514", + openai = "gpt-4o", + gemini = "gemini-2.0-flash", + copilot = "gpt-4o", + ollama = "deepseek-coder:6.7b", +} + +--- Interactive command to add/update API key +function M.interactive_add() + local providers = { "claude", "openai", "gemini", "copilot", "ollama" } + + -- Step 1: Select provider + vim.ui.select(providers, { + prompt = "Select LLM provider:", + format_item = function(item) + local display = item:sub(1, 1):upper() .. item:sub(2) + local creds = M.load() + local configured = creds.providers and creds.providers[item] + if configured and (configured.api_key or item == "ollama") then + return display .. " [configured]" + end + return display + end, + }, function(provider) + if not provider then + return + end + + -- Step 2: Get API key (skip for Ollama) + if provider == "ollama" then + M.interactive_ollama_config() + else + M.interactive_api_key(provider) + end + end) +end + +--- Interactive API key input +---@param provider string Provider name +function M.interactive_api_key(provider) + -- Copilot uses OAuth from copilot.lua, no API key needed + if provider == "copilot" then + M.interactive_copilot_config() + return + end + + local prompt = string.format("Enter %s API key (leave empty to skip): ", provider:upper()) + + vim.ui.input({ prompt = prompt }, function(api_key) + if api_key == nil then + return -- Cancelled + end + + -- Step 3: Get model + M.interactive_model(provider, api_key) + end) +end + +--- Interactive Copilot configuration (no API key, uses OAuth) +function M.interactive_copilot_config() + utils.notify("Copilot uses OAuth from copilot.lua/copilot.vim - no API key needed", vim.log.levels.INFO) + + -- Just ask for model + local default_model = M.default_models.copilot + vim.ui.input({ + prompt = string.format("Copilot model (default: %s): ", default_model), + default = default_model, + }, function(model) + if model == nil then + return -- Cancelled + end + + if model == "" then + model = default_model + end + + M.save_and_notify("copilot", { + model = model, + -- Mark as configured even without API key + configured = true, + }) + end) +end + +--- Interactive model selection +---@param provider string Provider name +---@param api_key string|nil API key +function M.interactive_model(provider, api_key) + local default_model = M.default_models[provider] or "" + local prompt = string.format("Enter model (default: %s): ", default_model) + + vim.ui.input({ prompt = prompt, default = default_model }, function(model) + if model == nil then + return -- Cancelled + end + + -- Use default if empty + if model == "" then + model = default_model + end + + -- Save credentials + local credentials = { + model = model, + } + + if api_key and api_key ~= "" then + credentials.api_key = api_key + end + + -- For OpenAI, also ask for custom endpoint + if provider == "openai" then + M.interactive_endpoint(provider, credentials) + else + M.save_and_notify(provider, credentials) + end + end) +end + +--- Interactive endpoint input for OpenAI-compatible providers +---@param provider string Provider name +---@param credentials table Current credentials +function M.interactive_endpoint(provider, credentials) + vim.ui.input({ + prompt = "Custom endpoint (leave empty for default OpenAI): ", + }, function(endpoint) + if endpoint == nil then + return -- Cancelled + end + + if endpoint ~= "" then + credentials.endpoint = endpoint + end + + M.save_and_notify(provider, credentials) + end) +end + +--- Interactive Ollama configuration +function M.interactive_ollama_config() + vim.ui.input({ + prompt = "Ollama host (default: http://localhost:11434): ", + default = "http://localhost:11434", + }, function(host) + if host == nil then + return -- Cancelled + end + + if host == "" then + host = "http://localhost:11434" + end + + -- Get model + local default_model = M.default_models.ollama + vim.ui.input({ + prompt = string.format("Ollama model (default: %s): ", default_model), + default = default_model, + }, function(model) + if model == nil then + return -- Cancelled + end + + if model == "" then + model = default_model + end + + M.save_and_notify("ollama", { + host = host, + model = model, + }) + end) + end) +end + +--- Save credentials and notify user +---@param provider string Provider name +---@param credentials table Credentials to save +function M.save_and_notify(provider, credentials) + if M.set_credentials(provider, credentials) then + local msg = string.format("Saved %s configuration", provider:upper()) + if credentials.model then + msg = msg .. " (model: " .. credentials.model .. ")" + end + utils.notify(msg, vim.log.levels.INFO) + else + utils.notify("Failed to save credentials", vim.log.levels.ERROR) + end +end + +--- Show current credentials status +function M.show_status() + local providers = M.list_providers() + + -- Get current active provider + local codetyper = require("codetyper") + local current = codetyper.get_config().llm.provider + + local lines = { + "Codetyper Credentials Status", + "============================", + "", + "Storage: " .. get_credentials_path(), + "Active: " .. current:upper(), + "", + } + + for _, p in ipairs(providers) do + local status_icon = p.configured and "✓" or "✗" + local active_marker = p.name == current and " [ACTIVE]" or "" + local source_info = "" + if p.configured then + source_info = p.source == "stored" and " (stored)" or " (config)" + end + local model_info = p.model and (" - " .. p.model) or "" + + table.insert(lines, string.format(" %s %s%s%s%s", + status_icon, + p.name:upper(), + active_marker, + source_info, + model_info)) + end + + table.insert(lines, "") + table.insert(lines, "Commands:") + table.insert(lines, " :CoderAddApiKey - Add/update credentials") + table.insert(lines, " :CoderSwitchProvider - Switch active provider") + table.insert(lines, " :CoderRemoveApiKey - Remove stored credentials") + + utils.notify(table.concat(lines, "\n")) +end + +--- Interactive remove credentials +function M.interactive_remove() + local data = M.load() + local configured = {} + + for provider, _ in pairs(data.providers or {}) do + table.insert(configured, provider) + end + + if #configured == 0 then + utils.notify("No credentials configured", vim.log.levels.INFO) + return + end + + vim.ui.select(configured, { + prompt = "Select provider to remove:", + }, function(provider) + if not provider then + return + end + + vim.ui.select({ "Yes", "No" }, { + prompt = "Remove " .. provider:upper() .. " credentials?", + }, function(choice) + if choice == "Yes" then + if M.remove_credentials(provider) then + utils.notify("Removed " .. provider:upper() .. " credentials", vim.log.levels.INFO) + else + utils.notify("Failed to remove credentials", vim.log.levels.ERROR) + end + end + end) + end) +end + +--- Set the active provider +---@param provider string Provider name +function M.set_active_provider(provider) + local data = M.load() + data.active_provider = provider + data.updated = os.time() + M.save(data) + + -- Also update the runtime config + local codetyper = require("codetyper") + local config = codetyper.get_config() + config.llm.provider = provider + + utils.notify("Active provider set to: " .. provider:upper(), vim.log.levels.INFO) +end + +--- Get the active provider from stored config +---@return string|nil Active provider +function M.get_active_provider() + local data = M.load() + return data.active_provider +end + +--- Check if a provider is configured (from stored credentials OR config) +---@param provider string Provider name +---@return boolean configured, string|nil source +local function is_provider_configured(provider) + -- Check stored credentials first + local data = M.load() + local stored = data.providers and data.providers[provider] + if stored then + if stored.configured or stored.api_key or provider == "ollama" or provider == "copilot" then + return true, "stored" + end + end + + -- Check codetyper config + local ok, codetyper = pcall(require, "codetyper") + if not ok then + return false, nil + end + + local config = codetyper.get_config() + if not config or not config.llm then + return false, nil + end + + local provider_config = config.llm[provider] + if not provider_config then + return false, nil + end + + -- Check for API key in config or environment + if provider == "claude" then + if provider_config.api_key or vim.env.ANTHROPIC_API_KEY then + return true, "config" + end + elseif provider == "openai" then + if provider_config.api_key or vim.env.OPENAI_API_KEY then + return true, "config" + end + elseif provider == "gemini" then + if provider_config.api_key or vim.env.GEMINI_API_KEY then + return true, "config" + end + elseif provider == "copilot" then + -- Copilot just needs copilot.lua installed + return true, "config" + elseif provider == "ollama" then + -- Ollama just needs host configured + if provider_config.host then + return true, "config" + end + end + + return false, nil +end + +--- Interactive switch provider +function M.interactive_switch_provider() + local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" } + local available = {} + local sources = {} + + for _, provider in ipairs(all_providers) do + local configured, source = is_provider_configured(provider) + if configured then + table.insert(available, provider) + sources[provider] = source + end + end + + if #available == 0 then + utils.notify("No providers configured. Use :CoderAddApiKey or add to your config.", vim.log.levels.WARN) + return + end + + local codetyper = require("codetyper") + local current = codetyper.get_config().llm.provider + + vim.ui.select(available, { + prompt = "Select provider (current: " .. current .. "):", + format_item = function(item) + local marker = item == current and " [active]" or "" + local source_marker = sources[item] == "stored" and " (stored)" or " (config)" + return item:upper() .. marker .. source_marker + end, + }, function(provider) + if provider then + M.set_active_provider(provider) + end + end) +end + +return M diff --git a/lua/codetyper/indexer/analyzer.lua b/lua/codetyper/indexer/analyzer.lua index 78aad8a..2babd95 100644 --- a/lua/codetyper/indexer/analyzer.lua +++ b/lua/codetyper/indexer/analyzer.lua @@ -83,6 +83,9 @@ local TS_QUERIES = { }, } +-- Forward declaration for analyze_tree_generic (defined below) +local analyze_tree_generic + --- Hash file content for change detection ---@param content string ---@return string @@ -256,7 +259,7 @@ end ---@param root TSNode ---@param bufnr number ---@return table -local function analyze_tree_generic(root, bufnr) +analyze_tree_generic = function(root, bufnr) local result = { functions = {}, classes = {}, diff --git a/lua/codetyper/llm/copilot.lua b/lua/codetyper/llm/copilot.lua index 5f08231..fa985f4 100644 --- a/lua/codetyper/llm/copilot.lua +++ b/lua/codetyper/llm/copilot.lua @@ -14,6 +14,51 @@ local AUTH_URL = "https://api.github.com/copilot_internal/v2/token" ---@field github_token table|nil M.state = nil +--- Track if we've already suggested Ollama fallback this session +local ollama_fallback_suggested = false + +--- Suggest switching to Ollama when rate limits are hit +---@param error_msg string The error message that triggered this +function M.suggest_ollama_fallback(error_msg) + if ollama_fallback_suggested then + return + end + + -- Check if Ollama is available + local ollama_available = false + vim.fn.jobstart({ "curl", "-s", "http://localhost:11434/api/tags" }, { + on_exit = function(_, code) + if code == 0 then + ollama_available = true + end + + vim.schedule(function() + if ollama_available then + -- Switch to Ollama automatically + local codetyper = require("codetyper") + local config = codetyper.get_config() + config.llm.provider = "ollama" + + ollama_fallback_suggested = true + utils.notify( + "⚠️ Copilot rate limit reached. Switched to Ollama automatically.\n" + .. "Original error: " + .. error_msg:sub(1, 100), + vim.log.levels.WARN + ) + else + utils.notify( + "⚠️ Copilot rate limit reached. Ollama not available.\n" + .. "Start Ollama with: ollama serve\n" + .. "Or wait for Copilot limits to reset.", + vim.log.levels.WARN + ) + end + end) + end, + }) +end + --- Get OAuth token from copilot.lua or copilot.vim config ---@return string|nil OAuth token local function get_oauth_token() @@ -51,9 +96,16 @@ local function get_oauth_token() return nil end ---- Get model from config +--- Get model from stored credentials or config ---@return string Model name local function get_model() + -- Priority: stored credentials > config + local credentials = require("codetyper.credentials") + local stored_model = credentials.get_model("copilot") + if stored_model then + return stored_model + end + local codetyper = require("codetyper") local config = codetyper.get_config() return config.llm.copilot.model @@ -204,15 +256,37 @@ local function make_request(token, body, callback) local ok, response = pcall(vim.json.decode, response_text) if not ok then + -- Show the actual response text as the error (truncated if too long) + local error_msg = response_text + if #error_msg > 200 then + error_msg = error_msg:sub(1, 200) .. "..." + end + + -- Clean up common patterns + if response_text:match(" 0 then + table.insert(copilot_messages, { role = "user", content = table.concat(text_parts, "\n") }) end end - if #text_parts > 0 then - table.insert(copilot_messages, { role = msg.role, content = table.concat(text_parts, "\n") }) + elseif msg.role == "assistant" then + -- Assistant messages - must preserve tool_calls if present + local assistant_msg = { + role = "assistant", + content = type(msg.content) == "string" and msg.content or nil, + } + -- Preserve tool_calls for the API + if msg.tool_calls then + assistant_msg.tool_calls = msg.tool_calls + -- Ensure content is not nil when tool_calls present + if assistant_msg.content == nil then + assistant_msg.content = "" + end end + table.insert(copilot_messages, assistant_msg) + elseif msg.role == "tool" then + -- Tool result messages - must have tool_call_id + table.insert(copilot_messages, { + role = "tool", + tool_call_id = msg.tool_call_id, + content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content), + }) end end @@ -396,6 +507,20 @@ function M.generate_with_tools(messages, context, tool_definitions, callback) logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) logs.thinking("Sending to Copilot API...") + -- Log request to debug file + local debug_log_path = vim.fn.expand("~/.local/codetyper-debug.log") + local debug_f = io.open(debug_log_path, "a") + if debug_f then + debug_f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. "COPILOT REQUEST\n") + debug_f:write("Messages count: " .. #copilot_messages .. "\n") + for i, m in ipairs(copilot_messages) do + debug_f:write(string.format(" [%d] role=%s, has_tool_calls=%s, has_tool_call_id=%s\n", + i, m.role, tostring(m.tool_calls ~= nil), tostring(m.tool_call_id ~= nil))) + end + debug_f:write("---\n") + debug_f:close() + end + local headers = build_headers(token) local cmd = { "curl", @@ -413,35 +538,97 @@ function M.generate_with_tools(messages, context, tool_definitions, callback) table.insert(cmd, "-d") table.insert(cmd, json_body) + -- Debug logging helper + local function debug_log(msg, data) + local log_path = vim.fn.expand("~/.local/codetyper-debug.log") + local f = io.open(log_path, "a") + if f then + f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. msg .. "\n") + if data then + f:write("DATA: " .. tostring(data):sub(1, 2000) .. "\n") + end + f:write("---\n") + f:close() + end + end + + -- Prevent double callback calls + local callback_called = false + vim.fn.jobstart(cmd, { stdout_buffered = true, on_stdout = function(_, data) + if callback_called then + debug_log("on_stdout: callback already called, skipping") + return + end + if not data or #data == 0 or (data[1] == "" and #data == 1) then + debug_log("on_stdout: empty data") return end local response_text = table.concat(data, "\n") + debug_log("on_stdout: received response", response_text) + local ok, response = pcall(vim.json.decode, response_text) if not ok then + debug_log("JSON parse failed", response_text) + callback_called = true + + -- Show the actual response text as the error (truncated if too long) + local error_msg = response_text + if #error_msg > 200 then + error_msg = error_msg:sub(1, 200) .. "..." + end + + -- Clean up common patterns + if response_text:match(" 0 and data[1] ~= "" then + debug_log("on_stderr", table.concat(data, "\n")) + callback_called = true vim.schedule(function() logs.error("Copilot API request failed: " .. table.concat(data, "\n")) callback(nil, "Copilot API request failed: " .. table.concat(data, "\n")) @@ -487,7 +681,12 @@ function M.generate_with_tools(messages, context, tool_definitions, callback) end end, on_exit = function(_, code) + debug_log("on_exit: code=" .. code .. ", callback_called=" .. tostring(callback_called)) + if callback_called then + return + end if code ~= 0 then + callback_called = true vim.schedule(function() logs.error("Copilot API request failed with code: " .. code) callback(nil, "Copilot API request failed with code: " .. code) diff --git a/lua/codetyper/llm/gemini.lua b/lua/codetyper/llm/gemini.lua index faac70c..d6ebdef 100644 --- a/lua/codetyper/llm/gemini.lua +++ b/lua/codetyper/llm/gemini.lua @@ -8,17 +8,31 @@ local llm = require("codetyper.llm") --- Gemini API endpoint local API_URL = "https://generativelanguage.googleapis.com/v1beta/models" ---- Get API key from config or environment +--- Get API key from stored credentials, config, or environment ---@return string|nil API key local function get_api_key() + -- Priority: stored credentials > config > environment + local credentials = require("codetyper.credentials") + local stored_key = credentials.get_api_key("gemini") + if stored_key then + return stored_key + end + local codetyper = require("codetyper") local config = codetyper.get_config() return config.llm.gemini.api_key or vim.env.GEMINI_API_KEY end ---- Get model from config +--- Get model from stored credentials or config ---@return string Model name local function get_model() + -- Priority: stored credentials > config + local credentials = require("codetyper.credentials") + local stored_model = credentials.get_model("gemini") + if stored_model then + return stored_model + end + local codetyper = require("codetyper") local config = codetyper.get_config() return config.llm.gemini.model diff --git a/lua/codetyper/llm/init.lua b/lua/codetyper/llm/init.lua index 7ffeef8..8e9f003 100644 --- a/lua/codetyper/llm/init.lua +++ b/lua/codetyper/llm/init.lua @@ -32,6 +32,32 @@ function M.generate(prompt, context, callback) client.generate(prompt, context, callback) end +--- Smart generate with automatic provider selection based on brain memories +--- Prefers Ollama when context is rich, falls back to Copilot otherwise. +--- Implements verification pondering to reinforce Ollama accuracy over time. +---@param prompt string The user's prompt +---@param context table Context information +---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback +function M.smart_generate(prompt, context, callback) + local selector = require("codetyper.llm.selector") + selector.smart_generate(prompt, context, callback) +end + +--- Get accuracy statistics for providers +---@return table Statistics for each provider +function M.get_accuracy_stats() + local selector = require("codetyper.llm.selector") + return selector.get_accuracy_stats() +end + +--- Report user feedback on response quality (for reinforcement learning) +---@param provider string Which provider generated the response +---@param was_correct boolean Whether the response was good +function M.report_feedback(provider, was_correct) + local selector = require("codetyper.llm.selector") + selector.report_feedback(provider, was_correct) +end + --- Build the system prompt for code generation ---@param context table Context information ---@return string System prompt diff --git a/lua/codetyper/llm/ollama.lua b/lua/codetyper/llm/ollama.lua index a6f706f..1033c99 100644 --- a/lua/codetyper/llm/ollama.lua +++ b/lua/codetyper/llm/ollama.lua @@ -5,21 +5,33 @@ local M = {} local utils = require("codetyper.utils") local llm = require("codetyper.llm") ---- Get Ollama host from config +--- Get Ollama host from stored credentials or config ---@return string Host URL local function get_host() + -- Priority: stored credentials > config + local credentials = require("codetyper.credentials") + local stored_host = credentials.get_ollama_host() + if stored_host then + return stored_host + end + local codetyper = require("codetyper") local config = codetyper.get_config() - return config.llm.ollama.host end ---- Get model from config +--- Get model from stored credentials or config ---@return string Model name local function get_model() + -- Priority: stored credentials > config + local credentials = require("codetyper.credentials") + local stored_model = credentials.get_model("ollama") + if stored_model then + return stored_model + end + local codetyper = require("codetyper") local config = codetyper.get_config() - return config.llm.ollama.model end @@ -199,47 +211,41 @@ function M.validate() return true end ---- Build system prompt for agent mode with tool instructions +--- Generate with tool use support for agentic mode (text-based tool calling) +---@param messages table[] Conversation history ---@param context table Context information ----@return string System prompt -local function build_agent_system_prompt(context) +---@param tool_definitions table Tool definitions +---@param callback fun(response: table|nil, error: string|nil) Callback with Claude-like response format +function M.generate_with_tools(messages, context, tool_definitions, callback) + local logs = require("codetyper.agent.logs") local agent_prompts = require("codetyper.prompts.agent") local tools_module = require("codetyper.agent.tools") - local system_prompt = agent_prompts.system .. "\n\n" - system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n" - system_prompt = system_prompt .. agent_prompts.tool_instructions + logs.request("ollama", get_model()) + logs.thinking("Preparing agent request...") - -- Add context about current file if available - if context.file_path then - system_prompt = system_prompt .. "\n\nCurrent working context:\n" - system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n" - if context.language then - system_prompt = system_prompt .. "- Language: " .. context.language .. "\n" + -- Build system prompt with tool instructions + local system_prompt = llm.build_system_prompt(context) + system_prompt = system_prompt .. "\n\n" .. agent_prompts.system + system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions + + -- Add tool descriptions + system_prompt = system_prompt .. "\n\n## Available Tools\n" + system_prompt = system_prompt .. "Call tools by outputting JSON in this exact format:\n" + system_prompt = system_prompt .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n' + + for _, tool in ipairs(tool_definitions) do + local name = tool.name or (tool["function"] and tool["function"].name) + local desc = tool.description or (tool["function"] and tool["function"].description) + if name then + system_prompt = system_prompt .. string.format("### %s\n%s\n\n", name, desc or "") end end - -- Add project root info - local root = utils.get_project_root() - if root then - system_prompt = system_prompt .. "- Project root: " .. root .. "\n" - end - - return system_prompt -end - ---- Build request body for Ollama API with tools (chat format) ----@param messages table[] Conversation messages ----@param context table Context information ----@return table Request body -local function build_tools_request_body(messages, context) - local system_prompt = build_agent_system_prompt(context) - -- Convert messages to Ollama chat format local ollama_messages = {} for _, msg in ipairs(messages) do local content = msg.content - -- Handle complex content (like tool results) if type(content) == "table" then local text_parts = {} for _, part in ipairs(content) do @@ -251,14 +257,10 @@ local function build_tools_request_body(messages, context) end content = table.concat(text_parts, "\n") end - - table.insert(ollama_messages, { - role = msg.role, - content = content, - }) + table.insert(ollama_messages, { role = msg.role, content = content }) end - return { + local body = { model = get_model(), messages = ollama_messages, system = system_prompt, @@ -268,16 +270,15 @@ local function build_tools_request_body(messages, context) num_predict = 4096, }, } -end ---- Make HTTP request to Ollama chat API ----@param body table Request body ----@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function -local function make_chat_request(body, callback) local host = get_host() local url = host .. "/api/chat" local json_body = vim.json.encode(body) + local prompt_estimate = logs.estimate_tokens(json_body) + logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) + logs.thinking("Sending to Ollama API...") + local cmd = { "curl", "-s", @@ -302,196 +303,82 @@ local function make_chat_request(body, callback) if not ok then vim.schedule(function() - callback(nil, "Failed to parse Ollama response", nil) + logs.error("Failed to parse Ollama response") + callback(nil, "Failed to parse Ollama response") end) return end if response.error then vim.schedule(function() - callback(nil, response.error or "Ollama API error", nil) + logs.error(response.error or "Ollama API error") + callback(nil, response.error or "Ollama API error") end) return end - -- Extract usage info - local usage = { - prompt_tokens = response.prompt_eval_count or 0, - response_tokens = response.eval_count or 0, - } + -- Log token usage and record cost (Ollama is free but we track usage) + if response.prompt_eval_count or response.eval_count then + logs.response(response.prompt_eval_count or 0, response.eval_count or 0, "stop") - -- Return the message content for agent parsing - if response.message and response.message.content then - vim.schedule(function() - callback(response.message.content, nil, usage) - end) - else - vim.schedule(function() - callback(nil, "No response from Ollama", nil) - end) + -- Record usage for cost tracking (free for local models) + local cost = require("codetyper.cost") + cost.record_usage( + get_model(), + response.prompt_eval_count or 0, + response.eval_count or 0, + 0 -- No cached tokens for Ollama + ) end + + -- Parse the response text for tool calls + local content_text = response.message and response.message.content or "" + local converted = { content = {}, stop_reason = "end_turn" } + + -- Try to extract JSON tool calls from response + local json_match = content_text:match("```json%s*(%b{})%s*```") + if json_match then + local ok_json, parsed = pcall(vim.json.decode, json_match) + if ok_json and parsed.tool then + table.insert(converted.content, { + type = "tool_use", + id = "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF)), + name = parsed.tool, + input = parsed.arguments or {}, + }) + logs.thinking("Tool call: " .. parsed.tool) + content_text = content_text:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "") + converted.stop_reason = "tool_use" + end + end + + -- Add text content + if content_text and content_text ~= "" then + table.insert(converted.content, 1, { type = "text", text = content_text }) + logs.thinking("Response contains text") + end + + vim.schedule(function() + callback(converted, nil) + end) end, on_stderr = function(_, data) if data and #data > 0 and data[1] ~= "" then vim.schedule(function() - callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil) + logs.error("Ollama API request failed: " .. table.concat(data, "\n")) + callback(nil, "Ollama API request failed: " .. table.concat(data, "\n")) end) end end, on_exit = function(_, code) if code ~= 0 then - -- Don't double-report errors + vim.schedule(function() + logs.error("Ollama API request failed with code: " .. code) + callback(nil, "Ollama API request failed with code: " .. code) + end) end end, }) end ---- Generate response with tools using Ollama API ----@param messages table[] Conversation history ----@param context table Context information ----@param tools table Tool definitions (embedded in prompt for Ollama) ----@param callback fun(response: string|nil, error: string|nil) Callback function -function M.generate_with_tools(messages, context, tools, callback) - local logs = require("codetyper.agent.logs") - - -- Log the request - local model = get_model() - logs.request("ollama", model) - logs.thinking("Preparing API request...") - - local body = build_tools_request_body(messages, context) - - -- Estimate prompt tokens - local prompt_estimate = logs.estimate_tokens(vim.json.encode(body)) - logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) - - make_chat_request(body, function(response, err, usage) - if err then - logs.error(err) - callback(nil, err) - else - -- Log token usage - if usage then - logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn") - end - - -- Log if response contains tool calls - if response then - local parser = require("codetyper.agent.parser") - local parsed = parser.parse_ollama_response(response) - if #parsed.tool_calls > 0 then - for _, tc in ipairs(parsed.tool_calls) do - logs.thinking("Tool call: " .. tc.name) - end - end - if parsed.text and parsed.text ~= "" then - logs.thinking("Response contains text") - end - end - - callback(response, nil) - end - end) -end - ---- Generate with tool use support for agentic mode (simulated via prompts) ----@param messages table[] Conversation history ----@param context table Context information ----@param tool_definitions table Tool definitions ----@param callback fun(response: string|nil, error: string|nil) Callback with response text -function M.generate_with_tools(messages, context, tool_definitions, callback) - local tools_module = require("codetyper.agent.tools") - local agent_prompts = require("codetyper.prompts.agent") - - -- Build system prompt with agent instructions and tool definitions - local system_prompt = llm.build_system_prompt(context) - system_prompt = system_prompt .. "\n\n" .. agent_prompts.system - system_prompt = system_prompt .. "\n\n" .. tools_module.to_prompt_format() - - -- Flatten messages to single prompt (Ollama's generate API) - local prompt_parts = {} - for _, msg in ipairs(messages) do - if type(msg.content) == "string" then - local role_prefix = msg.role == "user" and "User" or "Assistant" - table.insert(prompt_parts, role_prefix .. ": " .. msg.content) - elseif type(msg.content) == "table" then - -- Handle tool results - for _, item in ipairs(msg.content) do - if item.type == "tool_result" then - table.insert(prompt_parts, "Tool result: " .. item.content) - end - end - end - end - - local body = { - model = get_model(), - system = system_prompt, - prompt = table.concat(prompt_parts, "\n\n"), - stream = false, - options = { - temperature = 0.2, - num_predict = 4096, - }, - } - - local host = get_host() - local url = host .. "/api/generate" - local json_body = vim.json.encode(body) - - local cmd = { - "curl", - "-s", - "-X", "POST", - url, - "-H", "Content-Type: application/json", - "-d", json_body, - } - - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if not data or #data == 0 or (data[1] == "" and #data == 1) then - return - end - - local response_text = table.concat(data, "\n") - local ok, response = pcall(vim.json.decode, response_text) - - if not ok then - vim.schedule(function() - callback(nil, "Failed to parse Ollama response") - end) - return - end - - if response.error then - vim.schedule(function() - callback(nil, response.error or "Ollama API error") - end) - return - end - - -- Return raw response text for parser to handle - vim.schedule(function() - callback(response.response or "", nil) - end) - end, - on_stderr = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - callback(nil, "Ollama API request failed: " .. table.concat(data, "\n")) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - callback(nil, "Ollama API request failed with code: " .. code) - end) - end - end, - }) -end - return M diff --git a/lua/codetyper/llm/openai.lua b/lua/codetyper/llm/openai.lua index 26b6fcb..4933945 100644 --- a/lua/codetyper/llm/openai.lua +++ b/lua/codetyper/llm/openai.lua @@ -8,25 +8,46 @@ local llm = require("codetyper.llm") --- OpenAI API endpoint local API_URL = "https://api.openai.com/v1/chat/completions" ---- Get API key from config or environment +--- Get API key from stored credentials, config, or environment ---@return string|nil API key local function get_api_key() + -- Priority: stored credentials > config > environment + local credentials = require("codetyper.credentials") + local stored_key = credentials.get_api_key("openai") + if stored_key then + return stored_key + end + local codetyper = require("codetyper") local config = codetyper.get_config() return config.llm.openai.api_key or vim.env.OPENAI_API_KEY end ---- Get model from config +--- Get model from stored credentials or config ---@return string Model name local function get_model() + -- Priority: stored credentials > config + local credentials = require("codetyper.credentials") + local stored_model = credentials.get_model("openai") + if stored_model then + return stored_model + end + local codetyper = require("codetyper") local config = codetyper.get_config() return config.llm.openai.model end ---- Get endpoint from config (allows custom endpoints like Azure, OpenRouter) +--- Get endpoint from stored credentials or config (allows custom endpoints like Azure, OpenRouter) ---@return string API endpoint local function get_endpoint() + -- Priority: stored credentials > config > default + local credentials = require("codetyper.credentials") + local stored_endpoint = credentials.get_endpoint("openai") + if stored_endpoint then + return stored_endpoint + end + local codetyper = require("codetyper") local config = codetyper.get_config() return config.llm.openai.endpoint or API_URL @@ -284,9 +305,18 @@ function M.generate_with_tools(messages, context, tool_definitions, callback) return end - -- Log token usage + -- Log token usage and record cost if response.usage then logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop") + + -- Record usage for cost tracking + local cost = require("codetyper.cost") + cost.record_usage( + model, + response.usage.prompt_tokens or 0, + response.usage.completion_tokens or 0, + response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0 + ) end -- Convert to Claude-like format for parser compatibility diff --git a/lua/codetyper/llm/selector.lua b/lua/codetyper/llm/selector.lua new file mode 100644 index 0000000..aebc21c --- /dev/null +++ b/lua/codetyper/llm/selector.lua @@ -0,0 +1,514 @@ +---@mod codetyper.llm.selector Smart LLM selection with memory-based confidence +---@brief [[ +--- Intelligent LLM provider selection based on brain memories. +--- Prefers local Ollama when context is rich, falls back to Copilot otherwise. +--- Implements verification pondering to reinforce Ollama accuracy over time. +---@brief ]] + +local M = {} + +---@class SelectionResult +---@field provider string Selected provider name +---@field confidence number Confidence score (0-1) +---@field memory_count number Number of relevant memories found +---@field reason string Human-readable reason for selection + +---@class PonderResult +---@field ollama_response string Ollama's response +---@field verifier_response string Verifier's response +---@field agreement_score number How much they agree (0-1) +---@field ollama_correct boolean Whether Ollama was deemed correct +---@field feedback string Feedback for learning + +--- Minimum memories required for high confidence +local MIN_MEMORIES_FOR_LOCAL = 3 + +--- Minimum memory relevance score for local provider +local MIN_RELEVANCE_FOR_LOCAL = 0.6 + +--- Agreement threshold for Ollama verification +local AGREEMENT_THRESHOLD = 0.7 + +--- Pondering sample rate (0-1) - how often to verify Ollama +local PONDER_SAMPLE_RATE = 0.2 + +--- Provider accuracy tracking (persisted in brain) +local accuracy_cache = { + ollama = { correct = 0, total = 0 }, + copilot = { correct = 0, total = 0 }, +} + +--- Get the brain module safely +---@return table|nil +local function get_brain() + local ok, brain = pcall(require, "codetyper.brain") + if ok and brain.is_initialized and brain.is_initialized() then + return brain + end + return nil +end + +--- Load accuracy stats from brain +local function load_accuracy_stats() + local brain = get_brain() + if not brain then + return + end + + -- Query for accuracy tracking nodes + pcall(function() + local result = brain.query({ + query = "provider_accuracy_stats", + types = { "metric" }, + limit = 1, + }) + + if result and result.nodes and #result.nodes > 0 then + local node = result.nodes[1] + if node.c and node.c.d then + local ok, stats = pcall(vim.json.decode, node.c.d) + if ok and stats then + accuracy_cache = stats + end + end + end + end) +end + +--- Save accuracy stats to brain +local function save_accuracy_stats() + local brain = get_brain() + if not brain then + return + end + + pcall(function() + brain.learn({ + type = "metric", + summary = "provider_accuracy_stats", + detail = vim.json.encode(accuracy_cache), + weight = 1.0, + }) + end) +end + +--- Calculate Ollama confidence based on historical accuracy +---@return number confidence (0-1) +local function get_ollama_historical_confidence() + local stats = accuracy_cache.ollama + if stats.total < 5 then + -- Not enough data, return neutral confidence + return 0.5 + end + + local accuracy = stats.correct / stats.total + -- Boost confidence if accuracy is high + return math.min(1.0, accuracy * 1.2) +end + +--- Query brain for relevant context +---@param prompt string User prompt +---@param file_path string|nil Current file path +---@return table result {memories: table[], relevance: number, count: number} +local function query_brain_context(prompt, file_path) + local result = { + memories = {}, + relevance = 0, + count = 0, + } + + local brain = get_brain() + if not brain then + return result + end + + -- Query brain with multiple dimensions + local ok, query_result = pcall(function() + return brain.query({ + query = prompt, + file = file_path, + limit = 10, + types = { "pattern", "correction", "convention", "fact" }, + }) + end) + + if not ok or not query_result then + return result + end + + result.memories = query_result.nodes or {} + result.count = #result.memories + + -- Calculate average relevance + if result.count > 0 then + local total_relevance = 0 + for _, node in ipairs(result.memories) do + -- Use node weight and success rate as relevance indicators + local node_relevance = (node.sc and node.sc.w or 0.5) * (node.sc and node.sc.sr or 0.5) + total_relevance = total_relevance + node_relevance + end + result.relevance = total_relevance / result.count + end + + return result +end + +--- Select the best LLM provider based on context +---@param prompt string User prompt +---@param context table LLM context +---@return SelectionResult +function M.select_provider(prompt, context) + -- Load accuracy stats on first call + if accuracy_cache.ollama.total == 0 then + load_accuracy_stats() + end + + local file_path = context.file_path + + -- Query brain for relevant memories + local brain_context = query_brain_context(prompt, file_path) + + -- Calculate base confidence from memories + local memory_confidence = 0 + if brain_context.count >= MIN_MEMORIES_FOR_LOCAL then + memory_confidence = math.min(1.0, brain_context.count / 10) * brain_context.relevance + end + + -- Factor in historical Ollama accuracy + local historical_confidence = get_ollama_historical_confidence() + + -- Combined confidence score + local combined_confidence = (memory_confidence * 0.6) + (historical_confidence * 0.4) + + -- Decision logic + local provider = "copilot" -- Default to more capable + local reason = "" + + if brain_context.count >= MIN_MEMORIES_FOR_LOCAL and combined_confidence >= MIN_RELEVANCE_FOR_LOCAL then + provider = "ollama" + reason = string.format( + "Rich context: %d memories (%.1f%% relevance), historical accuracy: %.1f%%", + brain_context.count, + brain_context.relevance * 100, + historical_confidence * 100 + ) + elseif brain_context.count > 0 and combined_confidence >= 0.4 then + -- Medium confidence - use Ollama but with pondering + provider = "ollama" + reason = string.format( + "Moderate context: %d memories, will verify with pondering", + brain_context.count + ) + else + reason = string.format( + "Insufficient context: %d memories (need %d), using capable provider", + brain_context.count, + MIN_MEMORIES_FOR_LOCAL + ) + end + + return { + provider = provider, + confidence = combined_confidence, + memory_count = brain_context.count, + reason = reason, + memories = brain_context.memories, + } +end + +--- Check if we should ponder (verify) this Ollama response +---@param confidence number Current confidence level +---@return boolean +function M.should_ponder(confidence) + -- Always ponder when confidence is medium + if confidence >= 0.4 and confidence < 0.7 then + return true + end + + -- Random sampling for high confidence to keep learning + if confidence >= 0.7 then + return math.random() < PONDER_SAMPLE_RATE + end + + -- Low confidence shouldn't reach Ollama anyway + return false +end + +--- Calculate agreement score between two responses +---@param response1 string First response +---@param response2 string Second response +---@return number Agreement score (0-1) +local function calculate_agreement(response1, response2) + -- Normalize responses + local norm1 = response1:lower():gsub("%s+", " "):gsub("[^%w%s]", "") + local norm2 = response2:lower():gsub("%s+", " "):gsub("[^%w%s]", "") + + -- Extract words + local words1 = {} + for word in norm1:gmatch("%w+") do + words1[word] = (words1[word] or 0) + 1 + end + + local words2 = {} + for word in norm2:gmatch("%w+") do + words2[word] = (words2[word] or 0) + 1 + end + + -- Calculate Jaccard similarity + local intersection = 0 + local union = 0 + + for word, count1 in pairs(words1) do + local count2 = words2[word] or 0 + intersection = intersection + math.min(count1, count2) + union = union + math.max(count1, count2) + end + + for word, count2 in pairs(words2) do + if not words1[word] then + union = union + count2 + end + end + + if union == 0 then + return 1.0 -- Both empty + end + + -- Also check structural similarity (code structure) + local struct_score = 0 + local function_count1 = select(2, response1:gsub("function", "")) + local function_count2 = select(2, response2:gsub("function", "")) + if function_count1 > 0 or function_count2 > 0 then + struct_score = 1 - math.abs(function_count1 - function_count2) / math.max(function_count1, function_count2, 1) + else + struct_score = 1.0 + end + + -- Combined score + local jaccard = intersection / union + return (jaccard * 0.7) + (struct_score * 0.3) +end + +--- Ponder (verify) Ollama's response with another LLM +---@param prompt string Original prompt +---@param context table LLM context +---@param ollama_response string Ollama's response +---@param callback fun(result: PonderResult) Callback with pondering result +function M.ponder(prompt, context, ollama_response, callback) + -- Use Copilot as verifier + local copilot = require("codetyper.llm.copilot") + + -- Build verification prompt + local verify_prompt = prompt + + copilot.generate(verify_prompt, context, function(verifier_response, error) + if error or not verifier_response then + -- Verification failed, assume Ollama is correct + callback({ + ollama_response = ollama_response, + verifier_response = "", + agreement_score = 1.0, + ollama_correct = true, + feedback = "Verification unavailable, trusting Ollama", + }) + return + end + + -- Calculate agreement + local agreement = calculate_agreement(ollama_response, verifier_response) + + -- Determine if Ollama was correct + local ollama_correct = agreement >= AGREEMENT_THRESHOLD + + -- Generate feedback + local feedback + if ollama_correct then + feedback = string.format("Agreement: %.1f%% - Ollama response validated", agreement * 100) + else + feedback = string.format( + "Disagreement: %.1f%% - Ollama may need correction", + (1 - agreement) * 100 + ) + end + + -- Update accuracy tracking + accuracy_cache.ollama.total = accuracy_cache.ollama.total + 1 + if ollama_correct then + accuracy_cache.ollama.correct = accuracy_cache.ollama.correct + 1 + end + save_accuracy_stats() + + -- Learn from this verification + local brain = get_brain() + if brain then + pcall(function() + if ollama_correct then + -- Reinforce the pattern + brain.learn({ + type = "correction", + summary = "Ollama verified correct", + detail = string.format( + "Prompt: %s\nAgreement: %.1f%%", + prompt:sub(1, 100), + agreement * 100 + ), + weight = 0.8, + file = context.file_path, + }) + else + -- Learn the correction + brain.learn({ + type = "correction", + summary = "Ollama needed correction", + detail = string.format( + "Prompt: %s\nOllama: %s\nCorrect: %s", + prompt:sub(1, 100), + ollama_response:sub(1, 200), + verifier_response:sub(1, 200) + ), + weight = 0.9, + file = context.file_path, + }) + end + end) + end + + callback({ + ollama_response = ollama_response, + verifier_response = verifier_response, + agreement_score = agreement, + ollama_correct = ollama_correct, + feedback = feedback, + }) + end) +end + +--- Smart generate with automatic provider selection and pondering +---@param prompt string User prompt +---@param context table LLM context +---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback +function M.smart_generate(prompt, context, callback) + -- Select provider + local selection = M.select_provider(prompt, context) + + -- Log selection + pcall(function() + local logs = require("codetyper.agent.logs") + logs.add({ + type = "info", + message = string.format( + "LLM: %s (confidence: %.1f%%, %s)", + selection.provider, + selection.confidence * 100, + selection.reason + ), + }) + end) + + -- Get the selected client + local client + if selection.provider == "ollama" then + client = require("codetyper.llm.ollama") + else + client = require("codetyper.llm.copilot") + end + + -- Generate response + client.generate(prompt, context, function(response, error) + if error then + -- Fallback on error + if selection.provider == "ollama" then + -- Try Copilot as fallback + local copilot = require("codetyper.llm.copilot") + copilot.generate(prompt, context, function(fallback_response, fallback_error) + callback(fallback_response, fallback_error, { + provider = "copilot", + fallback = true, + original_provider = "ollama", + original_error = error, + }) + end) + return + end + callback(nil, error, { provider = selection.provider }) + return + end + + -- Check if we should ponder + if selection.provider == "ollama" and M.should_ponder(selection.confidence) then + M.ponder(prompt, context, response, function(ponder_result) + if ponder_result.ollama_correct then + -- Ollama was correct, use its response + callback(response, nil, { + provider = "ollama", + pondered = true, + agreement = ponder_result.agreement_score, + confidence = selection.confidence, + }) + else + -- Use verifier's response instead + callback(ponder_result.verifier_response, nil, { + provider = "copilot", + pondered = true, + agreement = ponder_result.agreement_score, + original_provider = "ollama", + corrected = true, + }) + end + end) + else + -- No pondering needed + callback(response, nil, { + provider = selection.provider, + pondered = false, + confidence = selection.confidence, + }) + end + end) +end + +--- Get current accuracy statistics +---@return table {ollama: {correct, total, accuracy}, copilot: {correct, total, accuracy}} +function M.get_accuracy_stats() + local stats = { + ollama = { + correct = accuracy_cache.ollama.correct, + total = accuracy_cache.ollama.total, + accuracy = accuracy_cache.ollama.total > 0 + and (accuracy_cache.ollama.correct / accuracy_cache.ollama.total) + or 0, + }, + copilot = { + correct = accuracy_cache.copilot.correct, + total = accuracy_cache.copilot.total, + accuracy = accuracy_cache.copilot.total > 0 + and (accuracy_cache.copilot.correct / accuracy_cache.copilot.total) + or 0, + }, + } + return stats +end + +--- Reset accuracy statistics +function M.reset_accuracy_stats() + accuracy_cache = { + ollama = { correct = 0, total = 0 }, + copilot = { correct = 0, total = 0 }, + } + save_accuracy_stats() +end + +--- Report user feedback on response quality +---@param provider string Which provider generated the response +---@param was_correct boolean Whether the response was good +function M.report_feedback(provider, was_correct) + if accuracy_cache[provider] then + accuracy_cache[provider].total = accuracy_cache[provider].total + 1 + if was_correct then + accuracy_cache[provider].correct = accuracy_cache[provider].correct + 1 + end + save_accuracy_stats() + end +end + +return M diff --git a/lua/codetyper/prompts/agent.lua b/lua/codetyper/prompts/agent.lua index ba51f76..67dd9f4 100644 --- a/lua/codetyper/prompts/agent.lua +++ b/lua/codetyper/prompts/agent.lua @@ -13,24 +13,23 @@ M.system = You have access to these tools - USE THEM to accomplish tasks: ### File Operations -- **read_file**: Read any file. ALWAYS read files before modifying them. -- **write_file**: Create new files or completely replace existing ones. Use for new files. -- **edit_file**: Make precise edits to existing files using find/replace. The "find" must match EXACTLY. -- **delete_file**: Delete files (requires user approval). Include a reason. -- **list_directory**: Explore project structure. See what files exist. -- **search_files**: Find files by pattern or content. +- **view**: Read any file. ALWAYS read files before modifying them. Parameters: path (string) +- **write**: Create new files or completely replace existing ones. Use for new files. Parameters: path (string), content (string) +- **edit**: Make precise edits to existing files using search/replace. Parameters: path (string), old_string (string), new_string (string) +- **glob**: Find files by pattern (e.g., "**/*.lua"). Parameters: pattern (string), path (optional) +- **grep**: Search file contents with regex. Parameters: pattern (string), path (optional) ### Shell Commands -- **bash**: Run shell commands (git, npm, make, etc.). User approves each command. +- **bash**: Run shell commands (git, npm, make, etc.). User approves each command. Parameters: command (string) ## HOW TO WORK -1. **UNDERSTAND FIRST**: Use read_file, list_directory, or search_files to understand the codebase before making changes. +1. **UNDERSTAND FIRST**: Use view, glob, or grep to understand the codebase before making changes. -2. **MAKE CHANGES**: Use write_file for new files, edit_file for modifications. - - For edit_file: The "find" parameter must match file content EXACTLY (including whitespace) - - Include enough context in "find" to be unique - - For write_file: Provide complete file content +2. **MAKE CHANGES**: Use write for new files, edit for modifications. + - For edit: The "old_string" parameter must match file content EXACTLY (including whitespace) + - Include enough context in "old_string" to be unique + - For write: Provide complete file content 3. **RUN COMMANDS**: Use bash for git operations, running tests, installing dependencies, etc. @@ -41,16 +40,16 @@ You have access to these tools - USE THEM to accomplish tasks: User: "Create a new React component for a login form" Your approach: -1. Use list_directory to see project structure -2. Use read_file to check existing component patterns -3. Use write_file to create the new component file -4. Use write_file to create a test file if appropriate +1. Use glob to see project structure (glob pattern="**/*.tsx") +2. Use view to check existing component patterns +3. Use write to create the new component file +4. Use write to create a test file if appropriate 5. Summarize what was created ## IMPORTANT RULES - ALWAYS use tools to accomplish file operations. Don't just describe what to do - DO IT. -- Read files before editing to ensure your "find" string matches exactly. +- Read files before editing to ensure your "old_string" matches exactly. - When creating files, write complete, working code. - When editing, preserve existing code style and conventions. - If a file path is provided, use it. If not, infer from context. @@ -68,13 +67,12 @@ M.tool_instructions = [[ ## TOOL USAGE When you need to perform an action, call the appropriate tool. You can call tools to: -- Read files to understand code -- Create new files with write_file -- Modify existing files with edit_file (read first!) -- Delete files with delete_file -- List directories to explore structure -- Search for files by name or content -- Run shell commands with bash +- Read files with view (parameters: path) +- Create new files with write (parameters: path, content) +- Modify existing files with edit (parameters: path, old_string, new_string) - read first! +- Find files by pattern with glob (parameters: pattern, path) +- Search file contents with grep (parameters: pattern, path) +- Run shell commands with bash (parameters: command) After receiving a tool result, continue working: - If more actions are needed, call another tool @@ -82,11 +80,11 @@ After receiving a tool result, continue working: ## CRITICAL RULES -1. **Always read before editing**: Use read_file before edit_file to ensure exact matches -2. **Be precise with edits**: The "find" parameter must match the file content EXACTLY -3. **Create complete files**: When using write_file, provide fully working code -4. **User approval required**: File writes, edits, deletes, and bash commands need approval -5. **Don't guess**: If unsure about file structure, use list_directory or search_files +1. **Always read before editing**: Use view before edit to ensure exact matches +2. **Be precise with edits**: The "old_string" parameter must match the file content EXACTLY +3. **Create complete files**: When using write, provide fully working code +4. **User approval required**: File writes, edits, and bash commands need approval +5. **Don't guess**: If unsure about file structure, use glob or grep ]] --- Prompt for when agent finishes diff --git a/tests/spec/agent_tools_spec.lua b/tests/spec/agent_tools_spec.lua new file mode 100644 index 0000000..b7f31f5 --- /dev/null +++ b/tests/spec/agent_tools_spec.lua @@ -0,0 +1,427 @@ +--- Tests for agent tools system + +describe("codetyper.agent.tools", function() + local tools + + before_each(function() + tools = require("codetyper.agent.tools") + -- Clear any existing registrations + for name, _ in pairs(tools.get_all()) do + tools.unregister(name) + end + end) + + describe("tool registration", function() + it("should register a tool", function() + local test_tool = { + name = "test_tool", + description = "A test tool", + params = { + { name = "input", type = "string", description = "Test input" }, + }, + func = function(input, opts) + return "result", nil + end, + } + + tools.register(test_tool) + local retrieved = tools.get("test_tool") + + assert.is_not_nil(retrieved) + assert.equals("test_tool", retrieved.name) + end) + + it("should unregister a tool", function() + local test_tool = { + name = "temp_tool", + description = "Temporary", + func = function() end, + } + + tools.register(test_tool) + assert.is_not_nil(tools.get("temp_tool")) + + tools.unregister("temp_tool") + assert.is_nil(tools.get("temp_tool")) + end) + + it("should list all tools", function() + tools.register({ name = "tool1", func = function() end }) + tools.register({ name = "tool2", func = function() end }) + tools.register({ name = "tool3", func = function() end }) + + local list = tools.list() + assert.equals(3, #list) + end) + + it("should filter tools with predicate", function() + tools.register({ name = "safe_tool", requires_confirmation = false, func = function() end }) + tools.register({ name = "dangerous_tool", requires_confirmation = true, func = function() end }) + + local safe_list = tools.list(function(t) + return not t.requires_confirmation + end) + + assert.equals(1, #safe_list) + assert.equals("safe_tool", safe_list[1].name) + end) + end) + + describe("tool execution", function() + it("should execute a tool and return result", function() + tools.register({ + name = "adder", + params = { + { name = "a", type = "number" }, + { name = "b", type = "number" }, + }, + func = function(input, opts) + return input.a + input.b, nil + end, + }) + + local result, err = tools.execute("adder", { a = 5, b = 3 }, {}) + + assert.is_nil(err) + assert.equals(8, result) + end) + + it("should return error for unknown tool", function() + local result, err = tools.execute("nonexistent", {}, {}) + + assert.is_nil(result) + assert.truthy(err:match("Unknown tool")) + end) + + it("should track execution history", function() + tools.clear_history() + tools.register({ + name = "tracked_tool", + func = function() + return "done", nil + end, + }) + + tools.execute("tracked_tool", {}, {}) + tools.execute("tracked_tool", {}, {}) + + local history = tools.get_history() + assert.equals(2, #history) + assert.equals("tracked_tool", history[1].tool) + assert.equals("completed", history[1].status) + end) + end) + + describe("tool schemas", function() + it("should generate JSON schema for tools", function() + tools.register({ + name = "schema_test", + description = "Test schema generation", + params = { + { name = "required_param", type = "string", description = "A required param" }, + { name = "optional_param", type = "number", description = "Optional", optional = true }, + }, + returns = { + { name = "result", type = "string" }, + }, + to_schema = require("codetyper.agent.tools.base").to_schema, + func = function() end, + }) + + local schemas = tools.get_schemas() + assert.equals(1, #schemas) + + local schema = schemas[1] + assert.equals("function", schema.type) + assert.equals("schema_test", schema.function_def.name) + assert.is_not_nil(schema.function_def.parameters.properties.required_param) + assert.is_not_nil(schema.function_def.parameters.properties.optional_param) + end) + end) + + describe("process_tool_call", function() + it("should process tool call with name and input", function() + tools.register({ + name = "processor_test", + func = function(input, opts) + return "processed: " .. input.value, nil + end, + }) + + local result, err = tools.process_tool_call({ + name = "processor_test", + input = { value = "test" }, + }, {}) + + assert.is_nil(err) + assert.equals("processed: test", result) + end) + + it("should parse JSON string arguments", function() + tools.register({ + name = "json_parser_test", + func = function(input, opts) + return input.key, nil + end, + }) + + local result, err = tools.process_tool_call({ + name = "json_parser_test", + arguments = '{"key": "value"}', + }, {}) + + assert.is_nil(err) + assert.equals("value", result) + end) + end) +end) + +describe("codetyper.agent.tools.base", function() + local base + + before_each(function() + base = require("codetyper.agent.tools.base") + end) + + describe("validate_input", function() + it("should validate required parameters", function() + local tool = setmetatable({ + params = { + { name = "required", type = "string" }, + { name = "optional", type = "string", optional = true }, + }, + }, base) + + local valid, err = tool:validate_input({ required = "value" }) + assert.is_true(valid) + assert.is_nil(err) + end) + + it("should fail on missing required parameter", function() + local tool = setmetatable({ + params = { + { name = "required", type = "string" }, + }, + }, base) + + local valid, err = tool:validate_input({}) + assert.is_false(valid) + assert.truthy(err:match("Missing required parameter")) + end) + + it("should validate parameter types", function() + local tool = setmetatable({ + params = { + { name = "num", type = "number" }, + }, + }, base) + + local valid1, _ = tool:validate_input({ num = 42 }) + assert.is_true(valid1) + + local valid2, err2 = tool:validate_input({ num = "not a number" }) + assert.is_false(valid2) + assert.truthy(err2:match("must be number")) + end) + + it("should validate integer type", function() + local tool = setmetatable({ + params = { + { name = "int", type = "integer" }, + }, + }, base) + + local valid1, _ = tool:validate_input({ int = 42 }) + assert.is_true(valid1) + + local valid2, err2 = tool:validate_input({ int = 42.5 }) + assert.is_false(valid2) + assert.truthy(err2:match("must be an integer")) + end) + end) + + describe("get_description", function() + it("should return string description", function() + local tool = setmetatable({ + description = "Static description", + }, base) + + assert.equals("Static description", tool:get_description()) + end) + + it("should call function description", function() + local tool = setmetatable({ + description = function() + return "Dynamic description" + end, + }, base) + + assert.equals("Dynamic description", tool:get_description()) + end) + end) + + describe("to_schema", function() + it("should generate valid schema", function() + local tool = setmetatable({ + name = "test", + description = "Test tool", + params = { + { name = "input", type = "string", description = "Input value" }, + { name = "count", type = "integer", description = "Count", optional = true }, + }, + }, base) + + local schema = tool:to_schema() + + assert.equals("function", schema.type) + assert.equals("test", schema.function_def.name) + assert.equals("Test tool", schema.function_def.description) + assert.equals("object", schema.function_def.parameters.type) + assert.is_not_nil(schema.function_def.parameters.properties.input) + assert.is_not_nil(schema.function_def.parameters.properties.count) + assert.same({ "input" }, schema.function_def.parameters.required) + end) + end) +end) + +describe("built-in tools", function() + describe("view tool", function() + local view + + before_each(function() + view = require("codetyper.agent.tools.view") + end) + + it("should have required fields", function() + assert.equals("view", view.name) + assert.is_string(view.description) + assert.is_table(view.params) + assert.is_function(view.func) + end) + + it("should require path parameter", function() + local result, err = view.func({}, {}) + assert.is_nil(result) + assert.truthy(err:match("path is required")) + end) + end) + + describe("grep tool", function() + local grep + + before_each(function() + grep = require("codetyper.agent.tools.grep") + end) + + it("should have required fields", function() + assert.equals("grep", grep.name) + assert.is_string(grep.description) + assert.is_table(grep.params) + assert.is_function(grep.func) + end) + + it("should require pattern parameter", function() + local result, err = grep.func({}, {}) + assert.is_nil(result) + assert.truthy(err:match("pattern is required")) + end) + end) + + describe("glob tool", function() + local glob + + before_each(function() + glob = require("codetyper.agent.tools.glob") + end) + + it("should have required fields", function() + assert.equals("glob", glob.name) + assert.is_string(glob.description) + assert.is_table(glob.params) + assert.is_function(glob.func) + end) + + it("should require pattern parameter", function() + local result, err = glob.func({}, {}) + assert.is_nil(result) + assert.truthy(err:match("pattern is required")) + end) + end) + + describe("edit tool", function() + local edit + + before_each(function() + edit = require("codetyper.agent.tools.edit") + end) + + it("should have required fields", function() + assert.equals("edit", edit.name) + assert.is_string(edit.description) + assert.is_table(edit.params) + assert.is_function(edit.func) + end) + + it("should require path parameter", function() + local result, err = edit.func({}, {}) + assert.is_nil(result) + assert.truthy(err:match("path is required")) + end) + + it("should require old_string parameter", function() + local result, err = edit.func({ path = "/tmp/test" }, {}) + assert.is_nil(result) + assert.truthy(err:match("old_string is required")) + end) + end) + + describe("write tool", function() + local write + + before_each(function() + write = require("codetyper.agent.tools.write") + end) + + it("should have required fields", function() + assert.equals("write", write.name) + assert.is_string(write.description) + assert.is_table(write.params) + assert.is_function(write.func) + end) + + it("should require path parameter", function() + local result, err = write.func({}, {}) + assert.is_nil(result) + assert.truthy(err:match("path is required")) + end) + + it("should require content parameter", function() + local result, err = write.func({ path = "/tmp/test" }, {}) + assert.is_nil(result) + assert.truthy(err:match("content is required")) + end) + end) + + describe("bash tool", function() + local bash + + before_each(function() + bash = require("codetyper.agent.tools.bash") + end) + + it("should have required fields", function() + assert.equals("bash", bash.name) + assert.is_function(bash.func) + end) + + it("should require command parameter", function() + local result, err = bash.func({}, {}) + assert.is_nil(result) + assert.truthy(err:match("command is required")) + end) + + it("should require confirmation by default", function() + assert.is_true(bash.requires_confirmation) + end) + end) +end) diff --git a/tests/spec/agentic_spec.lua b/tests/spec/agentic_spec.lua new file mode 100644 index 0000000..8597694 --- /dev/null +++ b/tests/spec/agentic_spec.lua @@ -0,0 +1,312 @@ +---@diagnostic disable: undefined-global +-- Unit tests for the agentic system + +describe("agentic module", function() + local agentic + + before_each(function() + -- Reset and reload + package.loaded["codetyper.agent.agentic"] = nil + agentic = require("codetyper.agent.agentic") + end) + + it("should list built-in agents", function() + local agents = agentic.list_agents() + assert.is_table(agents) + assert.is_true(#agents >= 3) -- coder, planner, explorer + + local names = {} + for _, agent in ipairs(agents) do + names[agent.name] = true + end + + assert.is_true(names["coder"]) + assert.is_true(names["planner"]) + assert.is_true(names["explorer"]) + end) + + it("should have description for each agent", function() + local agents = agentic.list_agents() + for _, agent in ipairs(agents) do + assert.is_string(agent.description) + assert.is_true(#agent.description > 0) + end + end) + + it("should mark built-in agents as builtin", function() + local agents = agentic.list_agents() + local coder = nil + for _, agent in ipairs(agents) do + if agent.name == "coder" then + coder = agent + break + end + end + assert.is_not_nil(coder) + assert.is_true(coder.builtin) + end) + + it("should have init function to create directories", function() + assert.is_function(agentic.init) + assert.is_function(agentic.init_agents_dir) + assert.is_function(agentic.init_rules_dir) + end) + + it("should have run function for executing tasks", function() + assert.is_function(agentic.run) + end) +end) + +describe("tools format conversion", function() + local tools_module + + before_each(function() + package.loaded["codetyper.agent.tools"] = nil + tools_module = require("codetyper.agent.tools") + -- Load tools + if tools_module.load_builtins then + pcall(tools_module.load_builtins) + end + end) + + it("should have to_openai_format function", function() + assert.is_function(tools_module.to_openai_format) + end) + + it("should have to_claude_format function", function() + assert.is_function(tools_module.to_claude_format) + end) + + it("should convert tools to OpenAI format", function() + local openai_tools = tools_module.to_openai_format() + assert.is_table(openai_tools) + + -- If tools are loaded, check format + if #openai_tools > 0 then + local first_tool = openai_tools[1] + assert.equals("function", first_tool.type) + assert.is_table(first_tool["function"]) + assert.is_string(first_tool["function"].name) + end + end) + + it("should convert tools to Claude format", function() + local claude_tools = tools_module.to_claude_format() + assert.is_table(claude_tools) + + -- If tools are loaded, check format + if #claude_tools > 0 then + local first_tool = claude_tools[1] + assert.is_string(first_tool.name) + assert.is_table(first_tool.input_schema) + end + end) +end) + +describe("edit tool", function() + local edit_tool + + before_each(function() + package.loaded["codetyper.agent.tools.edit"] = nil + edit_tool = require("codetyper.agent.tools.edit") + end) + + it("should have name 'edit'", function() + assert.equals("edit", edit_tool.name) + end) + + it("should have description mentioning matching strategies", function() + local desc = edit_tool:get_description() + assert.is_string(desc) + -- Should mention the matching capabilities + assert.is_true(desc:lower():match("match") ~= nil or desc:lower():match("replac") ~= nil) + end) + + it("should have params defined", function() + assert.is_table(edit_tool.params) + assert.is_true(#edit_tool.params >= 3) -- path, old_string, new_string + end) + + it("should require path parameter", function() + local valid, err = edit_tool:validate_input({ + old_string = "test", + new_string = "test2", + }) + assert.is_false(valid) + assert.is_string(err) + end) + + it("should require old_string parameter", function() + local valid, err = edit_tool:validate_input({ + path = "/test", + new_string = "test", + }) + assert.is_false(valid) + end) + + it("should require new_string parameter", function() + local valid, err = edit_tool:validate_input({ + path = "/test", + old_string = "test", + }) + assert.is_false(valid) + end) + + it("should accept empty old_string for new file creation", function() + local valid, err = edit_tool:validate_input({ + path = "/test/new_file.lua", + old_string = "", + new_string = "new content", + }) + assert.is_true(valid) + assert.is_nil(err) + end) + + it("should have func implementation", function() + assert.is_function(edit_tool.func) + end) +end) + +describe("view tool", function() + local view_tool + + before_each(function() + package.loaded["codetyper.agent.tools.view"] = nil + view_tool = require("codetyper.agent.tools.view") + end) + + it("should have name 'view'", function() + assert.equals("view", view_tool.name) + end) + + it("should require path parameter", function() + local valid, err = view_tool:validate_input({}) + assert.is_false(valid) + end) + + it("should accept valid path", function() + local valid, err = view_tool:validate_input({ + path = "/test/file.lua", + }) + assert.is_true(valid) + end) +end) + +describe("write tool", function() + local write_tool + + before_each(function() + package.loaded["codetyper.agent.tools.write"] = nil + write_tool = require("codetyper.agent.tools.write") + end) + + it("should have name 'write'", function() + assert.equals("write", write_tool.name) + end) + + it("should require path and content parameters", function() + local valid, err = write_tool:validate_input({}) + assert.is_false(valid) + + valid, err = write_tool:validate_input({ path = "/test" }) + assert.is_false(valid) + end) + + it("should accept valid input", function() + local valid, err = write_tool:validate_input({ + path = "/test/file.lua", + content = "test content", + }) + assert.is_true(valid) + end) +end) + +describe("grep tool", function() + local grep_tool + + before_each(function() + package.loaded["codetyper.agent.tools.grep"] = nil + grep_tool = require("codetyper.agent.tools.grep") + end) + + it("should have name 'grep'", function() + assert.equals("grep", grep_tool.name) + end) + + it("should require pattern parameter", function() + local valid, err = grep_tool:validate_input({}) + assert.is_false(valid) + end) + + it("should accept valid pattern", function() + local valid, err = grep_tool:validate_input({ + pattern = "function.*test", + }) + assert.is_true(valid) + end) +end) + +describe("glob tool", function() + local glob_tool + + before_each(function() + package.loaded["codetyper.agent.tools.glob"] = nil + glob_tool = require("codetyper.agent.tools.glob") + end) + + it("should have name 'glob'", function() + assert.equals("glob", glob_tool.name) + end) + + it("should require pattern parameter", function() + local valid, err = glob_tool:validate_input({}) + assert.is_false(valid) + end) + + it("should accept valid pattern", function() + local valid, err = glob_tool:validate_input({ + pattern = "**/*.lua", + }) + assert.is_true(valid) + end) +end) + +describe("base tool", function() + local Base + + before_each(function() + package.loaded["codetyper.agent.tools.base"] = nil + Base = require("codetyper.agent.tools.base") + end) + + it("should have validate_input method", function() + assert.is_function(Base.validate_input) + end) + + it("should have to_schema method", function() + assert.is_function(Base.to_schema) + end) + + it("should have get_description method", function() + assert.is_function(Base.get_description) + end) + + it("should generate valid schema", function() + local test_tool = setmetatable({ + name = "test", + description = "A test tool", + params = { + { name = "arg1", type = "string", description = "First arg" }, + { name = "arg2", type = "number", description = "Second arg", optional = true }, + }, + }, Base) + + local schema = test_tool:to_schema() + assert.equals("function", schema.type) + assert.equals("test", schema.function_def.name) + assert.is_table(schema.function_def.parameters.properties) + assert.is_table(schema.function_def.parameters.required) + assert.is_true(vim.tbl_contains(schema.function_def.parameters.required, "arg1")) + assert.is_false(vim.tbl_contains(schema.function_def.parameters.required, "arg2")) + end) +end) diff --git a/tests/spec/brain_learners_spec.lua b/tests/spec/brain_learners_spec.lua new file mode 100644 index 0000000..af9e774 --- /dev/null +++ b/tests/spec/brain_learners_spec.lua @@ -0,0 +1,153 @@ +--- Tests for brain/learners pattern detection and extraction +describe("brain.learners", function() + local pattern_learner + + before_each(function() + -- Clear module cache + package.loaded["codetyper.brain.learners.pattern"] = nil + package.loaded["codetyper.brain.types"] = nil + + pattern_learner = require("codetyper.brain.learners.pattern") + end) + + describe("pattern learner detection", function() + it("should detect code_completion events", function() + local event = { type = "code_completion", data = {} } + assert.is_true(pattern_learner.detect(event)) + end) + + it("should detect file_indexed events", function() + local event = { type = "file_indexed", data = {} } + assert.is_true(pattern_learner.detect(event)) + end) + + it("should detect code_analyzed events", function() + local event = { type = "code_analyzed", data = {} } + assert.is_true(pattern_learner.detect(event)) + end) + + it("should detect pattern_detected events", function() + local event = { type = "pattern_detected", data = {} } + assert.is_true(pattern_learner.detect(event)) + end) + + it("should NOT detect plain 'pattern' type events", function() + -- This was the bug - 'pattern' type was not in the valid_types list + local event = { type = "pattern", data = {} } + assert.is_false(pattern_learner.detect(event)) + end) + + it("should NOT detect unknown event types", function() + local event = { type = "unknown_type", data = {} } + assert.is_false(pattern_learner.detect(event)) + end) + + it("should NOT detect nil events", function() + assert.is_false(pattern_learner.detect(nil)) + end) + + it("should NOT detect events without type", function() + local event = { data = {} } + assert.is_false(pattern_learner.detect(event)) + end) + end) + + describe("pattern learner extraction", function() + it("should extract from pattern_detected events", function() + local event = { + type = "pattern_detected", + file = "/path/to/file.lua", + data = { + name = "Test pattern", + description = "Pattern description", + language = "lua", + symbols = { "func1", "func2" }, + }, + } + + local extracted = pattern_learner.extract(event) + + assert.is_not_nil(extracted) + assert.equals("Test pattern", extracted.summary) + assert.equals("Pattern description", extracted.detail) + assert.equals("lua", extracted.lang) + assert.equals("/path/to/file.lua", extracted.file) + end) + + it("should handle pattern_detected with minimal data", function() + local event = { + type = "pattern_detected", + file = "/path/to/file.lua", + data = { + name = "Minimal pattern", + }, + } + + local extracted = pattern_learner.extract(event) + + assert.is_not_nil(extracted) + assert.equals("Minimal pattern", extracted.summary) + assert.equals("Minimal pattern", extracted.detail) + end) + + it("should extract from code_completion events", function() + local event = { + type = "code_completion", + file = "/path/to/file.lua", + data = { + intent = "add function", + code = "function test() end", + language = "lua", + }, + } + + local extracted = pattern_learner.extract(event) + + assert.is_not_nil(extracted) + assert.is_true(extracted.summary:find("Code pattern") ~= nil) + assert.equals("function test() end", extracted.detail) + end) + end) + + describe("should_learn validation", function() + it("should accept valid patterns", function() + local data = { + summary = "Valid pattern summary", + detail = "This is a detailed description of the pattern", + } + assert.is_true(pattern_learner.should_learn(data)) + end) + + it("should reject patterns without summary", function() + local data = { + summary = "", + detail = "Some detail", + } + assert.is_false(pattern_learner.should_learn(data)) + end) + + it("should reject patterns with nil summary", function() + local data = { + summary = nil, + detail = "Some detail", + } + assert.is_false(pattern_learner.should_learn(data)) + end) + + it("should reject patterns with very short detail", function() + local data = { + summary = "Valid summary", + detail = "short", -- Less than 10 chars + } + assert.is_false(pattern_learner.should_learn(data)) + end) + + it("should reject whitespace-only summaries", function() + local data = { + summary = " ", + detail = "Some valid detail here", + } + assert.is_false(pattern_learner.should_learn(data)) + end) + end) +end) diff --git a/tests/spec/coder_context_spec.lua b/tests/spec/coder_context_spec.lua new file mode 100644 index 0000000..60121c7 --- /dev/null +++ b/tests/spec/coder_context_spec.lua @@ -0,0 +1,194 @@ +--- Tests for coder file context injection +describe("coder context injection", function() + local test_dir + local original_filereadable + + before_each(function() + test_dir = "/tmp/codetyper_coder_test_" .. os.time() + vim.fn.mkdir(test_dir, "p") + + -- Store original function + original_filereadable = vim.fn.filereadable + end) + + after_each(function() + vim.fn.delete(test_dir, "rf") + vim.fn.filereadable = original_filereadable + end) + + describe("get_coder_companion_path logic", function() + -- Test the path generation logic (simulating the function behavior) + local function get_coder_companion_path(target_path, file_exists_check) + if not target_path or target_path == "" then + return nil + end + + -- Skip if target is already a coder file + if target_path:match("%.coder%.") then + return nil + end + + local dir = vim.fn.fnamemodify(target_path, ":h") + local name = vim.fn.fnamemodify(target_path, ":t:r") + local ext = vim.fn.fnamemodify(target_path, ":e") + + local coder_path = dir .. "/" .. name .. ".coder." .. ext + if file_exists_check(coder_path) then + return coder_path + end + + return nil + end + + it("should generate correct coder path for source file", function() + local target = "/path/to/file.ts" + local expected = "/path/to/file.coder.ts" + + local path = get_coder_companion_path(target, function() return true end) + + assert.equals(expected, path) + end) + + it("should return nil for empty path", function() + local path = get_coder_companion_path("", function() return true end) + assert.is_nil(path) + end) + + it("should return nil for nil path", function() + local path = get_coder_companion_path(nil, function() return true end) + assert.is_nil(path) + end) + + it("should return nil for coder files (avoid recursion)", function() + local target = "/path/to/file.coder.ts" + local path = get_coder_companion_path(target, function() return true end) + assert.is_nil(path) + end) + + it("should return nil if coder file doesn't exist", function() + local target = "/path/to/file.ts" + local path = get_coder_companion_path(target, function() return false end) + assert.is_nil(path) + end) + + it("should handle files with multiple dots", function() + local target = "/path/to/my.component.ts" + local expected = "/path/to/my.component.coder.ts" + + local path = get_coder_companion_path(target, function() return true end) + + assert.equals(expected, path) + end) + + it("should handle different extensions", function() + local test_cases = { + { target = "/path/file.lua", expected = "/path/file.coder.lua" }, + { target = "/path/file.py", expected = "/path/file.coder.py" }, + { target = "/path/file.js", expected = "/path/file.coder.js" }, + { target = "/path/file.go", expected = "/path/file.coder.go" }, + } + + for _, tc in ipairs(test_cases) do + local path = get_coder_companion_path(tc.target, function() return true end) + assert.equals(tc.expected, path, "Failed for: " .. tc.target) + end + end) + end) + + describe("coder content filtering", function() + -- Test the filtering logic that skips template-only content + local function has_meaningful_content(lines) + for _, line in ipairs(lines) do + local trimmed = line:gsub("^%s*", "") + if not trimmed:match("^[%-#/]+%s*Coder companion") + and not trimmed:match("^[%-#/]+%s*Use /@ @/") + and not trimmed:match("^[%-#/]+%s*Example:") + and not trimmed:match("^Hello; +}]] + local result = inject.parse_code(code, "typescript") + + assert.equals(2, #result.imports) + assert.truthy(result.imports[1]:match("useState")) + assert.truthy(result.imports[2]:match("Button")) + assert.truthy(#result.body > 0) + end) + + it("should detect ES6 default imports", function() + local code = [[import React from 'react'; +import axios from 'axios'; + +const api = axios.create();]] + local result = inject.parse_code(code, "javascript") + + assert.equals(2, #result.imports) + assert.truthy(result.imports[1]:match("React")) + assert.truthy(result.imports[2]:match("axios")) + end) + + it("should detect require imports", function() + local code = [[const fs = require('fs'); +const path = require('path'); + +module.exports = { fs, path };]] + local result = inject.parse_code(code, "javascript") + + assert.equals(2, #result.imports) + assert.truthy(result.imports[1]:match("fs")) + assert.truthy(result.imports[2]:match("path")) + end) + + it("should detect multi-line imports", function() + local code = [[import { + useState, + useEffect, + useCallback +} from 'react'; + +function Component() {}]] + local result = inject.parse_code(code, "typescript") + + assert.equals(1, #result.imports) + assert.truthy(result.imports[1]:match("useState")) + assert.truthy(result.imports[1]:match("useCallback")) + end) + + it("should detect namespace imports", function() + local code = [[import * as React from 'react'; + +export default React;]] + local result = inject.parse_code(code, "tsx") + + assert.equals(1, #result.imports) + assert.truthy(result.imports[1]:match("%* as React")) + end) + end) + + describe("Python", function() + it("should detect simple imports", function() + local code = [[import os +import sys +import json + +def main(): + pass]] + local result = inject.parse_code(code, "python") + + assert.equals(3, #result.imports) + assert.truthy(result.imports[1]:match("import os")) + assert.truthy(result.imports[2]:match("import sys")) + assert.truthy(result.imports[3]:match("import json")) + end) + + it("should detect from imports", function() + local code = [[from typing import List, Dict +from pathlib import Path + +def process(items: List[str]) -> None: + pass]] + local result = inject.parse_code(code, "py") + + assert.equals(2, #result.imports) + assert.truthy(result.imports[1]:match("from typing")) + assert.truthy(result.imports[2]:match("from pathlib")) + end) + end) + + describe("Lua", function() + it("should detect require statements", function() + local code = [[local M = {} +local utils = require("codetyper.utils") +local config = require('codetyper.config') + +function M.setup() +end + +return M]] + local result = inject.parse_code(code, "lua") + + assert.equals(2, #result.imports) + assert.truthy(result.imports[1]:match("utils")) + assert.truthy(result.imports[2]:match("config")) + end) + end) + + describe("Go", function() + it("should detect single imports", function() + local code = [[package main + +import "fmt" + +func main() { + fmt.Println("Hello") +}]] + local result = inject.parse_code(code, "go") + + assert.equals(1, #result.imports) + assert.truthy(result.imports[1]:match('import "fmt"')) + end) + + it("should detect grouped imports", function() + local code = [[package main + +import ( + "fmt" + "os" + "strings" +) + +func main() {}]] + local result = inject.parse_code(code, "go") + + assert.equals(1, #result.imports) + assert.truthy(result.imports[1]:match("fmt")) + assert.truthy(result.imports[1]:match("os")) + end) + end) + + describe("Rust", function() + it("should detect use statements", function() + local code = [[use std::io; +use std::collections::HashMap; + +fn main() { + let map = HashMap::new(); +}]] + local result = inject.parse_code(code, "rs") + + assert.equals(2, #result.imports) + assert.truthy(result.imports[1]:match("std::io")) + assert.truthy(result.imports[2]:match("HashMap")) + end) + end) + + describe("C/C++", function() + it("should detect include statements", function() + local code = [[#include +#include "myheader.h" + +int main() { + return 0; +}]] + local result = inject.parse_code(code, "c") + + assert.equals(2, #result.imports) + assert.truthy(result.imports[1]:match("stdio")) + assert.truthy(result.imports[2]:match("myheader")) + end) + end) + end) + + describe("merge_imports", function() + it("should merge without duplicates", function() + local existing = { + "import { useState } from 'react';", + "import { Button } from './components';", + } + local new_imports = { + "import { useEffect } from 'react';", + "import { useState } from 'react';", -- duplicate + "import { Card } from './components';", + } + + local merged = inject.merge_imports(existing, new_imports) + + assert.equals(4, #merged) -- Should not have duplicate useState + end) + + it("should handle empty existing imports", function() + local existing = {} + local new_imports = { + "import os", + "import sys", + } + + local merged = inject.merge_imports(existing, new_imports) + + assert.equals(2, #merged) + end) + + it("should handle empty new imports", function() + local existing = { + "import os", + "import sys", + } + local new_imports = {} + + local merged = inject.merge_imports(existing, new_imports) + + assert.equals(2, #merged) + end) + + it("should handle whitespace variations in duplicates", function() + local existing = { + "import { useState } from 'react';", + } + local new_imports = { + "import {useState} from 'react';", -- Same but different spacing + } + + local merged = inject.merge_imports(existing, new_imports) + + assert.equals(1, #merged) -- Should detect as duplicate + end) + end) + + describe("sort_imports", function() + it("should group imports by type for JavaScript", function() + local imports = { + "import React from 'react';", + "import { Button } from './components';", + "import axios from 'axios';", + "import path from 'path';", + } + + local sorted = inject.sort_imports(imports, "javascript") + + -- Check ordering: builtin -> third-party -> local + local found_builtin = false + local found_local = false + local builtin_pos = 0 + local local_pos = 0 + + for i, imp in ipairs(sorted) do + if imp:match("path") then + found_builtin = true + builtin_pos = i + end + if imp:match("%.%/") then + found_local = true + local_pos = i + end + end + + -- Local imports should come after third-party + if found_local and found_builtin then + assert.truthy(local_pos > builtin_pos) + end + end) + end) + + describe("has_imports", function() + it("should return true when code has imports", function() + local code = [[import { useState } from 'react'; +function App() {}]] + + assert.is_true(inject.has_imports(code, "typescript")) + end) + + it("should return false when code has no imports", function() + local code = [[function App() { + return
Hello
; +}]] + + assert.is_false(inject.has_imports(code, "typescript")) + end) + + it("should detect Python imports", function() + local code = [[from typing import List + +def process(items: List[str]): + pass]] + + assert.is_true(inject.has_imports(code, "python")) + end) + + it("should detect Lua requires", function() + local code = [[local utils = require("utils") + +local M = {} +return M]] + + assert.is_true(inject.has_imports(code, "lua")) + end) + end) + + describe("edge cases", function() + it("should handle empty code", function() + local result = inject.parse_code("", "javascript") + + assert.equals(0, #result.imports) + assert.equals(1, #result.body) -- Empty string becomes one empty line + end) + + it("should handle code with only imports", function() + local code = [[import React from 'react'; +import { useState } from 'react';]] + + local result = inject.parse_code(code, "javascript") + + assert.equals(2, #result.imports) + assert.equals(0, #result.body) + end) + + it("should handle code with only body", function() + local code = [[function hello() { + console.log("Hello"); +}]] + + local result = inject.parse_code(code, "javascript") + + assert.equals(0, #result.imports) + assert.truthy(#result.body > 0) + end) + + it("should handle imports in string literals (not detect as imports)", function() + local code = [[const example = "import { fake } from 'not-real';"; +const config = { import: true }; + +function test() {}]] + + local result = inject.parse_code(code, "javascript") + + -- The first line looks like an import but is in a string + -- This is a known limitation - we accept some false positives + -- The important thing is we don't break the code + assert.truthy(#result.body >= 0) + end) + + it("should handle mixed import styles in same file", function() + local code = [[import React from 'react'; +const axios = require('axios'); +import { useState } from 'react'; + +function App() {}]] + + local result = inject.parse_code(code, "javascript") + + assert.equals(3, #result.imports) + end) + end) +end) diff --git a/tests/spec/llm_selector_spec.lua b/tests/spec/llm_selector_spec.lua new file mode 100644 index 0000000..1a0fd41 --- /dev/null +++ b/tests/spec/llm_selector_spec.lua @@ -0,0 +1,174 @@ +--- Tests for smart LLM selection with memory-based confidence + +describe("codetyper.llm.selector", function() + local selector + + before_each(function() + selector = require("codetyper.llm.selector") + -- Reset stats for clean tests + selector.reset_accuracy_stats() + end) + + describe("select_provider", function() + it("should return copilot when no brain memories exist", function() + local result = selector.select_provider("write a function", { + file_path = "/test/file.lua", + }) + + assert.equals("copilot", result.provider) + assert.equals(0, result.memory_count) + assert.truthy(result.reason:match("Insufficient context")) + end) + + it("should return a valid selection result structure", function() + local result = selector.select_provider("test prompt", {}) + + assert.is_string(result.provider) + assert.is_number(result.confidence) + assert.is_number(result.memory_count) + assert.is_string(result.reason) + end) + + it("should have confidence between 0 and 1", function() + local result = selector.select_provider("test", {}) + + assert.truthy(result.confidence >= 0) + assert.truthy(result.confidence <= 1) + end) + end) + + describe("should_ponder", function() + it("should return true for medium confidence", function() + assert.is_true(selector.should_ponder(0.5)) + assert.is_true(selector.should_ponder(0.6)) + end) + + it("should return false for low confidence", function() + assert.is_false(selector.should_ponder(0.2)) + assert.is_false(selector.should_ponder(0.3)) + end) + + -- High confidence pondering is probabilistic, so we test the range + it("should sometimes ponder for high confidence (sampling)", function() + -- Run multiple times to test probabilistic behavior + local pondered_count = 0 + for _ = 1, 100 do + if selector.should_ponder(0.9) then + pondered_count = pondered_count + 1 + end + end + -- Should ponder roughly 20% of the time (PONDER_SAMPLE_RATE = 0.2) + -- Allow range of 5-40% due to randomness + assert.truthy(pondered_count >= 5, "Should ponder at least sometimes") + assert.truthy(pondered_count <= 40, "Should not ponder too often") + end) + end) + + describe("get_accuracy_stats", function() + it("should return initial empty stats", function() + local stats = selector.get_accuracy_stats() + + assert.equals(0, stats.ollama.total) + assert.equals(0, stats.ollama.correct) + assert.equals(0, stats.ollama.accuracy) + assert.equals(0, stats.copilot.total) + assert.equals(0, stats.copilot.correct) + assert.equals(0, stats.copilot.accuracy) + end) + end) + + describe("report_feedback", function() + it("should track positive feedback", function() + selector.report_feedback("ollama", true) + selector.report_feedback("ollama", true) + selector.report_feedback("ollama", false) + + local stats = selector.get_accuracy_stats() + assert.equals(3, stats.ollama.total) + assert.equals(2, stats.ollama.correct) + end) + + it("should track copilot feedback separately", function() + selector.report_feedback("ollama", true) + selector.report_feedback("copilot", true) + selector.report_feedback("copilot", false) + + local stats = selector.get_accuracy_stats() + assert.equals(1, stats.ollama.total) + assert.equals(2, stats.copilot.total) + end) + + it("should calculate accuracy correctly", function() + selector.report_feedback("ollama", true) + selector.report_feedback("ollama", true) + selector.report_feedback("ollama", true) + selector.report_feedback("ollama", false) + + local stats = selector.get_accuracy_stats() + assert.equals(0.75, stats.ollama.accuracy) + end) + end) + + describe("reset_accuracy_stats", function() + it("should clear all stats", function() + selector.report_feedback("ollama", true) + selector.report_feedback("copilot", true) + + selector.reset_accuracy_stats() + + local stats = selector.get_accuracy_stats() + assert.equals(0, stats.ollama.total) + assert.equals(0, stats.copilot.total) + end) + end) +end) + +describe("agreement calculation", function() + -- Test the internal agreement calculation through pondering behavior + -- Since calculate_agreement is local, we test its effects indirectly + + it("should detect high agreement for similar responses", function() + -- This is tested through the pondering system + -- When responses are similar, agreement should be high + local selector = require("codetyper.llm.selector") + + -- Verify that should_ponder returns predictable results + -- for medium confidence (where pondering always happens) + assert.is_true(selector.should_ponder(0.5)) + end) +end) + +describe("provider selection with accuracy history", function() + local selector + + before_each(function() + selector = require("codetyper.llm.selector") + selector.reset_accuracy_stats() + end) + + it("should factor in historical accuracy for selection", function() + -- Simulate high Ollama accuracy + for _ = 1, 10 do + selector.report_feedback("ollama", true) + end + + -- Even with no brain context, historical accuracy should influence confidence + local result = selector.select_provider("test", {}) + + -- Confidence should be higher due to historical accuracy + -- but provider might still be copilot if no memories + assert.is_number(result.confidence) + end) + + it("should have lower confidence for low historical accuracy", function() + -- Simulate low Ollama accuracy + for _ = 1, 10 do + selector.report_feedback("ollama", false) + end + + local result = selector.select_provider("test", {}) + + -- With bad history and no memories, should definitely use copilot + assert.equals("copilot", result.provider) + end) +end)