Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 258b29f5f0 | |||
| 6a69a524ea | |||
| 10c1de8843 | |||
| 4fb52596e3 | |||
| 9dfb52ac8d | |||
| c9be0cf804 | |||
| 60577f8951 | |||
| f5df1a9ac0 | |||
| 84c8bcf92c |
213
CHANGELOG.md
213
CHANGELOG.md
@@ -7,12 +7,114 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.6.0] - 2026-01-16
|
||||
|
||||
### Added
|
||||
|
||||
- **Conflict Resolution System** - Git-style diff visualization for code review
|
||||
- New `conflict.lua` module with full conflict management
|
||||
- Git-style markers: `<<<<<<< CURRENT`, `=======`, `>>>>>>> INCOMING`
|
||||
- Visual highlighting: green for original, blue for AI suggestions
|
||||
- Buffer-local keymaps: `co` (ours), `ct` (theirs), `cb` (both), `cn` (none)
|
||||
- Navigation keymaps: `]x` (next), `[x` (previous)
|
||||
- Floating menu with `cm` or `<CR>` on conflict
|
||||
- Number keys `1-4` for quick selection in menu
|
||||
- Auto-show menu after code injection
|
||||
- Auto-show menu for next conflict after resolution
|
||||
- Commands: `:CoderConflictToggle`, `:CoderConflictMenu`, `:CoderConflictNext`, `:CoderConflictPrev`, `:CoderConflictStatus`, `:CoderConflictResolveAll`, `:CoderConflictAcceptCurrent`, `:CoderConflictAcceptIncoming`, `:CoderConflictAcceptBoth`, `:CoderConflictAcceptNone`, `:CoderConflictAutoMenu`
|
||||
|
||||
- **Linter Validation System** - Auto-check and fix lint errors after code injection
|
||||
- New `linter.lua` module for LSP diagnostics integration
|
||||
- Auto-saves file after code injection
|
||||
- Waits for LSP diagnostics to update
|
||||
- Detects errors and warnings in injected code region
|
||||
- Auto-queues AI fix prompts for lint errors
|
||||
- Shows errors in quickfix list
|
||||
- Commands: `:CoderLintCheck`, `:CoderLintFix`, `:CoderLintQuickfix`, `:CoderLintToggleAuto`
|
||||
|
||||
- **SEARCH/REPLACE Block System** - Reliable code editing with fuzzy matching
|
||||
- New `search_replace.lua` module for reliable code editing
|
||||
- Parses SEARCH/REPLACE blocks from LLM responses
|
||||
- Fuzzy matching with configurable thresholds
|
||||
- Whitespace normalization for better matching
|
||||
- Multiple matching strategies: exact, normalized, line-by-line
|
||||
- Automatic fallback to line-based injection
|
||||
|
||||
- **Process and Show Menu Function** - Streamlined conflict handling
|
||||
- New `process_and_show_menu()` function combines processing and menu display
|
||||
- Ensures highlights and keymaps are set up before showing menu
|
||||
|
||||
### Changed
|
||||
|
||||
- Unified automatic and manual tag processing to use same code path
|
||||
- `insert_conflict()` now only inserts markers, callers handle processing
|
||||
- Added `nowait = true` to conflict keymaps to prevent delay from built-in `c` command
|
||||
- Improved patch application flow with conflict mode integration
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `string.gsub` returning two values causing `table.insert` errors
|
||||
- Fixed keymaps not triggering due to Neovim's `c` command intercepting first character
|
||||
- Fixed menu not showing after code injection
|
||||
- Fixed diff highlighting not appearing
|
||||
|
||||
---
|
||||
|
||||
## [0.5.0] - 2026-01-15
|
||||
|
||||
### Added
|
||||
|
||||
- **Cost Tracking System** - Track LLM API costs across sessions
|
||||
- New `:CoderCost` command opens cost estimation floating window
|
||||
- Session costs tracked in real-time
|
||||
- All-time costs persisted in `.coder/cost_history.json`
|
||||
- Per-model breakdown with token counts
|
||||
- Pricing database for 50+ models (GPT-4/5, Claude, O-series, Gemini)
|
||||
- Window keymaps: `q` close, `r` refresh, `c` clear session, `C` clear all
|
||||
|
||||
- **Automatic Ollama Fallback** - Graceful degradation when API limits hit
|
||||
- Automatically switches to Ollama when Copilot rate limits exceeded
|
||||
- Detects local Ollama availability before fallback
|
||||
- Notifies user of provider switch
|
||||
|
||||
- **Enhanced Error Handling** - Better error messages for API failures
|
||||
- Shows actual API response on parse errors
|
||||
- Improved rate limit detection and messaging
|
||||
- Sanitized newlines in error notifications
|
||||
|
||||
- **Agent Tools System Improvements**
|
||||
- New `to_openai_format()` and `to_claude_format()` functions
|
||||
- `get_definitions()` for generic tool access
|
||||
- Fixed tool call argument serialization
|
||||
|
||||
- **Credentials Management System** - Store API keys outside of config files
|
||||
- New `:CoderAddApiKey` command for interactive credential setup
|
||||
- `:CoderRemoveApiKey` to remove stored credentials
|
||||
- `:CoderCredentials` to view credential status
|
||||
- `:CoderSwitchProvider` to switch active LLM provider
|
||||
- Credentials stored in `~/.local/share/nvim/codetyper/configuration.json`
|
||||
|
||||
### Changed
|
||||
|
||||
- Cost window now shows both session and all-time statistics
|
||||
- Improved agent prompt templates with correct tool names
|
||||
- Better error context in LLM provider responses
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed "Failed to parse Copilot response" error showing instead of actual error
|
||||
- Fixed `nvim_buf_set_lines` crash from newlines in error messages
|
||||
- Fixed `tools.definitions` nil error in agent initialization
|
||||
- Fixed tool name mismatch in agent prompts
|
||||
|
||||
---
|
||||
|
||||
## [0.4.0] - 2026-01-13
|
||||
|
||||
### Added
|
||||
|
||||
- **Event-Driven Architecture** - Complete rewrite of prompt processing system
|
||||
- Prompts are now treated as events with metadata (buffer state, priority, timestamps)
|
||||
- Prompts are now treated as events with metadata
|
||||
- New modules: `queue.lua`, `patch.lua`, `confidence.lua`, `worker.lua`, `scheduler.lua`
|
||||
- Priority-based event queue with observer pattern
|
||||
- Buffer snapshots for staleness detection
|
||||
@@ -23,42 +125,33 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Configurable escalation threshold (default: 0.7)
|
||||
|
||||
- **Confidence Scoring** - Response quality heuristics
|
||||
- 5 weighted heuristics: length, uncertainty phrases, syntax completeness, repetition, truncation
|
||||
- 5 weighted heuristics: length, uncertainty, syntax, repetition, truncation
|
||||
- Scores range from 0.0-1.0
|
||||
- Determines whether to escalate to more capable LLM
|
||||
|
||||
- **Staleness Detection** - Safe patch application
|
||||
- Track `vim.b.changedtick` and content hash at prompt time
|
||||
- Discard patches if buffer changed during generation
|
||||
- Prevents stale code injection
|
||||
|
||||
- **Completion-Aware Injection** - No fighting with autocomplete
|
||||
- Defer code injection while completion popup visible
|
||||
- Works with native popup, nvim-cmp, and coq_nvim
|
||||
- Configurable delay after popup closes (default: 100ms)
|
||||
|
||||
- **Tree-sitter Scope Resolution** - Smart context extraction
|
||||
- Automatically resolves prompts to enclosing function/method/class
|
||||
- Falls back to heuristics when Tree-sitter unavailable
|
||||
- Scope types: function, method, class, block, file
|
||||
|
||||
- **Intent Detection** - Understands what you want
|
||||
- Parses prompts to detect: complete, refactor, fix, add, document, test, optimize, explain
|
||||
- Intent determines injection strategy (replace vs insert vs append)
|
||||
- Priority adjustment based on intent type
|
||||
|
||||
- **Tag Precedence Rules** - Multiple tags handled cleanly
|
||||
- First tag in scope wins (FIFO ordering)
|
||||
- Later tags in same scope skipped with warning
|
||||
- Different scopes process independently
|
||||
- Intent determines injection strategy
|
||||
|
||||
### Configuration
|
||||
|
||||
New `scheduler` configuration block:
|
||||
```lua
|
||||
scheduler = {
|
||||
enabled = true, -- Enable event-driven mode
|
||||
ollama_scout = true, -- Use Ollama first
|
||||
enabled = true,
|
||||
ollama_scout = true,
|
||||
escalation_threshold = 0.7,
|
||||
max_concurrent = 2,
|
||||
completion_delay_ms = 100,
|
||||
@@ -71,50 +164,32 @@ scheduler = {
|
||||
|
||||
### Added
|
||||
|
||||
- **Multiple LLM Providers** - Support for additional providers beyond Claude and Ollama
|
||||
- OpenAI API with custom endpoint support (Azure, OpenRouter, etc.)
|
||||
- **Multiple LLM Providers** - Support for additional providers
|
||||
- OpenAI API with custom endpoint support
|
||||
- Google Gemini API
|
||||
- GitHub Copilot (uses existing copilot.lua/copilot.vim authentication)
|
||||
- GitHub Copilot
|
||||
|
||||
- **Agent Mode** - Autonomous coding assistant with tool use
|
||||
- `read_file` - Read file contents
|
||||
- `edit_file` - Edit files with find/replace
|
||||
- `write_file` - Create or overwrite files
|
||||
- `bash` - Execute shell commands
|
||||
- `read_file`, `edit_file`, `write_file`, `bash` tools
|
||||
- Real-time logging of agent actions
|
||||
- `:CoderAgent`, `:CoderAgentToggle`, `:CoderAgentStop` commands
|
||||
|
||||
- **Transform Commands** - Transform /@ @/ tags inline without split view
|
||||
- `:CoderTransform` - Transform all tags in file
|
||||
- `:CoderTransformCursor` - Transform tag at cursor
|
||||
- `:CoderTransformVisual` - Transform selected tags
|
||||
- Default keymaps: `<leader>ctt` (cursor/visual), `<leader>ctT` (all)
|
||||
- **Transform Commands** - Transform /@ @/ tags inline
|
||||
- `:CoderTransform`, `:CoderTransformCursor`, `:CoderTransformVisual`
|
||||
- Default keymaps: `<leader>ctt`, `<leader>ctT`
|
||||
|
||||
- **Auto-Index Feature** - Automatically create coder companion files
|
||||
- Creates `.coder.` companion files when opening source files
|
||||
- Language-aware templates with correct comment syntax
|
||||
- `:CoderIndex` command to manually open companion
|
||||
- `<leader>ci` keymap
|
||||
- Configurable via `auto_index` option (disabled by default)
|
||||
- Language-aware templates
|
||||
|
||||
- **Logs Panel** - Real-time visibility into LLM operations
|
||||
- Token usage tracking (prompt and completion tokens)
|
||||
- "Thinking" process visibility
|
||||
- Request/response logging
|
||||
- `:CoderLogs` command to toggle panel
|
||||
|
||||
- **Mode Switcher** - Switch between Ask and Agent modes
|
||||
- `:CoderType` command shows mode selection UI
|
||||
|
||||
### Changed
|
||||
|
||||
- Window width configuration now uses percentage as whole number (e.g., `25` for 25%)
|
||||
- Window width configuration now uses percentage as whole number
|
||||
- Improved code extraction from LLM responses
|
||||
- Better prompt templates for code generation
|
||||
|
||||
### Fixed
|
||||
|
||||
- Window width calculation consistency across modules
|
||||
|
||||
---
|
||||
|
||||
@@ -123,31 +198,23 @@ scheduler = {
|
||||
### Added
|
||||
|
||||
- **Ask Panel** - Chat interface for asking questions about code
|
||||
- Fixed at 1/4 (25%) screen width for consistent layout
|
||||
- File attachment with `@` key (uses Telescope if available)
|
||||
- `Ctrl+n` to start a new chat (clears input and history)
|
||||
- Fixed at 1/4 (25%) screen width
|
||||
- File attachment with `@` key
|
||||
- `Ctrl+n` to start a new chat
|
||||
- `Ctrl+Enter` to submit questions
|
||||
- `Ctrl+f` to add current file as context
|
||||
- `Ctrl+h/j/k/l` for window navigation
|
||||
- `K/J` to jump between output and input windows
|
||||
- `Y` to copy last response to clipboard
|
||||
- `q` to close panel (closes both windows together)
|
||||
- Auto-open Ask panel on startup (configurable via `auto_open_ask`)
|
||||
- File content is now sent to LLM when attaching files with `@`
|
||||
- `Y` to copy last response
|
||||
|
||||
### Changed
|
||||
|
||||
- Ask panel width is now fixed at 25% (1/4 of screen)
|
||||
- Improved close behavior - closing either Ask window closes both
|
||||
- Proper focus management after closing Ask panel
|
||||
- Compact UI elements to fit 1/4 width layout
|
||||
- Changed "Assistant" label to "AI" in chat messages
|
||||
- Ask panel width is now fixed at 25%
|
||||
- Improved close behavior
|
||||
- Changed "Assistant" label to "AI"
|
||||
|
||||
### Fixed
|
||||
|
||||
- Ask panel window state sync issues
|
||||
- Window focus returning to code after closing Ask panel
|
||||
- NerdTree/nvim-tree causing Ask panel to resize incorrectly
|
||||
- Window focus returning to code after closing
|
||||
|
||||
---
|
||||
|
||||
@@ -159,27 +226,13 @@ scheduler = {
|
||||
- Core plugin architecture with modular Lua structure
|
||||
- Split window view for coder and target files
|
||||
- Tag-based prompt system (`/@` to open, `@/` to close)
|
||||
- Claude API integration for code generation
|
||||
- Ollama API integration for local LLM support
|
||||
- Automatic `.gitignore` management for coder files and `.coder/` folder
|
||||
- Smart prompt type detection (refactor, add, document, explain)
|
||||
- Code injection system with multiple strategies
|
||||
- User commands: `Coder`, `CoderOpen`, `CoderClose`, `CoderToggle`, `CoderProcess`, `CoderTree`, `CoderTreeView`
|
||||
- Health check module (`:checkhealth codetyper`)
|
||||
- Comprehensive documentation and help files
|
||||
- Telescope integration for file selection (optional)
|
||||
- **Project tree logging**: Automatic `.coder/tree.log` maintenance
|
||||
- Updates on file create, save, delete
|
||||
- Debounced updates (1 second) for performance
|
||||
- File type icons for visual clarity
|
||||
- Ignores common build/dependency folders
|
||||
|
||||
### Configuration Options
|
||||
|
||||
- LLM provider selection (Claude/Ollama)
|
||||
- Window position and width customization
|
||||
- Custom prompt tag patterns
|
||||
- Auto gitignore toggle
|
||||
- Claude API integration
|
||||
- Ollama API integration
|
||||
- Automatic `.gitignore` management
|
||||
- Smart prompt type detection
|
||||
- Code injection system
|
||||
- Health check module
|
||||
- Project tree logging
|
||||
|
||||
---
|
||||
|
||||
@@ -194,7 +247,9 @@ scheduler = {
|
||||
- **Fixed** - Bug fixes
|
||||
- **Security** - Vulnerability fixes
|
||||
|
||||
[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...HEAD
|
||||
[Unreleased]: https://github.com/cargdev/codetyper.nvim/compare/v0.6.0...HEAD
|
||||
[0.6.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.5.0...v0.6.0
|
||||
[0.5.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.4.0...v0.5.0
|
||||
[0.4.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.3.0...v0.4.0
|
||||
[0.3.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.2.0...v0.3.0
|
||||
[0.2.0]: https://github.com/cargdev/codetyper.nvim/compare/v0.1.0...v0.2.0
|
||||
|
||||
700
README.md
700
README.md
@@ -1,46 +1,56 @@
|
||||
# 🚀 Codetyper.nvim
|
||||
# Codetyper.nvim
|
||||
|
||||
**AI-powered coding partner for Neovim** - Write code faster with LLM assistance while staying in control of your logic.
|
||||
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://neovim.io/)
|
||||
|
||||
## ✨ Features
|
||||
## Features
|
||||
|
||||
- 📐 **Split View**: Work with your code and prompts side by side
|
||||
- 💬 **Ask Panel**: Chat interface for questions and explanations
|
||||
- 🤖 **Agent Mode**: Autonomous coding agent with tool use (read, edit, write, bash)
|
||||
- 🏷️ **Tag-based Prompts**: Use `/@` and `@/` tags to write natural language prompts
|
||||
- ⚡ **Transform Commands**: Transform prompts inline without leaving your file
|
||||
- 🔌 **Multiple LLM Providers**: Claude, OpenAI, Gemini, Copilot, and Ollama (local)
|
||||
- 📋 **Event-Driven Scheduler**: Queue-based processing with optimistic execution
|
||||
- 🎯 **Tree-sitter Scope Resolution**: Smart context extraction for functions/methods
|
||||
- 🧠 **Intent Detection**: Understands complete, refactor, fix, add, document intents
|
||||
- 📊 **Confidence Scoring**: Automatic escalation from local to remote LLMs
|
||||
- 🛡️ **Completion-Aware**: Safe injection that doesn't fight with autocomplete
|
||||
- 📁 **Auto-Index**: Automatically create coder companion files on file open
|
||||
- 📜 **Logs Panel**: Real-time visibility into LLM requests and token usage
|
||||
- 🔒 **Git Integration**: Automatically adds `.coder.*` files to `.gitignore`
|
||||
- 🌳 **Project Tree Logging**: Maintains a `tree.log` tracking your project structure
|
||||
- **Split View**: Work with your code and prompts side by side
|
||||
- **Ask Panel**: Chat interface for questions and explanations
|
||||
- **Agent Mode**: Autonomous coding agent with tool use (read, edit, write, bash)
|
||||
- **Tag-based Prompts**: Use `/@` and `@/` tags to write natural language prompts
|
||||
- **Transform Commands**: Transform prompts inline without leaving your file
|
||||
- **Multiple LLM Providers**: Claude, OpenAI, Gemini, Copilot, and Ollama (local)
|
||||
- **SEARCH/REPLACE Blocks**: Reliable code editing with fuzzy matching
|
||||
- **Conflict Resolution**: Git-style diff visualization with interactive resolution
|
||||
- **Linter Validation**: Auto-check and fix lint errors after code injection
|
||||
- **Event-Driven Scheduler**: Queue-based processing with optimistic execution
|
||||
- **Tree-sitter Scope Resolution**: Smart context extraction for functions/methods
|
||||
- **Intent Detection**: Understands complete, refactor, fix, add, document intents
|
||||
- **Confidence Scoring**: Automatic escalation from local to remote LLMs
|
||||
- **Completion-Aware**: Safe injection that doesn't fight with autocomplete
|
||||
- **Auto-Index**: Automatically create coder companion files on file open
|
||||
- **Logs Panel**: Real-time visibility into LLM requests and token usage
|
||||
- **Cost Tracking**: Persistent LLM cost estimation with session and all-time stats
|
||||
- **Git Integration**: Automatically adds `.coder.*` files to `.gitignore`
|
||||
- **Project Tree Logging**: Maintains a `tree.log` tracking your project structure
|
||||
- **Brain System**: Knowledge graph that learns from your coding patterns
|
||||
|
||||
---
|
||||
|
||||
## 📚 Table of Contents
|
||||
## Table of Contents
|
||||
|
||||
- [Requirements](#-requirements)
|
||||
- [Installation](#-installation)
|
||||
- [Quick Start](#-quick-start)
|
||||
- [Configuration](#-configuration)
|
||||
- [LLM Providers](#-llm-providers)
|
||||
- [Commands Reference](#-commands-reference)
|
||||
- [Usage Guide](#-usage-guide)
|
||||
- [Agent Mode](#-agent-mode)
|
||||
- [Keymaps](#-keymaps)
|
||||
- [Health Check](#-health-check)
|
||||
- [Requirements](#requirements)
|
||||
- [Installation](#installation)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Configuration](#configuration)
|
||||
- [LLM Providers](#llm-providers)
|
||||
- [Commands Reference](#commands-reference)
|
||||
- [Keymaps Reference](#keymaps-reference)
|
||||
- [Usage Guide](#usage-guide)
|
||||
- [Conflict Resolution](#conflict-resolution)
|
||||
- [Linter Validation](#linter-validation)
|
||||
- [Logs Panel](#logs-panel)
|
||||
- [Cost Tracking](#cost-tracking)
|
||||
- [Agent Mode](#agent-mode)
|
||||
- [Health Check](#health-check)
|
||||
- [Reporting Issues](#reporting-issues)
|
||||
|
||||
---
|
||||
|
||||
## 📋 Requirements
|
||||
## Requirements
|
||||
|
||||
- Neovim >= 0.8.0
|
||||
- curl (for API calls)
|
||||
@@ -58,7 +68,7 @@
|
||||
|
||||
---
|
||||
|
||||
## 📦 Installation
|
||||
## Installation
|
||||
|
||||
### Using [lazy.nvim](https://github.com/folke/lazy.nvim)
|
||||
|
||||
@@ -66,10 +76,10 @@
|
||||
{
|
||||
"cargdev/codetyper.nvim",
|
||||
dependencies = {
|
||||
"nvim-lua/plenary.nvim", -- Required: async utilities
|
||||
"nvim-treesitter/nvim-treesitter", -- Required: scope detection
|
||||
"nvim-treesitter/nvim-treesitter-textobjects", -- Optional: text objects
|
||||
"MunifTanjim/nui.nvim", -- Optional: UI components
|
||||
"nvim-lua/plenary.nvim",
|
||||
"nvim-treesitter/nvim-treesitter",
|
||||
"nvim-treesitter/nvim-treesitter-textobjects",
|
||||
"MunifTanjim/nui.nvim",
|
||||
},
|
||||
cmd = { "Coder", "CoderOpen", "CoderToggle", "CoderAgent" },
|
||||
keys = {
|
||||
@@ -100,7 +110,7 @@ use {
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Quick Start
|
||||
## Quick Start
|
||||
|
||||
**1. Open a file and start Coder:**
|
||||
```vim
|
||||
@@ -114,11 +124,17 @@ use {
|
||||
using regex, return boolean @/
|
||||
```
|
||||
|
||||
**3. The LLM generates code and injects it into `utils.ts` (right panel)**
|
||||
**3. The LLM generates code and shows a diff for you to review**
|
||||
|
||||
**4. Use conflict resolution keymaps to accept/reject changes:**
|
||||
- `ct` - Accept AI suggestion (theirs)
|
||||
- `co` - Keep original code (ours)
|
||||
- `cb` - Accept both versions
|
||||
- `cn` - Delete both (none)
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ Configuration
|
||||
## Configuration
|
||||
|
||||
```lua
|
||||
require("codetyper").setup({
|
||||
@@ -126,31 +142,26 @@ require("codetyper").setup({
|
||||
llm = {
|
||||
provider = "claude", -- "claude", "openai", "gemini", "copilot", or "ollama"
|
||||
|
||||
-- Claude (Anthropic) settings
|
||||
claude = {
|
||||
api_key = nil, -- Uses ANTHROPIC_API_KEY env var if nil
|
||||
model = "claude-sonnet-4-20250514",
|
||||
},
|
||||
|
||||
-- OpenAI settings
|
||||
openai = {
|
||||
api_key = nil, -- Uses OPENAI_API_KEY env var if nil
|
||||
model = "gpt-4o",
|
||||
endpoint = nil, -- Custom endpoint (Azure, OpenRouter, etc.)
|
||||
},
|
||||
|
||||
-- Google Gemini settings
|
||||
gemini = {
|
||||
api_key = nil, -- Uses GEMINI_API_KEY env var if nil
|
||||
model = "gemini-2.0-flash",
|
||||
},
|
||||
|
||||
-- GitHub Copilot settings (uses copilot.lua/copilot.vim auth)
|
||||
copilot = {
|
||||
model = "gpt-4o",
|
||||
},
|
||||
|
||||
-- Ollama (local) settings
|
||||
ollama = {
|
||||
host = "http://localhost:11434",
|
||||
model = "deepseek-coder:6.7b",
|
||||
@@ -159,7 +170,7 @@ require("codetyper").setup({
|
||||
|
||||
-- Window Configuration
|
||||
window = {
|
||||
width = 25, -- Percentage of screen width (25 = 25%)
|
||||
width = 25, -- Percentage of screen width
|
||||
position = "left",
|
||||
border = "rounded",
|
||||
},
|
||||
@@ -172,18 +183,18 @@ require("codetyper").setup({
|
||||
},
|
||||
|
||||
-- Auto Features
|
||||
auto_gitignore = true, -- Automatically add coder files to .gitignore
|
||||
auto_open_ask = true, -- Auto-open Ask panel on startup
|
||||
auto_index = false, -- Auto-create coder companion files on file open
|
||||
auto_gitignore = true,
|
||||
auto_open_ask = true,
|
||||
auto_index = false,
|
||||
|
||||
-- Event-Driven Scheduler
|
||||
scheduler = {
|
||||
enabled = true, -- Enable event-driven prompt processing
|
||||
ollama_scout = true, -- Use Ollama for first attempt (fast local)
|
||||
escalation_threshold = 0.7, -- Below this confidence, escalate to remote
|
||||
max_concurrent = 2, -- Max parallel workers
|
||||
completion_delay_ms = 100, -- Delay injection after completion popup
|
||||
apply_delay_ms = 5000, -- Wait before applying code (ms), allows review
|
||||
enabled = true,
|
||||
ollama_scout = true,
|
||||
escalation_threshold = 0.7,
|
||||
max_concurrent = 2,
|
||||
completion_delay_ms = 100,
|
||||
apply_delay_ms = 5000,
|
||||
},
|
||||
})
|
||||
```
|
||||
@@ -196,12 +207,26 @@ require("codetyper").setup({
|
||||
| `OPENAI_API_KEY` | OpenAI API key |
|
||||
| `GEMINI_API_KEY` | Google Gemini API key |
|
||||
|
||||
### Credentials Management
|
||||
|
||||
Store API keys securely outside of config files:
|
||||
|
||||
```vim
|
||||
:CoderAddApiKey
|
||||
```
|
||||
|
||||
Credentials are stored in `~/.local/share/nvim/codetyper/configuration.json`.
|
||||
|
||||
**Priority order:**
|
||||
1. Stored credentials (via `:CoderAddApiKey`)
|
||||
2. Config file settings
|
||||
3. Environment variables
|
||||
|
||||
---
|
||||
|
||||
## 🔌 LLM Providers
|
||||
## LLM Providers
|
||||
|
||||
### Claude (Anthropic)
|
||||
Best for complex reasoning and code generation.
|
||||
### Claude
|
||||
```lua
|
||||
llm = {
|
||||
provider = "claude",
|
||||
@@ -210,19 +235,17 @@ llm = {
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
Supports custom endpoints for Azure, OpenRouter, etc.
|
||||
```lua
|
||||
llm = {
|
||||
provider = "openai",
|
||||
openai = {
|
||||
model = "gpt-4o",
|
||||
endpoint = "https://api.openai.com/v1/chat/completions", -- optional
|
||||
endpoint = "https://api.openai.com/v1/chat/completions",
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### Google Gemini
|
||||
Fast and capable.
|
||||
```lua
|
||||
llm = {
|
||||
provider = "gemini",
|
||||
@@ -231,7 +254,6 @@ llm = {
|
||||
```
|
||||
|
||||
### GitHub Copilot
|
||||
Uses your existing Copilot subscription (requires copilot.lua or copilot.vim).
|
||||
```lua
|
||||
llm = {
|
||||
provider = "copilot",
|
||||
@@ -240,7 +262,6 @@ llm = {
|
||||
```
|
||||
|
||||
### Ollama (Local)
|
||||
Run models locally with no API costs.
|
||||
```lua
|
||||
llm = {
|
||||
provider = "ollama",
|
||||
@@ -253,166 +274,120 @@ llm = {
|
||||
|
||||
---
|
||||
|
||||
## 📝 Commands Reference
|
||||
## Commands Reference
|
||||
|
||||
### Main Commands
|
||||
### Core Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:Coder {subcommand}` | Main command with subcommands |
|
||||
| `:CoderOpen` | Open the coder split view |
|
||||
| `:CoderClose` | Close the coder split view |
|
||||
| `:CoderToggle` | Toggle the coder split view |
|
||||
| `:CoderProcess` | Process the last prompt |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder open` | `:CoderOpen` | Open the coder split view |
|
||||
| `:Coder close` | `:CoderClose` | Close the coder split view |
|
||||
| `:Coder toggle` | `:CoderToggle` | Toggle the coder split view |
|
||||
| `:Coder process` | `:CoderProcess` | Process the last prompt |
|
||||
| `:Coder status` | - | Show plugin status |
|
||||
| `:Coder focus` | - | Switch focus between windows |
|
||||
| `:Coder reset` | - | Reset processed prompts |
|
||||
|
||||
### Ask Panel
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderAsk` | Open the Ask panel |
|
||||
| `:CoderAskToggle` | Toggle the Ask panel |
|
||||
| `:CoderAskClear` | Clear chat history |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder ask` | `:CoderAsk` | Open the Ask panel |
|
||||
| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle the Ask panel |
|
||||
| `:Coder ask-clear` | `:CoderAskClear` | Clear chat history |
|
||||
|
||||
### Agent Mode
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderAgent` | Open the Agent panel |
|
||||
| `:CoderAgentToggle` | Toggle the Agent panel |
|
||||
| `:CoderAgentStop` | Stop the running agent |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agent` | `:CoderAgent` | Open the Agent panel |
|
||||
| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle the Agent panel |
|
||||
| `:Coder agent-stop` | `:CoderAgentStop` | Stop running agent |
|
||||
|
||||
### Agentic Mode
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agentic-run <task>` | `:CoderAgenticRun` | Run agentic task |
|
||||
| `:Coder agentic-list` | `:CoderAgenticList` | List available agents |
|
||||
| `:Coder agentic-init` | `:CoderAgenticInit` | Initialize .coder/agents/ |
|
||||
|
||||
### Transform Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderTransform` | Transform all /@ @/ tags in file |
|
||||
| `:CoderTransformCursor` | Transform tag at cursor position |
|
||||
| `:CoderTransformVisual` | Transform selected tags (visual mode) |
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder transform` | `:CoderTransform` | Transform all tags in file |
|
||||
| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform tag at cursor |
|
||||
| - | `:CoderTransformVisual` | Transform selected tags |
|
||||
|
||||
### Utility Commands
|
||||
### Conflict Resolution
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderConflictToggle` | Toggle conflict mode |
|
||||
| `:CoderConflictMenu` | Show resolution menu |
|
||||
| `:CoderConflictNext` | Go to next conflict |
|
||||
| `:CoderConflictPrev` | Go to previous conflict |
|
||||
| `:CoderConflictStatus` | Show conflict status |
|
||||
| `:CoderConflictResolveAll [keep]` | Resolve all (ours/theirs/both/none) |
|
||||
| `:CoderConflictAcceptCurrent` | Accept original code |
|
||||
| `:CoderConflictAcceptIncoming` | Accept AI suggestion |
|
||||
| `:CoderConflictAcceptBoth` | Accept both versions |
|
||||
| `:CoderConflictAcceptNone` | Delete both |
|
||||
| `:CoderConflictAutoMenu` | Toggle auto-show menu |
|
||||
|
||||
### Linter Validation
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderLintCheck` | Check buffer for lint errors |
|
||||
| `:CoderLintFix` | Request AI to fix lint errors |
|
||||
| `:CoderLintQuickfix` | Show errors in quickfix |
|
||||
| `:CoderLintToggleAuto` | Toggle auto lint checking |
|
||||
|
||||
### Queue & Scheduler
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder queue-status` | `:CoderQueueStatus` | Show scheduler status |
|
||||
| `:Coder queue-process` | `:CoderQueueProcess` | Trigger queue processing |
|
||||
|
||||
### Processing Mode
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle auto/manual mode |
|
||||
| `:Coder auto-set <mode>` | `:CoderAutoSet` | Set mode (auto/manual) |
|
||||
|
||||
### Brain & Memory
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderMemories` | Show learned memories |
|
||||
| `:CoderForget [pattern]` | Clear memories |
|
||||
| `:CoderBrain [action]` | Brain management (stats/commit/flush/prune) |
|
||||
| `:CoderFeedback <type>` | Give feedback (good/bad/stats) |
|
||||
|
||||
### Cost & Credentials
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderCost` | Show cost estimation window |
|
||||
| `:CoderAddApiKey` | Add/update API key |
|
||||
| `:CoderRemoveApiKey` | Remove credentials |
|
||||
| `:CoderCredentials` | Show credentials status |
|
||||
| `:CoderSwitchProvider` | Switch LLM provider |
|
||||
|
||||
### UI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderIndex` | Open coder companion for current file |
|
||||
| `:CoderLogs` | Toggle logs panel |
|
||||
| `:CoderType` | Switch between Ask/Agent modes |
|
||||
| `:CoderTree` | Refresh tree.log |
|
||||
| `:CoderTreeView` | View tree.log |
|
||||
| `:CoderType` | Show Ask/Agent switcher |
|
||||
|
||||
---
|
||||
|
||||
## 📖 Usage Guide
|
||||
|
||||
### Tag-Based Prompts
|
||||
|
||||
Write prompts in your coder file using `/@` and `@/` tags:
|
||||
|
||||
```typescript
|
||||
/@ Create a Button component with the following props:
|
||||
- variant: 'primary' | 'secondary' | 'danger'
|
||||
- size: 'sm' | 'md' | 'lg'
|
||||
- disabled: boolean
|
||||
Use Tailwind CSS for styling @/
|
||||
```
|
||||
|
||||
When you close the tag with `@/`, the prompt is automatically processed.
|
||||
|
||||
### Transform Commands
|
||||
|
||||
Transform prompts inline without the split view:
|
||||
|
||||
```typescript
|
||||
// In your source file:
|
||||
/@ Add input validation for email and password @/
|
||||
|
||||
// Run :CoderTransformCursor to transform the prompt at cursor
|
||||
```
|
||||
|
||||
### Prompt Types
|
||||
|
||||
The plugin auto-detects prompt type:
|
||||
|
||||
| Keywords | Type | Behavior |
|
||||
|----------|------|----------|
|
||||
| `complete`, `finish`, `implement`, `todo` | Complete | Completes function body (replaces scope) |
|
||||
| `refactor`, `rewrite`, `simplify` | Refactor | Replaces code |
|
||||
| `fix`, `debug`, `bug`, `error` | Fix | Fixes bugs (replaces scope) |
|
||||
| `add`, `create`, `generate` | Add | Inserts new code |
|
||||
| `document`, `comment`, `jsdoc` | Document | Adds documentation |
|
||||
| `optimize`, `performance`, `faster` | Optimize | Optimizes code (replaces scope) |
|
||||
| `explain`, `what`, `how` | Explain | Shows explanation only |
|
||||
|
||||
### Function Completion
|
||||
|
||||
When you write a prompt **inside** a function body, the plugin uses Tree-sitter to detect the enclosing scope and automatically switches to "complete" mode:
|
||||
|
||||
```typescript
|
||||
function getUserById(id: number): User | null {
|
||||
/@ return the user from the database by id, handle not found case @/
|
||||
}
|
||||
```
|
||||
|
||||
The LLM will complete the function body while keeping the exact same signature. The entire function scope is replaced with the completed version.
|
||||
|
||||
---
|
||||
|
||||
## 📊 Logs Panel
|
||||
|
||||
The logs panel provides real-time visibility into LLM operations:
|
||||
|
||||
### Features
|
||||
|
||||
- **Generation Logs**: Shows all LLM requests, responses, and token usage
|
||||
- **Queue Display**: Shows pending and processing prompts
|
||||
- **Full Response View**: Complete LLM responses are logged for debugging
|
||||
- **Auto-cleanup**: Logs panel and queue windows automatically close when exiting Neovim
|
||||
|
||||
### Opening the Logs Panel
|
||||
|
||||
```vim
|
||||
:CoderLogs
|
||||
```
|
||||
|
||||
The logs panel opens automatically when processing prompts with the scheduler enabled.
|
||||
|
||||
### Keymaps
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `q` | Close logs panel |
|
||||
| `<Esc>` | Close logs panel |
|
||||
|
||||
---
|
||||
|
||||
## 🤖 Agent Mode
|
||||
|
||||
The Agent mode provides an autonomous coding assistant with tool access:
|
||||
|
||||
### Available Tools
|
||||
|
||||
- **read_file**: Read file contents
|
||||
- **edit_file**: Edit files with find/replace
|
||||
- **write_file**: Create or overwrite files
|
||||
- **bash**: Execute shell commands
|
||||
|
||||
### Using Agent Mode
|
||||
|
||||
1. Open the agent panel: `:CoderAgent` or `<leader>ca`
|
||||
2. Describe what you want to accomplish
|
||||
3. The agent will use tools to complete the task
|
||||
4. Review changes before they're applied
|
||||
|
||||
### Agent Keymaps
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `<CR>` | Submit message |
|
||||
| `Ctrl+c` | Stop agent execution |
|
||||
| `q` | Close agent panel |
|
||||
|
||||
---
|
||||
|
||||
## ⌨️ Keymaps
|
||||
## Keymaps Reference
|
||||
|
||||
### Default Keymaps (auto-configured)
|
||||
|
||||
@@ -422,7 +397,36 @@ The Agent mode provides an autonomous coding assistant with tool access:
|
||||
| `<leader>ctt` | Visual | Transform selected tags |
|
||||
| `<leader>ctT` | Normal | Transform all tags in file |
|
||||
| `<leader>ca` | Normal | Toggle Agent panel |
|
||||
| `<leader>ci` | Normal | Open coder companion (index) |
|
||||
| `<leader>ci` | Normal | Open coder companion |
|
||||
|
||||
### Conflict Resolution Keymaps (buffer-local when conflicts exist)
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `co` | Accept CURRENT (original) code |
|
||||
| `ct` | Accept INCOMING (AI suggestion) |
|
||||
| `cb` | Accept BOTH versions |
|
||||
| `cn` | Delete conflict (accept NONE) |
|
||||
| `cm` | Show conflict resolution menu |
|
||||
| `]x` | Go to next conflict |
|
||||
| `[x` | Go to previous conflict |
|
||||
| `<CR>` | Show menu when on conflict |
|
||||
|
||||
### Conflict Menu Keymaps (in floating menu)
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `1` | Accept current (original) |
|
||||
| `2` | Accept incoming (AI) |
|
||||
| `3` | Accept both |
|
||||
| `4` | Accept none |
|
||||
| `co` | Accept current |
|
||||
| `ct` | Accept incoming |
|
||||
| `cb` | Accept both |
|
||||
| `cn` | Accept none |
|
||||
| `]x` | Go to next conflict |
|
||||
| `[x` | Go to previous conflict |
|
||||
| `q` / `<Esc>` | Close menu |
|
||||
|
||||
### Ask Panel Keymaps
|
||||
|
||||
@@ -435,6 +439,29 @@ The Agent mode provides an autonomous coding assistant with tool access:
|
||||
| `q` | Close panel |
|
||||
| `Y` | Copy last response |
|
||||
|
||||
### Agent Panel Keymaps
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `<CR>` | Submit message |
|
||||
| `Ctrl+c` | Stop agent execution |
|
||||
| `q` | Close agent panel |
|
||||
|
||||
### Logs Panel Keymaps
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `q` / `<Esc>` | Close logs panel |
|
||||
|
||||
### Cost Window Keymaps
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `q` / `<Esc>` | Close window |
|
||||
| `r` | Refresh display |
|
||||
| `c` | Clear session costs |
|
||||
| `C` | Clear all history |
|
||||
|
||||
### Suggested Additional Keymaps
|
||||
|
||||
```lua
|
||||
@@ -445,63 +472,288 @@ map("n", "<leader>cc", "<cmd>Coder close<cr>", { desc = "Coder: Close" })
|
||||
map("n", "<leader>ct", "<cmd>Coder toggle<cr>", { desc = "Coder: Toggle" })
|
||||
map("n", "<leader>cp", "<cmd>Coder process<cr>", { desc = "Coder: Process" })
|
||||
map("n", "<leader>cs", "<cmd>Coder status<cr>", { desc = "Coder: Status" })
|
||||
map("n", "<leader>cl", "<cmd>CoderLogs<cr>", { desc = "Coder: Logs" })
|
||||
map("n", "<leader>cm", "<cmd>CoderConflictMenu<cr>", { desc = "Coder: Conflict Menu" })
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🏥 Health Check
|
||||
## Usage Guide
|
||||
|
||||
Verify your setup:
|
||||
### Tag-Based Prompts
|
||||
|
||||
Write prompts using `/@` and `@/` tags:
|
||||
|
||||
```typescript
|
||||
/@ Create a Button component with:
|
||||
- variant: 'primary' | 'secondary' | 'danger'
|
||||
- size: 'sm' | 'md' | 'lg'
|
||||
Use Tailwind CSS for styling @/
|
||||
```
|
||||
|
||||
### Prompt Types
|
||||
|
||||
| Keywords | Type | Behavior |
|
||||
|----------|------|----------|
|
||||
| `complete`, `finish`, `implement` | Complete | Replaces scope |
|
||||
| `refactor`, `rewrite`, `simplify` | Refactor | Replaces code |
|
||||
| `fix`, `debug`, `bug`, `error` | Fix | Fixes bugs |
|
||||
| `add`, `create`, `generate` | Add | Inserts new code |
|
||||
| `document`, `comment`, `jsdoc` | Document | Adds docs |
|
||||
| `explain`, `what`, `how` | Explain | Shows explanation |
|
||||
|
||||
### Function Completion
|
||||
|
||||
When you write a prompt inside a function, the plugin detects the enclosing scope:
|
||||
|
||||
```typescript
|
||||
function getUserById(id: number): User | null {
|
||||
/@ return the user from the database by id @/
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Conflict Resolution
|
||||
|
||||
When code is generated, it's shown as a git-style conflict for you to review:
|
||||
|
||||
```
|
||||
<<<<<<< CURRENT
|
||||
// Original code here
|
||||
=======
|
||||
// AI-generated code here
|
||||
>>>>>>> INCOMING
|
||||
```
|
||||
|
||||
### Visual Indicators
|
||||
|
||||
- **Green background**: Original (CURRENT) code
|
||||
- **Blue background**: AI-generated (INCOMING) code
|
||||
- **Virtual text hints**: Shows available keymaps
|
||||
|
||||
### Resolution Options
|
||||
|
||||
1. **Accept Current (`co`)**: Keep your original code
|
||||
2. **Accept Incoming (`ct`)**: Use the AI suggestion
|
||||
3. **Accept Both (`cb`)**: Keep both versions
|
||||
4. **Accept None (`cn`)**: Delete the entire conflict
|
||||
|
||||
### Auto-Show Menu
|
||||
|
||||
When code is injected, a floating menu automatically appears. After resolving a conflict, the menu shows again for the next conflict.
|
||||
|
||||
Toggle auto-show: `:CoderConflictAutoMenu`
|
||||
|
||||
---
|
||||
|
||||
## Linter Validation
|
||||
|
||||
After accepting AI suggestions (`ct` or `cb`), the plugin:
|
||||
|
||||
1. **Saves the file** automatically
|
||||
2. **Checks LSP diagnostics** for errors/warnings
|
||||
3. **Offers to fix** lint errors with AI
|
||||
|
||||
### Configuration
|
||||
|
||||
```lua
|
||||
-- In conflict.lua config
|
||||
lint_after_accept = true, -- Check linter after accepting
|
||||
auto_fix_lint_errors = true, -- Auto-queue fix without prompting
|
||||
```
|
||||
|
||||
### Manual Commands
|
||||
|
||||
- `:CoderLintCheck` - Check current buffer
|
||||
- `:CoderLintFix` - Queue AI fix for errors
|
||||
- `:CoderLintQuickfix` - Show in quickfix list
|
||||
|
||||
---
|
||||
|
||||
## Logs Panel
|
||||
|
||||
Real-time visibility into LLM operations:
|
||||
|
||||
```vim
|
||||
:CoderLogs
|
||||
```
|
||||
|
||||
Shows:
|
||||
- Generation requests and responses
|
||||
- Token usage
|
||||
- Queue status
|
||||
- Errors and warnings
|
||||
|
||||
---
|
||||
|
||||
## Cost Tracking
|
||||
|
||||
Track LLM API costs across sessions:
|
||||
|
||||
```vim
|
||||
:CoderCost
|
||||
```
|
||||
|
||||
Features:
|
||||
- Session and all-time statistics
|
||||
- Per-model breakdown
|
||||
- Pricing for 50+ models
|
||||
- Persistent history in `.coder/cost_history.json`
|
||||
|
||||
---
|
||||
|
||||
## Agent Mode
|
||||
|
||||
Autonomous coding assistant with tool access:
|
||||
|
||||
### Available Tools
|
||||
|
||||
- **read_file**: Read file contents
|
||||
- **edit_file**: Edit files with find/replace
|
||||
- **write_file**: Create or overwrite files
|
||||
- **bash**: Execute shell commands
|
||||
|
||||
### Using Agent Mode
|
||||
|
||||
1. Open: `:CoderAgent` or `<leader>ca`
|
||||
2. Describe your task
|
||||
3. Agent uses tools autonomously
|
||||
4. Review changes in conflict mode
|
||||
|
||||
---
|
||||
|
||||
## Health Check
|
||||
|
||||
```vim
|
||||
:checkhealth codetyper
|
||||
```
|
||||
|
||||
This checks:
|
||||
- Neovim version
|
||||
- curl availability
|
||||
- LLM configuration
|
||||
- API key status
|
||||
- Telescope availability (optional)
|
||||
|
||||
---
|
||||
|
||||
## 📁 File Structure
|
||||
## File Structure
|
||||
|
||||
```
|
||||
your-project/
|
||||
├── .coder/ # Auto-created, gitignored
|
||||
│ └── tree.log # Project structure log
|
||||
├── .coder/
|
||||
│ ├── tree.log
|
||||
│ ├── cost_history.json
|
||||
│ ├── brain/
|
||||
│ ├── agents/
|
||||
│ └── rules/
|
||||
├── src/
|
||||
│ ├── index.ts # Your source file
|
||||
│ ├── index.coder.ts # Coder file (gitignored)
|
||||
└── .gitignore # Auto-updated with coder patterns
|
||||
│ ├── index.ts
|
||||
│ └── index.coder.ts
|
||||
└── .gitignore
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
## Reporting Issues
|
||||
|
||||
Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
Found a bug or have a feature request? Please create an issue on GitHub.
|
||||
|
||||
### Before Creating an Issue
|
||||
|
||||
1. **Search existing issues** to avoid duplicates
|
||||
2. **Update to the latest version** and check if the issue persists
|
||||
3. **Run health check**: `:checkhealth codetyper`
|
||||
|
||||
### Bug Reports
|
||||
|
||||
When reporting a bug, please include:
|
||||
|
||||
```markdown
|
||||
**Description**
|
||||
A clear description of what the bug is.
|
||||
|
||||
**Steps to Reproduce**
|
||||
1. Open file '...'
|
||||
2. Run command '...'
|
||||
3. See error
|
||||
|
||||
**Expected Behavior**
|
||||
What you expected to happen.
|
||||
|
||||
**Actual Behavior**
|
||||
What actually happened.
|
||||
|
||||
**Environment**
|
||||
- Neovim version: (output of `nvim --version`)
|
||||
- Plugin version: (commit hash or tag)
|
||||
- OS: (e.g., macOS 14.0, Ubuntu 22.04)
|
||||
- LLM Provider: (e.g., Claude, OpenAI, Ollama)
|
||||
|
||||
**Error Messages**
|
||||
Paste any error messages from `:messages`
|
||||
|
||||
**Minimal Config**
|
||||
If possible, provide a minimal config to reproduce:
|
||||
```lua
|
||||
-- minimal.lua
|
||||
require("codetyper").setup({
|
||||
llm = { provider = "..." },
|
||||
})
|
||||
```
|
||||
```
|
||||
|
||||
### Feature Requests
|
||||
|
||||
For feature requests, please describe:
|
||||
|
||||
- **Use case**: What problem does this solve?
|
||||
- **Proposed solution**: How should it work?
|
||||
- **Alternatives**: Other solutions you've considered
|
||||
|
||||
### Debug Information
|
||||
|
||||
To gather debug information:
|
||||
|
||||
```vim
|
||||
" Check plugin status
|
||||
:Coder status
|
||||
|
||||
" View logs
|
||||
:CoderLogs
|
||||
|
||||
" Check health
|
||||
:checkhealth codetyper
|
||||
|
||||
" View recent messages
|
||||
:messages
|
||||
```
|
||||
|
||||
### Issue Labels
|
||||
|
||||
- `bug` - Something isn't working
|
||||
- `enhancement` - New feature request
|
||||
- `documentation` - Documentation improvements
|
||||
- `question` - General questions
|
||||
- `help wanted` - Issues that need community help
|
||||
|
||||
---
|
||||
|
||||
## 📄 License
|
||||
## Contributing
|
||||
|
||||
MIT License - see [LICENSE](LICENSE) for details.
|
||||
Contributions welcome! See [CONTRIBUTING.md](CONTRIBUTING.md).
|
||||
|
||||
---
|
||||
|
||||
## 👨💻 Author
|
||||
## License
|
||||
|
||||
MIT License - see [LICENSE](LICENSE).
|
||||
|
||||
---
|
||||
|
||||
## Author
|
||||
|
||||
**cargdev**
|
||||
|
||||
- Website: [cargdev.io](https://cargdev.io)
|
||||
- Blog: [blog.cargdev.io](https://blog.cargdev.io)
|
||||
- Email: carlos.gutierrez@carg.dev
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
Made with ❤️ for the Neovim community
|
||||
Made with care for the Neovim community
|
||||
</p>
|
||||
|
||||
@@ -120,7 +120,7 @@ Default configuration: >lua
|
||||
5. LLM PROVIDERS *codetyper-providers*
|
||||
|
||||
*codetyper-claude*
|
||||
Claude (Anthropic)~
|
||||
Claude~
|
||||
Best for complex reasoning and code generation.
|
||||
>lua
|
||||
llm = {
|
||||
|
||||
475
llms.txt
475
llms.txt
@@ -14,61 +14,72 @@ Instead of having an AI generate entire files, Codetyper lets developers maintai
|
||||
2. A companion "coder file" is created (`index.coder.ts`)
|
||||
3. Developer writes prompts using special tags: `/@ prompt @/`
|
||||
4. When the closing tag is typed, the LLM generates code
|
||||
5. Generated code is injected into the target file
|
||||
5. Generated code is shown as a conflict for review
|
||||
6. Developer accepts/rejects changes using keymaps
|
||||
|
||||
## Plugin Architecture
|
||||
|
||||
```
|
||||
lua/codetyper/
|
||||
├── init.lua # Main entry, setup function, module initialization
|
||||
├── config.lua # Configuration management, defaults, validation
|
||||
├── types.lua # Lua type definitions for LSP/documentation
|
||||
├── utils.lua # Utility functions (file ops, notifications)
|
||||
├── commands.lua # Vim command definitions (:Coder, :CoderOpen, etc.)
|
||||
├── window.lua # Split window management (open, close, toggle)
|
||||
├── parser.lua # Parses /@ @/ tags from buffer content
|
||||
├── gitignore.lua # Manages .gitignore entries for coder files
|
||||
├── autocmds.lua # Autocommands for tag detection, filetype, auto-index
|
||||
├── inject.lua # Code injection strategies
|
||||
├── health.lua # Health check for :checkhealth
|
||||
├── tree.lua # Project tree logging (.coder/tree.log)
|
||||
├── logs_panel.lua # Standalone logs panel UI
|
||||
├── init.lua # Main entry, setup function
|
||||
├── config.lua # Configuration management
|
||||
├── types.lua # Lua type definitions
|
||||
├── utils.lua # Utility functions
|
||||
├── commands.lua # Vim command definitions
|
||||
├── window.lua # Split window management
|
||||
├── parser.lua # Parses /@ @/ tags
|
||||
├── gitignore.lua # Manages .gitignore entries
|
||||
├── autocmds.lua # Autocommands for tag detection
|
||||
├── inject.lua # Code injection strategies
|
||||
├── health.lua # Health check for :checkhealth
|
||||
├── tree.lua # Project tree logging
|
||||
├── logs_panel.lua # Standalone logs panel UI
|
||||
├── cost.lua # LLM cost tracking
|
||||
├── credentials.lua # Secure credential storage
|
||||
├── llm/
|
||||
│ ├── init.lua # LLM interface, provider selection
|
||||
│ ├── claude.lua # Claude API client (Anthropic)
|
||||
│ ├── openai.lua # OpenAI API client (with custom endpoint support)
|
||||
│ ├── gemini.lua # Google Gemini API client
|
||||
│ ├── copilot.lua # GitHub Copilot client (uses OAuth from copilot.lua/vim)
|
||||
│ └── ollama.lua # Ollama API client (local LLMs)
|
||||
│ ├── init.lua # LLM interface, provider selection
|
||||
│ ├── claude.lua # Claude API client
|
||||
│ ├── openai.lua # OpenAI API client
|
||||
│ ├── gemini.lua # Google Gemini API client
|
||||
│ ├── copilot.lua # GitHub Copilot client
|
||||
│ └── ollama.lua # Ollama API client (local)
|
||||
├── agent/
|
||||
│ ├── init.lua # Agent system entry point
|
||||
│ ├── ui.lua # Agent panel UI
|
||||
│ ├── logs.lua # Logging system with listeners
|
||||
│ ├── tools.lua # Tool definitions (read_file, edit_file, write_file, bash)
|
||||
│ ├── executor.lua # Tool execution logic
|
||||
│ ├── parser.lua # Parse tool calls from LLM responses
|
||||
│ ├── queue.lua # Event queue with priority heap
|
||||
│ ├── patch.lua # Patch candidates with staleness detection
|
||||
│ ├── confidence.lua # Response confidence scoring heuristics
|
||||
│ ├── worker.lua # Async LLM worker wrapper
|
||||
│ ├── scheduler.lua # Event scheduler with completion-awareness
|
||||
│ ├── scope.lua # Tree-sitter scope resolution
|
||||
│ └── intent.lua # Intent detection from prompts
|
||||
│ ├── init.lua # Agent system entry point
|
||||
│ ├── ui.lua # Agent panel UI
|
||||
│ ├── logs.lua # Logging system
|
||||
│ ├── tools.lua # Tool definitions (read, edit, write, bash)
|
||||
│ ├── executor.lua # Tool execution logic
|
||||
│ ├── parser.lua # Parse tool calls from responses
|
||||
│ ├── queue.lua # Event queue with priority heap
|
||||
│ ├── patch.lua # Patch candidates with staleness detection
|
||||
│ ├── confidence.lua # Response confidence scoring
|
||||
│ ├── worker.lua # Async LLM worker
|
||||
│ ├── scheduler.lua # Event scheduler
|
||||
│ ├── scope.lua # Tree-sitter scope resolution
|
||||
│ ├── intent.lua # Intent detection from prompts
|
||||
│ ├── conflict.lua # Git-style conflict resolution
|
||||
│ ├── linter.lua # LSP diagnostics validation
|
||||
│ └── search_replace.lua # SEARCH/REPLACE block parsing
|
||||
├── ask/
|
||||
│ ├── init.lua # Ask panel entry point
|
||||
│ └── ui.lua # Ask panel UI (chat interface)
|
||||
│ ├── init.lua # Ask panel entry point
|
||||
│ └── ui.lua # Ask panel UI (chat interface)
|
||||
└── prompts/
|
||||
├── init.lua # System prompts for code generation
|
||||
└── agent.lua # Agent-specific prompts and tool instructions
|
||||
├── init.lua # System prompts for code generation
|
||||
└── agent.lua # Agent-specific prompts
|
||||
```
|
||||
|
||||
## .coder/ Folder
|
||||
|
||||
The plugin automatically creates and maintains a `.coder/` folder in your project:
|
||||
|
||||
```
|
||||
.coder/
|
||||
└── tree.log # Project structure, auto-updated on file changes
|
||||
├── tree.log # Project structure, auto-updated
|
||||
├── cost_history.json # LLM cost tracking history
|
||||
├── brain/ # Knowledge graph storage
|
||||
│ ├── nodes/
|
||||
│ ├── indices/
|
||||
│ └── deltas/
|
||||
├── agents/ # Custom agent definitions
|
||||
└── rules/ # Project-specific rules
|
||||
```
|
||||
|
||||
## Key Features
|
||||
@@ -86,66 +97,122 @@ llm = {
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Agent Mode
|
||||
### 2. Conflict Resolution System
|
||||
|
||||
Git-style diff visualization for code review:
|
||||
|
||||
```
|
||||
<<<<<<< CURRENT
|
||||
// Original code
|
||||
=======
|
||||
// AI-generated code
|
||||
>>>>>>> INCOMING
|
||||
```
|
||||
|
||||
**Keymaps (buffer-local when conflicts exist):**
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `co` | Accept CURRENT (original) code |
|
||||
| `ct` | Accept INCOMING (AI suggestion) |
|
||||
| `cb` | Accept BOTH versions |
|
||||
| `cn` | Delete conflict (accept NONE) |
|
||||
| `cm` | Show conflict resolution menu |
|
||||
| `]x` | Go to next conflict |
|
||||
| `[x` | Go to previous conflict |
|
||||
| `<CR>` | Show menu when on conflict |
|
||||
|
||||
**Menu keymaps:**
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `1` | Accept current |
|
||||
| `2` | Accept incoming |
|
||||
| `3` | Accept both |
|
||||
| `4` | Accept none |
|
||||
| `q`/`<Esc>` | Close menu |
|
||||
|
||||
**Configuration:**
|
||||
```lua
|
||||
-- In conflict.lua
|
||||
config = {
|
||||
lint_after_accept = true, -- Check linter after accepting
|
||||
auto_fix_lint_errors = true, -- Auto-queue fix
|
||||
auto_show_menu = true, -- Show menu after injection
|
||||
auto_show_next_menu = true, -- Show menu for next conflict
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Linter Validation
|
||||
|
||||
Auto-check and fix lint errors after code injection:
|
||||
|
||||
```lua
|
||||
-- In linter.lua
|
||||
config = {
|
||||
auto_save = true, -- Save file after injection
|
||||
diagnostic_delay_ms = 500, -- Wait for LSP
|
||||
min_severity = vim.diagnostic.severity.WARN,
|
||||
auto_offer_fix = true, -- Offer to fix errors
|
||||
}
|
||||
```
|
||||
|
||||
**Commands:**
|
||||
- `:CoderLintCheck` - Check buffer for lint errors
|
||||
- `:CoderLintFix` - Request AI to fix lint errors
|
||||
- `:CoderLintQuickfix` - Show errors in quickfix
|
||||
- `:CoderLintToggleAuto` - Toggle auto lint checking
|
||||
|
||||
### 4. SEARCH/REPLACE Block System
|
||||
|
||||
Reliable code editing with fuzzy matching:
|
||||
|
||||
```
|
||||
<<<<<<< SEARCH
|
||||
function oldCode() {
|
||||
// original
|
||||
}
|
||||
=======
|
||||
function newCode() {
|
||||
// replacement
|
||||
}
|
||||
>>>>>>> REPLACE
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
```lua
|
||||
-- In search_replace.lua
|
||||
config = {
|
||||
fuzzy_threshold = 0.8, -- Minimum similarity
|
||||
normalize_whitespace = true, -- Ignore whitespace differences
|
||||
context_lines = 3, -- Lines for context matching
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Agent Mode
|
||||
|
||||
Autonomous coding assistant with tool access:
|
||||
|
||||
**Available Tools:**
|
||||
- `read_file` - Read file contents
|
||||
- `edit_file` - Edit files with find/replace
|
||||
- `write_file` - Create or overwrite files
|
||||
- `bash` - Execute shell commands
|
||||
|
||||
### 3. Transform Commands
|
||||
|
||||
Transform `/@ @/` tags inline without split view:
|
||||
|
||||
- `:CoderTransform` - Transform all tags in file
|
||||
- `:CoderTransformCursor` - Transform tag at cursor
|
||||
- `:CoderTransformVisual` - Transform selected tags
|
||||
|
||||
### 4. Auto-Index
|
||||
|
||||
Automatically create coder companion files when opening source files:
|
||||
|
||||
```lua
|
||||
auto_index = true -- disabled by default
|
||||
```
|
||||
|
||||
### 5. Logs Panel
|
||||
|
||||
Real-time visibility into LLM operations with token usage tracking.
|
||||
|
||||
### 6. Event-Driven Scheduler
|
||||
|
||||
Prompts are treated as events, not commands:
|
||||
|
||||
```
|
||||
User types /@...@/ → Event queued → Scheduler dispatches → Worker processes → Patch created → Safe injection
|
||||
User types /@...@/ → Event queued → Scheduler dispatches → Worker processes → Patch created → Conflict shown
|
||||
```
|
||||
|
||||
**Key concepts:**
|
||||
|
||||
- **PromptEvent**: Captures buffer state (changedtick, content hash) at prompt time
|
||||
- **Optimistic Execution**: Ollama as fast scout, escalate to remote LLMs if confidence low
|
||||
- **Confidence Scoring**: 5 heuristics (length, uncertainty, syntax, repetition, truncation)
|
||||
- **Staleness Detection**: Discard patches if buffer changed during generation
|
||||
- **Completion Safety**: Defer injection while autocomplete popup visible
|
||||
|
||||
**Configuration:**
|
||||
|
||||
```lua
|
||||
scheduler = {
|
||||
enabled = true, -- Enable event-driven mode
|
||||
ollama_scout = true, -- Use Ollama first
|
||||
escalation_threshold = 0.7, -- Below this → escalate
|
||||
max_concurrent = 2, -- Parallel workers
|
||||
completion_delay_ms = 100, -- Wait after popup closes
|
||||
}
|
||||
```
|
||||
- **PromptEvent**: Captures buffer state at prompt time
|
||||
- **Optimistic Execution**: Ollama as fast scout
|
||||
- **Confidence Scoring**: 5 heuristics
|
||||
- **Staleness Detection**: Discard if buffer changed
|
||||
- **Completion Safety**: Defer while autocomplete visible
|
||||
|
||||
### 7. Tree-sitter Scope Resolution
|
||||
|
||||
Prompts automatically resolve to their enclosing function/method/class:
|
||||
Prompts automatically resolve to enclosing scope:
|
||||
|
||||
```lua
|
||||
function foo()
|
||||
@@ -155,13 +222,8 @@ end
|
||||
|
||||
**Scope types:** `function`, `method`, `class`, `block`, `file`
|
||||
|
||||
For replacement intents (complete, refactor, fix), the entire scope is extracted
|
||||
and sent to the LLM, then replaced with the transformed version.
|
||||
|
||||
### 8. Intent Detection
|
||||
|
||||
The system parses prompts to detect user intent:
|
||||
|
||||
| Intent | Keywords | Action |
|
||||
|--------|----------|--------|
|
||||
| complete | complete, finish, implement | replace |
|
||||
@@ -170,76 +232,174 @@ The system parses prompts to detect user intent:
|
||||
| add | add, create, insert, new | insert |
|
||||
| document | document, comment, jsdoc | replace |
|
||||
| test | test, spec, unit test | append |
|
||||
| optimize | optimize, performance, faster | replace |
|
||||
| explain | explain, what, how, why | none |
|
||||
| optimize | optimize, performance | replace |
|
||||
| explain | explain, what, how | none |
|
||||
|
||||
### 9. Tag Precedence
|
||||
### 9. Cost Tracking
|
||||
|
||||
Multiple tags in the same scope follow "first tag wins" rule:
|
||||
- Earlier (by line number) unresolved tag processes first
|
||||
- Later tags in same scope are skipped with warning
|
||||
- Different scopes process independently
|
||||
Track LLM API costs:
|
||||
- Session costs tracked in real-time
|
||||
- All-time costs in `.coder/cost_history.json`
|
||||
- Pricing for 50+ models
|
||||
|
||||
## Commands
|
||||
### 10. Credentials Management
|
||||
|
||||
### Main Commands
|
||||
- `:Coder open` - Opens split view with coder file
|
||||
- `:Coder close` - Closes the split
|
||||
- `:Coder toggle` - Toggles the view
|
||||
- `:Coder process` - Manually triggers code generation
|
||||
```vim
|
||||
:CoderAddApiKey
|
||||
```
|
||||
|
||||
Stored in `~/.local/share/nvim/codetyper/configuration.json`
|
||||
|
||||
**Priority:** stored credentials > config > environment variables
|
||||
|
||||
## Commands Reference
|
||||
|
||||
### Core Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder open` | `:CoderOpen` | Open coder split |
|
||||
| `:Coder close` | `:CoderClose` | Close coder split |
|
||||
| `:Coder toggle` | `:CoderToggle` | Toggle coder split |
|
||||
| `:Coder process` | `:CoderProcess` | Process last prompt |
|
||||
| `:Coder status` | - | Show status |
|
||||
| `:Coder focus` | - | Switch focus |
|
||||
| `:Coder reset` | - | Reset processed prompts |
|
||||
|
||||
### Ask Panel
|
||||
- `:CoderAsk` - Open Ask panel
|
||||
- `:CoderAskToggle` - Toggle Ask panel
|
||||
- `:CoderAskClear` - Clear chat history
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder ask` | `:CoderAsk` | Open Ask panel |
|
||||
| `:Coder ask-toggle` | `:CoderAskToggle` | Toggle Ask panel |
|
||||
| `:Coder ask-clear` | `:CoderAskClear` | Clear chat |
|
||||
|
||||
### Agent Mode
|
||||
- `:CoderAgent` - Open Agent panel
|
||||
- `:CoderAgentToggle` - Toggle Agent panel
|
||||
- `:CoderAgentStop` - Stop running agent
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder agent` | `:CoderAgent` | Open Agent panel |
|
||||
| `:Coder agent-toggle` | `:CoderAgentToggle` | Toggle Agent panel |
|
||||
| `:Coder agent-stop` | `:CoderAgentStop` | Stop agent |
|
||||
|
||||
### Transform
|
||||
- `:CoderTransform` - Transform all tags
|
||||
- `:CoderTransformCursor` - Transform at cursor
|
||||
- `:CoderTransformVisual` - Transform selection
|
||||
### Transform Commands
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder transform` | `:CoderTransform` | Transform all tags |
|
||||
| `:Coder transform-cursor` | `:CoderTransformCursor` | Transform at cursor |
|
||||
| - | `:CoderTransformVisual` | Transform selected |
|
||||
|
||||
### Utility
|
||||
- `:CoderIndex` - Open coder companion
|
||||
- `:CoderLogs` - Toggle logs panel
|
||||
- `:CoderType` - Switch Ask/Agent mode
|
||||
- `:CoderTree` - Refresh tree.log
|
||||
- `:CoderTreeView` - View tree.log
|
||||
### Conflict Resolution
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderConflictToggle` | Toggle conflict mode |
|
||||
| `:CoderConflictMenu` | Show resolution menu |
|
||||
| `:CoderConflictNext` | Go to next conflict |
|
||||
| `:CoderConflictPrev` | Go to previous conflict |
|
||||
| `:CoderConflictStatus` | Show conflict status |
|
||||
| `:CoderConflictResolveAll [keep]` | Resolve all |
|
||||
| `:CoderConflictAcceptCurrent` | Accept original |
|
||||
| `:CoderConflictAcceptIncoming` | Accept AI |
|
||||
| `:CoderConflictAcceptBoth` | Accept both |
|
||||
| `:CoderConflictAcceptNone` | Delete both |
|
||||
| `:CoderConflictAutoMenu` | Toggle auto-show menu |
|
||||
|
||||
### Linter Validation
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderLintCheck` | Check buffer |
|
||||
| `:CoderLintFix` | AI fix errors |
|
||||
| `:CoderLintQuickfix` | Show in quickfix |
|
||||
| `:CoderLintToggleAuto` | Toggle auto lint |
|
||||
|
||||
### Queue & Scheduler
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder queue-status` | `:CoderQueueStatus` | Show status |
|
||||
| `:Coder queue-process` | `:CoderQueueProcess` | Trigger processing |
|
||||
|
||||
### Processing Mode
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `:Coder auto-toggle` | `:CoderAutoToggle` | Toggle auto/manual |
|
||||
| `:Coder auto-set <mode>` | `:CoderAutoSet` | Set mode |
|
||||
|
||||
### Brain & Memory
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderMemories` | Show memories |
|
||||
| `:CoderForget [pattern]` | Clear memories |
|
||||
| `:CoderBrain [action]` | Brain management |
|
||||
| `:CoderFeedback <type>` | Give feedback |
|
||||
|
||||
### Cost & Credentials
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderCost` | Show cost window |
|
||||
| `:CoderAddApiKey` | Add/update API key |
|
||||
| `:CoderRemoveApiKey` | Remove credentials |
|
||||
| `:CoderCredentials` | Show credentials |
|
||||
| `:CoderSwitchProvider` | Switch provider |
|
||||
|
||||
### UI Commands
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `:CoderLogs` | Toggle logs panel |
|
||||
| `:CoderType` | Show mode switcher |
|
||||
|
||||
## Keymaps Reference
|
||||
|
||||
### Default Keymaps
|
||||
| Key | Mode | Description |
|
||||
|-----|------|-------------|
|
||||
| `<leader>ctt` | Normal | Transform tag at cursor |
|
||||
| `<leader>ctt` | Visual | Transform selected tags |
|
||||
| `<leader>ctT` | Normal | Transform all tags |
|
||||
| `<leader>ca` | Normal | Toggle Agent panel |
|
||||
| `<leader>ci` | Normal | Open coder companion |
|
||||
|
||||
### Ask Panel Keymaps
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `@` | Attach file |
|
||||
| `Ctrl+Enter` | Submit |
|
||||
| `Ctrl+n` | New chat |
|
||||
| `Ctrl+f` | Add current file |
|
||||
| `q` | Close |
|
||||
| `Y` | Copy response |
|
||||
|
||||
### Agent Panel Keymaps
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `<CR>` | Submit |
|
||||
| `Ctrl+c` | Stop agent |
|
||||
| `q` | Close |
|
||||
|
||||
### Logs Panel Keymaps
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `q`/`<Esc>` | Close |
|
||||
|
||||
### Cost Window Keymaps
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `q`/`<Esc>` | Close |
|
||||
| `r` | Refresh |
|
||||
| `c` | Clear session |
|
||||
| `C` | Clear all |
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
```lua
|
||||
{
|
||||
llm = {
|
||||
provider = "claude", -- "claude" | "openai" | "gemini" | "copilot" | "ollama"
|
||||
claude = {
|
||||
api_key = nil, -- string, uses ANTHROPIC_API_KEY env if nil
|
||||
model = "claude-sonnet-4-20250514",
|
||||
},
|
||||
openai = {
|
||||
api_key = nil, -- string, uses OPENAI_API_KEY env if nil
|
||||
model = "gpt-4o",
|
||||
endpoint = nil, -- custom endpoint for Azure, OpenRouter, etc.
|
||||
},
|
||||
gemini = {
|
||||
api_key = nil, -- string, uses GEMINI_API_KEY env if nil
|
||||
model = "gemini-2.0-flash",
|
||||
},
|
||||
copilot = {
|
||||
model = "gpt-4o", -- uses OAuth from copilot.lua/copilot.vim
|
||||
},
|
||||
ollama = {
|
||||
host = "http://localhost:11434",
|
||||
model = "deepseek-coder:6.7b",
|
||||
},
|
||||
provider = "claude",
|
||||
claude = { api_key = nil, model = "claude-sonnet-4-20250514" },
|
||||
openai = { api_key = nil, model = "gpt-4o", endpoint = nil },
|
||||
gemini = { api_key = nil, model = "gemini-2.0-flash" },
|
||||
copilot = { model = "gpt-4o" },
|
||||
ollama = { host = "http://localhost:11434", model = "deepseek-coder:6.7b" },
|
||||
},
|
||||
window = {
|
||||
width = 25, -- percentage (25 = 25% of screen)
|
||||
position = "left", -- "left" | "right"
|
||||
width = 25,
|
||||
position = "left",
|
||||
border = "rounded",
|
||||
},
|
||||
patterns = {
|
||||
@@ -249,13 +409,14 @@ Multiple tags in the same scope follow "first tag wins" rule:
|
||||
},
|
||||
auto_gitignore = true,
|
||||
auto_open_ask = true,
|
||||
auto_index = false, -- auto-create coder companion files
|
||||
auto_index = false,
|
||||
scheduler = {
|
||||
enabled = true, -- enable event-driven scheduler
|
||||
ollama_scout = true, -- use Ollama as fast scout
|
||||
enabled = true,
|
||||
ollama_scout = true,
|
||||
escalation_threshold = 0.7,
|
||||
max_concurrent = 2,
|
||||
completion_delay_ms = 100,
|
||||
apply_delay_ms = 5000,
|
||||
},
|
||||
}
|
||||
```
|
||||
@@ -264,29 +425,26 @@ Multiple tags in the same scope follow "first tag wins" rule:
|
||||
|
||||
### Claude API
|
||||
- Endpoint: `https://api.anthropic.com/v1/messages`
|
||||
- Uses `x-api-key` header for authentication
|
||||
- Supports tool use for agent mode
|
||||
- Auth: `x-api-key` header
|
||||
- Supports tool use
|
||||
|
||||
### OpenAI API
|
||||
- Endpoint: `https://api.openai.com/v1/chat/completions` (configurable)
|
||||
- Uses `Authorization: Bearer` header
|
||||
- Supports tool use for agent mode
|
||||
- Compatible with Azure, OpenRouter, and other OpenAI-compatible APIs
|
||||
- Auth: `Authorization: Bearer`
|
||||
- Compatible with Azure, OpenRouter
|
||||
|
||||
### Gemini API
|
||||
- Endpoint: `https://generativelanguage.googleapis.com/v1beta/models`
|
||||
- Uses API key in URL parameter
|
||||
- Supports function calling for agent mode
|
||||
- Auth: API key in URL
|
||||
- Supports function calling
|
||||
|
||||
### Copilot API
|
||||
- Uses GitHub OAuth token from copilot.lua/copilot.vim
|
||||
- Endpoint from token response (typically `api.githubcopilot.com`)
|
||||
- OpenAI-compatible format
|
||||
|
||||
### Ollama API
|
||||
- Endpoint: `{host}/api/generate` or `{host}/api/chat`
|
||||
- No authentication required for local instances
|
||||
- Tool use via prompt-based approach
|
||||
- No auth required locally
|
||||
|
||||
## Agent Tool Definitions
|
||||
|
||||
@@ -299,13 +457,6 @@ tools = {
|
||||
}
|
||||
```
|
||||
|
||||
## Code Injection Strategies
|
||||
|
||||
1. **Refactor**: Replace entire file content
|
||||
2. **Add**: Insert at cursor position in target file
|
||||
3. **Document**: Insert above current function/class
|
||||
4. **Generic**: Prompt user for action
|
||||
|
||||
## File Naming Convention
|
||||
|
||||
| Target File | Coder File |
|
||||
@@ -318,8 +469,8 @@ Pattern: `name.coder.extension`
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **Required**: Neovim >= 0.8.0, curl
|
||||
- **Optional**: telescope.nvim (enhanced file picker), copilot.lua or copilot.vim (for Copilot provider)
|
||||
- **Required**: Neovim >= 0.8.0, curl, plenary.nvim, nvim-treesitter
|
||||
- **Optional**: telescope.nvim, copilot.lua/copilot.vim, nui.nvim
|
||||
|
||||
## Contact
|
||||
|
||||
|
||||
1525
lua/codetyper/adapters/nvim/autocmds.lua
Normal file
1525
lua/codetyper/adapters/nvim/autocmds.lua
Normal file
File diff suppressed because it is too large
Load Diff
357
lua/codetyper/adapters/nvim/cmp/init.lua
Normal file
357
lua/codetyper/adapters/nvim/cmp/init.lua
Normal file
@@ -0,0 +1,357 @@
|
||||
---@mod codetyper.cmp_source Completion source for nvim-cmp
|
||||
---@brief [[
|
||||
--- Provides intelligent code completions using the brain, indexer, and LLM.
|
||||
--- Integrates with nvim-cmp as a custom source.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local source = {}
|
||||
|
||||
--- Check if cmp is available
|
||||
---@return boolean
|
||||
local function has_cmp()
|
||||
return pcall(require, "cmp")
|
||||
end
|
||||
|
||||
--- Get completion items from brain context
|
||||
---@param prefix string Current word prefix
|
||||
---@return table[] items
|
||||
local function get_brain_completions(prefix)
|
||||
local items = {}
|
||||
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain then
|
||||
return items
|
||||
end
|
||||
|
||||
-- Check if brain is initialized safely
|
||||
local is_init = false
|
||||
if brain.is_initialized then
|
||||
local ok, result = pcall(brain.is_initialized)
|
||||
is_init = ok and result
|
||||
end
|
||||
|
||||
if not is_init then
|
||||
return items
|
||||
end
|
||||
|
||||
-- Query brain for relevant patterns
|
||||
local ok_query, result = pcall(brain.query, {
|
||||
query = prefix,
|
||||
max_results = 10,
|
||||
types = { "pattern" },
|
||||
})
|
||||
|
||||
if ok_query and result and result.nodes then
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c and node.c.s then
|
||||
-- Extract function/class names from summary
|
||||
local summary = node.c.s
|
||||
for name in summary:gmatch("functions:%s*([^;]+)") do
|
||||
for func in name:gmatch("([%w_]+)") do
|
||||
if func:lower():find(prefix:lower(), 1, true) then
|
||||
table.insert(items, {
|
||||
label = func,
|
||||
kind = 3, -- Function
|
||||
detail = "[brain]",
|
||||
documentation = summary,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
for name in summary:gmatch("classes:%s*([^;]+)") do
|
||||
for class in name:gmatch("([%w_]+)") do
|
||||
if class:lower():find(prefix:lower(), 1, true) then
|
||||
table.insert(items, {
|
||||
label = class,
|
||||
kind = 7, -- Class
|
||||
detail = "[brain]",
|
||||
documentation = summary,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return items
|
||||
end
|
||||
|
||||
--- Get completion items from indexer symbols
|
||||
---@param prefix string Current word prefix
|
||||
---@return table[] items
|
||||
local function get_indexer_completions(prefix)
|
||||
local items = {}
|
||||
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if not ok_indexer then
|
||||
return items
|
||||
end
|
||||
|
||||
local ok_load, index = pcall(indexer.load_index)
|
||||
if not ok_load or not index then
|
||||
return items
|
||||
end
|
||||
|
||||
-- Search symbols
|
||||
if index.symbols then
|
||||
for symbol, files in pairs(index.symbols) do
|
||||
if symbol:lower():find(prefix:lower(), 1, true) then
|
||||
local files_str = type(files) == "table" and table.concat(files, ", ") or tostring(files)
|
||||
table.insert(items, {
|
||||
label = symbol,
|
||||
kind = 6, -- Variable (generic)
|
||||
detail = "[index] " .. files_str:sub(1, 30),
|
||||
documentation = "Symbol found in: " .. files_str,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Search functions in files
|
||||
if index.files then
|
||||
for filepath, file_index in pairs(index.files) do
|
||||
if file_index and file_index.functions then
|
||||
for _, func in ipairs(file_index.functions) do
|
||||
if func.name and func.name:lower():find(prefix:lower(), 1, true) then
|
||||
table.insert(items, {
|
||||
label = func.name,
|
||||
kind = 3, -- Function
|
||||
detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"),
|
||||
documentation = func.docstring or ("Function at line " .. (func.line or "?")),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
if file_index and file_index.classes then
|
||||
for _, class in ipairs(file_index.classes) do
|
||||
if class.name and class.name:lower():find(prefix:lower(), 1, true) then
|
||||
table.insert(items, {
|
||||
label = class.name,
|
||||
kind = 7, -- Class
|
||||
detail = "[index] " .. vim.fn.fnamemodify(filepath, ":t"),
|
||||
documentation = class.docstring or ("Class at line " .. (class.line or "?")),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return items
|
||||
end
|
||||
|
||||
--- Get completion items from current buffer (fallback)
|
||||
---@param prefix string Current word prefix
|
||||
---@param bufnr number Buffer number
|
||||
---@return table[] items
|
||||
local function get_buffer_completions(prefix, bufnr)
|
||||
local items = {}
|
||||
local seen = {}
|
||||
|
||||
-- Get all lines in buffer
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local prefix_lower = prefix:lower()
|
||||
|
||||
for _, line in ipairs(lines) do
|
||||
-- Extract words that could be identifiers
|
||||
for word in line:gmatch("[%a_][%w_]*") do
|
||||
if #word >= 3 and word:lower():find(prefix_lower, 1, true) and not seen[word] and word ~= prefix then
|
||||
seen[word] = true
|
||||
table.insert(items, {
|
||||
label = word,
|
||||
kind = 1, -- Text
|
||||
detail = "[buffer]",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return items
|
||||
end
|
||||
|
||||
--- Try to get Copilot suggestion if plugin is installed
|
||||
---@param prefix string
|
||||
---@return string|nil suggestion
|
||||
local function get_copilot_suggestion(prefix)
|
||||
-- Try copilot.lua suggestion API first
|
||||
local ok, copilot_suggestion = pcall(require, "copilot.suggestion")
|
||||
if ok and copilot_suggestion and type(copilot_suggestion.get_suggestion) == "function" then
|
||||
local ok2, suggestion = pcall(copilot_suggestion.get_suggestion)
|
||||
if ok2 and suggestion and suggestion ~= "" then
|
||||
-- Only return if suggestion seems to start with prefix (best-effort)
|
||||
if prefix == "" or suggestion:lower():match(prefix:lower(), 1) then
|
||||
return suggestion
|
||||
else
|
||||
return suggestion
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Fallback: try older copilot module if present
|
||||
local ok3, copilot = pcall(require, "copilot")
|
||||
if ok3 and copilot and type(copilot.get_suggestion) == "function" then
|
||||
local ok4, suggestion = pcall(copilot.get_suggestion)
|
||||
if ok4 and suggestion and suggestion ~= "" then
|
||||
return suggestion
|
||||
end
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Create new cmp source instance
|
||||
function source.new()
|
||||
return setmetatable({}, { __index = source })
|
||||
end
|
||||
|
||||
--- Get source name
|
||||
function source:get_keyword_pattern()
|
||||
return [[\k\+]]
|
||||
end
|
||||
|
||||
--- Check if source is available
|
||||
function source:is_available()
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get debug name
|
||||
function source:get_debug_name()
|
||||
return "codetyper"
|
||||
end
|
||||
|
||||
--- Get trigger characters
|
||||
function source:get_trigger_characters()
|
||||
return { ".", ":", "_" }
|
||||
end
|
||||
|
||||
--- Complete
|
||||
---@param params table
|
||||
---@param callback fun(response: table|nil)
|
||||
function source:complete(params, callback)
|
||||
local prefix = params.context.cursor_before_line:match("[%w_]+$") or ""
|
||||
|
||||
if #prefix < 2 then
|
||||
callback({ items = {}, isIncomplete = true })
|
||||
return
|
||||
end
|
||||
|
||||
-- Collect completions from brain, indexer, and buffer
|
||||
local items = {}
|
||||
local seen = {}
|
||||
|
||||
-- Get brain completions (highest priority)
|
||||
local ok1, brain_items = pcall(get_brain_completions, prefix)
|
||||
if ok1 and brain_items then
|
||||
for _, item in ipairs(brain_items) do
|
||||
if not seen[item.label] then
|
||||
seen[item.label] = true
|
||||
item.sortText = "1" .. item.label
|
||||
table.insert(items, item)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get indexer completions
|
||||
local ok2, indexer_items = pcall(get_indexer_completions, prefix)
|
||||
if ok2 and indexer_items then
|
||||
for _, item in ipairs(indexer_items) do
|
||||
if not seen[item.label] then
|
||||
seen[item.label] = true
|
||||
item.sortText = "2" .. item.label
|
||||
table.insert(items, item)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get buffer completions as fallback (lower priority)
|
||||
local bufnr = params.context.bufnr
|
||||
if bufnr then
|
||||
local ok3, buffer_items = pcall(get_buffer_completions, prefix, bufnr)
|
||||
if ok3 and buffer_items then
|
||||
for _, item in ipairs(buffer_items) do
|
||||
if not seen[item.label] then
|
||||
seen[item.label] = true
|
||||
item.sortText = "3" .. item.label
|
||||
table.insert(items, item)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- If Copilot is installed, prefer its suggestion as a top-priority completion
|
||||
local ok_cp, _ = pcall(require, "copilot")
|
||||
if ok_cp then
|
||||
local suggestion = nil
|
||||
local ok_sug, res = pcall(get_copilot_suggestion, prefix)
|
||||
if ok_sug then
|
||||
suggestion = res
|
||||
end
|
||||
if suggestion and suggestion ~= "" then
|
||||
-- Truncate suggestion to first line for label display
|
||||
local first_line = suggestion:match("([^
|
||||
]+)") or suggestion
|
||||
-- Avoid duplicates
|
||||
if not seen[first_line] then
|
||||
seen[first_line] = true
|
||||
table.insert(items, 1, {
|
||||
label = first_line,
|
||||
kind = 1,
|
||||
detail = "[copilot]",
|
||||
documentation = suggestion,
|
||||
sortText = "0" .. first_line,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback({
|
||||
items = items,
|
||||
isIncomplete = #items >= 50,
|
||||
})
|
||||
end
|
||||
|
||||
--- Setup the completion source
|
||||
function M.setup()
|
||||
if not has_cmp() then
|
||||
return false
|
||||
end
|
||||
|
||||
local cmp = require("cmp")
|
||||
local new_source = source.new()
|
||||
|
||||
-- Register the source
|
||||
cmp.register_source("codetyper", new_source)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Check if source is registered
|
||||
---@return boolean
|
||||
function M.is_registered()
|
||||
local ok, cmp = pcall(require, "cmp")
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Try to get registered sources
|
||||
local config = cmp.get_config()
|
||||
if config and config.sources then
|
||||
for _, src in ipairs(config.sources) do
|
||||
if src.name == "codetyper" then
|
||||
return true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Get source for manual registration
|
||||
function M.get_source()
|
||||
return source
|
||||
end
|
||||
|
||||
return M
|
||||
1492
lua/codetyper/adapters/nvim/commands.lua
Normal file
1492
lua/codetyper/adapters/nvim/commands.lua
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,9 +4,9 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local agent = require("codetyper.agent")
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local utils = require("codetyper.utils")
|
||||
local agent = require("codetyper.features.agents")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
---@class AgentUIState
|
||||
---@field chat_buf number|nil Chat buffer
|
||||
@@ -29,20 +29,69 @@ local state = {
|
||||
is_open = false,
|
||||
log_listener_id = nil,
|
||||
referenced_files = {},
|
||||
selection_context = nil, -- Visual selection passed when opening
|
||||
}
|
||||
|
||||
--- Namespace for highlights
|
||||
local ns_chat = vim.api.nvim_create_namespace("codetyper_agent_chat")
|
||||
local ns_logs = vim.api.nvim_create_namespace("codetyper_agent_logs")
|
||||
|
||||
--- Fixed widths
|
||||
local CHAT_WIDTH = 300
|
||||
local LOGS_WIDTH = 50
|
||||
--- Fixed heights
|
||||
local INPUT_HEIGHT = 5
|
||||
local LOGS_WIDTH = 50
|
||||
|
||||
--- Calculate dynamic width (1/4 of screen, minimum 30)
|
||||
---@return number
|
||||
local function get_panel_width()
|
||||
return math.max(math.floor(vim.o.columns * 0.25), 30)
|
||||
end
|
||||
|
||||
--- Autocmd group
|
||||
local agent_augroup = nil
|
||||
|
||||
--- Autocmd group for width maintenance
|
||||
local width_augroup = nil
|
||||
|
||||
--- Store target width
|
||||
local target_width = nil
|
||||
|
||||
--- Setup autocmd to always maintain 1/4 window width
|
||||
local function setup_width_autocmd()
|
||||
-- Clear previous autocmd group if exists
|
||||
if width_augroup then
|
||||
pcall(vim.api.nvim_del_augroup_by_id, width_augroup)
|
||||
end
|
||||
|
||||
width_augroup = vim.api.nvim_create_augroup("CodetypeAgentWidth", { clear = true })
|
||||
|
||||
-- Always maintain 1/4 width on any window event
|
||||
vim.api.nvim_create_autocmd({ "WinResized", "WinNew", "WinClosed", "VimResized" }, {
|
||||
group = width_augroup,
|
||||
callback = function()
|
||||
if not state.is_open or not state.chat_win then
|
||||
return
|
||||
end
|
||||
if not vim.api.nvim_win_is_valid(state.chat_win) then
|
||||
return
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
if state.chat_win and vim.api.nvim_win_is_valid(state.chat_win) then
|
||||
-- Always calculate 1/4 of current screen width
|
||||
local new_target = math.max(math.floor(vim.o.columns * 0.25), 30)
|
||||
target_width = new_target
|
||||
|
||||
local current_width = vim.api.nvim_win_get_width(state.chat_win)
|
||||
if current_width ~= target_width then
|
||||
pcall(vim.api.nvim_win_set_width, state.chat_win, target_width)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end,
|
||||
desc = "Maintain Agent panel at 1/4 window width",
|
||||
})
|
||||
end
|
||||
|
||||
--- Add a log entry to the logs buffer
|
||||
---@param entry table Log entry
|
||||
local function add_log_entry(entry)
|
||||
@@ -73,7 +122,9 @@ local function add_log_entry(entry)
|
||||
local lines = vim.api.nvim_buf_get_lines(state.logs_buf, 0, -1, false)
|
||||
local line_num = #lines
|
||||
|
||||
vim.api.nvim_buf_set_lines(state.logs_buf, -1, -1, false, { formatted })
|
||||
-- Split formatted log into individual lines to avoid passing newline-containing items
|
||||
local formatted_lines = vim.split(formatted, "\n")
|
||||
vim.api.nvim_buf_set_lines(state.logs_buf, -1, -1, false, formatted_lines)
|
||||
|
||||
-- Apply highlighting based on level
|
||||
local hl_map = {
|
||||
@@ -186,8 +237,16 @@ local function create_callbacks()
|
||||
|
||||
on_complete = function()
|
||||
vim.schedule(function()
|
||||
add_message("system", "Done.", "DiagnosticHint")
|
||||
logs.info("Agent loop completed")
|
||||
local changes_count = agent.get_changes_count()
|
||||
if changes_count > 0 then
|
||||
add_message("system",
|
||||
string.format("Done. %d file(s) changed. Press <leader>d to review changes.", changes_count),
|
||||
"DiagnosticHint")
|
||||
logs.info(string.format("Agent completed with %d change(s)", changes_count))
|
||||
else
|
||||
add_message("system", "Done.", "DiagnosticHint")
|
||||
logs.info("Agent loop completed")
|
||||
end
|
||||
M.focus_input()
|
||||
end)
|
||||
end,
|
||||
@@ -255,12 +314,15 @@ local function submit_input()
|
||||
"╔═══════════════════════════════════════════════════════════════╗",
|
||||
"║ [AGENT MODE] Can read/write files ║",
|
||||
"╠═══════════════════════════════════════════════════════════════╣",
|
||||
"║ @ attach file | C-f current file | :CoderType switch mode ║",
|
||||
"║ @ attach | C-f current file | <leader>d review changes ║",
|
||||
"╚═══════════════════════════════════════════════════════════════╝",
|
||||
"",
|
||||
})
|
||||
vim.bo[state.chat_buf].modifiable = false
|
||||
end
|
||||
-- Also clear collected diffs
|
||||
local diff_review = require("codetyper.adapters.nvim.ui.diff_review")
|
||||
diff_review.clear()
|
||||
return
|
||||
end
|
||||
|
||||
@@ -269,6 +331,30 @@ local function submit_input()
|
||||
return
|
||||
end
|
||||
|
||||
if input == "/continue" then
|
||||
if agent.is_running() then
|
||||
add_message("system", "Agent is already running. Use /stop first.")
|
||||
return
|
||||
end
|
||||
|
||||
if not agent.has_saved_session() then
|
||||
add_message("system", "No saved session to continue.")
|
||||
return
|
||||
end
|
||||
|
||||
local info = agent.get_saved_session_info()
|
||||
if info then
|
||||
add_message("system", string.format("Resuming session from %s...", info.saved_at))
|
||||
logs.info(string.format("Resuming: %d messages, iteration %d", info.messages, info.iteration))
|
||||
end
|
||||
|
||||
local success = agent.continue_session(create_callbacks())
|
||||
if not success then
|
||||
add_message("system", "Failed to resume session.")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build file context
|
||||
local file_context = build_file_context()
|
||||
local file_count = vim.tbl_count(state.referenced_files)
|
||||
@@ -301,7 +387,7 @@ local function submit_input()
|
||||
current_file = vim.fn.expand("%:p")
|
||||
end
|
||||
|
||||
local llm = require("codetyper.llm")
|
||||
local llm = require("codetyper.core.llm")
|
||||
local context = {}
|
||||
|
||||
if current_file ~= "" and vim.fn.filereadable(current_file) == 1 then
|
||||
@@ -311,8 +397,15 @@ local function submit_input()
|
||||
|
||||
-- Append file context to input
|
||||
local full_input = input
|
||||
|
||||
-- Add selection context if present
|
||||
local selection_ctx = M.get_selection_context()
|
||||
if selection_ctx then
|
||||
full_input = full_input .. "\n\n" .. selection_ctx
|
||||
end
|
||||
|
||||
if file_context ~= "" then
|
||||
full_input = input .. "\n\nATTACHED FILES:" .. file_context
|
||||
full_input = full_input .. "\n\nATTACHED FILES:" .. file_context
|
||||
end
|
||||
|
||||
logs.thinking("Starting...")
|
||||
@@ -446,12 +539,20 @@ local function update_logs_title()
|
||||
end
|
||||
|
||||
--- Open the agent UI
|
||||
function M.open()
|
||||
---@param selection table|nil Visual selection context {text, start_line, end_line, filepath, filename, language}
|
||||
function M.open(selection)
|
||||
if state.is_open then
|
||||
-- If already open and new selection provided, add it as context
|
||||
if selection and selection.text and selection.text ~= "" then
|
||||
M.add_selection_context(selection)
|
||||
end
|
||||
M.focus_input()
|
||||
return
|
||||
end
|
||||
|
||||
-- Store selection context
|
||||
state.selection_context = selection
|
||||
|
||||
-- Clear previous state
|
||||
logs.clear()
|
||||
state.referenced_files = {}
|
||||
@@ -479,7 +580,7 @@ function M.open()
|
||||
vim.cmd("topleft vsplit")
|
||||
state.chat_win = vim.api.nvim_get_current_win()
|
||||
vim.api.nvim_win_set_buf(state.chat_win, state.chat_buf)
|
||||
vim.api.nvim_win_set_width(state.chat_win, CHAT_WIDTH)
|
||||
vim.api.nvim_win_set_width(state.chat_win, get_panel_width())
|
||||
|
||||
-- Window options for chat
|
||||
vim.wo[state.chat_win].number = false
|
||||
@@ -526,7 +627,7 @@ function M.open()
|
||||
"╔═══════════════════════════════════════════════════════════════╗",
|
||||
"║ [AGENT MODE] Can read/write files ║",
|
||||
"╠═══════════════════════════════════════════════════════════════╣",
|
||||
"║ @ attach file | C-f current file | :CoderType switch mode ║",
|
||||
"║ @ attach | C-f current file | <leader>d review changes ║",
|
||||
"╚═══════════════════════════════════════════════════════════════╝",
|
||||
"",
|
||||
})
|
||||
@@ -559,6 +660,7 @@ function M.open()
|
||||
vim.keymap.set("n", "<Tab>", M.focus_chat, input_opts)
|
||||
vim.keymap.set("n", "q", M.close, input_opts)
|
||||
vim.keymap.set("n", "<Esc>", M.close, input_opts)
|
||||
vim.keymap.set("n", "<leader>d", M.show_diff_review, input_opts)
|
||||
|
||||
-- Set up keymaps for chat buffer
|
||||
local chat_opts = { buffer = state.chat_buf, noremap = true, silent = true }
|
||||
@@ -569,6 +671,7 @@ function M.open()
|
||||
vim.keymap.set("n", "<C-f>", M.include_current_file, chat_opts)
|
||||
vim.keymap.set("n", "<Tab>", M.focus_logs, chat_opts)
|
||||
vim.keymap.set("n", "q", M.close, chat_opts)
|
||||
vim.keymap.set("n", "<leader>d", M.show_diff_review, chat_opts)
|
||||
|
||||
-- Set up keymaps for logs buffer
|
||||
local logs_opts = { buffer = state.logs_buf, noremap = true, silent = true }
|
||||
@@ -592,18 +695,51 @@ function M.open()
|
||||
end,
|
||||
})
|
||||
|
||||
-- Setup autocmd to maintain 1/4 width
|
||||
target_width = get_panel_width()
|
||||
setup_width_autocmd()
|
||||
|
||||
state.is_open = true
|
||||
|
||||
-- Focus input and log startup
|
||||
M.focus_input()
|
||||
logs.info("Agent ready")
|
||||
|
||||
-- Check for saved session and notify user
|
||||
if agent.has_saved_session() then
|
||||
vim.schedule(function()
|
||||
local info = agent.get_saved_session_info()
|
||||
if info then
|
||||
add_message("system",
|
||||
string.format("Saved session available (%s). Type /continue to resume.", info.saved_at),
|
||||
"DiagnosticHint")
|
||||
logs.info("Saved session found: " .. (info.prompt or ""):sub(1, 30) .. "...")
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- If we have a selection, show it as context
|
||||
if selection and selection.text and selection.text ~= "" then
|
||||
vim.schedule(function()
|
||||
M.add_selection_context(selection)
|
||||
end)
|
||||
end
|
||||
|
||||
-- Log provider info
|
||||
local ok, codetyper = pcall(require, "codetyper")
|
||||
if ok then
|
||||
local config = codetyper.get_config()
|
||||
local provider = config.llm.provider
|
||||
local model = provider == "claude" and config.llm.claude.model or config.llm.ollama.model
|
||||
local model = "unknown"
|
||||
if provider == "ollama" then
|
||||
model = config.llm.ollama.model
|
||||
elseif provider == "openai" then
|
||||
model = config.llm.openai.model
|
||||
elseif provider == "gemini" then
|
||||
model = config.llm.gemini.model
|
||||
elseif provider == "copilot" then
|
||||
model = config.llm.copilot.model
|
||||
end
|
||||
logs.info(string.format("%s (%s)", provider, model))
|
||||
end
|
||||
end
|
||||
@@ -671,4 +807,101 @@ function M.is_open()
|
||||
return state.is_open
|
||||
end
|
||||
|
||||
--- Show the diff review for all changes made in this session
|
||||
function M.show_diff_review()
|
||||
local changes_count = agent.get_changes_count()
|
||||
if changes_count == 0 then
|
||||
utils.notify("No changes to review", vim.log.levels.INFO)
|
||||
return
|
||||
end
|
||||
agent.show_diff_review()
|
||||
end
|
||||
|
||||
--- Add visual selection as context in the chat
|
||||
---@param selection table Selection info {text, start_line, end_line, filepath, filename, language}
|
||||
function M.add_selection_context(selection)
|
||||
if not state.chat_buf or not vim.api.nvim_buf_is_valid(state.chat_buf) then
|
||||
return
|
||||
end
|
||||
|
||||
state.selection_context = selection
|
||||
|
||||
vim.bo[state.chat_buf].modifiable = true
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(state.chat_buf, 0, -1, false)
|
||||
|
||||
-- Format the selection display
|
||||
local location = ""
|
||||
if selection.filename then
|
||||
location = selection.filename
|
||||
if selection.start_line then
|
||||
location = location .. ":" .. selection.start_line
|
||||
if selection.end_line and selection.end_line ~= selection.start_line then
|
||||
location = location .. "-" .. selection.end_line
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local new_lines = {
|
||||
"",
|
||||
"┌─ Selected Code ─────────────────────",
|
||||
"│ " .. location,
|
||||
"│",
|
||||
}
|
||||
|
||||
-- Add the selected code
|
||||
for _, line in ipairs(vim.split(selection.text, "\n")) do
|
||||
table.insert(new_lines, "│ " .. line)
|
||||
end
|
||||
|
||||
table.insert(new_lines, "│")
|
||||
table.insert(new_lines, "└──────────────────────────────────────")
|
||||
table.insert(new_lines, "")
|
||||
table.insert(new_lines, "Describe what you'd like to do with this code.")
|
||||
|
||||
for _, line in ipairs(new_lines) do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
|
||||
vim.api.nvim_buf_set_lines(state.chat_buf, 0, -1, false, lines)
|
||||
vim.bo[state.chat_buf].modifiable = false
|
||||
|
||||
-- Scroll to bottom
|
||||
if state.chat_win and vim.api.nvim_win_is_valid(state.chat_win) then
|
||||
local line_count = vim.api.nvim_buf_line_count(state.chat_buf)
|
||||
vim.api.nvim_win_set_cursor(state.chat_win, { line_count, 0 })
|
||||
end
|
||||
|
||||
-- Also add the file to referenced_files for context
|
||||
if selection.filepath and selection.filepath ~= "" then
|
||||
state.referenced_files[selection.filename or "selection"] = selection.filepath
|
||||
end
|
||||
|
||||
logs.info("Selection added: " .. location)
|
||||
end
|
||||
|
||||
--- Get selection context for agent prompt
|
||||
---@return string|nil Selection context string
|
||||
function M.get_selection_context()
|
||||
if not state.selection_context or not state.selection_context.text then
|
||||
return nil
|
||||
end
|
||||
|
||||
local sel = state.selection_context
|
||||
local location = sel.filename or "unknown"
|
||||
if sel.start_line then
|
||||
location = location .. ":" .. sel.start_line
|
||||
if sel.end_line and sel.end_line ~= sel.start_line then
|
||||
location = location .. "-" .. sel.end_line
|
||||
end
|
||||
end
|
||||
|
||||
return string.format(
|
||||
"SELECTED CODE (%s):\n```%s\n%s\n```",
|
||||
location,
|
||||
sel.language or "",
|
||||
sel.text
|
||||
)
|
||||
end
|
||||
|
||||
return M
|
||||
381
lua/codetyper/adapters/nvim/ui/context_modal.lua
Normal file
381
lua/codetyper/adapters/nvim/ui/context_modal.lua
Normal file
@@ -0,0 +1,381 @@
|
||||
---@mod codetyper.agent.context_modal Modal for additional context input
|
||||
---@brief [[
|
||||
--- Opens a floating window for user to provide additional context
|
||||
--- when the LLM requests more information.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class ContextModalState
|
||||
---@field buf number|nil Buffer number
|
||||
---@field win number|nil Window number
|
||||
---@field original_event table|nil Original prompt event
|
||||
---@field callback function|nil Callback with additional context
|
||||
---@field llm_response string|nil LLM's response asking for context
|
||||
|
||||
local state = {
|
||||
buf = nil,
|
||||
win = nil,
|
||||
original_event = nil,
|
||||
callback = nil,
|
||||
llm_response = nil,
|
||||
attached_files = nil,
|
||||
}
|
||||
|
||||
--- Close the context modal
|
||||
function M.close()
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
vim.api.nvim_win_close(state.win, true)
|
||||
end
|
||||
if state.buf and vim.api.nvim_buf_is_valid(state.buf) then
|
||||
vim.api.nvim_buf_delete(state.buf, { force = true })
|
||||
end
|
||||
state.win = nil
|
||||
state.buf = nil
|
||||
state.original_event = nil
|
||||
state.callback = nil
|
||||
state.llm_response = nil
|
||||
end
|
||||
|
||||
--- Submit the additional context
|
||||
local function submit()
|
||||
if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then
|
||||
return
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(state.buf, 0, -1, false)
|
||||
local additional_context = table.concat(lines, "\n")
|
||||
|
||||
-- Trim whitespace
|
||||
additional_context = additional_context:match("^%s*(.-)%s*$") or additional_context
|
||||
|
||||
if additional_context == "" then
|
||||
M.close()
|
||||
return
|
||||
end
|
||||
|
||||
local original_event = state.original_event
|
||||
local callback = state.callback
|
||||
|
||||
M.close()
|
||||
|
||||
if callback and original_event then
|
||||
-- Pass attached_files as third optional parameter
|
||||
callback(original_event, additional_context, state.attached_files)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
--- Parse requested file paths from LLM response and resolve to full paths
|
||||
local function parse_requested_files(response)
|
||||
if not response or response == "" then
|
||||
return {}
|
||||
end
|
||||
|
||||
local cwd = vim.fn.getcwd()
|
||||
local candidates = {}
|
||||
local seen = {}
|
||||
|
||||
for path in response:gmatch("`([%w%._%-%/]+%.[%w_]+)`") do
|
||||
if not seen[path] then
|
||||
table.insert(candidates, path)
|
||||
seen[path] = true
|
||||
end
|
||||
end
|
||||
for path in response:gmatch("([%w%._%-%/]+%.[%w_]+)") do
|
||||
if not seen[path] then
|
||||
table.insert(candidates, path)
|
||||
seen[path] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- Resolve to full paths using cwd and glob
|
||||
local resolved = {}
|
||||
for _, p in ipairs(candidates) do
|
||||
local full = nil
|
||||
if p:sub(1,1) == "/" and vim.fn.filereadable(p) == 1 then
|
||||
full = p
|
||||
else
|
||||
local try1 = cwd .. "/" .. p
|
||||
if vim.fn.filereadable(try1) == 1 then
|
||||
full = try1
|
||||
else
|
||||
local tail = p:match("[^/]+$") or p
|
||||
local matches = vim.fn.globpath(cwd, "**/" .. tail, false, true)
|
||||
if matches and #matches > 0 then
|
||||
full = matches[1]
|
||||
end
|
||||
end
|
||||
end
|
||||
if full and vim.fn.filereadable(full) == 1 then
|
||||
table.insert(resolved, full)
|
||||
end
|
||||
end
|
||||
return resolved
|
||||
end
|
||||
|
||||
|
||||
--- Attach parsed files into the modal buffer and remember them for submission
|
||||
local function attach_requested_files()
|
||||
if not state.llm_response or state.llm_response == "" then
|
||||
return
|
||||
end
|
||||
local files = parse_requested_files(state.llm_response)
|
||||
if #files == 0 then
|
||||
local ui_prompts = require("codetyper.prompts.agents.modal").ui
|
||||
vim.api.nvim_buf_set_lines(state.buf, vim.api.nvim_buf_line_count(state.buf), -1, false, ui_prompts.files_header)
|
||||
return
|
||||
end
|
||||
|
||||
state.attached_files = state.attached_files or {}
|
||||
|
||||
for _, full in ipairs(files) do
|
||||
local ok, lines = pcall(vim.fn.readfile, full)
|
||||
if ok and lines and #lines > 0 then
|
||||
table.insert(state.attached_files, { path = vim.fn.fnamemodify(full, ":~:." ) , full_path = full, content = table.concat(lines, "\n") })
|
||||
local insert_at = vim.api.nvim_buf_line_count(state.buf)
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Attached: " .. full .. " --" })
|
||||
for i, l in ipairs(lines) do
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at + 1 + i, insert_at + 1 + i, false, { l })
|
||||
end
|
||||
else
|
||||
local insert_at = vim.api.nvim_buf_line_count(state.buf)
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Failed to read: " .. full .. " --" })
|
||||
end
|
||||
end
|
||||
-- Move cursor to end and enter insert mode
|
||||
vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 })
|
||||
vim.cmd("startinsert")
|
||||
end
|
||||
|
||||
--- Open the context modal
|
||||
---@param original_event table Original prompt event
|
||||
---@param llm_response string LLM's response asking for context
|
||||
---@param callback function(event: table, additional_context: string, attached_files?: table)
|
||||
---@param suggested_commands table[]|nil Optional list of {label,cmd} suggested shell commands
|
||||
function M.open(original_event, llm_response, callback, suggested_commands)
|
||||
-- Close any existing modal
|
||||
M.close()
|
||||
|
||||
state.original_event = original_event
|
||||
state.llm_response = llm_response
|
||||
state.callback = callback
|
||||
|
||||
-- Calculate window size
|
||||
local width = math.min(80, vim.o.columns - 10)
|
||||
local height = 10
|
||||
|
||||
-- Create buffer
|
||||
state.buf = vim.api.nvim_create_buf(false, true)
|
||||
vim.bo[state.buf].buftype = "nofile"
|
||||
vim.bo[state.buf].bufhidden = "wipe"
|
||||
vim.bo[state.buf].filetype = "markdown"
|
||||
|
||||
-- Create window
|
||||
local row = math.floor((vim.o.lines - height) / 2)
|
||||
local col = math.floor((vim.o.columns - width) / 2)
|
||||
|
||||
state.win = vim.api.nvim_open_win(state.buf, true, {
|
||||
relative = "editor",
|
||||
row = row,
|
||||
col = col,
|
||||
width = width,
|
||||
height = height,
|
||||
style = "minimal",
|
||||
border = "rounded",
|
||||
title = " Additional Context Needed ",
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
-- Set window options
|
||||
vim.wo[state.win].wrap = true
|
||||
vim.wo[state.win].cursorline = true
|
||||
|
||||
local ui_prompts = require("codetyper.prompts.agents.modal").ui
|
||||
|
||||
-- Add header showing what the LLM said
|
||||
local header_lines = {
|
||||
ui_prompts.llm_response_header,
|
||||
}
|
||||
|
||||
-- Truncate LLM response for display
|
||||
local response_preview = llm_response or ""
|
||||
if #response_preview > 200 then
|
||||
response_preview = response_preview:sub(1, 200) .. "..."
|
||||
end
|
||||
for line in response_preview:gmatch("[^\n]+") do
|
||||
table.insert(header_lines, "-- " .. line)
|
||||
end
|
||||
|
||||
-- If suggested commands were provided, show them in the header
|
||||
if suggested_commands and #suggested_commands > 0 then
|
||||
table.insert(header_lines, "")
|
||||
table.insert(header_lines, ui_prompts.suggested_commands_header)
|
||||
for i, s in ipairs(suggested_commands) do
|
||||
local label = s.label or s.cmd
|
||||
table.insert(header_lines, string.format("[%d] %s: %s", i, label, s.cmd))
|
||||
end
|
||||
table.insert(header_lines, ui_prompts.commands_hint)
|
||||
end
|
||||
|
||||
table.insert(header_lines, "")
|
||||
table.insert(header_lines, ui_prompts.input_header)
|
||||
table.insert(header_lines, "")
|
||||
|
||||
vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, header_lines)
|
||||
|
||||
-- Move cursor to the end
|
||||
vim.api.nvim_win_set_cursor(state.win, { #header_lines, 0 })
|
||||
|
||||
-- Set up keymaps
|
||||
local opts = { buffer = state.buf, noremap = true, silent = true }
|
||||
|
||||
-- Submit with Ctrl+Enter or <leader>s
|
||||
vim.keymap.set("n", "<C-CR>", submit, opts)
|
||||
vim.keymap.set("i", "<C-CR>", submit, opts)
|
||||
vim.keymap.set("n", "<leader>s", submit, opts)
|
||||
vim.keymap.set("n", "<CR><CR>", submit, opts)
|
||||
|
||||
-- Attach parsed files (from LLM response)
|
||||
vim.keymap.set("n", "a", function()
|
||||
attach_requested_files()
|
||||
end, opts)
|
||||
|
||||
-- Confirm and submit with 'c' (convenient when doing question round)
|
||||
vim.keymap.set("n", "c", submit, opts)
|
||||
|
||||
-- Quick run of project inspection from modal with <leader>r / <C-r> in insert mode
|
||||
vim.keymap.set("n", "<leader>r", run_project_inspect, opts)
|
||||
vim.keymap.set("i", "<C-r>", function()
|
||||
vim.schedule(run_project_inspect)
|
||||
end, { buffer = state.buf, noremap = true, silent = true })
|
||||
|
||||
-- If suggested commands provided, create per-command keymaps <leader>1..n to run them
|
||||
state.suggested_commands = suggested_commands
|
||||
if suggested_commands and #suggested_commands > 0 then
|
||||
for i, s in ipairs(suggested_commands) do
|
||||
local key = "<leader>" .. tostring(i)
|
||||
vim.keymap.set("n", key, function()
|
||||
-- run this single command and append output
|
||||
if not s or not s.cmd then
|
||||
return
|
||||
end
|
||||
local ok, out = pcall(vim.fn.systemlist, s.cmd)
|
||||
local insert_at = vim.api.nvim_buf_line_count(state.buf)
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Output: " .. s.cmd .. " --" })
|
||||
if ok and out and #out > 0 then
|
||||
for j, line in ipairs(out) do
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at + j, insert_at + j, false, { line })
|
||||
end
|
||||
else
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at + 1, insert_at + 1, false, { "(no output or command failed)" })
|
||||
end
|
||||
vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 })
|
||||
vim.cmd("startinsert")
|
||||
end, opts)
|
||||
end
|
||||
-- Also map <leader>0 to run all suggested commands
|
||||
vim.keymap.set("n", "<leader>0", function()
|
||||
for _, s in ipairs(suggested_commands) do
|
||||
pcall(function()
|
||||
local ok, out = pcall(vim.fn.systemlist, s.cmd)
|
||||
local insert_at = vim.api.nvim_buf_line_count(state.buf)
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at, insert_at, false, { "", "-- Output: " .. s.cmd .. " --" })
|
||||
if ok and out and #out > 0 then
|
||||
for j, line in ipairs(out) do
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at + j, insert_at + j, false, { line })
|
||||
end
|
||||
else
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_at + 1, insert_at + 1, false, { "(no output or command failed)" })
|
||||
end
|
||||
end)
|
||||
end
|
||||
vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 })
|
||||
vim.cmd("startinsert")
|
||||
end, opts)
|
||||
end
|
||||
|
||||
-- Close with Esc or q
|
||||
vim.keymap.set("n", "<Esc>", M.close, opts)
|
||||
vim.keymap.set("n", "q", M.close, opts)
|
||||
|
||||
-- Start in insert mode
|
||||
vim.cmd("startinsert")
|
||||
|
||||
-- Log
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = "Context modal opened - waiting for user input",
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Run a small set of safe project inspection commands and insert outputs into the modal buffer
|
||||
local function run_project_inspect()
|
||||
if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then
|
||||
return
|
||||
end
|
||||
|
||||
local cmds = {
|
||||
{ label = "List files (ls -la)", cmd = "ls -la" },
|
||||
{ label = "Git status (git status --porcelain)", cmd = "git status --porcelain" },
|
||||
{ label = "Git top (git rev-parse --show-toplevel)", cmd = "git rev-parse --show-toplevel" },
|
||||
{ label = "Show repo files (git ls-files)", cmd = "git ls-files" },
|
||||
}
|
||||
|
||||
local ui_prompts = require("codetyper.prompts.agents.modal").ui
|
||||
local insert_pos = vim.api.nvim_buf_line_count(state.buf)
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_pos, insert_pos, false, ui_prompts.project_inspect_header)
|
||||
|
||||
for _, c in ipairs(cmds) do
|
||||
local ok, out = pcall(vim.fn.systemlist, c.cmd)
|
||||
if ok and out and #out > 0 then
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_pos + 2, insert_pos + 2, false, { "-- " .. c.label .. " --" })
|
||||
for i, line in ipairs(out) do
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_pos + 2 + i, insert_pos + 2 + i, false, { line })
|
||||
end
|
||||
insert_pos = vim.api.nvim_buf_line_count(state.buf)
|
||||
else
|
||||
vim.api.nvim_buf_set_lines(state.buf, insert_pos + 2, insert_pos + 2, false, { "-- " .. c.label .. " --", "(no output or command failed)" })
|
||||
insert_pos = vim.api.nvim_buf_line_count(state.buf)
|
||||
end
|
||||
end
|
||||
|
||||
-- Move cursor to end
|
||||
vim.api.nvim_win_set_cursor(state.win, { vim.api.nvim_buf_line_count(state.buf), 0 })
|
||||
vim.cmd("startinsert")
|
||||
end
|
||||
|
||||
-- Provide a keybinding in the modal to run project inspection commands
|
||||
pcall(function()
|
||||
if state.buf and vim.api.nvim_buf_is_valid(state.buf) then
|
||||
vim.keymap.set("n", "<leader>r", run_project_inspect, { buffer = state.buf, noremap = true, silent = true })
|
||||
vim.keymap.set("i", "<C-r>", function()
|
||||
vim.schedule(run_project_inspect)
|
||||
end, { buffer = state.buf, noremap = true, silent = true })
|
||||
end
|
||||
end)
|
||||
|
||||
--- Check if modal is open
|
||||
---@return boolean
|
||||
function M.is_open()
|
||||
return state.win ~= nil and vim.api.nvim_win_is_valid(state.win)
|
||||
end
|
||||
|
||||
--- Setup autocmds for the context modal
|
||||
function M.setup()
|
||||
local group = vim.api.nvim_create_augroup("CodetypeContextModal", { clear = true })
|
||||
|
||||
-- Close context modal when exiting Neovim
|
||||
vim.api.nvim_create_autocmd("VimLeavePre", {
|
||||
group = group,
|
||||
callback = function()
|
||||
M.close()
|
||||
end,
|
||||
desc = "Close context modal before exiting Neovim",
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
386
lua/codetyper/adapters/nvim/ui/diff_review.lua
Normal file
386
lua/codetyper/adapters/nvim/ui/diff_review.lua
Normal file
@@ -0,0 +1,386 @@
|
||||
---@mod codetyper.agent.diff_review Diff review UI for agent changes
|
||||
---
|
||||
--- Provides a lazygit-style window interface for reviewing all changes
|
||||
--- made during an agent session.
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local prompts = require("codetyper.prompts.agents.diff")
|
||||
|
||||
|
||||
---@class DiffEntry
|
||||
---@field path string File path
|
||||
---@field operation string "create"|"edit"|"delete"
|
||||
---@field original string|nil Original content (nil for new files)
|
||||
---@field modified string New/modified content
|
||||
---@field approved boolean Whether change was approved
|
||||
---@field applied boolean Whether change was applied
|
||||
|
||||
---@class DiffReviewState
|
||||
---@field entries DiffEntry[] List of changes
|
||||
---@field current_index number Currently selected entry
|
||||
---@field list_buf number|nil File list buffer
|
||||
---@field list_win number|nil File list window
|
||||
---@field diff_buf number|nil Diff view buffer
|
||||
---@field diff_win number|nil Diff view window
|
||||
---@field is_open boolean Whether review UI is open
|
||||
|
||||
local state = {
|
||||
entries = {},
|
||||
current_index = 1,
|
||||
list_buf = nil,
|
||||
list_win = nil,
|
||||
diff_buf = nil,
|
||||
diff_win = nil,
|
||||
is_open = false,
|
||||
}
|
||||
|
||||
--- Clear all collected diffs
|
||||
function M.clear()
|
||||
state.entries = {}
|
||||
state.current_index = 1
|
||||
end
|
||||
|
||||
--- Add a diff entry
|
||||
---@param entry DiffEntry
|
||||
function M.add(entry)
|
||||
table.insert(state.entries, entry)
|
||||
end
|
||||
|
||||
--- Get all entries
|
||||
---@return DiffEntry[]
|
||||
function M.get_entries()
|
||||
return state.entries
|
||||
end
|
||||
|
||||
--- Get entry count
|
||||
---@return number
|
||||
function M.count()
|
||||
return #state.entries
|
||||
end
|
||||
|
||||
--- Generate unified diff between two strings
|
||||
---@param original string|nil
|
||||
---@param modified string
|
||||
---@param filepath string
|
||||
---@return string[]
|
||||
local function generate_diff_lines(original, modified, filepath)
|
||||
local lines = {}
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
|
||||
if not original then
|
||||
-- New file
|
||||
table.insert(lines, "--- /dev/null")
|
||||
table.insert(lines, "+++ b/" .. filename)
|
||||
table.insert(lines, "@@ -0,0 +1," .. #vim.split(modified, "\n") .. " @@")
|
||||
for _, line in ipairs(vim.split(modified, "\n")) do
|
||||
table.insert(lines, "+" .. line)
|
||||
end
|
||||
else
|
||||
-- Modified file - use vim's diff
|
||||
table.insert(lines, "--- a/" .. filename)
|
||||
table.insert(lines, "+++ b/" .. filename)
|
||||
|
||||
local orig_lines = vim.split(original, "\n")
|
||||
local mod_lines = vim.split(modified, "\n")
|
||||
|
||||
-- Simple diff: show removed and added lines
|
||||
local max_lines = math.max(#orig_lines, #mod_lines)
|
||||
local context_start = 1
|
||||
local in_change = false
|
||||
|
||||
for i = 1, max_lines do
|
||||
local orig = orig_lines[i] or ""
|
||||
local mod = mod_lines[i] or ""
|
||||
|
||||
if orig ~= mod then
|
||||
if not in_change then
|
||||
table.insert(lines, string.format("@@ -%d,%d +%d,%d @@",
|
||||
math.max(1, i - 2), math.min(5, #orig_lines - i + 3),
|
||||
math.max(1, i - 2), math.min(5, #mod_lines - i + 3)))
|
||||
in_change = true
|
||||
end
|
||||
if orig ~= "" then
|
||||
table.insert(lines, "-" .. orig)
|
||||
end
|
||||
if mod ~= "" then
|
||||
table.insert(lines, "+" .. mod)
|
||||
end
|
||||
else
|
||||
if in_change then
|
||||
table.insert(lines, " " .. orig)
|
||||
in_change = false
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Update the diff view for current entry
|
||||
local function update_diff_view()
|
||||
if not state.diff_buf or not vim.api.nvim_buf_is_valid(state.diff_buf) then
|
||||
return
|
||||
end
|
||||
|
||||
local entry = state.entries[state.current_index]
|
||||
local ui_prompts = prompts.review
|
||||
if not entry then
|
||||
vim.bo[state.diff_buf].modifiable = true
|
||||
vim.api.nvim_buf_set_lines(state.diff_buf, 0, -1, false, { ui_prompts.messages.no_changes_short })
|
||||
vim.bo[state.diff_buf].modifiable = false
|
||||
return
|
||||
end
|
||||
|
||||
local lines = {}
|
||||
|
||||
-- Header
|
||||
local status_icon = entry.applied and " " or (entry.approved and " " or " ")
|
||||
local op_icon = entry.operation == "create" and "+" or (entry.operation == "delete" and "-" or "~")
|
||||
local current_status = entry.applied and ui_prompts.status.applied
|
||||
or (entry.approved and ui_prompts.status.approved or ui_prompts.status.pending)
|
||||
|
||||
table.insert(lines, string.format(ui_prompts.diff_header.top,
|
||||
status_icon, op_icon, vim.fn.fnamemodify(entry.path, ":t")))
|
||||
table.insert(lines, string.format(ui_prompts.diff_header.path, entry.path))
|
||||
table.insert(lines, string.format(ui_prompts.diff_header.op, entry.operation))
|
||||
table.insert(lines, string.format(ui_prompts.diff_header.status, current_status))
|
||||
table.insert(lines, ui_prompts.diff_header.bottom)
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Diff content
|
||||
local diff_lines = generate_diff_lines(entry.original, entry.modified, entry.path)
|
||||
for _, line in ipairs(diff_lines) do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
|
||||
vim.bo[state.diff_buf].modifiable = true
|
||||
vim.api.nvim_buf_set_lines(state.diff_buf, 0, -1, false, lines)
|
||||
vim.bo[state.diff_buf].modifiable = false
|
||||
vim.bo[state.diff_buf].filetype = "diff"
|
||||
end
|
||||
|
||||
--- Update the file list
|
||||
local function update_file_list()
|
||||
if not state.list_buf or not vim.api.nvim_buf_is_valid(state.list_buf) then
|
||||
return
|
||||
end
|
||||
|
||||
local ui_prompts = prompts.review
|
||||
local lines = {}
|
||||
table.insert(lines, string.format(ui_prompts.list_menu.top, #state.entries))
|
||||
for _, item in ipairs(ui_prompts.list_menu.items) do
|
||||
table.insert(lines, item)
|
||||
end
|
||||
table.insert(lines, ui_prompts.list_menu.bottom)
|
||||
table.insert(lines, "")
|
||||
|
||||
for i, entry in ipairs(state.entries) do
|
||||
local prefix = (i == state.current_index) and "▶ " or " "
|
||||
local status = entry.applied and "" or (entry.approved and "" or "○")
|
||||
local op = entry.operation == "create" and "[+]" or (entry.operation == "delete" and "[-]" or "[~]")
|
||||
local filename = vim.fn.fnamemodify(entry.path, ":t")
|
||||
|
||||
table.insert(lines, string.format("%s%s %s %s", prefix, status, op, filename))
|
||||
end
|
||||
|
||||
if #state.entries == 0 then
|
||||
table.insert(lines, ui_prompts.messages.no_changes)
|
||||
end
|
||||
|
||||
vim.bo[state.list_buf].modifiable = true
|
||||
vim.api.nvim_buf_set_lines(state.list_buf, 0, -1, false, lines)
|
||||
vim.bo[state.list_buf].modifiable = false
|
||||
|
||||
-- Highlight current line
|
||||
if state.list_win and vim.api.nvim_win_is_valid(state.list_win) then
|
||||
local target_line = 9 + state.current_index - 1
|
||||
if target_line <= vim.api.nvim_buf_line_count(state.list_buf) then
|
||||
vim.api.nvim_win_set_cursor(state.list_win, { target_line, 0 })
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Navigate to next entry
|
||||
function M.next()
|
||||
if state.current_index < #state.entries then
|
||||
state.current_index = state.current_index + 1
|
||||
update_file_list()
|
||||
update_diff_view()
|
||||
end
|
||||
end
|
||||
|
||||
--- Navigate to previous entry
|
||||
function M.prev()
|
||||
if state.current_index > 1 then
|
||||
state.current_index = state.current_index - 1
|
||||
update_file_list()
|
||||
update_diff_view()
|
||||
end
|
||||
end
|
||||
|
||||
--- Approve current entry
|
||||
function M.approve_current()
|
||||
local entry = state.entries[state.current_index]
|
||||
if entry and not entry.applied then
|
||||
entry.approved = true
|
||||
update_file_list()
|
||||
update_diff_view()
|
||||
end
|
||||
end
|
||||
|
||||
--- Reject current entry
|
||||
function M.reject_current()
|
||||
local entry = state.entries[state.current_index]
|
||||
if entry and not entry.applied then
|
||||
entry.approved = false
|
||||
update_file_list()
|
||||
update_diff_view()
|
||||
end
|
||||
end
|
||||
|
||||
--- Approve all entries
|
||||
function M.approve_all()
|
||||
for _, entry in ipairs(state.entries) do
|
||||
if not entry.applied then
|
||||
entry.approved = true
|
||||
end
|
||||
end
|
||||
update_file_list()
|
||||
update_diff_view()
|
||||
end
|
||||
|
||||
--- Apply approved changes
|
||||
function M.apply_approved()
|
||||
local applied_count = 0
|
||||
|
||||
for _, entry in ipairs(state.entries) do
|
||||
if entry.approved and not entry.applied then
|
||||
if entry.operation == "create" or entry.operation == "edit" then
|
||||
local ok = utils.write_file(entry.path, entry.modified)
|
||||
if ok then
|
||||
entry.applied = true
|
||||
applied_count = applied_count + 1
|
||||
end
|
||||
elseif entry.operation == "delete" then
|
||||
local ok = os.remove(entry.path)
|
||||
if ok then
|
||||
entry.applied = true
|
||||
applied_count = applied_count + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
update_file_list()
|
||||
update_diff_view()
|
||||
|
||||
if applied_count > 0 then
|
||||
utils.notify(string.format(prompts.review.messages.applied_count, applied_count))
|
||||
end
|
||||
|
||||
return applied_count
|
||||
end
|
||||
|
||||
--- Open the diff review UI
|
||||
function M.open()
|
||||
if state.is_open then
|
||||
return
|
||||
end
|
||||
|
||||
if #state.entries == 0 then
|
||||
utils.notify(prompts.review.messages.no_changes_short, vim.log.levels.INFO)
|
||||
return
|
||||
end
|
||||
|
||||
-- Create list buffer
|
||||
state.list_buf = vim.api.nvim_create_buf(false, true)
|
||||
vim.bo[state.list_buf].buftype = "nofile"
|
||||
vim.bo[state.list_buf].bufhidden = "wipe"
|
||||
vim.bo[state.list_buf].swapfile = false
|
||||
|
||||
-- Create diff buffer
|
||||
state.diff_buf = vim.api.nvim_create_buf(false, true)
|
||||
vim.bo[state.diff_buf].buftype = "nofile"
|
||||
vim.bo[state.diff_buf].bufhidden = "wipe"
|
||||
vim.bo[state.diff_buf].swapfile = false
|
||||
|
||||
-- Create layout: list on left (30 cols), diff on right
|
||||
vim.cmd("tabnew")
|
||||
state.diff_win = vim.api.nvim_get_current_win()
|
||||
vim.api.nvim_win_set_buf(state.diff_win, state.diff_buf)
|
||||
|
||||
vim.cmd("topleft vsplit")
|
||||
state.list_win = vim.api.nvim_get_current_win()
|
||||
vim.api.nvim_win_set_buf(state.list_win, state.list_buf)
|
||||
vim.api.nvim_win_set_width(state.list_win, 35)
|
||||
|
||||
-- Window options
|
||||
for _, win in ipairs({ state.list_win, state.diff_win }) do
|
||||
vim.wo[win].number = false
|
||||
vim.wo[win].relativenumber = false
|
||||
vim.wo[win].signcolumn = "no"
|
||||
vim.wo[win].wrap = false
|
||||
vim.wo[win].cursorline = true
|
||||
end
|
||||
|
||||
-- Set up keymaps for list buffer
|
||||
local list_opts = { buffer = state.list_buf, noremap = true, silent = true }
|
||||
vim.keymap.set("n", "j", M.next, list_opts)
|
||||
vim.keymap.set("n", "k", M.prev, list_opts)
|
||||
vim.keymap.set("n", "<Down>", M.next, list_opts)
|
||||
vim.keymap.set("n", "<Up>", M.prev, list_opts)
|
||||
vim.keymap.set("n", "<CR>", function() vim.api.nvim_set_current_win(state.diff_win) end, list_opts)
|
||||
vim.keymap.set("n", "a", M.approve_current, list_opts)
|
||||
vim.keymap.set("n", "r", M.reject_current, list_opts)
|
||||
vim.keymap.set("n", "A", M.approve_all, list_opts)
|
||||
vim.keymap.set("n", "q", M.close, list_opts)
|
||||
vim.keymap.set("n", "<Esc>", M.close, list_opts)
|
||||
|
||||
-- Set up keymaps for diff buffer
|
||||
local diff_opts = { buffer = state.diff_buf, noremap = true, silent = true }
|
||||
vim.keymap.set("n", "j", M.next, diff_opts)
|
||||
vim.keymap.set("n", "k", M.prev, diff_opts)
|
||||
vim.keymap.set("n", "<Tab>", function() vim.api.nvim_set_current_win(state.list_win) end, diff_opts)
|
||||
vim.keymap.set("n", "a", M.approve_current, diff_opts)
|
||||
vim.keymap.set("n", "r", M.reject_current, diff_opts)
|
||||
vim.keymap.set("n", "A", M.approve_all, diff_opts)
|
||||
vim.keymap.set("n", "q", M.close, diff_opts)
|
||||
vim.keymap.set("n", "<Esc>", M.close, diff_opts)
|
||||
|
||||
state.is_open = true
|
||||
state.current_index = 1
|
||||
|
||||
-- Initial render
|
||||
update_file_list()
|
||||
update_diff_view()
|
||||
|
||||
-- Focus list window
|
||||
vim.api.nvim_set_current_win(state.list_win)
|
||||
end
|
||||
|
||||
--- Close the diff review UI
|
||||
function M.close()
|
||||
if not state.is_open then
|
||||
return
|
||||
end
|
||||
|
||||
-- Close the tab (which closes both windows)
|
||||
pcall(vim.cmd, "tabclose")
|
||||
|
||||
state.list_buf = nil
|
||||
state.list_win = nil
|
||||
state.diff_buf = nil
|
||||
state.diff_win = nil
|
||||
state.is_open = false
|
||||
end
|
||||
|
||||
--- Check if review UI is open
|
||||
---@return boolean
|
||||
function M.is_open()
|
||||
return state.is_open
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -4,6 +4,9 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local params = require("codetyper.params.agents.logs")
|
||||
|
||||
|
||||
---@class LogEntry
|
||||
---@field timestamp string ISO timestamp
|
||||
---@field level string "info" | "debug" | "request" | "response" | "tool" | "error"
|
||||
@@ -119,14 +122,7 @@ end
|
||||
---@param status string "start" | "success" | "error" | "approval"
|
||||
---@param details? string Additional details
|
||||
function M.tool(tool_name, status, details)
|
||||
local icons = {
|
||||
start = "->",
|
||||
success = "OK",
|
||||
error = "ERR",
|
||||
approval = "??",
|
||||
approved = "YES",
|
||||
rejected = "NO",
|
||||
}
|
||||
local icons = params.icons
|
||||
|
||||
local msg = string.format("[%s] %s", icons[status] or status, tool_name)
|
||||
if details then
|
||||
@@ -165,10 +161,83 @@ function M.add(entry)
|
||||
M.log(entry.type or "info", entry.message or "", entry.data)
|
||||
end
|
||||
|
||||
--- Log thinking/reasoning step
|
||||
--- Log thinking/reasoning step (Claude Code style)
|
||||
---@param step string Description of what's happening
|
||||
function M.thinking(step)
|
||||
M.log("debug", "> " .. step)
|
||||
M.log("thinking", step)
|
||||
end
|
||||
|
||||
--- Log a reasoning/explanation message (shown prominently)
|
||||
---@param message string The reasoning message
|
||||
function M.reason(message)
|
||||
M.log("reason", message)
|
||||
end
|
||||
|
||||
--- Log file read operation
|
||||
---@param filepath string Path of file being read
|
||||
---@param lines? number Number of lines read
|
||||
function M.read(filepath, lines)
|
||||
local msg = string.format("Read(%s)", vim.fn.fnamemodify(filepath, ":~:."))
|
||||
if lines then
|
||||
msg = msg .. string.format("\n ⎿ Read %d lines", lines)
|
||||
end
|
||||
M.log("action", msg)
|
||||
end
|
||||
|
||||
--- Log explore/search operation
|
||||
---@param description string What we're exploring
|
||||
function M.explore(description)
|
||||
M.log("action", string.format("Explore(%s)", description))
|
||||
end
|
||||
|
||||
--- Log explore done
|
||||
---@param tool_uses number Number of tool uses
|
||||
---@param tokens number Tokens used
|
||||
---@param duration number Duration in seconds
|
||||
function M.explore_done(tool_uses, tokens, duration)
|
||||
M.log("result", string.format(" ⎿ Done (%d tool uses · %.1fk tokens · %.1fs)", tool_uses, tokens / 1000, duration))
|
||||
end
|
||||
|
||||
--- Log update/edit operation
|
||||
---@param filepath string Path of file being edited
|
||||
---@param added? number Lines added
|
||||
---@param removed? number Lines removed
|
||||
function M.update(filepath, added, removed)
|
||||
local msg = string.format("Update(%s)", vim.fn.fnamemodify(filepath, ":~:."))
|
||||
if added or removed then
|
||||
local parts = {}
|
||||
if added and added > 0 then
|
||||
table.insert(parts, string.format("Added %d lines", added))
|
||||
end
|
||||
if removed and removed > 0 then
|
||||
table.insert(parts, string.format("Removed %d lines", removed))
|
||||
end
|
||||
if #parts > 0 then
|
||||
msg = msg .. "\n ⎿ " .. table.concat(parts, ", ")
|
||||
end
|
||||
end
|
||||
M.log("action", msg)
|
||||
end
|
||||
|
||||
--- Log a task/step that's in progress
|
||||
---@param task string Task name
|
||||
---@param status string Status message (optional)
|
||||
function M.task(task, status)
|
||||
local msg = task
|
||||
if status then
|
||||
msg = msg .. " " .. status
|
||||
end
|
||||
M.log("task", msg)
|
||||
end
|
||||
|
||||
--- Log task completion
|
||||
---@param next_task? string Next task (optional)
|
||||
function M.task_done(next_task)
|
||||
local msg = " ⎿ Done"
|
||||
if next_task then
|
||||
msg = msg .. "\n✶ " .. next_task
|
||||
end
|
||||
M.log("result", msg)
|
||||
end
|
||||
|
||||
--- Register a listener for new log entries
|
||||
@@ -223,18 +292,22 @@ end
|
||||
---@param entry LogEntry
|
||||
---@return string
|
||||
function M.format_entry(entry)
|
||||
local level_prefix = ({
|
||||
info = "i",
|
||||
debug = ".",
|
||||
request = ">",
|
||||
response = "<",
|
||||
tool = "T",
|
||||
error = "!",
|
||||
warning = "?",
|
||||
success = "i",
|
||||
queue = "Q",
|
||||
patch = "P",
|
||||
})[entry.level] or "?"
|
||||
-- Claude Code style formatting for thinking/action entries
|
||||
local thinking_types = params.thinking_types
|
||||
local is_thinking = vim.tbl_contains(thinking_types, entry.level)
|
||||
|
||||
if is_thinking then
|
||||
local prefix = params.thinking_prefixes[entry.level] or "⏺"
|
||||
|
||||
if prefix ~= "" then
|
||||
return prefix .. " " .. entry.message
|
||||
else
|
||||
return entry.message
|
||||
end
|
||||
end
|
||||
|
||||
-- Traditional log format for other types
|
||||
local level_prefix = params.level_icons[entry.level] or "?"
|
||||
|
||||
local base = string.format("[%s] %s %s", entry.timestamp, level_prefix, entry.message)
|
||||
|
||||
@@ -248,6 +321,54 @@ function M.format_entry(entry)
|
||||
return base
|
||||
end
|
||||
|
||||
--- Format entry for display in chat (compact Claude Code style)
|
||||
---@param entry LogEntry
|
||||
---@return string|nil Formatted string or nil to skip
|
||||
function M.format_for_chat(entry)
|
||||
-- Skip certain log types in chat view
|
||||
local skip_types = { "debug", "queue", "patch" }
|
||||
if vim.tbl_contains(skip_types, entry.level) then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Claude Code style formatting
|
||||
local thinking_types = params.thinking_types
|
||||
if vim.tbl_contains(thinking_types, entry.level) then
|
||||
local prefix = params.thinking_prefixes[entry.level] or "⏺"
|
||||
|
||||
if prefix ~= "" then
|
||||
return prefix .. " " .. entry.message
|
||||
else
|
||||
return entry.message
|
||||
end
|
||||
end
|
||||
|
||||
-- Tool logs
|
||||
if entry.level == "tool" then
|
||||
return "⏺ " .. entry.message:gsub("^%[.-%] ", "")
|
||||
end
|
||||
|
||||
-- Info/success
|
||||
if entry.level == "info" or entry.level == "success" then
|
||||
return "⏺ " .. entry.message
|
||||
end
|
||||
|
||||
-- Errors
|
||||
if entry.level == "error" then
|
||||
return "⚠ " .. entry.message
|
||||
end
|
||||
|
||||
-- Request/response (compact)
|
||||
if entry.level == "request" then
|
||||
return "⏺ " .. entry.message
|
||||
end
|
||||
if entry.level == "response" then
|
||||
return " ⎿ " .. entry.message
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Estimate token count for a string (rough approximation)
|
||||
---@param text string
|
||||
---@return number
|
||||
@@ -4,8 +4,8 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local queue = require("codetyper.agent.queue")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local queue = require("codetyper.core.events.queue")
|
||||
|
||||
---@class LogsPanelState
|
||||
---@field buf number|nil Logs buffer
|
||||
@@ -20,8 +20,8 @@ function M.show()
|
||||
end
|
||||
|
||||
-- Close current panel first
|
||||
local ask = require("codetyper.ask")
|
||||
local agent_ui = require("codetyper.agent.ui")
|
||||
local ask = require("codetyper.features.ask.engine")
|
||||
local agent_ui = require("codetyper.adapters.nvim.ui.chat")
|
||||
|
||||
if ask.is_open() then
|
||||
ask.close()
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
---@type number|nil Current coder window ID
|
||||
M._coder_win = nil
|
||||
@@ -18,14 +18,14 @@ M._target_buf = nil
|
||||
|
||||
--- Calculate window width based on configuration
|
||||
---@param config CoderConfig Plugin configuration
|
||||
---@return number Width in columns
|
||||
---@return number Width in columns (minimum 30)
|
||||
local function calculate_width(config)
|
||||
local width = config.window.width
|
||||
if width <= 1 then
|
||||
-- Percentage of total width
|
||||
return math.floor(vim.o.columns * width)
|
||||
-- Percentage of total width (1/4 of screen with minimum 30)
|
||||
return math.max(math.floor(vim.o.columns * width), 30)
|
||||
end
|
||||
return math.floor(width)
|
||||
return math.max(math.floor(width), 30)
|
||||
end
|
||||
|
||||
--- Open the coder split view
|
||||
@@ -43,7 +43,7 @@ function M.open_split(target_path, coder_path)
|
||||
utils.write_file(coder_path, "")
|
||||
|
||||
-- Ensure gitignore is updated when creating a new coder file
|
||||
local gitignore = require("codetyper.gitignore")
|
||||
local gitignore = require("codetyper.support.gitignore")
|
||||
gitignore.ensure_ignored()
|
||||
end
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
---@mod codetyper.agent.context_modal Modal for additional context input
|
||||
---@brief [[
|
||||
--- Opens a floating window for user to provide additional context
|
||||
--- when the LLM requests more information.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class ContextModalState
|
||||
---@field buf number|nil Buffer number
|
||||
---@field win number|nil Window number
|
||||
---@field original_event table|nil Original prompt event
|
||||
---@field callback function|nil Callback with additional context
|
||||
---@field llm_response string|nil LLM's response asking for context
|
||||
|
||||
local state = {
|
||||
buf = nil,
|
||||
win = nil,
|
||||
original_event = nil,
|
||||
callback = nil,
|
||||
llm_response = nil,
|
||||
}
|
||||
|
||||
--- Close the context modal
|
||||
function M.close()
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
vim.api.nvim_win_close(state.win, true)
|
||||
end
|
||||
if state.buf and vim.api.nvim_buf_is_valid(state.buf) then
|
||||
vim.api.nvim_buf_delete(state.buf, { force = true })
|
||||
end
|
||||
state.win = nil
|
||||
state.buf = nil
|
||||
state.original_event = nil
|
||||
state.callback = nil
|
||||
state.llm_response = nil
|
||||
end
|
||||
|
||||
--- Submit the additional context
|
||||
local function submit()
|
||||
if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then
|
||||
return
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(state.buf, 0, -1, false)
|
||||
local additional_context = table.concat(lines, "\n")
|
||||
|
||||
-- Trim whitespace
|
||||
additional_context = additional_context:match("^%s*(.-)%s*$") or additional_context
|
||||
|
||||
if additional_context == "" then
|
||||
M.close()
|
||||
return
|
||||
end
|
||||
|
||||
local original_event = state.original_event
|
||||
local callback = state.callback
|
||||
|
||||
M.close()
|
||||
|
||||
if callback and original_event then
|
||||
callback(original_event, additional_context)
|
||||
end
|
||||
end
|
||||
|
||||
--- Open the context modal
|
||||
---@param original_event table Original prompt event
|
||||
---@param llm_response string LLM's response asking for context
|
||||
---@param callback function(event: table, additional_context: string)
|
||||
function M.open(original_event, llm_response, callback)
|
||||
-- Close any existing modal
|
||||
M.close()
|
||||
|
||||
state.original_event = original_event
|
||||
state.llm_response = llm_response
|
||||
state.callback = callback
|
||||
|
||||
-- Calculate window size
|
||||
local width = math.min(80, vim.o.columns - 10)
|
||||
local height = 10
|
||||
|
||||
-- Create buffer
|
||||
state.buf = vim.api.nvim_create_buf(false, true)
|
||||
vim.bo[state.buf].buftype = "nofile"
|
||||
vim.bo[state.buf].bufhidden = "wipe"
|
||||
vim.bo[state.buf].filetype = "markdown"
|
||||
|
||||
-- Create window
|
||||
local row = math.floor((vim.o.lines - height) / 2)
|
||||
local col = math.floor((vim.o.columns - width) / 2)
|
||||
|
||||
state.win = vim.api.nvim_open_win(state.buf, true, {
|
||||
relative = "editor",
|
||||
row = row,
|
||||
col = col,
|
||||
width = width,
|
||||
height = height,
|
||||
style = "minimal",
|
||||
border = "rounded",
|
||||
title = " Additional Context Needed ",
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
-- Set window options
|
||||
vim.wo[state.win].wrap = true
|
||||
vim.wo[state.win].cursorline = true
|
||||
|
||||
-- Add header showing what the LLM said
|
||||
local header_lines = {
|
||||
"-- LLM Response: --",
|
||||
}
|
||||
|
||||
-- Truncate LLM response for display
|
||||
local response_preview = llm_response or ""
|
||||
if #response_preview > 200 then
|
||||
response_preview = response_preview:sub(1, 200) .. "..."
|
||||
end
|
||||
for line in response_preview:gmatch("[^\n]+") do
|
||||
table.insert(header_lines, "-- " .. line)
|
||||
end
|
||||
|
||||
table.insert(header_lines, "")
|
||||
table.insert(header_lines, "-- Enter additional context below (Ctrl-Enter to submit, Esc to cancel) --")
|
||||
table.insert(header_lines, "")
|
||||
|
||||
vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, header_lines)
|
||||
|
||||
-- Move cursor to the end
|
||||
vim.api.nvim_win_set_cursor(state.win, { #header_lines, 0 })
|
||||
|
||||
-- Set up keymaps
|
||||
local opts = { buffer = state.buf, noremap = true, silent = true }
|
||||
|
||||
-- Submit with Ctrl+Enter or <leader>s
|
||||
vim.keymap.set("n", "<C-CR>", submit, opts)
|
||||
vim.keymap.set("i", "<C-CR>", submit, opts)
|
||||
vim.keymap.set("n", "<leader>s", submit, opts)
|
||||
vim.keymap.set("n", "<CR><CR>", submit, opts)
|
||||
|
||||
-- Close with Esc or q
|
||||
vim.keymap.set("n", "<Esc>", M.close, opts)
|
||||
vim.keymap.set("n", "q", M.close, opts)
|
||||
|
||||
-- Start in insert mode
|
||||
vim.cmd("startinsert")
|
||||
|
||||
-- Log
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = "Context modal opened - waiting for user input",
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Check if modal is open
|
||||
---@return boolean
|
||||
function M.is_open()
|
||||
return state.win ~= nil and vim.api.nvim_win_is_valid(state.win)
|
||||
end
|
||||
|
||||
--- Setup autocmds for the context modal
|
||||
function M.setup()
|
||||
local group = vim.api.nvim_create_augroup("CodetypeContextModal", { clear = true })
|
||||
|
||||
-- Close context modal when exiting Neovim
|
||||
vim.api.nvim_create_autocmd("VimLeavePre", {
|
||||
group = group,
|
||||
callback = function()
|
||||
M.close()
|
||||
end,
|
||||
desc = "Close context modal before exiting Neovim",
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,294 +0,0 @@
|
||||
---@mod codetyper.agent.executor Tool executor for agent system
|
||||
---
|
||||
--- Executes tools requested by the LLM and returns results.
|
||||
|
||||
local M = {}
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
---@class ExecutionResult
|
||||
---@field success boolean Whether the execution succeeded
|
||||
---@field result string Result message or content
|
||||
---@field requires_approval boolean Whether user approval is needed
|
||||
---@field diff_data? DiffData Data for diff preview (if requires_approval)
|
||||
|
||||
---@class DiffData
|
||||
---@field path string File path
|
||||
---@field original string Original content
|
||||
---@field modified string Modified content
|
||||
---@field operation string Operation type: "edit", "create", "overwrite", "bash"
|
||||
|
||||
--- Execute a tool and return result via callback
|
||||
---@param tool_name string Name of the tool to execute
|
||||
---@param parameters table Tool parameters
|
||||
---@param callback fun(result: ExecutionResult) Callback with result
|
||||
function M.execute(tool_name, parameters, callback)
|
||||
local handlers = {
|
||||
read_file = M.handle_read_file,
|
||||
edit_file = M.handle_edit_file,
|
||||
write_file = M.handle_write_file,
|
||||
bash = M.handle_bash,
|
||||
}
|
||||
|
||||
local handler = handlers[tool_name]
|
||||
if not handler then
|
||||
callback({
|
||||
success = false,
|
||||
result = "Unknown tool: " .. tool_name,
|
||||
requires_approval = false,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
handler(parameters, callback)
|
||||
end
|
||||
|
||||
--- Handle read_file tool
|
||||
---@param params table { path: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_read_file(params, callback)
|
||||
local path = M.resolve_path(params.path)
|
||||
local content = utils.read_file(path)
|
||||
|
||||
if content then
|
||||
callback({
|
||||
success = true,
|
||||
result = content,
|
||||
requires_approval = false,
|
||||
})
|
||||
else
|
||||
callback({
|
||||
success = false,
|
||||
result = "Could not read file: " .. path,
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
--- Handle edit_file tool
|
||||
---@param params table { path: string, find: string, replace: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_edit_file(params, callback)
|
||||
local path = M.resolve_path(params.path)
|
||||
local original = utils.read_file(path)
|
||||
|
||||
if not original then
|
||||
callback({
|
||||
success = false,
|
||||
result = "File not found: " .. path,
|
||||
requires_approval = false,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Try to find and replace the content
|
||||
local escaped_find = utils.escape_pattern(params.find)
|
||||
local new_content, count = original:gsub(escaped_find, params.replace, 1)
|
||||
|
||||
if count == 0 then
|
||||
callback({
|
||||
success = false,
|
||||
result = "Could not find content to replace in: " .. path,
|
||||
requires_approval = false,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Requires user approval - show diff
|
||||
callback({
|
||||
success = true,
|
||||
result = "Edit prepared for: " .. path,
|
||||
requires_approval = true,
|
||||
diff_data = {
|
||||
path = path,
|
||||
original = original,
|
||||
modified = new_content,
|
||||
operation = "edit",
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle write_file tool
|
||||
---@param params table { path: string, content: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_write_file(params, callback)
|
||||
local path = M.resolve_path(params.path)
|
||||
local original = utils.read_file(path) or ""
|
||||
local operation = original == "" and "create" or "overwrite"
|
||||
|
||||
-- Ensure parent directory exists
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if dir ~= "" and dir ~= "." then
|
||||
utils.ensure_dir(dir)
|
||||
end
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = (operation == "create" and "Create" or "Overwrite") .. " prepared for: " .. path,
|
||||
requires_approval = true,
|
||||
diff_data = {
|
||||
path = path,
|
||||
original = original,
|
||||
modified = params.content,
|
||||
operation = operation,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle bash tool
|
||||
---@param params table { command: string, timeout?: number }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_bash(params, callback)
|
||||
local command = params.command
|
||||
|
||||
-- Requires user approval first
|
||||
callback({
|
||||
success = true,
|
||||
result = "Command: " .. command,
|
||||
requires_approval = true,
|
||||
diff_data = {
|
||||
path = "[bash]",
|
||||
original = "",
|
||||
modified = "$ " .. command,
|
||||
operation = "bash",
|
||||
},
|
||||
bash_command = command,
|
||||
bash_timeout = params.timeout or 30000,
|
||||
})
|
||||
end
|
||||
|
||||
--- Actually apply an approved change
|
||||
---@param diff_data DiffData The diff data to apply
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.apply_change(diff_data, callback)
|
||||
if diff_data.operation == "bash" then
|
||||
-- Extract command from modified (remove "$ " prefix)
|
||||
local command = diff_data.modified:gsub("^%$ ", "")
|
||||
M.execute_bash_command(command, 30000, callback)
|
||||
else
|
||||
-- Write file
|
||||
local success = utils.write_file(diff_data.path, diff_data.modified)
|
||||
if success then
|
||||
-- Reload buffer if it's open
|
||||
M.reload_buffer_if_open(diff_data.path)
|
||||
callback({
|
||||
success = true,
|
||||
result = "Changes applied to: " .. diff_data.path,
|
||||
requires_approval = false,
|
||||
})
|
||||
else
|
||||
callback({
|
||||
success = false,
|
||||
result = "Failed to write: " .. diff_data.path,
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Execute a bash command
|
||||
---@param command string Command to execute
|
||||
---@param timeout number Timeout in milliseconds
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.execute_bash_command(command, timeout, callback)
|
||||
local stdout_data = {}
|
||||
local stderr_data = {}
|
||||
local job_id
|
||||
|
||||
job_id = vim.fn.jobstart(command, {
|
||||
stdout_buffered = true,
|
||||
stderr_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if data then
|
||||
for _, line in ipairs(data) do
|
||||
if line ~= "" then
|
||||
table.insert(stdout_data, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data then
|
||||
for _, line in ipairs(data) do
|
||||
if line ~= "" then
|
||||
table.insert(stderr_data, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, exit_code)
|
||||
vim.schedule(function()
|
||||
local result = table.concat(stdout_data, "\n")
|
||||
if #stderr_data > 0 then
|
||||
if result ~= "" then
|
||||
result = result .. "\n"
|
||||
end
|
||||
result = result .. "STDERR:\n" .. table.concat(stderr_data, "\n")
|
||||
end
|
||||
result = result .. "\n[Exit code: " .. exit_code .. "]"
|
||||
|
||||
callback({
|
||||
success = exit_code == 0,
|
||||
result = result,
|
||||
requires_approval = false,
|
||||
})
|
||||
end)
|
||||
end,
|
||||
})
|
||||
|
||||
-- Set up timeout
|
||||
if job_id > 0 then
|
||||
vim.defer_fn(function()
|
||||
if vim.fn.jobwait({ job_id }, 0)[1] == -1 then
|
||||
vim.fn.jobstop(job_id)
|
||||
vim.schedule(function()
|
||||
callback({
|
||||
success = false,
|
||||
result = "Command timed out after " .. timeout .. "ms",
|
||||
requires_approval = false,
|
||||
})
|
||||
end)
|
||||
end
|
||||
end, timeout)
|
||||
else
|
||||
callback({
|
||||
success = false,
|
||||
result = "Failed to start command",
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
--- Reload a buffer if it's currently open
|
||||
---@param filepath string Path to the file
|
||||
function M.reload_buffer_if_open(filepath)
|
||||
local full_path = vim.fn.fnamemodify(filepath, ":p")
|
||||
for _, buf in ipairs(vim.api.nvim_list_bufs()) do
|
||||
if vim.api.nvim_buf_is_loaded(buf) then
|
||||
local buf_name = vim.api.nvim_buf_get_name(buf)
|
||||
if buf_name == full_path then
|
||||
vim.api.nvim_buf_call(buf, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Resolve a path (expand ~ and make absolute if needed)
|
||||
---@param path string Path to resolve
|
||||
---@return string Resolved path
|
||||
function M.resolve_path(path)
|
||||
-- Expand ~ to home directory
|
||||
local expanded = vim.fn.expand(path)
|
||||
|
||||
-- If relative, make it relative to project root or cwd
|
||||
if not vim.startswith(expanded, "/") then
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
expanded = root .. "/" .. expanded
|
||||
end
|
||||
|
||||
return vim.fn.fnamemodify(expanded, ":p")
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,312 +0,0 @@
|
||||
---@mod codetyper.agent.intent Intent detection from prompts
|
||||
---@brief [[
|
||||
--- Parses prompt content to determine user intent and target scope.
|
||||
--- Intents determine how the generated code should be applied.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class Intent
|
||||
---@field type string "complete"|"refactor"|"add"|"fix"|"document"|"test"|"explain"|"optimize"
|
||||
---@field scope_hint string|nil "function"|"class"|"block"|"file"|"selection"|nil
|
||||
---@field confidence number 0.0-1.0 how confident we are about the intent
|
||||
---@field action string "replace"|"insert"|"append"|"none"
|
||||
---@field keywords string[] Keywords that triggered this intent
|
||||
|
||||
--- Intent patterns with associated metadata
|
||||
local intent_patterns = {
|
||||
-- Complete: fill in missing implementation
|
||||
complete = {
|
||||
patterns = {
|
||||
"complete",
|
||||
"finish",
|
||||
"implement",
|
||||
"fill in",
|
||||
"fill out",
|
||||
"stub",
|
||||
"todo",
|
||||
"fixme",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace",
|
||||
priority = 1,
|
||||
},
|
||||
|
||||
-- Refactor: rewrite existing code
|
||||
refactor = {
|
||||
patterns = {
|
||||
"refactor",
|
||||
"rewrite",
|
||||
"restructure",
|
||||
"reorganize",
|
||||
"clean up",
|
||||
"cleanup",
|
||||
"simplify",
|
||||
"improve",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace",
|
||||
priority = 2,
|
||||
},
|
||||
|
||||
-- Fix: repair bugs or issues
|
||||
fix = {
|
||||
patterns = {
|
||||
"fix",
|
||||
"repair",
|
||||
"correct",
|
||||
"debug",
|
||||
"solve",
|
||||
"resolve",
|
||||
"patch",
|
||||
"bug",
|
||||
"error",
|
||||
"issue",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace",
|
||||
priority = 1,
|
||||
},
|
||||
|
||||
-- Add: insert new code
|
||||
add = {
|
||||
patterns = {
|
||||
"add",
|
||||
"create",
|
||||
"insert",
|
||||
"include",
|
||||
"append",
|
||||
"new",
|
||||
"generate",
|
||||
"write",
|
||||
},
|
||||
scope_hint = nil, -- Could be anywhere
|
||||
action = "insert",
|
||||
priority = 3,
|
||||
},
|
||||
|
||||
-- Document: add documentation
|
||||
document = {
|
||||
patterns = {
|
||||
"document",
|
||||
"comment",
|
||||
"jsdoc",
|
||||
"docstring",
|
||||
"describe",
|
||||
"annotate",
|
||||
"type hint",
|
||||
"typehint",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace", -- Replace with documented version
|
||||
priority = 2,
|
||||
},
|
||||
|
||||
-- Test: generate tests
|
||||
test = {
|
||||
patterns = {
|
||||
"test",
|
||||
"spec",
|
||||
"unit test",
|
||||
"integration test",
|
||||
"coverage",
|
||||
},
|
||||
scope_hint = "file",
|
||||
action = "append",
|
||||
priority = 3,
|
||||
},
|
||||
|
||||
-- Optimize: improve performance
|
||||
optimize = {
|
||||
patterns = {
|
||||
"optimize",
|
||||
"performance",
|
||||
"faster",
|
||||
"efficient",
|
||||
"speed up",
|
||||
"reduce",
|
||||
"minimize",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace",
|
||||
priority = 2,
|
||||
},
|
||||
|
||||
-- Explain: provide explanation (no code change)
|
||||
explain = {
|
||||
patterns = {
|
||||
"explain",
|
||||
"what does",
|
||||
"how does",
|
||||
"why",
|
||||
"describe",
|
||||
"walk through",
|
||||
"understand",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "none",
|
||||
priority = 4,
|
||||
},
|
||||
}
|
||||
|
||||
--- Scope hint patterns
|
||||
local scope_patterns = {
|
||||
["this function"] = "function",
|
||||
["this method"] = "function",
|
||||
["the function"] = "function",
|
||||
["the method"] = "function",
|
||||
["this class"] = "class",
|
||||
["the class"] = "class",
|
||||
["this file"] = "file",
|
||||
["the file"] = "file",
|
||||
["this block"] = "block",
|
||||
["the block"] = "block",
|
||||
["this"] = nil, -- Use Tree-sitter to determine
|
||||
["here"] = nil,
|
||||
}
|
||||
|
||||
--- Detect intent from prompt content
|
||||
---@param prompt string The prompt content
|
||||
---@return Intent
|
||||
function M.detect(prompt)
|
||||
local lower = prompt:lower()
|
||||
local best_match = nil
|
||||
local best_priority = 999
|
||||
local matched_keywords = {}
|
||||
|
||||
-- Check each intent type
|
||||
for intent_type, config in pairs(intent_patterns) do
|
||||
for _, pattern in ipairs(config.patterns) do
|
||||
if lower:find(pattern, 1, true) then
|
||||
if config.priority < best_priority then
|
||||
best_match = intent_type
|
||||
best_priority = config.priority
|
||||
matched_keywords = { pattern }
|
||||
elseif config.priority == best_priority and best_match == intent_type then
|
||||
table.insert(matched_keywords, pattern)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Default to "add" if no clear intent
|
||||
if not best_match then
|
||||
best_match = "add"
|
||||
matched_keywords = {}
|
||||
end
|
||||
|
||||
local config = intent_patterns[best_match]
|
||||
|
||||
-- Detect scope hint from prompt
|
||||
local scope_hint = config.scope_hint
|
||||
for pattern, hint in pairs(scope_patterns) do
|
||||
if lower:find(pattern, 1, true) then
|
||||
scope_hint = hint or scope_hint
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
-- Calculate confidence based on keyword matches
|
||||
local confidence = 0.5 + (#matched_keywords * 0.15)
|
||||
confidence = math.min(confidence, 1.0)
|
||||
|
||||
return {
|
||||
type = best_match,
|
||||
scope_hint = scope_hint,
|
||||
confidence = confidence,
|
||||
action = config.action,
|
||||
keywords = matched_keywords,
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if intent requires code modification
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.modifies_code(intent)
|
||||
return intent.action ~= "none"
|
||||
end
|
||||
|
||||
--- Check if intent should replace existing code
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.is_replacement(intent)
|
||||
return intent.action == "replace"
|
||||
end
|
||||
|
||||
--- Check if intent adds new code
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.is_insertion(intent)
|
||||
return intent.action == "insert" or intent.action == "append"
|
||||
end
|
||||
|
||||
--- Get system prompt modifier based on intent
|
||||
---@param intent Intent
|
||||
---@return string
|
||||
function M.get_prompt_modifier(intent)
|
||||
local modifiers = {
|
||||
complete = [[
|
||||
You are completing an incomplete function.
|
||||
Return the complete function with all missing parts filled in.
|
||||
Keep the existing signature unless changes are required.
|
||||
Output only the code, no explanations.]],
|
||||
|
||||
refactor = [[
|
||||
You are refactoring existing code.
|
||||
Improve the code structure while maintaining the same behavior.
|
||||
Keep the function signature unchanged.
|
||||
Output only the refactored code, no explanations.]],
|
||||
|
||||
fix = [[
|
||||
You are fixing a bug in the code.
|
||||
Identify and correct the issue while minimizing changes.
|
||||
Preserve the original intent of the code.
|
||||
Output only the fixed code, no explanations.]],
|
||||
|
||||
add = [[
|
||||
You are adding new code.
|
||||
Follow the existing code style and conventions.
|
||||
Output only the new code to be inserted, no explanations.]],
|
||||
|
||||
document = [[
|
||||
You are adding documentation to the code.
|
||||
Add appropriate comments/docstrings for the function.
|
||||
Include parameter types, return types, and description.
|
||||
Output the complete function with documentation.]],
|
||||
|
||||
test = [[
|
||||
You are generating tests for the code.
|
||||
Create comprehensive unit tests covering edge cases.
|
||||
Follow the testing conventions of the project.
|
||||
Output only the test code, no explanations.]],
|
||||
|
||||
optimize = [[
|
||||
You are optimizing code for performance.
|
||||
Improve efficiency while maintaining correctness.
|
||||
Document any significant algorithmic changes.
|
||||
Output only the optimized code, no explanations.]],
|
||||
|
||||
explain = [[
|
||||
You are explaining code to a developer.
|
||||
Provide a clear, concise explanation of what the code does.
|
||||
Include information about the algorithm and any edge cases.
|
||||
Do not output code, only explanation.]],
|
||||
}
|
||||
|
||||
return modifiers[intent.type] or modifiers.add
|
||||
end
|
||||
|
||||
--- Format intent for logging
|
||||
---@param intent Intent
|
||||
---@return string
|
||||
function M.format(intent)
|
||||
return string.format(
|
||||
"%s (scope: %s, action: %s, confidence: %.2f)",
|
||||
intent.type,
|
||||
intent.scope_hint or "auto",
|
||||
intent.action,
|
||||
intent.confidence
|
||||
)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,665 +0,0 @@
|
||||
---@mod codetyper.agent.patch Patch system with staleness detection
|
||||
---@brief [[
|
||||
--- Manages code patches with buffer snapshots for staleness detection.
|
||||
--- Patches are queued for safe injection when completion popup is not visible.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class BufferSnapshot
|
||||
---@field bufnr number Buffer number
|
||||
---@field changedtick number vim.b.changedtick at snapshot time
|
||||
---@field content_hash string Hash of buffer content in range
|
||||
---@field range {start_line: number, end_line: number}|nil Range snapshotted
|
||||
|
||||
---@class PatchCandidate
|
||||
---@field id string Unique patch ID
|
||||
---@field event_id string Related PromptEvent ID
|
||||
---@field target_bufnr number Target buffer for injection
|
||||
---@field target_path string Target file path
|
||||
---@field original_snapshot BufferSnapshot Snapshot at event creation
|
||||
---@field generated_code string Code to inject
|
||||
---@field injection_range {start_line: number, end_line: number}|nil
|
||||
---@field injection_strategy string "append"|"replace"|"insert"
|
||||
---@field confidence number Confidence score (0.0-1.0)
|
||||
---@field status string "pending"|"applied"|"stale"|"rejected"
|
||||
---@field created_at number Timestamp
|
||||
---@field applied_at number|nil When applied
|
||||
|
||||
--- Patch storage
|
||||
---@type PatchCandidate[]
|
||||
local patches = {}
|
||||
|
||||
--- Patch ID counter
|
||||
local patch_counter = 0
|
||||
|
||||
--- Generate unique patch ID
|
||||
---@return string
|
||||
function M.generate_id()
|
||||
patch_counter = patch_counter + 1
|
||||
return string.format("patch_%d_%d", os.time(), patch_counter)
|
||||
end
|
||||
|
||||
--- Hash buffer content in range
|
||||
---@param bufnr number
|
||||
---@param start_line number|nil 1-indexed, nil for whole buffer
|
||||
---@param end_line number|nil 1-indexed, nil for whole buffer
|
||||
---@return string
|
||||
local function hash_buffer_range(bufnr, start_line, end_line)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return ""
|
||||
end
|
||||
|
||||
local lines
|
||||
if start_line and end_line then
|
||||
lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
else
|
||||
lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
end
|
||||
|
||||
local content = table.concat(lines, "\n")
|
||||
local hash = 0
|
||||
for i = 1, #content do
|
||||
hash = (hash * 31 + string.byte(content, i)) % 2147483647
|
||||
end
|
||||
return string.format("%x", hash)
|
||||
end
|
||||
|
||||
--- Take a snapshot of buffer state
|
||||
---@param bufnr number Buffer number
|
||||
---@param range {start_line: number, end_line: number}|nil Optional range
|
||||
---@return BufferSnapshot
|
||||
function M.snapshot_buffer(bufnr, range)
|
||||
local changedtick = 0
|
||||
if vim.api.nvim_buf_is_valid(bufnr) then
|
||||
changedtick = vim.api.nvim_buf_get_var(bufnr, "changedtick") or vim.b[bufnr].changedtick or 0
|
||||
end
|
||||
|
||||
local content_hash
|
||||
if range then
|
||||
content_hash = hash_buffer_range(bufnr, range.start_line, range.end_line)
|
||||
else
|
||||
content_hash = hash_buffer_range(bufnr, nil, nil)
|
||||
end
|
||||
|
||||
return {
|
||||
bufnr = bufnr,
|
||||
changedtick = changedtick,
|
||||
content_hash = content_hash,
|
||||
range = range,
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if buffer changed since snapshot
|
||||
---@param snapshot BufferSnapshot
|
||||
---@return boolean is_stale
|
||||
---@return string|nil reason
|
||||
function M.is_snapshot_stale(snapshot)
|
||||
if not vim.api.nvim_buf_is_valid(snapshot.bufnr) then
|
||||
return true, "buffer_invalid"
|
||||
end
|
||||
|
||||
-- Check changedtick first (fast path)
|
||||
local current_tick = vim.api.nvim_buf_get_var(snapshot.bufnr, "changedtick")
|
||||
or vim.b[snapshot.bufnr].changedtick or 0
|
||||
|
||||
if current_tick ~= snapshot.changedtick then
|
||||
-- Changedtick differs, but might be just cursor movement
|
||||
-- Verify with content hash
|
||||
local current_hash
|
||||
if snapshot.range then
|
||||
current_hash = hash_buffer_range(
|
||||
snapshot.bufnr,
|
||||
snapshot.range.start_line,
|
||||
snapshot.range.end_line
|
||||
)
|
||||
else
|
||||
current_hash = hash_buffer_range(snapshot.bufnr, nil, nil)
|
||||
end
|
||||
|
||||
if current_hash ~= snapshot.content_hash then
|
||||
return true, "content_changed"
|
||||
end
|
||||
end
|
||||
|
||||
return false, nil
|
||||
end
|
||||
|
||||
--- Check if a patch is stale
|
||||
---@param patch PatchCandidate
|
||||
---@return boolean
|
||||
---@return string|nil reason
|
||||
function M.is_stale(patch)
|
||||
return M.is_snapshot_stale(patch.original_snapshot)
|
||||
end
|
||||
|
||||
--- Queue a patch for deferred application
|
||||
---@param patch PatchCandidate
|
||||
---@return PatchCandidate
|
||||
function M.queue_patch(patch)
|
||||
patch.id = patch.id or M.generate_id()
|
||||
patch.status = patch.status or "pending"
|
||||
patch.created_at = patch.created_at or os.time()
|
||||
|
||||
table.insert(patches, patch)
|
||||
|
||||
-- Log patch creation
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "patch",
|
||||
message = string.format(
|
||||
"Patch queued: %s (confidence: %.2f)",
|
||||
patch.id, patch.confidence or 0
|
||||
),
|
||||
data = {
|
||||
patch_id = patch.id,
|
||||
event_id = patch.event_id,
|
||||
target_path = patch.target_path,
|
||||
code_preview = patch.generated_code:sub(1, 50),
|
||||
},
|
||||
})
|
||||
end)
|
||||
|
||||
return patch
|
||||
end
|
||||
|
||||
--- Create patch from event and response
|
||||
---@param event table PromptEvent
|
||||
---@param generated_code string
|
||||
---@param confidence number
|
||||
---@param strategy string|nil Injection strategy (overrides intent-based)
|
||||
---@return PatchCandidate
|
||||
function M.create_from_event(event, generated_code, confidence, strategy)
|
||||
-- Get target buffer
|
||||
local target_bufnr = vim.fn.bufnr(event.target_path)
|
||||
if target_bufnr == -1 then
|
||||
-- Try to find by filename
|
||||
for _, buf in ipairs(vim.api.nvim_list_bufs()) do
|
||||
local name = vim.api.nvim_buf_get_name(buf)
|
||||
if name == event.target_path then
|
||||
target_bufnr = buf
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Take snapshot of the scope range in target buffer (for staleness detection)
|
||||
local snapshot_range = event.scope_range or event.range
|
||||
local snapshot = M.snapshot_buffer(
|
||||
target_bufnr ~= -1 and target_bufnr or event.bufnr,
|
||||
snapshot_range
|
||||
)
|
||||
|
||||
-- Determine injection strategy and range based on intent
|
||||
local injection_strategy = strategy
|
||||
local injection_range = nil
|
||||
|
||||
if not injection_strategy and event.intent then
|
||||
local intent_mod = require("codetyper.agent.intent")
|
||||
if intent_mod.is_replacement(event.intent) then
|
||||
injection_strategy = "replace"
|
||||
-- Use scope range for replacement
|
||||
if event.scope_range then
|
||||
injection_range = event.scope_range
|
||||
end
|
||||
elseif event.intent.action == "insert" then
|
||||
injection_strategy = "insert"
|
||||
-- Insert at prompt location
|
||||
injection_range = { start_line = event.range.start_line, end_line = event.range.start_line }
|
||||
elseif event.intent.action == "append" then
|
||||
injection_strategy = "append"
|
||||
-- Will append to end of file
|
||||
else
|
||||
injection_strategy = "append"
|
||||
end
|
||||
end
|
||||
|
||||
injection_strategy = injection_strategy or "append"
|
||||
|
||||
return {
|
||||
id = M.generate_id(),
|
||||
event_id = event.id,
|
||||
target_bufnr = target_bufnr,
|
||||
target_path = event.target_path,
|
||||
original_snapshot = snapshot,
|
||||
generated_code = generated_code,
|
||||
injection_range = injection_range,
|
||||
injection_strategy = injection_strategy,
|
||||
confidence = confidence,
|
||||
status = "pending",
|
||||
created_at = os.time(),
|
||||
intent = event.intent,
|
||||
scope = event.scope,
|
||||
-- Store the prompt tag range so we can delete it after applying
|
||||
prompt_tag_range = event.range,
|
||||
}
|
||||
end
|
||||
|
||||
--- Get all pending patches
|
||||
---@return PatchCandidate[]
|
||||
function M.get_pending()
|
||||
local pending = {}
|
||||
for _, patch in ipairs(patches) do
|
||||
if patch.status == "pending" then
|
||||
table.insert(pending, patch)
|
||||
end
|
||||
end
|
||||
return pending
|
||||
end
|
||||
|
||||
--- Get patch by ID
|
||||
---@param id string
|
||||
---@return PatchCandidate|nil
|
||||
function M.get(id)
|
||||
for _, patch in ipairs(patches) do
|
||||
if patch.id == id then
|
||||
return patch
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get patches for event
|
||||
---@param event_id string
|
||||
---@return PatchCandidate[]
|
||||
function M.get_for_event(event_id)
|
||||
local result = {}
|
||||
for _, patch in ipairs(patches) do
|
||||
if patch.event_id == event_id then
|
||||
table.insert(result, patch)
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Mark patch as applied
|
||||
---@param id string
|
||||
---@return boolean
|
||||
function M.mark_applied(id)
|
||||
local patch = M.get(id)
|
||||
if patch then
|
||||
patch.status = "applied"
|
||||
patch.applied_at = os.time()
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Mark patch as stale
|
||||
---@param id string
|
||||
---@param reason string|nil
|
||||
---@return boolean
|
||||
function M.mark_stale(id, reason)
|
||||
local patch = M.get(id)
|
||||
if patch then
|
||||
patch.status = "stale"
|
||||
patch.stale_reason = reason
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Mark patch as rejected
|
||||
---@param id string
|
||||
---@param reason string|nil
|
||||
---@return boolean
|
||||
function M.mark_rejected(id, reason)
|
||||
local patch = M.get(id)
|
||||
if patch then
|
||||
patch.status = "rejected"
|
||||
patch.reject_reason = reason
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Remove /@ @/ prompt tags from buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@return number Number of tag regions removed
|
||||
local function remove_prompt_tags(bufnr)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local removed = 0
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
|
||||
-- Find and remove all /@ ... @/ regions (can be multiline)
|
||||
local i = 1
|
||||
while i <= #lines do
|
||||
local line = lines[i]
|
||||
local open_start = line:find("/@")
|
||||
|
||||
if open_start then
|
||||
-- Found an opening tag, look for closing tag
|
||||
local close_end = nil
|
||||
local close_line = i
|
||||
|
||||
-- Check if closing tag is on same line
|
||||
local after_open = line:sub(open_start + 2)
|
||||
local same_line_close = after_open:find("@/")
|
||||
if same_line_close then
|
||||
-- Single line tag - remove just this portion
|
||||
local before = line:sub(1, open_start - 1)
|
||||
local after = line:sub(open_start + 2 + same_line_close + 1)
|
||||
lines[i] = before .. after
|
||||
-- If line is now empty or just whitespace, remove it
|
||||
if lines[i]:match("^%s*$") then
|
||||
table.remove(lines, i)
|
||||
else
|
||||
i = i + 1
|
||||
end
|
||||
removed = removed + 1
|
||||
else
|
||||
-- Multi-line tag - find the closing line
|
||||
for j = i, #lines do
|
||||
if lines[j]:find("@/") then
|
||||
close_line = j
|
||||
close_end = lines[j]:find("@/")
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if close_end then
|
||||
-- Remove lines from i to close_line
|
||||
-- Keep content before /@ on first line and after @/ on last line
|
||||
local before = lines[i]:sub(1, open_start - 1)
|
||||
local after = lines[close_line]:sub(close_end + 2)
|
||||
|
||||
-- Remove the lines containing the tag
|
||||
for _ = i, close_line do
|
||||
table.remove(lines, i)
|
||||
end
|
||||
|
||||
-- If there's content to keep, insert it back
|
||||
local remaining = (before .. after):match("^%s*(.-)%s*$")
|
||||
if remaining and remaining ~= "" then
|
||||
table.insert(lines, i, remaining)
|
||||
i = i + 1
|
||||
end
|
||||
|
||||
removed = removed + 1
|
||||
else
|
||||
-- No closing tag found, skip this line
|
||||
i = i + 1
|
||||
end
|
||||
end
|
||||
else
|
||||
i = i + 1
|
||||
end
|
||||
end
|
||||
|
||||
if removed > 0 then
|
||||
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines)
|
||||
end
|
||||
|
||||
return removed
|
||||
end
|
||||
|
||||
--- Check if it's safe to modify the buffer (not in insert mode)
|
||||
---@return boolean
|
||||
local function is_safe_to_modify()
|
||||
local mode = vim.fn.mode()
|
||||
-- Don't modify if in insert mode or completion is visible
|
||||
if mode == "i" or mode == "ic" or mode == "ix" then
|
||||
return false
|
||||
end
|
||||
if vim.fn.pumvisible() == 1 then
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
--- Apply a patch to the target buffer
|
||||
---@param patch PatchCandidate
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
function M.apply(patch)
|
||||
-- Check if safe to modify (not in insert mode)
|
||||
if not is_safe_to_modify() then
|
||||
return false, "user_typing"
|
||||
end
|
||||
|
||||
-- Check staleness first
|
||||
local is_stale, stale_reason = M.is_stale(patch)
|
||||
if is_stale then
|
||||
M.mark_stale(patch.id, stale_reason)
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = string.format("Patch %s is stale: %s", patch.id, stale_reason or "unknown"),
|
||||
})
|
||||
end)
|
||||
|
||||
return false, "patch_stale: " .. (stale_reason or "unknown")
|
||||
end
|
||||
|
||||
-- Ensure target buffer is valid
|
||||
local target_bufnr = patch.target_bufnr
|
||||
if target_bufnr == -1 or not vim.api.nvim_buf_is_valid(target_bufnr) then
|
||||
-- Try to load buffer from path
|
||||
target_bufnr = vim.fn.bufadd(patch.target_path)
|
||||
if target_bufnr == 0 then
|
||||
M.mark_rejected(patch.id, "buffer_not_found")
|
||||
return false, "target buffer not found"
|
||||
end
|
||||
vim.fn.bufload(target_bufnr)
|
||||
patch.target_bufnr = target_bufnr
|
||||
end
|
||||
|
||||
-- Prepare code lines
|
||||
local code_lines = vim.split(patch.generated_code, "\n", { plain = true })
|
||||
|
||||
-- FIRST: Remove the prompt tags from the buffer before applying code
|
||||
-- This prevents the infinite loop where tags stay and get re-detected
|
||||
local tags_removed = remove_prompt_tags(target_bufnr)
|
||||
|
||||
pcall(function()
|
||||
if tags_removed > 0 then
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Removed %d prompt tag(s) from buffer", tags_removed),
|
||||
})
|
||||
end
|
||||
end)
|
||||
|
||||
-- Recalculate line count after tag removal
|
||||
local line_count = vim.api.nvim_buf_line_count(target_bufnr)
|
||||
|
||||
-- Apply based on strategy
|
||||
local ok, err = pcall(function()
|
||||
if patch.injection_strategy == "replace" and patch.injection_range then
|
||||
-- Replace the scope range with the new code
|
||||
-- The injection_range points to the function/method we're completing
|
||||
local start_line = patch.injection_range.start_line
|
||||
local end_line = patch.injection_range.end_line
|
||||
|
||||
-- Adjust for tag removal - find the new range by searching for the scope
|
||||
-- After removing tags, line numbers may have shifted
|
||||
-- Use the scope information to find the correct range
|
||||
if patch.scope and patch.scope.type then
|
||||
-- Try to find the scope using treesitter if available
|
||||
local found_range = nil
|
||||
pcall(function()
|
||||
local ts_utils = require("nvim-treesitter.ts_utils")
|
||||
local parsers = require("nvim-treesitter.parsers")
|
||||
local parser = parsers.get_parser(target_bufnr)
|
||||
if parser then
|
||||
local tree = parser:parse()[1]
|
||||
if tree then
|
||||
local root = tree:root()
|
||||
-- Find the function/method node that contains our original position
|
||||
local function find_scope_node(node)
|
||||
local node_type = node:type()
|
||||
local is_scope = node_type:match("function")
|
||||
or node_type:match("method")
|
||||
or node_type:match("class")
|
||||
or node_type:match("declaration")
|
||||
|
||||
if is_scope then
|
||||
local s_row, _, e_row, _ = node:range()
|
||||
-- Check if this scope roughly matches our expected range
|
||||
if math.abs(s_row - (start_line - 1)) <= 5 then
|
||||
found_range = { start_line = s_row + 1, end_line = e_row + 1 }
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
for child in node:iter_children() do
|
||||
if find_scope_node(child) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
find_scope_node(root)
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
if found_range then
|
||||
start_line = found_range.start_line
|
||||
end_line = found_range.end_line
|
||||
end
|
||||
end
|
||||
|
||||
-- Clamp to valid range
|
||||
start_line = math.max(1, start_line)
|
||||
end_line = math.min(line_count, end_line)
|
||||
|
||||
-- Replace the range (0-indexed for nvim_buf_set_lines)
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, start_line - 1, end_line, false, code_lines)
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Replacing lines %d-%d with %d lines of code", start_line, end_line, #code_lines),
|
||||
})
|
||||
end)
|
||||
elseif patch.injection_strategy == "insert" and patch.injection_range then
|
||||
-- Insert at the specified location
|
||||
local insert_line = patch.injection_range.start_line
|
||||
insert_line = math.max(1, math.min(line_count + 1, insert_line))
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, insert_line - 1, insert_line - 1, false, code_lines)
|
||||
else
|
||||
-- Default: append to end
|
||||
-- Check if last line is empty, if not add a blank line for spacing
|
||||
local last_line = vim.api.nvim_buf_get_lines(target_bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Last line has content, add blank line for spacing
|
||||
table.insert(code_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(target_bufnr, line_count, line_count, false, code_lines)
|
||||
end
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
M.mark_rejected(patch.id, err)
|
||||
return false, err
|
||||
end
|
||||
|
||||
M.mark_applied(patch.id)
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
logs.add({
|
||||
type = "success",
|
||||
message = string.format("Patch %s applied successfully", patch.id),
|
||||
data = {
|
||||
target_path = patch.target_path,
|
||||
lines_added = #code_lines,
|
||||
},
|
||||
})
|
||||
end)
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
--- Flush all pending patches that are safe to apply
|
||||
---@return number applied_count
|
||||
---@return number stale_count
|
||||
---@return number deferred_count
|
||||
function M.flush_pending()
|
||||
local applied = 0
|
||||
local stale = 0
|
||||
local deferred = 0
|
||||
|
||||
for _, p in ipairs(patches) do
|
||||
if p.status == "pending" then
|
||||
local success, err = M.apply(p)
|
||||
if success then
|
||||
applied = applied + 1
|
||||
elseif err == "user_typing" then
|
||||
-- Keep pending, will retry later
|
||||
deferred = deferred + 1
|
||||
else
|
||||
stale = stale + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return applied, stale, deferred
|
||||
end
|
||||
|
||||
--- Cancel all pending patches for a buffer
|
||||
---@param bufnr number
|
||||
---@return number cancelled_count
|
||||
function M.cancel_for_buffer(bufnr)
|
||||
local cancelled = 0
|
||||
for _, patch in ipairs(patches) do
|
||||
if patch.status == "pending" and
|
||||
(patch.target_bufnr == bufnr or patch.original_snapshot.bufnr == bufnr) then
|
||||
patch.status = "cancelled"
|
||||
cancelled = cancelled + 1
|
||||
end
|
||||
end
|
||||
return cancelled
|
||||
end
|
||||
|
||||
--- Cleanup old patches
|
||||
---@param max_age number Max age in seconds (default: 300)
|
||||
function M.cleanup(max_age)
|
||||
max_age = max_age or 300
|
||||
local now = os.time()
|
||||
local i = 1
|
||||
while i <= #patches do
|
||||
local patch = patches[i]
|
||||
if patch.status ~= "pending" and (now - patch.created_at) > max_age then
|
||||
table.remove(patches, i)
|
||||
else
|
||||
i = i + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Get statistics
|
||||
---@return table
|
||||
function M.stats()
|
||||
local stats = {
|
||||
total = #patches,
|
||||
pending = 0,
|
||||
applied = 0,
|
||||
stale = 0,
|
||||
rejected = 0,
|
||||
cancelled = 0,
|
||||
}
|
||||
for _, patch in ipairs(patches) do
|
||||
local s = patch.status
|
||||
if stats[s] then
|
||||
stats[s] = stats[s] + 1
|
||||
end
|
||||
end
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Clear all patches
|
||||
function M.clear()
|
||||
patches = {}
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,161 +0,0 @@
|
||||
---@mod codetyper.agent.tools Tool definitions for the agent system
|
||||
---
|
||||
--- Defines available tools that the LLM can use to interact with files and system.
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Tool definitions in a provider-agnostic format
|
||||
M.definitions = {
|
||||
read_file = {
|
||||
name = "read_file",
|
||||
description = "Read the contents of a file at the specified path",
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = {
|
||||
path = {
|
||||
type = "string",
|
||||
description = "Absolute or relative path to the file to read",
|
||||
},
|
||||
},
|
||||
required = { "path" },
|
||||
},
|
||||
},
|
||||
|
||||
edit_file = {
|
||||
name = "edit_file",
|
||||
description = "Edit a file by replacing specific content. Provide the exact content to find and the replacement.",
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = {
|
||||
path = {
|
||||
type = "string",
|
||||
description = "Path to the file to edit",
|
||||
},
|
||||
find = {
|
||||
type = "string",
|
||||
description = "Exact content to find (must match exactly, including whitespace)",
|
||||
},
|
||||
replace = {
|
||||
type = "string",
|
||||
description = "Content to replace with",
|
||||
},
|
||||
},
|
||||
required = { "path", "find", "replace" },
|
||||
},
|
||||
},
|
||||
|
||||
write_file = {
|
||||
name = "write_file",
|
||||
description = "Write content to a file, creating it if it doesn't exist or overwriting if it does",
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = {
|
||||
path = {
|
||||
type = "string",
|
||||
description = "Path to the file to write",
|
||||
},
|
||||
content = {
|
||||
type = "string",
|
||||
description = "Complete file content to write",
|
||||
},
|
||||
},
|
||||
required = { "path", "content" },
|
||||
},
|
||||
},
|
||||
|
||||
bash = {
|
||||
name = "bash",
|
||||
description = "Execute a bash command and return the output. Use for git, npm, build tools, etc.",
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = {
|
||||
command = {
|
||||
type = "string",
|
||||
description = "The bash command to execute",
|
||||
},
|
||||
timeout = {
|
||||
type = "number",
|
||||
description = "Timeout in milliseconds (default: 30000)",
|
||||
},
|
||||
},
|
||||
required = { "command" },
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
--- Convert tool definitions to Claude API format
|
||||
---@return table[] Tools in Claude's expected format
|
||||
function M.to_claude_format()
|
||||
local tools = {}
|
||||
for _, tool in pairs(M.definitions) do
|
||||
table.insert(tools, {
|
||||
name = tool.name,
|
||||
description = tool.description,
|
||||
input_schema = tool.parameters,
|
||||
})
|
||||
end
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Convert tool definitions to OpenAI API format
|
||||
---@return table[] Tools in OpenAI's expected format
|
||||
function M.to_openai_format()
|
||||
local tools = {}
|
||||
for _, tool in pairs(M.definitions) do
|
||||
table.insert(tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = tool.description,
|
||||
parameters = tool.parameters,
|
||||
},
|
||||
})
|
||||
end
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Convert tool definitions to prompt format for Ollama
|
||||
---@return string Formatted tool descriptions for system prompt
|
||||
function M.to_prompt_format()
|
||||
local lines = {
|
||||
"You have access to the following tools. To use a tool, respond with a JSON block.",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, tool in pairs(M.definitions) do
|
||||
table.insert(lines, "## " .. tool.name)
|
||||
table.insert(lines, tool.description)
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Parameters:")
|
||||
for prop_name, prop in pairs(tool.parameters.properties) do
|
||||
local required = vim.tbl_contains(tool.parameters.required or {}, prop_name)
|
||||
local req_str = required and " (required)" or " (optional)"
|
||||
table.insert(lines, " - " .. prop_name .. ": " .. prop.description .. req_str)
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
table.insert(lines, "---")
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "To call a tool, output a JSON block like this:")
|
||||
table.insert(lines, "```json")
|
||||
table.insert(lines, '{"tool": "tool_name", "parameters": {"param1": "value1"}}')
|
||||
table.insert(lines, "```")
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "After receiving tool results, continue your response or call another tool.")
|
||||
table.insert(lines, "When you're done, just respond normally without any tool calls.")
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Get a list of tool names
|
||||
---@return string[]
|
||||
function M.get_tool_names()
|
||||
local names = {}
|
||||
for name, _ in pairs(M.definitions) do
|
||||
table.insert(names, name)
|
||||
end
|
||||
return names
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,999 +0,0 @@
|
||||
---@mod codetyper.autocmds Autocommands for Codetyper.nvim
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
|
||||
--- Autocommand group name
|
||||
local AUGROUP = "Codetyper"
|
||||
|
||||
--- Debounce timer for tree updates
|
||||
local tree_update_timer = nil
|
||||
local TREE_UPDATE_DEBOUNCE_MS = 1000 -- 1 second debounce
|
||||
|
||||
--- Track processed prompts to avoid re-processing
|
||||
---@type table<string, boolean>
|
||||
local processed_prompts = {}
|
||||
|
||||
--- Track if we're currently asking for preferences
|
||||
local asking_preference = false
|
||||
|
||||
--- Generate a unique key for a prompt
|
||||
---@param bufnr number Buffer number
|
||||
---@param prompt table Prompt object
|
||||
---@return string Unique key
|
||||
local function get_prompt_key(bufnr, prompt)
|
||||
return string.format("%d:%d:%d:%s", bufnr, prompt.start_line, prompt.end_line, prompt.content:sub(1, 50))
|
||||
end
|
||||
|
||||
--- Schedule tree update with debounce
|
||||
local function schedule_tree_update()
|
||||
if tree_update_timer then
|
||||
tree_update_timer:stop()
|
||||
end
|
||||
|
||||
tree_update_timer = vim.defer_fn(function()
|
||||
local tree = require("codetyper.tree")
|
||||
tree.update_tree_log()
|
||||
tree_update_timer = nil
|
||||
end, TREE_UPDATE_DEBOUNCE_MS)
|
||||
end
|
||||
|
||||
--- Setup autocommands
|
||||
function M.setup()
|
||||
local group = vim.api.nvim_create_augroup(AUGROUP, { clear = true })
|
||||
|
||||
-- Auto-check for closed prompts when leaving insert mode (works on ALL files)
|
||||
vim.api.nvim_create_autocmd("InsertLeave", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function()
|
||||
-- Skip special buffers
|
||||
local buftype = vim.bo.buftype
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
-- Auto-save coder files only
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
if utils.is_coder_file(filepath) and vim.bo.modified then
|
||||
vim.cmd("silent! write")
|
||||
end
|
||||
-- Check for closed prompts and auto-process (respects preferences)
|
||||
M.check_for_closed_prompt_with_preference()
|
||||
end,
|
||||
desc = "Check for closed prompt tags on InsertLeave",
|
||||
})
|
||||
|
||||
-- Auto-process prompts when entering normal mode (works on ALL files)
|
||||
vim.api.nvim_create_autocmd("ModeChanged", {
|
||||
group = group,
|
||||
pattern = "*:n",
|
||||
callback = function()
|
||||
-- Skip special buffers
|
||||
local buftype = vim.bo.buftype
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
-- Slight delay to let buffer settle
|
||||
vim.defer_fn(function()
|
||||
M.check_all_prompts_with_preference()
|
||||
end, 50)
|
||||
end,
|
||||
desc = "Auto-process closed prompts when entering normal mode",
|
||||
})
|
||||
|
||||
-- Also check on CursorHold as backup (works on ALL files)
|
||||
vim.api.nvim_create_autocmd("CursorHold", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function()
|
||||
-- Skip special buffers
|
||||
local buftype = vim.bo.buftype
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
local mode = vim.api.nvim_get_mode().mode
|
||||
if mode == "n" then
|
||||
M.check_all_prompts_with_preference()
|
||||
end
|
||||
end,
|
||||
desc = "Auto-process closed prompts when idle in normal mode",
|
||||
})
|
||||
|
||||
-- Auto-set filetype for coder files based on extension
|
||||
vim.api.nvim_create_autocmd({ "BufRead", "BufNewFile" }, {
|
||||
group = group,
|
||||
pattern = "*.coder.*",
|
||||
callback = function()
|
||||
M.set_coder_filetype()
|
||||
end,
|
||||
desc = "Set filetype for coder files",
|
||||
})
|
||||
|
||||
-- Auto-open split view when opening a coder file directly (e.g., from nvim-tree)
|
||||
vim.api.nvim_create_autocmd("BufEnter", {
|
||||
group = group,
|
||||
pattern = "*.coder.*",
|
||||
callback = function()
|
||||
-- Delay slightly to ensure buffer is fully loaded
|
||||
vim.defer_fn(function()
|
||||
M.auto_open_target_file()
|
||||
end, 50)
|
||||
end,
|
||||
desc = "Auto-open target file when coder file is opened",
|
||||
})
|
||||
|
||||
-- Cleanup on buffer close
|
||||
vim.api.nvim_create_autocmd("BufWipeout", {
|
||||
group = group,
|
||||
pattern = "*.coder.*",
|
||||
callback = function(ev)
|
||||
local window = require("codetyper.window")
|
||||
if window.is_open() then
|
||||
window.close_split()
|
||||
end
|
||||
-- Clear processed prompts for this buffer
|
||||
local bufnr = ev.buf
|
||||
for key, _ in pairs(processed_prompts) do
|
||||
if key:match("^" .. bufnr .. ":") then
|
||||
processed_prompts[key] = nil
|
||||
end
|
||||
end
|
||||
-- Clear auto-opened tracking
|
||||
M.clear_auto_opened(bufnr)
|
||||
end,
|
||||
desc = "Cleanup on coder buffer close",
|
||||
})
|
||||
|
||||
-- Update tree.log when files are created/written
|
||||
vim.api.nvim_create_autocmd({ "BufWritePost", "BufNewFile" }, {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function(ev)
|
||||
-- Skip coder files and tree.log itself
|
||||
local filepath = ev.file or vim.fn.expand("%:p")
|
||||
if filepath:match("%.coder%.") or filepath:match("tree%.log$") then
|
||||
return
|
||||
end
|
||||
-- Schedule tree update with debounce
|
||||
schedule_tree_update()
|
||||
end,
|
||||
desc = "Update tree.log on file creation/save",
|
||||
})
|
||||
|
||||
-- Update tree.log when files are deleted (via netrw or file explorer)
|
||||
vim.api.nvim_create_autocmd("BufDelete", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function(ev)
|
||||
local filepath = ev.file or ""
|
||||
-- Skip special buffers and coder files
|
||||
if filepath == "" or filepath:match("%.coder%.") or filepath:match("tree%.log$") then
|
||||
return
|
||||
end
|
||||
schedule_tree_update()
|
||||
end,
|
||||
desc = "Update tree.log on file deletion",
|
||||
})
|
||||
|
||||
-- Update tree on directory change
|
||||
vim.api.nvim_create_autocmd("DirChanged", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function()
|
||||
schedule_tree_update()
|
||||
end,
|
||||
desc = "Update tree.log on directory change",
|
||||
})
|
||||
|
||||
-- Auto-index: Create/open coder companion file when opening source files
|
||||
vim.api.nvim_create_autocmd("BufEnter", {
|
||||
group = group,
|
||||
pattern = "*",
|
||||
callback = function(ev)
|
||||
-- Delay to ensure buffer is fully loaded
|
||||
vim.defer_fn(function()
|
||||
M.auto_index_file(ev.buf)
|
||||
end, 100)
|
||||
end,
|
||||
desc = "Auto-index source files with coder companion",
|
||||
})
|
||||
end
|
||||
|
||||
--- Get config with fallback defaults
|
||||
local function get_config_safe()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
-- Return defaults if not initialized
|
||||
if not config or not config.patterns then
|
||||
return {
|
||||
patterns = {
|
||||
open_tag = "/@",
|
||||
close_tag = "@/",
|
||||
file_pattern = "*.coder.*",
|
||||
}
|
||||
}
|
||||
end
|
||||
return config
|
||||
end
|
||||
|
||||
--- Read attached files from prompt content
|
||||
---@param prompt_content string Prompt content
|
||||
---@param base_path string Base path to resolve relative file paths
|
||||
---@return table[] attached_files List of {path, content} tables
|
||||
local function read_attached_files(prompt_content, base_path)
|
||||
local parser = require("codetyper.parser")
|
||||
local file_refs = parser.extract_file_references(prompt_content)
|
||||
local attached = {}
|
||||
local cwd = vim.fn.getcwd()
|
||||
local base_dir = vim.fn.fnamemodify(base_path, ":h")
|
||||
|
||||
for _, ref in ipairs(file_refs) do
|
||||
local file_path = nil
|
||||
|
||||
-- Try resolving relative to cwd first
|
||||
local cwd_path = cwd .. "/" .. ref
|
||||
if utils.file_exists(cwd_path) then
|
||||
file_path = cwd_path
|
||||
else
|
||||
-- Try resolving relative to base file directory
|
||||
local rel_path = base_dir .. "/" .. ref
|
||||
if utils.file_exists(rel_path) then
|
||||
file_path = rel_path
|
||||
end
|
||||
end
|
||||
|
||||
if file_path then
|
||||
local content = utils.read_file(file_path)
|
||||
if content then
|
||||
table.insert(attached, {
|
||||
path = ref,
|
||||
full_path = file_path,
|
||||
content = content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return attached
|
||||
end
|
||||
|
||||
--- Check if the buffer has a newly closed prompt and auto-process (works on ANY file)
|
||||
function M.check_for_closed_prompt()
|
||||
local config = get_config_safe()
|
||||
local parser = require("codetyper.parser")
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
|
||||
-- Skip if no file
|
||||
if current_file == "" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Get current line
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local line = cursor[1]
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, line - 1, line, false)
|
||||
|
||||
if #lines == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
local current_line = lines[1]
|
||||
|
||||
-- Check if line contains closing tag
|
||||
if parser.has_closing_tag(current_line, config.patterns.close_tag) then
|
||||
-- Find the complete prompt
|
||||
local prompt = parser.get_last_prompt(bufnr)
|
||||
if prompt and prompt.content and prompt.content ~= "" then
|
||||
-- Generate unique key for this prompt
|
||||
local prompt_key = get_prompt_key(bufnr, prompt)
|
||||
|
||||
-- Check if already processed
|
||||
if processed_prompts[prompt_key] then
|
||||
return
|
||||
end
|
||||
|
||||
-- Mark as processed
|
||||
processed_prompts[prompt_key] = true
|
||||
|
||||
-- Check if scheduler is enabled
|
||||
local codetyper = require("codetyper")
|
||||
local ct_config = codetyper.get_config()
|
||||
local scheduler_enabled = ct_config and ct_config.scheduler and ct_config.scheduler.enabled
|
||||
|
||||
if scheduler_enabled then
|
||||
-- Event-driven: emit to queue
|
||||
vim.schedule(function()
|
||||
local queue = require("codetyper.agent.queue")
|
||||
local patch_mod = require("codetyper.agent.patch")
|
||||
local intent_mod = require("codetyper.agent.intent")
|
||||
local scope_mod = require("codetyper.agent.scope")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
|
||||
-- Open logs panel to show progress
|
||||
logs_panel.ensure_open()
|
||||
|
||||
-- Take buffer snapshot
|
||||
local snapshot = patch_mod.snapshot_buffer(bufnr, {
|
||||
start_line = prompt.start_line,
|
||||
end_line = prompt.end_line,
|
||||
})
|
||||
|
||||
-- Get target path - for coder files, get the target; for regular files, use self
|
||||
local target_path
|
||||
if utils.is_coder_file(current_file) then
|
||||
target_path = utils.get_target_path(current_file)
|
||||
else
|
||||
target_path = current_file
|
||||
end
|
||||
|
||||
-- Read attached files before cleaning
|
||||
local attached_files = read_attached_files(prompt.content, current_file)
|
||||
|
||||
-- Clean prompt content (strip file references)
|
||||
local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content))
|
||||
|
||||
-- Resolve scope in target file FIRST (need it to adjust intent)
|
||||
local target_bufnr = vim.fn.bufnr(target_path)
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr
|
||||
end
|
||||
|
||||
local scope = nil
|
||||
local scope_text = nil
|
||||
local scope_range = nil
|
||||
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
end
|
||||
|
||||
-- Detect intent from prompt
|
||||
local intent = intent_mod.detect(cleaned)
|
||||
|
||||
-- IMPORTANT: If prompt is inside a function/method and intent is "add",
|
||||
-- override to "complete" since we're completing the function body
|
||||
if scope and (scope.type == "function" or scope.type == "method") then
|
||||
if intent.type == "add" or intent.action == "insert" or intent.action == "append" then
|
||||
-- Override to complete the function instead of adding new code
|
||||
intent = {
|
||||
type = "complete",
|
||||
scope_hint = "function",
|
||||
confidence = intent.confidence,
|
||||
action = "replace",
|
||||
keywords = intent.keywords,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Determine priority based on intent
|
||||
local priority = 2 -- Normal
|
||||
if intent.type == "fix" or intent.type == "complete" then
|
||||
priority = 1 -- High priority for fixes and completions
|
||||
elseif intent.type == "test" or intent.type == "document" then
|
||||
priority = 3 -- Lower priority for tests and docs
|
||||
end
|
||||
|
||||
-- Enqueue the event
|
||||
queue.enqueue({
|
||||
id = queue.generate_id(),
|
||||
bufnr = bufnr,
|
||||
range = { start_line = prompt.start_line, end_line = prompt.end_line },
|
||||
timestamp = os.clock(),
|
||||
changedtick = snapshot.changedtick,
|
||||
content_hash = snapshot.content_hash,
|
||||
prompt_content = cleaned,
|
||||
target_path = target_path,
|
||||
priority = priority,
|
||||
status = "pending",
|
||||
attempt_count = 0,
|
||||
intent = intent,
|
||||
scope = scope,
|
||||
scope_text = scope_text,
|
||||
scope_range = scope_range,
|
||||
attached_files = attached_files,
|
||||
})
|
||||
|
||||
local scope_info = scope and scope.type ~= "file"
|
||||
and string.format(" [%s: %s]", scope.type, scope.name or "anonymous")
|
||||
or ""
|
||||
utils.notify(
|
||||
string.format("Prompt queued: %s%s", intent.type, scope_info),
|
||||
vim.log.levels.INFO
|
||||
)
|
||||
end)
|
||||
else
|
||||
-- Legacy: direct processing
|
||||
utils.notify("Processing prompt...", vim.log.levels.INFO)
|
||||
vim.schedule(function()
|
||||
vim.cmd("CoderProcess")
|
||||
end)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Check and process all closed prompts in the buffer (works on ANY file)
|
||||
function M.check_all_prompts()
|
||||
local parser = require("codetyper.parser")
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
|
||||
-- Skip if no file
|
||||
if current_file == "" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Find all prompts in buffer
|
||||
local prompts = parser.find_prompts_in_buffer(bufnr)
|
||||
|
||||
if #prompts == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if scheduler is enabled
|
||||
local codetyper = require("codetyper")
|
||||
local ct_config = codetyper.get_config()
|
||||
local scheduler_enabled = ct_config and ct_config.scheduler and ct_config.scheduler.enabled
|
||||
|
||||
if not scheduler_enabled then
|
||||
return
|
||||
end
|
||||
|
||||
for _, prompt in ipairs(prompts) do
|
||||
if prompt.content and prompt.content ~= "" then
|
||||
-- Generate unique key for this prompt
|
||||
local prompt_key = get_prompt_key(bufnr, prompt)
|
||||
|
||||
-- Skip if already processed
|
||||
if processed_prompts[prompt_key] then
|
||||
goto continue
|
||||
end
|
||||
|
||||
-- Mark as processed
|
||||
processed_prompts[prompt_key] = true
|
||||
|
||||
-- Process this prompt
|
||||
vim.schedule(function()
|
||||
local queue = require("codetyper.agent.queue")
|
||||
local patch_mod = require("codetyper.agent.patch")
|
||||
local intent_mod = require("codetyper.agent.intent")
|
||||
local scope_mod = require("codetyper.agent.scope")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
|
||||
-- Open logs panel to show progress
|
||||
logs_panel.ensure_open()
|
||||
|
||||
-- Take buffer snapshot
|
||||
local snapshot = patch_mod.snapshot_buffer(bufnr, {
|
||||
start_line = prompt.start_line,
|
||||
end_line = prompt.end_line,
|
||||
})
|
||||
|
||||
-- Get target path - for coder files, get the target; for regular files, use self
|
||||
local target_path
|
||||
if utils.is_coder_file(current_file) then
|
||||
target_path = utils.get_target_path(current_file)
|
||||
else
|
||||
target_path = current_file
|
||||
end
|
||||
|
||||
-- Read attached files before cleaning
|
||||
local attached_files = read_attached_files(prompt.content, current_file)
|
||||
|
||||
-- Clean prompt content (strip file references)
|
||||
local cleaned = parser.clean_prompt(parser.strip_file_references(prompt.content))
|
||||
|
||||
-- Resolve scope in target file FIRST (need it to adjust intent)
|
||||
local target_bufnr = vim.fn.bufnr(target_path)
|
||||
if target_bufnr == -1 then
|
||||
target_bufnr = bufnr -- Use current buffer if target not loaded
|
||||
end
|
||||
|
||||
local scope = nil
|
||||
local scope_text = nil
|
||||
local scope_range = nil
|
||||
|
||||
scope = scope_mod.resolve_scope(target_bufnr, prompt.start_line, 1)
|
||||
if scope and scope.type ~= "file" then
|
||||
scope_text = scope.text
|
||||
scope_range = {
|
||||
start_line = scope.range.start_row,
|
||||
end_line = scope.range.end_row,
|
||||
}
|
||||
end
|
||||
|
||||
-- Detect intent from prompt
|
||||
local intent = intent_mod.detect(cleaned)
|
||||
|
||||
-- IMPORTANT: If prompt is inside a function/method and intent is "add",
|
||||
-- override to "complete" since we're completing the function body
|
||||
if scope and (scope.type == "function" or scope.type == "method") then
|
||||
if intent.type == "add" or intent.action == "insert" or intent.action == "append" then
|
||||
-- Override to complete the function instead of adding new code
|
||||
intent = {
|
||||
type = "complete",
|
||||
scope_hint = "function",
|
||||
confidence = intent.confidence,
|
||||
action = "replace",
|
||||
keywords = intent.keywords,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Determine priority based on intent
|
||||
local priority = 2
|
||||
if intent.type == "fix" or intent.type == "complete" then
|
||||
priority = 1
|
||||
elseif intent.type == "test" or intent.type == "document" then
|
||||
priority = 3
|
||||
end
|
||||
|
||||
-- Enqueue the event
|
||||
queue.enqueue({
|
||||
id = queue.generate_id(),
|
||||
bufnr = bufnr,
|
||||
range = { start_line = prompt.start_line, end_line = prompt.end_line },
|
||||
timestamp = os.clock(),
|
||||
changedtick = snapshot.changedtick,
|
||||
content_hash = snapshot.content_hash,
|
||||
prompt_content = cleaned,
|
||||
target_path = target_path,
|
||||
priority = priority,
|
||||
status = "pending",
|
||||
attempt_count = 0,
|
||||
intent = intent,
|
||||
scope = scope,
|
||||
scope_text = scope_text,
|
||||
scope_range = scope_range,
|
||||
attached_files = attached_files,
|
||||
})
|
||||
|
||||
local scope_info = scope and scope.type ~= "file"
|
||||
and string.format(" [%s: %s]", scope.type, scope.name or "anonymous")
|
||||
or ""
|
||||
utils.notify(
|
||||
string.format("Prompt queued: %s%s", intent.type, scope_info),
|
||||
vim.log.levels.INFO
|
||||
)
|
||||
end)
|
||||
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Check for closed prompt with preference check
|
||||
--- If user hasn't chosen auto/manual mode, ask them first
|
||||
function M.check_for_closed_prompt_with_preference()
|
||||
local preferences = require("codetyper.preferences")
|
||||
local parser = require("codetyper.parser")
|
||||
|
||||
-- First check if there are any prompts to process
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local prompts = parser.find_prompts_in_buffer(bufnr)
|
||||
if #prompts == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check user preference
|
||||
local auto_process = preferences.is_auto_process_enabled()
|
||||
|
||||
if auto_process == nil then
|
||||
-- Not yet decided - ask the user (but only once per session)
|
||||
if not asking_preference then
|
||||
asking_preference = true
|
||||
preferences.ask_auto_process_preference(function(enabled)
|
||||
asking_preference = false
|
||||
if enabled then
|
||||
-- User chose automatic - process now
|
||||
M.check_for_closed_prompt()
|
||||
else
|
||||
-- User chose manual - show hint
|
||||
utils.notify("Use :CoderProcess to process prompt tags manually", vim.log.levels.INFO)
|
||||
end
|
||||
end)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if auto_process then
|
||||
-- Automatic mode - process prompts
|
||||
M.check_for_closed_prompt()
|
||||
end
|
||||
-- Manual mode - do nothing, user will run :CoderProcess
|
||||
end
|
||||
|
||||
--- Check all prompts with preference check
|
||||
function M.check_all_prompts_with_preference()
|
||||
local preferences = require("codetyper.preferences")
|
||||
local parser = require("codetyper.parser")
|
||||
|
||||
-- First check if there are any prompts to process
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local prompts = parser.find_prompts_in_buffer(bufnr)
|
||||
if #prompts == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if any prompts are unprocessed
|
||||
local has_unprocessed = false
|
||||
for _, prompt in ipairs(prompts) do
|
||||
local prompt_key = get_prompt_key(bufnr, prompt)
|
||||
if not processed_prompts[prompt_key] then
|
||||
has_unprocessed = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if not has_unprocessed then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check user preference
|
||||
local auto_process = preferences.is_auto_process_enabled()
|
||||
|
||||
if auto_process == nil then
|
||||
-- Not yet decided - ask the user (but only once per session)
|
||||
if not asking_preference then
|
||||
asking_preference = true
|
||||
preferences.ask_auto_process_preference(function(enabled)
|
||||
asking_preference = false
|
||||
if enabled then
|
||||
-- User chose automatic - process now
|
||||
M.check_all_prompts()
|
||||
else
|
||||
-- User chose manual - show hint
|
||||
utils.notify("Use :CoderProcess to process prompt tags manually", vim.log.levels.INFO)
|
||||
end
|
||||
end)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if auto_process then
|
||||
-- Automatic mode - process prompts
|
||||
M.check_all_prompts()
|
||||
end
|
||||
-- Manual mode - do nothing, user will run :CoderProcess
|
||||
end
|
||||
|
||||
--- Reset processed prompts for a buffer (useful for re-processing)
|
||||
---@param bufnr? number Buffer number (default: current)
|
||||
function M.reset_processed(bufnr)
|
||||
bufnr = bufnr or vim.api.nvim_get_current_buf()
|
||||
for key, _ in pairs(processed_prompts) do
|
||||
if key:match("^" .. bufnr .. ":") then
|
||||
processed_prompts[key] = nil
|
||||
end
|
||||
end
|
||||
utils.notify("Prompt history cleared - prompts can be re-processed")
|
||||
end
|
||||
|
||||
--- Track if we already opened the split for this buffer
|
||||
---@type table<number, boolean>
|
||||
local auto_opened_buffers = {}
|
||||
|
||||
--- Auto-open target file when a coder file is opened directly
|
||||
function M.auto_open_target_file()
|
||||
local window = require("codetyper.window")
|
||||
|
||||
-- Skip if split is already open
|
||||
if window.is_open() then
|
||||
return
|
||||
end
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
|
||||
-- Skip if we already handled this buffer
|
||||
if auto_opened_buffers[bufnr] then
|
||||
return
|
||||
end
|
||||
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
|
||||
-- Skip empty paths
|
||||
if not current_file or current_file == "" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Verify it's a coder file
|
||||
if not utils.is_coder_file(current_file) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip if we're in a special buffer (nvim-tree, etc.)
|
||||
local buftype = vim.bo[bufnr].buftype
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Mark as handled
|
||||
auto_opened_buffers[bufnr] = true
|
||||
|
||||
-- Get the target file path
|
||||
local target_path = utils.get_target_path(current_file)
|
||||
|
||||
-- Check if target file exists
|
||||
if not utils.file_exists(target_path) then
|
||||
utils.notify("Target file not found: " .. vim.fn.fnamemodify(target_path, ":t"), vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
-- Get config with fallback defaults
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
-- Fallback width if config not fully loaded (percentage, e.g., 25 = 25%)
|
||||
local width_pct = (config and config.window and config.window.width) or 25
|
||||
local width = math.ceil(vim.o.columns * (width_pct / 100))
|
||||
|
||||
-- Store current coder window
|
||||
local coder_win = vim.api.nvim_get_current_win()
|
||||
local coder_buf = bufnr
|
||||
|
||||
-- Open target file in a vertical split on the right
|
||||
local ok, err = pcall(function()
|
||||
vim.cmd("vsplit " .. vim.fn.fnameescape(target_path))
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
utils.notify("Failed to open target file: " .. tostring(err), vim.log.levels.ERROR)
|
||||
auto_opened_buffers[bufnr] = nil -- Allow retry
|
||||
return
|
||||
end
|
||||
|
||||
-- Now we're in the target window (right side)
|
||||
local target_win = vim.api.nvim_get_current_win()
|
||||
local target_buf = vim.api.nvim_get_current_buf()
|
||||
|
||||
-- Set the coder window width (left side)
|
||||
pcall(vim.api.nvim_win_set_width, coder_win, width)
|
||||
|
||||
-- Update window module state
|
||||
window._coder_win = coder_win
|
||||
window._coder_buf = coder_buf
|
||||
window._target_win = target_win
|
||||
window._target_buf = target_buf
|
||||
|
||||
-- Set up window options for coder window
|
||||
pcall(function()
|
||||
vim.wo[coder_win].number = true
|
||||
vim.wo[coder_win].relativenumber = true
|
||||
vim.wo[coder_win].signcolumn = "yes"
|
||||
end)
|
||||
|
||||
utils.notify("Opened target: " .. vim.fn.fnamemodify(target_path, ":t"))
|
||||
end
|
||||
|
||||
--- Clear auto-opened tracking for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
function M.clear_auto_opened(bufnr)
|
||||
auto_opened_buffers[bufnr] = nil
|
||||
end
|
||||
|
||||
--- Set appropriate filetype for coder files
|
||||
function M.set_coder_filetype()
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
|
||||
-- Extract the actual extension (e.g., index.coder.ts -> ts)
|
||||
local ext = filepath:match("%.coder%.(%w+)$")
|
||||
|
||||
if ext then
|
||||
-- Map extension to filetype
|
||||
local ft_map = {
|
||||
ts = "typescript",
|
||||
tsx = "typescriptreact",
|
||||
js = "javascript",
|
||||
jsx = "javascriptreact",
|
||||
py = "python",
|
||||
lua = "lua",
|
||||
go = "go",
|
||||
rs = "rust",
|
||||
rb = "ruby",
|
||||
java = "java",
|
||||
c = "c",
|
||||
cpp = "cpp",
|
||||
cs = "cs",
|
||||
json = "json",
|
||||
yaml = "yaml",
|
||||
yml = "yaml",
|
||||
md = "markdown",
|
||||
html = "html",
|
||||
css = "css",
|
||||
scss = "scss",
|
||||
vue = "vue",
|
||||
svelte = "svelte",
|
||||
}
|
||||
|
||||
local filetype = ft_map[ext] or ext
|
||||
vim.bo.filetype = filetype
|
||||
end
|
||||
end
|
||||
|
||||
--- Clear all autocommands
|
||||
function M.clear()
|
||||
vim.api.nvim_del_augroup_by_name(AUGROUP)
|
||||
end
|
||||
|
||||
--- Track buffers that have been auto-indexed
|
||||
---@type table<number, boolean>
|
||||
local auto_indexed_buffers = {}
|
||||
|
||||
--- Supported file extensions for auto-indexing
|
||||
local supported_extensions = {
|
||||
"ts", "tsx", "js", "jsx", "py", "lua", "go", "rs", "rb",
|
||||
"java", "c", "cpp", "cs", "json", "yaml", "yml", "md",
|
||||
"html", "css", "scss", "vue", "svelte", "php", "sh", "zsh",
|
||||
}
|
||||
|
||||
--- Check if extension is supported
|
||||
---@param ext string File extension
|
||||
---@return boolean
|
||||
local function is_supported_extension(ext)
|
||||
for _, supported in ipairs(supported_extensions) do
|
||||
if ext == supported then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Auto-index a file by creating/opening its coder companion
|
||||
---@param bufnr number Buffer number
|
||||
function M.auto_index_file(bufnr)
|
||||
-- Skip if buffer is invalid
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip if already indexed
|
||||
if auto_indexed_buffers[bufnr] then
|
||||
return
|
||||
end
|
||||
|
||||
-- Get file path
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
if not filepath or filepath == "" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip coder files
|
||||
if utils.is_coder_file(filepath) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip special buffers
|
||||
local buftype = vim.bo[bufnr].buftype
|
||||
if buftype ~= "" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip unsupported file types
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
if ext == "" or not is_supported_extension(ext) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Skip if auto_index is disabled in config
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
if config and config.auto_index == false then
|
||||
return
|
||||
end
|
||||
|
||||
-- Mark as indexed
|
||||
auto_indexed_buffers[bufnr] = true
|
||||
|
||||
-- Get coder companion path
|
||||
local coder_path = utils.get_coder_path(filepath)
|
||||
|
||||
-- Check if coder file already exists
|
||||
local coder_exists = utils.file_exists(coder_path)
|
||||
|
||||
-- Create coder file with template if it doesn't exist
|
||||
if not coder_exists then
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
local template = string.format(
|
||||
[[-- Coder companion for %s
|
||||
-- Use /@ @/ tags to write pseudo-code prompts
|
||||
-- Example:
|
||||
-- /@
|
||||
-- Add a function that validates user input
|
||||
-- - Check for empty strings
|
||||
-- - Validate email format
|
||||
-- @/
|
||||
|
||||
]],
|
||||
filename
|
||||
)
|
||||
utils.write_file(coder_path, template)
|
||||
end
|
||||
|
||||
-- Notify user about the coder companion
|
||||
local coder_filename = vim.fn.fnamemodify(coder_path, ":t")
|
||||
if coder_exists then
|
||||
utils.notify("Coder companion available: " .. coder_filename, vim.log.levels.DEBUG)
|
||||
else
|
||||
utils.notify("Created coder companion: " .. coder_filename, vim.log.levels.INFO)
|
||||
end
|
||||
end
|
||||
|
||||
--- Open the coder companion for the current file
|
||||
---@param open_split? boolean Whether to open in split view (default: true)
|
||||
function M.open_coder_companion(open_split)
|
||||
open_split = open_split ~= false -- Default to true
|
||||
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
if not filepath or filepath == "" then
|
||||
utils.notify("No file open", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
if utils.is_coder_file(filepath) then
|
||||
utils.notify("Already in coder file", vim.log.levels.INFO)
|
||||
return
|
||||
end
|
||||
|
||||
local coder_path = utils.get_coder_path(filepath)
|
||||
|
||||
-- Create if it doesn't exist
|
||||
if not utils.file_exists(coder_path) then
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
local comment_prefix = "--"
|
||||
if vim.tbl_contains({ "js", "jsx", "ts", "tsx", "java", "c", "cpp", "cs", "go", "rs", "php" }, ext) then
|
||||
comment_prefix = "//"
|
||||
elseif vim.tbl_contains({ "py", "sh", "zsh", "yaml", "yml" }, ext) then
|
||||
comment_prefix = "#"
|
||||
elseif vim.tbl_contains({ "html", "md" }, ext) then
|
||||
comment_prefix = "<!--"
|
||||
end
|
||||
|
||||
local close_comment = comment_prefix == "<!--" and " -->" or ""
|
||||
local template = string.format(
|
||||
[[%s Coder companion for %s%s
|
||||
%s Use /@ @/ tags to write pseudo-code prompts%s
|
||||
%s Example:%s
|
||||
%s /@%s
|
||||
%s Add a function that validates user input%s
|
||||
%s - Check for empty strings%s
|
||||
%s - Validate email format%s
|
||||
%s @/%s
|
||||
|
||||
]],
|
||||
comment_prefix, filename, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment,
|
||||
comment_prefix, close_comment
|
||||
)
|
||||
utils.write_file(coder_path, template)
|
||||
end
|
||||
|
||||
if open_split then
|
||||
-- Use the window module to open split view
|
||||
local window = require("codetyper.window")
|
||||
window.open_split(coder_path, filepath)
|
||||
else
|
||||
-- Just open the coder file
|
||||
vim.cmd("edit " .. vim.fn.fnameescape(coder_path))
|
||||
end
|
||||
end
|
||||
|
||||
--- Clear auto-indexed tracking for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
function M.clear_auto_indexed(bufnr)
|
||||
auto_indexed_buffers[bufnr] = nil
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,957 +0,0 @@
|
||||
---@mod codetyper.commands Command definitions for Codetyper.nvim
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local window = require("codetyper.window")
|
||||
|
||||
--- Open coder view for current file or select one
|
||||
---@param opts? table Command options
|
||||
local function cmd_open(opts)
|
||||
opts = opts or {}
|
||||
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
|
||||
-- If no file is open, prompt user to select one
|
||||
if current_file == "" or vim.bo.buftype ~= "" then
|
||||
-- Use telescope or vim.ui.select to pick a file
|
||||
if pcall(require, "telescope") then
|
||||
require("telescope.builtin").find_files({
|
||||
prompt_title = "Select file for Coder",
|
||||
attach_mappings = function(prompt_bufnr, map)
|
||||
local actions = require("telescope.actions")
|
||||
local action_state = require("telescope.actions.state")
|
||||
|
||||
actions.select_default:replace(function()
|
||||
actions.close(prompt_bufnr)
|
||||
local selection = action_state.get_selected_entry()
|
||||
if selection then
|
||||
local target_path = selection.path or selection[1]
|
||||
local coder_path = utils.get_coder_path(target_path)
|
||||
window.open_split(target_path, coder_path)
|
||||
end
|
||||
end)
|
||||
return true
|
||||
end,
|
||||
})
|
||||
else
|
||||
-- Fallback to input prompt
|
||||
vim.ui.input({ prompt = "Enter file path: " }, function(input)
|
||||
if input and input ~= "" then
|
||||
local target_path = vim.fn.fnamemodify(input, ":p")
|
||||
local coder_path = utils.get_coder_path(target_path)
|
||||
window.open_split(target_path, coder_path)
|
||||
end
|
||||
end)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
local target_path, coder_path
|
||||
|
||||
-- Check if current file is a coder file
|
||||
if utils.is_coder_file(current_file) then
|
||||
coder_path = current_file
|
||||
target_path = utils.get_target_path(current_file)
|
||||
else
|
||||
target_path = current_file
|
||||
coder_path = utils.get_coder_path(current_file)
|
||||
end
|
||||
|
||||
window.open_split(target_path, coder_path)
|
||||
end
|
||||
|
||||
--- Close coder view
|
||||
local function cmd_close()
|
||||
window.close_split()
|
||||
end
|
||||
|
||||
--- Toggle coder view
|
||||
local function cmd_toggle()
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
|
||||
if current_file == "" then
|
||||
utils.notify("No file in current buffer", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local target_path, coder_path
|
||||
|
||||
if utils.is_coder_file(current_file) then
|
||||
coder_path = current_file
|
||||
target_path = utils.get_target_path(current_file)
|
||||
else
|
||||
target_path = current_file
|
||||
coder_path = utils.get_coder_path(current_file)
|
||||
end
|
||||
|
||||
window.toggle_split(target_path, coder_path)
|
||||
end
|
||||
|
||||
--- Build enhanced user prompt with context
|
||||
---@param clean_prompt string The cleaned user prompt
|
||||
---@param context table Context information
|
||||
---@return string Enhanced prompt
|
||||
local function build_user_prompt(clean_prompt, context)
|
||||
local enhanced = "TASK: " .. clean_prompt .. "\n\n"
|
||||
|
||||
enhanced = enhanced .. "REQUIREMENTS:\n"
|
||||
enhanced = enhanced .. "- Generate ONLY " .. (context.language or "code") .. " code\n"
|
||||
enhanced = enhanced .. "- NO markdown code blocks (no ```)\n"
|
||||
enhanced = enhanced .. "- NO explanations or comments about what you did\n"
|
||||
enhanced = enhanced .. "- Match the coding style of the existing file exactly\n"
|
||||
enhanced = enhanced .. "- Output must be ready to insert directly into the file\n"
|
||||
|
||||
return enhanced
|
||||
end
|
||||
|
||||
--- Process prompt at cursor and generate code
|
||||
local function cmd_process()
|
||||
local parser = require("codetyper.parser")
|
||||
local llm = require("codetyper.llm")
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
|
||||
if not utils.is_coder_file(current_file) then
|
||||
utils.notify("Not a coder file. Use *.coder.* files", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local prompt = parser.get_last_prompt(bufnr)
|
||||
if not prompt then
|
||||
utils.notify("No prompt found. Use /@ your prompt @/", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local target_path = utils.get_target_path(current_file)
|
||||
local prompt_type = parser.detect_prompt_type(prompt.content)
|
||||
local context = llm.build_context(target_path, prompt_type)
|
||||
local clean_prompt = parser.clean_prompt(prompt.content)
|
||||
|
||||
-- Build enhanced prompt with explicit instructions
|
||||
local enhanced_prompt = build_user_prompt(clean_prompt, context)
|
||||
|
||||
utils.notify("Processing: " .. clean_prompt:sub(1, 50) .. "...", vim.log.levels.INFO)
|
||||
|
||||
llm.generate(enhanced_prompt, context, function(response, err)
|
||||
if err then
|
||||
utils.notify("Generation failed: " .. err, vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
if response then
|
||||
-- Inject code into target file
|
||||
local inject = require("codetyper.inject")
|
||||
inject.inject_code(target_path, response, prompt_type)
|
||||
utils.notify("Code generated and injected!", vim.log.levels.INFO)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Show plugin status
|
||||
local function cmd_status()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
local tree = require("codetyper.tree")
|
||||
|
||||
local stats = tree.get_stats()
|
||||
|
||||
local status = {
|
||||
"Codetyper.nvim Status",
|
||||
"====================",
|
||||
"",
|
||||
"Provider: " .. config.llm.provider,
|
||||
}
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
local has_key = (config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY) ~= nil
|
||||
table.insert(status, "Claude API Key: " .. (has_key and "configured" or "NOT SET"))
|
||||
table.insert(status, "Claude Model: " .. config.llm.claude.model)
|
||||
else
|
||||
table.insert(status, "Ollama Host: " .. config.llm.ollama.host)
|
||||
table.insert(status, "Ollama Model: " .. config.llm.ollama.model)
|
||||
end
|
||||
|
||||
table.insert(status, "")
|
||||
table.insert(status, "Window Position: " .. config.window.position)
|
||||
table.insert(status, "Window Width: " .. tostring(config.window.width * 100) .. "%")
|
||||
table.insert(status, "")
|
||||
table.insert(status, "View Open: " .. (window.is_open() and "yes" or "no"))
|
||||
table.insert(status, "")
|
||||
table.insert(status, "Project Stats:")
|
||||
table.insert(status, " Files: " .. stats.files)
|
||||
table.insert(status, " Directories: " .. stats.directories)
|
||||
table.insert(status, " Tree Log: " .. (tree.get_tree_log_path() or "N/A"))
|
||||
|
||||
utils.notify(table.concat(status, "\n"))
|
||||
end
|
||||
|
||||
--- Refresh tree.log manually
|
||||
local function cmd_tree()
|
||||
local tree = require("codetyper.tree")
|
||||
if tree.update_tree_log() then
|
||||
utils.notify("Tree log updated: " .. tree.get_tree_log_path())
|
||||
else
|
||||
utils.notify("Failed to update tree log", vim.log.levels.ERROR)
|
||||
end
|
||||
end
|
||||
|
||||
--- Open tree.log file
|
||||
local function cmd_tree_view()
|
||||
local tree = require("codetyper.tree")
|
||||
local tree_log_path = tree.get_tree_log_path()
|
||||
|
||||
if not tree_log_path then
|
||||
utils.notify("Could not find tree.log", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
-- Ensure tree is up to date
|
||||
tree.update_tree_log()
|
||||
|
||||
-- Open in a new split
|
||||
vim.cmd("vsplit " .. vim.fn.fnameescape(tree_log_path))
|
||||
vim.bo.readonly = true
|
||||
vim.bo.modifiable = false
|
||||
end
|
||||
|
||||
--- Reset processed prompts to allow re-processing
|
||||
local function cmd_reset()
|
||||
local autocmds = require("codetyper.autocmds")
|
||||
autocmds.reset_processed()
|
||||
end
|
||||
|
||||
--- Force update gitignore
|
||||
local function cmd_gitignore()
|
||||
local gitignore = require("codetyper.gitignore")
|
||||
gitignore.force_update()
|
||||
end
|
||||
|
||||
--- Open ask panel
|
||||
local function cmd_ask()
|
||||
local ask = require("codetyper.ask")
|
||||
ask.open()
|
||||
end
|
||||
|
||||
--- Close ask panel
|
||||
local function cmd_ask_close()
|
||||
local ask = require("codetyper.ask")
|
||||
ask.close()
|
||||
end
|
||||
|
||||
--- Toggle ask panel
|
||||
local function cmd_ask_toggle()
|
||||
local ask = require("codetyper.ask")
|
||||
ask.toggle()
|
||||
end
|
||||
|
||||
--- Clear ask history
|
||||
local function cmd_ask_clear()
|
||||
local ask = require("codetyper.ask")
|
||||
ask.clear_history()
|
||||
end
|
||||
|
||||
--- Open agent panel
|
||||
local function cmd_agent()
|
||||
local agent_ui = require("codetyper.agent.ui")
|
||||
agent_ui.open()
|
||||
end
|
||||
|
||||
--- Close agent panel
|
||||
local function cmd_agent_close()
|
||||
local agent_ui = require("codetyper.agent.ui")
|
||||
agent_ui.close()
|
||||
end
|
||||
|
||||
--- Toggle agent panel
|
||||
local function cmd_agent_toggle()
|
||||
local agent_ui = require("codetyper.agent.ui")
|
||||
agent_ui.toggle()
|
||||
end
|
||||
|
||||
--- Stop running agent
|
||||
local function cmd_agent_stop()
|
||||
local agent = require("codetyper.agent")
|
||||
if agent.is_running() then
|
||||
agent.stop()
|
||||
utils.notify("Agent stopped")
|
||||
else
|
||||
utils.notify("No agent running", vim.log.levels.INFO)
|
||||
end
|
||||
end
|
||||
|
||||
--- Show chat type switcher modal (Ask/Agent)
|
||||
local function cmd_type_toggle()
|
||||
local switcher = require("codetyper.chat_switcher")
|
||||
switcher.show()
|
||||
end
|
||||
|
||||
--- Toggle logs panel
|
||||
local function cmd_logs_toggle()
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
logs_panel.toggle()
|
||||
end
|
||||
|
||||
--- Show scheduler status and queue info
|
||||
local function cmd_queue_status()
|
||||
local scheduler = require("codetyper.agent.scheduler")
|
||||
local queue = require("codetyper.agent.queue")
|
||||
local parser = require("codetyper.parser")
|
||||
|
||||
local status = scheduler.status()
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
|
||||
local lines = {
|
||||
"Scheduler Status",
|
||||
"================",
|
||||
"",
|
||||
"Running: " .. (status.running and "yes" or "NO"),
|
||||
"Paused: " .. (status.paused and "yes" or "no"),
|
||||
"Active Workers: " .. status.active_workers,
|
||||
"",
|
||||
"Queue Stats:",
|
||||
" Pending: " .. status.queue_stats.pending,
|
||||
" Processing: " .. status.queue_stats.processing,
|
||||
" Completed: " .. status.queue_stats.completed,
|
||||
" Cancelled: " .. status.queue_stats.cancelled,
|
||||
"",
|
||||
}
|
||||
|
||||
-- Check current buffer for prompts
|
||||
if filepath ~= "" then
|
||||
local prompts = parser.find_prompts_in_buffer(bufnr)
|
||||
table.insert(lines, "Current Buffer: " .. vim.fn.fnamemodify(filepath, ":t"))
|
||||
table.insert(lines, " Prompts found: " .. #prompts)
|
||||
for i, p in ipairs(prompts) do
|
||||
local preview = p.content:sub(1, 30):gsub("\n", " ")
|
||||
table.insert(lines, string.format(" %d. Line %d: %s...", i, p.start_line, preview))
|
||||
end
|
||||
end
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Manually trigger queue processing for current buffer
|
||||
local function cmd_queue_process()
|
||||
local autocmds = require("codetyper.autocmds")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
|
||||
-- Open logs panel to show progress
|
||||
logs_panel.open()
|
||||
|
||||
-- Check all prompts in current buffer
|
||||
autocmds.check_all_prompts()
|
||||
|
||||
utils.notify("Triggered queue processing for current buffer")
|
||||
end
|
||||
|
||||
--- Switch focus between coder and target windows
|
||||
local function cmd_focus()
|
||||
if not window.is_open() then
|
||||
utils.notify("Coder view not open", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local current_win = vim.api.nvim_get_current_win()
|
||||
if current_win == window.get_coder_win() then
|
||||
window.focus_target()
|
||||
else
|
||||
window.focus_coder()
|
||||
end
|
||||
end
|
||||
|
||||
--- Transform inline /@ @/ tags in current file
|
||||
--- Works on ANY file, not just .coder.* files
|
||||
local function cmd_transform()
|
||||
local parser = require("codetyper.parser")
|
||||
local llm = require("codetyper.llm")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
|
||||
if filepath == "" then
|
||||
utils.notify("No file in current buffer", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
-- Find all prompts in the current buffer
|
||||
local prompts = parser.find_prompts_in_buffer(bufnr)
|
||||
|
||||
if #prompts == 0 then
|
||||
utils.notify("No /@ @/ tags found in current file", vim.log.levels.INFO)
|
||||
return
|
||||
end
|
||||
|
||||
-- Open the logs panel to show generation progress
|
||||
logs_panel.open()
|
||||
logs.info("Transform started: " .. #prompts .. " prompt(s)")
|
||||
|
||||
utils.notify("Found " .. #prompts .. " prompt(s) to transform...", vim.log.levels.INFO)
|
||||
|
||||
-- Build context for this file
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
local context = llm.build_context(filepath, "code_generation")
|
||||
|
||||
-- Process prompts in reverse order (bottom to top) to maintain line numbers
|
||||
local sorted_prompts = {}
|
||||
for i = #prompts, 1, -1 do
|
||||
table.insert(sorted_prompts, prompts[i])
|
||||
end
|
||||
|
||||
-- Track how many are being processed
|
||||
local pending = #sorted_prompts
|
||||
local completed = 0
|
||||
local errors = 0
|
||||
|
||||
-- Process each prompt
|
||||
for _, prompt in ipairs(sorted_prompts) do
|
||||
local clean_prompt = parser.clean_prompt(prompt.content)
|
||||
local prompt_type = parser.detect_prompt_type(prompt.content)
|
||||
|
||||
-- Build enhanced user prompt
|
||||
local enhanced_prompt = "TASK: " .. clean_prompt .. "\n\n"
|
||||
enhanced_prompt = enhanced_prompt .. "REQUIREMENTS:\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Generate ONLY " .. (context.language or "code") .. " code\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- NO markdown code blocks (no ```)\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- NO explanations or comments about what you did\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Match the coding style of the existing file exactly\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Output must be ready to insert directly into the file\n"
|
||||
|
||||
logs.info("Processing: " .. clean_prompt:sub(1, 40) .. "...")
|
||||
utils.notify("Processing: " .. clean_prompt:sub(1, 40) .. "...", vim.log.levels.INFO)
|
||||
|
||||
-- Generate code for this prompt
|
||||
llm.generate(enhanced_prompt, context, function(response, err)
|
||||
if err then
|
||||
logs.error("Failed: " .. err)
|
||||
utils.notify("Failed: " .. err, vim.log.levels.ERROR)
|
||||
errors = errors + 1
|
||||
elseif response then
|
||||
-- Replace the prompt tag with generated code
|
||||
vim.schedule(function()
|
||||
-- Get current buffer lines
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
|
||||
-- Calculate the exact range to replace
|
||||
local start_line = prompt.start_line
|
||||
local end_line = prompt.end_line
|
||||
|
||||
-- Find the full lines containing the tags
|
||||
local start_line_content = lines[start_line] or ""
|
||||
local end_line_content = lines[end_line] or ""
|
||||
|
||||
-- Check if there's content before the opening tag on the same line
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
local before_tag = ""
|
||||
local after_tag = ""
|
||||
|
||||
local open_pos = start_line_content:find(utils.escape_pattern(config.patterns.open_tag))
|
||||
if open_pos and open_pos > 1 then
|
||||
before_tag = start_line_content:sub(1, open_pos - 1)
|
||||
end
|
||||
|
||||
local close_pos = end_line_content:find(utils.escape_pattern(config.patterns.close_tag))
|
||||
if close_pos then
|
||||
local after_close = close_pos + #config.patterns.close_tag
|
||||
if after_close <= #end_line_content then
|
||||
after_tag = end_line_content:sub(after_close)
|
||||
end
|
||||
end
|
||||
|
||||
-- Build the replacement lines
|
||||
local replacement_lines = vim.split(response, "\n", { plain = true })
|
||||
|
||||
-- Add before/after content if any
|
||||
if before_tag ~= "" and #replacement_lines > 0 then
|
||||
replacement_lines[1] = before_tag .. replacement_lines[1]
|
||||
end
|
||||
if after_tag ~= "" and #replacement_lines > 0 then
|
||||
replacement_lines[#replacement_lines] = replacement_lines[#replacement_lines] .. after_tag
|
||||
end
|
||||
|
||||
-- Replace the lines in buffer
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, replacement_lines)
|
||||
|
||||
completed = completed + 1
|
||||
if completed + errors >= pending then
|
||||
local msg = "Transform complete: " .. completed .. " succeeded, " .. errors .. " failed"
|
||||
logs.info(msg)
|
||||
utils.notify(msg, errors > 0 and vim.log.levels.WARN or vim.log.levels.INFO)
|
||||
end
|
||||
end)
|
||||
end
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Transform prompts within a line range (for visual selection)
|
||||
---@param start_line number Start line (1-indexed)
|
||||
---@param end_line number End line (1-indexed)
|
||||
local function cmd_transform_range(start_line, end_line)
|
||||
local parser = require("codetyper.parser")
|
||||
local llm = require("codetyper.llm")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
|
||||
if filepath == "" then
|
||||
utils.notify("No file in current buffer", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
-- Find all prompts in the current buffer
|
||||
local all_prompts = parser.find_prompts_in_buffer(bufnr)
|
||||
|
||||
-- Filter prompts that are within the selected range
|
||||
local prompts = {}
|
||||
for _, prompt in ipairs(all_prompts) do
|
||||
if prompt.start_line >= start_line and prompt.end_line <= end_line then
|
||||
table.insert(prompts, prompt)
|
||||
end
|
||||
end
|
||||
|
||||
if #prompts == 0 then
|
||||
utils.notify("No /@ @/ tags found in selection (lines " .. start_line .. "-" .. end_line .. ")", vim.log.levels.INFO)
|
||||
return
|
||||
end
|
||||
|
||||
-- Open the logs panel to show generation progress
|
||||
logs_panel.open()
|
||||
logs.info("Transform selection: " .. #prompts .. " prompt(s)")
|
||||
|
||||
utils.notify("Found " .. #prompts .. " prompt(s) in selection to transform...", vim.log.levels.INFO)
|
||||
|
||||
-- Build context for this file
|
||||
local context = llm.build_context(filepath, "code_generation")
|
||||
|
||||
-- Process prompts in reverse order (bottom to top) to maintain line numbers
|
||||
local sorted_prompts = {}
|
||||
for i = #prompts, 1, -1 do
|
||||
table.insert(sorted_prompts, prompts[i])
|
||||
end
|
||||
|
||||
local pending = #sorted_prompts
|
||||
local completed = 0
|
||||
local errors = 0
|
||||
|
||||
for _, prompt in ipairs(sorted_prompts) do
|
||||
local clean_prompt = parser.clean_prompt(prompt.content)
|
||||
|
||||
local enhanced_prompt = "TASK: " .. clean_prompt .. "\n\n"
|
||||
enhanced_prompt = enhanced_prompt .. "REQUIREMENTS:\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Generate ONLY " .. (context.language or "code") .. " code\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- NO markdown code blocks (no ```)\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- NO explanations or comments about what you did\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Match the coding style of the existing file exactly\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Output must be ready to insert directly into the file\n"
|
||||
|
||||
logs.info("Processing: " .. clean_prompt:sub(1, 40) .. "...")
|
||||
utils.notify("Processing: " .. clean_prompt:sub(1, 40) .. "...", vim.log.levels.INFO)
|
||||
|
||||
llm.generate(enhanced_prompt, context, function(response, err)
|
||||
if err then
|
||||
logs.error("Failed: " .. err)
|
||||
utils.notify("Failed: " .. err, vim.log.levels.ERROR)
|
||||
errors = errors + 1
|
||||
elseif response then
|
||||
vim.schedule(function()
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local p_start_line = prompt.start_line
|
||||
local p_end_line = prompt.end_line
|
||||
|
||||
local start_line_content = lines[p_start_line] or ""
|
||||
local end_line_content = lines[p_end_line] or ""
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
local before_tag = ""
|
||||
local after_tag = ""
|
||||
|
||||
local open_pos = start_line_content:find(utils.escape_pattern(config.patterns.open_tag))
|
||||
if open_pos and open_pos > 1 then
|
||||
before_tag = start_line_content:sub(1, open_pos - 1)
|
||||
end
|
||||
|
||||
local close_pos = end_line_content:find(utils.escape_pattern(config.patterns.close_tag))
|
||||
if close_pos then
|
||||
local after_close = close_pos + #config.patterns.close_tag
|
||||
if after_close <= #end_line_content then
|
||||
after_tag = end_line_content:sub(after_close)
|
||||
end
|
||||
end
|
||||
|
||||
local replacement_lines = vim.split(response, "\n", { plain = true })
|
||||
|
||||
if before_tag ~= "" and #replacement_lines > 0 then
|
||||
replacement_lines[1] = before_tag .. replacement_lines[1]
|
||||
end
|
||||
if after_tag ~= "" and #replacement_lines > 0 then
|
||||
replacement_lines[#replacement_lines] = replacement_lines[#replacement_lines] .. after_tag
|
||||
end
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, p_start_line - 1, p_end_line, false, replacement_lines)
|
||||
|
||||
completed = completed + 1
|
||||
if completed + errors >= pending then
|
||||
local msg = "Transform complete: " .. completed .. " succeeded, " .. errors .. " failed"
|
||||
logs.info(msg)
|
||||
utils.notify(msg, errors > 0 and vim.log.levels.WARN or vim.log.levels.INFO)
|
||||
end
|
||||
end)
|
||||
end
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Command wrapper for visual selection transform
|
||||
local function cmd_transform_visual()
|
||||
-- Get visual selection marks
|
||||
local start_line = vim.fn.line("'<")
|
||||
local end_line = vim.fn.line("'>")
|
||||
cmd_transform_range(start_line, end_line)
|
||||
end
|
||||
|
||||
--- Transform a single prompt at cursor position
|
||||
local function cmd_transform_at_cursor()
|
||||
local parser = require("codetyper.parser")
|
||||
local llm = require("codetyper.llm")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
|
||||
if filepath == "" then
|
||||
utils.notify("No file in current buffer", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
-- Find prompt at cursor
|
||||
local prompt = parser.get_prompt_at_cursor(bufnr)
|
||||
|
||||
if not prompt then
|
||||
utils.notify("No /@ @/ tag at cursor position", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
-- Open the logs panel to show generation progress
|
||||
logs_panel.open()
|
||||
|
||||
local clean_prompt = parser.clean_prompt(prompt.content)
|
||||
local context = llm.build_context(filepath, "code_generation")
|
||||
|
||||
logs.info("Transform cursor: " .. clean_prompt:sub(1, 40) .. "...")
|
||||
|
||||
-- Build enhanced user prompt
|
||||
local enhanced_prompt = "TASK: " .. clean_prompt .. "\n\n"
|
||||
enhanced_prompt = enhanced_prompt .. "REQUIREMENTS:\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Generate ONLY " .. (context.language or "code") .. " code\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- NO markdown code blocks (no ```)\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- NO explanations or comments about what you did\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Match the coding style of the existing file exactly\n"
|
||||
enhanced_prompt = enhanced_prompt .. "- Output must be ready to insert directly into the file\n"
|
||||
|
||||
utils.notify("Transforming: " .. clean_prompt:sub(1, 40) .. "...", vim.log.levels.INFO)
|
||||
|
||||
llm.generate(enhanced_prompt, context, function(response, err)
|
||||
if err then
|
||||
logs.error("Transform failed: " .. err)
|
||||
utils.notify("Transform failed: " .. err, vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
if response then
|
||||
vim.schedule(function()
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local start_line = prompt.start_line
|
||||
local end_line = prompt.end_line
|
||||
|
||||
local start_line_content = lines[start_line] or ""
|
||||
local end_line_content = lines[end_line] or ""
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
local before_tag = ""
|
||||
local after_tag = ""
|
||||
|
||||
local open_pos = start_line_content:find(utils.escape_pattern(config.patterns.open_tag))
|
||||
if open_pos and open_pos > 1 then
|
||||
before_tag = start_line_content:sub(1, open_pos - 1)
|
||||
end
|
||||
|
||||
local close_pos = end_line_content:find(utils.escape_pattern(config.patterns.close_tag))
|
||||
if close_pos then
|
||||
local after_close = close_pos + #config.patterns.close_tag
|
||||
if after_close <= #end_line_content then
|
||||
after_tag = end_line_content:sub(after_close)
|
||||
end
|
||||
end
|
||||
|
||||
local replacement_lines = vim.split(response, "\n", { plain = true })
|
||||
|
||||
if before_tag ~= "" and #replacement_lines > 0 then
|
||||
replacement_lines[1] = before_tag .. replacement_lines[1]
|
||||
end
|
||||
if after_tag ~= "" and #replacement_lines > 0 then
|
||||
replacement_lines[#replacement_lines] = replacement_lines[#replacement_lines] .. after_tag
|
||||
end
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, replacement_lines)
|
||||
logs.info("Transform complete!")
|
||||
utils.notify("Transform complete!", vim.log.levels.INFO)
|
||||
end)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Main command dispatcher
|
||||
---@param args table Command arguments
|
||||
local function coder_cmd(args)
|
||||
local subcommand = args.fargs[1] or "toggle"
|
||||
|
||||
local commands = {
|
||||
open = cmd_open,
|
||||
close = cmd_close,
|
||||
toggle = cmd_toggle,
|
||||
process = cmd_process,
|
||||
status = cmd_status,
|
||||
focus = cmd_focus,
|
||||
tree = cmd_tree,
|
||||
["tree-view"] = cmd_tree_view,
|
||||
reset = cmd_reset,
|
||||
ask = cmd_ask,
|
||||
["ask-close"] = cmd_ask_close,
|
||||
["ask-toggle"] = cmd_ask_toggle,
|
||||
["ask-clear"] = cmd_ask_clear,
|
||||
gitignore = cmd_gitignore,
|
||||
transform = cmd_transform,
|
||||
["transform-cursor"] = cmd_transform_at_cursor,
|
||||
agent = cmd_agent,
|
||||
["agent-close"] = cmd_agent_close,
|
||||
["agent-toggle"] = cmd_agent_toggle,
|
||||
["agent-stop"] = cmd_agent_stop,
|
||||
["type-toggle"] = cmd_type_toggle,
|
||||
["logs-toggle"] = cmd_logs_toggle,
|
||||
["queue-status"] = cmd_queue_status,
|
||||
["queue-process"] = cmd_queue_process,
|
||||
["auto-toggle"] = function()
|
||||
local preferences = require("codetyper.preferences")
|
||||
preferences.toggle_auto_process()
|
||||
end,
|
||||
["auto-set"] = function(args)
|
||||
local preferences = require("codetyper.preferences")
|
||||
local arg = (args[1] or ""):lower()
|
||||
if arg == "auto" or arg == "automatic" or arg == "on" then
|
||||
preferences.set_auto_process(true)
|
||||
utils.notify("Set to automatic mode", vim.log.levels.INFO)
|
||||
elseif arg == "manual" or arg == "off" then
|
||||
preferences.set_auto_process(false)
|
||||
utils.notify("Set to manual mode", vim.log.levels.INFO)
|
||||
else
|
||||
local auto = preferences.is_auto_process_enabled()
|
||||
if auto == nil then
|
||||
utils.notify("Mode not set yet (will ask on first prompt)", vim.log.levels.INFO)
|
||||
else
|
||||
local mode = auto and "automatic" or "manual"
|
||||
utils.notify("Currently in " .. mode .. " mode", vim.log.levels.INFO)
|
||||
end
|
||||
end
|
||||
end,
|
||||
}
|
||||
|
||||
local cmd_fn = commands[subcommand]
|
||||
if cmd_fn then
|
||||
cmd_fn(args)
|
||||
else
|
||||
utils.notify("Unknown subcommand: " .. subcommand, vim.log.levels.ERROR)
|
||||
end
|
||||
end
|
||||
|
||||
--- Setup all commands
|
||||
function M.setup()
|
||||
vim.api.nvim_create_user_command("Coder", coder_cmd, {
|
||||
nargs = "?",
|
||||
complete = function()
|
||||
return {
|
||||
"open", "close", "toggle", "process", "status", "focus",
|
||||
"tree", "tree-view", "reset", "gitignore",
|
||||
"ask", "ask-close", "ask-toggle", "ask-clear",
|
||||
"transform", "transform-cursor",
|
||||
"agent", "agent-close", "agent-toggle", "agent-stop",
|
||||
"type-toggle", "logs-toggle",
|
||||
"queue-status", "queue-process",
|
||||
"auto-toggle", "auto-set",
|
||||
}
|
||||
end,
|
||||
desc = "Codetyper.nvim commands",
|
||||
})
|
||||
|
||||
-- Convenience aliases
|
||||
vim.api.nvim_create_user_command("CoderOpen", function()
|
||||
cmd_open()
|
||||
end, { desc = "Open Coder view" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderClose", function()
|
||||
cmd_close()
|
||||
end, { desc = "Close Coder view" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderToggle", function()
|
||||
cmd_toggle()
|
||||
end, { desc = "Toggle Coder view" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderProcess", function()
|
||||
cmd_process()
|
||||
end, { desc = "Process prompt and generate code" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderTree", function()
|
||||
cmd_tree()
|
||||
end, { desc = "Refresh tree.log" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderTreeView", function()
|
||||
cmd_tree_view()
|
||||
end, { desc = "View tree.log" })
|
||||
|
||||
-- Ask panel commands
|
||||
vim.api.nvim_create_user_command("CoderAsk", function()
|
||||
cmd_ask()
|
||||
end, { desc = "Open Ask panel" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAskToggle", function()
|
||||
cmd_ask_toggle()
|
||||
end, { desc = "Toggle Ask panel" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAskClear", function()
|
||||
cmd_ask_clear()
|
||||
end, { desc = "Clear Ask history" })
|
||||
|
||||
-- Transform commands (inline /@ @/ tag replacement)
|
||||
vim.api.nvim_create_user_command("CoderTransform", function()
|
||||
cmd_transform()
|
||||
end, { desc = "Transform all /@ @/ tags in current file" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderTransformCursor", function()
|
||||
cmd_transform_at_cursor()
|
||||
end, { desc = "Transform /@ @/ tag at cursor" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderTransformVisual", function(opts)
|
||||
local start_line = opts.line1
|
||||
local end_line = opts.line2
|
||||
cmd_transform_range(start_line, end_line)
|
||||
end, { range = true, desc = "Transform /@ @/ tags in visual selection" })
|
||||
|
||||
-- Agent commands
|
||||
vim.api.nvim_create_user_command("CoderAgent", function()
|
||||
cmd_agent()
|
||||
end, { desc = "Open Agent panel" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAgentToggle", function()
|
||||
cmd_agent_toggle()
|
||||
end, { desc = "Toggle Agent panel" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAgentStop", function()
|
||||
cmd_agent_stop()
|
||||
end, { desc = "Stop running agent" })
|
||||
|
||||
-- Chat type switcher command
|
||||
vim.api.nvim_create_user_command("CoderType", function()
|
||||
cmd_type_toggle()
|
||||
end, { desc = "Show Ask/Agent mode switcher" })
|
||||
|
||||
-- Logs panel command
|
||||
vim.api.nvim_create_user_command("CoderLogs", function()
|
||||
cmd_logs_toggle()
|
||||
end, { desc = "Toggle logs panel" })
|
||||
|
||||
-- Index command - open coder companion for current file
|
||||
vim.api.nvim_create_user_command("CoderIndex", function()
|
||||
local autocmds = require("codetyper.autocmds")
|
||||
autocmds.open_coder_companion()
|
||||
end, { desc = "Open coder companion for current file" })
|
||||
|
||||
-- Queue commands
|
||||
vim.api.nvim_create_user_command("CoderQueueStatus", function()
|
||||
cmd_queue_status()
|
||||
end, { desc = "Show scheduler and queue status" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderQueueProcess", function()
|
||||
cmd_queue_process()
|
||||
end, { desc = "Manually trigger queue processing" })
|
||||
|
||||
-- Preferences commands
|
||||
vim.api.nvim_create_user_command("CoderAutoToggle", function()
|
||||
local preferences = require("codetyper.preferences")
|
||||
preferences.toggle_auto_process()
|
||||
end, { desc = "Toggle automatic/manual prompt processing" })
|
||||
|
||||
vim.api.nvim_create_user_command("CoderAutoSet", function(opts)
|
||||
local preferences = require("codetyper.preferences")
|
||||
local arg = opts.args:lower()
|
||||
if arg == "auto" or arg == "automatic" or arg == "on" then
|
||||
preferences.set_auto_process(true)
|
||||
vim.notify("Codetyper: Set to automatic mode", vim.log.levels.INFO)
|
||||
elseif arg == "manual" or arg == "off" then
|
||||
preferences.set_auto_process(false)
|
||||
vim.notify("Codetyper: Set to manual mode", vim.log.levels.INFO)
|
||||
else
|
||||
-- Show current mode
|
||||
local auto = preferences.is_auto_process_enabled()
|
||||
if auto == nil then
|
||||
vim.notify("Codetyper: Mode not set yet (will ask on first prompt)", vim.log.levels.INFO)
|
||||
else
|
||||
local mode = auto and "automatic" or "manual"
|
||||
vim.notify("Codetyper: Currently in " .. mode .. " mode", vim.log.levels.INFO)
|
||||
end
|
||||
end
|
||||
end, {
|
||||
desc = "Set prompt processing mode (auto/manual)",
|
||||
nargs = "?",
|
||||
complete = function()
|
||||
return { "auto", "manual" }
|
||||
end,
|
||||
})
|
||||
|
||||
-- Setup default keymaps
|
||||
M.setup_keymaps()
|
||||
end
|
||||
|
||||
--- Setup default keymaps for transform commands
|
||||
function M.setup_keymaps()
|
||||
-- Visual mode: transform selected /@ @/ tags
|
||||
vim.keymap.set("v", "<leader>ctt", ":<C-u>CoderTransformVisual<CR>", {
|
||||
silent = true,
|
||||
desc = "Coder: Transform selected tags"
|
||||
})
|
||||
|
||||
-- Normal mode: transform tag at cursor
|
||||
vim.keymap.set("n", "<leader>ctt", "<cmd>CoderTransformCursor<CR>", {
|
||||
silent = true,
|
||||
desc = "Coder: Transform tag at cursor"
|
||||
})
|
||||
|
||||
-- Normal mode: transform all tags in file
|
||||
vim.keymap.set("n", "<leader>ctT", "<cmd>CoderTransform<CR>", {
|
||||
silent = true,
|
||||
desc = "Coder: Transform all tags in file"
|
||||
})
|
||||
|
||||
-- Agent keymaps
|
||||
vim.keymap.set("n", "<leader>ca", "<cmd>CoderAgentToggle<CR>", {
|
||||
silent = true,
|
||||
desc = "Coder: Toggle Agent panel"
|
||||
})
|
||||
|
||||
-- Index keymap - open coder companion
|
||||
vim.keymap.set("n", "<leader>ci", "<cmd>CoderIndex<CR>", {
|
||||
silent = true,
|
||||
desc = "Coder: Open coder companion for file"
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
18
lua/codetyper/commands/agents/banned.lua
Normal file
18
lua/codetyper/commands/agents/banned.lua
Normal file
@@ -0,0 +1,18 @@
|
||||
--- Banned commands for safety
|
||||
M.BANNED_COMMANDS = {
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
":(){ :|:& };:",
|
||||
"> /dev/sda",
|
||||
}
|
||||
|
||||
--- Banned patterns
|
||||
M.BANNED_PATTERNS = {
|
||||
"curl.*|.*sh",
|
||||
"wget.*|.*sh",
|
||||
"rm%s+%-rf%s+/",
|
||||
}
|
||||
|
||||
return M
|
||||
681
lua/codetyper/config/credentials.lua
Normal file
681
lua/codetyper/config/credentials.lua
Normal file
@@ -0,0 +1,681 @@
|
||||
---@mod codetyper.credentials Secure credential storage for Codetyper.nvim
|
||||
---@brief [[
|
||||
--- Manages API keys and model preferences stored outside of config files.
|
||||
--- Credentials are stored in ~/.local/share/nvim/codetyper/configuration.json
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Get the credentials file path
|
||||
---@return string Path to credentials file
|
||||
local function get_credentials_path()
|
||||
local data_dir = vim.fn.stdpath("data")
|
||||
return data_dir .. "/codetyper/configuration.json"
|
||||
end
|
||||
|
||||
--- Ensure the credentials directory exists
|
||||
---@return boolean Success
|
||||
local function ensure_dir()
|
||||
local data_dir = vim.fn.stdpath("data")
|
||||
local codetyper_dir = data_dir .. "/codetyper"
|
||||
return utils.ensure_dir(codetyper_dir)
|
||||
end
|
||||
|
||||
--- Load credentials from file
|
||||
---@return table Credentials data
|
||||
function M.load()
|
||||
local path = get_credentials_path()
|
||||
local content = utils.read_file(path)
|
||||
|
||||
if not content or content == "" then
|
||||
return {
|
||||
version = 1,
|
||||
providers = {},
|
||||
}
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok or not data then
|
||||
return {
|
||||
version = 1,
|
||||
providers = {},
|
||||
}
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Save credentials to file
|
||||
---@param data table Credentials data
|
||||
---@return boolean Success
|
||||
function M.save(data)
|
||||
if not ensure_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local path = get_credentials_path()
|
||||
local ok, json = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, json)
|
||||
end
|
||||
|
||||
--- Get API key for a provider
|
||||
---@param provider string Provider name (claude, openai, gemini, copilot, ollama)
|
||||
---@return string|nil API key or nil if not found
|
||||
function M.get_api_key(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.api_key then
|
||||
return provider_data.api_key
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get model for a provider
|
||||
---@param provider string Provider name
|
||||
---@return string|nil Model name or nil if not found
|
||||
function M.get_model(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.model then
|
||||
return provider_data.model
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get endpoint for a provider (for custom OpenAI-compatible endpoints)
|
||||
---@param provider string Provider name
|
||||
---@return string|nil Endpoint URL or nil if not found
|
||||
function M.get_endpoint(provider)
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
|
||||
if provider_data and provider_data.endpoint then
|
||||
return provider_data.endpoint
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get host for Ollama
|
||||
---@return string|nil Host URL or nil if not found
|
||||
function M.get_ollama_host()
|
||||
local data = M.load()
|
||||
local provider_data = data.providers and data.providers.ollama
|
||||
|
||||
if provider_data and provider_data.host then
|
||||
return provider_data.host
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Set credentials for a provider
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Credentials (api_key, model, endpoint, host)
|
||||
---@return boolean Success
|
||||
function M.set_credentials(provider, credentials)
|
||||
local data = M.load()
|
||||
|
||||
if not data.providers then
|
||||
data.providers = {}
|
||||
end
|
||||
|
||||
if not data.providers[provider] then
|
||||
data.providers[provider] = {}
|
||||
end
|
||||
|
||||
-- Merge credentials
|
||||
for key, value in pairs(credentials) do
|
||||
if value and value ~= "" then
|
||||
data.providers[provider][key] = value
|
||||
end
|
||||
end
|
||||
|
||||
data.updated = os.time()
|
||||
|
||||
return M.save(data)
|
||||
end
|
||||
|
||||
--- Remove credentials for a provider
|
||||
---@param provider string Provider name
|
||||
---@return boolean Success
|
||||
function M.remove_credentials(provider)
|
||||
local data = M.load()
|
||||
|
||||
if data.providers and data.providers[provider] then
|
||||
data.providers[provider] = nil
|
||||
data.updated = os.time()
|
||||
return M.save(data)
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- List all configured providers (checks both stored credentials AND config)
|
||||
---@return table List of provider names with their config status
|
||||
function M.list_providers()
|
||||
local data = M.load()
|
||||
local result = {}
|
||||
|
||||
local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
|
||||
for _, provider in ipairs(all_providers) do
|
||||
local provider_data = data.providers and data.providers[provider]
|
||||
local has_stored_key = provider_data and provider_data.api_key and provider_data.api_key ~= ""
|
||||
local has_model = provider_data and provider_data.model and provider_data.model ~= ""
|
||||
|
||||
-- Check if configured from config or environment
|
||||
local configured_from_config = false
|
||||
local config_model = nil
|
||||
local ok, codetyper = pcall(require, "codetyper")
|
||||
if ok then
|
||||
local config = codetyper.get_config()
|
||||
if config and config.llm and config.llm[provider] then
|
||||
local pc = config.llm[provider]
|
||||
config_model = pc.model
|
||||
|
||||
if provider == "claude" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.ANTHROPIC_API_KEY ~= nil
|
||||
elseif provider == "openai" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.OPENAI_API_KEY ~= nil
|
||||
elseif provider == "gemini" then
|
||||
configured_from_config = pc.api_key ~= nil or vim.env.GEMINI_API_KEY ~= nil
|
||||
elseif provider == "copilot" then
|
||||
configured_from_config = true -- Just needs copilot.lua
|
||||
elseif provider == "ollama" then
|
||||
configured_from_config = pc.host ~= nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local is_configured = has_stored_key
|
||||
or (provider == "ollama" and provider_data ~= nil)
|
||||
or (provider == "copilot" and (provider_data ~= nil or configured_from_config))
|
||||
or configured_from_config
|
||||
|
||||
table.insert(result, {
|
||||
name = provider,
|
||||
configured = is_configured,
|
||||
has_api_key = has_stored_key,
|
||||
has_model = has_model or config_model ~= nil,
|
||||
model = (provider_data and provider_data.model) or config_model,
|
||||
source = has_stored_key and "stored" or (configured_from_config and "config" or nil),
|
||||
})
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Default models for each provider
|
||||
M.default_models = {
|
||||
claude = "claude-sonnet-4-20250514",
|
||||
openai = "gpt-4o",
|
||||
gemini = "gemini-2.0-flash",
|
||||
copilot = "claude-sonnet-4",
|
||||
ollama = "deepseek-coder:6.7b",
|
||||
}
|
||||
|
||||
--- Available models for Copilot (GitHub Copilot Chat API)
|
||||
--- Models with cost multipliers: 0x = free, 0.33x = discount, 1x = standard, 3x = premium
|
||||
M.copilot_models = {
|
||||
-- Free tier (0x)
|
||||
{ name = "gpt-4.1", cost = "0x" },
|
||||
{ name = "gpt-4o", cost = "0x" },
|
||||
{ name = "gpt-5-mini", cost = "0x" },
|
||||
{ name = "grok-code-fast-1", cost = "0x" },
|
||||
{ name = "raptor-mini", cost = "0x" },
|
||||
-- Discount tier (0.33x)
|
||||
{ name = "claude-haiku-4.5", cost = "0.33x" },
|
||||
{ name = "gemini-3-flash", cost = "0.33x" },
|
||||
{ name = "gpt-5.1-codex-mini", cost = "0.33x" },
|
||||
-- Standard tier (1x)
|
||||
{ name = "claude-sonnet-4", cost = "1x" },
|
||||
{ name = "claude-sonnet-4.5", cost = "1x" },
|
||||
{ name = "gemini-2.5-pro", cost = "1x" },
|
||||
{ name = "gemini-3-pro", cost = "1x" },
|
||||
{ name = "gpt-5", cost = "1x" },
|
||||
{ name = "gpt-5-codex", cost = "1x" },
|
||||
{ name = "gpt-5.1", cost = "1x" },
|
||||
{ name = "gpt-5.1-codex", cost = "1x" },
|
||||
{ name = "gpt-5.1-codex-max", cost = "1x" },
|
||||
{ name = "gpt-5.2", cost = "1x" },
|
||||
{ name = "gpt-5.2-codex", cost = "1x" },
|
||||
-- Premium tier (3x)
|
||||
{ name = "claude-opus-4.5", cost = "3x" },
|
||||
}
|
||||
|
||||
--- Get list of copilot model names (for completion)
|
||||
---@return string[]
|
||||
function M.get_copilot_model_names()
|
||||
local names = {}
|
||||
for _, model in ipairs(M.copilot_models) do
|
||||
table.insert(names, model.name)
|
||||
end
|
||||
return names
|
||||
end
|
||||
|
||||
--- Get cost for a copilot model
|
||||
---@param model_name string
|
||||
---@return string|nil
|
||||
function M.get_copilot_model_cost(model_name)
|
||||
for _, model in ipairs(M.copilot_models) do
|
||||
if model.name == model_name then
|
||||
return model.cost
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Interactive command to add/update API key
|
||||
function M.interactive_add()
|
||||
local providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
|
||||
-- Step 1: Select provider
|
||||
vim.ui.select(providers, {
|
||||
prompt = "Select LLM provider:",
|
||||
format_item = function(item)
|
||||
local display = item:sub(1, 1):upper() .. item:sub(2)
|
||||
local creds = M.load()
|
||||
local configured = creds.providers and creds.providers[item]
|
||||
if configured and (configured.api_key or item == "ollama") then
|
||||
return display .. " [configured]"
|
||||
end
|
||||
return display
|
||||
end,
|
||||
}, function(provider)
|
||||
if not provider then
|
||||
return
|
||||
end
|
||||
|
||||
-- Step 2: Get API key (skip for Ollama)
|
||||
if provider == "ollama" then
|
||||
M.interactive_ollama_config()
|
||||
else
|
||||
M.interactive_api_key(provider)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive API key input
|
||||
---@param provider string Provider name
|
||||
function M.interactive_api_key(provider)
|
||||
-- Copilot uses OAuth from copilot.lua, no API key needed
|
||||
if provider == "copilot" then
|
||||
M.interactive_copilot_config()
|
||||
return
|
||||
end
|
||||
|
||||
local prompt = string.format("Enter %s API key (leave empty to skip): ", provider:upper())
|
||||
|
||||
vim.ui.input({ prompt = prompt }, function(api_key)
|
||||
if api_key == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
-- Step 3: Get model
|
||||
M.interactive_model(provider, api_key)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive Copilot configuration (no API key, uses OAuth)
|
||||
---@param silent? boolean If true, don't show the OAuth info message
|
||||
function M.interactive_copilot_config(silent)
|
||||
if not silent then
|
||||
utils.notify("Copilot uses OAuth from copilot.lua/copilot.vim - no API key needed", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
-- Get current model if configured
|
||||
local current_model = M.get_model("copilot") or M.default_models.copilot
|
||||
local current_cost = M.get_copilot_model_cost(current_model) or "?"
|
||||
|
||||
-- Build model options with "Custom..." option
|
||||
local model_options = vim.deepcopy(M.copilot_models)
|
||||
table.insert(model_options, { name = "Custom...", cost = "" })
|
||||
|
||||
vim.ui.select(model_options, {
|
||||
prompt = "Select Copilot model (current: " .. current_model .. " — " .. current_cost .. "):",
|
||||
format_item = function(item)
|
||||
local display = item.name
|
||||
if item.cost and item.cost ~= "" then
|
||||
display = display .. " — " .. item.cost
|
||||
end
|
||||
if item.name == current_model then
|
||||
display = display .. " [current]"
|
||||
end
|
||||
return display
|
||||
end,
|
||||
}, function(choice)
|
||||
if choice == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if choice.name == "Custom..." then
|
||||
-- Allow custom model input
|
||||
vim.ui.input({
|
||||
prompt = "Enter custom model name: ",
|
||||
default = current_model,
|
||||
}, function(custom_model)
|
||||
if custom_model and custom_model ~= "" then
|
||||
M.save_and_notify("copilot", {
|
||||
model = custom_model,
|
||||
configured = true,
|
||||
})
|
||||
end
|
||||
end)
|
||||
else
|
||||
M.save_and_notify("copilot", {
|
||||
model = choice.name,
|
||||
configured = true,
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive model selection
|
||||
---@param provider string Provider name
|
||||
---@param api_key string|nil API key
|
||||
function M.interactive_model(provider, api_key)
|
||||
local default_model = M.default_models[provider] or ""
|
||||
local prompt = string.format("Enter model (default: %s): ", default_model)
|
||||
|
||||
vim.ui.input({ prompt = prompt, default = default_model }, function(model)
|
||||
if model == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
-- Use default if empty
|
||||
if model == "" then
|
||||
model = default_model
|
||||
end
|
||||
|
||||
-- Save credentials
|
||||
local credentials = {
|
||||
model = model,
|
||||
}
|
||||
|
||||
if api_key and api_key ~= "" then
|
||||
credentials.api_key = api_key
|
||||
end
|
||||
|
||||
-- For OpenAI, also ask for custom endpoint
|
||||
if provider == "openai" then
|
||||
M.interactive_endpoint(provider, credentials)
|
||||
else
|
||||
M.save_and_notify(provider, credentials)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive endpoint input for OpenAI-compatible providers
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Current credentials
|
||||
function M.interactive_endpoint(provider, credentials)
|
||||
vim.ui.input({
|
||||
prompt = "Custom endpoint (leave empty for default OpenAI): ",
|
||||
}, function(endpoint)
|
||||
if endpoint == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if endpoint ~= "" then
|
||||
credentials.endpoint = endpoint
|
||||
end
|
||||
|
||||
M.save_and_notify(provider, credentials)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Interactive Ollama configuration
|
||||
function M.interactive_ollama_config()
|
||||
vim.ui.input({
|
||||
prompt = "Ollama host (default: http://localhost:11434): ",
|
||||
default = "http://localhost:11434",
|
||||
}, function(host)
|
||||
if host == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if host == "" then
|
||||
host = "http://localhost:11434"
|
||||
end
|
||||
|
||||
-- Get model
|
||||
local default_model = M.default_models.ollama
|
||||
vim.ui.input({
|
||||
prompt = string.format("Ollama model (default: %s): ", default_model),
|
||||
default = default_model,
|
||||
}, function(model)
|
||||
if model == nil then
|
||||
return -- Cancelled
|
||||
end
|
||||
|
||||
if model == "" then
|
||||
model = default_model
|
||||
end
|
||||
|
||||
M.save_and_notify("ollama", {
|
||||
host = host,
|
||||
model = model,
|
||||
})
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Save credentials and notify user
|
||||
---@param provider string Provider name
|
||||
---@param credentials table Credentials to save
|
||||
function M.save_and_notify(provider, credentials)
|
||||
if M.set_credentials(provider, credentials) then
|
||||
local msg = string.format("Saved %s configuration", provider:upper())
|
||||
if credentials.model then
|
||||
msg = msg .. " (model: " .. credentials.model .. ")"
|
||||
end
|
||||
utils.notify(msg, vim.log.levels.INFO)
|
||||
else
|
||||
utils.notify("Failed to save credentials", vim.log.levels.ERROR)
|
||||
end
|
||||
end
|
||||
|
||||
--- Show current credentials status
|
||||
function M.show_status()
|
||||
local providers = M.list_providers()
|
||||
|
||||
-- Get current active provider
|
||||
local codetyper = require("codetyper")
|
||||
local current = codetyper.get_config().llm.provider
|
||||
|
||||
local lines = {
|
||||
"Codetyper Credentials Status",
|
||||
"============================",
|
||||
"",
|
||||
"Storage: " .. get_credentials_path(),
|
||||
"Active: " .. current:upper(),
|
||||
"",
|
||||
}
|
||||
|
||||
for _, p in ipairs(providers) do
|
||||
local status_icon = p.configured and "✓" or "✗"
|
||||
local active_marker = p.name == current and " [ACTIVE]" or ""
|
||||
local source_info = ""
|
||||
if p.configured then
|
||||
source_info = p.source == "stored" and " (stored)" or " (config)"
|
||||
end
|
||||
local model_info = p.model and (" - " .. p.model) or ""
|
||||
|
||||
table.insert(lines, string.format(" %s %s%s%s%s",
|
||||
status_icon,
|
||||
p.name:upper(),
|
||||
active_marker,
|
||||
source_info,
|
||||
model_info))
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Commands:")
|
||||
table.insert(lines, " :CoderAddApiKey - Add/update credentials")
|
||||
table.insert(lines, " :CoderSwitchProvider - Switch active provider")
|
||||
table.insert(lines, " :CoderRemoveApiKey - Remove stored credentials")
|
||||
|
||||
utils.notify(table.concat(lines, "\n"))
|
||||
end
|
||||
|
||||
--- Interactive remove credentials
|
||||
function M.interactive_remove()
|
||||
local data = M.load()
|
||||
local configured = {}
|
||||
|
||||
for provider, _ in pairs(data.providers or {}) do
|
||||
table.insert(configured, provider)
|
||||
end
|
||||
|
||||
if #configured == 0 then
|
||||
utils.notify("No credentials configured", vim.log.levels.INFO)
|
||||
return
|
||||
end
|
||||
|
||||
vim.ui.select(configured, {
|
||||
prompt = "Select provider to remove:",
|
||||
}, function(provider)
|
||||
if not provider then
|
||||
return
|
||||
end
|
||||
|
||||
vim.ui.select({ "Yes", "No" }, {
|
||||
prompt = "Remove " .. provider:upper() .. " credentials?",
|
||||
}, function(choice)
|
||||
if choice == "Yes" then
|
||||
if M.remove_credentials(provider) then
|
||||
utils.notify("Removed " .. provider:upper() .. " credentials", vim.log.levels.INFO)
|
||||
else
|
||||
utils.notify("Failed to remove credentials", vim.log.levels.ERROR)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Set the active provider
|
||||
---@param provider string Provider name
|
||||
function M.set_active_provider(provider)
|
||||
local data = M.load()
|
||||
data.active_provider = provider
|
||||
data.updated = os.time()
|
||||
M.save(data)
|
||||
|
||||
-- Also update the runtime config
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
config.llm.provider = provider
|
||||
|
||||
utils.notify("Active provider set to: " .. provider:upper(), vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Get the active provider from stored config
|
||||
---@return string|nil Active provider
|
||||
function M.get_active_provider()
|
||||
local data = M.load()
|
||||
return data.active_provider
|
||||
end
|
||||
|
||||
--- Check if a provider is configured (from stored credentials OR config)
|
||||
---@param provider string Provider name
|
||||
---@return boolean configured, string|nil source
|
||||
local function is_provider_configured(provider)
|
||||
-- Check stored credentials first
|
||||
local data = M.load()
|
||||
local stored = data.providers and data.providers[provider]
|
||||
if stored then
|
||||
if stored.configured or stored.api_key or provider == "ollama" or provider == "copilot" then
|
||||
return true, "stored"
|
||||
end
|
||||
end
|
||||
|
||||
-- Check codetyper config
|
||||
local ok, codetyper = pcall(require, "codetyper")
|
||||
if not ok then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local config = codetyper.get_config()
|
||||
if not config or not config.llm then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local provider_config = config.llm[provider]
|
||||
if not provider_config then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
-- Check for API key in config or environment
|
||||
if provider == "claude" then
|
||||
if provider_config.api_key or vim.env.ANTHROPIC_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "openai" then
|
||||
if provider_config.api_key or vim.env.OPENAI_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "gemini" then
|
||||
if provider_config.api_key or vim.env.GEMINI_API_KEY then
|
||||
return true, "config"
|
||||
end
|
||||
elseif provider == "copilot" then
|
||||
-- Copilot just needs copilot.lua installed
|
||||
return true, "config"
|
||||
elseif provider == "ollama" then
|
||||
-- Ollama just needs host configured
|
||||
if provider_config.host then
|
||||
return true, "config"
|
||||
end
|
||||
end
|
||||
|
||||
return false, nil
|
||||
end
|
||||
|
||||
--- Interactive switch provider
|
||||
function M.interactive_switch_provider()
|
||||
local all_providers = { "claude", "openai", "gemini", "copilot", "ollama" }
|
||||
local available = {}
|
||||
local sources = {}
|
||||
|
||||
for _, provider in ipairs(all_providers) do
|
||||
local configured, source = is_provider_configured(provider)
|
||||
if configured then
|
||||
table.insert(available, provider)
|
||||
sources[provider] = source
|
||||
end
|
||||
end
|
||||
|
||||
if #available == 0 then
|
||||
utils.notify("No providers configured. Use :CoderAddApiKey or add to your config.", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local current = codetyper.get_config().llm.provider
|
||||
|
||||
vim.ui.select(available, {
|
||||
prompt = "Select provider (current: " .. current .. "):",
|
||||
format_item = function(item)
|
||||
local marker = item == current and " [active]" or ""
|
||||
local source_marker = sources[item] == "stored" and " (stored)" or " (config)"
|
||||
return item:upper() .. marker .. source_marker
|
||||
end,
|
||||
}, function(provider)
|
||||
if provider then
|
||||
M.set_active_provider(provider)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -5,11 +5,7 @@ local M = {}
|
||||
---@type CoderConfig
|
||||
local defaults = {
|
||||
llm = {
|
||||
provider = "ollama", -- Options: "claude", "ollama", "openai", "gemini", "copilot"
|
||||
claude = {
|
||||
api_key = nil, -- Will use ANTHROPIC_API_KEY env var if nil
|
||||
model = "claude-sonnet-4-20250514",
|
||||
},
|
||||
provider = "ollama", -- Options: "ollama", "openai", "gemini", "copilot"
|
||||
ollama = {
|
||||
host = "http://localhost:11434",
|
||||
model = "deepseek-coder:6.7b",
|
||||
@@ -24,7 +20,7 @@ local defaults = {
|
||||
model = "gemini-2.0-flash",
|
||||
},
|
||||
copilot = {
|
||||
model = "gpt-4o", -- Uses GitHub Copilot authentication
|
||||
model = "claude-sonnet-4", -- Uses GitHub Copilot authentication
|
||||
},
|
||||
},
|
||||
window = {
|
||||
@@ -48,6 +44,48 @@ local defaults = {
|
||||
completion_delay_ms = 100, -- Wait after completion popup closes
|
||||
apply_delay_ms = 5000, -- Wait before removing tags and applying code (ms)
|
||||
},
|
||||
indexer = {
|
||||
enabled = true, -- Enable project indexing
|
||||
auto_index = true, -- Index files on save
|
||||
index_on_open = false, -- Index project when opening
|
||||
max_file_size = 100000, -- Skip files larger than 100KB
|
||||
excluded_dirs = { "node_modules", "dist", "build", ".git", ".coder", "__pycache__", "vendor", "target" },
|
||||
index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" },
|
||||
memory = {
|
||||
enabled = true, -- Enable memory persistence
|
||||
max_memories = 1000, -- Maximum stored memories
|
||||
prune_threshold = 0.1, -- Remove low-weight memories
|
||||
},
|
||||
},
|
||||
brain = {
|
||||
enabled = true, -- Enable brain learning system
|
||||
auto_learn = true, -- Auto-learn from events
|
||||
auto_commit = true, -- Auto-commit after threshold
|
||||
commit_threshold = 10, -- Changes before auto-commit
|
||||
max_nodes = 5000, -- Maximum nodes before pruning
|
||||
max_deltas = 500, -- Maximum delta history
|
||||
prune = {
|
||||
enabled = true, -- Enable auto-pruning
|
||||
threshold = 0.1, -- Remove nodes below this weight
|
||||
unused_days = 90, -- Remove unused nodes after N days
|
||||
},
|
||||
output = {
|
||||
max_tokens = 4000, -- Token budget for LLM context
|
||||
format = "compact", -- "compact"|"json"|"natural"
|
||||
},
|
||||
},
|
||||
suggestion = {
|
||||
enabled = true, -- Enable ghost text suggestions (Copilot-style)
|
||||
auto_trigger = true, -- Auto-trigger on typing
|
||||
debounce = 150, -- Debounce in milliseconds
|
||||
use_copilot = true, -- Use copilot.lua suggestions when available, fallback to codetyper
|
||||
keymap = {
|
||||
accept = "<Tab>", -- Accept suggestion
|
||||
next = "<M-]>", -- Next suggestion (Alt+])
|
||||
prev = "<M-[>", -- Previous suggestion (Alt+[)
|
||||
dismiss = "<C-]>", -- Dismiss suggestion (Ctrl+])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
--- Deep merge two tables
|
||||
@@ -88,7 +126,7 @@ function M.validate(config)
|
||||
return false, "Missing LLM configuration"
|
||||
end
|
||||
|
||||
local valid_providers = { "claude", "ollama", "openai", "gemini", "copilot" }
|
||||
local valid_providers = { "ollama", "openai", "gemini", "copilot" }
|
||||
local is_valid_provider = false
|
||||
for _, p in ipairs(valid_providers) do
|
||||
if config.llm.provider == p then
|
||||
@@ -102,12 +140,7 @@ function M.validate(config)
|
||||
end
|
||||
|
||||
-- Validate provider-specific configuration
|
||||
if config.llm.provider == "claude" then
|
||||
local api_key = config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY
|
||||
if not api_key or api_key == "" then
|
||||
return false, "Claude API key not configured. Set llm.claude.api_key or ANTHROPIC_API_KEY env var"
|
||||
end
|
||||
elseif config.llm.provider == "openai" then
|
||||
if config.llm.provider == "openai" then
|
||||
local api_key = config.llm.openai.api_key or vim.env.OPENAI_API_KEY
|
||||
if not api_key or api_key == "" then
|
||||
return false, "OpenAI API key not configured. Set llm.openai.api_key or OPENAI_API_KEY env var"
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
---@class CoderPreferences
|
||||
---@field auto_process boolean Whether to auto-process /@ @/ tags (default: nil = ask)
|
||||
750
lua/codetyper/core/cost/init.lua
Normal file
750
lua/codetyper/core/cost/init.lua
Normal file
@@ -0,0 +1,750 @@
|
||||
---@mod codetyper.cost Cost estimation for LLM usage
|
||||
---@brief [[
|
||||
--- Tracks token usage and estimates costs based on model pricing.
|
||||
--- Prices are per 1M tokens. Persists usage data in the brain.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Cost history file name
|
||||
local COST_HISTORY_FILE = "cost_history.json"
|
||||
|
||||
--- Get path to cost history file
|
||||
---@return string File path
|
||||
local function get_history_path()
|
||||
local root = utils.get_project_root()
|
||||
return root .. "/.coder/" .. COST_HISTORY_FILE
|
||||
end
|
||||
|
||||
--- Default model for savings comparison (what you'd pay if not using Ollama)
|
||||
M.comparison_model = "gpt-4o"
|
||||
|
||||
--- Models considered "free" (Ollama, local, Copilot subscription)
|
||||
M.free_models = {
|
||||
["ollama"] = true,
|
||||
["codellama"] = true,
|
||||
["llama2"] = true,
|
||||
["llama3"] = true,
|
||||
["mistral"] = true,
|
||||
["deepseek-coder"] = true,
|
||||
["copilot"] = true,
|
||||
}
|
||||
|
||||
--- Model pricing table (per 1M tokens in USD)
|
||||
---@type table<string, {input: number, cached_input: number|nil, output: number|nil}>
|
||||
M.pricing = {
|
||||
-- GPT-5.x series
|
||||
["gpt-5.2"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 },
|
||||
["gpt-5-nano"] = { input = 0.05, cached_input = 0.005, output = 0.40 },
|
||||
["gpt-5.2-chat-latest"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-chat-latest"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.2-codex"] = { input = 1.75, cached_input = 0.175, output = 14.00 },
|
||||
["gpt-5.1-codex-max"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.1-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5-codex"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
["gpt-5.2-pro"] = { input = 21.00, cached_input = nil, output = 168.00 },
|
||||
["gpt-5-pro"] = { input = 15.00, cached_input = nil, output = 120.00 },
|
||||
["gpt-5.1-codex-mini"] = { input = 0.25, cached_input = 0.025, output = 2.00 },
|
||||
["gpt-5-search-api"] = { input = 1.25, cached_input = 0.125, output = 10.00 },
|
||||
|
||||
-- GPT-4.x series
|
||||
["gpt-4.1"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["gpt-4.1-mini"] = { input = 0.40, cached_input = 0.10, output = 1.60 },
|
||||
["gpt-4.1-nano"] = { input = 0.10, cached_input = 0.025, output = 0.40 },
|
||||
["gpt-4o"] = { input = 2.50, cached_input = 1.25, output = 10.00 },
|
||||
["gpt-4o-2024-05-13"] = { input = 5.00, cached_input = nil, output = 15.00 },
|
||||
["gpt-4o-mini"] = { input = 0.15, cached_input = 0.075, output = 0.60 },
|
||||
|
||||
-- Realtime models
|
||||
["gpt-realtime"] = { input = 4.00, cached_input = 0.40, output = 16.00 },
|
||||
["gpt-realtime-mini"] = { input = 0.60, cached_input = 0.06, output = 2.40 },
|
||||
["gpt-4o-realtime-preview"] = { input = 5.00, cached_input = 2.50, output = 20.00 },
|
||||
["gpt-4o-mini-realtime-preview"] = { input = 0.60, cached_input = 0.30, output = 2.40 },
|
||||
|
||||
-- Audio models
|
||||
["gpt-audio"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
["gpt-audio-mini"] = { input = 0.60, cached_input = nil, output = 2.40 },
|
||||
["gpt-4o-audio-preview"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
["gpt-4o-mini-audio-preview"] = { input = 0.15, cached_input = nil, output = 0.60 },
|
||||
|
||||
-- O-series reasoning models
|
||||
["o1"] = { input = 15.00, cached_input = 7.50, output = 60.00 },
|
||||
["o1-pro"] = { input = 150.00, cached_input = nil, output = 600.00 },
|
||||
["o3-pro"] = { input = 20.00, cached_input = nil, output = 80.00 },
|
||||
["o3"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["o3-deep-research"] = { input = 10.00, cached_input = 2.50, output = 40.00 },
|
||||
["o4-mini"] = { input = 1.10, cached_input = 0.275, output = 4.40 },
|
||||
["o4-mini-deep-research"] = { input = 2.00, cached_input = 0.50, output = 8.00 },
|
||||
["o3-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 },
|
||||
["o1-mini"] = { input = 1.10, cached_input = 0.55, output = 4.40 },
|
||||
|
||||
-- Codex
|
||||
["codex-mini-latest"] = { input = 1.50, cached_input = 0.375, output = 6.00 },
|
||||
|
||||
-- Search models
|
||||
["gpt-4o-mini-search-preview"] = { input = 0.15, cached_input = nil, output = 0.60 },
|
||||
["gpt-4o-search-preview"] = { input = 2.50, cached_input = nil, output = 10.00 },
|
||||
|
||||
-- Computer use
|
||||
["computer-use-preview"] = { input = 3.00, cached_input = nil, output = 12.00 },
|
||||
|
||||
-- Image models
|
||||
["gpt-image-1.5"] = { input = 5.00, cached_input = 1.25, output = 10.00 },
|
||||
["chatgpt-image-latest"] = { input = 5.00, cached_input = 1.25, output = 10.00 },
|
||||
["gpt-image-1"] = { input = 5.00, cached_input = 1.25, output = nil },
|
||||
["gpt-image-1-mini"] = { input = 2.00, cached_input = 0.20, output = nil },
|
||||
|
||||
-- Claude models
|
||||
["claude-3-opus"] = { input = 15.00, cached_input = 7.50, output = 75.00 },
|
||||
["claude-3-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 },
|
||||
["claude-3-haiku"] = { input = 0.25, cached_input = 0.125, output = 1.25 },
|
||||
["claude-3.5-sonnet"] = { input = 3.00, cached_input = 1.50, output = 15.00 },
|
||||
["claude-3.5-haiku"] = { input = 0.80, cached_input = 0.40, output = 4.00 },
|
||||
|
||||
-- Ollama/Local models (free)
|
||||
["ollama"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["codellama"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["llama2"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["llama3"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["mistral"] = { input = 0, cached_input = 0, output = 0 },
|
||||
["deepseek-coder"] = { input = 0, cached_input = 0, output = 0 },
|
||||
|
||||
-- Copilot (included in subscription, but tracking usage)
|
||||
["copilot"] = { input = 0, cached_input = 0, output = 0 },
|
||||
}
|
||||
|
||||
---@class CostUsage
|
||||
---@field model string Model name
|
||||
---@field input_tokens number Input tokens used
|
||||
---@field output_tokens number Output tokens used
|
||||
---@field cached_tokens number Cached input tokens
|
||||
---@field timestamp number Unix timestamp
|
||||
---@field cost number Calculated cost in USD
|
||||
|
||||
---@class CostState
|
||||
---@field usage CostUsage[] Current session usage
|
||||
---@field all_usage CostUsage[] All historical usage from brain
|
||||
---@field session_start number Session start timestamp
|
||||
---@field win number|nil Window handle
|
||||
---@field buf number|nil Buffer handle
|
||||
---@field loaded boolean Whether historical data has been loaded
|
||||
local state = {
|
||||
usage = {},
|
||||
all_usage = {},
|
||||
session_start = os.time(),
|
||||
win = nil,
|
||||
buf = nil,
|
||||
loaded = false,
|
||||
}
|
||||
|
||||
--- Load historical usage from disk
|
||||
function M.load_from_history()
|
||||
if state.loaded then
|
||||
return
|
||||
end
|
||||
|
||||
local history_path = get_history_path()
|
||||
local content = utils.read_file(history_path)
|
||||
|
||||
if content and content ~= "" then
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if ok and data and data.usage then
|
||||
state.all_usage = data.usage
|
||||
end
|
||||
end
|
||||
|
||||
state.loaded = true
|
||||
end
|
||||
|
||||
--- Save all usage to disk (debounced)
|
||||
local save_timer = nil
|
||||
local function save_to_disk()
|
||||
-- Cancel existing timer
|
||||
if save_timer then
|
||||
save_timer:stop()
|
||||
save_timer = nil
|
||||
end
|
||||
|
||||
-- Debounce writes (500ms)
|
||||
save_timer = vim.loop.new_timer()
|
||||
save_timer:start(500, 0, vim.schedule_wrap(function()
|
||||
local history_path = get_history_path()
|
||||
|
||||
-- Ensure directory exists
|
||||
local dir = vim.fn.fnamemodify(history_path, ":h")
|
||||
utils.ensure_dir(dir)
|
||||
|
||||
-- Merge session and historical usage
|
||||
local all_data = vim.deepcopy(state.all_usage)
|
||||
for _, usage in ipairs(state.usage) do
|
||||
table.insert(all_data, usage)
|
||||
end
|
||||
|
||||
-- Save to file
|
||||
local data = {
|
||||
version = 1,
|
||||
updated = os.time(),
|
||||
usage = all_data,
|
||||
}
|
||||
|
||||
local ok, json = pcall(vim.json.encode, data)
|
||||
if ok then
|
||||
utils.write_file(history_path, json)
|
||||
end
|
||||
|
||||
save_timer = nil
|
||||
end))
|
||||
end
|
||||
|
||||
--- Normalize model name for pricing lookup
|
||||
---@param model string Model name from API
|
||||
---@return string Normalized model name
|
||||
local function normalize_model(model)
|
||||
if not model then
|
||||
return "unknown"
|
||||
end
|
||||
|
||||
-- Convert to lowercase
|
||||
local normalized = model:lower()
|
||||
|
||||
-- Handle Copilot models
|
||||
if normalized:match("copilot") then
|
||||
return "copilot"
|
||||
end
|
||||
|
||||
-- Handle common prefixes
|
||||
normalized = normalized:gsub("^openai/", "")
|
||||
normalized = normalized:gsub("^anthropic/", "")
|
||||
|
||||
-- Try exact match first
|
||||
if M.pricing[normalized] then
|
||||
return normalized
|
||||
end
|
||||
|
||||
-- Try partial matches
|
||||
for price_model, _ in pairs(M.pricing) do
|
||||
if normalized:match(price_model) or price_model:match(normalized) then
|
||||
return price_model
|
||||
end
|
||||
end
|
||||
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if a model is considered "free" (local/Ollama/Copilot subscription)
|
||||
---@param model string Model name
|
||||
---@return boolean True if free
|
||||
function M.is_free_model(model)
|
||||
local normalized = normalize_model(model)
|
||||
|
||||
-- Check direct match
|
||||
if M.free_models[normalized] then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check if it's an Ollama model (any model with : in name like deepseek-coder:6.7b)
|
||||
if model:match(":") then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check pricing - if cost is 0, it's free
|
||||
local pricing = M.pricing[normalized]
|
||||
if pricing and pricing.input == 0 and pricing.output == 0 then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Calculate cost for token usage
|
||||
---@param model string Model name
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
---@return number Cost in USD
|
||||
function M.calculate_cost(model, input_tokens, output_tokens, cached_tokens)
|
||||
local normalized = normalize_model(model)
|
||||
local pricing = M.pricing[normalized]
|
||||
|
||||
if not pricing then
|
||||
-- Unknown model, return 0
|
||||
return 0
|
||||
end
|
||||
|
||||
cached_tokens = cached_tokens or 0
|
||||
local regular_input = input_tokens - cached_tokens
|
||||
|
||||
-- Calculate cost (prices are per 1M tokens)
|
||||
local input_cost = (regular_input / 1000000) * (pricing.input or 0)
|
||||
local cached_cost = (cached_tokens / 1000000) * (pricing.cached_input or pricing.input or 0)
|
||||
local output_cost = (output_tokens / 1000000) * (pricing.output or 0)
|
||||
|
||||
return input_cost + cached_cost + output_cost
|
||||
end
|
||||
|
||||
--- Calculate estimated savings (what would have been paid if using comparison model)
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
---@return number Estimated savings in USD
|
||||
function M.calculate_savings(input_tokens, output_tokens, cached_tokens)
|
||||
-- Calculate what it would have cost with the comparison model
|
||||
return M.calculate_cost(M.comparison_model, input_tokens, output_tokens, cached_tokens)
|
||||
end
|
||||
|
||||
--- Record token usage
|
||||
---@param model string Model name
|
||||
---@param input_tokens number Input tokens
|
||||
---@param output_tokens number Output tokens
|
||||
---@param cached_tokens? number Cached input tokens
|
||||
function M.record_usage(model, input_tokens, output_tokens, cached_tokens)
|
||||
cached_tokens = cached_tokens or 0
|
||||
local cost = M.calculate_cost(model, input_tokens, output_tokens, cached_tokens)
|
||||
|
||||
-- Calculate savings if using a free model
|
||||
local savings = 0
|
||||
if M.is_free_model(model) then
|
||||
savings = M.calculate_savings(input_tokens, output_tokens, cached_tokens)
|
||||
end
|
||||
|
||||
table.insert(state.usage, {
|
||||
model = model,
|
||||
input_tokens = input_tokens,
|
||||
output_tokens = output_tokens,
|
||||
cached_tokens = cached_tokens,
|
||||
timestamp = os.time(),
|
||||
cost = cost,
|
||||
savings = savings,
|
||||
is_free = M.is_free_model(model),
|
||||
})
|
||||
|
||||
-- Save to disk (debounced)
|
||||
save_to_disk()
|
||||
|
||||
-- Update window if open
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
M.refresh_window()
|
||||
end
|
||||
end
|
||||
|
||||
--- Aggregate usage data into stats
|
||||
---@param usage_list CostUsage[] List of usage records
|
||||
---@return table Stats
|
||||
local function aggregate_usage(usage_list)
|
||||
local stats = {
|
||||
total_input = 0,
|
||||
total_output = 0,
|
||||
total_cached = 0,
|
||||
total_cost = 0,
|
||||
total_savings = 0,
|
||||
free_requests = 0,
|
||||
paid_requests = 0,
|
||||
by_model = {},
|
||||
request_count = #usage_list,
|
||||
}
|
||||
|
||||
for _, usage in ipairs(usage_list) do
|
||||
stats.total_input = stats.total_input + (usage.input_tokens or 0)
|
||||
stats.total_output = stats.total_output + (usage.output_tokens or 0)
|
||||
stats.total_cached = stats.total_cached + (usage.cached_tokens or 0)
|
||||
stats.total_cost = stats.total_cost + (usage.cost or 0)
|
||||
|
||||
-- Track savings
|
||||
local usage_savings = usage.savings or 0
|
||||
-- For historical data without savings field, calculate it
|
||||
if usage_savings == 0 and usage.is_free == nil then
|
||||
local model = usage.model or "unknown"
|
||||
if M.is_free_model(model) then
|
||||
usage_savings = M.calculate_savings(
|
||||
usage.input_tokens or 0,
|
||||
usage.output_tokens or 0,
|
||||
usage.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
end
|
||||
stats.total_savings = stats.total_savings + usage_savings
|
||||
|
||||
-- Track free vs paid
|
||||
local is_free = usage.is_free
|
||||
if is_free == nil then
|
||||
is_free = M.is_free_model(usage.model or "unknown")
|
||||
end
|
||||
if is_free then
|
||||
stats.free_requests = stats.free_requests + 1
|
||||
else
|
||||
stats.paid_requests = stats.paid_requests + 1
|
||||
end
|
||||
|
||||
local model = usage.model or "unknown"
|
||||
if not stats.by_model[model] then
|
||||
stats.by_model[model] = {
|
||||
input_tokens = 0,
|
||||
output_tokens = 0,
|
||||
cached_tokens = 0,
|
||||
cost = 0,
|
||||
savings = 0,
|
||||
requests = 0,
|
||||
is_free = is_free,
|
||||
}
|
||||
end
|
||||
|
||||
stats.by_model[model].input_tokens = stats.by_model[model].input_tokens + (usage.input_tokens or 0)
|
||||
stats.by_model[model].output_tokens = stats.by_model[model].output_tokens + (usage.output_tokens or 0)
|
||||
stats.by_model[model].cached_tokens = stats.by_model[model].cached_tokens + (usage.cached_tokens or 0)
|
||||
stats.by_model[model].cost = stats.by_model[model].cost + (usage.cost or 0)
|
||||
stats.by_model[model].savings = stats.by_model[model].savings + usage_savings
|
||||
stats.by_model[model].requests = stats.by_model[model].requests + 1
|
||||
end
|
||||
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Get session statistics
|
||||
---@return table Statistics
|
||||
function M.get_stats()
|
||||
local stats = aggregate_usage(state.usage)
|
||||
stats.session_duration = os.time() - state.session_start
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Get all-time statistics (session + historical)
|
||||
---@return table Statistics
|
||||
function M.get_all_time_stats()
|
||||
-- Load history if not loaded
|
||||
M.load_from_history()
|
||||
|
||||
-- Combine session and historical usage
|
||||
local all_usage = vim.deepcopy(state.all_usage)
|
||||
for _, usage in ipairs(state.usage) do
|
||||
table.insert(all_usage, usage)
|
||||
end
|
||||
|
||||
local stats = aggregate_usage(all_usage)
|
||||
|
||||
-- Calculate time span
|
||||
if #all_usage > 0 then
|
||||
local oldest = all_usage[1].timestamp or os.time()
|
||||
for _, usage in ipairs(all_usage) do
|
||||
if usage.timestamp and usage.timestamp < oldest then
|
||||
oldest = usage.timestamp
|
||||
end
|
||||
end
|
||||
stats.time_span = os.time() - oldest
|
||||
else
|
||||
stats.time_span = 0
|
||||
end
|
||||
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Format cost as string
|
||||
---@param cost number Cost in USD
|
||||
---@return string Formatted cost
|
||||
local function format_cost(cost)
|
||||
if cost < 0.01 then
|
||||
return string.format("$%.4f", cost)
|
||||
elseif cost < 1 then
|
||||
return string.format("$%.3f", cost)
|
||||
else
|
||||
return string.format("$%.2f", cost)
|
||||
end
|
||||
end
|
||||
|
||||
--- Format token count
|
||||
---@param tokens number Token count
|
||||
---@return string Formatted count
|
||||
local function format_tokens(tokens)
|
||||
if tokens >= 1000000 then
|
||||
return string.format("%.2fM", tokens / 1000000)
|
||||
elseif tokens >= 1000 then
|
||||
return string.format("%.1fK", tokens / 1000)
|
||||
else
|
||||
return tostring(tokens)
|
||||
end
|
||||
end
|
||||
|
||||
--- Format duration
|
||||
---@param seconds number Duration in seconds
|
||||
---@return string Formatted duration
|
||||
local function format_duration(seconds)
|
||||
if seconds < 60 then
|
||||
return string.format("%ds", seconds)
|
||||
elseif seconds < 3600 then
|
||||
return string.format("%dm %ds", math.floor(seconds / 60), seconds % 60)
|
||||
else
|
||||
local hours = math.floor(seconds / 3600)
|
||||
local mins = math.floor((seconds % 3600) / 60)
|
||||
return string.format("%dh %dm", hours, mins)
|
||||
end
|
||||
end
|
||||
|
||||
--- Generate model breakdown section
|
||||
---@param stats table Stats with by_model
|
||||
---@return string[] Lines
|
||||
local function generate_model_breakdown(stats)
|
||||
local lines = {}
|
||||
|
||||
if next(stats.by_model) then
|
||||
-- Sort models by cost (descending)
|
||||
local models = {}
|
||||
for model, data in pairs(stats.by_model) do
|
||||
table.insert(models, { name = model, data = data })
|
||||
end
|
||||
table.sort(models, function(a, b)
|
||||
return a.data.cost > b.data.cost
|
||||
end)
|
||||
|
||||
for _, item in ipairs(models) do
|
||||
local model = item.name
|
||||
local data = item.data
|
||||
local pricing = M.pricing[normalize_model(model)]
|
||||
local is_free = data.is_free or M.is_free_model(model)
|
||||
|
||||
table.insert(lines, "")
|
||||
local model_icon = is_free and "🆓" or "💳"
|
||||
table.insert(lines, string.format(" %s %s", model_icon, model))
|
||||
table.insert(lines, string.format(" Requests: %d", data.requests))
|
||||
table.insert(lines, string.format(" Input: %s tokens", format_tokens(data.input_tokens)))
|
||||
table.insert(lines, string.format(" Output: %s tokens", format_tokens(data.output_tokens)))
|
||||
|
||||
if is_free then
|
||||
-- Show savings for free models
|
||||
if data.savings and data.savings > 0 then
|
||||
table.insert(lines, string.format(" Saved: %s", format_cost(data.savings)))
|
||||
end
|
||||
else
|
||||
table.insert(lines, string.format(" Cost: %s", format_cost(data.cost)))
|
||||
end
|
||||
|
||||
-- Show pricing info for paid models
|
||||
if pricing and not is_free then
|
||||
local price_info = string.format(
|
||||
" Rate: $%.2f/1M in, $%.2f/1M out",
|
||||
pricing.input or 0,
|
||||
pricing.output or 0
|
||||
)
|
||||
table.insert(lines, price_info)
|
||||
end
|
||||
end
|
||||
else
|
||||
table.insert(lines, " No usage recorded.")
|
||||
end
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Generate window content
|
||||
---@return string[] Lines for the buffer
|
||||
local function generate_content()
|
||||
local session_stats = M.get_stats()
|
||||
local all_time_stats = M.get_all_time_stats()
|
||||
local lines = {}
|
||||
|
||||
-- Header
|
||||
table.insert(lines, "╔══════════════════════════════════════════════════════╗")
|
||||
table.insert(lines, "║ 💰 LLM Cost Estimation ║")
|
||||
table.insert(lines, "╠══════════════════════════════════════════════════════╣")
|
||||
table.insert(lines, "")
|
||||
|
||||
-- All-time summary (prominent)
|
||||
table.insert(lines, "🌐 All-Time Summary (Project)")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
if all_time_stats.time_span > 0 then
|
||||
table.insert(lines, string.format(" Time span: %s", format_duration(all_time_stats.time_span)))
|
||||
end
|
||||
table.insert(lines, string.format(" Requests: %d total", all_time_stats.request_count))
|
||||
table.insert(lines, string.format(" Local/Free: %d requests", all_time_stats.free_requests or 0))
|
||||
table.insert(lines, string.format(" Paid API: %d requests", all_time_stats.paid_requests or 0))
|
||||
table.insert(lines, string.format(" Input tokens: %s", format_tokens(all_time_stats.total_input)))
|
||||
table.insert(lines, string.format(" Output tokens: %s", format_tokens(all_time_stats.total_output)))
|
||||
if all_time_stats.total_cached > 0 then
|
||||
table.insert(lines, string.format(" Cached tokens: %s", format_tokens(all_time_stats.total_cached)))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, string.format(" 💵 Total Cost: %s", format_cost(all_time_stats.total_cost)))
|
||||
|
||||
-- Show savings prominently if there are any
|
||||
if all_time_stats.total_savings and all_time_stats.total_savings > 0 then
|
||||
table.insert(lines, string.format(" 💚 Saved: %s (vs %s)", format_cost(all_time_stats.total_savings), M.comparison_model))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Session summary
|
||||
table.insert(lines, "📊 Current Session")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
table.insert(lines, string.format(" Duration: %s", format_duration(session_stats.session_duration)))
|
||||
table.insert(lines, string.format(" Requests: %d (%d free, %d paid)",
|
||||
session_stats.request_count,
|
||||
session_stats.free_requests or 0,
|
||||
session_stats.paid_requests or 0))
|
||||
table.insert(lines, string.format(" Input tokens: %s", format_tokens(session_stats.total_input)))
|
||||
table.insert(lines, string.format(" Output tokens: %s", format_tokens(session_stats.total_output)))
|
||||
if session_stats.total_cached > 0 then
|
||||
table.insert(lines, string.format(" Cached tokens: %s", format_tokens(session_stats.total_cached)))
|
||||
end
|
||||
table.insert(lines, string.format(" Session Cost: %s", format_cost(session_stats.total_cost)))
|
||||
if session_stats.total_savings and session_stats.total_savings > 0 then
|
||||
table.insert(lines, string.format(" Session Saved: %s", format_cost(session_stats.total_savings)))
|
||||
end
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Per-model breakdown (all-time)
|
||||
table.insert(lines, "📈 Cost by Model (All-Time)")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
local model_lines = generate_model_breakdown(all_time_stats)
|
||||
for _, line in ipairs(model_lines) do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "───────────────────────────────────────────────────────")
|
||||
table.insert(lines, " 'q' close | 'r' refresh | 'c' clear session | 'C' all")
|
||||
table.insert(lines, "╚══════════════════════════════════════════════════════╝")
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Refresh the cost window content
|
||||
function M.refresh_window()
|
||||
if not state.buf or not vim.api.nvim_buf_is_valid(state.buf) then
|
||||
return
|
||||
end
|
||||
|
||||
local lines = generate_content()
|
||||
|
||||
vim.bo[state.buf].modifiable = true
|
||||
vim.api.nvim_buf_set_lines(state.buf, 0, -1, false, lines)
|
||||
vim.bo[state.buf].modifiable = false
|
||||
end
|
||||
|
||||
--- Open the cost estimation window
|
||||
function M.open()
|
||||
-- Load historical data if not loaded
|
||||
M.load_from_history()
|
||||
|
||||
-- Close existing window if open
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
vim.api.nvim_win_close(state.win, true)
|
||||
end
|
||||
|
||||
-- Create buffer
|
||||
state.buf = vim.api.nvim_create_buf(false, true)
|
||||
vim.bo[state.buf].buftype = "nofile"
|
||||
vim.bo[state.buf].bufhidden = "wipe"
|
||||
vim.bo[state.buf].swapfile = false
|
||||
vim.bo[state.buf].filetype = "codetyper-cost"
|
||||
|
||||
-- Calculate window size
|
||||
local width = 58
|
||||
local height = 40
|
||||
local row = math.floor((vim.o.lines - height) / 2)
|
||||
local col = math.floor((vim.o.columns - width) / 2)
|
||||
|
||||
-- Create floating window
|
||||
state.win = vim.api.nvim_open_win(state.buf, true, {
|
||||
relative = "editor",
|
||||
width = width,
|
||||
height = height,
|
||||
row = row,
|
||||
col = col,
|
||||
style = "minimal",
|
||||
border = "rounded",
|
||||
title = " Cost Estimation ",
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
-- Set window options
|
||||
vim.wo[state.win].wrap = false
|
||||
vim.wo[state.win].cursorline = false
|
||||
|
||||
-- Populate content
|
||||
M.refresh_window()
|
||||
|
||||
-- Set up keymaps
|
||||
local opts = { buffer = state.buf, silent = true }
|
||||
vim.keymap.set("n", "q", function()
|
||||
M.close()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "<Esc>", function()
|
||||
M.close()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "r", function()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "c", function()
|
||||
M.clear_session()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
vim.keymap.set("n", "C", function()
|
||||
M.clear_all()
|
||||
M.refresh_window()
|
||||
end, opts)
|
||||
|
||||
-- Set up highlights
|
||||
vim.api.nvim_buf_call(state.buf, function()
|
||||
vim.fn.matchadd("Title", "LLM Cost Estimation")
|
||||
vim.fn.matchadd("Number", "\\$[0-9.]*")
|
||||
vim.fn.matchadd("Keyword", "[0-9.]*[KM]\\? tokens")
|
||||
vim.fn.matchadd("Special", "🤖\\|💰\\|📊\\|📈\\|💵")
|
||||
end)
|
||||
end
|
||||
|
||||
--- Close the cost window
|
||||
function M.close()
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
vim.api.nvim_win_close(state.win, true)
|
||||
end
|
||||
state.win = nil
|
||||
state.buf = nil
|
||||
end
|
||||
|
||||
--- Toggle the cost window
|
||||
function M.toggle()
|
||||
if state.win and vim.api.nvim_win_is_valid(state.win) then
|
||||
M.close()
|
||||
else
|
||||
M.open()
|
||||
end
|
||||
end
|
||||
|
||||
--- Clear session usage (not history)
|
||||
function M.clear_session()
|
||||
state.usage = {}
|
||||
state.session_start = os.time()
|
||||
utils.notify("Session cost tracking cleared", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Clear all history (session + saved)
|
||||
function M.clear_all()
|
||||
state.usage = {}
|
||||
state.all_usage = {}
|
||||
state.session_start = os.time()
|
||||
state.loaded = false
|
||||
|
||||
-- Delete history file
|
||||
local history_path = get_history_path()
|
||||
local ok, err = os.remove(history_path)
|
||||
if not ok and err and not err:match("No such file") then
|
||||
utils.notify("Failed to delete history: " .. err, vim.log.levels.WARN)
|
||||
end
|
||||
|
||||
utils.notify("All cost history cleared", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Clear usage history (alias for clear_session)
|
||||
function M.clear()
|
||||
M.clear_session()
|
||||
end
|
||||
|
||||
--- Reset session
|
||||
function M.reset()
|
||||
M.clear_session()
|
||||
end
|
||||
|
||||
return M
|
||||
1052
lua/codetyper/core/diff/conflict.lua
Normal file
1052
lua/codetyper/core/diff/conflict.lua
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,20 @@ local M = {}
|
||||
---@param callback fun(approved: boolean) Called with user decision
|
||||
function M.show_diff(diff_data, callback)
|
||||
local original_lines = vim.split(diff_data.original, "\n", { plain = true })
|
||||
local modified_lines = vim.split(diff_data.modified, "\n", { plain = true })
|
||||
local modified_lines
|
||||
|
||||
-- For delete operations, show a clear message
|
||||
if diff_data.operation == "delete" then
|
||||
modified_lines = {
|
||||
"",
|
||||
" FILE WILL BE DELETED",
|
||||
"",
|
||||
" Reason: " .. (diff_data.reason or "No reason provided"),
|
||||
"",
|
||||
}
|
||||
else
|
||||
modified_lines = vim.split(diff_data.modified, "\n", { plain = true })
|
||||
end
|
||||
|
||||
-- Calculate window dimensions
|
||||
local width = math.floor(vim.o.columns * 0.8)
|
||||
@@ -59,7 +72,7 @@ function M.show_diff(diff_data, callback)
|
||||
col = col + half_width + 1,
|
||||
style = "minimal",
|
||||
border = "rounded",
|
||||
title = " MODIFIED [" .. diff_data.operation .. "] ",
|
||||
title = diff_data.operation == "delete" and " ⚠️ DELETE " or (" MODIFIED [" .. diff_data.operation .. "] "),
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
@@ -144,39 +157,67 @@ function M.show_diff(diff_data, callback)
|
||||
end
|
||||
|
||||
-- Show help message
|
||||
vim.api.nvim_echo({
|
||||
{ "Diff: ", "Normal" },
|
||||
{ diff_data.path, "Directory" },
|
||||
{ " | ", "Normal" },
|
||||
{ "y/<CR>", "Keyword" },
|
||||
{ " approve ", "Normal" },
|
||||
{ "n/q/<Esc>", "Keyword" },
|
||||
{ " reject ", "Normal" },
|
||||
{ "<Tab>", "Keyword" },
|
||||
{ " switch panes", "Normal" },
|
||||
}, false, {})
|
||||
local help_msg = require("codetyper.prompts.agents.diff").diff_help
|
||||
|
||||
-- Iterate to replace {path} variable
|
||||
local final_help = {}
|
||||
for _, item in ipairs(help_msg) do
|
||||
if item[1] == "{path}" then
|
||||
table.insert(final_help, { diff_data.path, item[2] })
|
||||
else
|
||||
table.insert(final_help, item)
|
||||
end
|
||||
end
|
||||
|
||||
vim.api.nvim_echo(final_help, false, {})
|
||||
end
|
||||
|
||||
--- Show approval dialog for bash commands
|
||||
---@alias BashApprovalResult {approved: boolean, permission_level: string|nil}
|
||||
|
||||
--- Show approval dialog for bash commands with permission levels
|
||||
---@param command string The bash command to approve
|
||||
---@param callback fun(approved: boolean) Called with user decision
|
||||
---@param callback fun(result: BashApprovalResult) Called with user decision
|
||||
function M.show_bash_approval(command, callback)
|
||||
-- Create a simple floating window for bash approval
|
||||
local permissions = require("codetyper.features.agents.permissions")
|
||||
|
||||
-- Check if command is auto-approved
|
||||
local perm_result = permissions.check_bash_permission(command)
|
||||
if perm_result.auto and perm_result.allowed then
|
||||
vim.schedule(function()
|
||||
callback({ approved = true, permission_level = "auto" })
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Create approval dialog with options
|
||||
local approval_prompts = require("codetyper.prompts.agents.diff").bash_approval
|
||||
local lines = {
|
||||
"",
|
||||
" BASH COMMAND APPROVAL",
|
||||
" " .. string.rep("-", 50),
|
||||
approval_prompts.title,
|
||||
approval_prompts.divider,
|
||||
"",
|
||||
" Command:",
|
||||
approval_prompts.command_label,
|
||||
" $ " .. command,
|
||||
"",
|
||||
" " .. string.rep("-", 50),
|
||||
" Press [y] or [Enter] to execute",
|
||||
" Press [n], [q], or [Esc] to cancel",
|
||||
"",
|
||||
}
|
||||
|
||||
local width = math.max(60, #command + 10)
|
||||
-- Add warning for dangerous commands
|
||||
if not perm_result.allowed and perm_result.reason ~= "Requires approval" then
|
||||
table.insert(lines, approval_prompts.warning_prefix .. perm_result.reason)
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
table.insert(lines, approval_prompts.divider)
|
||||
table.insert(lines, "")
|
||||
for _, opt in ipairs(approval_prompts.options) do
|
||||
table.insert(lines, opt)
|
||||
end
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, approval_prompts.divider)
|
||||
table.insert(lines, approval_prompts.cancel_hint)
|
||||
table.insert(lines, "")
|
||||
|
||||
local width = math.max(65, #command + 15)
|
||||
local height = #lines
|
||||
|
||||
local buf = vim.api.nvim_create_buf(false, true)
|
||||
@@ -196,45 +237,84 @@ function M.show_bash_approval(command, callback)
|
||||
title_pos = "center",
|
||||
})
|
||||
|
||||
-- Apply some highlighting
|
||||
-- Apply highlighting
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "Title", 1, 0, -1)
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "String", 5, 0, -1)
|
||||
|
||||
-- Highlight options
|
||||
for i, line in ipairs(lines) do
|
||||
if line:match("^%s+%[y%]") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticOk", i - 1, 0, -1)
|
||||
elseif line:match("^%s+%[s%]") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticInfo", i - 1, 0, -1)
|
||||
elseif line:match("^%s+%[a%]") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticHint", i - 1, 0, -1)
|
||||
elseif line:match("^%s+%[n%]") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticError", i - 1, 0, -1)
|
||||
elseif line:match("⚠️") then
|
||||
vim.api.nvim_buf_add_highlight(buf, -1, "DiagnosticWarn", i - 1, 0, -1)
|
||||
end
|
||||
end
|
||||
|
||||
local callback_called = false
|
||||
|
||||
local function close_and_respond(approved)
|
||||
local function close_and_respond(approved, permission_level)
|
||||
if callback_called then
|
||||
return
|
||||
end
|
||||
callback_called = true
|
||||
|
||||
-- Grant permission if approved with session or list level
|
||||
if approved and permission_level then
|
||||
permissions.grant_permission(command, permission_level)
|
||||
end
|
||||
|
||||
pcall(vim.api.nvim_win_close, win, true)
|
||||
|
||||
vim.schedule(function()
|
||||
callback(approved)
|
||||
callback({ approved = approved, permission_level = permission_level })
|
||||
end)
|
||||
end
|
||||
|
||||
local keymap_opts = { buffer = buf, noremap = true, silent = true, nowait = true }
|
||||
|
||||
-- Approve
|
||||
-- Allow once
|
||||
vim.keymap.set("n", "y", function()
|
||||
close_and_respond(true)
|
||||
close_and_respond(true, "allow")
|
||||
end, keymap_opts)
|
||||
vim.keymap.set("n", "<CR>", function()
|
||||
close_and_respond(true)
|
||||
close_and_respond(true, "allow")
|
||||
end, keymap_opts)
|
||||
|
||||
-- Allow this session
|
||||
vim.keymap.set("n", "s", function()
|
||||
close_and_respond(true, "allow_session")
|
||||
end, keymap_opts)
|
||||
|
||||
-- Add to allow list
|
||||
vim.keymap.set("n", "a", function()
|
||||
close_and_respond(true, "allow_list")
|
||||
end, keymap_opts)
|
||||
|
||||
-- Reject
|
||||
vim.keymap.set("n", "n", function()
|
||||
close_and_respond(false)
|
||||
close_and_respond(false, nil)
|
||||
end, keymap_opts)
|
||||
vim.keymap.set("n", "q", function()
|
||||
close_and_respond(false)
|
||||
close_and_respond(false, nil)
|
||||
end, keymap_opts)
|
||||
vim.keymap.set("n", "<Esc>", function()
|
||||
close_and_respond(false)
|
||||
close_and_respond(false, nil)
|
||||
end, keymap_opts)
|
||||
end
|
||||
|
||||
--- Show approval dialog for bash commands (simple version for backward compatibility)
|
||||
---@param command string The bash command to approve
|
||||
---@param callback fun(approved: boolean) Called with user decision
|
||||
function M.show_bash_approval_simple(command, callback)
|
||||
M.show_bash_approval(command, function(result)
|
||||
callback(result.approved)
|
||||
end)
|
||||
end
|
||||
|
||||
return M
|
||||
1098
lua/codetyper/core/diff/patch.lua
Normal file
1098
lua/codetyper/core/diff/patch.lua
Normal file
File diff suppressed because it is too large
Load Diff
572
lua/codetyper/core/diff/search_replace.lua
Normal file
572
lua/codetyper/core/diff/search_replace.lua
Normal file
@@ -0,0 +1,572 @@
|
||||
---@mod codetyper.agent.search_replace Search/Replace editing system
|
||||
---@brief [[
|
||||
--- Implements SEARCH/REPLACE block parsing and fuzzy matching for reliable code edits.
|
||||
--- Parses and applies SEARCH/REPLACE blocks from LLM responses.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local params = require("codetyper.params.agents.search_replace").patterns
|
||||
|
||||
---@class SearchReplaceBlock
|
||||
---@field search string The text to search for
|
||||
---@field replace string The text to replace with
|
||||
---@field file_path string|nil Optional file path for multi-file edits
|
||||
|
||||
---@class MatchResult
|
||||
---@field start_line number 1-indexed start line
|
||||
---@field end_line number 1-indexed end line
|
||||
---@field start_col number 1-indexed start column (for partial line matches)
|
||||
---@field end_col number 1-indexed end column
|
||||
---@field strategy string Which matching strategy succeeded
|
||||
---@field confidence number Match confidence (0.0-1.0)
|
||||
|
||||
--- Parse SEARCH/REPLACE blocks from LLM response
|
||||
--- Supports multiple formats:
|
||||
--- Format 1 (dash style):
|
||||
--- ------- SEARCH
|
||||
--- old code
|
||||
--- =======
|
||||
--- new code
|
||||
--- +++++++ REPLACE
|
||||
---
|
||||
--- Format 2 (claude style):
|
||||
--- <<<<<<< SEARCH
|
||||
--- old code
|
||||
--- =======
|
||||
--- new code
|
||||
--- >>>>>>> REPLACE
|
||||
---
|
||||
--- Format 3 (simple):
|
||||
--- [SEARCH]
|
||||
--- old code
|
||||
--- [REPLACE]
|
||||
--- new code
|
||||
--- [END]
|
||||
---
|
||||
---@param response string LLM response text
|
||||
---@return SearchReplaceBlock[]
|
||||
function M.parse_blocks(response)
|
||||
local blocks = {}
|
||||
|
||||
-- Try dash-style format: ------- SEARCH ... ======= ... +++++++ REPLACE
|
||||
for search, replace in response:gmatch(params.dash_style) do
|
||||
table.insert(blocks, { search = search, replace = replace })
|
||||
end
|
||||
|
||||
if #blocks > 0 then
|
||||
return blocks
|
||||
end
|
||||
|
||||
-- Try claude-style format: <<<<<<< SEARCH ... ======= ... >>>>>>> REPLACE
|
||||
for search, replace in response:gmatch(params.claude_style) do
|
||||
table.insert(blocks, { search = search, replace = replace })
|
||||
end
|
||||
|
||||
if #blocks > 0 then
|
||||
return blocks
|
||||
end
|
||||
|
||||
-- Try simple format: [SEARCH] ... [REPLACE] ... [END]
|
||||
for search, replace in response:gmatch(params.simple_style) do
|
||||
table.insert(blocks, { search = search, replace = replace })
|
||||
end
|
||||
|
||||
if #blocks > 0 then
|
||||
return blocks
|
||||
end
|
||||
|
||||
-- Try markdown diff format: ```diff ... ```
|
||||
local diff_block = response:match(params.diff_block)
|
||||
if diff_block then
|
||||
local old_lines = {}
|
||||
local new_lines = {}
|
||||
for line in diff_block:gmatch("[^\n]+") do
|
||||
if line:match("^%-[^%-]") then
|
||||
-- Removed line (starts with single -)
|
||||
table.insert(old_lines, line:sub(2))
|
||||
elseif line:match("^%+[^%+]") then
|
||||
-- Added line (starts with single +)
|
||||
table.insert(new_lines, line:sub(2))
|
||||
elseif line:match("^%s") or line:match("^[^%-%+@]") then
|
||||
-- Context line
|
||||
table.insert(old_lines, line:match("^%s?(.*)"))
|
||||
table.insert(new_lines, line:match("^%s?(.*)"))
|
||||
end
|
||||
end
|
||||
if #old_lines > 0 or #new_lines > 0 then
|
||||
table.insert(blocks, {
|
||||
search = table.concat(old_lines, "\n"),
|
||||
replace = table.concat(new_lines, "\n"),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return blocks
|
||||
end
|
||||
|
||||
--- Get indentation of a line
|
||||
---@param line string
|
||||
---@return string
|
||||
local function get_indentation(line)
|
||||
if not line then
|
||||
return ""
|
||||
end
|
||||
return line:match("^(%s*)") or ""
|
||||
end
|
||||
|
||||
--- Normalize whitespace in a string (collapse multiple spaces to one)
|
||||
---@param str string
|
||||
---@return string
|
||||
local function normalize_whitespace(str)
|
||||
-- Wrap in parentheses to only return first value (gsub returns string + count)
|
||||
return (str:gsub("%s+", " "):gsub("^%s*", ""):gsub("%s*$", ""))
|
||||
end
|
||||
|
||||
--- Trim trailing whitespace from each line
|
||||
---@param str string
|
||||
---@return string
|
||||
local function trim_lines(str)
|
||||
local lines = vim.split(str, "\n", { plain = true })
|
||||
for i, line in ipairs(lines) do
|
||||
-- Wrap in parentheses to only get string, not count
|
||||
lines[i] = (line:gsub("%s+$", ""))
|
||||
end
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Calculate Levenshtein distance between two strings
|
||||
---@param s1 string
|
||||
---@param s2 string
|
||||
---@return number
|
||||
local function levenshtein(s1, s2)
|
||||
local len1, len2 = #s1, #s2
|
||||
if len1 == 0 then
|
||||
return len2
|
||||
end
|
||||
if len2 == 0 then
|
||||
return len1
|
||||
end
|
||||
|
||||
local matrix = {}
|
||||
for i = 0, len1 do
|
||||
matrix[i] = { [0] = i }
|
||||
end
|
||||
for j = 0, len2 do
|
||||
matrix[0][j] = j
|
||||
end
|
||||
|
||||
for i = 1, len1 do
|
||||
for j = 1, len2 do
|
||||
local cost = (s1:sub(i, i) == s2:sub(j, j)) and 0 or 1
|
||||
matrix[i][j] = math.min(
|
||||
matrix[i - 1][j] + 1,
|
||||
matrix[i][j - 1] + 1,
|
||||
matrix[i - 1][j - 1] + cost
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
return matrix[len1][len2]
|
||||
end
|
||||
|
||||
--- Calculate similarity ratio (0.0-1.0) between two strings
|
||||
---@param s1 string
|
||||
---@param s2 string
|
||||
---@return number
|
||||
local function similarity(s1, s2)
|
||||
if s1 == s2 then
|
||||
return 1.0
|
||||
end
|
||||
local max_len = math.max(#s1, #s2)
|
||||
if max_len == 0 then
|
||||
return 1.0
|
||||
end
|
||||
local distance = levenshtein(s1, s2)
|
||||
return 1.0 - (distance / max_len)
|
||||
end
|
||||
|
||||
--- Strategy 1: Exact match
|
||||
---@param content_lines string[]
|
||||
---@param search_lines string[]
|
||||
---@return MatchResult|nil
|
||||
local function exact_match(content_lines, search_lines)
|
||||
if #search_lines == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
for i = 1, #content_lines - #search_lines + 1 do
|
||||
local match = true
|
||||
for j = 1, #search_lines do
|
||||
if content_lines[i + j - 1] ~= search_lines[j] then
|
||||
match = false
|
||||
break
|
||||
end
|
||||
end
|
||||
if match then
|
||||
return {
|
||||
start_line = i,
|
||||
end_line = i + #search_lines - 1,
|
||||
start_col = 1,
|
||||
end_col = #content_lines[i + #search_lines - 1],
|
||||
strategy = "exact",
|
||||
confidence = 1.0,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Strategy 2: Line-trimmed match (ignore trailing whitespace)
|
||||
---@param content_lines string[]
|
||||
---@param search_lines string[]
|
||||
---@return MatchResult|nil
|
||||
local function line_trimmed_match(content_lines, search_lines)
|
||||
if #search_lines == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local trimmed_search = {}
|
||||
for _, line in ipairs(search_lines) do
|
||||
table.insert(trimmed_search, (line:gsub("%s+$", "")))
|
||||
end
|
||||
|
||||
for i = 1, #content_lines - #search_lines + 1 do
|
||||
local match = true
|
||||
for j = 1, #search_lines do
|
||||
local trimmed_content = content_lines[i + j - 1]:gsub("%s+$", "")
|
||||
if trimmed_content ~= trimmed_search[j] then
|
||||
match = false
|
||||
break
|
||||
end
|
||||
end
|
||||
if match then
|
||||
return {
|
||||
start_line = i,
|
||||
end_line = i + #search_lines - 1,
|
||||
start_col = 1,
|
||||
end_col = #content_lines[i + #search_lines - 1],
|
||||
strategy = "line_trimmed",
|
||||
confidence = 0.95,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Strategy 3: Indentation-flexible match (normalize indentation)
|
||||
---@param content_lines string[]
|
||||
---@param search_lines string[]
|
||||
---@return MatchResult|nil
|
||||
local function indentation_flexible_match(content_lines, search_lines)
|
||||
if #search_lines == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Get base indentation from search (first non-empty line)
|
||||
local search_indent = ""
|
||||
for _, line in ipairs(search_lines) do
|
||||
if line:match("%S") then
|
||||
search_indent = get_indentation(line)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
-- Strip common indentation from search
|
||||
local stripped_search = {}
|
||||
for _, line in ipairs(search_lines) do
|
||||
if line:match("^" .. vim.pesc(search_indent)) then
|
||||
table.insert(stripped_search, line:sub(#search_indent + 1))
|
||||
else
|
||||
table.insert(stripped_search, line)
|
||||
end
|
||||
end
|
||||
|
||||
for i = 1, #content_lines - #search_lines + 1 do
|
||||
-- Get content indentation at this position
|
||||
local content_indent = ""
|
||||
for j = 0, #search_lines - 1 do
|
||||
local line = content_lines[i + j]
|
||||
if line:match("%S") then
|
||||
content_indent = get_indentation(line)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
local match = true
|
||||
for j = 1, #search_lines do
|
||||
local content_line = content_lines[i + j - 1]
|
||||
local expected = content_indent .. stripped_search[j]
|
||||
|
||||
-- Compare with normalized indentation
|
||||
if content_line:gsub("%s+$", "") ~= expected:gsub("%s+$", "") then
|
||||
match = false
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if match then
|
||||
return {
|
||||
start_line = i,
|
||||
end_line = i + #search_lines - 1,
|
||||
start_col = 1,
|
||||
end_col = #content_lines[i + #search_lines - 1],
|
||||
strategy = "indentation_flexible",
|
||||
confidence = 0.9,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Strategy 4: Block anchor match (match first/last lines, fuzzy middle)
|
||||
---@param content_lines string[]
|
||||
---@param search_lines string[]
|
||||
---@return MatchResult|nil
|
||||
local function block_anchor_match(content_lines, search_lines)
|
||||
if #search_lines < 2 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local first_search = search_lines[1]:gsub("%s+$", "")
|
||||
local last_search = search_lines[#search_lines]:gsub("%s+$", "")
|
||||
|
||||
-- Find potential start positions
|
||||
local candidates = {}
|
||||
for i = 1, #content_lines - #search_lines + 1 do
|
||||
local first_content = content_lines[i]:gsub("%s+$", "")
|
||||
if similarity(first_content, first_search) > 0.8 then
|
||||
-- Check if last line also matches
|
||||
local last_idx = i + #search_lines - 1
|
||||
if last_idx <= #content_lines then
|
||||
local last_content = content_lines[last_idx]:gsub("%s+$", "")
|
||||
if similarity(last_content, last_search) > 0.8 then
|
||||
-- Calculate overall similarity
|
||||
local total_sim = 0
|
||||
for j = 1, #search_lines do
|
||||
local c = content_lines[i + j - 1]:gsub("%s+$", "")
|
||||
local s = search_lines[j]:gsub("%s+$", "")
|
||||
total_sim = total_sim + similarity(c, s)
|
||||
end
|
||||
local avg_sim = total_sim / #search_lines
|
||||
if avg_sim > 0.7 then
|
||||
table.insert(candidates, { start = i, similarity = avg_sim })
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Return best match
|
||||
if #candidates > 0 then
|
||||
table.sort(candidates, function(a, b)
|
||||
return a.similarity > b.similarity
|
||||
end)
|
||||
local best = candidates[1]
|
||||
return {
|
||||
start_line = best.start,
|
||||
end_line = best.start + #search_lines - 1,
|
||||
start_col = 1,
|
||||
end_col = #content_lines[best.start + #search_lines - 1],
|
||||
strategy = "block_anchor",
|
||||
confidence = best.similarity * 0.85,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Strategy 5: Whitespace-normalized match
|
||||
---@param content_lines string[]
|
||||
---@param search_lines string[]
|
||||
---@return MatchResult|nil
|
||||
local function whitespace_normalized_match(content_lines, search_lines)
|
||||
if #search_lines == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Normalize search lines
|
||||
local norm_search = {}
|
||||
for _, line in ipairs(search_lines) do
|
||||
table.insert(norm_search, normalize_whitespace(line))
|
||||
end
|
||||
|
||||
for i = 1, #content_lines - #search_lines + 1 do
|
||||
local match = true
|
||||
for j = 1, #search_lines do
|
||||
local norm_content = normalize_whitespace(content_lines[i + j - 1])
|
||||
if norm_content ~= norm_search[j] then
|
||||
match = false
|
||||
break
|
||||
end
|
||||
end
|
||||
if match then
|
||||
return {
|
||||
start_line = i,
|
||||
end_line = i + #search_lines - 1,
|
||||
start_col = 1,
|
||||
end_col = #content_lines[i + #search_lines - 1],
|
||||
strategy = "whitespace_normalized",
|
||||
confidence = 0.8,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Find the best match for search text in content
|
||||
---@param content string File content
|
||||
---@param search string Text to search for
|
||||
---@return MatchResult|nil
|
||||
function M.find_match(content, search)
|
||||
local content_lines = vim.split(content, "\n", { plain = true })
|
||||
local search_lines = vim.split(search, "\n", { plain = true })
|
||||
|
||||
-- Remove trailing empty lines from search
|
||||
while #search_lines > 0 and search_lines[#search_lines]:match("^%s*$") do
|
||||
table.remove(search_lines)
|
||||
end
|
||||
|
||||
if #search_lines == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Try strategies in order of strictness
|
||||
local strategies = {
|
||||
exact_match,
|
||||
line_trimmed_match,
|
||||
indentation_flexible_match,
|
||||
block_anchor_match,
|
||||
whitespace_normalized_match,
|
||||
}
|
||||
|
||||
for _, strategy in ipairs(strategies) do
|
||||
local result = strategy(content_lines, search_lines)
|
||||
if result then
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Apply a single SEARCH/REPLACE block to content
|
||||
---@param content string Original file content
|
||||
---@param block SearchReplaceBlock
|
||||
---@return string|nil new_content
|
||||
---@return MatchResult|nil match_info
|
||||
---@return string|nil error
|
||||
function M.apply_block(content, block)
|
||||
local match = M.find_match(content, block.search)
|
||||
if not match then
|
||||
return nil, nil, "Could not find search text in file"
|
||||
end
|
||||
|
||||
local content_lines = vim.split(content, "\n", { plain = true })
|
||||
local replace_lines = vim.split(block.replace, "\n", { plain = true })
|
||||
|
||||
-- Adjust indentation of replacement to match original
|
||||
local original_indent = get_indentation(content_lines[match.start_line])
|
||||
local replace_indent = ""
|
||||
for _, line in ipairs(replace_lines) do
|
||||
if line:match("%S") then
|
||||
replace_indent = get_indentation(line)
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
-- Apply indentation adjustment
|
||||
local adjusted_replace = {}
|
||||
for _, line in ipairs(replace_lines) do
|
||||
if line:match("^" .. vim.pesc(replace_indent)) then
|
||||
table.insert(adjusted_replace, original_indent .. line:sub(#replace_indent + 1))
|
||||
elseif line:match("^%s*$") then
|
||||
table.insert(adjusted_replace, "")
|
||||
else
|
||||
table.insert(adjusted_replace, original_indent .. line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Build new content
|
||||
local new_lines = {}
|
||||
for i = 1, match.start_line - 1 do
|
||||
table.insert(new_lines, content_lines[i])
|
||||
end
|
||||
for _, line in ipairs(adjusted_replace) do
|
||||
table.insert(new_lines, line)
|
||||
end
|
||||
for i = match.end_line + 1, #content_lines do
|
||||
table.insert(new_lines, content_lines[i])
|
||||
end
|
||||
|
||||
return table.concat(new_lines, "\n"), match, nil
|
||||
end
|
||||
|
||||
--- Apply multiple SEARCH/REPLACE blocks to content
|
||||
---@param content string Original file content
|
||||
---@param blocks SearchReplaceBlock[]
|
||||
---@return string new_content
|
||||
---@return table results Array of {success: boolean, match: MatchResult|nil, error: string|nil}
|
||||
function M.apply_blocks(content, blocks)
|
||||
local current_content = content
|
||||
local results = {}
|
||||
|
||||
for _, block in ipairs(blocks) do
|
||||
local new_content, match, err = M.apply_block(current_content, block)
|
||||
if new_content then
|
||||
current_content = new_content
|
||||
table.insert(results, { success = true, match = match })
|
||||
else
|
||||
table.insert(results, { success = false, error = err })
|
||||
end
|
||||
end
|
||||
|
||||
return current_content, results
|
||||
end
|
||||
|
||||
--- Apply SEARCH/REPLACE blocks to a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param blocks SearchReplaceBlock[]
|
||||
---@return boolean success
|
||||
---@return string|nil error
|
||||
function M.apply_to_buffer(bufnr, blocks)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return false, "Invalid buffer"
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local content = table.concat(lines, "\n")
|
||||
|
||||
local new_content, results = M.apply_blocks(content, blocks)
|
||||
|
||||
-- Check for any failures
|
||||
local failures = {}
|
||||
for i, result in ipairs(results) do
|
||||
if not result.success then
|
||||
table.insert(failures, string.format("Block %d: %s", i, result.error or "unknown error"))
|
||||
end
|
||||
end
|
||||
|
||||
if #failures > 0 then
|
||||
return false, table.concat(failures, "; ")
|
||||
end
|
||||
|
||||
-- Apply to buffer
|
||||
local new_lines = vim.split(new_content, "\n", { plain = true })
|
||||
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, new_lines)
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
--- Check if response contains SEARCH/REPLACE blocks
|
||||
---@param response string
|
||||
---@return boolean
|
||||
function M.has_blocks(response)
|
||||
return #M.parse_blocks(response) > 0
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -23,7 +23,7 @@ local M = {}
|
||||
---@field priority number Priority (1=high, 2=normal, 3=low)
|
||||
---@field status string "pending"|"processing"|"completed"|"escalated"|"cancelled"|"needs_context"|"failed"
|
||||
---@field attempt_count number Number of processing attempts
|
||||
---@field worker_type string|nil LLM provider used ("ollama"|"claude"|etc)
|
||||
---@field worker_type string|nil LLM provider used ("ollama"|"openai"|"gemini"|"copilot")
|
||||
---@field created_at number System time when created
|
||||
---@field intent Intent|nil Detected intent from prompt
|
||||
---@field scope ScopeInfo|nil Resolved scope (function/class/file)
|
||||
@@ -194,7 +194,7 @@ function M.enqueue(event)
|
||||
|
||||
-- Log to agent logs if available
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "queue",
|
||||
message = string.format("Event queued: %s (priority: %d)", event.id, event.priority),
|
||||
117
lua/codetyper/core/intent/init.lua
Normal file
117
lua/codetyper/core/intent/init.lua
Normal file
@@ -0,0 +1,117 @@
|
||||
---@mod codetyper.agent.intent Intent detection from prompts
|
||||
---@brief [[
|
||||
--- Parses prompt content to determine user intent and target scope.
|
||||
--- Intents determine how the generated code should be applied.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class Intent
|
||||
---@field type string "complete"|"refactor"|"add"|"fix"|"document"|"test"|"explain"|"optimize"
|
||||
---@field scope_hint string|nil "function"|"class"|"block"|"file"|"selection"|nil
|
||||
---@field confidence number 0.0-1.0 how confident we are about the intent
|
||||
---@field action string "replace"|"insert"|"append"|"none"
|
||||
---@field keywords string[] Keywords that triggered this intent
|
||||
|
||||
local params = require("codetyper.params.agents.intent")
|
||||
local intent_patterns = params.intent_patterns
|
||||
local scope_patterns = params.scope_patterns
|
||||
local prompts = require("codetyper.prompts.agents.intent")
|
||||
|
||||
--- Detect intent from prompt content
|
||||
---@param prompt string The prompt content
|
||||
---@return Intent
|
||||
function M.detect(prompt)
|
||||
local lower = prompt:lower()
|
||||
local best_match = nil
|
||||
local best_priority = 999
|
||||
local matched_keywords = {}
|
||||
|
||||
-- Check each intent type
|
||||
for intent_type, config in pairs(intent_patterns) do
|
||||
for _, pattern in ipairs(config.patterns) do
|
||||
if lower:find(pattern, 1, true) then
|
||||
if config.priority < best_priority then
|
||||
best_match = intent_type
|
||||
best_priority = config.priority
|
||||
matched_keywords = { pattern }
|
||||
elseif config.priority == best_priority and best_match == intent_type then
|
||||
table.insert(matched_keywords, pattern)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Default to "add" if no clear intent
|
||||
if not best_match then
|
||||
best_match = "add"
|
||||
matched_keywords = {}
|
||||
end
|
||||
|
||||
local config = intent_patterns[best_match]
|
||||
|
||||
-- Detect scope hint from prompt
|
||||
local scope_hint = config.scope_hint
|
||||
for pattern, hint in pairs(scope_patterns) do
|
||||
if lower:find(pattern, 1, true) then
|
||||
scope_hint = hint or scope_hint
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
-- Calculate confidence based on keyword matches
|
||||
local confidence = 0.5 + (#matched_keywords * 0.15)
|
||||
confidence = math.min(confidence, 1.0)
|
||||
|
||||
return {
|
||||
type = best_match,
|
||||
scope_hint = scope_hint,
|
||||
confidence = confidence,
|
||||
action = config.action,
|
||||
keywords = matched_keywords,
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if intent requires code modification
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.modifies_code(intent)
|
||||
return intent.action ~= "none"
|
||||
end
|
||||
|
||||
--- Check if intent should replace existing code
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.is_replacement(intent)
|
||||
return intent.action == "replace"
|
||||
end
|
||||
|
||||
--- Check if intent adds new code
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.is_insertion(intent)
|
||||
return intent.action == "insert" or intent.action == "append"
|
||||
end
|
||||
|
||||
--- Get system prompt modifier based on intent
|
||||
---@param intent Intent
|
||||
---@return string
|
||||
function M.get_prompt_modifier(intent)
|
||||
local modifiers = prompts.modifiers
|
||||
return modifiers[intent.type] or modifiers.add
|
||||
end
|
||||
|
||||
--- Format intent for logging
|
||||
---@param intent Intent
|
||||
---@return string
|
||||
function M.format(intent)
|
||||
return string.format(
|
||||
"%s (scope: %s, action: %s, confidence: %.2f)",
|
||||
intent.type,
|
||||
intent.scope_hint or "auto",
|
||||
intent.action,
|
||||
intent.confidence
|
||||
)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -6,41 +6,14 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local params = require("codetyper.params.agents.confidence")
|
||||
|
||||
--- Heuristic weights (must sum to 1.0)
|
||||
M.weights = {
|
||||
length = 0.15, -- Response length relative to prompt
|
||||
uncertainty = 0.30, -- Uncertainty phrases
|
||||
syntax = 0.25, -- Syntax completeness
|
||||
repetition = 0.15, -- Duplicate lines
|
||||
truncation = 0.15, -- Incomplete ending
|
||||
}
|
||||
M.weights = params.weights
|
||||
|
||||
--- Uncertainty phrases that indicate low confidence
|
||||
local uncertainty_phrases = {
|
||||
-- English
|
||||
"i'm not sure",
|
||||
"i am not sure",
|
||||
"maybe",
|
||||
"perhaps",
|
||||
"might work",
|
||||
"could work",
|
||||
"not certain",
|
||||
"uncertain",
|
||||
"i think",
|
||||
"possibly",
|
||||
"TODO",
|
||||
"FIXME",
|
||||
"XXX",
|
||||
"placeholder",
|
||||
"implement this",
|
||||
"fill in",
|
||||
"your code here",
|
||||
"...", -- Ellipsis as placeholder
|
||||
"# TODO",
|
||||
"// TODO",
|
||||
"-- TODO",
|
||||
"/* TODO",
|
||||
}
|
||||
local uncertainty_phrases = params.uncertainty_phrases
|
||||
|
||||
|
||||
--- Score based on response length relative to prompt
|
||||
---@param response string
|
||||
@@ -94,32 +67,6 @@ local function score_uncertainty(response)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check bracket balance for common languages
|
||||
---@param response string
|
||||
---@return boolean balanced
|
||||
local function check_brackets(response)
|
||||
local pairs = {
|
||||
["{"] = "}",
|
||||
["["] = "]",
|
||||
["("] = ")",
|
||||
}
|
||||
|
||||
local stack = {}
|
||||
|
||||
for char in response:gmatch(".") do
|
||||
if pairs[char] then
|
||||
table.insert(stack, pairs[char])
|
||||
elseif char == "}" or char == "]" or char == ")" then
|
||||
if #stack == 0 or stack[#stack] ~= char then
|
||||
return false
|
||||
end
|
||||
table.remove(stack)
|
||||
end
|
||||
end
|
||||
|
||||
return #stack == 0
|
||||
end
|
||||
|
||||
--- Score based on syntax completeness
|
||||
---@param response string
|
||||
---@return number 0.0-1.0
|
||||
@@ -127,7 +74,7 @@ local function score_syntax(response)
|
||||
local score = 1.0
|
||||
|
||||
-- Check bracket balance
|
||||
if not check_brackets(response) then
|
||||
if not require("codetyper.support.utils").check_brackets(response) then
|
||||
score = score - 0.4
|
||||
end
|
||||
|
||||
@@ -255,14 +202,15 @@ function M.score(response, prompt, context)
|
||||
_ = context -- Reserved for future use
|
||||
|
||||
if not response or #response == 0 then
|
||||
return 0, {
|
||||
length = 0,
|
||||
uncertainty = 0,
|
||||
syntax = 0,
|
||||
repetition = 0,
|
||||
truncation = 0,
|
||||
weighted_total = 0,
|
||||
}
|
||||
return 0,
|
||||
{
|
||||
length = 0,
|
||||
uncertainty = 0,
|
||||
syntax = 0,
|
||||
repetition = 0,
|
||||
truncation = 0,
|
||||
weighted_total = 0,
|
||||
}
|
||||
end
|
||||
|
||||
local scores = {
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local llm = require("codetyper.llm")
|
||||
local utils = require("codetyper.support.utils")
|
||||
local llm = require("codetyper.core.llm")
|
||||
|
||||
--- Copilot API endpoints
|
||||
local AUTH_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
@@ -14,6 +14,51 @@ local AUTH_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
---@field github_token table|nil
|
||||
M.state = nil
|
||||
|
||||
--- Track if we've already suggested Ollama fallback this session
|
||||
local ollama_fallback_suggested = false
|
||||
|
||||
--- Suggest switching to Ollama when rate limits are hit
|
||||
---@param error_msg string The error message that triggered this
|
||||
function M.suggest_ollama_fallback(error_msg)
|
||||
if ollama_fallback_suggested then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if Ollama is available
|
||||
local ollama_available = false
|
||||
vim.fn.jobstart({ "curl", "-s", "http://localhost:11434/api/tags" }, {
|
||||
on_exit = function(_, code)
|
||||
if code == 0 then
|
||||
ollama_available = true
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
if ollama_available then
|
||||
-- Switch to Ollama automatically
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
config.llm.provider = "ollama"
|
||||
|
||||
ollama_fallback_suggested = true
|
||||
utils.notify(
|
||||
"⚠️ Copilot rate limit reached. Switched to Ollama automatically.\n"
|
||||
.. "Original error: "
|
||||
.. error_msg:sub(1, 100),
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
else
|
||||
utils.notify(
|
||||
"⚠️ Copilot rate limit reached. Ollama not available.\n"
|
||||
.. "Start Ollama with: ollama serve\n"
|
||||
.. "Or wait for Copilot limits to reset.",
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
end
|
||||
end)
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Get OAuth token from copilot.lua or copilot.vim config
|
||||
---@return string|nil OAuth token
|
||||
local function get_oauth_token()
|
||||
@@ -51,9 +96,16 @@ local function get_oauth_token()
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("copilot")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.copilot.model
|
||||
@@ -204,15 +256,37 @@ local function make_request(token, body, callback)
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
-- Show the actual response text as the error (truncated if too long)
|
||||
local error_msg = response_text
|
||||
if #error_msg > 200 then
|
||||
error_msg = error_msg:sub(1, 200) .. "..."
|
||||
end
|
||||
|
||||
-- Clean up common patterns
|
||||
if response_text:match("<!DOCTYPE") or response_text:match("<html") then
|
||||
error_msg = "Copilot API returned HTML error page. Service may be unavailable."
|
||||
end
|
||||
|
||||
-- Check for rate limit and suggest Ollama fallback
|
||||
if response_text:match("limit") or response_text:match("Upgrade") or response_text:match("quota") then
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Copilot response", nil)
|
||||
callback(nil, error_msg, nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
local error_msg = response.error.message or "Copilot API error"
|
||||
if response.error.code == "rate_limit_exceeded" or (error_msg:match("limit") and error_msg:match("plan")) then
|
||||
error_msg = "Copilot rate limit: " .. error_msg
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error.message or "Copilot API error", nil)
|
||||
callback(nil, error_msg, nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
@@ -220,6 +294,17 @@ local function make_request(token, body, callback)
|
||||
-- Extract usage info
|
||||
local usage = response.usage or {}
|
||||
|
||||
-- Record usage for cost tracking
|
||||
if usage.prompt_tokens or usage.completion_tokens then
|
||||
local cost = require("codetyper.core.cost")
|
||||
cost.record_usage(
|
||||
get_model(),
|
||||
usage.prompt_tokens or 0,
|
||||
usage.completion_tokens or 0,
|
||||
usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
if response.choices and response.choices[1] and response.choices[1].message then
|
||||
local code = llm.extract_code(response.choices[1].message.content)
|
||||
vim.schedule(function()
|
||||
@@ -263,7 +348,7 @@ end
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil)
|
||||
function M.generate(prompt, context, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
|
||||
ensure_initialized()
|
||||
|
||||
@@ -329,7 +414,7 @@ end
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil)
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
|
||||
ensure_initialized()
|
||||
|
||||
@@ -351,31 +436,68 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
return
|
||||
end
|
||||
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
local tools_module = require("codetyper.core.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agents")
|
||||
|
||||
-- Build system prompt with agent instructions
|
||||
-- Build system prompt with agent instructions and project context
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.build_system_prompt()
|
||||
|
||||
-- Format messages for Copilot (OpenAI-compatible format)
|
||||
local copilot_messages = { { role = "system", content = system_prompt } }
|
||||
for _, msg in ipairs(messages) do
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(copilot_messages, { role = msg.role, content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(msg.content) do
|
||||
if part.type == "tool_result" then
|
||||
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
|
||||
elseif part.type == "text" then
|
||||
table.insert(text_parts, part.text or "")
|
||||
if msg.role == "user" then
|
||||
-- User messages - handle string or table content
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(copilot_messages, { role = "user", content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
-- Handle complex content (like tool results from user perspective)
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(msg.content) do
|
||||
if part.type == "tool_result" then
|
||||
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
|
||||
elseif part.type == "text" then
|
||||
table.insert(text_parts, part.text or "")
|
||||
end
|
||||
end
|
||||
if #text_parts > 0 then
|
||||
table.insert(copilot_messages, { role = "user", content = table.concat(text_parts, "\n") })
|
||||
end
|
||||
end
|
||||
if #text_parts > 0 then
|
||||
table.insert(copilot_messages, { role = msg.role, content = table.concat(text_parts, "\n") })
|
||||
elseif msg.role == "assistant" then
|
||||
-- Assistant messages - must preserve tool_calls if present
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = type(msg.content) == "string" and msg.content or nil,
|
||||
}
|
||||
-- Convert tool_calls to OpenAI format for the API
|
||||
if msg.tool_calls and #msg.tool_calls > 0 then
|
||||
assistant_msg.tool_calls = {}
|
||||
for _, tc in ipairs(msg.tool_calls) do
|
||||
-- Convert from parsed format {id, name, parameters} to OpenAI format
|
||||
local openai_tc = {
|
||||
id = tc.id,
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tc.name,
|
||||
arguments = vim.json.encode(tc.parameters or {}),
|
||||
},
|
||||
}
|
||||
table.insert(assistant_msg.tool_calls, openai_tc)
|
||||
end
|
||||
-- Ensure content is not nil when tool_calls present
|
||||
if assistant_msg.content == nil then
|
||||
assistant_msg.content = ""
|
||||
end
|
||||
end
|
||||
table.insert(copilot_messages, assistant_msg)
|
||||
elseif msg.role == "tool" then
|
||||
-- Tool result messages - must have tool_call_id
|
||||
table.insert(copilot_messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
@@ -386,6 +508,7 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
temperature = 0.3,
|
||||
stream = false,
|
||||
tools = tools_module.to_openai_format(),
|
||||
tool_choice = "auto", -- Encourage the model to use tools when appropriate
|
||||
}
|
||||
|
||||
local endpoint = (token.endpoints and token.endpoints.api or "https://api.githubcopilot.com")
|
||||
@@ -396,6 +519,20 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Copilot API...")
|
||||
|
||||
-- Log request to debug file
|
||||
local debug_log_path = vim.fn.expand("~/.local/codetyper-debug.log")
|
||||
local debug_f = io.open(debug_log_path, "a")
|
||||
if debug_f then
|
||||
debug_f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. "COPILOT REQUEST\n")
|
||||
debug_f:write("Messages count: " .. #copilot_messages .. "\n")
|
||||
for i, m in ipairs(copilot_messages) do
|
||||
debug_f:write(string.format(" [%d] role=%s, has_tool_calls=%s, has_tool_call_id=%s\n",
|
||||
i, m.role, tostring(m.tool_calls ~= nil), tostring(m.tool_call_id ~= nil)))
|
||||
end
|
||||
debug_f:write("---\n")
|
||||
debug_f:close()
|
||||
end
|
||||
|
||||
local headers = build_headers(token)
|
||||
local cmd = {
|
||||
"curl",
|
||||
@@ -413,35 +550,97 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
table.insert(cmd, "-d")
|
||||
table.insert(cmd, json_body)
|
||||
|
||||
-- Debug logging helper
|
||||
local function debug_log(msg, data)
|
||||
local log_path = vim.fn.expand("~/.local/codetyper-debug.log")
|
||||
local f = io.open(log_path, "a")
|
||||
if f then
|
||||
f:write(os.date("[%Y-%m-%d %H:%M:%S] ") .. msg .. "\n")
|
||||
if data then
|
||||
f:write("DATA: " .. tostring(data):sub(1, 2000) .. "\n")
|
||||
end
|
||||
f:write("---\n")
|
||||
f:close()
|
||||
end
|
||||
end
|
||||
|
||||
-- Prevent double callback calls
|
||||
local callback_called = false
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if callback_called then
|
||||
debug_log("on_stdout: callback already called, skipping")
|
||||
return
|
||||
end
|
||||
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
debug_log("on_stdout: empty data")
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
debug_log("on_stdout: received response", response_text)
|
||||
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
debug_log("JSON parse failed", response_text)
|
||||
callback_called = true
|
||||
|
||||
-- Show the actual response text as the error (truncated if too long)
|
||||
local error_msg = response_text
|
||||
if #error_msg > 200 then
|
||||
error_msg = error_msg:sub(1, 200) .. "..."
|
||||
end
|
||||
|
||||
-- Clean up common patterns
|
||||
if response_text:match("<!DOCTYPE") or response_text:match("<html") then
|
||||
error_msg = "Copilot API returned HTML error page. Service may be unavailable."
|
||||
end
|
||||
|
||||
-- Check for rate limit and suggest Ollama fallback
|
||||
if response_text:match("limit") or response_text:match("Upgrade") or response_text:match("quota") then
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
logs.error("Failed to parse Copilot response")
|
||||
callback(nil, "Failed to parse Copilot response")
|
||||
logs.error(error_msg)
|
||||
callback(nil, error_msg)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
callback_called = true
|
||||
local error_msg = response.error.message or "Copilot API error"
|
||||
|
||||
-- Check for rate limit in structured error
|
||||
if response.error.code == "rate_limit_exceeded" or (error_msg:match("limit") and error_msg:match("plan")) then
|
||||
error_msg = "Copilot rate limit: " .. error_msg
|
||||
M.suggest_ollama_fallback(error_msg)
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
logs.error(response.error.message or "Copilot API error")
|
||||
callback(nil, response.error.message or "Copilot API error")
|
||||
logs.error(error_msg)
|
||||
callback(nil, error_msg)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Log token usage
|
||||
-- Log token usage and record cost
|
||||
if response.usage then
|
||||
logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop")
|
||||
|
||||
-- Record usage for cost tracking
|
||||
local cost_tracker = require("codetyper.core.cost")
|
||||
cost_tracker.record_usage(
|
||||
get_model(),
|
||||
response.usage.prompt_tokens or 0,
|
||||
response.usage.completion_tokens or 0,
|
||||
response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
-- Convert to Claude-like format for parser compatibility
|
||||
@@ -474,12 +673,19 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
end
|
||||
end
|
||||
|
||||
callback_called = true
|
||||
debug_log("on_stdout: success, calling callback")
|
||||
vim.schedule(function()
|
||||
callback(converted, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if callback_called then
|
||||
return
|
||||
end
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
debug_log("on_stderr", table.concat(data, "\n"))
|
||||
callback_called = true
|
||||
vim.schedule(function()
|
||||
logs.error("Copilot API request failed: " .. table.concat(data, "\n"))
|
||||
callback(nil, "Copilot API request failed: " .. table.concat(data, "\n"))
|
||||
@@ -487,7 +693,12 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
debug_log("on_exit: code=" .. code .. ", callback_called=" .. tostring(callback_called))
|
||||
if callback_called then
|
||||
return
|
||||
end
|
||||
if code ~= 0 then
|
||||
callback_called = true
|
||||
vim.schedule(function()
|
||||
logs.error("Copilot API request failed with code: " .. code)
|
||||
callback(nil, "Copilot API request failed with code: " .. code)
|
||||
@@ -2,23 +2,37 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local llm = require("codetyper.llm")
|
||||
local utils = require("codetyper.support.utils")
|
||||
local llm = require("codetyper.core.llm")
|
||||
|
||||
--- Gemini API endpoint
|
||||
local API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
--- Get API key from config or environment
|
||||
--- Get API key from stored credentials, config, or environment
|
||||
---@return string|nil API key
|
||||
local function get_api_key()
|
||||
-- Priority: stored credentials > config > environment
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_key = credentials.get_api_key("gemini")
|
||||
if stored_key then
|
||||
return stored_key
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.gemini.api_key or vim.env.GEMINI_API_KEY
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("gemini")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.gemini.model
|
||||
@@ -153,7 +167,7 @@ end
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate(prompt, context, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local model = get_model()
|
||||
|
||||
-- Log the request
|
||||
@@ -203,7 +217,7 @@ end
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil) Callback with raw response
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local model = get_model()
|
||||
|
||||
logs.request("gemini", model)
|
||||
@@ -216,8 +230,8 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
return
|
||||
end
|
||||
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
local tools_module = require("codetyper.core.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agents")
|
||||
|
||||
-- Build system prompt with agent instructions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
@@ -1,8 +1,8 @@
|
||||
---@mod codetyper.llm LLM interface for Codetyper.nvim
|
||||
|
||||
local M = {}
|
||||
local lang_map = require("codetyper.utils.langmap")
|
||||
local utils = require("codetyper.utils")
|
||||
local lang_map = require("codetyper.support.langmap")
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Get the appropriate LLM client based on configuration
|
||||
---@return table LLM client module
|
||||
@@ -10,16 +10,14 @@ function M.get_client()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
return require("codetyper.llm.claude")
|
||||
elseif config.llm.provider == "ollama" then
|
||||
return require("codetyper.llm.ollama")
|
||||
if config.llm.provider == "ollama" then
|
||||
return require("codetyper.core.llm.ollama")
|
||||
elseif config.llm.provider == "openai" then
|
||||
return require("codetyper.llm.openai")
|
||||
return require("codetyper.core.llm.openai")
|
||||
elseif config.llm.provider == "gemini" then
|
||||
return require("codetyper.llm.gemini")
|
||||
return require("codetyper.core.llm.gemini")
|
||||
elseif config.llm.provider == "copilot" then
|
||||
return require("codetyper.llm.copilot")
|
||||
return require("codetyper.core.llm.copilot")
|
||||
else
|
||||
error("Unknown LLM provider: " .. config.llm.provider)
|
||||
end
|
||||
@@ -34,6 +32,32 @@ function M.generate(prompt, context, callback)
|
||||
client.generate(prompt, context, callback)
|
||||
end
|
||||
|
||||
--- Smart generate with automatic provider selection based on brain memories
|
||||
--- Prefers Ollama when context is rich, falls back to Copilot otherwise.
|
||||
--- Implements verification pondering to reinforce Ollama accuracy over time.
|
||||
---@param prompt string The user's prompt
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback
|
||||
function M.smart_generate(prompt, context, callback)
|
||||
local selector = require("codetyper.core.llm.selector")
|
||||
selector.smart_generate(prompt, context, callback)
|
||||
end
|
||||
|
||||
--- Get accuracy statistics for providers
|
||||
---@return table Statistics for each provider
|
||||
function M.get_accuracy_stats()
|
||||
local selector = require("codetyper.core.llm.selector")
|
||||
return selector.get_accuracy_stats()
|
||||
end
|
||||
|
||||
--- Report user feedback on response quality (for reinforcement learning)
|
||||
---@param provider string Which provider generated the response
|
||||
---@param was_correct boolean Whether the response was good
|
||||
function M.report_feedback(provider, was_correct)
|
||||
local selector = require("codetyper.core.llm.selector")
|
||||
selector.report_feedback(provider, was_correct)
|
||||
end
|
||||
|
||||
--- Build the system prompt for code generation
|
||||
---@param context table Context information
|
||||
---@return string System prompt
|
||||
@@ -50,7 +74,49 @@ function M.build_system_prompt(context)
|
||||
system = system:gsub("{{language}}", context.language or "unknown")
|
||||
system = system:gsub("{{filepath}}", context.file_path or "unknown")
|
||||
|
||||
-- Add file content with analysis hints
|
||||
-- For agent mode, include project context
|
||||
if prompt_type == "agent" then
|
||||
local project_info = "\n\n## PROJECT CONTEXT\n"
|
||||
|
||||
if context.project_root then
|
||||
project_info = project_info .. "- Project root: " .. context.project_root .. "\n"
|
||||
end
|
||||
if context.cwd then
|
||||
project_info = project_info .. "- Working directory: " .. context.cwd .. "\n"
|
||||
end
|
||||
if context.project_type then
|
||||
project_info = project_info .. "- Project type: " .. context.project_type .. "\n"
|
||||
end
|
||||
if context.project_stats then
|
||||
project_info = project_info
|
||||
.. string.format(
|
||||
"- Stats: %d files, %d functions, %d classes\n",
|
||||
context.project_stats.files or 0,
|
||||
context.project_stats.functions or 0,
|
||||
context.project_stats.classes or 0
|
||||
)
|
||||
end
|
||||
if context.file_path then
|
||||
project_info = project_info .. "- Current file: " .. context.file_path .. "\n"
|
||||
end
|
||||
|
||||
system = system .. project_info
|
||||
return system
|
||||
end
|
||||
|
||||
-- For "ask" or "explain" mode, don't add code generation instructions
|
||||
if prompt_type == "ask" or prompt_type == "explain" then
|
||||
-- Just add context about the file if available
|
||||
if context.file_path then
|
||||
system = system .. "\n\nContext: The user is working with " .. context.file_path
|
||||
if context.language then
|
||||
system = system .. " (" .. context.language .. ")"
|
||||
end
|
||||
end
|
||||
return system
|
||||
end
|
||||
|
||||
-- Add file content with analysis hints (for code generation modes)
|
||||
if context.file_content and context.file_content ~= "" then
|
||||
system = system .. "\n\n===== EXISTING FILE CONTENT (analyze and match this style) =====\n"
|
||||
system = system .. context.file_content
|
||||
@@ -74,13 +140,34 @@ function M.build_context(target_path, prompt_type)
|
||||
local content = utils.read_file(target_path)
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
return {
|
||||
local context = {
|
||||
file_content = content,
|
||||
language = lang_map[ext] or ext,
|
||||
extension = ext,
|
||||
prompt_type = prompt_type,
|
||||
file_path = target_path,
|
||||
}
|
||||
|
||||
-- For agent mode, include additional project context
|
||||
if prompt_type == "agent" then
|
||||
local project_root = utils.get_project_root()
|
||||
context.project_root = project_root
|
||||
|
||||
-- Try to get project info from indexer
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
local status = indexer.get_status()
|
||||
if status.indexed then
|
||||
context.project_type = status.project_type
|
||||
context.project_stats = status.stats
|
||||
end
|
||||
end
|
||||
|
||||
-- Include working directory
|
||||
context.cwd = vim.fn.getcwd()
|
||||
end
|
||||
|
||||
return context
|
||||
end
|
||||
|
||||
--- Parse LLM response and extract code
|
||||
@@ -2,24 +2,36 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local llm = require("codetyper.llm")
|
||||
local utils = require("codetyper.support.utils")
|
||||
local llm = require("codetyper.core.llm")
|
||||
|
||||
--- Get Ollama host from config
|
||||
--- Get Ollama host from stored credentials or config
|
||||
---@return string Host URL
|
||||
local function get_host()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_host = credentials.get_ollama_host()
|
||||
if stored_host then
|
||||
return stored_host
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.ollama.host
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("ollama")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.ollama.model
|
||||
end
|
||||
|
||||
@@ -125,7 +137,7 @@ end
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate(prompt, context, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local model = get_model()
|
||||
|
||||
-- Log the request
|
||||
@@ -199,47 +211,41 @@ function M.validate()
|
||||
return true
|
||||
end
|
||||
|
||||
--- Build system prompt for agent mode with tool instructions
|
||||
--- Generate with tool use support for agentic mode (text-based tool calling)
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@return string System prompt
|
||||
local function build_agent_system_prompt(context)
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil) Callback with Claude-like response format
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local agent_prompts = require("codetyper.prompts.agents")
|
||||
local tools_module = require("codetyper.core.tools")
|
||||
|
||||
local system_prompt = agent_prompts.system .. "\n\n"
|
||||
system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n"
|
||||
system_prompt = system_prompt .. agent_prompts.tool_instructions
|
||||
logs.request("ollama", get_model())
|
||||
logs.thinking("Preparing agent request...")
|
||||
|
||||
-- Add context about current file if available
|
||||
if context.file_path then
|
||||
system_prompt = system_prompt .. "\n\nCurrent working context:\n"
|
||||
system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n"
|
||||
if context.language then
|
||||
system_prompt = system_prompt .. "- Language: " .. context.language .. "\n"
|
||||
-- Build system prompt with tool instructions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
|
||||
|
||||
-- Add tool descriptions
|
||||
system_prompt = system_prompt .. "\n\n## Available Tools\n"
|
||||
system_prompt = system_prompt .. "Call tools by outputting JSON in this exact format:\n"
|
||||
system_prompt = system_prompt .. '```json\n{"tool": "tool_name", "arguments": {...}}\n```\n\n'
|
||||
|
||||
for _, tool in ipairs(tool_definitions) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
system_prompt = system_prompt .. string.format("### %s\n%s\n\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
-- Add project root info
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
system_prompt = system_prompt .. "- Project root: " .. root .. "\n"
|
||||
end
|
||||
|
||||
return system_prompt
|
||||
end
|
||||
|
||||
--- Build request body for Ollama API with tools (chat format)
|
||||
---@param messages table[] Conversation messages
|
||||
---@param context table Context information
|
||||
---@return table Request body
|
||||
local function build_tools_request_body(messages, context)
|
||||
local system_prompt = build_agent_system_prompt(context)
|
||||
|
||||
-- Convert messages to Ollama chat format
|
||||
local ollama_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
local content = msg.content
|
||||
-- Handle complex content (like tool results)
|
||||
if type(content) == "table" then
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(content) do
|
||||
@@ -251,14 +257,10 @@ local function build_tools_request_body(messages, context)
|
||||
end
|
||||
content = table.concat(text_parts, "\n")
|
||||
end
|
||||
|
||||
table.insert(ollama_messages, {
|
||||
role = msg.role,
|
||||
content = content,
|
||||
})
|
||||
table.insert(ollama_messages, { role = msg.role, content = content })
|
||||
end
|
||||
|
||||
return {
|
||||
local body = {
|
||||
model = get_model(),
|
||||
messages = ollama_messages,
|
||||
system = system_prompt,
|
||||
@@ -268,16 +270,15 @@ local function build_tools_request_body(messages, context)
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Make HTTP request to Ollama chat API
|
||||
---@param body table Request body
|
||||
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
|
||||
local function make_chat_request(body, callback)
|
||||
local host = get_host()
|
||||
local url = host .. "/api/chat"
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local prompt_estimate = logs.estimate_tokens(json_body)
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Ollama API...")
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
@@ -302,196 +303,82 @@ local function make_chat_request(body, callback)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response", nil)
|
||||
logs.error("Failed to parse Ollama response")
|
||||
callback(nil, "Failed to parse Ollama response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error", nil)
|
||||
logs.error(response.error or "Ollama API error")
|
||||
callback(nil, response.error or "Ollama API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract usage info
|
||||
local usage = {
|
||||
prompt_tokens = response.prompt_eval_count or 0,
|
||||
response_tokens = response.eval_count or 0,
|
||||
}
|
||||
-- Log token usage and record cost (Ollama is free but we track usage)
|
||||
if response.prompt_eval_count or response.eval_count then
|
||||
logs.response(response.prompt_eval_count or 0, response.eval_count or 0, "stop")
|
||||
|
||||
-- Return the message content for agent parsing
|
||||
if response.message and response.message.content then
|
||||
vim.schedule(function()
|
||||
callback(response.message.content, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No response from Ollama", nil)
|
||||
end)
|
||||
-- Record usage for cost tracking (free for local models)
|
||||
local cost = require("codetyper.core.cost")
|
||||
cost.record_usage(
|
||||
get_model(),
|
||||
response.prompt_eval_count or 0,
|
||||
response.eval_count or 0,
|
||||
0 -- No cached tokens for Ollama
|
||||
)
|
||||
end
|
||||
|
||||
-- Parse the response text for tool calls
|
||||
local content_text = response.message and response.message.content or ""
|
||||
local converted = { content = {}, stop_reason = "end_turn" }
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = content_text:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok_json, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok_json and parsed.tool then
|
||||
table.insert(converted.content, {
|
||||
type = "tool_use",
|
||||
id = "call_" .. string.format("%x", os.time()) .. "_" .. string.format("%x", math.random(0, 0xFFFF)),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
logs.thinking("Tool call: " .. parsed.tool)
|
||||
content_text = content_text:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
converted.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
-- Add text content
|
||||
if content_text and content_text ~= "" then
|
||||
table.insert(converted.content, 1, { type = "text", text = content_text })
|
||||
logs.thinking("Response contains text")
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
callback(converted, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
|
||||
logs.error("Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
-- Don't double-report errors
|
||||
vim.schedule(function()
|
||||
logs.error("Ollama API request failed with code: " .. code)
|
||||
callback(nil, "Ollama API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Generate response with tools using Ollama API
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tools table Tool definitions (embedded in prompt for Ollama)
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate_with_tools(messages, context, tools, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
-- Log the request
|
||||
local model = get_model()
|
||||
logs.request("ollama", model)
|
||||
logs.thinking("Preparing API request...")
|
||||
|
||||
local body = build_tools_request_body(messages, context)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
|
||||
make_chat_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn")
|
||||
end
|
||||
|
||||
-- Log if response contains tool calls
|
||||
if response then
|
||||
local parser = require("codetyper.agent.parser")
|
||||
local parsed = parser.parse_ollama_response(response)
|
||||
if #parsed.tool_calls > 0 then
|
||||
for _, tc in ipairs(parsed.tool_calls) do
|
||||
logs.thinking("Tool call: " .. tc.name)
|
||||
end
|
||||
end
|
||||
if parsed.text and parsed.text ~= "" then
|
||||
logs.thinking("Response contains text")
|
||||
end
|
||||
end
|
||||
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Generate with tool use support for agentic mode (simulated via prompts)
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback with response text
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
|
||||
-- Build system prompt with agent instructions and tool definitions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. tools_module.to_prompt_format()
|
||||
|
||||
-- Flatten messages to single prompt (Ollama's generate API)
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if type(msg.content) == "string" then
|
||||
local role_prefix = msg.role == "user" and "User" or "Assistant"
|
||||
table.insert(prompt_parts, role_prefix .. ": " .. msg.content)
|
||||
elseif type(msg.content) == "table" then
|
||||
-- Handle tool results
|
||||
for _, item in ipairs(msg.content) do
|
||||
if item.type == "tool_result" then
|
||||
table.insert(prompt_parts, "Tool result: " .. item.content)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local body = {
|
||||
model = get_model(),
|
||||
system = system_prompt,
|
||||
prompt = table.concat(prompt_parts, "\n\n"),
|
||||
stream = false,
|
||||
options = {
|
||||
temperature = 0.2,
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
|
||||
local host = get_host()
|
||||
local url = host .. "/api/generate"
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X", "POST",
|
||||
url,
|
||||
"-H", "Content-Type: application/json",
|
||||
"-d", json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Return raw response text for parser to handle
|
||||
vim.schedule(function()
|
||||
callback(response.response or "", nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -2,31 +2,52 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local llm = require("codetyper.llm")
|
||||
local utils = require("codetyper.support.utils")
|
||||
local llm = require("codetyper.core.llm")
|
||||
|
||||
--- OpenAI API endpoint
|
||||
local API_URL = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
--- Get API key from config or environment
|
||||
--- Get API key from stored credentials, config, or environment
|
||||
---@return string|nil API key
|
||||
local function get_api_key()
|
||||
-- Priority: stored credentials > config > environment
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_key = credentials.get_api_key("openai")
|
||||
if stored_key then
|
||||
return stored_key
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.api_key or vim.env.OPENAI_API_KEY
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
--- Get model from stored credentials or config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
-- Priority: stored credentials > config
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_model = credentials.get_model("openai")
|
||||
if stored_model then
|
||||
return stored_model
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.model
|
||||
end
|
||||
|
||||
--- Get endpoint from config (allows custom endpoints like Azure, OpenRouter)
|
||||
--- Get endpoint from stored credentials or config (allows custom endpoints like Azure, OpenRouter)
|
||||
---@return string API endpoint
|
||||
local function get_endpoint()
|
||||
-- Priority: stored credentials > config > default
|
||||
local credentials = require("codetyper.credentials")
|
||||
local stored_endpoint = credentials.get_endpoint("openai")
|
||||
if stored_endpoint then
|
||||
return stored_endpoint
|
||||
end
|
||||
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
return config.llm.openai.endpoint or API_URL
|
||||
@@ -137,7 +158,7 @@ end
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate(prompt, context, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local model = get_model()
|
||||
|
||||
-- Log the request
|
||||
@@ -187,7 +208,7 @@ end
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil) Callback with raw response
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local model = get_model()
|
||||
|
||||
logs.request("openai", model)
|
||||
@@ -200,8 +221,8 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
return
|
||||
end
|
||||
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
local tools_module = require("codetyper.core.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agents")
|
||||
|
||||
-- Build system prompt with agent instructions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
@@ -284,9 +305,18 @@ function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
return
|
||||
end
|
||||
|
||||
-- Log token usage
|
||||
-- Log token usage and record cost
|
||||
if response.usage then
|
||||
logs.response(response.usage.prompt_tokens or 0, response.usage.completion_tokens or 0, "stop")
|
||||
|
||||
-- Record usage for cost tracking
|
||||
local cost = require("codetyper.core.cost")
|
||||
cost.record_usage(
|
||||
model,
|
||||
response.usage.prompt_tokens or 0,
|
||||
response.usage.completion_tokens or 0,
|
||||
response.usage.prompt_tokens_details and response.usage.prompt_tokens_details.cached_tokens or 0
|
||||
)
|
||||
end
|
||||
|
||||
-- Convert to Claude-like format for parser compatibility
|
||||
@@ -4,6 +4,9 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local params = require("codetyper.params.agents.parser")
|
||||
|
||||
|
||||
---@class ParsedResponse
|
||||
---@field text string Text content from the response
|
||||
---@field tool_calls ToolCall[] List of tool calls
|
||||
@@ -48,11 +51,11 @@ function M.parse_ollama_response(response_text)
|
||||
local result = {
|
||||
text = response_text,
|
||||
tool_calls = {},
|
||||
stop_reason = "end_turn",
|
||||
stop_reason = params.defaults.stop_reason,
|
||||
}
|
||||
|
||||
-- Pattern to find JSON tool blocks in fenced code blocks
|
||||
local fenced_pattern = "```json%s*(%b{})%s*```"
|
||||
local fenced_pattern = params.patterns.fenced_json
|
||||
|
||||
-- Find all fenced JSON blocks
|
||||
for json_str in response_text:gmatch(fenced_pattern) do
|
||||
@@ -63,14 +66,14 @@ function M.parse_ollama_response(response_text)
|
||||
name = parsed.tool,
|
||||
parameters = parsed.parameters,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
result.stop_reason = params.defaults.tool_stop_reason
|
||||
end
|
||||
end
|
||||
|
||||
-- Also try to find inline JSON (not in code blocks)
|
||||
-- Pattern for {"tool": "...", "parameters": {...}}
|
||||
if #result.tool_calls == 0 then
|
||||
local inline_pattern = '(%{"tool"%s*:%s*"[^"]+"%s*,%s*"parameters"%s*:%s*%b{}%})'
|
||||
local inline_pattern = params.patterns.inline_json
|
||||
for json_str in response_text:gmatch(inline_pattern) do
|
||||
local ok, parsed = pcall(vim.json.decode, json_str)
|
||||
if ok and parsed.tool and parsed.parameters then
|
||||
@@ -79,15 +82,15 @@ function M.parse_ollama_response(response_text)
|
||||
name = parsed.tool,
|
||||
parameters = parsed.parameters,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
result.stop_reason = params.defaults.tool_stop_reason
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Clean tool JSON from displayed text
|
||||
if #result.tool_calls > 0 then
|
||||
result.text = result.text:gsub("```json%s*%b{}%s*```", "[Tool call]")
|
||||
result.text = result.text:gsub('%{"tool"%s*:%s*"[^"]+"%s*,%s*"parameters"%s*:%s*%b{}%}', "[Tool call]")
|
||||
result.text = result.text:gsub(params.patterns.fenced_json, params.defaults.replacement_text)
|
||||
result.text = result.text:gsub(params.patterns.inline_json, params.defaults.replacement_text)
|
||||
end
|
||||
|
||||
return result
|
||||
514
lua/codetyper/core/llm/selector.lua
Normal file
514
lua/codetyper/core/llm/selector.lua
Normal file
@@ -0,0 +1,514 @@
|
||||
---@mod codetyper.llm.selector Smart LLM selection with memory-based confidence
|
||||
---@brief [[
|
||||
--- Intelligent LLM provider selection based on brain memories.
|
||||
--- Prefers local Ollama when context is rich, falls back to Copilot otherwise.
|
||||
--- Implements verification pondering to reinforce Ollama accuracy over time.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class SelectionResult
|
||||
---@field provider string Selected provider name
|
||||
---@field confidence number Confidence score (0-1)
|
||||
---@field memory_count number Number of relevant memories found
|
||||
---@field reason string Human-readable reason for selection
|
||||
|
||||
---@class PonderResult
|
||||
---@field ollama_response string Ollama's response
|
||||
---@field verifier_response string Verifier's response
|
||||
---@field agreement_score number How much they agree (0-1)
|
||||
---@field ollama_correct boolean Whether Ollama was deemed correct
|
||||
---@field feedback string Feedback for learning
|
||||
|
||||
--- Minimum memories required for high confidence
|
||||
local MIN_MEMORIES_FOR_LOCAL = 3
|
||||
|
||||
--- Minimum memory relevance score for local provider
|
||||
local MIN_RELEVANCE_FOR_LOCAL = 0.6
|
||||
|
||||
--- Agreement threshold for Ollama verification
|
||||
local AGREEMENT_THRESHOLD = 0.7
|
||||
|
||||
--- Pondering sample rate (0-1) - how often to verify Ollama
|
||||
local PONDER_SAMPLE_RATE = 0.2
|
||||
|
||||
--- Provider accuracy tracking (persisted in brain)
|
||||
local accuracy_cache = {
|
||||
ollama = { correct = 0, total = 0 },
|
||||
copilot = { correct = 0, total = 0 },
|
||||
}
|
||||
|
||||
--- Get the brain module safely
|
||||
---@return table|nil
|
||||
local function get_brain()
|
||||
local ok, brain = pcall(require, "codetyper.brain")
|
||||
if ok and brain.is_initialized and brain.is_initialized() then
|
||||
return brain
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Load accuracy stats from brain
|
||||
local function load_accuracy_stats()
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return
|
||||
end
|
||||
|
||||
-- Query for accuracy tracking nodes
|
||||
pcall(function()
|
||||
local result = brain.query({
|
||||
query = "provider_accuracy_stats",
|
||||
types = { "metric" },
|
||||
limit = 1,
|
||||
})
|
||||
|
||||
if result and result.nodes and #result.nodes > 0 then
|
||||
local node = result.nodes[1]
|
||||
if node.c and node.c.d then
|
||||
local ok, stats = pcall(vim.json.decode, node.c.d)
|
||||
if ok and stats then
|
||||
accuracy_cache = stats
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Save accuracy stats to brain
|
||||
local function save_accuracy_stats()
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return
|
||||
end
|
||||
|
||||
pcall(function()
|
||||
brain.learn({
|
||||
type = "metric",
|
||||
summary = "provider_accuracy_stats",
|
||||
detail = vim.json.encode(accuracy_cache),
|
||||
weight = 1.0,
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Calculate Ollama confidence based on historical accuracy
|
||||
---@return number confidence (0-1)
|
||||
local function get_ollama_historical_confidence()
|
||||
local stats = accuracy_cache.ollama
|
||||
if stats.total < 5 then
|
||||
-- Not enough data, return neutral confidence
|
||||
return 0.5
|
||||
end
|
||||
|
||||
local accuracy = stats.correct / stats.total
|
||||
-- Boost confidence if accuracy is high
|
||||
return math.min(1.0, accuracy * 1.2)
|
||||
end
|
||||
|
||||
--- Query brain for relevant context
|
||||
---@param prompt string User prompt
|
||||
---@param file_path string|nil Current file path
|
||||
---@return table result {memories: table[], relevance: number, count: number}
|
||||
local function query_brain_context(prompt, file_path)
|
||||
local result = {
|
||||
memories = {},
|
||||
relevance = 0,
|
||||
count = 0,
|
||||
}
|
||||
|
||||
local brain = get_brain()
|
||||
if not brain then
|
||||
return result
|
||||
end
|
||||
|
||||
-- Query brain with multiple dimensions
|
||||
local ok, query_result = pcall(function()
|
||||
return brain.query({
|
||||
query = prompt,
|
||||
file = file_path,
|
||||
limit = 10,
|
||||
types = { "pattern", "correction", "convention", "fact" },
|
||||
})
|
||||
end)
|
||||
|
||||
if not ok or not query_result then
|
||||
return result
|
||||
end
|
||||
|
||||
result.memories = query_result.nodes or {}
|
||||
result.count = #result.memories
|
||||
|
||||
-- Calculate average relevance
|
||||
if result.count > 0 then
|
||||
local total_relevance = 0
|
||||
for _, node in ipairs(result.memories) do
|
||||
-- Use node weight and success rate as relevance indicators
|
||||
local node_relevance = (node.sc and node.sc.w or 0.5) * (node.sc and node.sc.sr or 0.5)
|
||||
total_relevance = total_relevance + node_relevance
|
||||
end
|
||||
result.relevance = total_relevance / result.count
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Select the best LLM provider based on context
|
||||
---@param prompt string User prompt
|
||||
---@param context table LLM context
|
||||
---@return SelectionResult
|
||||
function M.select_provider(prompt, context)
|
||||
-- Load accuracy stats on first call
|
||||
if accuracy_cache.ollama.total == 0 then
|
||||
load_accuracy_stats()
|
||||
end
|
||||
|
||||
local file_path = context.file_path
|
||||
|
||||
-- Query brain for relevant memories
|
||||
local brain_context = query_brain_context(prompt, file_path)
|
||||
|
||||
-- Calculate base confidence from memories
|
||||
local memory_confidence = 0
|
||||
if brain_context.count >= MIN_MEMORIES_FOR_LOCAL then
|
||||
memory_confidence = math.min(1.0, brain_context.count / 10) * brain_context.relevance
|
||||
end
|
||||
|
||||
-- Factor in historical Ollama accuracy
|
||||
local historical_confidence = get_ollama_historical_confidence()
|
||||
|
||||
-- Combined confidence score
|
||||
local combined_confidence = (memory_confidence * 0.6) + (historical_confidence * 0.4)
|
||||
|
||||
-- Decision logic
|
||||
local provider = "copilot" -- Default to more capable
|
||||
local reason = ""
|
||||
|
||||
if brain_context.count >= MIN_MEMORIES_FOR_LOCAL and combined_confidence >= MIN_RELEVANCE_FOR_LOCAL then
|
||||
provider = "ollama"
|
||||
reason = string.format(
|
||||
"Rich context: %d memories (%.1f%% relevance), historical accuracy: %.1f%%",
|
||||
brain_context.count,
|
||||
brain_context.relevance * 100,
|
||||
historical_confidence * 100
|
||||
)
|
||||
elseif brain_context.count > 0 and combined_confidence >= 0.4 then
|
||||
-- Medium confidence - use Ollama but with pondering
|
||||
provider = "ollama"
|
||||
reason = string.format(
|
||||
"Moderate context: %d memories, will verify with pondering",
|
||||
brain_context.count
|
||||
)
|
||||
else
|
||||
reason = string.format(
|
||||
"Insufficient context: %d memories (need %d), using capable provider",
|
||||
brain_context.count,
|
||||
MIN_MEMORIES_FOR_LOCAL
|
||||
)
|
||||
end
|
||||
|
||||
return {
|
||||
provider = provider,
|
||||
confidence = combined_confidence,
|
||||
memory_count = brain_context.count,
|
||||
reason = reason,
|
||||
memories = brain_context.memories,
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if we should ponder (verify) this Ollama response
|
||||
---@param confidence number Current confidence level
|
||||
---@return boolean
|
||||
function M.should_ponder(confidence)
|
||||
-- Always ponder when confidence is medium
|
||||
if confidence >= 0.4 and confidence < 0.7 then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Random sampling for high confidence to keep learning
|
||||
if confidence >= 0.7 then
|
||||
return math.random() < PONDER_SAMPLE_RATE
|
||||
end
|
||||
|
||||
-- Low confidence shouldn't reach Ollama anyway
|
||||
return false
|
||||
end
|
||||
|
||||
--- Calculate agreement score between two responses
|
||||
---@param response1 string First response
|
||||
---@param response2 string Second response
|
||||
---@return number Agreement score (0-1)
|
||||
local function calculate_agreement(response1, response2)
|
||||
-- Normalize responses
|
||||
local norm1 = response1:lower():gsub("%s+", " "):gsub("[^%w%s]", "")
|
||||
local norm2 = response2:lower():gsub("%s+", " "):gsub("[^%w%s]", "")
|
||||
|
||||
-- Extract words
|
||||
local words1 = {}
|
||||
for word in norm1:gmatch("%w+") do
|
||||
words1[word] = (words1[word] or 0) + 1
|
||||
end
|
||||
|
||||
local words2 = {}
|
||||
for word in norm2:gmatch("%w+") do
|
||||
words2[word] = (words2[word] or 0) + 1
|
||||
end
|
||||
|
||||
-- Calculate Jaccard similarity
|
||||
local intersection = 0
|
||||
local union = 0
|
||||
|
||||
for word, count1 in pairs(words1) do
|
||||
local count2 = words2[word] or 0
|
||||
intersection = intersection + math.min(count1, count2)
|
||||
union = union + math.max(count1, count2)
|
||||
end
|
||||
|
||||
for word, count2 in pairs(words2) do
|
||||
if not words1[word] then
|
||||
union = union + count2
|
||||
end
|
||||
end
|
||||
|
||||
if union == 0 then
|
||||
return 1.0 -- Both empty
|
||||
end
|
||||
|
||||
-- Also check structural similarity (code structure)
|
||||
local struct_score = 0
|
||||
local function_count1 = select(2, response1:gsub("function", ""))
|
||||
local function_count2 = select(2, response2:gsub("function", ""))
|
||||
if function_count1 > 0 or function_count2 > 0 then
|
||||
struct_score = 1 - math.abs(function_count1 - function_count2) / math.max(function_count1, function_count2, 1)
|
||||
else
|
||||
struct_score = 1.0
|
||||
end
|
||||
|
||||
-- Combined score
|
||||
local jaccard = intersection / union
|
||||
return (jaccard * 0.7) + (struct_score * 0.3)
|
||||
end
|
||||
|
||||
--- Ponder (verify) Ollama's response with another LLM
|
||||
---@param prompt string Original prompt
|
||||
---@param context table LLM context
|
||||
---@param ollama_response string Ollama's response
|
||||
---@param callback fun(result: PonderResult) Callback with pondering result
|
||||
function M.ponder(prompt, context, ollama_response, callback)
|
||||
-- Use Copilot as verifier
|
||||
local copilot = require("codetyper.core.llm.copilot")
|
||||
|
||||
-- Build verification prompt
|
||||
local verify_prompt = prompt
|
||||
|
||||
copilot.generate(verify_prompt, context, function(verifier_response, error)
|
||||
if error or not verifier_response then
|
||||
-- Verification failed, assume Ollama is correct
|
||||
callback({
|
||||
ollama_response = ollama_response,
|
||||
verifier_response = "",
|
||||
agreement_score = 1.0,
|
||||
ollama_correct = true,
|
||||
feedback = "Verification unavailable, trusting Ollama",
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Calculate agreement
|
||||
local agreement = calculate_agreement(ollama_response, verifier_response)
|
||||
|
||||
-- Determine if Ollama was correct
|
||||
local ollama_correct = agreement >= AGREEMENT_THRESHOLD
|
||||
|
||||
-- Generate feedback
|
||||
local feedback
|
||||
if ollama_correct then
|
||||
feedback = string.format("Agreement: %.1f%% - Ollama response validated", agreement * 100)
|
||||
else
|
||||
feedback = string.format(
|
||||
"Disagreement: %.1f%% - Ollama may need correction",
|
||||
(1 - agreement) * 100
|
||||
)
|
||||
end
|
||||
|
||||
-- Update accuracy tracking
|
||||
accuracy_cache.ollama.total = accuracy_cache.ollama.total + 1
|
||||
if ollama_correct then
|
||||
accuracy_cache.ollama.correct = accuracy_cache.ollama.correct + 1
|
||||
end
|
||||
save_accuracy_stats()
|
||||
|
||||
-- Learn from this verification
|
||||
local brain = get_brain()
|
||||
if brain then
|
||||
pcall(function()
|
||||
if ollama_correct then
|
||||
-- Reinforce the pattern
|
||||
brain.learn({
|
||||
type = "correction",
|
||||
summary = "Ollama verified correct",
|
||||
detail = string.format(
|
||||
"Prompt: %s\nAgreement: %.1f%%",
|
||||
prompt:sub(1, 100),
|
||||
agreement * 100
|
||||
),
|
||||
weight = 0.8,
|
||||
file = context.file_path,
|
||||
})
|
||||
else
|
||||
-- Learn the correction
|
||||
brain.learn({
|
||||
type = "correction",
|
||||
summary = "Ollama needed correction",
|
||||
detail = string.format(
|
||||
"Prompt: %s\nOllama: %s\nCorrect: %s",
|
||||
prompt:sub(1, 100),
|
||||
ollama_response:sub(1, 200),
|
||||
verifier_response:sub(1, 200)
|
||||
),
|
||||
weight = 0.9,
|
||||
file = context.file_path,
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
callback({
|
||||
ollama_response = ollama_response,
|
||||
verifier_response = verifier_response,
|
||||
agreement_score = agreement,
|
||||
ollama_correct = ollama_correct,
|
||||
feedback = feedback,
|
||||
})
|
||||
end)
|
||||
end
|
||||
|
||||
--- Smart generate with automatic provider selection and pondering
|
||||
---@param prompt string User prompt
|
||||
---@param context table LLM context
|
||||
---@param callback fun(response: string|nil, error: string|nil, metadata: table|nil) Callback
|
||||
function M.smart_generate(prompt, context, callback)
|
||||
-- Select provider
|
||||
local selection = M.select_provider(prompt, context)
|
||||
|
||||
-- Log selection
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
"LLM: %s (confidence: %.1f%%, %s)",
|
||||
selection.provider,
|
||||
selection.confidence * 100,
|
||||
selection.reason
|
||||
),
|
||||
})
|
||||
end)
|
||||
|
||||
-- Get the selected client
|
||||
local client
|
||||
if selection.provider == "ollama" then
|
||||
client = require("codetyper.core.llm.ollama")
|
||||
else
|
||||
client = require("codetyper.core.llm.copilot")
|
||||
end
|
||||
|
||||
-- Generate response
|
||||
client.generate(prompt, context, function(response, error)
|
||||
if error then
|
||||
-- Fallback on error
|
||||
if selection.provider == "ollama" then
|
||||
-- Try Copilot as fallback
|
||||
local copilot = require("codetyper.core.llm.copilot")
|
||||
copilot.generate(prompt, context, function(fallback_response, fallback_error)
|
||||
callback(fallback_response, fallback_error, {
|
||||
provider = "copilot",
|
||||
fallback = true,
|
||||
original_provider = "ollama",
|
||||
original_error = error,
|
||||
})
|
||||
end)
|
||||
return
|
||||
end
|
||||
callback(nil, error, { provider = selection.provider })
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if we should ponder
|
||||
if selection.provider == "ollama" and M.should_ponder(selection.confidence) then
|
||||
M.ponder(prompt, context, response, function(ponder_result)
|
||||
if ponder_result.ollama_correct then
|
||||
-- Ollama was correct, use its response
|
||||
callback(response, nil, {
|
||||
provider = "ollama",
|
||||
pondered = true,
|
||||
agreement = ponder_result.agreement_score,
|
||||
confidence = selection.confidence,
|
||||
})
|
||||
else
|
||||
-- Use verifier's response instead
|
||||
callback(ponder_result.verifier_response, nil, {
|
||||
provider = "copilot",
|
||||
pondered = true,
|
||||
agreement = ponder_result.agreement_score,
|
||||
original_provider = "ollama",
|
||||
corrected = true,
|
||||
})
|
||||
end
|
||||
end)
|
||||
else
|
||||
-- No pondering needed
|
||||
callback(response, nil, {
|
||||
provider = selection.provider,
|
||||
pondered = false,
|
||||
confidence = selection.confidence,
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Get current accuracy statistics
|
||||
---@return table {ollama: {correct, total, accuracy}, copilot: {correct, total, accuracy}}
|
||||
function M.get_accuracy_stats()
|
||||
local stats = {
|
||||
ollama = {
|
||||
correct = accuracy_cache.ollama.correct,
|
||||
total = accuracy_cache.ollama.total,
|
||||
accuracy = accuracy_cache.ollama.total > 0
|
||||
and (accuracy_cache.ollama.correct / accuracy_cache.ollama.total)
|
||||
or 0,
|
||||
},
|
||||
copilot = {
|
||||
correct = accuracy_cache.copilot.correct,
|
||||
total = accuracy_cache.copilot.total,
|
||||
accuracy = accuracy_cache.copilot.total > 0
|
||||
and (accuracy_cache.copilot.correct / accuracy_cache.copilot.total)
|
||||
or 0,
|
||||
},
|
||||
}
|
||||
return stats
|
||||
end
|
||||
|
||||
--- Reset accuracy statistics
|
||||
function M.reset_accuracy_stats()
|
||||
accuracy_cache = {
|
||||
ollama = { correct = 0, total = 0 },
|
||||
copilot = { correct = 0, total = 0 },
|
||||
}
|
||||
save_accuracy_stats()
|
||||
end
|
||||
|
||||
--- Report user feedback on response quality
|
||||
---@param provider string Which provider generated the response
|
||||
---@param was_correct boolean Whether the response was good
|
||||
function M.report_feedback(provider, was_correct)
|
||||
if accuracy_cache[provider] then
|
||||
accuracy_cache[provider].total = accuracy_cache[provider].total + 1
|
||||
if was_correct then
|
||||
accuracy_cache[provider].correct = accuracy_cache[provider].correct + 1
|
||||
end
|
||||
save_accuracy_stats()
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
291
lua/codetyper/core/memory/delta/commit.lua
Normal file
291
lua/codetyper/core/memory/delta/commit.lua
Normal file
@@ -0,0 +1,291 @@
|
||||
--- Brain Delta Commit Operations
|
||||
--- Git-like commit creation and management
|
||||
|
||||
local storage = require("codetyper.core.memory.storage")
|
||||
local hash_mod = require("codetyper.core.memory.hash")
|
||||
local diff_mod = require("codetyper.core.memory.delta.diff")
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Create a new delta commit
|
||||
---@param changes table[] Changes to commit
|
||||
---@param message string Commit message
|
||||
---@param trigger? string Trigger source
|
||||
---@return Delta|nil Created delta
|
||||
function M.create(changes, message, trigger)
|
||||
if not changes or #changes == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local now = os.time()
|
||||
local head = storage.get_head()
|
||||
|
||||
-- Create delta object
|
||||
local delta = {
|
||||
h = hash_mod.delta_hash(changes, head, now),
|
||||
p = head,
|
||||
ts = now,
|
||||
ch = {},
|
||||
m = {
|
||||
msg = message or "Unnamed commit",
|
||||
trig = trigger or "manual",
|
||||
},
|
||||
}
|
||||
|
||||
-- Process changes
|
||||
for _, change in ipairs(changes) do
|
||||
table.insert(delta.ch, {
|
||||
op = change.op,
|
||||
path = change.path,
|
||||
bh = change.bh,
|
||||
ah = change.ah,
|
||||
diff = change.diff,
|
||||
})
|
||||
end
|
||||
|
||||
-- Save delta
|
||||
storage.save_delta(delta)
|
||||
|
||||
-- Update HEAD
|
||||
storage.set_head(delta.h)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ dc = meta.dc + 1 })
|
||||
|
||||
return delta
|
||||
end
|
||||
|
||||
--- Get a delta by hash
|
||||
---@param delta_hash string Delta hash
|
||||
---@return Delta|nil
|
||||
function M.get(delta_hash)
|
||||
return storage.get_delta(delta_hash)
|
||||
end
|
||||
|
||||
--- Get the current HEAD delta
|
||||
---@return Delta|nil
|
||||
function M.get_head()
|
||||
local head_hash = storage.get_head()
|
||||
if not head_hash then
|
||||
return nil
|
||||
end
|
||||
return M.get(head_hash)
|
||||
end
|
||||
|
||||
--- Get delta history (ancestry chain)
|
||||
---@param limit? number Max entries
|
||||
---@param from_hash? string Starting hash (default: HEAD)
|
||||
---@return Delta[]
|
||||
function M.get_history(limit, from_hash)
|
||||
limit = limit or 50
|
||||
local history = {}
|
||||
local current_hash = from_hash or storage.get_head()
|
||||
|
||||
while current_hash and #history < limit do
|
||||
local delta = M.get(current_hash)
|
||||
if not delta then
|
||||
break
|
||||
end
|
||||
|
||||
table.insert(history, delta)
|
||||
current_hash = delta.p
|
||||
end
|
||||
|
||||
return history
|
||||
end
|
||||
|
||||
--- Check if a delta exists
|
||||
---@param delta_hash string Delta hash
|
||||
---@return boolean
|
||||
function M.exists(delta_hash)
|
||||
return M.get(delta_hash) ~= nil
|
||||
end
|
||||
|
||||
--- Get the path from one delta to another
|
||||
---@param from_hash string Start delta hash
|
||||
---@param to_hash string End delta hash
|
||||
---@return Delta[]|nil Path of deltas, or nil if no path
|
||||
function M.get_path(from_hash, to_hash)
|
||||
-- Build ancestry from both sides
|
||||
local from_ancestry = {}
|
||||
local current = from_hash
|
||||
while current do
|
||||
from_ancestry[current] = true
|
||||
local delta = M.get(current)
|
||||
if not delta then
|
||||
break
|
||||
end
|
||||
current = delta.p
|
||||
end
|
||||
|
||||
-- Walk from to_hash back to find common ancestor
|
||||
local path = {}
|
||||
current = to_hash
|
||||
while current do
|
||||
local delta = M.get(current)
|
||||
if not delta then
|
||||
break
|
||||
end
|
||||
|
||||
table.insert(path, 1, delta)
|
||||
|
||||
if from_ancestry[current] then
|
||||
-- Found common ancestor
|
||||
return path
|
||||
end
|
||||
|
||||
current = delta.p
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Get all changes between two deltas
|
||||
---@param from_hash string|nil Start delta hash (nil = beginning)
|
||||
---@param to_hash string End delta hash
|
||||
---@return table[] Combined changes
|
||||
function M.get_changes_between(from_hash, to_hash)
|
||||
local path = {}
|
||||
local current = to_hash
|
||||
|
||||
while current and current ~= from_hash do
|
||||
local delta = M.get(current)
|
||||
if not delta then
|
||||
break
|
||||
end
|
||||
table.insert(path, 1, delta)
|
||||
current = delta.p
|
||||
end
|
||||
|
||||
-- Collect all changes
|
||||
local changes = {}
|
||||
for _, delta in ipairs(path) do
|
||||
for _, change in ipairs(delta.ch) do
|
||||
table.insert(changes, change)
|
||||
end
|
||||
end
|
||||
|
||||
return changes
|
||||
end
|
||||
|
||||
--- Compute reverse changes for rollback
|
||||
---@param delta Delta Delta to reverse
|
||||
---@return table[] Reverse changes
|
||||
function M.compute_reverse(delta)
|
||||
local reversed = {}
|
||||
|
||||
for i = #delta.ch, 1, -1 do
|
||||
local change = delta.ch[i]
|
||||
local rev = {
|
||||
path = change.path,
|
||||
}
|
||||
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
rev.op = types.DELTA_OPS.DELETE
|
||||
rev.bh = change.ah
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
rev.op = types.DELTA_OPS.ADD
|
||||
rev.ah = change.bh
|
||||
elseif change.op == types.DELTA_OPS.MODIFY then
|
||||
rev.op = types.DELTA_OPS.MODIFY
|
||||
rev.bh = change.ah
|
||||
rev.ah = change.bh
|
||||
if change.diff then
|
||||
rev.diff = diff_mod.reverse(change.diff)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(reversed, rev)
|
||||
end
|
||||
|
||||
return reversed
|
||||
end
|
||||
|
||||
--- Squash multiple deltas into one
|
||||
---@param delta_hashes string[] Delta hashes to squash
|
||||
---@param message string Squash commit message
|
||||
---@return Delta|nil Squashed delta
|
||||
function M.squash(delta_hashes, message)
|
||||
if #delta_hashes == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Collect all changes in order
|
||||
local all_changes = {}
|
||||
for _, delta_hash in ipairs(delta_hashes) do
|
||||
local delta = M.get(delta_hash)
|
||||
if delta then
|
||||
for _, change in ipairs(delta.ch) do
|
||||
table.insert(all_changes, change)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Compact the changes
|
||||
local compacted = diff_mod.compact(all_changes)
|
||||
|
||||
return M.create(compacted, message, "squash")
|
||||
end
|
||||
|
||||
--- Get summary of a delta
|
||||
---@param delta Delta Delta to summarize
|
||||
---@return table Summary
|
||||
function M.summarize(delta)
|
||||
local adds = 0
|
||||
local mods = 0
|
||||
local dels = 0
|
||||
local paths = {}
|
||||
|
||||
for _, change in ipairs(delta.ch) do
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
adds = adds + 1
|
||||
elseif change.op == types.DELTA_OPS.MODIFY then
|
||||
mods = mods + 1
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
dels = dels + 1
|
||||
end
|
||||
|
||||
-- Extract category from path
|
||||
local parts = vim.split(change.path, ".", { plain = true })
|
||||
if parts[1] then
|
||||
paths[parts[1]] = true
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
hash = delta.h,
|
||||
parent = delta.p,
|
||||
timestamp = delta.ts,
|
||||
message = delta.m.msg,
|
||||
trigger = delta.m.trig,
|
||||
stats = {
|
||||
adds = adds,
|
||||
modifies = mods,
|
||||
deletes = dels,
|
||||
total = adds + mods + dels,
|
||||
},
|
||||
categories = vim.tbl_keys(paths),
|
||||
}
|
||||
end
|
||||
|
||||
--- Format delta for display
|
||||
---@param delta Delta Delta to format
|
||||
---@return string[] Lines
|
||||
function M.format(delta)
|
||||
local summary = M.summarize(delta)
|
||||
local lines = {
|
||||
string.format("commit %s", delta.h),
|
||||
string.format("Date: %s", os.date("%Y-%m-%d %H:%M:%S", delta.ts)),
|
||||
string.format("Parent: %s", delta.p or "(none)"),
|
||||
"",
|
||||
" " .. (delta.m.msg or "No message"),
|
||||
"",
|
||||
string.format(" %d additions, %d modifications, %d deletions", summary.stats.adds, summary.stats.modifies, summary.stats.deletes),
|
||||
}
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
return M
|
||||
261
lua/codetyper/core/memory/delta/diff.lua
Normal file
261
lua/codetyper/core/memory/delta/diff.lua
Normal file
@@ -0,0 +1,261 @@
|
||||
--- Brain Delta Diff Computation
|
||||
--- Field-level diff algorithms for delta versioning
|
||||
|
||||
local hash = require("codetyper.core.memory.hash")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Compute diff between two values
|
||||
---@param before any Before value
|
||||
---@param after any After value
|
||||
---@param path? string Current path
|
||||
---@return table[] Diff entries
|
||||
function M.compute(before, after, path)
|
||||
path = path or ""
|
||||
local diffs = {}
|
||||
|
||||
local before_type = type(before)
|
||||
local after_type = type(after)
|
||||
|
||||
-- Handle nil cases
|
||||
if before == nil and after == nil then
|
||||
return diffs
|
||||
end
|
||||
|
||||
if before == nil then
|
||||
table.insert(diffs, {
|
||||
path = path,
|
||||
op = "add",
|
||||
value = after,
|
||||
})
|
||||
return diffs
|
||||
end
|
||||
|
||||
if after == nil then
|
||||
table.insert(diffs, {
|
||||
path = path,
|
||||
op = "delete",
|
||||
value = before,
|
||||
})
|
||||
return diffs
|
||||
end
|
||||
|
||||
-- Type change
|
||||
if before_type ~= after_type then
|
||||
table.insert(diffs, {
|
||||
path = path,
|
||||
op = "replace",
|
||||
from = before,
|
||||
to = after,
|
||||
})
|
||||
return diffs
|
||||
end
|
||||
|
||||
-- Tables (recursive)
|
||||
if before_type == "table" then
|
||||
-- Get all keys
|
||||
local keys = {}
|
||||
for k in pairs(before) do
|
||||
keys[k] = true
|
||||
end
|
||||
for k in pairs(after) do
|
||||
keys[k] = true
|
||||
end
|
||||
|
||||
for k in pairs(keys) do
|
||||
local sub_path = path == "" and tostring(k) or (path .. "." .. tostring(k))
|
||||
local sub_diffs = M.compute(before[k], after[k], sub_path)
|
||||
for _, d in ipairs(sub_diffs) do
|
||||
table.insert(diffs, d)
|
||||
end
|
||||
end
|
||||
|
||||
return diffs
|
||||
end
|
||||
|
||||
-- Primitive comparison
|
||||
if before ~= after then
|
||||
table.insert(diffs, {
|
||||
path = path,
|
||||
op = "replace",
|
||||
from = before,
|
||||
to = after,
|
||||
})
|
||||
end
|
||||
|
||||
return diffs
|
||||
end
|
||||
|
||||
--- Apply a diff to a value
|
||||
---@param base any Base value
|
||||
---@param diffs table[] Diff entries
|
||||
---@return any Result value
|
||||
function M.apply(base, diffs)
|
||||
local result = vim.deepcopy(base) or {}
|
||||
|
||||
for _, diff in ipairs(diffs) do
|
||||
M.apply_single(result, diff)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Apply a single diff entry
|
||||
---@param target table Target table
|
||||
---@param diff table Diff entry
|
||||
function M.apply_single(target, diff)
|
||||
local path = diff.path
|
||||
local parts = vim.split(path, ".", { plain = true })
|
||||
|
||||
if #parts == 0 or parts[1] == "" then
|
||||
-- Root-level change
|
||||
if diff.op == "add" or diff.op == "replace" then
|
||||
for k, v in pairs(diff.value or diff.to or {}) do
|
||||
target[k] = v
|
||||
end
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Navigate to parent
|
||||
local current = target
|
||||
for i = 1, #parts - 1 do
|
||||
local key = parts[i]
|
||||
-- Try numeric key
|
||||
local num_key = tonumber(key)
|
||||
key = num_key or key
|
||||
|
||||
if current[key] == nil then
|
||||
current[key] = {}
|
||||
end
|
||||
current = current[key]
|
||||
end
|
||||
|
||||
-- Apply to final key
|
||||
local final_key = parts[#parts]
|
||||
local num_key = tonumber(final_key)
|
||||
final_key = num_key or final_key
|
||||
|
||||
if diff.op == "add" then
|
||||
current[final_key] = diff.value
|
||||
elseif diff.op == "delete" then
|
||||
current[final_key] = nil
|
||||
elseif diff.op == "replace" then
|
||||
current[final_key] = diff.to
|
||||
end
|
||||
end
|
||||
|
||||
--- Reverse a diff (for rollback)
|
||||
---@param diffs table[] Diff entries
|
||||
---@return table[] Reversed diffs
|
||||
function M.reverse(diffs)
|
||||
local reversed = {}
|
||||
|
||||
for i = #diffs, 1, -1 do
|
||||
local diff = diffs[i]
|
||||
local rev = {
|
||||
path = diff.path,
|
||||
}
|
||||
|
||||
if diff.op == "add" then
|
||||
rev.op = "delete"
|
||||
rev.value = diff.value
|
||||
elseif diff.op == "delete" then
|
||||
rev.op = "add"
|
||||
rev.value = diff.value
|
||||
elseif diff.op == "replace" then
|
||||
rev.op = "replace"
|
||||
rev.from = diff.to
|
||||
rev.to = diff.from
|
||||
end
|
||||
|
||||
table.insert(reversed, rev)
|
||||
end
|
||||
|
||||
return reversed
|
||||
end
|
||||
|
||||
--- Compact diffs (combine related changes)
|
||||
---@param diffs table[] Diff entries
|
||||
---@return table[] Compacted diffs
|
||||
function M.compact(diffs)
|
||||
local by_path = {}
|
||||
|
||||
for _, diff in ipairs(diffs) do
|
||||
local existing = by_path[diff.path]
|
||||
if existing then
|
||||
-- Combine: keep first "from", use last "to"
|
||||
if diff.op == "replace" then
|
||||
existing.to = diff.to
|
||||
elseif diff.op == "delete" then
|
||||
existing.op = "delete"
|
||||
existing.to = nil
|
||||
end
|
||||
else
|
||||
by_path[diff.path] = vim.deepcopy(diff)
|
||||
end
|
||||
end
|
||||
|
||||
-- Convert back to array, filter out no-ops
|
||||
local result = {}
|
||||
for _, diff in pairs(by_path) do
|
||||
-- Skip if add then delete (net no change)
|
||||
if not (diff.op == "delete" and diff.from == nil) then
|
||||
table.insert(result, diff)
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Create a minimal diff summary for storage
|
||||
---@param diffs table[] Diff entries
|
||||
---@return table Summary
|
||||
function M.summarize(diffs)
|
||||
local adds = 0
|
||||
local deletes = 0
|
||||
local replaces = 0
|
||||
local paths = {}
|
||||
|
||||
for _, diff in ipairs(diffs) do
|
||||
if diff.op == "add" then
|
||||
adds = adds + 1
|
||||
elseif diff.op == "delete" then
|
||||
deletes = deletes + 1
|
||||
elseif diff.op == "replace" then
|
||||
replaces = replaces + 1
|
||||
end
|
||||
|
||||
-- Extract top-level path
|
||||
local parts = vim.split(diff.path, ".", { plain = true })
|
||||
if parts[1] then
|
||||
paths[parts[1]] = true
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
adds = adds,
|
||||
deletes = deletes,
|
||||
replaces = replaces,
|
||||
paths = vim.tbl_keys(paths),
|
||||
total = adds + deletes + replaces,
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if two states are equal (no diff)
|
||||
---@param state1 any First state
|
||||
---@param state2 any Second state
|
||||
---@return boolean
|
||||
function M.equals(state1, state2)
|
||||
local diffs = M.compute(state1, state2)
|
||||
return #diffs == 0
|
||||
end
|
||||
|
||||
--- Get hash of diff for deduplication
|
||||
---@param diffs table[] Diff entries
|
||||
---@return string Hash
|
||||
function M.hash(diffs)
|
||||
return hash.compute_table(diffs)
|
||||
end
|
||||
|
||||
return M
|
||||
278
lua/codetyper/core/memory/delta/init.lua
Normal file
278
lua/codetyper/core/memory/delta/init.lua
Normal file
@@ -0,0 +1,278 @@
|
||||
--- Brain Delta Coordinator
|
||||
--- Git-like versioning system for brain state
|
||||
|
||||
local storage = require("codetyper.core.memory.storage")
|
||||
local commit_mod = require("codetyper.core.memory.delta.commit")
|
||||
local diff_mod = require("codetyper.core.memory.delta.diff")
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
-- Re-export submodules
|
||||
M.commit = commit_mod
|
||||
M.diff = diff_mod
|
||||
|
||||
--- Create a commit from pending graph changes
|
||||
---@param message string Commit message
|
||||
---@param trigger? string Trigger source
|
||||
---@return string|nil Delta hash
|
||||
function M.commit(message, trigger)
|
||||
local graph = require("codetyper.core.memory.graph")
|
||||
local changes = graph.get_pending_changes()
|
||||
|
||||
if #changes == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local delta = commit_mod.create(changes, message, trigger or "auto")
|
||||
if delta then
|
||||
return delta.h
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Rollback to a specific delta
|
||||
---@param target_hash string Target delta hash
|
||||
---@return boolean Success
|
||||
function M.rollback(target_hash)
|
||||
local current_hash = storage.get_head()
|
||||
if not current_hash then
|
||||
return false
|
||||
end
|
||||
|
||||
if current_hash == target_hash then
|
||||
return true -- Already at target
|
||||
end
|
||||
|
||||
-- Get path from target to current
|
||||
local deltas_to_reverse = {}
|
||||
local current = current_hash
|
||||
|
||||
while current and current ~= target_hash do
|
||||
local delta = commit_mod.get(current)
|
||||
if not delta then
|
||||
return false -- Broken chain
|
||||
end
|
||||
table.insert(deltas_to_reverse, delta)
|
||||
current = delta.p
|
||||
end
|
||||
|
||||
if current ~= target_hash then
|
||||
return false -- Target not in ancestry
|
||||
end
|
||||
|
||||
-- Apply reverse changes
|
||||
for _, delta in ipairs(deltas_to_reverse) do
|
||||
local reverse_changes = commit_mod.compute_reverse(delta)
|
||||
M.apply_changes(reverse_changes)
|
||||
end
|
||||
|
||||
-- Update HEAD
|
||||
storage.set_head(target_hash)
|
||||
|
||||
-- Create a rollback commit
|
||||
commit_mod.create({
|
||||
{
|
||||
op = types.DELTA_OPS.MODIFY,
|
||||
path = "meta.head",
|
||||
bh = current_hash,
|
||||
ah = target_hash,
|
||||
},
|
||||
}, "Rollback to " .. target_hash:sub(1, 8), "rollback")
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Apply changes to current state
|
||||
---@param changes table[] Changes to apply
|
||||
function M.apply_changes(changes)
|
||||
local node_mod = require("codetyper.core.memory.graph.node")
|
||||
|
||||
for _, change in ipairs(changes) do
|
||||
local parts = vim.split(change.path, ".", { plain = true })
|
||||
|
||||
if parts[1] == "nodes" and #parts >= 3 then
|
||||
local node_type = parts[2]
|
||||
local node_id = parts[3]
|
||||
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
-- Node was added, need to delete for reverse
|
||||
node_mod.delete(node_id)
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
-- Node was deleted, would need original data to restore
|
||||
-- This is a limitation - we'd need content storage
|
||||
elseif change.op == types.DELTA_OPS.MODIFY then
|
||||
-- Apply diff if available
|
||||
if change.diff then
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
local updated = diff_mod.apply(node, change.diff)
|
||||
-- Direct update without tracking
|
||||
local nodes = storage.get_nodes(node_type)
|
||||
nodes[node_id] = updated
|
||||
storage.save_nodes(node_type, nodes)
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif parts[1] == "graph" then
|
||||
-- Handle graph/edge changes
|
||||
local edge_mod = require("codetyper.core.memory.graph.edge")
|
||||
if parts[2] == "edges" and #parts >= 3 then
|
||||
local edge_id = parts[3]
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
-- Edge was added, delete for reverse
|
||||
-- Parse edge_id to get source/target
|
||||
local graph = storage.get_graph()
|
||||
if graph.edges and graph.edges[edge_id] then
|
||||
local edge = graph.edges[edge_id]
|
||||
edge_mod.delete(edge.s, edge.t, edge.ty)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Get delta history
|
||||
---@param limit? number Max entries
|
||||
---@return Delta[]
|
||||
function M.get_history(limit)
|
||||
return commit_mod.get_history(limit)
|
||||
end
|
||||
|
||||
--- Get formatted log
|
||||
---@param limit? number Max entries
|
||||
---@return string[] Log lines
|
||||
function M.log(limit)
|
||||
local history = M.get_history(limit or 20)
|
||||
local lines = {}
|
||||
|
||||
for _, delta in ipairs(history) do
|
||||
local formatted = commit_mod.format(delta)
|
||||
for _, line in ipairs(formatted) do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
--- Get current HEAD hash
|
||||
---@return string|nil
|
||||
function M.head()
|
||||
return storage.get_head()
|
||||
end
|
||||
|
||||
--- Check if there are uncommitted changes
|
||||
---@return boolean
|
||||
function M.has_pending()
|
||||
local graph = require("codetyper.core.memory.graph")
|
||||
local node_pending = require("codetyper.core.memory.graph.node").pending
|
||||
local edge_pending = require("codetyper.core.memory.graph.edge").pending
|
||||
return #node_pending > 0 or #edge_pending > 0
|
||||
end
|
||||
|
||||
--- Get status (like git status)
|
||||
---@return table Status info
|
||||
function M.status()
|
||||
local node_pending = require("codetyper.core.memory.graph.node").pending
|
||||
local edge_pending = require("codetyper.core.memory.graph.edge").pending
|
||||
|
||||
local adds = 0
|
||||
local mods = 0
|
||||
local dels = 0
|
||||
|
||||
for _, change in ipairs(node_pending) do
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
adds = adds + 1
|
||||
elseif change.op == types.DELTA_OPS.MODIFY then
|
||||
mods = mods + 1
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
dels = dels + 1
|
||||
end
|
||||
end
|
||||
|
||||
for _, change in ipairs(edge_pending) do
|
||||
if change.op == types.DELTA_OPS.ADD then
|
||||
adds = adds + 1
|
||||
elseif change.op == types.DELTA_OPS.DELETE then
|
||||
dels = dels + 1
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
head = storage.get_head(),
|
||||
pending = {
|
||||
adds = adds,
|
||||
modifies = mods,
|
||||
deletes = dels,
|
||||
total = adds + mods + dels,
|
||||
},
|
||||
clean = (adds + mods + dels) == 0,
|
||||
}
|
||||
end
|
||||
|
||||
--- Prune old deltas
|
||||
---@param keep number Number of recent deltas to keep
|
||||
---@return number Number of pruned deltas
|
||||
function M.prune_history(keep)
|
||||
keep = keep or 100
|
||||
local history = M.get_history(1000) -- Get all
|
||||
|
||||
if #history <= keep then
|
||||
return 0
|
||||
end
|
||||
|
||||
local pruned = 0
|
||||
local brain_dir = storage.get_brain_dir()
|
||||
|
||||
for i = keep + 1, #history do
|
||||
local delta = history[i]
|
||||
local filepath = brain_dir .. "/deltas/objects/" .. delta.h .. ".json"
|
||||
if os.remove(filepath) then
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ dc = math.max(0, meta.dc - pruned) })
|
||||
|
||||
return pruned
|
||||
end
|
||||
|
||||
--- Reset to initial state (dangerous!)
|
||||
---@return boolean Success
|
||||
function M.reset()
|
||||
-- Clear all nodes
|
||||
for _, node_type in pairs(types.NODE_TYPES) do
|
||||
storage.save_nodes(node_type .. "s", {})
|
||||
end
|
||||
|
||||
-- Clear graph
|
||||
storage.save_graph({ adj = {}, radj = {}, edges = {} })
|
||||
|
||||
-- Clear indices
|
||||
storage.save_index("by_file", {})
|
||||
storage.save_index("by_time", {})
|
||||
storage.save_index("by_symbol", {})
|
||||
|
||||
-- Reset meta
|
||||
storage.update_meta({
|
||||
head = nil,
|
||||
nc = 0,
|
||||
ec = 0,
|
||||
dc = 0,
|
||||
})
|
||||
|
||||
-- Clear pending
|
||||
require("codetyper.core.memory.graph.node").pending = {}
|
||||
require("codetyper.core.memory.graph.edge").pending = {}
|
||||
|
||||
storage.flush_all()
|
||||
return true
|
||||
end
|
||||
|
||||
return M
|
||||
367
lua/codetyper/core/memory/graph/edge.lua
Normal file
367
lua/codetyper/core/memory/graph/edge.lua
Normal file
@@ -0,0 +1,367 @@
|
||||
--- Brain Graph Edge Operations
|
||||
--- CRUD operations for node connections
|
||||
|
||||
local storage = require("codetyper.core.memory.storage")
|
||||
local hash = require("codetyper.core.memory.hash")
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Pending changes for delta tracking
|
||||
---@type table[]
|
||||
M.pending = {}
|
||||
|
||||
--- Create a new edge between nodes
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@param edge_type EdgeType Edge type
|
||||
---@param props? EdgeProps Edge properties
|
||||
---@return Edge|nil Created edge
|
||||
function M.create(source_id, target_id, edge_type, props)
|
||||
props = props or {}
|
||||
|
||||
local edge = {
|
||||
id = hash.edge_id(source_id, target_id),
|
||||
s = source_id,
|
||||
t = target_id,
|
||||
ty = edge_type,
|
||||
p = {
|
||||
w = props.w or 0.5,
|
||||
dir = props.dir or "bi",
|
||||
r = props.r,
|
||||
},
|
||||
ts = os.time(),
|
||||
}
|
||||
|
||||
-- Update adjacency lists
|
||||
local graph = storage.get_graph()
|
||||
|
||||
-- Forward adjacency
|
||||
graph.adj[source_id] = graph.adj[source_id] or {}
|
||||
graph.adj[source_id][edge_type] = graph.adj[source_id][edge_type] or {}
|
||||
|
||||
-- Check for duplicate
|
||||
if vim.tbl_contains(graph.adj[source_id][edge_type], target_id) then
|
||||
-- Edge exists, strengthen it instead
|
||||
return M.strengthen(source_id, target_id, edge_type)
|
||||
end
|
||||
|
||||
table.insert(graph.adj[source_id][edge_type], target_id)
|
||||
|
||||
-- Reverse adjacency
|
||||
graph.radj[target_id] = graph.radj[target_id] or {}
|
||||
graph.radj[target_id][edge_type] = graph.radj[target_id][edge_type] or {}
|
||||
table.insert(graph.radj[target_id][edge_type], source_id)
|
||||
|
||||
-- Store edge properties separately (for weight/metadata)
|
||||
graph.edges = graph.edges or {}
|
||||
graph.edges[edge.id] = edge
|
||||
|
||||
storage.save_graph(graph)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ ec = meta.ec + 1 })
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.ADD,
|
||||
path = "graph.edges." .. edge.id,
|
||||
ah = hash.compute_table(edge),
|
||||
})
|
||||
|
||||
return edge
|
||||
end
|
||||
|
||||
--- Get edge by source and target
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@param edge_type? EdgeType Optional edge type filter
|
||||
---@return Edge|nil
|
||||
function M.get(source_id, target_id, edge_type)
|
||||
local graph = storage.get_graph()
|
||||
local edge_id = hash.edge_id(source_id, target_id)
|
||||
|
||||
if not graph.edges or not graph.edges[edge_id] then
|
||||
return nil
|
||||
end
|
||||
|
||||
local edge = graph.edges[edge_id]
|
||||
|
||||
if edge_type and edge.ty ~= edge_type then
|
||||
return nil
|
||||
end
|
||||
|
||||
return edge
|
||||
end
|
||||
|
||||
--- Get all edges for a node
|
||||
---@param node_id string Node ID
|
||||
---@param edge_types? EdgeType[] Edge types to include
|
||||
---@param direction? "out"|"in"|"both" Direction (default: "out")
|
||||
---@return Edge[]
|
||||
function M.get_edges(node_id, edge_types, direction)
|
||||
direction = direction or "out"
|
||||
local graph = storage.get_graph()
|
||||
local results = {}
|
||||
|
||||
edge_types = edge_types or vim.tbl_values(types.EDGE_TYPES)
|
||||
|
||||
-- Outgoing edges
|
||||
if direction == "out" or direction == "both" then
|
||||
local adj = graph.adj[node_id]
|
||||
if adj then
|
||||
for _, edge_type in ipairs(edge_types) do
|
||||
local targets = adj[edge_type] or {}
|
||||
for _, target_id in ipairs(targets) do
|
||||
local edge_id = hash.edge_id(node_id, target_id)
|
||||
if graph.edges and graph.edges[edge_id] then
|
||||
table.insert(results, graph.edges[edge_id])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Incoming edges
|
||||
if direction == "in" or direction == "both" then
|
||||
local radj = graph.radj[node_id]
|
||||
if radj then
|
||||
for _, edge_type in ipairs(edge_types) do
|
||||
local sources = radj[edge_type] or {}
|
||||
for _, source_id in ipairs(sources) do
|
||||
local edge_id = hash.edge_id(source_id, node_id)
|
||||
if graph.edges and graph.edges[edge_id] then
|
||||
table.insert(results, graph.edges[edge_id])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
--- Get neighbor node IDs
|
||||
---@param node_id string Node ID
|
||||
---@param edge_types? EdgeType[] Edge types to follow
|
||||
---@param direction? "out"|"in"|"both" Direction
|
||||
---@return string[] Neighbor node IDs
|
||||
function M.get_neighbors(node_id, edge_types, direction)
|
||||
direction = direction or "out"
|
||||
local graph = storage.get_graph()
|
||||
local neighbors = {}
|
||||
|
||||
edge_types = edge_types or vim.tbl_values(types.EDGE_TYPES)
|
||||
|
||||
-- Outgoing
|
||||
if direction == "out" or direction == "both" then
|
||||
local adj = graph.adj[node_id]
|
||||
if adj then
|
||||
for _, edge_type in ipairs(edge_types) do
|
||||
for _, target in ipairs(adj[edge_type] or {}) do
|
||||
if not vim.tbl_contains(neighbors, target) then
|
||||
table.insert(neighbors, target)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Incoming
|
||||
if direction == "in" or direction == "both" then
|
||||
local radj = graph.radj[node_id]
|
||||
if radj then
|
||||
for _, edge_type in ipairs(edge_types) do
|
||||
for _, source in ipairs(radj[edge_type] or {}) do
|
||||
if not vim.tbl_contains(neighbors, source) then
|
||||
table.insert(neighbors, source)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return neighbors
|
||||
end
|
||||
|
||||
--- Delete an edge
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@param edge_type? EdgeType Edge type (deletes all if nil)
|
||||
---@return boolean Success
|
||||
function M.delete(source_id, target_id, edge_type)
|
||||
local graph = storage.get_graph()
|
||||
local edge_id = hash.edge_id(source_id, target_id)
|
||||
|
||||
if not graph.edges or not graph.edges[edge_id] then
|
||||
return false
|
||||
end
|
||||
|
||||
local edge = graph.edges[edge_id]
|
||||
|
||||
if edge_type and edge.ty ~= edge_type then
|
||||
return false
|
||||
end
|
||||
|
||||
local before_hash = hash.compute_table(edge)
|
||||
|
||||
-- Remove from adjacency
|
||||
if graph.adj[source_id] and graph.adj[source_id][edge.ty] then
|
||||
graph.adj[source_id][edge.ty] = vim.tbl_filter(function(id)
|
||||
return id ~= target_id
|
||||
end, graph.adj[source_id][edge.ty])
|
||||
end
|
||||
|
||||
-- Remove from reverse adjacency
|
||||
if graph.radj[target_id] and graph.radj[target_id][edge.ty] then
|
||||
graph.radj[target_id][edge.ty] = vim.tbl_filter(function(id)
|
||||
return id ~= source_id
|
||||
end, graph.radj[target_id][edge.ty])
|
||||
end
|
||||
|
||||
-- Remove edge data
|
||||
graph.edges[edge_id] = nil
|
||||
|
||||
storage.save_graph(graph)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ ec = math.max(0, meta.ec - 1) })
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.DELETE,
|
||||
path = "graph.edges." .. edge_id,
|
||||
bh = before_hash,
|
||||
})
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Delete all edges for a node
|
||||
---@param node_id string Node ID
|
||||
---@return number Number of deleted edges
|
||||
function M.delete_all(node_id)
|
||||
local edges = M.get_edges(node_id, nil, "both")
|
||||
local count = 0
|
||||
|
||||
for _, edge in ipairs(edges) do
|
||||
if M.delete(edge.s, edge.t, edge.ty) then
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
|
||||
return count
|
||||
end
|
||||
|
||||
--- Strengthen an existing edge
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@param edge_type EdgeType Edge type
|
||||
---@return Edge|nil Updated edge
|
||||
function M.strengthen(source_id, target_id, edge_type)
|
||||
local graph = storage.get_graph()
|
||||
local edge_id = hash.edge_id(source_id, target_id)
|
||||
|
||||
if not graph.edges or not graph.edges[edge_id] then
|
||||
return nil
|
||||
end
|
||||
|
||||
local edge = graph.edges[edge_id]
|
||||
|
||||
if edge.ty ~= edge_type then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Increase weight (diminishing returns)
|
||||
edge.p.w = math.min(1.0, edge.p.w + (1 - edge.p.w) * 0.1)
|
||||
edge.ts = os.time()
|
||||
|
||||
graph.edges[edge_id] = edge
|
||||
storage.save_graph(graph)
|
||||
|
||||
return edge
|
||||
end
|
||||
|
||||
--- Find path between two nodes
|
||||
---@param from_id string Start node ID
|
||||
---@param to_id string End node ID
|
||||
---@param max_depth? number Maximum depth (default: 5)
|
||||
---@return table|nil Path info {nodes: string[], edges: Edge[], found: boolean}
|
||||
function M.find_path(from_id, to_id, max_depth)
|
||||
max_depth = max_depth or 5
|
||||
|
||||
-- BFS
|
||||
local queue = { { id = from_id, path = {}, edges = {} } }
|
||||
local visited = { [from_id] = true }
|
||||
|
||||
while #queue > 0 do
|
||||
local current = table.remove(queue, 1)
|
||||
|
||||
if current.id == to_id then
|
||||
table.insert(current.path, to_id)
|
||||
return {
|
||||
nodes = current.path,
|
||||
edges = current.edges,
|
||||
found = true,
|
||||
}
|
||||
end
|
||||
|
||||
if #current.path >= max_depth then
|
||||
goto continue
|
||||
end
|
||||
|
||||
-- Get all neighbors
|
||||
local edges = M.get_edges(current.id, nil, "both")
|
||||
|
||||
for _, edge in ipairs(edges) do
|
||||
local neighbor = edge.s == current.id and edge.t or edge.s
|
||||
|
||||
if not visited[neighbor] then
|
||||
visited[neighbor] = true
|
||||
|
||||
local new_path = vim.list_extend({}, current.path)
|
||||
table.insert(new_path, current.id)
|
||||
|
||||
local new_edges = vim.list_extend({}, current.edges)
|
||||
table.insert(new_edges, edge)
|
||||
|
||||
table.insert(queue, {
|
||||
id = neighbor,
|
||||
path = new_path,
|
||||
edges = new_edges,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
::continue::
|
||||
end
|
||||
|
||||
return { nodes = {}, edges = {}, found = false }
|
||||
end
|
||||
|
||||
--- Get pending changes and clear
|
||||
---@return table[] Pending changes
|
||||
function M.get_and_clear_pending()
|
||||
local changes = M.pending
|
||||
M.pending = {}
|
||||
return changes
|
||||
end
|
||||
|
||||
--- Check if two nodes are connected
|
||||
---@param node_id_1 string First node ID
|
||||
---@param node_id_2 string Second node ID
|
||||
---@param edge_type? EdgeType Edge type filter
|
||||
---@return boolean
|
||||
function M.are_connected(node_id_1, node_id_2, edge_type)
|
||||
local edge = M.get(node_id_1, node_id_2, edge_type)
|
||||
if edge then
|
||||
return true
|
||||
end
|
||||
-- Check reverse
|
||||
edge = M.get(node_id_2, node_id_1, edge_type)
|
||||
return edge ~= nil
|
||||
end
|
||||
|
||||
return M
|
||||
213
lua/codetyper/core/memory/graph/init.lua
Normal file
213
lua/codetyper/core/memory/graph/init.lua
Normal file
@@ -0,0 +1,213 @@
|
||||
--- Brain Graph Coordinator
|
||||
--- High-level graph operations
|
||||
|
||||
local node = require("codetyper.core.memory.graph.node")
|
||||
local edge = require("codetyper.core.memory.graph.edge")
|
||||
local query = require("codetyper.core.memory.graph.query")
|
||||
local storage = require("codetyper.core.memory.storage")
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
-- Re-export submodules
|
||||
M.node = node
|
||||
M.edge = edge
|
||||
M.query = query
|
||||
|
||||
--- Add a learning with automatic edge creation
|
||||
---@param node_type NodeType Node type
|
||||
---@param content NodeContent Content
|
||||
---@param context? NodeContext Context
|
||||
---@param related_ids? string[] Related node IDs
|
||||
---@return Node Created node
|
||||
function M.add_learning(node_type, content, context, related_ids)
|
||||
-- Create the node
|
||||
local new_node = node.create(node_type, content, context)
|
||||
|
||||
-- Create edges to related nodes
|
||||
if related_ids then
|
||||
for _, related_id in ipairs(related_ids) do
|
||||
local related_node = node.get(related_id)
|
||||
if related_node then
|
||||
-- Determine edge type based on relationship
|
||||
local edge_type = types.EDGE_TYPES.SEMANTIC
|
||||
|
||||
-- If same file, use file edge
|
||||
if context and context.f and related_node.ctx and related_node.ctx.f == context.f then
|
||||
edge_type = types.EDGE_TYPES.FILE
|
||||
end
|
||||
|
||||
edge.create(new_node.id, related_id, edge_type, {
|
||||
w = 0.5,
|
||||
r = "Related learning",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Find and link to similar existing nodes
|
||||
local similar = query.semantic_search(content.s, 5)
|
||||
for _, sim_node in ipairs(similar) do
|
||||
if sim_node.id ~= new_node.id then
|
||||
-- Create semantic edge if similarity is high enough
|
||||
local sim_score = query.compute_relevance(sim_node, { query = content.s })
|
||||
if sim_score > 0.5 then
|
||||
edge.create(new_node.id, sim_node.id, types.EDGE_TYPES.SEMANTIC, {
|
||||
w = sim_score,
|
||||
r = "Semantic similarity",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return new_node
|
||||
end
|
||||
|
||||
--- Remove a learning and its edges
|
||||
---@param node_id string Node ID to remove
|
||||
---@return boolean Success
|
||||
function M.remove_learning(node_id)
|
||||
-- Delete all edges first
|
||||
edge.delete_all(node_id)
|
||||
|
||||
-- Delete the node
|
||||
return node.delete(node_id)
|
||||
end
|
||||
|
||||
--- Prune low-value nodes
|
||||
---@param opts? table Prune options
|
||||
---@return number Number of pruned nodes
|
||||
function M.prune(opts)
|
||||
opts = opts or {}
|
||||
local threshold = opts.threshold or 0.1
|
||||
local unused_days = opts.unused_days or 90
|
||||
local now = os.time()
|
||||
local cutoff = now - (unused_days * 86400)
|
||||
|
||||
local pruned = 0
|
||||
|
||||
-- Find nodes to prune
|
||||
for _, node_type in pairs(types.NODE_TYPES) do
|
||||
local nodes_to_prune = node.find({
|
||||
types = { node_type },
|
||||
min_weight = 0, -- Get all
|
||||
})
|
||||
|
||||
for _, n in ipairs(nodes_to_prune) do
|
||||
local should_prune = false
|
||||
|
||||
-- Prune if weight below threshold and not used recently
|
||||
if n.sc.w < threshold and (n.ts.lu or n.ts.up) < cutoff then
|
||||
should_prune = true
|
||||
end
|
||||
|
||||
-- Prune if never used and old
|
||||
if n.sc.u == 0 and n.ts.cr < cutoff then
|
||||
should_prune = true
|
||||
end
|
||||
|
||||
if should_prune then
|
||||
if M.remove_learning(n.id) then
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return pruned
|
||||
end
|
||||
|
||||
--- Get all pending changes from nodes and edges
|
||||
---@return table[] Combined pending changes
|
||||
function M.get_pending_changes()
|
||||
local changes = {}
|
||||
|
||||
-- Get node changes
|
||||
local node_changes = node.get_and_clear_pending()
|
||||
for _, change in ipairs(node_changes) do
|
||||
table.insert(changes, change)
|
||||
end
|
||||
|
||||
-- Get edge changes
|
||||
local edge_changes = edge.get_and_clear_pending()
|
||||
for _, change in ipairs(edge_changes) do
|
||||
table.insert(changes, change)
|
||||
end
|
||||
|
||||
return changes
|
||||
end
|
||||
|
||||
--- Get graph statistics
|
||||
---@return table Stats
|
||||
function M.stats()
|
||||
local meta = storage.get_meta()
|
||||
|
||||
-- Count nodes by type
|
||||
local by_type = {}
|
||||
for _, node_type in pairs(types.NODE_TYPES) do
|
||||
local nodes = storage.get_nodes(node_type .. "s")
|
||||
by_type[node_type] = vim.tbl_count(nodes)
|
||||
end
|
||||
|
||||
-- Count edges by type
|
||||
local graph = storage.get_graph()
|
||||
local edges_by_type = {}
|
||||
if graph.edges then
|
||||
for _, e in pairs(graph.edges) do
|
||||
edges_by_type[e.ty] = (edges_by_type[e.ty] or 0) + 1
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
node_count = meta.nc,
|
||||
edge_count = meta.ec,
|
||||
delta_count = meta.dc,
|
||||
nodes_by_type = by_type,
|
||||
edges_by_type = edges_by_type,
|
||||
}
|
||||
end
|
||||
|
||||
--- Create temporal edge between nodes created in sequence
|
||||
---@param node_ids string[] Node IDs in temporal order
|
||||
function M.link_temporal(node_ids)
|
||||
for i = 1, #node_ids - 1 do
|
||||
edge.create(node_ids[i], node_ids[i + 1], types.EDGE_TYPES.TEMPORAL, {
|
||||
w = 0.7,
|
||||
dir = "fwd",
|
||||
r = "Temporal sequence",
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
--- Create causal edge (this caused that)
|
||||
---@param cause_id string Cause node ID
|
||||
---@param effect_id string Effect node ID
|
||||
---@param reason? string Reason description
|
||||
function M.link_causal(cause_id, effect_id, reason)
|
||||
edge.create(cause_id, effect_id, types.EDGE_TYPES.CAUSAL, {
|
||||
w = 0.8,
|
||||
dir = "fwd",
|
||||
r = reason or "Caused by",
|
||||
})
|
||||
end
|
||||
|
||||
--- Mark a node as superseded by another
|
||||
---@param old_id string Old node ID
|
||||
---@param new_id string New node ID
|
||||
function M.supersede(old_id, new_id)
|
||||
edge.create(old_id, new_id, types.EDGE_TYPES.SUPERSEDES, {
|
||||
w = 1.0,
|
||||
dir = "fwd",
|
||||
r = "Superseded by newer learning",
|
||||
})
|
||||
|
||||
-- Reduce weight of old node
|
||||
local old_node = node.get(old_id)
|
||||
if old_node then
|
||||
node.update(old_id, {
|
||||
sc = { w = old_node.sc.w * 0.5 },
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
403
lua/codetyper/core/memory/graph/node.lua
Normal file
403
lua/codetyper/core/memory/graph/node.lua
Normal file
@@ -0,0 +1,403 @@
|
||||
--- Brain Graph Node Operations
|
||||
--- CRUD operations for learning nodes
|
||||
|
||||
local storage = require("codetyper.core.memory.storage")
|
||||
local hash = require("codetyper.core.memory.hash")
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Pending changes for delta tracking
|
||||
---@type table[]
|
||||
M.pending = {}
|
||||
|
||||
--- Node type to file mapping
|
||||
local TYPE_MAP = {
|
||||
[types.NODE_TYPES.PATTERN] = "patterns",
|
||||
[types.NODE_TYPES.CORRECTION] = "corrections",
|
||||
[types.NODE_TYPES.DECISION] = "decisions",
|
||||
[types.NODE_TYPES.CONVENTION] = "conventions",
|
||||
[types.NODE_TYPES.FEEDBACK] = "feedback",
|
||||
[types.NODE_TYPES.SESSION] = "sessions",
|
||||
-- Full names for convenience
|
||||
patterns = "patterns",
|
||||
corrections = "corrections",
|
||||
decisions = "decisions",
|
||||
conventions = "conventions",
|
||||
feedback = "feedback",
|
||||
sessions = "sessions",
|
||||
}
|
||||
|
||||
--- Get storage key for node type
|
||||
---@param node_type string Node type
|
||||
---@return string Storage key
|
||||
local function get_storage_key(node_type)
|
||||
return TYPE_MAP[node_type] or "patterns"
|
||||
end
|
||||
|
||||
--- Create a new node
|
||||
---@param node_type NodeType Node type
|
||||
---@param content NodeContent Content
|
||||
---@param context? NodeContext Context
|
||||
---@param opts? table Additional options
|
||||
---@return Node Created node
|
||||
function M.create(node_type, content, context, opts)
|
||||
opts = opts or {}
|
||||
local now = os.time()
|
||||
|
||||
local node = {
|
||||
id = hash.node_id(node_type, content.s),
|
||||
t = node_type,
|
||||
h = hash.compute(content.s .. (content.d or "")),
|
||||
c = {
|
||||
s = content.s or "",
|
||||
d = content.d or content.s or "",
|
||||
code = content.code,
|
||||
lang = content.lang,
|
||||
},
|
||||
ctx = context or {},
|
||||
sc = {
|
||||
w = opts.weight or 0.5,
|
||||
u = 0,
|
||||
sr = 1.0,
|
||||
},
|
||||
ts = {
|
||||
cr = now,
|
||||
up = now,
|
||||
lu = now,
|
||||
},
|
||||
m = {
|
||||
src = opts.source or types.SOURCES.AUTO,
|
||||
v = 1,
|
||||
},
|
||||
}
|
||||
|
||||
-- Store node
|
||||
local storage_key = get_storage_key(node_type)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
nodes[node.id] = node
|
||||
storage.save_nodes(storage_key, nodes)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ nc = meta.nc + 1 })
|
||||
|
||||
-- Update indices
|
||||
M.update_indices(node, "add")
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.ADD,
|
||||
path = "nodes." .. storage_key .. "." .. node.id,
|
||||
ah = node.h,
|
||||
})
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
--- Get a node by ID
|
||||
---@param node_id string Node ID
|
||||
---@return Node|nil
|
||||
function M.get(node_id)
|
||||
-- Parse node type from ID (n_<type>_<timestamp>_<hash>)
|
||||
local parts = vim.split(node_id, "_")
|
||||
if #parts < 3 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local node_type = parts[2]
|
||||
local storage_key = get_storage_key(node_type)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
|
||||
return nodes[node_id]
|
||||
end
|
||||
|
||||
--- Update a node
|
||||
---@param node_id string Node ID
|
||||
---@param updates table Partial updates
|
||||
---@return Node|nil Updated node
|
||||
function M.update(node_id, updates)
|
||||
local node = M.get(node_id)
|
||||
if not node then
|
||||
return nil
|
||||
end
|
||||
|
||||
local before_hash = node.h
|
||||
|
||||
-- Apply updates
|
||||
if updates.c then
|
||||
node.c = vim.tbl_deep_extend("force", node.c, updates.c)
|
||||
end
|
||||
if updates.ctx then
|
||||
node.ctx = vim.tbl_deep_extend("force", node.ctx, updates.ctx)
|
||||
end
|
||||
if updates.sc then
|
||||
node.sc = vim.tbl_deep_extend("force", node.sc, updates.sc)
|
||||
end
|
||||
|
||||
-- Update timestamps and hash
|
||||
node.ts.up = os.time()
|
||||
node.h = hash.compute((node.c.s or "") .. (node.c.d or ""))
|
||||
node.m.v = (node.m.v or 0) + 1
|
||||
|
||||
-- Save
|
||||
local storage_key = get_storage_key(node.t)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
nodes[node_id] = node
|
||||
storage.save_nodes(storage_key, nodes)
|
||||
|
||||
-- Update indices if context changed
|
||||
if updates.ctx then
|
||||
M.update_indices(node, "update")
|
||||
end
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.MODIFY,
|
||||
path = "nodes." .. storage_key .. "." .. node_id,
|
||||
bh = before_hash,
|
||||
ah = node.h,
|
||||
})
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
--- Delete a node
|
||||
---@param node_id string Node ID
|
||||
---@return boolean Success
|
||||
function M.delete(node_id)
|
||||
local node = M.get(node_id)
|
||||
if not node then
|
||||
return false
|
||||
end
|
||||
|
||||
local storage_key = get_storage_key(node.t)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
|
||||
if not nodes[node_id] then
|
||||
return false
|
||||
end
|
||||
|
||||
local before_hash = node.h
|
||||
nodes[node_id] = nil
|
||||
storage.save_nodes(storage_key, nodes)
|
||||
|
||||
-- Update meta
|
||||
local meta = storage.get_meta()
|
||||
storage.update_meta({ nc = math.max(0, meta.nc - 1) })
|
||||
|
||||
-- Update indices
|
||||
M.update_indices(node, "delete")
|
||||
|
||||
-- Track pending change
|
||||
table.insert(M.pending, {
|
||||
op = types.DELTA_OPS.DELETE,
|
||||
path = "nodes." .. storage_key .. "." .. node_id,
|
||||
bh = before_hash,
|
||||
})
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Find nodes by criteria
|
||||
---@param criteria table Search criteria
|
||||
---@return Node[]
|
||||
function M.find(criteria)
|
||||
local results = {}
|
||||
|
||||
local node_types = criteria.types or vim.tbl_values(types.NODE_TYPES)
|
||||
|
||||
for _, node_type in ipairs(node_types) do
|
||||
local storage_key = get_storage_key(node_type)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
|
||||
for _, node in pairs(nodes) do
|
||||
local matches = true
|
||||
|
||||
-- Filter by file
|
||||
if criteria.file and node.ctx.f ~= criteria.file then
|
||||
matches = false
|
||||
end
|
||||
|
||||
-- Filter by min weight
|
||||
if criteria.min_weight and node.sc.w < criteria.min_weight then
|
||||
matches = false
|
||||
end
|
||||
|
||||
-- Filter by since timestamp
|
||||
if criteria.since and node.ts.cr < criteria.since then
|
||||
matches = false
|
||||
end
|
||||
|
||||
-- Filter by content match
|
||||
if criteria.query then
|
||||
local query_lower = criteria.query:lower()
|
||||
local summary_lower = (node.c.s or ""):lower()
|
||||
local detail_lower = (node.c.d or ""):lower()
|
||||
if not summary_lower:find(query_lower, 1, true) and not detail_lower:find(query_lower, 1, true) then
|
||||
matches = false
|
||||
end
|
||||
end
|
||||
|
||||
if matches then
|
||||
table.insert(results, node)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort by relevance (weight * recency)
|
||||
table.sort(results, function(a, b)
|
||||
local score_a = a.sc.w * (1 / (1 + (os.time() - a.ts.lu) / 86400))
|
||||
local score_b = b.sc.w * (1 / (1 + (os.time() - b.ts.lu) / 86400))
|
||||
return score_a > score_b
|
||||
end)
|
||||
|
||||
-- Apply limit
|
||||
if criteria.limit and #results > criteria.limit then
|
||||
local limited = {}
|
||||
for i = 1, criteria.limit do
|
||||
limited[i] = results[i]
|
||||
end
|
||||
return limited
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
--- Record usage of a node
|
||||
---@param node_id string Node ID
|
||||
---@param success? boolean Was the usage successful
|
||||
function M.record_usage(node_id, success)
|
||||
local node = M.get(node_id)
|
||||
if not node then
|
||||
return
|
||||
end
|
||||
|
||||
-- Update usage stats
|
||||
node.sc.u = node.sc.u + 1
|
||||
node.ts.lu = os.time()
|
||||
|
||||
-- Update success rate
|
||||
if success ~= nil then
|
||||
local total = node.sc.u
|
||||
local successes = node.sc.sr * (total - 1) + (success and 1 or 0)
|
||||
node.sc.sr = successes / total
|
||||
end
|
||||
|
||||
-- Increase weight slightly for frequently used nodes
|
||||
if node.sc.u > 5 then
|
||||
node.sc.w = math.min(1.0, node.sc.w + 0.01)
|
||||
end
|
||||
|
||||
-- Save (direct save, no pending change tracking for usage)
|
||||
local storage_key = get_storage_key(node.t)
|
||||
local nodes = storage.get_nodes(storage_key)
|
||||
nodes[node_id] = node
|
||||
storage.save_nodes(storage_key, nodes)
|
||||
end
|
||||
|
||||
--- Update indices for a node
|
||||
---@param node Node The node
|
||||
---@param op "add"|"update"|"delete" Operation type
|
||||
function M.update_indices(node, op)
|
||||
-- File index
|
||||
if node.ctx.f then
|
||||
local by_file = storage.get_index("by_file")
|
||||
|
||||
if op == "delete" then
|
||||
if by_file[node.ctx.f] then
|
||||
by_file[node.ctx.f] = vim.tbl_filter(function(id)
|
||||
return id ~= node.id
|
||||
end, by_file[node.ctx.f])
|
||||
end
|
||||
else
|
||||
by_file[node.ctx.f] = by_file[node.ctx.f] or {}
|
||||
if not vim.tbl_contains(by_file[node.ctx.f], node.id) then
|
||||
table.insert(by_file[node.ctx.f], node.id)
|
||||
end
|
||||
end
|
||||
|
||||
storage.save_index("by_file", by_file)
|
||||
end
|
||||
|
||||
-- Symbol index
|
||||
if node.ctx.sym then
|
||||
local by_symbol = storage.get_index("by_symbol")
|
||||
|
||||
for _, sym in ipairs(node.ctx.sym) do
|
||||
if op == "delete" then
|
||||
if by_symbol[sym] then
|
||||
by_symbol[sym] = vim.tbl_filter(function(id)
|
||||
return id ~= node.id
|
||||
end, by_symbol[sym])
|
||||
end
|
||||
else
|
||||
by_symbol[sym] = by_symbol[sym] or {}
|
||||
if not vim.tbl_contains(by_symbol[sym], node.id) then
|
||||
table.insert(by_symbol[sym], node.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
storage.save_index("by_symbol", by_symbol)
|
||||
end
|
||||
|
||||
-- Time index (daily buckets)
|
||||
local day = os.date("%Y-%m-%d", node.ts.cr)
|
||||
local by_time = storage.get_index("by_time")
|
||||
|
||||
if op == "delete" then
|
||||
if by_time[day] then
|
||||
by_time[day] = vim.tbl_filter(function(id)
|
||||
return id ~= node.id
|
||||
end, by_time[day])
|
||||
end
|
||||
elseif op == "add" then
|
||||
by_time[day] = by_time[day] or {}
|
||||
if not vim.tbl_contains(by_time[day], node.id) then
|
||||
table.insert(by_time[day], node.id)
|
||||
end
|
||||
end
|
||||
|
||||
storage.save_index("by_time", by_time)
|
||||
end
|
||||
|
||||
--- Get pending changes and clear
|
||||
---@return table[] Pending changes
|
||||
function M.get_and_clear_pending()
|
||||
local changes = M.pending
|
||||
M.pending = {}
|
||||
return changes
|
||||
end
|
||||
|
||||
--- Merge two similar nodes
|
||||
---@param node_id_1 string First node ID
|
||||
---@param node_id_2 string Second node ID (will be deleted)
|
||||
---@return Node|nil Merged node
|
||||
function M.merge(node_id_1, node_id_2)
|
||||
local node1 = M.get(node_id_1)
|
||||
local node2 = M.get(node_id_2)
|
||||
|
||||
if not node1 or not node2 then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Merge content (keep longer detail)
|
||||
local merged_detail = #node1.c.d > #node2.c.d and node1.c.d or node2.c.d
|
||||
|
||||
-- Merge scores (combine weights and usage)
|
||||
local merged_weight = (node1.sc.w + node2.sc.w) / 2
|
||||
local merged_usage = node1.sc.u + node2.sc.u
|
||||
|
||||
M.update(node_id_1, {
|
||||
c = { d = merged_detail },
|
||||
sc = { w = merged_weight, u = merged_usage },
|
||||
})
|
||||
|
||||
-- Delete the second node
|
||||
M.delete(node_id_2)
|
||||
|
||||
return M.get(node_id_1)
|
||||
end
|
||||
|
||||
return M
|
||||
488
lua/codetyper/core/memory/graph/query.lua
Normal file
488
lua/codetyper/core/memory/graph/query.lua
Normal file
@@ -0,0 +1,488 @@
|
||||
--- Brain Graph Query Engine
|
||||
--- Multi-dimensional traversal and relevance scoring
|
||||
|
||||
local storage = require("codetyper.core.memory.storage")
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Lazy load dependencies to avoid circular requires
|
||||
local function get_node_module()
|
||||
return require("codetyper.core.memory.graph.node")
|
||||
end
|
||||
|
||||
local function get_edge_module()
|
||||
return require("codetyper.core.memory.graph.edge")
|
||||
end
|
||||
|
||||
--- Compute text similarity (simple keyword matching)
|
||||
---@param text1 string First text
|
||||
---@param text2 string Second text
|
||||
---@return number Similarity score (0-1)
|
||||
local function text_similarity(text1, text2)
|
||||
if not text1 or not text2 then
|
||||
return 0
|
||||
end
|
||||
|
||||
text1 = text1:lower()
|
||||
text2 = text2:lower()
|
||||
|
||||
-- Extract words
|
||||
local words1 = {}
|
||||
for word in text1:gmatch("%w+") do
|
||||
words1[word] = true
|
||||
end
|
||||
|
||||
local words2 = {}
|
||||
for word in text2:gmatch("%w+") do
|
||||
words2[word] = true
|
||||
end
|
||||
|
||||
-- Count matches
|
||||
local matches = 0
|
||||
local total = 0
|
||||
|
||||
for word in pairs(words1) do
|
||||
total = total + 1
|
||||
if words2[word] then
|
||||
matches = matches + 1
|
||||
end
|
||||
end
|
||||
|
||||
for word in pairs(words2) do
|
||||
if not words1[word] then
|
||||
total = total + 1
|
||||
end
|
||||
end
|
||||
|
||||
if total == 0 then
|
||||
return 0
|
||||
end
|
||||
|
||||
return matches / total
|
||||
end
|
||||
|
||||
--- Compute relevance score for a node
|
||||
---@param node Node Node to score
|
||||
---@param opts QueryOpts Query options
|
||||
---@return number Relevance score (0-1)
|
||||
function M.compute_relevance(node, opts)
|
||||
local score = 0
|
||||
local weights = {
|
||||
content_match = 0.30,
|
||||
recency = 0.20,
|
||||
usage = 0.15,
|
||||
weight = 0.15,
|
||||
connection_density = 0.10,
|
||||
success_rate = 0.10,
|
||||
}
|
||||
|
||||
-- Content similarity
|
||||
if opts.query then
|
||||
local summary = node.c.s or ""
|
||||
local detail = node.c.d or ""
|
||||
local similarity = math.max(text_similarity(opts.query, summary), text_similarity(opts.query, detail) * 0.8)
|
||||
score = score + (similarity * weights.content_match)
|
||||
else
|
||||
score = score + weights.content_match * 0.5 -- Base score if no query
|
||||
end
|
||||
|
||||
-- Recency decay (exponential with 30-day half-life)
|
||||
local age_days = (os.time() - (node.ts.lu or node.ts.up)) / 86400
|
||||
local recency = math.exp(-age_days / 30)
|
||||
score = score + (recency * weights.recency)
|
||||
|
||||
-- Usage frequency (normalized)
|
||||
local usage = math.min(node.sc.u / 10, 1.0)
|
||||
score = score + (usage * weights.usage)
|
||||
|
||||
-- Node weight
|
||||
score = score + (node.sc.w * weights.weight)
|
||||
|
||||
-- Connection density
|
||||
local edge_mod = get_edge_module()
|
||||
local connections = #edge_mod.get_edges(node.id, nil, "both")
|
||||
local density = math.min(connections / 5, 1.0)
|
||||
score = score + (density * weights.connection_density)
|
||||
|
||||
-- Success rate
|
||||
score = score + (node.sc.sr * weights.success_rate)
|
||||
|
||||
return score
|
||||
end
|
||||
|
||||
--- Traverse graph from seed nodes (basic traversal)
|
||||
---@param seed_ids string[] Starting node IDs
|
||||
---@param depth number Traversal depth
|
||||
---@param edge_types? EdgeType[] Edge types to follow
|
||||
---@return table<string, Node> Discovered nodes indexed by ID
|
||||
local function traverse(seed_ids, depth, edge_types)
|
||||
local node_mod = get_node_module()
|
||||
local edge_mod = get_edge_module()
|
||||
local discovered = {}
|
||||
local frontier = seed_ids
|
||||
|
||||
for _ = 1, depth do
|
||||
local next_frontier = {}
|
||||
|
||||
for _, node_id in ipairs(frontier) do
|
||||
-- Skip if already discovered
|
||||
if discovered[node_id] then
|
||||
goto continue
|
||||
end
|
||||
|
||||
-- Get and store node
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
discovered[node_id] = node
|
||||
|
||||
-- Get neighbors
|
||||
local neighbors = edge_mod.get_neighbors(node_id, edge_types, "both")
|
||||
for _, neighbor_id in ipairs(neighbors) do
|
||||
if not discovered[neighbor_id] then
|
||||
table.insert(next_frontier, neighbor_id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
::continue::
|
||||
end
|
||||
|
||||
frontier = next_frontier
|
||||
if #frontier == 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
return discovered
|
||||
end
|
||||
|
||||
--- Spreading activation - mimics human associative memory
|
||||
--- Activation spreads from seed nodes along edges, decaying by weight
|
||||
--- Nodes accumulate activation from multiple paths (like neural pathways)
|
||||
---@param seed_activations table<string, number> Initial activations {node_id: activation}
|
||||
---@param max_iterations number Max spread iterations (default 3)
|
||||
---@param decay number Activation decay per hop (default 0.5)
|
||||
---@param threshold number Minimum activation to continue spreading (default 0.1)
|
||||
---@return table<string, number> Final activations {node_id: accumulated_activation}
|
||||
local function spreading_activation(seed_activations, max_iterations, decay, threshold)
|
||||
local edge_mod = get_edge_module()
|
||||
max_iterations = max_iterations or 3
|
||||
decay = decay or 0.5
|
||||
threshold = threshold or 0.1
|
||||
|
||||
-- Accumulated activation for each node
|
||||
local activation = {}
|
||||
for node_id, act in pairs(seed_activations) do
|
||||
activation[node_id] = act
|
||||
end
|
||||
|
||||
-- Current frontier with their activation levels
|
||||
local frontier = {}
|
||||
for node_id, act in pairs(seed_activations) do
|
||||
frontier[node_id] = act
|
||||
end
|
||||
|
||||
-- Spread activation iteratively
|
||||
for _ = 1, max_iterations do
|
||||
local next_frontier = {}
|
||||
|
||||
for source_id, source_activation in pairs(frontier) do
|
||||
-- Get all outgoing edges
|
||||
local edges = edge_mod.get_edges(source_id, nil, "both")
|
||||
|
||||
for _, edge in ipairs(edges) do
|
||||
-- Determine target (could be source or target of edge)
|
||||
local target_id = edge.s == source_id and edge.t or edge.s
|
||||
|
||||
-- Calculate spreading activation
|
||||
-- Activation = source_activation * edge_weight * decay
|
||||
local edge_weight = edge.p and edge.p.w or 0.5
|
||||
local spread_amount = source_activation * edge_weight * decay
|
||||
|
||||
-- Only spread if above threshold
|
||||
if spread_amount >= threshold then
|
||||
-- Accumulate activation (multiple paths add up)
|
||||
activation[target_id] = (activation[target_id] or 0) + spread_amount
|
||||
|
||||
-- Add to next frontier if not already processed with higher activation
|
||||
if not next_frontier[target_id] or next_frontier[target_id] < spread_amount then
|
||||
next_frontier[target_id] = spread_amount
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Stop if no more spreading
|
||||
if vim.tbl_count(next_frontier) == 0 then
|
||||
break
|
||||
end
|
||||
|
||||
frontier = next_frontier
|
||||
end
|
||||
|
||||
return activation
|
||||
end
|
||||
|
||||
--- Execute a query across all dimensions
|
||||
---@param opts QueryOpts Query options
|
||||
---@return QueryResult
|
||||
function M.execute(opts)
|
||||
opts = opts or {}
|
||||
local node_mod = get_node_module()
|
||||
local results = {
|
||||
semantic = {},
|
||||
file = {},
|
||||
temporal = {},
|
||||
}
|
||||
|
||||
-- 1. Semantic traversal (content similarity)
|
||||
if opts.query then
|
||||
local seed_nodes = node_mod.find({
|
||||
query = opts.query,
|
||||
types = opts.types,
|
||||
limit = 10,
|
||||
})
|
||||
|
||||
local seed_ids = vim.tbl_map(function(n)
|
||||
return n.id
|
||||
end, seed_nodes)
|
||||
local depth = opts.depth or 2
|
||||
|
||||
local discovered = traverse(seed_ids, depth, { types.EDGE_TYPES.SEMANTIC })
|
||||
for id, node in pairs(discovered) do
|
||||
results.semantic[id] = node
|
||||
end
|
||||
end
|
||||
|
||||
-- 2. File-based traversal
|
||||
if opts.file then
|
||||
local by_file = storage.get_index("by_file")
|
||||
local file_node_ids = by_file[opts.file] or {}
|
||||
|
||||
for _, node_id in ipairs(file_node_ids) do
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
results.file[node.id] = node
|
||||
end
|
||||
end
|
||||
|
||||
-- Also get nodes from related files via edges
|
||||
local discovered = traverse(file_node_ids, 1, { types.EDGE_TYPES.FILE })
|
||||
for id, node in pairs(discovered) do
|
||||
results.file[id] = node
|
||||
end
|
||||
end
|
||||
|
||||
-- 3. Temporal traversal (recent context)
|
||||
if opts.since then
|
||||
local by_time = storage.get_index("by_time")
|
||||
local now = os.time()
|
||||
|
||||
for day, node_ids in pairs(by_time) do
|
||||
-- Parse day to timestamp
|
||||
local year, month, day_num = day:match("(%d+)-(%d+)-(%d+)")
|
||||
if year then
|
||||
local day_ts = os.time({ year = tonumber(year), month = tonumber(month), day = tonumber(day_num) })
|
||||
if day_ts >= opts.since then
|
||||
for _, node_id in ipairs(node_ids) do
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
results.temporal[node.id] = node
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Follow temporal edges
|
||||
local temporal_ids = vim.tbl_keys(results.temporal)
|
||||
local discovered = traverse(temporal_ids, 1, { types.EDGE_TYPES.TEMPORAL })
|
||||
for id, node in pairs(discovered) do
|
||||
results.temporal[id] = node
|
||||
end
|
||||
end
|
||||
|
||||
-- 4. Combine all found nodes and compute seed activations
|
||||
local all_nodes = {}
|
||||
local seed_activations = {}
|
||||
|
||||
for _, category in pairs(results) do
|
||||
for id, node in pairs(category) do
|
||||
if not all_nodes[id] then
|
||||
all_nodes[id] = node
|
||||
-- Compute initial activation based on relevance
|
||||
local relevance = M.compute_relevance(node, opts)
|
||||
seed_activations[id] = relevance
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 5. Apply spreading activation - like human associative memory
|
||||
-- Activation spreads from seed nodes along edges, accumulating
|
||||
-- Nodes connected to multiple relevant seeds get higher activation
|
||||
local final_activations = spreading_activation(
|
||||
seed_activations,
|
||||
opts.spread_iterations or 3, -- How far activation spreads
|
||||
opts.spread_decay or 0.5, -- How much activation decays per hop
|
||||
opts.spread_threshold or 0.05 -- Minimum activation to continue spreading
|
||||
)
|
||||
|
||||
-- 6. Score and rank by combined activation
|
||||
local scored = {}
|
||||
for id, activation in pairs(final_activations) do
|
||||
local node = all_nodes[id] or node_mod.get(id)
|
||||
if node then
|
||||
all_nodes[id] = node
|
||||
-- Final score = spreading activation + base relevance
|
||||
local base_relevance = M.compute_relevance(node, opts)
|
||||
local final_score = (activation * 0.6) + (base_relevance * 0.4)
|
||||
table.insert(scored, { node = node, relevance = final_score, activation = activation })
|
||||
end
|
||||
end
|
||||
|
||||
table.sort(scored, function(a, b)
|
||||
return a.relevance > b.relevance
|
||||
end)
|
||||
|
||||
-- 7. Apply limit
|
||||
local limit = opts.limit or 50
|
||||
local result_nodes = {}
|
||||
local truncated = #scored > limit
|
||||
|
||||
for i = 1, math.min(limit, #scored) do
|
||||
table.insert(result_nodes, scored[i].node)
|
||||
end
|
||||
|
||||
-- 8. Get edges between result nodes
|
||||
local edge_mod = get_edge_module()
|
||||
local result_edges = {}
|
||||
local node_ids = {}
|
||||
for _, node in ipairs(result_nodes) do
|
||||
node_ids[node.id] = true
|
||||
end
|
||||
|
||||
for _, node in ipairs(result_nodes) do
|
||||
local edges = edge_mod.get_edges(node.id, nil, "out")
|
||||
for _, edge in ipairs(edges) do
|
||||
if node_ids[edge.t] then
|
||||
table.insert(result_edges, edge)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
nodes = result_nodes,
|
||||
edges = result_edges,
|
||||
stats = {
|
||||
semantic_count = vim.tbl_count(results.semantic),
|
||||
file_count = vim.tbl_count(results.file),
|
||||
temporal_count = vim.tbl_count(results.temporal),
|
||||
total_scored = #scored,
|
||||
seed_nodes = vim.tbl_count(seed_activations),
|
||||
activated_nodes = vim.tbl_count(final_activations),
|
||||
},
|
||||
truncated = truncated,
|
||||
}
|
||||
end
|
||||
|
||||
--- Expose spreading activation for direct use
|
||||
--- Useful for custom activation patterns or debugging
|
||||
M.spreading_activation = spreading_activation
|
||||
|
||||
--- Find nodes by file
|
||||
---@param filepath string File path
|
||||
---@param limit? number Max results
|
||||
---@return Node[]
|
||||
function M.by_file(filepath, limit)
|
||||
local result = M.execute({
|
||||
file = filepath,
|
||||
limit = limit or 20,
|
||||
})
|
||||
return result.nodes
|
||||
end
|
||||
|
||||
--- Find nodes by time range
|
||||
---@param since number Start timestamp
|
||||
---@param until_ts? number End timestamp
|
||||
---@param limit? number Max results
|
||||
---@return Node[]
|
||||
function M.by_time_range(since, until_ts, limit)
|
||||
local node_mod = get_node_module()
|
||||
local by_time = storage.get_index("by_time")
|
||||
local results = {}
|
||||
|
||||
until_ts = until_ts or os.time()
|
||||
|
||||
for day, node_ids in pairs(by_time) do
|
||||
local year, month, day_num = day:match("(%d+)-(%d+)-(%d+)")
|
||||
if year then
|
||||
local day_ts = os.time({ year = tonumber(year), month = tonumber(month), day = tonumber(day_num) })
|
||||
if day_ts >= since and day_ts <= until_ts then
|
||||
for _, node_id in ipairs(node_ids) do
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
table.insert(results, node)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort by creation time
|
||||
table.sort(results, function(a, b)
|
||||
return a.ts.cr > b.ts.cr
|
||||
end)
|
||||
|
||||
if limit and #results > limit then
|
||||
local limited = {}
|
||||
for i = 1, limit do
|
||||
limited[i] = results[i]
|
||||
end
|
||||
return limited
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
--- Find semantically similar nodes
|
||||
---@param query string Query text
|
||||
---@param limit? number Max results
|
||||
---@return Node[]
|
||||
function M.semantic_search(query, limit)
|
||||
local result = M.execute({
|
||||
query = query,
|
||||
limit = limit or 10,
|
||||
depth = 2,
|
||||
})
|
||||
return result.nodes
|
||||
end
|
||||
|
||||
--- Get context chain (path) for explanation
|
||||
---@param node_ids string[] Node IDs to chain
|
||||
---@return string[] Chain descriptions
|
||||
function M.get_context_chain(node_ids)
|
||||
local node_mod = get_node_module()
|
||||
local edge_mod = get_edge_module()
|
||||
local chain = {}
|
||||
|
||||
for i, node_id in ipairs(node_ids) do
|
||||
local node = node_mod.get(node_id)
|
||||
if node then
|
||||
local entry = string.format("[%s] %s (w:%.2f)", node.t:upper(), node.c.s, node.sc.w)
|
||||
table.insert(chain, entry)
|
||||
|
||||
-- Add edge to next node if exists
|
||||
if node_ids[i + 1] then
|
||||
local edge = edge_mod.get(node_id, node_ids[i + 1])
|
||||
if edge then
|
||||
table.insert(chain, string.format(" -> %s (w:%.2f)", edge.ty, edge.p.w))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return chain
|
||||
end
|
||||
|
||||
return M
|
||||
112
lua/codetyper/core/memory/hash.lua
Normal file
112
lua/codetyper/core/memory/hash.lua
Normal file
@@ -0,0 +1,112 @@
|
||||
--- Brain Hashing Utilities
|
||||
--- Content-addressable storage with 8-character hashes
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Simple DJB2 hash algorithm (fast, good distribution)
|
||||
---@param str string String to hash
|
||||
---@return number Hash value
|
||||
local function djb2(str)
|
||||
local hash = 5381
|
||||
for i = 1, #str do
|
||||
hash = ((hash * 33) + string.byte(str, i)) % 0x100000000
|
||||
end
|
||||
return hash
|
||||
end
|
||||
|
||||
--- Convert number to hex string
|
||||
---@param num number Number to convert
|
||||
---@param len number Desired length
|
||||
---@return string Hex string
|
||||
local function to_hex(num, len)
|
||||
local hex = string.format("%x", num)
|
||||
if #hex < len then
|
||||
hex = string.rep("0", len - #hex) .. hex
|
||||
end
|
||||
return hex:sub(-len)
|
||||
end
|
||||
|
||||
--- Compute 8-character hash from string
|
||||
---@param content string Content to hash
|
||||
---@return string 8-character hex hash
|
||||
function M.compute(content)
|
||||
if not content or content == "" then
|
||||
return "00000000"
|
||||
end
|
||||
local hash = djb2(content)
|
||||
return to_hex(hash, 8)
|
||||
end
|
||||
|
||||
--- Compute hash from table (JSON-serialized)
|
||||
---@param tbl table Table to hash
|
||||
---@return string 8-character hex hash
|
||||
function M.compute_table(tbl)
|
||||
local ok, json = pcall(vim.json.encode, tbl)
|
||||
if not ok then
|
||||
return "00000000"
|
||||
end
|
||||
return M.compute(json)
|
||||
end
|
||||
|
||||
--- Generate unique node ID
|
||||
---@param node_type string Node type prefix
|
||||
---@param content? string Optional content for hash
|
||||
---@return string Node ID (n_<timestamp>_<hash>)
|
||||
function M.node_id(node_type, content)
|
||||
local ts = os.time()
|
||||
local hash_input = (content or "") .. tostring(ts) .. tostring(math.random(100000))
|
||||
local hash = M.compute(hash_input):sub(1, 6)
|
||||
return string.format("n_%s_%d_%s", node_type, ts, hash)
|
||||
end
|
||||
|
||||
--- Generate unique edge ID
|
||||
---@param source_id string Source node ID
|
||||
---@param target_id string Target node ID
|
||||
---@return string Edge ID (e_<source_hash>_<target_hash>)
|
||||
function M.edge_id(source_id, target_id)
|
||||
local src_hash = M.compute(source_id):sub(1, 4)
|
||||
local tgt_hash = M.compute(target_id):sub(1, 4)
|
||||
return string.format("e_%s_%s", src_hash, tgt_hash)
|
||||
end
|
||||
|
||||
--- Generate delta hash
|
||||
---@param changes table[] Delta changes
|
||||
---@param parent string|nil Parent delta hash
|
||||
---@param timestamp number Delta timestamp
|
||||
---@return string 8-character delta hash
|
||||
function M.delta_hash(changes, parent, timestamp)
|
||||
local content = (parent or "root") .. tostring(timestamp)
|
||||
for _, change in ipairs(changes or {}) do
|
||||
content = content .. (change.op or "") .. (change.path or "")
|
||||
end
|
||||
return M.compute(content)
|
||||
end
|
||||
|
||||
--- Hash file path for storage
|
||||
---@param filepath string File path
|
||||
---@return string 8-character hash
|
||||
function M.path_hash(filepath)
|
||||
return M.compute(filepath)
|
||||
end
|
||||
|
||||
--- Check if two hashes match
|
||||
---@param hash1 string First hash
|
||||
---@param hash2 string Second hash
|
||||
---@return boolean True if matching
|
||||
function M.matches(hash1, hash2)
|
||||
return hash1 == hash2
|
||||
end
|
||||
|
||||
--- Generate random hash (for testing/temporary IDs)
|
||||
---@return string 8-character random hash
|
||||
function M.random()
|
||||
local chars = "0123456789abcdef"
|
||||
local result = ""
|
||||
for _ = 1, 8 do
|
||||
local idx = math.random(1, #chars)
|
||||
result = result .. chars:sub(idx, idx)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
return M
|
||||
276
lua/codetyper/core/memory/init.lua
Normal file
276
lua/codetyper/core/memory/init.lua
Normal file
@@ -0,0 +1,276 @@
|
||||
--- Brain Learning System
|
||||
--- Graph-based knowledge storage with delta versioning
|
||||
|
||||
local storage = require("codetyper.core.memory.storage")
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
---@type BrainConfig|nil
|
||||
local config = nil
|
||||
|
||||
---@type boolean
|
||||
local initialized = false
|
||||
|
||||
--- Pending changes counter for auto-commit
|
||||
local pending_changes = 0
|
||||
|
||||
--- Default configuration
|
||||
local DEFAULT_CONFIG = {
|
||||
enabled = true,
|
||||
auto_learn = true,
|
||||
auto_commit = true,
|
||||
commit_threshold = 10,
|
||||
max_nodes = 5000,
|
||||
max_deltas = 500,
|
||||
prune = {
|
||||
enabled = true,
|
||||
threshold = 0.1,
|
||||
unused_days = 90,
|
||||
},
|
||||
output = {
|
||||
max_tokens = 4000,
|
||||
format = "compact",
|
||||
},
|
||||
}
|
||||
|
||||
--- Initialize brain system
|
||||
---@param opts? BrainConfig Configuration options
|
||||
function M.setup(opts)
|
||||
config = vim.tbl_deep_extend("force", DEFAULT_CONFIG, opts or {})
|
||||
|
||||
if not config.enabled then
|
||||
return
|
||||
end
|
||||
|
||||
-- Ensure storage directories
|
||||
storage.ensure_dirs()
|
||||
|
||||
-- Initialize meta if not exists
|
||||
storage.get_meta()
|
||||
|
||||
initialized = true
|
||||
end
|
||||
|
||||
--- Check if brain is initialized
|
||||
---@return boolean
|
||||
function M.is_initialized()
|
||||
return initialized and config and config.enabled
|
||||
end
|
||||
|
||||
--- Get current configuration
|
||||
---@return BrainConfig|nil
|
||||
function M.get_config()
|
||||
return config
|
||||
end
|
||||
|
||||
--- Learn from an event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return string|nil Node ID if created
|
||||
function M.learn(event)
|
||||
if not M.is_initialized() or not config.auto_learn then
|
||||
return nil
|
||||
end
|
||||
|
||||
local learners = require("codetyper.core.memory.learners")
|
||||
local node_id = learners.process(event)
|
||||
|
||||
if node_id then
|
||||
pending_changes = pending_changes + 1
|
||||
|
||||
-- Auto-commit if threshold reached
|
||||
if config.auto_commit and pending_changes >= config.commit_threshold then
|
||||
M.commit("Auto-commit: " .. pending_changes .. " changes")
|
||||
pending_changes = 0
|
||||
end
|
||||
end
|
||||
|
||||
return node_id
|
||||
end
|
||||
|
||||
--- Query relevant knowledge for context
|
||||
---@param opts QueryOpts Query options
|
||||
---@return QueryResult
|
||||
function M.query(opts)
|
||||
if not M.is_initialized() then
|
||||
return { nodes = {}, edges = {}, stats = {}, truncated = false }
|
||||
end
|
||||
|
||||
local query_engine = require("codetyper.core.memory.graph.query")
|
||||
return query_engine.execute(opts)
|
||||
end
|
||||
|
||||
--- Get LLM-optimized context string
|
||||
---@param opts? QueryOpts Query options
|
||||
---@return string Formatted context
|
||||
function M.get_context_for_llm(opts)
|
||||
if not M.is_initialized() then
|
||||
return ""
|
||||
end
|
||||
|
||||
opts = opts or {}
|
||||
opts.max_tokens = opts.max_tokens or config.output.max_tokens
|
||||
|
||||
local result = M.query(opts)
|
||||
local formatter = require("codetyper.core.memory.output.formatter")
|
||||
|
||||
if config.output.format == "json" then
|
||||
return formatter.to_json(result, opts)
|
||||
else
|
||||
return formatter.to_compact(result, opts)
|
||||
end
|
||||
end
|
||||
|
||||
--- Create a delta commit
|
||||
---@param message string Commit message
|
||||
---@return string|nil Delta hash
|
||||
function M.commit(message)
|
||||
if not M.is_initialized() then
|
||||
return nil
|
||||
end
|
||||
|
||||
local delta_mgr = require("codetyper.core.memory.delta")
|
||||
return delta_mgr.commit(message)
|
||||
end
|
||||
|
||||
--- Rollback to a previous delta
|
||||
---@param delta_hash string Target delta hash
|
||||
---@return boolean Success
|
||||
function M.rollback(delta_hash)
|
||||
if not M.is_initialized() then
|
||||
return false
|
||||
end
|
||||
|
||||
local delta_mgr = require("codetyper.core.memory.delta")
|
||||
return delta_mgr.rollback(delta_hash)
|
||||
end
|
||||
|
||||
--- Get delta history
|
||||
---@param limit? number Max entries
|
||||
---@return Delta[]
|
||||
function M.get_history(limit)
|
||||
if not M.is_initialized() then
|
||||
return {}
|
||||
end
|
||||
|
||||
local delta_mgr = require("codetyper.core.memory.delta")
|
||||
return delta_mgr.get_history(limit or 50)
|
||||
end
|
||||
|
||||
--- Prune low-value nodes
|
||||
---@param opts? table Prune options
|
||||
---@return number Number of pruned nodes
|
||||
function M.prune(opts)
|
||||
if not M.is_initialized() or not config.prune.enabled then
|
||||
return 0
|
||||
end
|
||||
|
||||
opts = vim.tbl_extend("force", {
|
||||
threshold = config.prune.threshold,
|
||||
unused_days = config.prune.unused_days,
|
||||
}, opts or {})
|
||||
|
||||
local graph = require("codetyper.core.memory.graph")
|
||||
return graph.prune(opts)
|
||||
end
|
||||
|
||||
--- Export brain state
|
||||
---@return table|nil Exported data
|
||||
function M.export()
|
||||
if not M.is_initialized() then
|
||||
return nil
|
||||
end
|
||||
|
||||
return {
|
||||
schema = types.SCHEMA_VERSION,
|
||||
meta = storage.get_meta(),
|
||||
graph = storage.get_graph(),
|
||||
nodes = {
|
||||
patterns = storage.get_nodes("patterns"),
|
||||
corrections = storage.get_nodes("corrections"),
|
||||
decisions = storage.get_nodes("decisions"),
|
||||
conventions = storage.get_nodes("conventions"),
|
||||
feedback = storage.get_nodes("feedback"),
|
||||
sessions = storage.get_nodes("sessions"),
|
||||
},
|
||||
indices = {
|
||||
by_file = storage.get_index("by_file"),
|
||||
by_time = storage.get_index("by_time"),
|
||||
by_symbol = storage.get_index("by_symbol"),
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Import brain state
|
||||
---@param data table Exported data
|
||||
---@return boolean Success
|
||||
function M.import(data)
|
||||
if not data or data.schema ~= types.SCHEMA_VERSION then
|
||||
return false
|
||||
end
|
||||
|
||||
storage.ensure_dirs()
|
||||
|
||||
-- Import nodes
|
||||
if data.nodes then
|
||||
for node_type, nodes in pairs(data.nodes) do
|
||||
storage.save_nodes(node_type, nodes)
|
||||
end
|
||||
end
|
||||
|
||||
-- Import graph
|
||||
if data.graph then
|
||||
storage.save_graph(data.graph)
|
||||
end
|
||||
|
||||
-- Import indices
|
||||
if data.indices then
|
||||
for index_type, index_data in pairs(data.indices) do
|
||||
storage.save_index(index_type, index_data)
|
||||
end
|
||||
end
|
||||
|
||||
-- Import meta last
|
||||
if data.meta then
|
||||
for k, v in pairs(data.meta) do
|
||||
storage.update_meta({ [k] = v })
|
||||
end
|
||||
end
|
||||
|
||||
storage.flush_all()
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get stats about the brain
|
||||
---@return table Stats
|
||||
function M.stats()
|
||||
if not M.is_initialized() then
|
||||
return {}
|
||||
end
|
||||
|
||||
local meta = storage.get_meta()
|
||||
return {
|
||||
initialized = true,
|
||||
node_count = meta.nc,
|
||||
edge_count = meta.ec,
|
||||
delta_count = meta.dc,
|
||||
head = meta.head,
|
||||
pending_changes = pending_changes,
|
||||
}
|
||||
end
|
||||
|
||||
--- Flush all pending writes to disk
|
||||
function M.flush()
|
||||
storage.flush_all()
|
||||
end
|
||||
|
||||
--- Shutdown brain (call before exit)
|
||||
function M.shutdown()
|
||||
if pending_changes > 0 then
|
||||
M.commit("Session end: " .. pending_changes .. " changes")
|
||||
end
|
||||
storage.flush_all()
|
||||
initialized = false
|
||||
end
|
||||
|
||||
return M
|
||||
233
lua/codetyper/core/memory/learners/convention.lua
Normal file
233
lua/codetyper/core/memory/learners/convention.lua
Normal file
@@ -0,0 +1,233 @@
|
||||
--- Brain Convention Learner
|
||||
--- Learns project conventions and coding standards
|
||||
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Detect if event contains convention info
|
||||
---@param event LearnEvent Learning event
|
||||
---@return boolean
|
||||
function M.detect(event)
|
||||
local valid_types = {
|
||||
"convention_detected",
|
||||
"naming_pattern",
|
||||
"style_pattern",
|
||||
"project_structure",
|
||||
"config_change",
|
||||
}
|
||||
|
||||
for _, t in ipairs(valid_types) do
|
||||
if event.type == t then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Extract convention data from event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return table|nil Extracted data
|
||||
function M.extract(event)
|
||||
local data = event.data or {}
|
||||
|
||||
if event.type == "convention_detected" then
|
||||
return {
|
||||
summary = "Convention: " .. (data.name or "unnamed"),
|
||||
detail = data.description or data.name,
|
||||
rule = data.rule,
|
||||
examples = data.examples,
|
||||
category = data.category or "general",
|
||||
file = event.file,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "naming_pattern" then
|
||||
return {
|
||||
summary = "Naming: " .. (data.pattern_name or data.pattern),
|
||||
detail = "Naming convention: " .. (data.description or data.pattern),
|
||||
rule = data.pattern,
|
||||
examples = data.examples,
|
||||
category = "naming",
|
||||
scope = data.scope, -- function, variable, class, file
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "style_pattern" then
|
||||
return {
|
||||
summary = "Style: " .. (data.name or "unnamed"),
|
||||
detail = data.description or "Code style pattern",
|
||||
rule = data.rule,
|
||||
examples = data.examples,
|
||||
category = "style",
|
||||
lang = data.language,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "project_structure" then
|
||||
return {
|
||||
summary = "Structure: " .. (data.pattern or "project layout"),
|
||||
detail = data.description or "Project structure convention",
|
||||
rule = data.rule,
|
||||
category = "structure",
|
||||
paths = data.paths,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "config_change" then
|
||||
return {
|
||||
summary = "Config: " .. (data.setting or "setting change"),
|
||||
detail = "Configuration: " .. (data.description or data.setting),
|
||||
before = data.before,
|
||||
after = data.after,
|
||||
category = "config",
|
||||
file = event.file,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Check if convention should be learned
|
||||
---@param data table Extracted data
|
||||
---@return boolean
|
||||
function M.should_learn(data)
|
||||
if not data.summary then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip very vague conventions
|
||||
if not data.detail or #data.detail < 5 then
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Create node from convention data
|
||||
---@param data table Extracted data
|
||||
---@return table Node creation params
|
||||
function M.create_node_params(data)
|
||||
local detail = data.detail or ""
|
||||
|
||||
-- Add examples if available
|
||||
if data.examples and #data.examples > 0 then
|
||||
detail = detail .. "\n\nExamples:"
|
||||
for _, ex in ipairs(data.examples) do
|
||||
detail = detail .. "\n- " .. tostring(ex)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add rule if available
|
||||
if data.rule then
|
||||
detail = detail .. "\n\nRule: " .. tostring(data.rule)
|
||||
end
|
||||
|
||||
return {
|
||||
node_type = types.NODE_TYPES.CONVENTION,
|
||||
content = {
|
||||
s = data.summary:sub(1, 200),
|
||||
d = detail,
|
||||
lang = data.lang,
|
||||
},
|
||||
context = {
|
||||
f = data.file,
|
||||
sym = data.scope and { data.scope } or nil,
|
||||
},
|
||||
opts = {
|
||||
weight = 0.6,
|
||||
source = types.SOURCES.AUTO,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Find related conventions
|
||||
---@param data table Extracted data
|
||||
---@param query_fn function Query function
|
||||
---@return string[] Related node IDs
|
||||
function M.find_related(data, query_fn)
|
||||
local related = {}
|
||||
|
||||
-- Find conventions in same category
|
||||
if data.category then
|
||||
local similar = query_fn({
|
||||
query = data.category,
|
||||
types = { types.NODE_TYPES.CONVENTION },
|
||||
limit = 5,
|
||||
})
|
||||
for _, node in ipairs(similar) do
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
|
||||
-- Find patterns that follow this convention
|
||||
if data.rule then
|
||||
local patterns = query_fn({
|
||||
query = data.rule,
|
||||
types = { types.NODE_TYPES.PATTERN },
|
||||
limit = 3,
|
||||
})
|
||||
for _, node in ipairs(patterns) do
|
||||
if not vim.tbl_contains(related, node.id) then
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return related
|
||||
end
|
||||
|
||||
--- Detect naming convention from symbol names
|
||||
---@param symbols string[] Symbol names to analyze
|
||||
---@return table|nil Detected convention
|
||||
function M.detect_naming(symbols)
|
||||
if not symbols or #symbols < 3 then
|
||||
return nil
|
||||
end
|
||||
|
||||
local patterns = {
|
||||
snake_case = 0,
|
||||
camelCase = 0,
|
||||
PascalCase = 0,
|
||||
SCREAMING_SNAKE = 0,
|
||||
kebab_case = 0,
|
||||
}
|
||||
|
||||
for _, sym in ipairs(symbols) do
|
||||
if sym:match("^[a-z][a-z0-9_]*$") then
|
||||
patterns.snake_case = patterns.snake_case + 1
|
||||
elseif sym:match("^[a-z][a-zA-Z0-9]*$") then
|
||||
patterns.camelCase = patterns.camelCase + 1
|
||||
elseif sym:match("^[A-Z][a-zA-Z0-9]*$") then
|
||||
patterns.PascalCase = patterns.PascalCase + 1
|
||||
elseif sym:match("^[A-Z][A-Z0-9_]*$") then
|
||||
patterns.SCREAMING_SNAKE = patterns.SCREAMING_SNAKE + 1
|
||||
elseif sym:match("^[a-z][a-z0-9%-]*$") then
|
||||
patterns.kebab_case = patterns.kebab_case + 1
|
||||
end
|
||||
end
|
||||
|
||||
-- Find dominant pattern
|
||||
local max_count = 0
|
||||
local dominant = nil
|
||||
|
||||
for pattern, count in pairs(patterns) do
|
||||
if count > max_count then
|
||||
max_count = count
|
||||
dominant = pattern
|
||||
end
|
||||
end
|
||||
|
||||
if dominant and max_count >= #symbols * 0.6 then
|
||||
return {
|
||||
pattern = dominant,
|
||||
confidence = max_count / #symbols,
|
||||
sample_size = #symbols,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
return M
|
||||
213
lua/codetyper/core/memory/learners/correction.lua
Normal file
213
lua/codetyper/core/memory/learners/correction.lua
Normal file
@@ -0,0 +1,213 @@
|
||||
--- Brain Correction Learner
|
||||
--- Learns from user corrections and edits
|
||||
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Detect if event is a correction
|
||||
---@param event LearnEvent Learning event
|
||||
---@return boolean
|
||||
function M.detect(event)
|
||||
local valid_types = {
|
||||
"user_correction",
|
||||
"code_rejected",
|
||||
"code_modified",
|
||||
"suggestion_rejected",
|
||||
}
|
||||
|
||||
for _, t in ipairs(valid_types) do
|
||||
if event.type == t then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Extract correction data from event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return table|nil Extracted data
|
||||
function M.extract(event)
|
||||
local data = event.data or {}
|
||||
|
||||
if event.type == "user_correction" then
|
||||
return {
|
||||
summary = "Correction: " .. (data.error_type or "user edit"),
|
||||
detail = data.description or "User corrected the generated code",
|
||||
before = data.before,
|
||||
after = data.after,
|
||||
error_type = data.error_type,
|
||||
file = event.file,
|
||||
function_name = data.function_name,
|
||||
lines = data.lines,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "code_rejected" then
|
||||
return {
|
||||
summary = "Rejected: " .. (data.reason or "not accepted"),
|
||||
detail = data.description or "User rejected generated code",
|
||||
rejected_code = data.code,
|
||||
reason = data.reason,
|
||||
file = event.file,
|
||||
intent = data.intent,
|
||||
}
|
||||
end
|
||||
|
||||
if event.type == "code_modified" then
|
||||
local changes = M.analyze_changes(data.before, data.after)
|
||||
return {
|
||||
summary = "Modified: " .. changes.summary,
|
||||
detail = changes.detail,
|
||||
before = data.before,
|
||||
after = data.after,
|
||||
change_type = changes.type,
|
||||
file = event.file,
|
||||
lines = data.lines,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Analyze changes between before/after code
|
||||
---@param before string Before code
|
||||
---@param after string After code
|
||||
---@return table Change analysis
|
||||
function M.analyze_changes(before, after)
|
||||
before = before or ""
|
||||
after = after or ""
|
||||
|
||||
local before_lines = vim.split(before, "\n")
|
||||
local after_lines = vim.split(after, "\n")
|
||||
|
||||
local added = 0
|
||||
local removed = 0
|
||||
local modified = 0
|
||||
|
||||
-- Simple line-based diff
|
||||
local max_lines = math.max(#before_lines, #after_lines)
|
||||
for i = 1, max_lines do
|
||||
local b = before_lines[i]
|
||||
local a = after_lines[i]
|
||||
|
||||
if b == nil and a ~= nil then
|
||||
added = added + 1
|
||||
elseif b ~= nil and a == nil then
|
||||
removed = removed + 1
|
||||
elseif b ~= a then
|
||||
modified = modified + 1
|
||||
end
|
||||
end
|
||||
|
||||
local change_type = "mixed"
|
||||
if added > 0 and removed == 0 and modified == 0 then
|
||||
change_type = "addition"
|
||||
elseif removed > 0 and added == 0 and modified == 0 then
|
||||
change_type = "deletion"
|
||||
elseif modified > 0 and added == 0 and removed == 0 then
|
||||
change_type = "modification"
|
||||
end
|
||||
|
||||
return {
|
||||
type = change_type,
|
||||
summary = string.format("+%d -%d ~%d lines", added, removed, modified),
|
||||
detail = string.format("Added %d, removed %d, modified %d lines", added, removed, modified),
|
||||
stats = {
|
||||
added = added,
|
||||
removed = removed,
|
||||
modified = modified,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Check if correction should be learned
|
||||
---@param data table Extracted data
|
||||
---@return boolean
|
||||
function M.should_learn(data)
|
||||
-- Always learn corrections - they're valuable
|
||||
if not data.summary then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip trivial changes
|
||||
if data.before and data.after then
|
||||
-- Skip if only whitespace changed
|
||||
local before_trimmed = data.before:gsub("%s+", "")
|
||||
local after_trimmed = data.after:gsub("%s+", "")
|
||||
if before_trimmed == after_trimmed then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Create node from correction data
|
||||
---@param data table Extracted data
|
||||
---@return table Node creation params
|
||||
function M.create_node_params(data)
|
||||
local detail = data.detail or ""
|
||||
|
||||
-- Include before/after in detail for learning
|
||||
if data.before and data.after then
|
||||
detail = detail .. "\n\nBefore:\n" .. data.before:sub(1, 500)
|
||||
detail = detail .. "\n\nAfter:\n" .. data.after:sub(1, 500)
|
||||
end
|
||||
|
||||
return {
|
||||
node_type = types.NODE_TYPES.CORRECTION,
|
||||
content = {
|
||||
s = data.summary:sub(1, 200),
|
||||
d = detail,
|
||||
code = data.after or data.rejected_code,
|
||||
lang = data.lang,
|
||||
},
|
||||
context = {
|
||||
f = data.file,
|
||||
fn = data.function_name,
|
||||
ln = data.lines,
|
||||
},
|
||||
opts = {
|
||||
weight = 0.7, -- Corrections are valuable
|
||||
source = types.SOURCES.USER,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Find related nodes for corrections
|
||||
---@param data table Extracted data
|
||||
---@param query_fn function Query function
|
||||
---@return string[] Related node IDs
|
||||
function M.find_related(data, query_fn)
|
||||
local related = {}
|
||||
|
||||
-- Find patterns that might be corrected
|
||||
if data.before then
|
||||
local similar = query_fn({
|
||||
query = data.before:sub(1, 100),
|
||||
types = { types.NODE_TYPES.PATTERN },
|
||||
limit = 3,
|
||||
})
|
||||
for _, node in ipairs(similar) do
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
|
||||
-- Find other corrections in same file
|
||||
if data.file then
|
||||
local file_corrections = query_fn({
|
||||
file = data.file,
|
||||
types = { types.NODE_TYPES.CORRECTION },
|
||||
limit = 3,
|
||||
})
|
||||
for _, node in ipairs(file_corrections) do
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
|
||||
return related
|
||||
end
|
||||
|
||||
return M
|
||||
232
lua/codetyper/core/memory/learners/init.lua
Normal file
232
lua/codetyper/core/memory/learners/init.lua
Normal file
@@ -0,0 +1,232 @@
|
||||
--- Brain Learners Coordinator
|
||||
--- Routes learning events to appropriate learners
|
||||
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
-- Lazy load learners
|
||||
local function get_pattern_learner()
|
||||
return require("codetyper.core.memory.learners.pattern")
|
||||
end
|
||||
|
||||
local function get_correction_learner()
|
||||
return require("codetyper.core.memory.learners.correction")
|
||||
end
|
||||
|
||||
local function get_convention_learner()
|
||||
return require("codetyper.core.memory.learners.convention")
|
||||
end
|
||||
|
||||
--- All available learners
|
||||
local LEARNERS = {
|
||||
{ name = "pattern", loader = get_pattern_learner },
|
||||
{ name = "correction", loader = get_correction_learner },
|
||||
{ name = "convention", loader = get_convention_learner },
|
||||
}
|
||||
|
||||
--- Process a learning event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return string|nil Created node ID
|
||||
function M.process(event)
|
||||
if not event or not event.type then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Add timestamp if missing
|
||||
event.timestamp = event.timestamp or os.time()
|
||||
|
||||
-- Find matching learner
|
||||
for _, learner_info in ipairs(LEARNERS) do
|
||||
local learner = learner_info.loader()
|
||||
|
||||
if learner.detect(event) then
|
||||
return M.learn_with(learner, event)
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle generic feedback events
|
||||
if event.type == "user_feedback" then
|
||||
return M.process_feedback(event)
|
||||
end
|
||||
|
||||
-- Handle session events
|
||||
if event.type == "session_start" or event.type == "session_end" then
|
||||
return M.process_session(event)
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Learn using a specific learner
|
||||
---@param learner table Learner module
|
||||
---@param event LearnEvent Learning event
|
||||
---@return string|nil Created node ID
|
||||
function M.learn_with(learner, event)
|
||||
-- Extract data
|
||||
local extracted = learner.extract(event)
|
||||
if not extracted then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Handle multiple extractions (e.g., from file indexing)
|
||||
if vim.islist(extracted) then
|
||||
local node_ids = {}
|
||||
for _, data in ipairs(extracted) do
|
||||
local node_id = M.create_learning(learner, data, event)
|
||||
if node_id then
|
||||
table.insert(node_ids, node_id)
|
||||
end
|
||||
end
|
||||
return node_ids[1] -- Return first for now
|
||||
end
|
||||
|
||||
return M.create_learning(learner, extracted, event)
|
||||
end
|
||||
|
||||
--- Create a learning from extracted data
|
||||
---@param learner table Learner module
|
||||
---@param data table Extracted data
|
||||
---@param event LearnEvent Original event
|
||||
---@return string|nil Created node ID
|
||||
function M.create_learning(learner, data, event)
|
||||
-- Check if should learn
|
||||
if not learner.should_learn(data) then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Get node params
|
||||
local params = learner.create_node_params(data)
|
||||
|
||||
-- Get graph module
|
||||
local graph = require("codetyper.core.memory.graph")
|
||||
|
||||
-- Find related nodes
|
||||
local related_ids = {}
|
||||
if learner.find_related then
|
||||
related_ids = learner.find_related(data, function(opts)
|
||||
return graph.query.execute(opts).nodes
|
||||
end)
|
||||
end
|
||||
|
||||
-- Create the learning
|
||||
local node = graph.add_learning(params.node_type, params.content, params.context, related_ids)
|
||||
|
||||
-- Update weight if specified
|
||||
if params.opts and params.opts.weight then
|
||||
graph.node.update(node.id, { sc = { w = params.opts.weight } })
|
||||
end
|
||||
|
||||
return node.id
|
||||
end
|
||||
|
||||
--- Process feedback event
|
||||
---@param event LearnEvent Feedback event
|
||||
---@return string|nil Created node ID
|
||||
function M.process_feedback(event)
|
||||
local data = event.data or {}
|
||||
local graph = require("codetyper.core.memory.graph")
|
||||
|
||||
local content = {
|
||||
s = "Feedback: " .. (data.feedback or "unknown"),
|
||||
d = data.description or ("User " .. (data.feedback or "gave feedback")),
|
||||
}
|
||||
|
||||
local context = {
|
||||
f = event.file,
|
||||
}
|
||||
|
||||
-- If feedback references a node, update it
|
||||
if data.node_id then
|
||||
local node = graph.node.get(data.node_id)
|
||||
if node then
|
||||
local weight_delta = data.feedback == "accepted" and 0.1 or -0.1
|
||||
local new_weight = math.max(0, math.min(1, node.sc.w + weight_delta))
|
||||
|
||||
graph.node.update(data.node_id, {
|
||||
sc = { w = new_weight },
|
||||
})
|
||||
|
||||
-- Record usage
|
||||
graph.node.record_usage(data.node_id, data.feedback == "accepted")
|
||||
|
||||
-- Create feedback node linked to original
|
||||
local fb_node = graph.add_learning(types.NODE_TYPES.FEEDBACK, content, context, { data.node_id })
|
||||
|
||||
return fb_node.id
|
||||
end
|
||||
end
|
||||
|
||||
-- Create standalone feedback node
|
||||
local node = graph.add_learning(types.NODE_TYPES.FEEDBACK, content, context)
|
||||
return node.id
|
||||
end
|
||||
|
||||
--- Process session event
|
||||
---@param event LearnEvent Session event
|
||||
---@return string|nil Created node ID
|
||||
function M.process_session(event)
|
||||
local data = event.data or {}
|
||||
local graph = require("codetyper.core.memory.graph")
|
||||
|
||||
local content = {
|
||||
s = event.type == "session_start" and "Session started" or "Session ended",
|
||||
d = data.description or event.type,
|
||||
}
|
||||
|
||||
if event.type == "session_end" and data.stats then
|
||||
content.d = content.d .. "\n\nStats:"
|
||||
content.d = content.d .. "\n- Completions: " .. (data.stats.completions or 0)
|
||||
content.d = content.d .. "\n- Corrections: " .. (data.stats.corrections or 0)
|
||||
content.d = content.d .. "\n- Files: " .. (data.stats.files or 0)
|
||||
end
|
||||
|
||||
local node = graph.add_learning(types.NODE_TYPES.SESSION, content, {})
|
||||
|
||||
-- Link to recent session nodes
|
||||
if event.type == "session_end" then
|
||||
local recent = graph.query.by_time_range(os.time() - 3600, os.time(), 20) -- Last hour
|
||||
local session_nodes = {}
|
||||
|
||||
for _, n in ipairs(recent) do
|
||||
if n.id ~= node.id then
|
||||
table.insert(session_nodes, n.id)
|
||||
end
|
||||
end
|
||||
|
||||
-- Create temporal links
|
||||
if #session_nodes > 0 then
|
||||
graph.link_temporal(session_nodes)
|
||||
end
|
||||
end
|
||||
|
||||
return node.id
|
||||
end
|
||||
|
||||
--- Batch process multiple events
|
||||
---@param events LearnEvent[] Events to process
|
||||
---@return string[] Created node IDs
|
||||
function M.batch_process(events)
|
||||
local node_ids = {}
|
||||
|
||||
for _, event in ipairs(events) do
|
||||
local node_id = M.process(event)
|
||||
if node_id then
|
||||
table.insert(node_ids, node_id)
|
||||
end
|
||||
end
|
||||
|
||||
return node_ids
|
||||
end
|
||||
|
||||
--- Get learner names
|
||||
---@return string[]
|
||||
function M.get_learner_names()
|
||||
local names = {}
|
||||
for _, learner in ipairs(LEARNERS) do
|
||||
table.insert(names, learner.name)
|
||||
end
|
||||
return names
|
||||
end
|
||||
|
||||
return M
|
||||
176
lua/codetyper/core/memory/learners/pattern.lua
Normal file
176
lua/codetyper/core/memory/learners/pattern.lua
Normal file
@@ -0,0 +1,176 @@
|
||||
--- Brain Pattern Learner
|
||||
--- Detects and learns code patterns
|
||||
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Detect if event contains a learnable pattern
|
||||
---@param event LearnEvent Learning event
|
||||
---@return boolean
|
||||
function M.detect(event)
|
||||
if not event or not event.type then
|
||||
return false
|
||||
end
|
||||
|
||||
local valid_types = {
|
||||
"code_completion",
|
||||
"file_indexed",
|
||||
"code_analyzed",
|
||||
"pattern_detected",
|
||||
}
|
||||
|
||||
for _, t in ipairs(valid_types) do
|
||||
if event.type == t then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Extract pattern data from event
|
||||
---@param event LearnEvent Learning event
|
||||
---@return table|nil Extracted data
|
||||
function M.extract(event)
|
||||
local data = event.data or {}
|
||||
|
||||
-- Extract from code completion
|
||||
if event.type == "code_completion" then
|
||||
return {
|
||||
summary = "Code pattern: " .. (data.intent or "unknown"),
|
||||
detail = data.code or data.content or "",
|
||||
code = data.code,
|
||||
lang = data.language,
|
||||
file = event.file,
|
||||
function_name = data.function_name,
|
||||
symbols = data.symbols,
|
||||
}
|
||||
end
|
||||
|
||||
-- Extract from file indexing
|
||||
if event.type == "file_indexed" then
|
||||
local patterns = {}
|
||||
|
||||
-- Extract function patterns
|
||||
if data.functions then
|
||||
for _, func in ipairs(data.functions) do
|
||||
table.insert(patterns, {
|
||||
summary = "Function: " .. func.name,
|
||||
detail = func.signature or func.name,
|
||||
code = func.body,
|
||||
lang = data.language,
|
||||
file = event.file,
|
||||
function_name = func.name,
|
||||
lines = func.lines,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Extract class patterns
|
||||
if data.classes then
|
||||
for _, class in ipairs(data.classes) do
|
||||
table.insert(patterns, {
|
||||
summary = "Class: " .. class.name,
|
||||
detail = class.description or class.name,
|
||||
lang = data.language,
|
||||
file = event.file,
|
||||
symbols = { class.name },
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return #patterns > 0 and patterns or nil
|
||||
end
|
||||
|
||||
-- Extract from explicit pattern detection
|
||||
if event.type == "pattern_detected" then
|
||||
return {
|
||||
summary = data.name or "Unnamed pattern",
|
||||
detail = data.description or data.name or "",
|
||||
code = data.example,
|
||||
lang = data.language,
|
||||
file = event.file,
|
||||
symbols = data.symbols,
|
||||
}
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Check if pattern should be learned
|
||||
---@param data table Extracted data
|
||||
---@return boolean
|
||||
function M.should_learn(data)
|
||||
-- Skip if no meaningful content
|
||||
if not data.summary or data.summary == "" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip very short patterns
|
||||
if data.detail and #data.detail < 10 then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip auto-generated patterns
|
||||
if data.summary:match("^%s*$") then
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Create node from pattern data
|
||||
---@param data table Extracted data
|
||||
---@return table Node creation params
|
||||
function M.create_node_params(data)
|
||||
return {
|
||||
node_type = types.NODE_TYPES.PATTERN,
|
||||
content = {
|
||||
s = data.summary:sub(1, 200), -- Limit summary
|
||||
d = data.detail,
|
||||
code = data.code,
|
||||
lang = data.lang,
|
||||
},
|
||||
context = {
|
||||
f = data.file,
|
||||
fn = data.function_name,
|
||||
ln = data.lines,
|
||||
sym = data.symbols,
|
||||
},
|
||||
opts = {
|
||||
weight = 0.5,
|
||||
source = types.SOURCES.AUTO,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Find potentially related nodes
|
||||
---@param data table Extracted data
|
||||
---@param query_fn function Query function
|
||||
---@return string[] Related node IDs
|
||||
function M.find_related(data, query_fn)
|
||||
local related = {}
|
||||
|
||||
-- Find nodes in same file
|
||||
if data.file then
|
||||
local file_nodes = query_fn({ file = data.file, limit = 5 })
|
||||
for _, node in ipairs(file_nodes) do
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
|
||||
-- Find semantically similar
|
||||
if data.summary then
|
||||
local similar = query_fn({ query = data.summary, limit = 3 })
|
||||
for _, node in ipairs(similar) do
|
||||
if not vim.tbl_contains(related, node.id) then
|
||||
table.insert(related, node.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return related
|
||||
end
|
||||
|
||||
return M
|
||||
279
lua/codetyper/core/memory/output/formatter.lua
Normal file
279
lua/codetyper/core/memory/output/formatter.lua
Normal file
@@ -0,0 +1,279 @@
|
||||
--- Brain Output Formatter
|
||||
--- LLM-optimized output formatting
|
||||
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Estimate token count (rough approximation)
|
||||
---@param text string Text to estimate
|
||||
---@return number Estimated tokens
|
||||
function M.estimate_tokens(text)
|
||||
if not text then
|
||||
return 0
|
||||
end
|
||||
-- Rough estimate: 1 token ~= 4 characters
|
||||
return math.ceil(#text / 4)
|
||||
end
|
||||
|
||||
--- Format nodes to compact text format
|
||||
---@param result QueryResult Query result
|
||||
---@param opts? table Options
|
||||
---@return string Formatted output
|
||||
function M.to_compact(result, opts)
|
||||
opts = opts or {}
|
||||
local max_tokens = opts.max_tokens or 4000
|
||||
local lines = {}
|
||||
local current_tokens = 0
|
||||
|
||||
-- Header
|
||||
table.insert(lines, "---BRAIN_CONTEXT---")
|
||||
if opts.query then
|
||||
table.insert(lines, "Q: " .. opts.query)
|
||||
end
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Add nodes by relevance (already sorted)
|
||||
table.insert(lines, "Learnings:")
|
||||
|
||||
for i, node in ipairs(result.nodes) do
|
||||
-- Format: [idx] TYPE | w:0.85 u:5 | Summary
|
||||
local line = string.format(
|
||||
"[%d] %s | w:%.2f u:%d | %s",
|
||||
i,
|
||||
(node.t or "?"):upper(),
|
||||
node.sc.w or 0,
|
||||
node.sc.u or 0,
|
||||
(node.c.s or ""):sub(1, 100)
|
||||
)
|
||||
|
||||
local line_tokens = M.estimate_tokens(line)
|
||||
if current_tokens + line_tokens > max_tokens - 100 then
|
||||
table.insert(lines, "... (truncated)")
|
||||
break
|
||||
end
|
||||
|
||||
table.insert(lines, line)
|
||||
current_tokens = current_tokens + line_tokens
|
||||
|
||||
-- Add context if file-related
|
||||
if node.ctx and node.ctx.f then
|
||||
local ctx_line = " @ " .. node.ctx.f
|
||||
if node.ctx.fn then
|
||||
ctx_line = ctx_line .. ":" .. node.ctx.fn
|
||||
end
|
||||
if node.ctx.ln then
|
||||
ctx_line = ctx_line .. " L" .. node.ctx.ln[1]
|
||||
end
|
||||
table.insert(lines, ctx_line)
|
||||
current_tokens = current_tokens + M.estimate_tokens(ctx_line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add connections if space allows
|
||||
if #result.edges > 0 and current_tokens < max_tokens - 200 then
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Connections:")
|
||||
|
||||
for _, edge in ipairs(result.edges) do
|
||||
if current_tokens >= max_tokens - 50 then
|
||||
break
|
||||
end
|
||||
|
||||
local conn_line = string.format(
|
||||
" %s --%s(%.2f)--> %s",
|
||||
edge.s:sub(-8),
|
||||
edge.ty,
|
||||
edge.p.w or 0.5,
|
||||
edge.t:sub(-8)
|
||||
)
|
||||
table.insert(lines, conn_line)
|
||||
current_tokens = current_tokens + M.estimate_tokens(conn_line)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(lines, "---END_CONTEXT---")
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Format nodes to JSON format
|
||||
---@param result QueryResult Query result
|
||||
---@param opts? table Options
|
||||
---@return string JSON output
|
||||
function M.to_json(result, opts)
|
||||
opts = opts or {}
|
||||
local max_tokens = opts.max_tokens or 4000
|
||||
|
||||
local output = {
|
||||
_s = "brain-v1", -- Schema
|
||||
q = opts.query,
|
||||
l = {}, -- Learnings
|
||||
c = {}, -- Connections
|
||||
}
|
||||
|
||||
local current_tokens = 50 -- Base overhead
|
||||
|
||||
-- Add nodes
|
||||
for _, node in ipairs(result.nodes) do
|
||||
local entry = {
|
||||
t = node.t,
|
||||
s = (node.c.s or ""):sub(1, 150),
|
||||
w = node.sc.w,
|
||||
u = node.sc.u,
|
||||
}
|
||||
|
||||
if node.ctx and node.ctx.f then
|
||||
entry.f = node.ctx.f
|
||||
end
|
||||
|
||||
local entry_tokens = M.estimate_tokens(vim.json.encode(entry))
|
||||
if current_tokens + entry_tokens > max_tokens - 100 then
|
||||
break
|
||||
end
|
||||
|
||||
table.insert(output.l, entry)
|
||||
current_tokens = current_tokens + entry_tokens
|
||||
end
|
||||
|
||||
-- Add edges if space
|
||||
if current_tokens < max_tokens - 200 then
|
||||
for _, edge in ipairs(result.edges) do
|
||||
if current_tokens >= max_tokens - 50 then
|
||||
break
|
||||
end
|
||||
|
||||
local e = {
|
||||
s = edge.s:sub(-8),
|
||||
t = edge.t:sub(-8),
|
||||
r = edge.ty,
|
||||
w = edge.p.w,
|
||||
}
|
||||
|
||||
table.insert(output.c, e)
|
||||
current_tokens = current_tokens + 30
|
||||
end
|
||||
end
|
||||
|
||||
return vim.json.encode(output)
|
||||
end
|
||||
|
||||
--- Format as natural language
|
||||
---@param result QueryResult Query result
|
||||
---@param opts? table Options
|
||||
---@return string Natural language output
|
||||
function M.to_natural(result, opts)
|
||||
opts = opts or {}
|
||||
local max_tokens = opts.max_tokens or 4000
|
||||
local lines = {}
|
||||
local current_tokens = 0
|
||||
|
||||
if #result.nodes == 0 then
|
||||
return "No relevant learnings found."
|
||||
end
|
||||
|
||||
table.insert(lines, "Based on previous learnings:")
|
||||
table.insert(lines, "")
|
||||
|
||||
-- Group by type
|
||||
local by_type = {}
|
||||
for _, node in ipairs(result.nodes) do
|
||||
by_type[node.t] = by_type[node.t] or {}
|
||||
table.insert(by_type[node.t], node)
|
||||
end
|
||||
|
||||
local type_names = {
|
||||
[types.NODE_TYPES.PATTERN] = "Code Patterns",
|
||||
[types.NODE_TYPES.CORRECTION] = "Previous Corrections",
|
||||
[types.NODE_TYPES.CONVENTION] = "Project Conventions",
|
||||
[types.NODE_TYPES.DECISION] = "Architectural Decisions",
|
||||
[types.NODE_TYPES.FEEDBACK] = "User Preferences",
|
||||
[types.NODE_TYPES.SESSION] = "Session Context",
|
||||
}
|
||||
|
||||
for node_type, nodes in pairs(by_type) do
|
||||
local type_name = type_names[node_type] or node_type
|
||||
|
||||
table.insert(lines, "**" .. type_name .. "**")
|
||||
|
||||
for _, node in ipairs(nodes) do
|
||||
if current_tokens >= max_tokens - 100 then
|
||||
table.insert(lines, "...")
|
||||
goto done
|
||||
end
|
||||
|
||||
local bullet = string.format("- %s (confidence: %.0f%%)", node.c.s or "?", (node.sc.w or 0) * 100)
|
||||
|
||||
table.insert(lines, bullet)
|
||||
current_tokens = current_tokens + M.estimate_tokens(bullet)
|
||||
|
||||
-- Add detail if high weight
|
||||
if node.sc.w > 0.7 and node.c.d and #node.c.d > #(node.c.s or "") then
|
||||
local detail = " " .. node.c.d:sub(1, 150)
|
||||
if #node.c.d > 150 then
|
||||
detail = detail .. "..."
|
||||
end
|
||||
table.insert(lines, detail)
|
||||
current_tokens = current_tokens + M.estimate_tokens(detail)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
::done::
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Format context chain for explanation
|
||||
---@param chain table[] Chain of nodes and edges
|
||||
---@return string Chain explanation
|
||||
function M.format_chain(chain)
|
||||
local lines = {}
|
||||
|
||||
for i, item in ipairs(chain) do
|
||||
if item.node then
|
||||
local prefix = i == 1 and "" or " -> "
|
||||
table.insert(lines, string.format("%s[%s] %s (w:%.2f)", prefix, item.node.t:upper(), item.node.c.s:sub(1, 50), item.node.sc.w))
|
||||
end
|
||||
if item.edge then
|
||||
table.insert(lines, string.format(" via %s (w:%.2f)", item.edge.ty, item.edge.p.w))
|
||||
end
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Compress output to fit token budget
|
||||
---@param text string Text to compress
|
||||
---@param max_tokens number Token budget
|
||||
---@return string Compressed text
|
||||
function M.compress(text, max_tokens)
|
||||
local current = M.estimate_tokens(text)
|
||||
|
||||
if current <= max_tokens then
|
||||
return text
|
||||
end
|
||||
|
||||
-- Simple truncation with ellipsis
|
||||
local ratio = max_tokens / current
|
||||
local target_chars = math.floor(#text * ratio * 0.9) -- 10% buffer
|
||||
|
||||
return text:sub(1, target_chars) .. "\n...(truncated)"
|
||||
end
|
||||
|
||||
--- Get minimal context for quick lookups
|
||||
---@param nodes Node[] Nodes to format
|
||||
---@return string Minimal context
|
||||
function M.minimal(nodes)
|
||||
local items = {}
|
||||
|
||||
for _, node in ipairs(nodes) do
|
||||
table.insert(items, string.format("%s:%s", node.t, (node.c.s or ""):sub(1, 40)))
|
||||
end
|
||||
|
||||
return table.concat(items, " | ")
|
||||
end
|
||||
|
||||
return M
|
||||
166
lua/codetyper/core/memory/output/init.lua
Normal file
166
lua/codetyper/core/memory/output/init.lua
Normal file
@@ -0,0 +1,166 @@
|
||||
--- Brain Output Coordinator
|
||||
--- Manages LLM context generation
|
||||
|
||||
local formatter = require("codetyper.core.memory.output.formatter")
|
||||
|
||||
local M = {}
|
||||
|
||||
-- Re-export formatter
|
||||
M.formatter = formatter
|
||||
|
||||
--- Default token budget
|
||||
local DEFAULT_MAX_TOKENS = 4000
|
||||
|
||||
--- Generate context for LLM prompt
|
||||
---@param opts? table Options
|
||||
---@return string Context string
|
||||
function M.generate(opts)
|
||||
opts = opts or {}
|
||||
|
||||
local brain = require("codetyper.core.memory")
|
||||
if not brain.is_initialized() then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Build query opts
|
||||
local query_opts = {
|
||||
query = opts.query,
|
||||
file = opts.file,
|
||||
types = opts.types,
|
||||
since = opts.since,
|
||||
limit = opts.limit or 30,
|
||||
depth = opts.depth or 2,
|
||||
max_tokens = opts.max_tokens or DEFAULT_MAX_TOKENS,
|
||||
}
|
||||
|
||||
-- Execute query
|
||||
local result = brain.query(query_opts)
|
||||
|
||||
if #result.nodes == 0 then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Format based on style
|
||||
local format = opts.format or "compact"
|
||||
|
||||
if format == "json" then
|
||||
return formatter.to_json(result, query_opts)
|
||||
elseif format == "natural" then
|
||||
return formatter.to_natural(result, query_opts)
|
||||
else
|
||||
return formatter.to_compact(result, query_opts)
|
||||
end
|
||||
end
|
||||
|
||||
--- Generate context for a specific file
|
||||
---@param filepath string File path
|
||||
---@param opts? table Options
|
||||
---@return string Context string
|
||||
function M.for_file(filepath, opts)
|
||||
opts = opts or {}
|
||||
opts.file = filepath
|
||||
return M.generate(opts)
|
||||
end
|
||||
|
||||
--- Generate context for current buffer
|
||||
---@param opts? table Options
|
||||
---@return string Context string
|
||||
function M.for_current_buffer(opts)
|
||||
local filepath = vim.fn.expand("%:p")
|
||||
if filepath == "" then
|
||||
return ""
|
||||
end
|
||||
return M.for_file(filepath, opts)
|
||||
end
|
||||
|
||||
--- Generate context for a query/prompt
|
||||
---@param query string Query text
|
||||
---@param opts? table Options
|
||||
---@return string Context string
|
||||
function M.for_query(query, opts)
|
||||
opts = opts or {}
|
||||
opts.query = query
|
||||
return M.generate(opts)
|
||||
end
|
||||
|
||||
--- Get context for LLM system prompt
|
||||
---@param opts? table Options
|
||||
---@return string System context
|
||||
function M.system_context(opts)
|
||||
opts = opts or {}
|
||||
opts.limit = opts.limit or 20
|
||||
opts.format = opts.format or "compact"
|
||||
|
||||
local context = M.generate(opts)
|
||||
|
||||
if context == "" then
|
||||
return ""
|
||||
end
|
||||
|
||||
return [[
|
||||
The following context contains learned patterns and conventions from this project:
|
||||
|
||||
]] .. context .. [[
|
||||
|
||||
|
||||
Use this context to inform your responses, following established patterns and conventions.
|
||||
]]
|
||||
end
|
||||
|
||||
--- Get relevant context for code completion
|
||||
---@param prefix string Code before cursor
|
||||
---@param suffix string Code after cursor
|
||||
---@param filepath string Current file
|
||||
---@return string Context
|
||||
function M.for_completion(prefix, suffix, filepath)
|
||||
-- Extract relevant terms from code
|
||||
local terms = {}
|
||||
|
||||
-- Get function/class names
|
||||
for word in prefix:gmatch("[A-Z][a-zA-Z0-9]+") do
|
||||
table.insert(terms, word)
|
||||
end
|
||||
for word in prefix:gmatch("function%s+([a-zA-Z_][a-zA-Z0-9_]*)") do
|
||||
table.insert(terms, word)
|
||||
end
|
||||
|
||||
local query = table.concat(terms, " ")
|
||||
|
||||
return M.generate({
|
||||
query = query,
|
||||
file = filepath,
|
||||
limit = 15,
|
||||
max_tokens = 2000,
|
||||
format = "compact",
|
||||
})
|
||||
end
|
||||
|
||||
--- Check if context is available
|
||||
---@return boolean
|
||||
function M.has_context()
|
||||
local brain = require("codetyper.core.memory")
|
||||
if not brain.is_initialized() then
|
||||
return false
|
||||
end
|
||||
|
||||
local stats = brain.stats()
|
||||
return stats.node_count > 0
|
||||
end
|
||||
|
||||
--- Get context stats
|
||||
---@return table Stats
|
||||
function M.stats()
|
||||
local brain = require("codetyper.core.memory")
|
||||
if not brain.is_initialized() then
|
||||
return { available = false }
|
||||
end
|
||||
|
||||
local stats = brain.stats()
|
||||
return {
|
||||
available = true,
|
||||
node_count = stats.node_count,
|
||||
edge_count = stats.edge_count,
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
338
lua/codetyper/core/memory/storage.lua
Normal file
338
lua/codetyper/core/memory/storage.lua
Normal file
@@ -0,0 +1,338 @@
|
||||
--- Brain Storage Layer
|
||||
--- Cache + disk persistence with lazy loading
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local types = require("codetyper.core.memory.types")
|
||||
|
||||
local M = {}
|
||||
|
||||
--- In-memory cache keyed by project root
|
||||
---@type table<string, table>
|
||||
local cache = {}
|
||||
|
||||
--- Dirty flags for pending writes
|
||||
---@type table<string, table<string, boolean>>
|
||||
local dirty = {}
|
||||
|
||||
--- Debounce timers
|
||||
---@type table<string, userdata>
|
||||
local timers = {}
|
||||
|
||||
local DEBOUNCE_MS = 500
|
||||
|
||||
--- Get brain directory path for current project
|
||||
---@param root? string Project root (defaults to current)
|
||||
---@return string Brain directory path
|
||||
function M.get_brain_dir(root)
|
||||
root = root or utils.get_project_root()
|
||||
return root .. "/.coder/brain"
|
||||
end
|
||||
|
||||
--- Ensure brain directory structure exists
|
||||
---@param root? string Project root
|
||||
---@return boolean Success
|
||||
function M.ensure_dirs(root)
|
||||
local brain_dir = M.get_brain_dir(root)
|
||||
local dirs = {
|
||||
brain_dir,
|
||||
brain_dir .. "/nodes",
|
||||
brain_dir .. "/indices",
|
||||
brain_dir .. "/deltas",
|
||||
brain_dir .. "/deltas/objects",
|
||||
}
|
||||
for _, dir in ipairs(dirs) do
|
||||
if not utils.ensure_dir(dir) then
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get file path for a storage key
|
||||
---@param key string Storage key (e.g., "meta", "nodes.patterns", "deltas.objects.abc123")
|
||||
---@param root? string Project root
|
||||
---@return string File path
|
||||
function M.get_path(key, root)
|
||||
local brain_dir = M.get_brain_dir(root)
|
||||
local parts = vim.split(key, ".", { plain = true })
|
||||
|
||||
if #parts == 1 then
|
||||
return brain_dir .. "/" .. key .. ".json"
|
||||
elseif #parts == 2 then
|
||||
return brain_dir .. "/" .. parts[1] .. "/" .. parts[2] .. ".json"
|
||||
else
|
||||
return brain_dir .. "/" .. table.concat(parts, "/") .. ".json"
|
||||
end
|
||||
end
|
||||
|
||||
--- Get cache for project
|
||||
---@param root? string Project root
|
||||
---@return table Project cache
|
||||
local function get_cache(root)
|
||||
root = root or utils.get_project_root()
|
||||
if not cache[root] then
|
||||
cache[root] = {}
|
||||
dirty[root] = {}
|
||||
end
|
||||
return cache[root]
|
||||
end
|
||||
|
||||
--- Read JSON from disk
|
||||
---@param filepath string File path
|
||||
---@return table|nil Data or nil on error
|
||||
local function read_json(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content or content == "" then
|
||||
return nil
|
||||
end
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok then
|
||||
return nil
|
||||
end
|
||||
return data
|
||||
end
|
||||
|
||||
--- Write JSON to disk
|
||||
---@param filepath string File path
|
||||
---@param data table Data to write
|
||||
---@return boolean Success
|
||||
local function write_json(filepath, data)
|
||||
local ok, json = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
return utils.write_file(filepath, json)
|
||||
end
|
||||
|
||||
--- Load data from disk into cache
|
||||
---@param key string Storage key
|
||||
---@param root? string Project root
|
||||
---@return table|nil Data or nil
|
||||
function M.load(key, root)
|
||||
root = root or utils.get_project_root()
|
||||
local project_cache = get_cache(root)
|
||||
|
||||
-- Return cached if available
|
||||
if project_cache[key] ~= nil then
|
||||
return project_cache[key]
|
||||
end
|
||||
|
||||
-- Load from disk
|
||||
local filepath = M.get_path(key, root)
|
||||
local data = read_json(filepath)
|
||||
|
||||
-- Cache the result (even nil to avoid repeated reads)
|
||||
project_cache[key] = data or {}
|
||||
|
||||
return project_cache[key]
|
||||
end
|
||||
|
||||
--- Save data to cache and schedule disk write
|
||||
---@param key string Storage key
|
||||
---@param data table Data to save
|
||||
---@param root? string Project root
|
||||
---@param immediate? boolean Skip debounce
|
||||
function M.save(key, data, root, immediate)
|
||||
root = root or utils.get_project_root()
|
||||
local project_cache = get_cache(root)
|
||||
|
||||
-- Update cache
|
||||
project_cache[key] = data
|
||||
dirty[root][key] = true
|
||||
|
||||
if immediate then
|
||||
M.flush(key, root)
|
||||
return
|
||||
end
|
||||
|
||||
-- Debounced write
|
||||
local timer_key = root .. ":" .. key
|
||||
if timers[timer_key] then
|
||||
timers[timer_key]:stop()
|
||||
end
|
||||
|
||||
timers[timer_key] = vim.defer_fn(function()
|
||||
M.flush(key, root)
|
||||
timers[timer_key] = nil
|
||||
end, DEBOUNCE_MS)
|
||||
end
|
||||
|
||||
--- Flush a key to disk immediately
|
||||
---@param key string Storage key
|
||||
---@param root? string Project root
|
||||
---@return boolean Success
|
||||
function M.flush(key, root)
|
||||
root = root or utils.get_project_root()
|
||||
local project_cache = get_cache(root)
|
||||
|
||||
if not dirty[root][key] then
|
||||
return true
|
||||
end
|
||||
|
||||
M.ensure_dirs(root)
|
||||
local filepath = M.get_path(key, root)
|
||||
local data = project_cache[key]
|
||||
|
||||
if data == nil then
|
||||
-- Delete file if data is nil
|
||||
os.remove(filepath)
|
||||
dirty[root][key] = nil
|
||||
return true
|
||||
end
|
||||
|
||||
local success = write_json(filepath, data)
|
||||
if success then
|
||||
dirty[root][key] = nil
|
||||
end
|
||||
return success
|
||||
end
|
||||
|
||||
--- Flush all dirty keys to disk
|
||||
---@param root? string Project root
|
||||
function M.flush_all(root)
|
||||
root = root or utils.get_project_root()
|
||||
if not dirty[root] then
|
||||
return
|
||||
end
|
||||
|
||||
for key, is_dirty in pairs(dirty[root]) do
|
||||
if is_dirty then
|
||||
M.flush(key, root)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Get meta.json data
|
||||
---@param root? string Project root
|
||||
---@return GraphMeta
|
||||
function M.get_meta(root)
|
||||
local meta = M.load("meta", root)
|
||||
if not meta or not meta.v then
|
||||
meta = {
|
||||
v = types.SCHEMA_VERSION,
|
||||
head = nil,
|
||||
nc = 0,
|
||||
ec = 0,
|
||||
dc = 0,
|
||||
}
|
||||
M.save("meta", meta, root)
|
||||
end
|
||||
return meta
|
||||
end
|
||||
|
||||
--- Update meta.json
|
||||
---@param updates table Partial updates
|
||||
---@param root? string Project root
|
||||
function M.update_meta(updates, root)
|
||||
local meta = M.get_meta(root)
|
||||
for k, v in pairs(updates) do
|
||||
meta[k] = v
|
||||
end
|
||||
M.save("meta", meta, root)
|
||||
end
|
||||
|
||||
--- Get nodes by type
|
||||
---@param node_type string Node type (e.g., "patterns", "corrections")
|
||||
---@param root? string Project root
|
||||
---@return table<string, Node> Nodes indexed by ID
|
||||
function M.get_nodes(node_type, root)
|
||||
return M.load("nodes." .. node_type, root) or {}
|
||||
end
|
||||
|
||||
--- Save nodes by type
|
||||
---@param node_type string Node type
|
||||
---@param nodes table<string, Node> Nodes indexed by ID
|
||||
---@param root? string Project root
|
||||
function M.save_nodes(node_type, nodes, root)
|
||||
M.save("nodes." .. node_type, nodes, root)
|
||||
end
|
||||
|
||||
--- Get graph adjacency
|
||||
---@param root? string Project root
|
||||
---@return Graph Graph data
|
||||
function M.get_graph(root)
|
||||
local graph = M.load("graph", root)
|
||||
if not graph or not graph.adj then
|
||||
graph = {
|
||||
adj = {},
|
||||
radj = {},
|
||||
}
|
||||
M.save("graph", graph, root)
|
||||
end
|
||||
return graph
|
||||
end
|
||||
|
||||
--- Save graph
|
||||
---@param graph Graph Graph data
|
||||
---@param root? string Project root
|
||||
function M.save_graph(graph, root)
|
||||
M.save("graph", graph, root)
|
||||
end
|
||||
|
||||
--- Get index by type
|
||||
---@param index_type string Index type (e.g., "by_file", "by_time")
|
||||
---@param root? string Project root
|
||||
---@return table Index data
|
||||
function M.get_index(index_type, root)
|
||||
return M.load("indices." .. index_type, root) or {}
|
||||
end
|
||||
|
||||
--- Save index
|
||||
---@param index_type string Index type
|
||||
---@param data table Index data
|
||||
---@param root? string Project root
|
||||
function M.save_index(index_type, data, root)
|
||||
M.save("indices." .. index_type, data, root)
|
||||
end
|
||||
|
||||
--- Get delta by hash
|
||||
---@param hash string Delta hash
|
||||
---@param root? string Project root
|
||||
---@return Delta|nil Delta data
|
||||
function M.get_delta(hash, root)
|
||||
return M.load("deltas.objects." .. hash, root)
|
||||
end
|
||||
|
||||
--- Save delta
|
||||
---@param delta Delta Delta data
|
||||
---@param root? string Project root
|
||||
function M.save_delta(delta, root)
|
||||
M.save("deltas.objects." .. delta.h, delta, root, true) -- Immediate write for deltas
|
||||
end
|
||||
|
||||
--- Get HEAD delta hash
|
||||
---@param root? string Project root
|
||||
---@return string|nil HEAD hash
|
||||
function M.get_head(root)
|
||||
local meta = M.get_meta(root)
|
||||
return meta.head
|
||||
end
|
||||
|
||||
--- Set HEAD delta hash
|
||||
---@param hash string|nil Delta hash
|
||||
---@param root? string Project root
|
||||
function M.set_head(hash, root)
|
||||
M.update_meta({ head = hash }, root)
|
||||
end
|
||||
|
||||
--- Clear all caches (for testing)
|
||||
function M.clear_cache()
|
||||
cache = {}
|
||||
dirty = {}
|
||||
for _, timer in pairs(timers) do
|
||||
if timer then
|
||||
timer:stop()
|
||||
end
|
||||
end
|
||||
timers = {}
|
||||
end
|
||||
|
||||
--- Check if brain exists for project
|
||||
---@param root? string Project root
|
||||
---@return boolean
|
||||
function M.exists(root)
|
||||
local brain_dir = M.get_brain_dir(root)
|
||||
return vim.fn.isdirectory(brain_dir) == 1
|
||||
end
|
||||
|
||||
return M
|
||||
175
lua/codetyper/core/memory/types.lua
Normal file
175
lua/codetyper/core/memory/types.lua
Normal file
@@ -0,0 +1,175 @@
|
||||
---@meta
|
||||
--- Brain Learning System Type Definitions
|
||||
--- Optimized for LLM consumption with compact field names
|
||||
|
||||
local M = {}
|
||||
|
||||
---@alias NodeType "pat"|"cor"|"dec"|"con"|"fbk"|"ses"
|
||||
-- pat = pattern, cor = correction, dec = decision
|
||||
-- con = convention, fbk = feedback, ses = session
|
||||
|
||||
---@alias EdgeType "sem"|"file"|"temp"|"caus"|"sup"
|
||||
-- sem = semantic, file = file-based, temp = temporal
|
||||
-- caus = causal, sup = supersedes
|
||||
|
||||
---@alias DeltaOp "add"|"mod"|"del"
|
||||
|
||||
---@class NodeContent
|
||||
---@field s string Summary (max 200 chars)
|
||||
---@field d string Detail (full description)
|
||||
---@field code? string Optional code snippet
|
||||
---@field lang? string Language identifier
|
||||
|
||||
---@class NodeContext
|
||||
---@field f? string File path (relative)
|
||||
---@field fn? string Function name
|
||||
---@field ln? number[] Line range [start, end]
|
||||
---@field sym? string[] Symbol references
|
||||
|
||||
---@class NodeScores
|
||||
---@field w number Weight (0-1)
|
||||
---@field u number Usage count
|
||||
---@field sr number Success rate (0-1)
|
||||
|
||||
---@class NodeTimestamps
|
||||
---@field cr number Created (unix timestamp)
|
||||
---@field up number Updated (unix timestamp)
|
||||
---@field lu? number Last used (unix timestamp)
|
||||
|
||||
---@class NodeMeta
|
||||
---@field src "auto"|"user"|"llm" Source of learning
|
||||
---@field v number Version number
|
||||
---@field dr? string[] Delta references
|
||||
|
||||
---@class Node
|
||||
---@field id string Unique identifier (n_<timestamp>_<hash>)
|
||||
---@field t NodeType Node type
|
||||
---@field h string Content hash (8 chars)
|
||||
---@field c NodeContent Content
|
||||
---@field ctx NodeContext Context
|
||||
---@field sc NodeScores Scores
|
||||
---@field ts NodeTimestamps Timestamps
|
||||
---@field m? NodeMeta Metadata
|
||||
|
||||
---@class EdgeProps
|
||||
---@field w number Weight (0-1)
|
||||
---@field dir "bi"|"fwd"|"bwd" Direction
|
||||
---@field r? string Reason/description
|
||||
|
||||
---@class Edge
|
||||
---@field id string Unique identifier (e_<source>_<target>)
|
||||
---@field s string Source node ID
|
||||
---@field t string Target node ID
|
||||
---@field ty EdgeType Edge type
|
||||
---@field p EdgeProps Properties
|
||||
---@field ts number Created timestamp
|
||||
|
||||
---@class DeltaChange
|
||||
---@field op DeltaOp Operation type
|
||||
---@field path string JSON path (e.g., "nodes.pat.n_123")
|
||||
---@field bh? string Before hash
|
||||
---@field ah? string After hash
|
||||
---@field diff? table Field-level diff
|
||||
|
||||
---@class DeltaMeta
|
||||
---@field msg string Commit message
|
||||
---@field trig string Trigger source
|
||||
---@field sid? string Session ID
|
||||
|
||||
---@class Delta
|
||||
---@field h string Hash (8 chars)
|
||||
---@field p? string Parent hash
|
||||
---@field ts number Timestamp
|
||||
---@field ch DeltaChange[] Changes
|
||||
---@field m DeltaMeta Metadata
|
||||
|
||||
---@class GraphMeta
|
||||
---@field v number Schema version
|
||||
---@field head? string Current HEAD delta hash
|
||||
---@field nc number Node count
|
||||
---@field ec number Edge count
|
||||
---@field dc number Delta count
|
||||
|
||||
---@class AdjacencyEntry
|
||||
---@field sem? string[] Semantic edges
|
||||
---@field file? string[] File edges
|
||||
---@field temp? string[] Temporal edges
|
||||
---@field caus? string[] Causal edges
|
||||
---@field sup? string[] Supersedes edges
|
||||
|
||||
---@class Graph
|
||||
---@field meta GraphMeta Metadata
|
||||
---@field adj table<string, AdjacencyEntry> Adjacency list
|
||||
---@field radj table<string, AdjacencyEntry> Reverse adjacency
|
||||
|
||||
---@class QueryOpts
|
||||
---@field query? string Text query
|
||||
---@field file? string File path filter
|
||||
---@field types? NodeType[] Node types to include
|
||||
---@field since? number Timestamp filter
|
||||
---@field limit? number Max results
|
||||
---@field depth? number Traversal depth
|
||||
---@field max_tokens? number Token budget
|
||||
|
||||
---@class QueryResult
|
||||
---@field nodes Node[] Matched nodes
|
||||
---@field edges Edge[] Related edges
|
||||
---@field stats table Query statistics
|
||||
---@field truncated boolean Whether results were truncated
|
||||
|
||||
---@class LLMContext
|
||||
---@field schema string Schema version
|
||||
---@field query string Original query
|
||||
---@field learnings table[] Compact learning entries
|
||||
---@field connections table[] Connection summaries
|
||||
---@field tokens number Estimated token count
|
||||
|
||||
---@class LearnEvent
|
||||
---@field type string Event type
|
||||
---@field data table Event data
|
||||
---@field file? string Related file
|
||||
---@field timestamp number Event timestamp
|
||||
|
||||
---@class BrainConfig
|
||||
---@field enabled boolean Enable brain system
|
||||
---@field auto_learn boolean Auto-learn from events
|
||||
---@field auto_commit boolean Auto-commit after threshold
|
||||
---@field commit_threshold number Changes before auto-commit
|
||||
---@field max_nodes number Max nodes before pruning
|
||||
---@field max_deltas number Max delta history
|
||||
---@field prune table Pruning config
|
||||
---@field output table Output config
|
||||
|
||||
-- Type constants for runtime use
|
||||
M.NODE_TYPES = {
|
||||
PATTERN = "pat",
|
||||
CORRECTION = "cor",
|
||||
DECISION = "dec",
|
||||
CONVENTION = "con",
|
||||
FEEDBACK = "fbk",
|
||||
SESSION = "ses",
|
||||
}
|
||||
|
||||
M.EDGE_TYPES = {
|
||||
SEMANTIC = "sem",
|
||||
FILE = "file",
|
||||
TEMPORAL = "temp",
|
||||
CAUSAL = "caus",
|
||||
SUPERSEDES = "sup",
|
||||
}
|
||||
|
||||
M.DELTA_OPS = {
|
||||
ADD = "add",
|
||||
MODIFY = "mod",
|
||||
DELETE = "del",
|
||||
}
|
||||
|
||||
M.SOURCES = {
|
||||
AUTO = "auto",
|
||||
USER = "user",
|
||||
LLM = "llm",
|
||||
}
|
||||
|
||||
M.SCHEMA_VERSION = 1
|
||||
|
||||
return M
|
||||
616
lua/codetyper/core/scheduler/executor.lua
Normal file
616
lua/codetyper/core/scheduler/executor.lua
Normal file
@@ -0,0 +1,616 @@
|
||||
---@mod codetyper.agent.executor Tool executor for agent system
|
||||
---
|
||||
--- Executes tools requested by the LLM and returns results.
|
||||
|
||||
local M = {}
|
||||
local utils = require("codetyper.support.utils")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
|
||||
---@class ExecutionResult
|
||||
---@field success boolean Whether the execution succeeded
|
||||
---@field result string Result message or content
|
||||
---@field requires_approval boolean Whether user approval is needed
|
||||
---@field diff_data? DiffData Data for diff preview (if requires_approval)
|
||||
|
||||
--- Open a file in a buffer (in a non-agent window)
|
||||
---@param path string File path to open
|
||||
---@param jump_to_line? number Optional line number to jump to
|
||||
local function open_file_in_buffer(path, jump_to_line)
|
||||
if not path or path == "" then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
if vim.fn.filereadable(path) ~= 1 then
|
||||
return
|
||||
end
|
||||
|
||||
vim.schedule(function()
|
||||
-- Find a suitable window (not the agent UI windows)
|
||||
local target_win = nil
|
||||
local agent_ui_ok, agent_ui = pcall(require, "codetyper.agent.ui")
|
||||
|
||||
for _, win in ipairs(vim.api.nvim_list_wins()) do
|
||||
local buf = vim.api.nvim_win_get_buf(win)
|
||||
local buftype = vim.bo[buf].buftype
|
||||
|
||||
-- Skip special buffers (agent UI, nofile, etc.)
|
||||
if buftype == "" or buftype == "acwrite" then
|
||||
-- Check if this is not an agent UI window
|
||||
local is_agent_win = false
|
||||
if agent_ui_ok and agent_ui.is_open() then
|
||||
-- Skip agent windows by checking if it's one of our special buffers
|
||||
local bufname = vim.api.nvim_buf_get_name(buf)
|
||||
if bufname == "" then
|
||||
-- Could be agent buffer, check by buffer option
|
||||
is_agent_win = vim.bo[buf].buftype == "nofile"
|
||||
end
|
||||
end
|
||||
|
||||
if not is_agent_win then
|
||||
target_win = win
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- If no suitable window found, create a new split
|
||||
if not target_win then
|
||||
-- Get the rightmost non-agent window or create one
|
||||
vim.cmd("rightbelow vsplit")
|
||||
target_win = vim.api.nvim_get_current_win()
|
||||
end
|
||||
|
||||
-- Open the file in the target window
|
||||
vim.api.nvim_set_current_win(target_win)
|
||||
vim.cmd("edit " .. vim.fn.fnameescape(path))
|
||||
|
||||
-- Jump to line if specified
|
||||
if jump_to_line and jump_to_line > 0 then
|
||||
local line_count = vim.api.nvim_buf_line_count(0)
|
||||
local target_line = math.min(jump_to_line, line_count)
|
||||
vim.api.nvim_win_set_cursor(target_win, { target_line, 0 })
|
||||
vim.cmd("normal! zz")
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Expose open_file_in_buffer for external use
|
||||
M.open_file_in_buffer = open_file_in_buffer
|
||||
|
||||
---@class DiffData
|
||||
---@field path string File path
|
||||
---@field original string Original content
|
||||
---@field modified string Modified content
|
||||
---@field operation string Operation type: "edit", "create", "overwrite", "bash"
|
||||
|
||||
--- Execute a tool and return result via callback
|
||||
---@param tool_name string Name of the tool to execute
|
||||
---@param parameters table Tool parameters
|
||||
---@param callback fun(result: ExecutionResult) Callback with result
|
||||
function M.execute(tool_name, parameters, callback)
|
||||
local handlers = {
|
||||
read_file = M.handle_read_file,
|
||||
edit_file = M.handle_edit_file,
|
||||
write_file = M.handle_write_file,
|
||||
bash = M.handle_bash,
|
||||
delete_file = M.handle_delete_file,
|
||||
list_directory = M.handle_list_directory,
|
||||
search_files = M.handle_search_files,
|
||||
}
|
||||
|
||||
local handler = handlers[tool_name]
|
||||
if not handler then
|
||||
callback({
|
||||
success = false,
|
||||
result = "Unknown tool: " .. tool_name,
|
||||
requires_approval = false,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
handler(parameters, callback)
|
||||
end
|
||||
|
||||
--- Handle read_file tool
|
||||
---@param params table { path: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_read_file(params, callback)
|
||||
local path = M.resolve_path(params.path)
|
||||
|
||||
-- Log the read operation in Claude Code style
|
||||
local relative_path = vim.fn.fnamemodify(path, ":~:.")
|
||||
logs.read(relative_path)
|
||||
|
||||
local content = utils.read_file(path)
|
||||
|
||||
if content then
|
||||
-- Log how many lines were read
|
||||
local lines = vim.split(content, "\n", { plain = true })
|
||||
logs.add({ type = "result", message = string.format(" ⎿ Read %d lines", #lines) })
|
||||
|
||||
-- Open the file in a buffer so user can see it
|
||||
open_file_in_buffer(path)
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = content,
|
||||
requires_approval = false,
|
||||
})
|
||||
else
|
||||
logs.add({ type = "error", message = " ⎿ File not found" })
|
||||
callback({
|
||||
success = false,
|
||||
result = "Could not read file: " .. path,
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
--- Handle edit_file tool
|
||||
---@param params table { path: string, find: string, replace: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_edit_file(params, callback)
|
||||
local path = M.resolve_path(params.path)
|
||||
local relative_path = vim.fn.fnamemodify(path, ":~:.")
|
||||
|
||||
-- Log the edit operation
|
||||
logs.add({ type = "action", message = string.format("Edit(%s)", relative_path) })
|
||||
|
||||
local original = utils.read_file(path)
|
||||
|
||||
if not original then
|
||||
logs.add({ type = "error", message = " ⎿ File not found" })
|
||||
callback({
|
||||
success = false,
|
||||
result = "File not found: " .. path,
|
||||
requires_approval = false,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Try to find and replace the content
|
||||
local escaped_find = utils.escape_pattern(params.find)
|
||||
local new_content, count = original:gsub(escaped_find, params.replace, 1)
|
||||
|
||||
if count == 0 then
|
||||
logs.add({ type = "error", message = " ⎿ Content not found" })
|
||||
callback({
|
||||
success = false,
|
||||
result = "Could not find content to replace in: " .. path,
|
||||
requires_approval = false,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Calculate lines changed
|
||||
local original_lines = #vim.split(original, "\n", { plain = true })
|
||||
local new_lines = #vim.split(new_content, "\n", { plain = true })
|
||||
local diff = new_lines - original_lines
|
||||
if diff > 0 then
|
||||
logs.add({ type = "result", message = string.format(" ⎿ +%d lines (pending approval)", diff) })
|
||||
elseif diff < 0 then
|
||||
logs.add({ type = "result", message = string.format(" ⎿ %d lines (pending approval)", diff) })
|
||||
else
|
||||
logs.add({ type = "result", message = " ⎿ Modified (pending approval)" })
|
||||
end
|
||||
|
||||
-- Requires user approval - show diff
|
||||
callback({
|
||||
success = true,
|
||||
result = "Edit prepared for: " .. path,
|
||||
requires_approval = true,
|
||||
diff_data = {
|
||||
path = path,
|
||||
original = original,
|
||||
modified = new_content,
|
||||
operation = "edit",
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle write_file tool
|
||||
---@param params table { path: string, content: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_write_file(params, callback)
|
||||
local path = M.resolve_path(params.path)
|
||||
local relative_path = vim.fn.fnamemodify(path, ":~:.")
|
||||
local original = utils.read_file(path) or ""
|
||||
local operation = original == "" and "create" or "overwrite"
|
||||
|
||||
-- Log the write operation
|
||||
if operation == "create" then
|
||||
logs.add({ type = "action", message = string.format("Write(%s)", relative_path) })
|
||||
local new_lines = #vim.split(params.content, "\n", { plain = true })
|
||||
logs.add({ type = "result", message = string.format(" ⎿ New file (%d lines, pending approval)", new_lines) })
|
||||
else
|
||||
logs.add({ type = "action", message = string.format("Update(%s)", relative_path) })
|
||||
local original_lines = #vim.split(original, "\n", { plain = true })
|
||||
local new_lines = #vim.split(params.content, "\n", { plain = true })
|
||||
local diff = new_lines - original_lines
|
||||
if diff > 0 then
|
||||
logs.add({ type = "result", message = string.format(" ⎿ +%d lines (pending approval)", diff) })
|
||||
elseif diff < 0 then
|
||||
logs.add({ type = "result", message = string.format(" ⎿ %d lines (pending approval)", diff) })
|
||||
else
|
||||
logs.add({ type = "result", message = " ⎿ Modified (pending approval)" })
|
||||
end
|
||||
end
|
||||
|
||||
-- Ensure parent directory exists
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if dir ~= "" and dir ~= "." then
|
||||
utils.ensure_dir(dir)
|
||||
end
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = (operation == "create" and "Create" or "Overwrite") .. " prepared for: " .. path,
|
||||
requires_approval = true,
|
||||
diff_data = {
|
||||
path = path,
|
||||
original = original,
|
||||
modified = params.content,
|
||||
operation = operation,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle bash tool
|
||||
---@param params table { command: string, timeout?: number }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_bash(params, callback)
|
||||
local command = params.command
|
||||
|
||||
-- Log the bash operation
|
||||
logs.add({ type = "action", message = string.format("Bash(%s)", command:sub(1, 50) .. (#command > 50 and "..." or "")) })
|
||||
logs.add({ type = "result", message = " ⎿ Pending approval" })
|
||||
|
||||
-- Requires user approval first
|
||||
callback({
|
||||
success = true,
|
||||
result = "Command: " .. command,
|
||||
requires_approval = true,
|
||||
diff_data = {
|
||||
path = "[bash]",
|
||||
original = "",
|
||||
modified = "$ " .. command,
|
||||
operation = "bash",
|
||||
},
|
||||
bash_command = command,
|
||||
bash_timeout = params.timeout or 30000,
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle delete_file tool
|
||||
---@param params table { path: string, reason: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_delete_file(params, callback)
|
||||
local path = M.resolve_path(params.path)
|
||||
local reason = params.reason or "No reason provided"
|
||||
|
||||
-- Check if file exists
|
||||
if not utils.file_exists(path) then
|
||||
callback({
|
||||
success = false,
|
||||
result = "File not found: " .. path,
|
||||
requires_approval = false,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
-- Read content for showing in diff (so user knows what they're deleting)
|
||||
local content = utils.read_file(path) or "[Could not read file]"
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = "Delete: " .. path .. " (" .. reason .. ")",
|
||||
requires_approval = true,
|
||||
diff_data = {
|
||||
path = path,
|
||||
original = content,
|
||||
modified = "", -- Empty = deletion
|
||||
operation = "delete",
|
||||
reason = reason,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle list_directory tool
|
||||
---@param params table { path?: string, recursive?: boolean }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_list_directory(params, callback)
|
||||
local path = params.path and M.resolve_path(params.path) or (utils.get_project_root() or vim.fn.getcwd())
|
||||
local recursive = params.recursive or false
|
||||
|
||||
-- Use vim.fn.readdir or glob for directory listing
|
||||
local entries = {}
|
||||
local function list_dir(dir, depth)
|
||||
if depth > 3 then
|
||||
return
|
||||
end
|
||||
|
||||
local ok, files = pcall(vim.fn.readdir, dir)
|
||||
if not ok or not files then
|
||||
return
|
||||
end
|
||||
|
||||
for _, name in ipairs(files) do
|
||||
if name ~= "." and name ~= ".." and not name:match("^%.git$") and not name:match("^node_modules$") then
|
||||
local full_path = dir .. "/" .. name
|
||||
local stat = vim.loop.fs_stat(full_path)
|
||||
if stat then
|
||||
local prefix = string.rep(" ", depth)
|
||||
local type_indicator = stat.type == "directory" and "/" or ""
|
||||
table.insert(entries, prefix .. name .. type_indicator)
|
||||
|
||||
if recursive and stat.type == "directory" then
|
||||
list_dir(full_path, depth + 1)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
list_dir(path, 0)
|
||||
|
||||
local result = "Directory: " .. path .. "\n\n" .. table.concat(entries, "\n")
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = result,
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
|
||||
--- Handle search_files tool
|
||||
---@param params table { pattern?: string, content?: string, path?: string }
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.handle_search_files(params, callback)
|
||||
local search_path = params.path and M.resolve_path(params.path) or (utils.get_project_root() or vim.fn.getcwd())
|
||||
local pattern = params.pattern
|
||||
local content_search = params.content
|
||||
|
||||
local results = {}
|
||||
|
||||
if pattern then
|
||||
-- Search by file name pattern using glob
|
||||
local glob_pattern = search_path .. "/**/" .. pattern
|
||||
local files = vim.fn.glob(glob_pattern, false, true)
|
||||
|
||||
for _, file in ipairs(files) do
|
||||
-- Skip common ignore patterns
|
||||
if not file:match("node_modules") and not file:match("%.git/") then
|
||||
local relative = file:gsub(search_path .. "/", "")
|
||||
table.insert(results, relative)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if content_search then
|
||||
-- Search by content using grep
|
||||
local grep_results = {}
|
||||
local grep_cmd = string.format("grep -rl '%s' '%s' 2>/dev/null | head -20", content_search:gsub("'", "\\'"), search_path)
|
||||
|
||||
local handle = io.popen(grep_cmd)
|
||||
if handle then
|
||||
for line in handle:lines() do
|
||||
if not line:match("node_modules") and not line:match("%.git/") then
|
||||
local relative = line:gsub(search_path .. "/", "")
|
||||
table.insert(grep_results, relative)
|
||||
end
|
||||
end
|
||||
handle:close()
|
||||
end
|
||||
|
||||
-- Merge with pattern results or use as primary results
|
||||
if #results == 0 then
|
||||
results = grep_results
|
||||
else
|
||||
-- Intersection of pattern and content results
|
||||
local pattern_set = {}
|
||||
for _, f in ipairs(results) do
|
||||
pattern_set[f] = true
|
||||
end
|
||||
results = {}
|
||||
for _, f in ipairs(grep_results) do
|
||||
if pattern_set[f] then
|
||||
table.insert(results, f)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local result_text = "Search results"
|
||||
if pattern then
|
||||
result_text = result_text .. " (pattern: " .. pattern .. ")"
|
||||
end
|
||||
if content_search then
|
||||
result_text = result_text .. " (content: " .. content_search .. ")"
|
||||
end
|
||||
result_text = result_text .. ":\n\n"
|
||||
|
||||
if #results == 0 then
|
||||
result_text = result_text .. "No files found."
|
||||
else
|
||||
result_text = result_text .. table.concat(results, "\n")
|
||||
end
|
||||
|
||||
callback({
|
||||
success = true,
|
||||
result = result_text,
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
|
||||
--- Actually apply an approved change
|
||||
---@param diff_data DiffData The diff data to apply
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.apply_change(diff_data, callback)
|
||||
if diff_data.operation == "bash" then
|
||||
-- Extract command from modified (remove "$ " prefix)
|
||||
local command = diff_data.modified:gsub("^%$ ", "")
|
||||
M.execute_bash_command(command, 30000, callback)
|
||||
elseif diff_data.operation == "delete" then
|
||||
-- Delete file
|
||||
local ok, err = os.remove(diff_data.path)
|
||||
if ok then
|
||||
-- Close buffer if it's open
|
||||
M.close_buffer_if_open(diff_data.path)
|
||||
callback({
|
||||
success = true,
|
||||
result = "Deleted: " .. diff_data.path,
|
||||
requires_approval = false,
|
||||
})
|
||||
else
|
||||
callback({
|
||||
success = false,
|
||||
result = "Failed to delete: " .. diff_data.path .. " (" .. (err or "unknown error") .. ")",
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
else
|
||||
-- Write file
|
||||
local success = utils.write_file(diff_data.path, diff_data.modified)
|
||||
if success then
|
||||
-- Open and/or reload buffer so user can see the changes
|
||||
open_file_in_buffer(diff_data.path)
|
||||
M.reload_buffer_if_open(diff_data.path)
|
||||
callback({
|
||||
success = true,
|
||||
result = "Changes applied to: " .. diff_data.path,
|
||||
requires_approval = false,
|
||||
})
|
||||
else
|
||||
callback({
|
||||
success = false,
|
||||
result = "Failed to write: " .. diff_data.path,
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Execute a bash command
|
||||
---@param command string Command to execute
|
||||
---@param timeout number Timeout in milliseconds
|
||||
---@param callback fun(result: ExecutionResult)
|
||||
function M.execute_bash_command(command, timeout, callback)
|
||||
local stdout_data = {}
|
||||
local stderr_data = {}
|
||||
local job_id
|
||||
|
||||
job_id = vim.fn.jobstart(command, {
|
||||
stdout_buffered = true,
|
||||
stderr_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if data then
|
||||
for _, line in ipairs(data) do
|
||||
if line ~= "" then
|
||||
table.insert(stdout_data, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data then
|
||||
for _, line in ipairs(data) do
|
||||
if line ~= "" then
|
||||
table.insert(stderr_data, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, exit_code)
|
||||
vim.schedule(function()
|
||||
local result = table.concat(stdout_data, "\n")
|
||||
if #stderr_data > 0 then
|
||||
if result ~= "" then
|
||||
result = result .. "\n"
|
||||
end
|
||||
result = result .. "STDERR:\n" .. table.concat(stderr_data, "\n")
|
||||
end
|
||||
result = result .. "\n[Exit code: " .. exit_code .. "]"
|
||||
|
||||
callback({
|
||||
success = exit_code == 0,
|
||||
result = result,
|
||||
requires_approval = false,
|
||||
})
|
||||
end)
|
||||
end,
|
||||
})
|
||||
|
||||
-- Set up timeout
|
||||
if job_id > 0 then
|
||||
vim.defer_fn(function()
|
||||
if vim.fn.jobwait({ job_id }, 0)[1] == -1 then
|
||||
vim.fn.jobstop(job_id)
|
||||
vim.schedule(function()
|
||||
callback({
|
||||
success = false,
|
||||
result = "Command timed out after " .. timeout .. "ms",
|
||||
requires_approval = false,
|
||||
})
|
||||
end)
|
||||
end
|
||||
end, timeout)
|
||||
else
|
||||
callback({
|
||||
success = false,
|
||||
result = "Failed to start command",
|
||||
requires_approval = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
--- Reload a buffer if it's currently open
|
||||
---@param filepath string Path to the file
|
||||
function M.reload_buffer_if_open(filepath)
|
||||
local full_path = vim.fn.fnamemodify(filepath, ":p")
|
||||
for _, buf in ipairs(vim.api.nvim_list_bufs()) do
|
||||
if vim.api.nvim_buf_is_loaded(buf) then
|
||||
local buf_name = vim.api.nvim_buf_get_name(buf)
|
||||
if buf_name == full_path then
|
||||
vim.api.nvim_buf_call(buf, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Close a buffer if it's currently open (for deleted files)
|
||||
---@param filepath string Path to the file
|
||||
function M.close_buffer_if_open(filepath)
|
||||
local full_path = vim.fn.fnamemodify(filepath, ":p")
|
||||
for _, buf in ipairs(vim.api.nvim_list_bufs()) do
|
||||
if vim.api.nvim_buf_is_loaded(buf) then
|
||||
local buf_name = vim.api.nvim_buf_get_name(buf)
|
||||
if buf_name == full_path then
|
||||
-- Force close the buffer
|
||||
pcall(vim.api.nvim_buf_delete, buf, { force = true })
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Resolve a path (expand ~ and make absolute if needed)
|
||||
---@param path string Path to resolve
|
||||
---@return string Resolved path
|
||||
function M.resolve_path(path)
|
||||
-- Expand ~ to home directory
|
||||
local expanded = vim.fn.expand(path)
|
||||
|
||||
-- If relative, make it relative to project root or cwd
|
||||
if not vim.startswith(expanded, "/") then
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
expanded = root .. "/" .. expanded
|
||||
end
|
||||
|
||||
return vim.fn.fnamemodify(expanded, ":p")
|
||||
end
|
||||
|
||||
return M
|
||||
381
lua/codetyper/core/scheduler/loop.lua
Normal file
381
lua/codetyper/core/scheduler/loop.lua
Normal file
@@ -0,0 +1,381 @@
|
||||
---@mod codetyper.agent.loop Agent loop with tool orchestration
|
||||
---@brief [[
|
||||
--- Main agent loop that handles multi-turn conversations with tool use.
|
||||
--- Agent execution loop with tool calling support.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local prompts = require("codetyper.prompts.agents.loop")
|
||||
|
||||
---@class AgentMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_call_id? string For tool responses
|
||||
---@field tool_calls? table[] For assistant tool calls
|
||||
---@field name? string Tool name for tool responses
|
||||
|
||||
---@class AgentLoopOpts
|
||||
---@field system_prompt string System prompt
|
||||
---@field user_input string Initial user message
|
||||
---@field tools? CoderTool[] Available tools (default: all registered)
|
||||
---@field max_iterations? number Max tool call iterations (default: 10)
|
||||
---@field provider? string LLM provider to use
|
||||
---@field on_start? fun() Called when loop starts
|
||||
---@field on_chunk? fun(chunk: string) Called for each response chunk
|
||||
---@field on_tool_call? fun(name: string, input: table) Called before tool execution
|
||||
---@field on_tool_result? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_message? fun(message: AgentMessage) Called for each message added
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when loop completes
|
||||
---@field session_ctx? table Session context shared across tools
|
||||
|
||||
--- Format tool definitions for OpenAI-compatible API
|
||||
---@param tools CoderTool[]
|
||||
---@return table[]
|
||||
local function format_tools_for_api(tools)
|
||||
local formatted = {}
|
||||
for _, tool in ipairs(tools) do
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(formatted, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = type(tool.description) == "function" and tool.description() or tool.description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
return formatted
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response
|
||||
---@param response table LLM response
|
||||
---@return table[] tool_calls
|
||||
local function parse_tool_calls(response)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Handle different response formats
|
||||
if response.tool_calls then
|
||||
-- OpenAI format
|
||||
for _, call in ipairs(response.tool_calls) do
|
||||
local args = call["function"].arguments
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
end
|
||||
end
|
||||
table.insert(tool_calls, {
|
||||
id = call.id,
|
||||
name = call["function"].name,
|
||||
input = args,
|
||||
})
|
||||
end
|
||||
elseif response.content and type(response.content) == "table" then
|
||||
-- Claude format (content blocks)
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "tool_use" then
|
||||
table.insert(tool_calls, {
|
||||
id = block.id,
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Build messages for LLM request
|
||||
---@param history AgentMessage[]
|
||||
---@return table[]
|
||||
local function build_messages(history)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Execute the agent loop
|
||||
---@param opts AgentLoopOpts
|
||||
function M.run(opts)
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
local llm = require("codetyper.core.llm")
|
||||
|
||||
-- Get tools
|
||||
local tools = opts.tools or tools_mod.list()
|
||||
local tool_map = {}
|
||||
for _, tool in ipairs(tools) do
|
||||
tool_map[tool.name] = tool
|
||||
end
|
||||
|
||||
-- Initialize conversation history
|
||||
---@type AgentMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = opts.system_prompt },
|
||||
{ role = "user", content = opts.user_input },
|
||||
}
|
||||
|
||||
local session_ctx = opts.session_ctx or {}
|
||||
local max_iterations = opts.max_iterations or 10
|
||||
local iteration = 0
|
||||
|
||||
-- Callback wrappers
|
||||
local function on_message(msg)
|
||||
if opts.on_message then
|
||||
opts.on_message(msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify of initial messages
|
||||
for _, msg in ipairs(history) do
|
||||
on_message(msg)
|
||||
end
|
||||
|
||||
-- Start notification
|
||||
if opts.on_start then
|
||||
opts.on_start()
|
||||
end
|
||||
|
||||
--- Process one iteration of the loop
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build request
|
||||
local messages = build_messages(history)
|
||||
local formatted_tools = format_tools_for_api(tools)
|
||||
|
||||
-- Build context for LLM
|
||||
local context = {
|
||||
file_content = "",
|
||||
language = "lua",
|
||||
extension = "lua",
|
||||
prompt_type = "agent",
|
||||
tools = formatted_tools,
|
||||
}
|
||||
|
||||
-- Get LLM response
|
||||
local client = llm.get_client()
|
||||
if not client then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "No LLM client available")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(prompt_parts, string.format("[%s]: %s", msg.role, msg.content or ""))
|
||||
end
|
||||
end
|
||||
local prompt = table.concat(prompt_parts, "\n\n")
|
||||
|
||||
client.generate(prompt, context, function(response, error)
|
||||
if error then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, error)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Chunk callback
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(response)
|
||||
end
|
||||
|
||||
-- Parse response for tool calls
|
||||
-- For now, we'll use a simple heuristic to detect tool calls in the response
|
||||
-- In a full implementation, the LLM would return structured tool calls
|
||||
local tool_calls = {}
|
||||
|
||||
-- Try to parse JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool_calls then
|
||||
tool_calls = parsed.tool_calls
|
||||
end
|
||||
end
|
||||
|
||||
-- Add assistant message
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = response,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
on_message(assistant_msg)
|
||||
|
||||
-- Process tool calls
|
||||
if #tool_calls > 0 then
|
||||
local pending = #tool_calls
|
||||
local results = {}
|
||||
|
||||
for i, call in ipairs(tool_calls) do
|
||||
local tool = tool_map[call.name]
|
||||
if not tool then
|
||||
results[i] = { error = "Unknown tool: " .. call.name }
|
||||
pending = pending - 1
|
||||
else
|
||||
-- Notify of tool call
|
||||
if opts.on_tool_call then
|
||||
opts.on_tool_call(call.name, call.input)
|
||||
end
|
||||
|
||||
-- Execute tool
|
||||
local tool_opts = {
|
||||
on_log = function(msg)
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({ type = "tool", message = msg })
|
||||
end)
|
||||
end,
|
||||
on_complete = function(result, err)
|
||||
results[i] = { result = result, error = err }
|
||||
pending = pending - 1
|
||||
|
||||
-- Notify of tool result
|
||||
if opts.on_tool_result then
|
||||
opts.on_tool_result(call.name, result, err)
|
||||
end
|
||||
|
||||
-- Add tool response to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = call.id or tostring(i),
|
||||
name = call.name,
|
||||
content = err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
on_message(tool_msg)
|
||||
|
||||
-- Continue loop when all tools complete
|
||||
if pending == 0 then
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end,
|
||||
session_ctx = session_ctx,
|
||||
}
|
||||
|
||||
-- Validate and execute
|
||||
local valid, validation_err = true, nil
|
||||
if tool.validate_input then
|
||||
valid, validation_err = tool:validate_input(call.input)
|
||||
end
|
||||
|
||||
if not valid then
|
||||
tool_opts.on_complete(nil, validation_err)
|
||||
else
|
||||
local result, err = tool.func(call.input, tool_opts)
|
||||
-- If sync result, call on_complete
|
||||
if result ~= nil or err ~= nil then
|
||||
tool_opts.on_complete(result, err)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No tool calls - loop complete
|
||||
if opts.on_complete then
|
||||
opts.on_complete(response, nil)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create an agent with default settings
|
||||
---@param task string Task description
|
||||
---@param opts? AgentLoopOpts Additional options
|
||||
function M.create(task, opts)
|
||||
opts = opts or {}
|
||||
|
||||
local system_prompt = opts.system_prompt or prompts.default_system_prompt
|
||||
|
||||
M.run(vim.tbl_extend("force", opts, {
|
||||
system_prompt = system_prompt,
|
||||
user_input = task,
|
||||
}))
|
||||
end
|
||||
|
||||
--- Simple dispatch agent for sub-tasks
|
||||
---@param prompt string Task for the sub-agent
|
||||
---@param on_complete fun(result: string|nil, error: string|nil) Completion callback
|
||||
---@param opts? table Additional options
|
||||
function M.dispatch(prompt, on_complete, opts)
|
||||
opts = opts or {}
|
||||
|
||||
-- Sub-agents get limited tools by default
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
local safe_tools = tools_mod.list(function(tool)
|
||||
return tool.name == "view" or tool.name == "grep" or tool.name == "glob"
|
||||
end)
|
||||
|
||||
M.run({
|
||||
system_prompt = prompts.dispatch_prompt,
|
||||
user_input = prompt,
|
||||
tools = opts.tools or safe_tools,
|
||||
max_iterations = opts.max_iterations or 5,
|
||||
on_complete = on_complete,
|
||||
session_ctx = opts.session_ctx,
|
||||
})
|
||||
end
|
||||
|
||||
return M
|
||||
155
lua/codetyper/core/scheduler/resume.lua
Normal file
155
lua/codetyper/core/scheduler/resume.lua
Normal file
@@ -0,0 +1,155 @@
|
||||
---@mod codetyper.agent.resume Resume context for agent sessions
|
||||
---
|
||||
--- Saves and loads agent state to allow continuing long-running tasks.
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Get the resume context directory
|
||||
---@return string|nil
|
||||
local function get_resume_dir()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
return root .. "/.coder/tmp"
|
||||
end
|
||||
|
||||
--- Get the resume context file path
|
||||
---@return string|nil
|
||||
local function get_resume_path()
|
||||
local dir = get_resume_dir()
|
||||
if not dir then
|
||||
return nil
|
||||
end
|
||||
return dir .. "/agent_resume.json"
|
||||
end
|
||||
|
||||
--- Ensure the resume directory exists
|
||||
---@return boolean
|
||||
local function ensure_resume_dir()
|
||||
local dir = get_resume_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
return utils.ensure_dir(dir)
|
||||
end
|
||||
|
||||
---@class ResumeContext
|
||||
---@field conversation table[] Message history
|
||||
---@field pending_tool_results table[] Pending results
|
||||
---@field iteration number Current iteration count
|
||||
---@field original_prompt string Original user prompt
|
||||
---@field timestamp number When saved
|
||||
---@field project_root string Project root path
|
||||
|
||||
--- Save the current agent state for resuming later
|
||||
---@param conversation table[] Conversation history
|
||||
---@param pending_results table[] Pending tool results
|
||||
---@param iteration number Current iteration
|
||||
---@param original_prompt string Original prompt
|
||||
---@return boolean Success
|
||||
function M.save(conversation, pending_results, iteration, original_prompt)
|
||||
if not ensure_resume_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local path = get_resume_path()
|
||||
if not path then
|
||||
return false
|
||||
end
|
||||
|
||||
local context = {
|
||||
conversation = conversation,
|
||||
pending_tool_results = pending_results,
|
||||
iteration = iteration,
|
||||
original_prompt = original_prompt,
|
||||
timestamp = os.time(),
|
||||
project_root = utils.get_project_root() or vim.fn.getcwd(),
|
||||
}
|
||||
|
||||
local ok, json = pcall(vim.json.encode, context)
|
||||
if not ok then
|
||||
utils.notify("Failed to encode resume context", vim.log.levels.ERROR)
|
||||
return false
|
||||
end
|
||||
|
||||
local success = utils.write_file(path, json)
|
||||
if success then
|
||||
utils.notify("Agent state saved. Use /continue to resume.", vim.log.levels.INFO)
|
||||
end
|
||||
return success
|
||||
end
|
||||
|
||||
--- Load saved agent state
|
||||
---@return ResumeContext|nil
|
||||
function M.load()
|
||||
local path = get_resume_path()
|
||||
if not path then
|
||||
return nil
|
||||
end
|
||||
|
||||
local content = utils.read_file(path)
|
||||
if not content or content == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ok, context = pcall(vim.json.decode, content)
|
||||
if not ok or not context then
|
||||
return nil
|
||||
end
|
||||
|
||||
return context
|
||||
end
|
||||
|
||||
--- Check if there's a saved resume context
|
||||
---@return boolean
|
||||
function M.has_saved_state()
|
||||
local path = get_resume_path()
|
||||
if not path then
|
||||
return false
|
||||
end
|
||||
return vim.fn.filereadable(path) == 1
|
||||
end
|
||||
|
||||
--- Get info about saved state (for display)
|
||||
---@return table|nil
|
||||
function M.get_info()
|
||||
local context = M.load()
|
||||
if not context then
|
||||
return nil
|
||||
end
|
||||
|
||||
local age_seconds = os.time() - (context.timestamp or 0)
|
||||
local age_str
|
||||
if age_seconds < 60 then
|
||||
age_str = age_seconds .. " seconds ago"
|
||||
elseif age_seconds < 3600 then
|
||||
age_str = math.floor(age_seconds / 60) .. " minutes ago"
|
||||
else
|
||||
age_str = math.floor(age_seconds / 3600) .. " hours ago"
|
||||
end
|
||||
|
||||
return {
|
||||
prompt = context.original_prompt,
|
||||
iteration = context.iteration,
|
||||
messages = #context.conversation,
|
||||
saved_at = age_str,
|
||||
project = context.project_root,
|
||||
}
|
||||
end
|
||||
|
||||
--- Clear saved resume context
|
||||
---@return boolean
|
||||
function M.clear()
|
||||
local path = get_resume_path()
|
||||
if not path then
|
||||
return false
|
||||
end
|
||||
|
||||
if vim.fn.filereadable(path) == 1 then
|
||||
os.remove(path)
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -6,11 +6,12 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local queue = require("codetyper.agent.queue")
|
||||
local patch = require("codetyper.agent.patch")
|
||||
local worker = require("codetyper.agent.worker")
|
||||
local confidence_mod = require("codetyper.agent.confidence")
|
||||
local context_modal = require("codetyper.agent.context_modal")
|
||||
local queue = require("codetyper.core.events.queue")
|
||||
local patch = require("codetyper.core.diff.patch")
|
||||
local worker = require("codetyper.core.scheduler.worker")
|
||||
local confidence_mod = require("codetyper.core.llm.confidence")
|
||||
local context_modal = require("codetyper.adapters.nvim.ui.context_modal")
|
||||
local params = require("codetyper.params.agents.scheduler")
|
||||
|
||||
-- Setup context modal cleanup on exit
|
||||
context_modal.setup()
|
||||
@@ -21,15 +22,7 @@ local state = {
|
||||
timer = nil,
|
||||
poll_interval = 100, -- ms
|
||||
paused = false,
|
||||
config = {
|
||||
enabled = true,
|
||||
ollama_scout = true,
|
||||
escalation_threshold = 0.7,
|
||||
max_concurrent = 2,
|
||||
completion_delay_ms = 100,
|
||||
apply_delay_ms = 5000, -- Wait before applying code
|
||||
remote_provider = "claude", -- Default fallback provider
|
||||
},
|
||||
config = params.config,
|
||||
}
|
||||
|
||||
--- Autocommand group for injection timing
|
||||
@@ -90,9 +83,7 @@ local function get_remote_provider()
|
||||
-- If current provider is ollama, use configured remote
|
||||
if config.llm.provider == "ollama" then
|
||||
-- Check which remote provider is configured
|
||||
if config.llm.claude and config.llm.claude.api_key then
|
||||
return "claude"
|
||||
elseif config.llm.openai and config.llm.openai.api_key then
|
||||
if config.llm.openai and config.llm.openai.api_key then
|
||||
return "openai"
|
||||
elseif config.llm.gemini and config.llm.gemini.api_key then
|
||||
return "gemini"
|
||||
@@ -120,13 +111,13 @@ local function get_primary_provider()
|
||||
return config.llm.provider
|
||||
end
|
||||
end
|
||||
return "claude"
|
||||
return "ollama"
|
||||
end
|
||||
|
||||
--- Retry event with additional context
|
||||
---@param original_event table Original prompt event
|
||||
---@param additional_context string Additional context from user
|
||||
local function retry_with_context(original_event, additional_context)
|
||||
local function retry_with_context(original_event, additional_context, attached_files)
|
||||
-- Create new prompt content combining original + additional
|
||||
local combined_prompt = string.format(
|
||||
"%s\n\nAdditional context:\n%s",
|
||||
@@ -140,10 +131,14 @@ local function retry_with_context(original_event, additional_context)
|
||||
new_event.prompt_content = combined_prompt
|
||||
new_event.attempt_count = 0
|
||||
new_event.status = nil
|
||||
-- Preserve any attached files provided by the context modal
|
||||
if attached_files and #attached_files > 0 then
|
||||
new_event.attached_files = attached_files
|
||||
end
|
||||
|
||||
-- Log the retry
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Retrying with additional context (original: %s)", original_event.id),
|
||||
@@ -154,6 +149,79 @@ local function retry_with_context(original_event, additional_context)
|
||||
queue.enqueue(new_event)
|
||||
end
|
||||
|
||||
--- Try to parse requested file paths from an LLM response asking for more context
|
||||
---@param response string
|
||||
---@return string[] list of resolved full paths
|
||||
local function parse_requested_files(response)
|
||||
if not response or response == "" then
|
||||
return {}
|
||||
end
|
||||
|
||||
local cwd = vim.fn.getcwd()
|
||||
local results = {}
|
||||
local seen = {}
|
||||
|
||||
-- Heuristics: capture backticked paths, lines starting with - or *, or raw paths with slashes and extension
|
||||
for path in response:gmatch("`([%w%._%-%/]+%.[%w_]+)`") do
|
||||
if not seen[path] then
|
||||
table.insert(results, path)
|
||||
seen[path] = true
|
||||
end
|
||||
end
|
||||
|
||||
for path in response:gmatch("([%w%._%-%/]+%.[%w_]+)") do
|
||||
if not seen[path] then
|
||||
-- Filter out common English words that match the pattern
|
||||
if not path:match("^[Ii]$") and not path:match("^[Tt]his$") then
|
||||
table.insert(results, path)
|
||||
seen[path] = true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Also capture list items like '- src/foo.lua'
|
||||
for line in response:gmatch("[^\\n]+") do
|
||||
local m = line:match("^%s*[-*]%s*([%w%._%-%/]+%.[%w_]+)%s*$")
|
||||
if m and not seen[m] then
|
||||
table.insert(results, m)
|
||||
seen[m] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- Resolve each candidate to a full path by checking cwd and globbing
|
||||
local resolved = {}
|
||||
for _, p in ipairs(results) do
|
||||
local candidate = p
|
||||
local full = nil
|
||||
|
||||
-- If absolute or already rooted
|
||||
if candidate:sub(1,1) == "/" and vim.fn.filereadable(candidate) == 1 then
|
||||
full = candidate
|
||||
else
|
||||
-- Try relative to cwd
|
||||
local try1 = cwd .. "/" .. candidate
|
||||
if vim.fn.filereadable(try1) == 1 then
|
||||
full = try1
|
||||
else
|
||||
-- Try globbing for filename anywhere in project
|
||||
local basename = candidate
|
||||
-- If candidate contains slashes, try the tail
|
||||
local tail = candidate:match("[^/]+$") or candidate
|
||||
local matches = vim.fn.globpath(cwd, "**/" .. tail, false, true)
|
||||
if matches and #matches > 0 then
|
||||
full = matches[1]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if full and vim.fn.filereadable(full) == 1 then
|
||||
table.insert(resolved, full)
|
||||
end
|
||||
end
|
||||
|
||||
return resolved
|
||||
end
|
||||
|
||||
--- Process worker result and decide next action
|
||||
---@param event table PromptEvent
|
||||
---@param result table WorkerResult
|
||||
@@ -161,14 +229,94 @@ local function handle_worker_result(event, result)
|
||||
-- Check if LLM needs more context
|
||||
if result.needs_context then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Event %s: LLM needs more context, opening modal", event.id),
|
||||
})
|
||||
end)
|
||||
|
||||
-- Open the context modal
|
||||
-- Try to auto-attach any files the LLM specifically requested in its response
|
||||
local requested = parse_requested_files(result.response or "")
|
||||
|
||||
-- Detect suggested shell commands the LLM may want executed (e.g., "run ls -la", "please run git status")
|
||||
local function detect_suggested_commands(response)
|
||||
if not response then
|
||||
return {}
|
||||
end
|
||||
local cmds = {}
|
||||
-- capture backticked commands: `ls -la`
|
||||
for c in response:gmatch("`([^`]+)`") do
|
||||
if #c > 1 and not c:match("%-%-help") then
|
||||
table.insert(cmds, { label = c, cmd = c })
|
||||
end
|
||||
end
|
||||
-- capture phrases like: run ls -la or run `ls -la`
|
||||
for m in response:gmatch("[Rr]un%s+([%w%p%s%-_/]+)") do
|
||||
local cand = m:gsub("^%s+",""):gsub("%s+$","")
|
||||
if cand and #cand > 1 then
|
||||
-- ignore long sentences; keep first line or command-like substring
|
||||
local line = cand:match("[^\n]+") or cand
|
||||
line = line:gsub("and then.*","")
|
||||
line = line:gsub("please.*","")
|
||||
if not line:match("%a+%s+files") then
|
||||
table.insert(cmds, { label = line, cmd = line })
|
||||
end
|
||||
end
|
||||
end
|
||||
-- dedupe
|
||||
local seen = {}
|
||||
local out = {}
|
||||
for _, v in ipairs(cmds) do
|
||||
if v.cmd and not seen[v.cmd] then
|
||||
seen[v.cmd] = true
|
||||
table.insert(out, v)
|
||||
end
|
||||
end
|
||||
return out
|
||||
end
|
||||
|
||||
local suggested_cmds = detect_suggested_commands(result.response or "")
|
||||
if suggested_cmds and #suggested_cmds > 0 then
|
||||
-- Open modal and show suggested commands for user approval
|
||||
context_modal.open(result.original_event or event, result.response or "", retry_with_context, suggested_cmds)
|
||||
queue.update_status(event.id, "needs_context", { response = result.response })
|
||||
return
|
||||
end
|
||||
if requested and #requested > 0 then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({ type = "info", message = string.format("Auto-attaching %d requested file(s)", #requested) })
|
||||
end)
|
||||
|
||||
-- Build attached_files entries
|
||||
local attached = event.attached_files or {}
|
||||
for _, full in ipairs(requested) do
|
||||
local ok, content = pcall(function()
|
||||
return table.concat(vim.fn.readfile(full), "\n")
|
||||
end)
|
||||
if ok and content then
|
||||
table.insert(attached, {
|
||||
path = vim.fn.fnamemodify(full, ":~:."),
|
||||
full_path = full,
|
||||
content = content,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Retry automatically with same prompt but attached files
|
||||
local new_event = vim.deepcopy(result.original_event or event)
|
||||
new_event.id = nil
|
||||
new_event.attached_files = attached
|
||||
new_event.attempt_count = 0
|
||||
new_event.status = nil
|
||||
queue.enqueue(new_event)
|
||||
|
||||
queue.update_status(event.id, "needs_context", { response = result.response })
|
||||
return
|
||||
end
|
||||
|
||||
-- If no files parsed, open modal for manual context entry
|
||||
context_modal.open(result.original_event or event, result.response or "", retry_with_context)
|
||||
|
||||
-- Mark original event as needing context (not failed)
|
||||
@@ -180,7 +328,7 @@ local function handle_worker_result(event, result)
|
||||
-- Failed - try escalation if this was ollama
|
||||
if result.worker_type == "ollama" and event.attempt_count < 2 then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
@@ -210,7 +358,7 @@ local function handle_worker_result(event, result)
|
||||
if needs_escalation and result.worker_type == "ollama" and event.attempt_count < 2 then
|
||||
-- Low confidence from ollama - escalate to remote
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
@@ -235,7 +383,7 @@ local function handle_worker_result(event, result)
|
||||
-- Schedule patch application after delay (gives user time to review/cancel)
|
||||
local delay = state.config.apply_delay_ms or 5000
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Code ready. Applying in %.1f seconds...", delay / 1000),
|
||||
@@ -268,7 +416,7 @@ local function dispatch_next()
|
||||
local should_skip, skip_reason = queue.check_precedence(event)
|
||||
if should_skip then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = string.format("Event %s skipped: %s", event.id, skip_reason or "conflict"),
|
||||
@@ -284,7 +432,7 @@ local function dispatch_next()
|
||||
|
||||
-- Log dispatch with intent/scope info
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
local intent_info = event.intent and event.intent.type or "unknown"
|
||||
local scope_info = event.scope and event.scope.type ~= "file"
|
||||
and string.format("%s:%s", event.scope.type, event.scope.name or "anon")
|
||||
@@ -323,10 +471,10 @@ function M.schedule_patch_flush()
|
||||
local safe, reason = M.is_safe_to_inject()
|
||||
if safe then
|
||||
waiting_to_flush = false
|
||||
local applied, stale = patch.flush_pending()
|
||||
local applied, stale = patch.flush_pending_smart()
|
||||
if applied > 0 or stale > 0 then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Patches flushed: %d applied, %d stale", applied, stale),
|
||||
@@ -339,7 +487,7 @@ function M.schedule_patch_flush()
|
||||
if not waiting_to_flush then
|
||||
waiting_to_flush = true
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = "Waiting for user to finish typing before applying code...",
|
||||
@@ -384,7 +532,7 @@ local function setup_autocmds()
|
||||
callback = function()
|
||||
vim.defer_fn(function()
|
||||
if not M.is_completion_visible() then
|
||||
patch.flush_pending()
|
||||
patch.flush_pending_smart()
|
||||
end
|
||||
end, state.config.completion_delay_ms)
|
||||
end,
|
||||
@@ -396,7 +544,7 @@ local function setup_autocmds()
|
||||
group = augroup,
|
||||
callback = function()
|
||||
if not M.is_insert_mode() and not M.is_completion_visible() then
|
||||
patch.flush_pending()
|
||||
patch.flush_pending_smart()
|
||||
end
|
||||
end,
|
||||
desc = "Flush pending patches on CursorHold",
|
||||
@@ -480,7 +628,7 @@ function M.start(config)
|
||||
scheduler_loop()
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = "Scheduler started",
|
||||
@@ -512,7 +660,7 @@ function M.stop()
|
||||
end
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = "Scheduler stopped",
|
||||
@@ -565,7 +713,7 @@ end
|
||||
|
||||
--- Force flush all pending patches (ignores completion check)
|
||||
function M.force_flush()
|
||||
return patch.flush_pending()
|
||||
return patch.flush_pending_smart()
|
||||
end
|
||||
|
||||
--- Update configuration
|
||||
@@ -576,4 +724,33 @@ function M.configure(config)
|
||||
end
|
||||
end
|
||||
|
||||
--- Queue a prompt for processing
|
||||
--- This is a convenience function that creates a proper PromptEvent and enqueues it
|
||||
---@param opts table Prompt options
|
||||
--- - bufnr: number Source buffer number
|
||||
--- - filepath: string Source file path
|
||||
--- - target_path: string Target file for injection (can be same as filepath)
|
||||
--- - prompt_content: string The cleaned prompt text
|
||||
--- - range: {start_line: number, end_line: number} Line range of prompt tag
|
||||
--- - source: string|nil Source identifier (e.g., "transform_command", "autocmd")
|
||||
--- - priority: number|nil Priority (1=high, 2=normal, 3=low) default 2
|
||||
---@return table The enqueued event
|
||||
function M.queue_prompt(opts)
|
||||
-- Build the PromptEvent structure
|
||||
local event = {
|
||||
bufnr = opts.bufnr,
|
||||
filepath = opts.filepath,
|
||||
target_path = opts.target_path or opts.filepath,
|
||||
prompt_content = opts.prompt_content,
|
||||
range = opts.range,
|
||||
priority = opts.priority or 2,
|
||||
source = opts.source or "manual",
|
||||
-- Capture buffer state for staleness detection
|
||||
changedtick = vim.api.nvim_buf_get_changedtick(opts.bufnr),
|
||||
}
|
||||
|
||||
-- Enqueue through the queue module
|
||||
return queue.enqueue(event)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local confidence = require("codetyper.agent.confidence")
|
||||
local params = require("codetyper.params.agents.worker")
|
||||
local confidence = require("codetyper.core.llm.confidence")
|
||||
|
||||
---@class WorkerResult
|
||||
---@field success boolean Whether the request succeeded
|
||||
@@ -32,20 +33,7 @@ local confidence = require("codetyper.agent.confidence")
|
||||
local worker_counter = 0
|
||||
|
||||
--- Patterns that indicate LLM needs more context (must be near start of response)
|
||||
local context_needed_patterns = {
|
||||
"^%s*i need more context",
|
||||
"^%s*i'm sorry.-i need more",
|
||||
"^%s*i apologize.-i need more",
|
||||
"^%s*could you provide more context",
|
||||
"^%s*could you please provide more",
|
||||
"^%s*can you clarify",
|
||||
"^%s*please provide more context",
|
||||
"^%s*more information needed",
|
||||
"^%s*not enough context",
|
||||
"^%s*i don't have enough",
|
||||
"^%s*unclear what you",
|
||||
"^%s*what do you mean by",
|
||||
}
|
||||
local context_needed_patterns = params.context_needed_patterns
|
||||
|
||||
--- Check if response indicates need for more context
|
||||
--- Only triggers if the response primarily asks for context (no substantial code)
|
||||
@@ -83,6 +71,19 @@ local function needs_more_context(response)
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if response contains SEARCH/REPLACE blocks
|
||||
---@param response string
|
||||
---@return boolean
|
||||
local function has_search_replace_blocks(response)
|
||||
if not response then
|
||||
return false
|
||||
end
|
||||
-- Check for any of the supported SEARCH/REPLACE formats
|
||||
return response:match("<<<<<<<%s*SEARCH") ~= nil
|
||||
or response:match("%-%-%-%-%-%-%-?%s*SEARCH") ~= nil
|
||||
or response:match("%[SEARCH%]") ~= nil
|
||||
end
|
||||
|
||||
--- Clean LLM response to extract only code
|
||||
---@param response string Raw LLM response
|
||||
---@param filetype string|nil File type for language detection
|
||||
@@ -107,6 +108,13 @@ local function clean_response(response, filetype)
|
||||
-- Use [%s%S] to match any character including newlines (Lua's . doesn't match newlines)
|
||||
cleaned = cleaned:gsub("/@[%s%S]-@/", "")
|
||||
|
||||
-- IMPORTANT: If response contains SEARCH/REPLACE blocks, preserve them!
|
||||
-- Don't extract from markdown or remove "explanations" that are actually part of the format
|
||||
if has_search_replace_blocks(cleaned) then
|
||||
-- Just trim whitespace and return - the blocks will be parsed by search_replace module
|
||||
return cleaned:match("^%s*(.-)%s*$") or cleaned
|
||||
end
|
||||
|
||||
-- Try to extract code from markdown code blocks
|
||||
-- Match ```language\n...\n``` or just ```\n...\n```
|
||||
local code_block = cleaned:match("```[%w]*\n(.-)\n```")
|
||||
@@ -176,13 +184,7 @@ end
|
||||
local active_workers = {}
|
||||
|
||||
--- Default timeouts by provider type
|
||||
local default_timeouts = {
|
||||
ollama = 30000, -- 30s for local
|
||||
claude = 60000, -- 60s for remote
|
||||
openai = 60000,
|
||||
gemini = 60000,
|
||||
copilot = 60000,
|
||||
}
|
||||
local default_timeouts = params.default_timeouts
|
||||
|
||||
--- Generate worker ID
|
||||
---@return string
|
||||
@@ -225,29 +227,260 @@ local function format_attached_files(attached_files)
|
||||
return table.concat(parts, "")
|
||||
end
|
||||
|
||||
--- Get coder companion file path for a target file
|
||||
---@param target_path string Target file path
|
||||
---@return string|nil Coder file path if exists
|
||||
local function get_coder_companion_path(target_path)
|
||||
if not target_path or target_path == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Skip if target is already a coder file
|
||||
if target_path:match("%.coder%.") then
|
||||
return nil
|
||||
end
|
||||
|
||||
local dir = vim.fn.fnamemodify(target_path, ":h")
|
||||
local name = vim.fn.fnamemodify(target_path, ":t:r") -- filename without extension
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
local coder_path = dir .. "/" .. name .. ".coder." .. ext
|
||||
if vim.fn.filereadable(coder_path) == 1 then
|
||||
return coder_path
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
--- Read and format coder companion context (business logic, pseudo-code)
|
||||
---@param target_path string Target file path
|
||||
---@return string Formatted coder context
|
||||
local function get_coder_context(target_path)
|
||||
local coder_path = get_coder_companion_path(target_path)
|
||||
if not coder_path then
|
||||
return ""
|
||||
end
|
||||
|
||||
local ok, lines = pcall(function()
|
||||
return vim.fn.readfile(coder_path)
|
||||
end)
|
||||
|
||||
if not ok or not lines or #lines == 0 then
|
||||
return ""
|
||||
end
|
||||
|
||||
local content = table.concat(lines, "\n")
|
||||
|
||||
-- Skip if only template comments (no actual content)
|
||||
local stripped = content:gsub("^%s*", ""):gsub("%s*$", "")
|
||||
if stripped == "" then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Check if there's meaningful content (not just template)
|
||||
local has_content = false
|
||||
for _, line in ipairs(lines) do
|
||||
-- Skip comment lines that are part of the template
|
||||
local trimmed = line:gsub("^%s*", "")
|
||||
if not trimmed:match("^[%-#/]+%s*Coder companion")
|
||||
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
|
||||
and not trimmed:match("^[%-#/]+%s*Example:")
|
||||
and not trimmed:match("^<!%-%-")
|
||||
and trimmed ~= ""
|
||||
and not trimmed:match("^[%-#/]+%s*$") then
|
||||
has_content = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if not has_content then
|
||||
return ""
|
||||
end
|
||||
|
||||
local ext = vim.fn.fnamemodify(coder_path, ":e")
|
||||
return string.format(
|
||||
"\n\n--- Business Context / Pseudo-code ---\n" ..
|
||||
"The following describes the intended behavior and design for this file:\n" ..
|
||||
"```%s\n%s\n```",
|
||||
ext,
|
||||
content:sub(1, 4000) -- Limit to 4000 chars
|
||||
)
|
||||
end
|
||||
|
||||
--- Format indexed project context for inclusion in prompt
|
||||
---@param indexed_context table|nil
|
||||
---@return string
|
||||
local function format_indexed_context(indexed_context)
|
||||
if not indexed_context then
|
||||
return ""
|
||||
end
|
||||
|
||||
local parts = {}
|
||||
|
||||
-- Project type
|
||||
if indexed_context.project_type and indexed_context.project_type ~= "unknown" then
|
||||
table.insert(parts, "Project type: " .. indexed_context.project_type)
|
||||
end
|
||||
|
||||
-- Relevant symbols
|
||||
if indexed_context.relevant_symbols then
|
||||
local symbol_list = {}
|
||||
for symbol, files in pairs(indexed_context.relevant_symbols) do
|
||||
if #files > 0 then
|
||||
table.insert(symbol_list, symbol .. " (in " .. files[1] .. ")")
|
||||
end
|
||||
end
|
||||
if #symbol_list > 0 then
|
||||
table.insert(parts, "Relevant symbols: " .. table.concat(symbol_list, ", "))
|
||||
end
|
||||
end
|
||||
|
||||
-- Learned patterns
|
||||
if indexed_context.patterns and #indexed_context.patterns > 0 then
|
||||
local pattern_list = {}
|
||||
for i, p in ipairs(indexed_context.patterns) do
|
||||
if i <= 3 then
|
||||
table.insert(pattern_list, p.content or "")
|
||||
end
|
||||
end
|
||||
if #pattern_list > 0 then
|
||||
table.insert(parts, "Project conventions: " .. table.concat(pattern_list, "; "))
|
||||
end
|
||||
end
|
||||
|
||||
if #parts == 0 then
|
||||
return ""
|
||||
end
|
||||
|
||||
return "\n\n--- Project Context ---\n" .. table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
--- Check if this is an inline prompt (tags in target file, not a coder file)
|
||||
---@param event table
|
||||
---@return boolean
|
||||
local function is_inline_prompt(event)
|
||||
-- Inline prompts have a range with start_line/end_line from tag detection
|
||||
-- and the source file is the same as target (not a .coder. file)
|
||||
if not event.range or not event.range.start_line then
|
||||
return false
|
||||
end
|
||||
-- Check if source path (if any) equals target, or if target has no .coder. in it
|
||||
local target = event.target_path or ""
|
||||
if target:match("%.coder%.") then
|
||||
return false
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
--- Build file content with marked region for inline prompts
|
||||
---@param lines string[] File lines
|
||||
---@param start_line number 1-indexed
|
||||
---@param end_line number 1-indexed
|
||||
---@param prompt_content string The prompt inside the tags
|
||||
---@return string
|
||||
local function build_marked_file_content(lines, start_line, end_line, prompt_content)
|
||||
local result = {}
|
||||
for i, line in ipairs(lines) do
|
||||
if i == start_line then
|
||||
-- Mark the start of the region to be replaced
|
||||
table.insert(result, ">>> REPLACE THIS REGION (lines " .. start_line .. "-" .. end_line .. ") <<<")
|
||||
table.insert(result, "--- User request: " .. prompt_content:gsub("\n", " "):sub(1, 100) .. " ---")
|
||||
end
|
||||
table.insert(result, line)
|
||||
if i == end_line then
|
||||
table.insert(result, ">>> END OF REGION TO REPLACE <<<")
|
||||
end
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
--- Build prompt for code generation
|
||||
---@param event table PromptEvent
|
||||
---@return string prompt
|
||||
---@return table context
|
||||
local function build_prompt(event)
|
||||
local intent_mod = require("codetyper.agent.intent")
|
||||
local intent_mod = require("codetyper.core.intent")
|
||||
|
||||
-- Get target file content for context
|
||||
local target_content = ""
|
||||
local target_lines = {}
|
||||
if event.target_path then
|
||||
local ok, lines = pcall(function()
|
||||
return vim.fn.readfile(event.target_path)
|
||||
end)
|
||||
if ok and lines then
|
||||
target_lines = lines
|
||||
target_content = table.concat(lines, "\n")
|
||||
end
|
||||
end
|
||||
|
||||
local filetype = vim.fn.fnamemodify(event.target_path or "", ":e")
|
||||
|
||||
-- Get indexed project context
|
||||
local indexed_context = nil
|
||||
local indexed_content = ""
|
||||
pcall(function()
|
||||
local indexer = require("codetyper.features.indexer")
|
||||
indexed_context = indexer.get_context_for({
|
||||
file = event.target_path,
|
||||
intent = event.intent,
|
||||
prompt = event.prompt_content,
|
||||
scope = event.scope_text,
|
||||
})
|
||||
indexed_content = format_indexed_context(indexed_context)
|
||||
end)
|
||||
|
||||
-- Format attached files
|
||||
local attached_content = format_attached_files(event.attached_files)
|
||||
|
||||
-- Get coder companion context (business logic, pseudo-code)
|
||||
local coder_context = get_coder_context(event.target_path)
|
||||
|
||||
-- Get brain memories - contextual recall based on current task
|
||||
local brain_context = ""
|
||||
pcall(function()
|
||||
local brain = require("codetyper.core.memory")
|
||||
if brain.is_initialized() then
|
||||
-- Query brain for relevant memories based on:
|
||||
-- 1. Current file (file-specific patterns)
|
||||
-- 2. Prompt content (semantic similarity)
|
||||
-- 3. Intent type (relevant past generations)
|
||||
local query_text = event.prompt_content or ""
|
||||
if event.scope and event.scope.name then
|
||||
query_text = event.scope.name .. " " .. query_text
|
||||
end
|
||||
|
||||
local result = brain.query({
|
||||
query = query_text,
|
||||
file = event.target_path,
|
||||
max_results = 5,
|
||||
types = { "pattern", "correction", "convention" },
|
||||
})
|
||||
|
||||
if result and result.nodes and #result.nodes > 0 then
|
||||
local memories = { "\n\n--- Learned Patterns & Conventions ---" }
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c then
|
||||
local summary = node.c.s or ""
|
||||
local detail = node.c.d or ""
|
||||
if summary ~= "" then
|
||||
table.insert(memories, "• " .. summary)
|
||||
if detail ~= "" and #detail < 200 then
|
||||
table.insert(memories, " " .. detail)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if #memories > 1 then
|
||||
brain_context = table.concat(memories, "\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
-- Combine all context sources: brain memories first, then coder context, attached files, indexed
|
||||
local extra_context = brain_context .. coder_context .. attached_content .. indexed_content
|
||||
|
||||
-- Build context with scope information
|
||||
local context = {
|
||||
target_path = event.target_path,
|
||||
@@ -258,6 +491,7 @@ local function build_prompt(event)
|
||||
scope_range = event.scope_range,
|
||||
intent = event.intent,
|
||||
attached_files = event.attached_files,
|
||||
indexed_context = indexed_context,
|
||||
}
|
||||
|
||||
-- Build the actual prompt based on intent and scope
|
||||
@@ -268,6 +502,93 @@ local function build_prompt(event)
|
||||
system_prompt = intent_mod.get_prompt_modifier(event.intent)
|
||||
end
|
||||
|
||||
-- SPECIAL HANDLING: Inline prompts with /@ ... @/ tags
|
||||
-- Uses SEARCH/REPLACE block format for reliable code editing
|
||||
if is_inline_prompt(event) and event.range and event.range.start_line then
|
||||
local start_line = event.range.start_line
|
||||
local end_line = event.range.end_line or start_line
|
||||
|
||||
-- Build full file content WITHOUT the /@ @/ tags for cleaner context
|
||||
local file_content_clean = {}
|
||||
for i, line in ipairs(target_lines) do
|
||||
-- Skip lines that are part of the tag
|
||||
if i < start_line or i > end_line then
|
||||
table.insert(file_content_clean, line)
|
||||
end
|
||||
end
|
||||
|
||||
user_prompt = string.format(
|
||||
[[You are editing a %s file: %s
|
||||
|
||||
TASK: %s
|
||||
|
||||
FULL FILE CONTENT:
|
||||
```%s
|
||||
%s
|
||||
```
|
||||
|
||||
IMPORTANT: The instruction above may ask you to make changes ANYWHERE in the file (e.g., "at the top", "after function X", etc.). Read the instruction carefully to determine WHERE to apply the change.
|
||||
|
||||
INSTRUCTIONS:
|
||||
You MUST respond using SEARCH/REPLACE blocks. This format lets you precisely specify what to find and what to replace it with.
|
||||
|
||||
FORMAT:
|
||||
<<<<<<< SEARCH
|
||||
[exact lines to find in the file - copy them exactly including whitespace]
|
||||
=======
|
||||
[new lines to replace them with]
|
||||
>>>>>>> REPLACE
|
||||
|
||||
RULES:
|
||||
1. The SEARCH section must contain EXACT lines from the file (copy-paste them)
|
||||
2. Include 2-3 context lines to uniquely identify the location
|
||||
3. The REPLACE section contains the modified code
|
||||
4. You can use multiple SEARCH/REPLACE blocks for multiple changes
|
||||
5. Preserve the original indentation style
|
||||
6. If adding new code at the start/end of file, include the first/last few lines in SEARCH
|
||||
|
||||
EXAMPLES:
|
||||
|
||||
Example 1 - Adding code at the TOP of file:
|
||||
Task: "Add a comment at the top"
|
||||
<<<<<<< SEARCH
|
||||
// existing first line
|
||||
// existing second line
|
||||
=======
|
||||
// NEW COMMENT ADDED HERE
|
||||
// existing first line
|
||||
// existing second line
|
||||
>>>>>>> REPLACE
|
||||
|
||||
Example 2 - Modifying a function:
|
||||
Task: "Add validation to setValue"
|
||||
<<<<<<< SEARCH
|
||||
export function setValue(key, value) {
|
||||
cache.set(key, value);
|
||||
}
|
||||
=======
|
||||
export function setValue(key, value) {
|
||||
if (!key) throw new Error("key required");
|
||||
cache.set(key, value);
|
||||
}
|
||||
>>>>>>> REPLACE
|
||||
|
||||
Now apply the requested changes using SEARCH/REPLACE blocks:]],
|
||||
filetype,
|
||||
vim.fn.fnamemodify(event.target_path or "", ":t"),
|
||||
event.prompt_content,
|
||||
filetype,
|
||||
table.concat(file_content_clean, "\n"):sub(1, 8000) -- Limit size
|
||||
)
|
||||
|
||||
context.system_prompt = system_prompt
|
||||
context.formatted_prompt = user_prompt
|
||||
context.is_inline_prompt = true
|
||||
context.use_search_replace = true
|
||||
|
||||
return user_prompt, context
|
||||
end
|
||||
|
||||
-- If we have a scope (function/method), include it in the prompt
|
||||
if event.scope_text and event.scope and event.scope.type ~= "file" then
|
||||
local scope_type = event.scope.type
|
||||
@@ -296,10 +617,15 @@ Return ONLY the complete %s with implementation. No explanations, no duplicates.
|
||||
scope_type,
|
||||
filetype,
|
||||
event.scope_text,
|
||||
attached_content,
|
||||
extra_context,
|
||||
event.prompt_content,
|
||||
scope_type
|
||||
)
|
||||
-- Remind the LLM not to repeat the original file content; ask for only the new/updated code or a unified diff
|
||||
user_prompt = user_prompt .. [[
|
||||
|
||||
IMPORTANT: Do NOT repeat the existing code provided above. Return ONLY the new or modified code (the updated function body). If you modify the file, prefer outputting a unified diff patch using standard diff headers (--- a/<file> / +++ b/<file> and @@ hunks). No explanations, no markdown, no code fences.
|
||||
]]
|
||||
-- For other replacement intents, provide the full scope to transform
|
||||
elseif event.intent and intent_mod.is_replacement(event.intent) then
|
||||
user_prompt = string.format(
|
||||
@@ -317,7 +643,7 @@ Return the complete transformed %s. Output only code, no explanations.]],
|
||||
filetype,
|
||||
filetype,
|
||||
event.scope_text,
|
||||
attached_content,
|
||||
extra_context,
|
||||
event.prompt_content,
|
||||
scope_type
|
||||
)
|
||||
@@ -337,9 +663,21 @@ Output only the code to insert, no explanations.]],
|
||||
scope_name,
|
||||
filetype,
|
||||
event.scope_text,
|
||||
attached_content,
|
||||
extra_context,
|
||||
event.prompt_content
|
||||
)
|
||||
|
||||
-- Remind the LLM not to repeat the full file content; ask for only the new/modified code or unified diff
|
||||
user_prompt = user_prompt .. [[
|
||||
|
||||
IMPORTANT: Do NOT repeat the full file content shown above. Return ONLY the new or modified code required to satisfy the request. If you modify the file, prefer outputting a unified diff patch using standard diff headers (--- a/<file> / +++ b/<file> and @@ hunks). No explanations, no markdown, no code fences.
|
||||
]]
|
||||
|
||||
-- Remind the LLM not to repeat the original file content; ask for only the inserted code or a unified diff
|
||||
user_prompt = user_prompt .. [[
|
||||
|
||||
IMPORTANT: Do NOT repeat the surrounding code provided above. Return ONLY the code to insert (the new snippet). If you modify multiple parts of the file, prefer outputting a unified diff patch using standard diff headers (--- a/<file> / +++ b/<file> and @@ hunks). No explanations, no markdown, no code fences.
|
||||
]]
|
||||
end
|
||||
else
|
||||
-- No scope resolved, use full file context
|
||||
@@ -357,7 +695,7 @@ Output only code, no explanations.]],
|
||||
filetype,
|
||||
filetype,
|
||||
target_content:sub(1, 4000), -- Limit context size
|
||||
attached_content,
|
||||
extra_context,
|
||||
event.prompt_content
|
||||
)
|
||||
end
|
||||
@@ -388,7 +726,7 @@ function M.create(event, worker_type, callback)
|
||||
|
||||
-- Log worker creation
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "worker",
|
||||
message = string.format("Worker %s started (%s)", worker.id, worker_type),
|
||||
@@ -418,7 +756,7 @@ function M.start(worker)
|
||||
active_workers[worker.id] = nil
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = string.format("Worker %s timed out after %dms", worker.id, worker.timeout_ms),
|
||||
@@ -437,21 +775,21 @@ function M.start(worker)
|
||||
end
|
||||
end, worker.timeout_ms)
|
||||
|
||||
-- Get client and execute
|
||||
local client, client_err = get_client(worker.worker_type)
|
||||
if not client then
|
||||
M.complete(worker, nil, client_err)
|
||||
return
|
||||
end
|
||||
|
||||
local prompt, context = build_prompt(worker.event)
|
||||
|
||||
-- Call the LLM
|
||||
client.generate(prompt, context, function(response, err, usage)
|
||||
-- Check if smart selection is enabled (memory-based provider selection)
|
||||
local use_smart_selection = false
|
||||
pcall(function()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
use_smart_selection = config.llm.smart_selection ~= false -- Default to true
|
||||
end)
|
||||
|
||||
-- Define the response handler
|
||||
local function handle_response(response, err, usage_or_metadata)
|
||||
-- Cancel timeout timer
|
||||
if worker.timer then
|
||||
pcall(function()
|
||||
-- Timer might have already fired
|
||||
if type(worker.timer) == "userdata" and worker.timer.stop then
|
||||
worker.timer:stop()
|
||||
end
|
||||
@@ -462,8 +800,45 @@ function M.start(worker)
|
||||
return -- Already timed out or cancelled
|
||||
end
|
||||
|
||||
-- Extract usage from metadata if smart_generate was used
|
||||
local usage = usage_or_metadata
|
||||
if type(usage_or_metadata) == "table" and usage_or_metadata.provider then
|
||||
-- This is metadata from smart_generate
|
||||
usage = nil
|
||||
-- Update worker type to reflect actual provider used
|
||||
worker.worker_type = usage_or_metadata.provider
|
||||
-- Log if pondering occurred
|
||||
if usage_or_metadata.pondered then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format(
|
||||
"Pondering: %s (agreement: %.0f%%)",
|
||||
usage_or_metadata.corrected and "corrected" or "validated",
|
||||
(usage_or_metadata.agreement or 1) * 100
|
||||
),
|
||||
})
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
M.complete(worker, response, err, usage)
|
||||
end)
|
||||
end
|
||||
|
||||
-- Use smart selection or direct client
|
||||
if use_smart_selection then
|
||||
local llm = require("codetyper.core.llm")
|
||||
llm.smart_generate(prompt, context, handle_response)
|
||||
else
|
||||
-- Get client and execute directly
|
||||
local client, client_err = get_client(worker.worker_type)
|
||||
if not client then
|
||||
M.complete(worker, nil, client_err)
|
||||
return
|
||||
end
|
||||
client.generate(prompt, context, handle_response)
|
||||
end
|
||||
end
|
||||
|
||||
--- Complete worker execution
|
||||
@@ -479,7 +854,7 @@ function M.complete(worker, response, error, usage)
|
||||
active_workers[worker.id] = nil
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "error",
|
||||
message = string.format("Worker %s failed: %s", worker.id, error),
|
||||
@@ -505,7 +880,7 @@ function M.complete(worker, response, error, usage)
|
||||
active_workers[worker.id] = nil
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Worker %s: LLM needs more context", worker.id),
|
||||
@@ -529,7 +904,7 @@ function M.complete(worker, response, error, usage)
|
||||
|
||||
-- Log the full raw LLM response (for debugging)
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "response",
|
||||
message = "--- LLM Response ---",
|
||||
@@ -550,7 +925,7 @@ function M.complete(worker, response, error, usage)
|
||||
active_workers[worker.id] = nil
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "success",
|
||||
message = string.format(
|
||||
@@ -597,7 +972,7 @@ function M.cancel(worker_id)
|
||||
active_workers[worker_id] = nil
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = string.format("Worker %s cancelled", worker_id),
|
||||
@@ -14,62 +14,10 @@ local M = {}
|
||||
---@field name string|nil Name of the function/class if available
|
||||
|
||||
--- Node types that represent function-like scopes per language
|
||||
local function_nodes = {
|
||||
-- Lua
|
||||
["function_declaration"] = "function",
|
||||
["function_definition"] = "function",
|
||||
["local_function"] = "function",
|
||||
["function"] = "function",
|
||||
|
||||
-- JavaScript/TypeScript
|
||||
["function_declaration"] = "function",
|
||||
["function_expression"] = "function",
|
||||
["arrow_function"] = "function",
|
||||
["method_definition"] = "method",
|
||||
["function"] = "function",
|
||||
|
||||
-- Python
|
||||
["function_definition"] = "function",
|
||||
["async_function_definition"] = "function",
|
||||
|
||||
-- Go
|
||||
["function_declaration"] = "function",
|
||||
["method_declaration"] = "method",
|
||||
|
||||
-- Rust
|
||||
["function_item"] = "function",
|
||||
["impl_item"] = "method",
|
||||
|
||||
-- Ruby
|
||||
["method"] = "method",
|
||||
["singleton_method"] = "method",
|
||||
|
||||
-- Java/C#
|
||||
["method_declaration"] = "method",
|
||||
["constructor_declaration"] = "method",
|
||||
|
||||
-- C/C++
|
||||
["function_definition"] = "function",
|
||||
}
|
||||
|
||||
--- Node types that represent class-like scopes
|
||||
local class_nodes = {
|
||||
["class_declaration"] = "class",
|
||||
["class_definition"] = "class",
|
||||
["class"] = "class",
|
||||
["struct_item"] = "class",
|
||||
["impl_item"] = "class",
|
||||
["interface_declaration"] = "class",
|
||||
["module"] = "class",
|
||||
}
|
||||
|
||||
--- Node types that represent block scopes
|
||||
local block_nodes = {
|
||||
["block"] = "block",
|
||||
["statement_block"] = "block",
|
||||
["compound_statement"] = "block",
|
||||
["do_block"] = "block",
|
||||
}
|
||||
local params = require("codetyper.params.agents.scope")
|
||||
local function_nodes = params.function_nodes
|
||||
local class_nodes = params.class_nodes
|
||||
local block_nodes = params.block_nodes
|
||||
|
||||
--- Check if Tree-sitter is available for buffer
|
||||
---@param bufnr number
|
||||
@@ -282,13 +230,21 @@ function M.resolve_scope_heuristic(bufnr, row, col)
|
||||
ending = nil, -- Python uses indentation
|
||||
},
|
||||
javascript = {
|
||||
start = "^%s*function%s+",
|
||||
start_alt = "^%s*const%s+%w+%s*=%s*",
|
||||
start = "^%s*export%s+function%s+",
|
||||
start_alt = "^%s*function%s+",
|
||||
start_alt2 = "^%s*export%s+const%s+%w+%s*=",
|
||||
start_alt3 = "^%s*const%s+%w+%s*=%s*",
|
||||
start_alt4 = "^%s*export%s+async%s+function%s+",
|
||||
start_alt5 = "^%s*async%s+function%s+",
|
||||
ending = "^%s*}%s*$",
|
||||
},
|
||||
typescript = {
|
||||
start = "^%s*function%s+",
|
||||
start_alt = "^%s*const%s+%w+%s*=%s*",
|
||||
start = "^%s*export%s+function%s+",
|
||||
start_alt = "^%s*function%s+",
|
||||
start_alt2 = "^%s*export%s+const%s+%w+%s*=",
|
||||
start_alt3 = "^%s*const%s+%w+%s*=%s*",
|
||||
start_alt4 = "^%s*export%s+async%s+function%s+",
|
||||
start_alt5 = "^%s*async%s+function%s+",
|
||||
ending = "^%s*}%s*$",
|
||||
},
|
||||
}
|
||||
@@ -302,8 +258,13 @@ function M.resolve_scope_heuristic(bufnr, row, col)
|
||||
local start_line = nil
|
||||
for i = row, 1, -1 do
|
||||
local line = lines[i]
|
||||
if line:match(lang_patterns.start) or
|
||||
(lang_patterns.start_alt and line:match(lang_patterns.start_alt)) then
|
||||
-- Check all start patterns
|
||||
if line:match(lang_patterns.start)
|
||||
or (lang_patterns.start_alt and line:match(lang_patterns.start_alt))
|
||||
or (lang_patterns.start_alt2 and line:match(lang_patterns.start_alt2))
|
||||
or (lang_patterns.start_alt3 and line:match(lang_patterns.start_alt3))
|
||||
or (lang_patterns.start_alt4 and line:match(lang_patterns.start_alt4))
|
||||
or (lang_patterns.start_alt5 and line:match(lang_patterns.start_alt5)) then
|
||||
start_line = i
|
||||
break
|
||||
end
|
||||
128
lua/codetyper/core/tools/base.lua
Normal file
128
lua/codetyper/core/tools/base.lua
Normal file
@@ -0,0 +1,128 @@
|
||||
---@mod codetyper.agent.tools.base Base tool definition
|
||||
---@brief [[
|
||||
--- Base metatable for all LLM tools.
|
||||
--- Tools extend this base to provide structured AI capabilities.
|
||||
---@brief ]]
|
||||
|
||||
---@class CoderToolParam
|
||||
---@field name string Parameter name
|
||||
---@field description string Parameter description
|
||||
---@field type string Parameter type ("string", "number", "boolean", "table")
|
||||
---@field optional? boolean Whether the parameter is optional
|
||||
---@field default? any Default value for optional parameters
|
||||
|
||||
---@class CoderToolReturn
|
||||
---@field name string Return value name
|
||||
---@field description string Return value description
|
||||
---@field type string Return type
|
||||
---@field optional? boolean Whether the return is optional
|
||||
|
||||
---@class CoderToolOpts
|
||||
---@field on_log? fun(message: string) Log callback
|
||||
---@field on_complete? fun(result: any, error: string|nil) Completion callback
|
||||
---@field session_ctx? table Session context
|
||||
---@field streaming? boolean Whether response is still streaming
|
||||
---@field confirm? fun(message: string, callback: fun(ok: boolean)) Confirmation callback
|
||||
|
||||
---@class CoderTool
|
||||
---@field name string Tool identifier
|
||||
---@field description string|fun(): string Tool description
|
||||
---@field params CoderToolParam[] Input parameters
|
||||
---@field returns CoderToolReturn[] Return values
|
||||
---@field requires_confirmation? boolean Whether tool needs user confirmation
|
||||
---@field func fun(input: table, opts: CoderToolOpts): any, string|nil Tool implementation
|
||||
|
||||
local M = {}
|
||||
M.__index = M
|
||||
|
||||
--- Call the tool function
|
||||
---@param opts CoderToolOpts Options for the tool call
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M:__call(opts, on_log, on_complete)
|
||||
return self.func(opts, on_log, on_complete)
|
||||
end
|
||||
|
||||
--- Get the tool description
|
||||
---@return string
|
||||
function M:get_description()
|
||||
if type(self.description) == "function" then
|
||||
return self.description()
|
||||
end
|
||||
return self.description
|
||||
end
|
||||
|
||||
--- Validate input against parameter schema
|
||||
---@param input table Input to validate
|
||||
---@return boolean valid
|
||||
---@return string|nil error
|
||||
function M:validate_input(input)
|
||||
if not self.params then
|
||||
return true
|
||||
end
|
||||
|
||||
for _, param in ipairs(self.params) do
|
||||
local value = input[param.name]
|
||||
|
||||
-- Check required parameters
|
||||
if not param.optional and value == nil then
|
||||
return false, string.format("Missing required parameter: %s", param.name)
|
||||
end
|
||||
|
||||
-- Type checking
|
||||
if value ~= nil then
|
||||
local actual_type = type(value)
|
||||
local expected_type = param.type
|
||||
|
||||
-- Handle special types
|
||||
if expected_type == "integer" and actual_type == "number" then
|
||||
if math.floor(value) ~= value then
|
||||
return false, string.format("Parameter %s must be an integer", param.name)
|
||||
end
|
||||
elseif expected_type ~= actual_type and expected_type ~= "any" then
|
||||
return false, string.format("Parameter %s must be %s, got %s", param.name, expected_type, actual_type)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Generate JSON schema for the tool (for LLM function calling)
|
||||
---@return table schema
|
||||
function M:to_schema()
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(self.params or {}) do
|
||||
local prop = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
|
||||
if param.default ~= nil then
|
||||
prop.default = param.default
|
||||
end
|
||||
|
||||
properties[param.name] = prop
|
||||
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
type = "function",
|
||||
function_def = {
|
||||
name = self.name,
|
||||
description = self:get_description(),
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
139
lua/codetyper/core/tools/bash.lua
Normal file
139
lua/codetyper/core/tools/bash.lua
Normal file
@@ -0,0 +1,139 @@
|
||||
---@mod codetyper.agent.tools.bash Shell command execution tool
|
||||
---@brief [[
|
||||
--- Tool for executing shell commands with safety checks.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.core.tools.base")
|
||||
local description = require("codetyper.prompts.agents.bash").description
|
||||
local params = require("codetyper.params.agents.bash").params
|
||||
local returns = require("codetyper.params.agents.bash").returns
|
||||
local BANNED_COMMANDS = require("codetyper.commands.agents.banned").BANNED_COMMANDS
|
||||
local BANNED_PATTERNS = require("codetyper.commands.agents.banned").BANNED_PATTERNS
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "bash"
|
||||
M.description = description
|
||||
M.params = params
|
||||
M.returns = returns
|
||||
M.requires_confirmation = true
|
||||
|
||||
--- Check if command is safe
|
||||
---@param command string
|
||||
---@return boolean safe
|
||||
---@return string|nil reason
|
||||
local function is_safe_command(command)
|
||||
-- Check exact matches
|
||||
for _, banned in ipairs(BANNED_COMMANDS) do
|
||||
if command == banned then
|
||||
return false, "Command is banned for safety"
|
||||
end
|
||||
end
|
||||
|
||||
-- Check patterns
|
||||
for _, pattern in ipairs(BANNED_PATTERNS) do
|
||||
if command:match(pattern) then
|
||||
return false, "Command matches banned pattern"
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
---@param input {command: string, cwd?: string, timeout?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.command then
|
||||
return nil, "command is required"
|
||||
end
|
||||
|
||||
-- Safety check
|
||||
local safe, reason = is_safe_command(input.command)
|
||||
if not safe then
|
||||
return nil, reason
|
||||
end
|
||||
|
||||
-- Confirmation required
|
||||
if M.requires_confirmation and opts.confirm then
|
||||
local confirmed = false
|
||||
local confirm_error = nil
|
||||
|
||||
opts.confirm("Execute command: " .. input.command, function(ok)
|
||||
if not ok then
|
||||
confirm_error = "User declined command execution"
|
||||
end
|
||||
confirmed = ok
|
||||
end)
|
||||
|
||||
-- Wait for confirmation (in async context, this would be handled differently)
|
||||
if confirm_error then
|
||||
return nil, confirm_error
|
||||
end
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Executing: " .. input.command)
|
||||
end
|
||||
|
||||
-- Prepare command
|
||||
local cwd = input.cwd or vim.fn.getcwd()
|
||||
local timeout = input.timeout or 120000
|
||||
|
||||
-- Execute command
|
||||
local output = ""
|
||||
local exit_code = 0
|
||||
|
||||
local job_opts = {
|
||||
command = "bash",
|
||||
args = { "-c", input.command },
|
||||
cwd = cwd,
|
||||
on_stdout = function(_, data)
|
||||
if data then
|
||||
output = output .. table.concat(data, "\n")
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data then
|
||||
output = output .. table.concat(data, "\n")
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
exit_code = code
|
||||
end,
|
||||
}
|
||||
|
||||
-- Run synchronously with timeout
|
||||
local Job = require("plenary.job")
|
||||
local job = Job:new(job_opts)
|
||||
|
||||
job:sync(timeout)
|
||||
exit_code = job.code or 0
|
||||
output = table.concat(job:result() or {}, "\n")
|
||||
|
||||
-- Also get stderr
|
||||
local stderr = table.concat(job:stderr_result() or {}, "\n")
|
||||
if stderr and stderr ~= "" then
|
||||
output = output .. "\n" .. stderr
|
||||
end
|
||||
|
||||
-- Check result
|
||||
if exit_code ~= 0 then
|
||||
local error_msg = string.format("Command failed with exit code %d: %s", exit_code, output)
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, error_msg)
|
||||
end
|
||||
return nil, error_msg
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(output, nil)
|
||||
end
|
||||
|
||||
return output, nil
|
||||
end
|
||||
|
||||
return M
|
||||
391
lua/codetyper/core/tools/edit.lua
Normal file
391
lua/codetyper/core/tools/edit.lua
Normal file
@@ -0,0 +1,391 @@
|
||||
---@mod codetyper.agent.tools.edit File editing tool with fallback matching
|
||||
---@brief [[
|
||||
--- Tool for making targeted edits to files using search/replace.
|
||||
--- Implements multiple fallback strategies for robust matching.
|
||||
--- Multi-strategy approach for reliable editing.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.core.tools.base")
|
||||
local description = require("codetyper.prompts.agents.edit").description
|
||||
local params = require("codetyper.params.agents.edit").params
|
||||
local returns = require("codetyper.params.agents.edit").returns
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "edit"
|
||||
M.description = description
|
||||
M.params = params
|
||||
M.returns = returns
|
||||
M.requires_confirmation = false
|
||||
|
||||
--- Normalize line endings to LF
|
||||
---@param str string
|
||||
---@return string
|
||||
local function normalize_line_endings(str)
|
||||
return str:gsub("\r\n", "\n"):gsub("\r", "\n")
|
||||
end
|
||||
|
||||
--- Strategy 1: Exact match
|
||||
---@param content string File content
|
||||
---@param old_str string String to find
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function exact_match(content, old_str)
|
||||
local pos = content:find(old_str, 1, true)
|
||||
if pos then
|
||||
return pos, pos + #old_str - 1
|
||||
end
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 2: Whitespace-normalized match
|
||||
--- Collapses all whitespace to single spaces
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function whitespace_normalized_match(content, old_str)
|
||||
local function normalize_ws(s)
|
||||
return s:gsub("%s+", " "):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
end
|
||||
|
||||
local norm_old = normalize_ws(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
-- Try to find matching block
|
||||
for i = 1, #lines do
|
||||
local block = {}
|
||||
local block_start = nil
|
||||
|
||||
for j = i, #lines do
|
||||
table.insert(block, lines[j])
|
||||
local block_text = table.concat(block, "\n")
|
||||
local norm_block = normalize_ws(block_text)
|
||||
|
||||
if norm_block == norm_old then
|
||||
-- Found match
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
|
||||
-- If block is already longer than target, stop
|
||||
if #norm_block > #norm_old then
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 3: Indentation-flexible match
|
||||
--- Ignores leading whitespace differences
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function indentation_flexible_match(content, old_str)
|
||||
local function strip_indent(s)
|
||||
local lines = vim.split(s, "\n")
|
||||
local result = {}
|
||||
for _, line in ipairs(lines) do
|
||||
table.insert(result, line:gsub("^%s+", ""))
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
local stripped_old = strip_indent(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local num_old_lines = #old_lines
|
||||
|
||||
for i = 1, #lines - num_old_lines + 1 do
|
||||
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
|
||||
local block_text = table.concat(block, "\n")
|
||||
|
||||
if strip_indent(block_text) == stripped_old then
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Strategy 4: Line-trimmed match
|
||||
--- Trims each line before comparing
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function line_trimmed_match(content, old_str)
|
||||
local function trim_lines(s)
|
||||
local lines = vim.split(s, "\n")
|
||||
local result = {}
|
||||
for _, line in ipairs(lines) do
|
||||
table.insert(result, line:match("^%s*(.-)%s*$"))
|
||||
end
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
local trimmed_old = trim_lines(old_str)
|
||||
local lines = vim.split(content, "\n")
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local num_old_lines = #old_lines
|
||||
|
||||
for i = 1, #lines - num_old_lines + 1 do
|
||||
local block = vim.list_slice(lines, i, i + num_old_lines - 1)
|
||||
local block_text = table.concat(block, "\n")
|
||||
|
||||
if trim_lines(block_text) == trimmed_old then
|
||||
local before = table.concat(vim.list_slice(lines, 1, i - 1), "\n")
|
||||
local start_pos = #before + (i > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block_text - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Calculate Levenshtein distance between two strings
|
||||
---@param s1 string
|
||||
---@param s2 string
|
||||
---@return number
|
||||
local function levenshtein(s1, s2)
|
||||
local len1, len2 = #s1, #s2
|
||||
local matrix = {}
|
||||
|
||||
for i = 0, len1 do
|
||||
matrix[i] = { [0] = i }
|
||||
end
|
||||
for j = 0, len2 do
|
||||
matrix[0][j] = j
|
||||
end
|
||||
|
||||
for i = 1, len1 do
|
||||
for j = 1, len2 do
|
||||
local cost = s1:sub(i, i) == s2:sub(j, j) and 0 or 1
|
||||
matrix[i][j] = math.min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost)
|
||||
end
|
||||
end
|
||||
|
||||
return matrix[len1][len2]
|
||||
end
|
||||
|
||||
--- Strategy 5: Fuzzy anchor-based match
|
||||
--- Uses first and last lines as anchors, allows fuzzy matching in between
|
||||
---@param content string
|
||||
---@param old_str string
|
||||
---@param threshold? number Similarity threshold (0-1), default 0.8
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
local function fuzzy_anchor_match(content, old_str, threshold)
|
||||
threshold = threshold or 0.8
|
||||
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
if #old_lines < 2 then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local first_line = old_lines[1]:match("^%s*(.-)%s*$")
|
||||
local last_line = old_lines[#old_lines]:match("^%s*(.-)%s*$")
|
||||
local content_lines = vim.split(content, "\n")
|
||||
|
||||
-- Find potential start positions
|
||||
local candidates = {}
|
||||
for i, line in ipairs(content_lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
if
|
||||
trimmed == first_line
|
||||
or (
|
||||
#first_line > 0
|
||||
and 1 - (levenshtein(trimmed, first_line) / math.max(#trimmed, #first_line)) >= threshold
|
||||
)
|
||||
then
|
||||
table.insert(candidates, i)
|
||||
end
|
||||
end
|
||||
|
||||
-- For each candidate, look for matching end
|
||||
for _, start_idx in ipairs(candidates) do
|
||||
local expected_end = start_idx + #old_lines - 1
|
||||
if expected_end <= #content_lines then
|
||||
local end_line = content_lines[expected_end]:match("^%s*(.-)%s*$")
|
||||
if
|
||||
end_line == last_line
|
||||
or (
|
||||
#last_line > 0
|
||||
and 1 - (levenshtein(end_line, last_line) / math.max(#end_line, #last_line)) >= threshold
|
||||
)
|
||||
then
|
||||
-- Calculate positions
|
||||
local before = table.concat(vim.list_slice(content_lines, 1, start_idx - 1), "\n")
|
||||
local block = table.concat(vim.list_slice(content_lines, start_idx, expected_end), "\n")
|
||||
local start_pos = #before + (start_idx > 1 and 2 or 1)
|
||||
local end_pos = start_pos + #block - 1
|
||||
return start_pos, end_pos
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
--- Try all matching strategies in order
|
||||
---@param content string File content
|
||||
---@param old_str string String to find
|
||||
---@return number|nil start_pos
|
||||
---@return number|nil end_pos
|
||||
---@return string strategy_used
|
||||
local function find_match(content, old_str)
|
||||
-- Strategy 1: Exact match
|
||||
local start_pos, end_pos = exact_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "exact"
|
||||
end
|
||||
|
||||
-- Strategy 2: Whitespace-normalized
|
||||
start_pos, end_pos = whitespace_normalized_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "whitespace_normalized"
|
||||
end
|
||||
|
||||
-- Strategy 3: Indentation-flexible
|
||||
start_pos, end_pos = indentation_flexible_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "indentation_flexible"
|
||||
end
|
||||
|
||||
-- Strategy 4: Line-trimmed
|
||||
start_pos, end_pos = line_trimmed_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "line_trimmed"
|
||||
end
|
||||
|
||||
-- Strategy 5: Fuzzy anchor
|
||||
start_pos, end_pos = fuzzy_anchor_match(content, old_str)
|
||||
if start_pos then
|
||||
return start_pos, end_pos, "fuzzy_anchor"
|
||||
end
|
||||
|
||||
return nil, nil, "none"
|
||||
end
|
||||
|
||||
---@param input {path: string, old_string: string, new_string: string}
|
||||
---@param opts CoderToolOpts
|
||||
---@return boolean|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
if input.old_string == nil then
|
||||
return nil, "old_string is required"
|
||||
end
|
||||
if input.new_string == nil then
|
||||
return nil, "new_string is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Editing file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Normalize inputs
|
||||
local old_str = normalize_line_endings(input.old_string)
|
||||
local new_str = normalize_line_endings(input.new_string)
|
||||
|
||||
-- Handle new file creation (empty old_string)
|
||||
if old_str == "" then
|
||||
-- Create parent directories
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if vim.fn.isdirectory(dir) == 0 then
|
||||
vim.fn.mkdir(dir, "p")
|
||||
end
|
||||
|
||||
-- Write new file
|
||||
local lines = vim.split(new_str, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to create file: " .. input.path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
if vim.fn.filereadable(path) ~= 1 then
|
||||
return nil, "File not found: " .. input.path
|
||||
end
|
||||
|
||||
-- Read current content
|
||||
local lines = vim.fn.readfile(path)
|
||||
if not lines then
|
||||
return nil, "Failed to read file: " .. input.path
|
||||
end
|
||||
|
||||
local content = normalize_line_endings(table.concat(lines, "\n"))
|
||||
|
||||
-- Find match using fallback strategies
|
||||
local start_pos, end_pos, strategy = find_match(content, old_str)
|
||||
|
||||
if not start_pos then
|
||||
return nil, "old_string not found in file (tried 5 matching strategies)"
|
||||
end
|
||||
|
||||
if opts.on_log then
|
||||
opts.on_log("Match found using strategy: " .. strategy)
|
||||
end
|
||||
|
||||
-- Perform replacement
|
||||
local new_content = content:sub(1, start_pos - 1) .. new_str .. content:sub(end_pos + 1)
|
||||
|
||||
-- Write back
|
||||
local new_lines = vim.split(new_content, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, new_lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to write file: " .. input.path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
return M
|
||||
146
lua/codetyper/core/tools/glob.lua
Normal file
146
lua/codetyper/core/tools/glob.lua
Normal file
@@ -0,0 +1,146 @@
|
||||
---@mod codetyper.agent.tools.glob File pattern matching tool
|
||||
---@brief [[
|
||||
--- Tool for finding files by glob pattern.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.core.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "glob"
|
||||
|
||||
M.description = [[Finds files matching a glob pattern.
|
||||
|
||||
Example patterns:
|
||||
- "**/*.lua" - All Lua files
|
||||
- "src/**/*.ts" - TypeScript files in src
|
||||
- "**/test_*.py" - Test files in Python]]
|
||||
|
||||
M.params = {
|
||||
{
|
||||
name = "pattern",
|
||||
description = "Glob pattern to match files",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "path",
|
||||
description = "Base directory to search in (default: project root)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "max_results",
|
||||
description = "Maximum number of results (default: 100)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "matches",
|
||||
description = "JSON array of matching file paths",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if glob failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
---@param input {pattern: string, path?: string, max_results?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.pattern then
|
||||
return nil, "pattern is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Finding files: " .. input.pattern)
|
||||
end
|
||||
|
||||
-- Resolve base path
|
||||
local base_path = input.path or vim.fn.getcwd()
|
||||
if not vim.startswith(base_path, "/") then
|
||||
base_path = vim.fn.getcwd() .. "/" .. base_path
|
||||
end
|
||||
|
||||
local max_results = input.max_results or 100
|
||||
|
||||
-- Use vim.fn.glob or fd if available
|
||||
local matches = {}
|
||||
|
||||
if vim.fn.executable("fd") == 1 then
|
||||
-- Use fd for better performance
|
||||
local Job = require("plenary.job")
|
||||
|
||||
-- Convert glob to fd pattern
|
||||
local fd_pattern = input.pattern:gsub("%*%*/", ""):gsub("%*", ".*")
|
||||
|
||||
local job = Job:new({
|
||||
command = "fd",
|
||||
args = {
|
||||
"--type",
|
||||
"f",
|
||||
"--max-results",
|
||||
tostring(max_results),
|
||||
"--glob",
|
||||
input.pattern,
|
||||
base_path,
|
||||
},
|
||||
cwd = base_path,
|
||||
})
|
||||
|
||||
job:sync(30000)
|
||||
matches = job:result() or {}
|
||||
else
|
||||
-- Fallback to vim.fn.globpath
|
||||
local pattern = base_path .. "/" .. input.pattern
|
||||
local files = vim.fn.glob(pattern, false, true)
|
||||
|
||||
for i, file in ipairs(files) do
|
||||
if i > max_results then
|
||||
break
|
||||
end
|
||||
-- Make paths relative to base_path
|
||||
local relative = file:gsub("^" .. vim.pesc(base_path) .. "/", "")
|
||||
table.insert(matches, relative)
|
||||
end
|
||||
end
|
||||
|
||||
-- Clean up matches
|
||||
local cleaned = {}
|
||||
for _, match in ipairs(matches) do
|
||||
if match and match ~= "" then
|
||||
-- Make relative if absolute
|
||||
local relative = match
|
||||
if vim.startswith(match, base_path) then
|
||||
relative = match:sub(#base_path + 2)
|
||||
end
|
||||
table.insert(cleaned, relative)
|
||||
end
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
matches = cleaned,
|
||||
total = #cleaned,
|
||||
truncated = #cleaned >= max_results,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
107
lua/codetyper/core/tools/grep.lua
Normal file
107
lua/codetyper/core/tools/grep.lua
Normal file
@@ -0,0 +1,107 @@
|
||||
---@mod codetyper.agent.tools.grep Search tool
|
||||
---@brief [[
|
||||
--- Tool for searching file contents using ripgrep.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.core.tools.base")
|
||||
local description = require("codetyper.params.agents.grep").description
|
||||
local params = require("codetyper.prompts.agents.grep").params
|
||||
local returns = require("codetyper.prompts.agents.grep").returns
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "grep"
|
||||
M.description = description
|
||||
M.params = params
|
||||
M.returns = returns
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
---@param input {pattern: string, path?: string, include?: string, max_results?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.pattern then
|
||||
return nil, "pattern is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Searching for: " .. input.pattern)
|
||||
end
|
||||
|
||||
-- Build ripgrep command
|
||||
local path = input.path or vim.fn.getcwd()
|
||||
local max_results = input.max_results or 50
|
||||
|
||||
-- Resolve path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Check if ripgrep is available
|
||||
if vim.fn.executable("rg") ~= 1 then
|
||||
return nil, "ripgrep (rg) is not installed"
|
||||
end
|
||||
|
||||
-- Build command args
|
||||
local args = {
|
||||
"--json",
|
||||
"--max-count",
|
||||
tostring(max_results),
|
||||
"--no-heading",
|
||||
}
|
||||
|
||||
if input.include then
|
||||
table.insert(args, "--glob")
|
||||
table.insert(args, input.include)
|
||||
end
|
||||
|
||||
table.insert(args, input.pattern)
|
||||
table.insert(args, path)
|
||||
|
||||
-- Execute ripgrep
|
||||
local Job = require("plenary.job")
|
||||
local job = Job:new({
|
||||
command = "rg",
|
||||
args = args,
|
||||
cwd = vim.fn.getcwd(),
|
||||
})
|
||||
|
||||
job:sync(30000) -- 30 second timeout
|
||||
|
||||
local results = job:result() or {}
|
||||
local matches = {}
|
||||
|
||||
-- Parse JSON output
|
||||
for _, line in ipairs(results) do
|
||||
if line and line ~= "" then
|
||||
local ok, parsed = pcall(vim.json.decode, line)
|
||||
if ok and parsed.type == "match" then
|
||||
local data = parsed.data
|
||||
table.insert(matches, {
|
||||
file = data.path.text,
|
||||
line_number = data.line_number,
|
||||
content = data.lines.text:gsub("\n$", ""),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
matches = matches,
|
||||
total = #matches,
|
||||
truncated = #matches >= max_results,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
90
lua/codetyper/core/tools/init.lua
Normal file
90
lua/codetyper/core/tools/init.lua
Normal file
@@ -0,0 +1,90 @@
|
||||
---@mod codetyper.agent.tools Tool definitions for the agent system
|
||||
---
|
||||
--- Defines available tools that the LLM can use to interact with files and system.
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Tool definitions in a provider-agnostic format
|
||||
M.definitions = require("codetyper.params.agents.tools").definitions
|
||||
|
||||
--- Convert tool definitions to Claude API format
|
||||
---@return table[] Tools in Claude's expected format
|
||||
function M.to_claude_format()
|
||||
local tools = {}
|
||||
for _, tool in pairs(M.definitions) do
|
||||
table.insert(tools, {
|
||||
name = tool.name,
|
||||
description = tool.description,
|
||||
input_schema = tool.parameters,
|
||||
})
|
||||
end
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Convert tool definitions to OpenAI API format
|
||||
---@return table[] Tools in OpenAI's expected format
|
||||
function M.to_openai_format()
|
||||
local tools = {}
|
||||
for _, tool in pairs(M.definitions) do
|
||||
table.insert(tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = tool.description,
|
||||
parameters = tool.parameters,
|
||||
},
|
||||
})
|
||||
end
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Convert tool definitions to prompt format for Ollama
|
||||
---@return string Formatted tool descriptions for system prompt
|
||||
function M.to_prompt_format()
|
||||
local prompts = require("codetyper.prompts.agents.tools").instructions
|
||||
local lines = {
|
||||
prompts.intro,
|
||||
"",
|
||||
}
|
||||
|
||||
for _, tool in pairs(M.definitions) do
|
||||
table.insert(lines, "## " .. tool.name)
|
||||
table.insert(lines, tool.description)
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "Parameters:")
|
||||
for prop_name, prop in pairs(tool.parameters.properties) do
|
||||
local required = vim.tbl_contains(tool.parameters.required or {}, prop_name)
|
||||
local req_str = required and " (required)" or " (optional)"
|
||||
table.insert(lines, " - " .. prop_name .. ": " .. prop.description .. req_str)
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
table.insert(lines, "---")
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, prompts.header)
|
||||
table.insert(lines, prompts.example)
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, prompts.footer)
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Get a list of tool names
|
||||
---@return string[]
|
||||
function M.get_tool_names()
|
||||
local names = {}
|
||||
for name, _ in pairs(M.definitions) do
|
||||
table.insert(names, name)
|
||||
end
|
||||
return names
|
||||
end
|
||||
|
||||
--- Optional setup function for future extensibility
|
||||
---@param opts table|nil Configuration options
|
||||
function M.setup(opts)
|
||||
-- Currently a no-op. Plugins or tests may call setup(); keep for compatibility.
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
308
lua/codetyper/core/tools/registry.lua
Normal file
308
lua/codetyper/core/tools/registry.lua
Normal file
@@ -0,0 +1,308 @@
|
||||
---@mod codetyper.agent.tools Tool registry and orchestration
|
||||
---@brief [[
|
||||
--- Registry for LLM tools with execution and schema generation.
|
||||
--- Tool system for agent mode.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
--- Registered tools
|
||||
---@type table<string, CoderTool>
|
||||
local tools = {}
|
||||
|
||||
--- Tool execution history for current session
|
||||
---@type table[]
|
||||
local execution_history = {}
|
||||
|
||||
--- Register a tool
|
||||
---@param tool CoderTool Tool to register
|
||||
function M.register(tool)
|
||||
if not tool.name then
|
||||
error("Tool must have a name")
|
||||
end
|
||||
tools[tool.name] = tool
|
||||
end
|
||||
|
||||
--- Unregister a tool
|
||||
---@param name string Tool name
|
||||
function M.unregister(name)
|
||||
tools[name] = nil
|
||||
end
|
||||
|
||||
--- Get a tool by name
|
||||
---@param name string Tool name
|
||||
---@return CoderTool|nil
|
||||
function M.get(name)
|
||||
return tools[name]
|
||||
end
|
||||
|
||||
--- Get all registered tools
|
||||
---@return table<string, CoderTool>
|
||||
function M.get_all()
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Get tools as a list
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return CoderTool[]
|
||||
function M.list(filter)
|
||||
local result = {}
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
table.insert(result, tool)
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Generate schemas for all tools (for LLM function calling)
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] schemas
|
||||
function M.get_schemas(filter)
|
||||
local schemas = {}
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
if tool.to_schema then
|
||||
table.insert(schemas, tool:to_schema())
|
||||
end
|
||||
end
|
||||
end
|
||||
return schemas
|
||||
end
|
||||
|
||||
--- Execute a tool by name
|
||||
---@param name string Tool name
|
||||
---@param input table Input parameters
|
||||
---@param opts CoderToolOpts Execution options
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M.execute(name, input, opts)
|
||||
local tool = tools[name]
|
||||
if not tool then
|
||||
return nil, "Unknown tool: " .. name
|
||||
end
|
||||
|
||||
-- Validate input
|
||||
if tool.validate_input then
|
||||
local valid, err = tool:validate_input(input)
|
||||
if not valid then
|
||||
return nil, err
|
||||
end
|
||||
end
|
||||
|
||||
-- Log execution
|
||||
if opts.on_log then
|
||||
opts.on_log(string.format("Executing tool: %s", name))
|
||||
end
|
||||
|
||||
-- Track execution
|
||||
local execution = {
|
||||
tool = name,
|
||||
input = input,
|
||||
start_time = os.time(),
|
||||
status = "running",
|
||||
}
|
||||
table.insert(execution_history, execution)
|
||||
|
||||
-- Execute the tool
|
||||
local result, err = tool.func(input, opts)
|
||||
|
||||
-- Update execution record
|
||||
execution.end_time = os.time()
|
||||
execution.status = err and "error" or "completed"
|
||||
execution.result = result
|
||||
execution.error = err
|
||||
|
||||
return result, err
|
||||
end
|
||||
|
||||
--- Process a tool call from LLM response
|
||||
---@param tool_call table Tool call from LLM (name + input)
|
||||
---@param opts CoderToolOpts Execution options
|
||||
---@return any result
|
||||
---@return string|nil error
|
||||
function M.process_tool_call(tool_call, opts)
|
||||
local name = tool_call.name or tool_call.function_name
|
||||
local input = tool_call.input or tool_call.arguments or {}
|
||||
|
||||
-- Parse JSON arguments if string
|
||||
if type(input) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, input)
|
||||
if ok then
|
||||
input = parsed
|
||||
else
|
||||
return nil, "Failed to parse tool arguments: " .. input
|
||||
end
|
||||
end
|
||||
|
||||
return M.execute(name, input, opts)
|
||||
end
|
||||
|
||||
--- Get execution history
|
||||
---@param limit? number Max entries to return
|
||||
---@return table[]
|
||||
function M.get_history(limit)
|
||||
if not limit then
|
||||
return execution_history
|
||||
end
|
||||
|
||||
local result = {}
|
||||
local start = math.max(1, #execution_history - limit + 1)
|
||||
for i = start, #execution_history do
|
||||
table.insert(result, execution_history[i])
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
--- Clear execution history
|
||||
function M.clear_history()
|
||||
execution_history = {}
|
||||
end
|
||||
|
||||
--- Load built-in tools
|
||||
function M.load_builtins()
|
||||
-- View file tool
|
||||
local view = require("codetyper.core.tools.view")
|
||||
M.register(view)
|
||||
|
||||
-- Bash tool
|
||||
local bash = require("codetyper.core.tools.bash")
|
||||
M.register(bash)
|
||||
|
||||
-- Grep tool
|
||||
local grep = require("codetyper.core.tools.grep")
|
||||
M.register(grep)
|
||||
|
||||
-- Glob tool
|
||||
local glob = require("codetyper.core.tools.glob")
|
||||
M.register(glob)
|
||||
|
||||
-- Write file tool
|
||||
local write = require("codetyper.core.tools.write")
|
||||
M.register(write)
|
||||
|
||||
-- Edit tool
|
||||
local edit = require("codetyper.core.tools.edit")
|
||||
M.register(edit)
|
||||
end
|
||||
|
||||
--- Initialize tools system
|
||||
function M.setup()
|
||||
M.load_builtins()
|
||||
end
|
||||
|
||||
--- Get tool definitions for LLM (lazy-loaded, OpenAI format)
|
||||
--- This is accessed as M.definitions property
|
||||
M.definitions = setmetatable({}, {
|
||||
__call = function()
|
||||
-- Ensure tools are loaded
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end,
|
||||
__index = function(_, key)
|
||||
-- Make it work as both function and table
|
||||
if key == "get" then
|
||||
return function()
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end,
|
||||
})
|
||||
|
||||
--- Get definitions as a function (for backwards compatibility)
|
||||
function M.get_definitions()
|
||||
if vim.tbl_count(tools) == 0 then
|
||||
M.load_builtins()
|
||||
end
|
||||
return M.to_openai_format()
|
||||
end
|
||||
|
||||
--- Convert all tools to OpenAI function calling format
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] OpenAI-compatible tool definitions
|
||||
function M.to_openai_format(filter)
|
||||
local openai_tools = {}
|
||||
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if param.default ~= nil then
|
||||
properties[param.name].default = param.default
|
||||
end
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
table.insert(openai_tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return openai_tools
|
||||
end
|
||||
|
||||
--- Convert all tools to Claude tool use format
|
||||
---@param filter? fun(tool: CoderTool): boolean Optional filter function
|
||||
---@return table[] Claude-compatible tool definitions
|
||||
function M.to_claude_format(filter)
|
||||
local claude_tools = {}
|
||||
|
||||
for _, tool in pairs(tools) do
|
||||
if not filter or filter(tool) then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
table.insert(claude_tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return claude_tools
|
||||
end
|
||||
|
||||
return M
|
||||
114
lua/codetyper/core/tools/view.lua
Normal file
114
lua/codetyper/core/tools/view.lua
Normal file
@@ -0,0 +1,114 @@
|
||||
---@mod codetyper.agent.tools.view File viewing tool
|
||||
---@brief [[
|
||||
--- Tool for reading file contents with line range support.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.core.tools.base")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "view"
|
||||
|
||||
local params = require("codetyper.params.agents.view")
|
||||
local description = require("codetyper.prompts.agents.view").description
|
||||
|
||||
M.description = description
|
||||
M.params = params.params
|
||||
M.returns = params.returns
|
||||
|
||||
M.requires_confirmation = false
|
||||
|
||||
--- Maximum content size before truncation
|
||||
local MAX_CONTENT_SIZE = 200 * 1024 -- 200KB
|
||||
|
||||
---@param input {path: string, start_line?: integer, end_line?: integer}
|
||||
---@param opts CoderToolOpts
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Reading file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
-- Relative path - resolve from project root
|
||||
local root = vim.fn.getcwd()
|
||||
path = root .. "/" .. path
|
||||
end
|
||||
|
||||
-- Check if file exists
|
||||
local stat = vim.uv.fs_stat(path)
|
||||
if not stat then
|
||||
return nil, "File not found: " .. input.path
|
||||
end
|
||||
|
||||
if stat.type == "directory" then
|
||||
return nil, "Path is a directory: " .. input.path
|
||||
end
|
||||
|
||||
-- Read file
|
||||
local lines = vim.fn.readfile(path)
|
||||
if not lines then
|
||||
return nil, "Failed to read file: " .. input.path
|
||||
end
|
||||
|
||||
-- Apply line range
|
||||
local start_line = input.start_line or 1
|
||||
local end_line = input.end_line or #lines
|
||||
|
||||
start_line = math.max(1, start_line)
|
||||
end_line = math.min(#lines, end_line)
|
||||
|
||||
local total_lines = #lines
|
||||
local selected_lines = {}
|
||||
|
||||
for i = start_line, end_line do
|
||||
table.insert(selected_lines, lines[i])
|
||||
end
|
||||
|
||||
-- Check for truncation
|
||||
local content = table.concat(selected_lines, "\n")
|
||||
local is_truncated = false
|
||||
|
||||
if #content > MAX_CONTENT_SIZE then
|
||||
-- Truncate content
|
||||
local truncated_lines = {}
|
||||
local size = 0
|
||||
|
||||
for _, line in ipairs(selected_lines) do
|
||||
size = size + #line + 1
|
||||
if size > MAX_CONTENT_SIZE then
|
||||
is_truncated = true
|
||||
break
|
||||
end
|
||||
table.insert(truncated_lines, line)
|
||||
end
|
||||
|
||||
content = table.concat(truncated_lines, "\n")
|
||||
end
|
||||
|
||||
-- Return as JSON
|
||||
local result = vim.json.encode({
|
||||
content = content,
|
||||
total_line_count = total_lines,
|
||||
is_truncated = is_truncated,
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
})
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(result, nil)
|
||||
end
|
||||
|
||||
return result, nil
|
||||
end
|
||||
|
||||
return M
|
||||
72
lua/codetyper/core/tools/write.lua
Normal file
72
lua/codetyper/core/tools/write.lua
Normal file
@@ -0,0 +1,72 @@
|
||||
---@mod codetyper.agent.tools.write File writing tool
|
||||
---@brief [[
|
||||
--- Tool for creating or overwriting files.
|
||||
---@brief ]]
|
||||
|
||||
local Base = require("codetyper.core.tools.base")
|
||||
local description = require("codetyper.prompts.agents.write").description
|
||||
local params = require("codetyper.params.agents.write")
|
||||
|
||||
---@class CoderTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "write"
|
||||
M.description = description
|
||||
M.params = params.params
|
||||
M.returns = params.returns
|
||||
|
||||
M.requires_confirmation = true
|
||||
|
||||
---@param input {path: string, content: string}
|
||||
---@param opts CoderToolOpts
|
||||
---@return boolean|nil result
|
||||
---@return string|nil error
|
||||
function M.func(input, opts)
|
||||
if not input.path then
|
||||
return nil, "path is required"
|
||||
end
|
||||
if not input.content then
|
||||
return nil, "content is required"
|
||||
end
|
||||
|
||||
-- Log the operation
|
||||
if opts.on_log then
|
||||
opts.on_log("Writing file: " .. input.path)
|
||||
end
|
||||
|
||||
-- Resolve path
|
||||
local path = input.path
|
||||
if not vim.startswith(path, "/") then
|
||||
path = vim.fn.getcwd() .. "/" .. path
|
||||
end
|
||||
|
||||
-- Create parent directories
|
||||
local dir = vim.fn.fnamemodify(path, ":h")
|
||||
if vim.fn.isdirectory(dir) == 0 then
|
||||
vim.fn.mkdir(dir, "p")
|
||||
end
|
||||
|
||||
-- Write the file
|
||||
local lines = vim.split(input.content, "\n", { plain = true })
|
||||
local ok = pcall(vim.fn.writefile, lines, path)
|
||||
|
||||
if not ok then
|
||||
return nil, "Failed to write file: " .. path
|
||||
end
|
||||
|
||||
-- Reload buffer if open
|
||||
local bufnr = vim.fn.bufnr(path)
|
||||
if bufnr ~= -1 and vim.api.nvim_buf_is_valid(bufnr) then
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("edit!")
|
||||
end)
|
||||
end
|
||||
|
||||
if opts.on_complete then
|
||||
opts.on_complete(true, nil)
|
||||
end
|
||||
|
||||
return true, nil
|
||||
end
|
||||
|
||||
return M
|
||||
268
lua/codetyper/features/agents/context_builder.lua
Normal file
268
lua/codetyper/features/agents/context_builder.lua
Normal file
@@ -0,0 +1,268 @@
|
||||
---@mod codetyper.agent.context_builder Context builder for agent prompts
|
||||
---
|
||||
--- Builds rich context including project structure, memories, and conventions
|
||||
--- to help the LLM understand the codebase.
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local params = require("codetyper.params.agents.context")
|
||||
|
||||
--- Get project structure as a tree string
|
||||
---@param max_depth? number Maximum depth to traverse (default: 3)
|
||||
---@param max_files? number Maximum files to show (default: 50)
|
||||
---@return string Project tree
|
||||
function M.get_project_structure(max_depth, max_files)
|
||||
max_depth = max_depth or 3
|
||||
max_files = max_files or 50
|
||||
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local lines = { "PROJECT STRUCTURE:", root, "" }
|
||||
local file_count = 0
|
||||
|
||||
-- Common ignore patterns
|
||||
local ignore_patterns = params.ignore_patterns
|
||||
|
||||
local function should_ignore(name)
|
||||
for _, pattern in ipairs(ignore_patterns) do
|
||||
if name:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function traverse(path, depth, prefix)
|
||||
if depth > max_depth or file_count >= max_files then
|
||||
return
|
||||
end
|
||||
|
||||
local entries = {}
|
||||
local handle = vim.loop.fs_scandir(path)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
if not should_ignore(name) then
|
||||
table.insert(entries, { name = name, type = type })
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort: directories first, then alphabetically
|
||||
table.sort(entries, function(a, b)
|
||||
if a.type == "directory" and b.type ~= "directory" then
|
||||
return true
|
||||
elseif a.type ~= "directory" and b.type == "directory" then
|
||||
return false
|
||||
else
|
||||
return a.name < b.name
|
||||
end
|
||||
end)
|
||||
|
||||
for i, entry in ipairs(entries) do
|
||||
if file_count >= max_files then
|
||||
table.insert(lines, prefix .. "... (truncated)")
|
||||
return
|
||||
end
|
||||
|
||||
local is_last = (i == #entries)
|
||||
local branch = is_last and "└── " or "├── "
|
||||
local new_prefix = prefix .. (is_last and " " or "│ ")
|
||||
|
||||
local icon = entry.type == "directory" and "/" or ""
|
||||
table.insert(lines, prefix .. branch .. entry.name .. icon)
|
||||
file_count = file_count + 1
|
||||
|
||||
if entry.type == "directory" then
|
||||
traverse(path .. "/" .. entry.name, depth + 1, new_prefix)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
traverse(root, 1, "")
|
||||
|
||||
if file_count >= max_files then
|
||||
table.insert(lines, "")
|
||||
table.insert(lines, "(Structure truncated at " .. max_files .. " entries)")
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Get key files that are important for understanding the project
|
||||
---@return table<string, string> Map of filename to description
|
||||
function M.get_key_files()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local key_files = {}
|
||||
|
||||
local important_files = {
|
||||
["package.json"] = "Node.js project config",
|
||||
["Cargo.toml"] = "Rust project config",
|
||||
["go.mod"] = "Go module config",
|
||||
["pyproject.toml"] = "Python project config",
|
||||
["setup.py"] = "Python setup config",
|
||||
["Makefile"] = "Build configuration",
|
||||
["CMakeLists.txt"] = "CMake config",
|
||||
[".gitignore"] = "Git ignore patterns",
|
||||
["README.md"] = "Project documentation",
|
||||
["init.lua"] = "Neovim plugin entry",
|
||||
["plugin.lua"] = "Neovim plugin config",
|
||||
}
|
||||
|
||||
for filename, desc in paparams.important_filesnd
|
||||
|
||||
return key_files
|
||||
end
|
||||
|
||||
--- Detect project type and language
|
||||
---@return table { type: string, language: string, framework?: string }
|
||||
function M.detect_project_type()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
|
||||
local indicators = {
|
||||
["package.json"] = { type = "node", language = "javascript/typescript" },
|
||||
["Cargo.toml"] = { type = "rust", language = "rust" },
|
||||
["go.mod"] = { type = "go", language = "go" },
|
||||
["pyproject.toml"] = { type = "python", language = "python" },
|
||||
["setup.py"] = { type = "python", language = "python" },
|
||||
["Gemfile"] = { type = "ruby", language = "ruby" },
|
||||
["pom.xml"] = { type = "maven", language = "java" },
|
||||
["build.gradle"] = { type = "gradle", language = "java/kotlin" },
|
||||
}
|
||||
|
||||
-- Check for Neovim plugin specifically
|
||||
if vim.fn.isdirectoparams.indicators return info
|
||||
end
|
||||
end
|
||||
|
||||
return { type = "unknown", language = "unknown" }
|
||||
end
|
||||
|
||||
--- Get memories/patterns from the brain system
|
||||
---@return string Formatted memories context
|
||||
function M.get_memories_context()
|
||||
local ok_memory, memory = pcall(require, "codetyper.indexer.memory")
|
||||
if not ok_memory then
|
||||
return ""
|
||||
end
|
||||
|
||||
local all = memory.get_all()
|
||||
if not all then
|
||||
return ""
|
||||
end
|
||||
|
||||
local lines = {}
|
||||
|
||||
-- Add patterns
|
||||
if all.patterns and next(all.patterns) then
|
||||
table.insert(lines, "LEARNED PATTERNS:")
|
||||
local count = 0
|
||||
for _, mem in pairs(all.patterns) do
|
||||
if count >= 5 then
|
||||
break
|
||||
end
|
||||
if mem.content then
|
||||
table.insert(lines, " - " .. mem.content:sub(1, 100))
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
-- Add conventions
|
||||
if all.conventions and next(all.conventions) then
|
||||
table.insert(lines, "CODING CONVENTIONS:")
|
||||
local count = 0
|
||||
for _, mem in pairs(all.conventions) do
|
||||
if count >= 5 then
|
||||
break
|
||||
end
|
||||
if mem.content then
|
||||
table.insert(lines, " - " .. mem.content:sub(1, 100))
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
--- Build the full context for agent prompts
|
||||
---@return string Full context string
|
||||
function M.build_full_context()
|
||||
local sections = {}
|
||||
|
||||
-- Project info
|
||||
local project_type = M.detect_project_type()
|
||||
table.insert(sections, string.format(
|
||||
"PROJECT INFO:\n Type: %s\n Language: %s%s\n",
|
||||
project_type.type,
|
||||
project_type.language,
|
||||
project_type.framework and ("\n Framework: " .. project_type.framework) or ""
|
||||
))
|
||||
|
||||
-- Project structure
|
||||
local structure = M.get_project_structure(3, 40)
|
||||
table.insert(sections, structure)
|
||||
|
||||
-- Key files
|
||||
local key_files = M.get_key_files()
|
||||
if next(key_files) then
|
||||
local key_lines = { "", "KEY FILES:" }
|
||||
for name, info in pairs(key_files) do
|
||||
table.insert(key_lines, string.format(" %s - %s", name, info.description))
|
||||
end
|
||||
table.insert(sections, table.concat(key_lines, "\n"))
|
||||
end
|
||||
|
||||
-- Memories
|
||||
local memories = M.get_memories_context()
|
||||
if memories ~= "" then
|
||||
table.insert(sections, "\n" .. memories)
|
||||
end
|
||||
|
||||
return table.concat(sections, "\n")
|
||||
end
|
||||
|
||||
--- Get a compact context summary for token efficiency
|
||||
---@return string Compact context
|
||||
function M.build_compact_context()
|
||||
local root = utils.get_project_root() or vim.fn.getcwd()
|
||||
local project_type = M.detect_project_type()
|
||||
|
||||
local lines = {
|
||||
"CONTEXT:",
|
||||
" Root: " .. root,
|
||||
" Type: " .. project_type.type .. " (" .. project_type.language .. ")",
|
||||
}
|
||||
|
||||
-- Add main directories
|
||||
local main_dirs = {}
|
||||
local handle = vim.loop.fs_scandir(root)
|
||||
if handle then
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
if type == "directory" and not name:match("^%.") and not name:match("node_modules") then
|
||||
table.insert(main_dirs, name .. "/")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if #main_dirs > 0 then
|
||||
table.sort(main_dirs)
|
||||
table.insert(lines, " Main dirs: " .. table.concat(main_dirs, ", "))
|
||||
end
|
||||
|
||||
return table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
return M
|
||||
754
lua/codetyper/features/agents/engine.lua
Normal file
754
lua/codetyper/features/agents/engine.lua
Normal file
@@ -0,0 +1,754 @@
|
||||
---@mod codetyper.agent.agentic Agentic loop with proper tool calling
|
||||
---@brief [[
|
||||
--- Full agentic system that handles multi-file changes via tool calling.
|
||||
--- Multi-file agent system with tool orchestration.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class AgenticMessage
|
||||
---@field role "system"|"user"|"assistant"|"tool"
|
||||
---@field content string|table
|
||||
---@field tool_calls? table[] For assistant messages with tool calls
|
||||
---@field tool_call_id? string For tool result messages
|
||||
---@field name? string Tool name for tool results
|
||||
|
||||
---@class AgenticToolCall
|
||||
---@field id string Unique tool call ID
|
||||
---@field type "function"
|
||||
---@field function {name: string, arguments: string|table}
|
||||
|
||||
---@class AgenticOpts
|
||||
---@field task string The task to accomplish
|
||||
---@field files? string[] Initial files to include as context
|
||||
---@field agent? string Agent name to use (default: "coder")
|
||||
---@field model? string Model override
|
||||
---@field max_iterations? number Max tool call rounds (default: 20)
|
||||
---@field on_message? fun(msg: AgenticMessage) Called for each message
|
||||
---@field on_tool_start? fun(name: string, args: table) Called before tool execution
|
||||
---@field on_tool_end? fun(name: string, result: any, error: string|nil) Called after tool execution
|
||||
---@field on_file_change? fun(path: string, action: string) Called when file is modified
|
||||
---@field on_complete? fun(result: string|nil, error: string|nil) Called when done
|
||||
---@field on_status? fun(status: string) Status updates
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Load agent definition
|
||||
---@param name string Agent name
|
||||
---@return table|nil agent definition
|
||||
local function load_agent(name)
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
local agent_file = agents_dir .. "/" .. name .. ".md"
|
||||
|
||||
-- Check if custom agent exists
|
||||
if vim.fn.filereadable(agent_file) == 1 then
|
||||
local content = table.concat(vim.fn.readfile(agent_file), "\n")
|
||||
-- Parse frontmatter and content
|
||||
local frontmatter = {}
|
||||
local body = content
|
||||
|
||||
local fm_match = content:match("^%-%-%-\n(.-)%-%-%-\n(.*)$")
|
||||
if fm_match then
|
||||
-- Parse YAML-like frontmatter
|
||||
for line in content:match("^%-%-%-\n(.-)%-%-%-"):gmatch("[^\n]+") do
|
||||
local key, value = line:match("^(%w+):%s*(.+)$")
|
||||
if key and value then
|
||||
frontmatter[key] = value
|
||||
end
|
||||
end
|
||||
body = content:match("%-%-%-\n.-%-%-%-%s*\n(.*)$") or content
|
||||
end
|
||||
|
||||
return {
|
||||
name = name,
|
||||
description = frontmatter.description or "Custom agent: " .. name,
|
||||
system_prompt = body,
|
||||
tools = frontmatter.tools and vim.split(frontmatter.tools, ",") or nil,
|
||||
model = frontmatter.model,
|
||||
}
|
||||
end
|
||||
|
||||
-- Built-in agents
|
||||
local builtin_agents = require("codetyper.prompts.agents.personas").builtin
|
||||
|
||||
return builtin_agents[name]
|
||||
end
|
||||
|
||||
--- Load rules from .coder/rules/
|
||||
---@return string Combined rules content
|
||||
local function load_rules()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
local rules = {}
|
||||
|
||||
if vim.fn.isdirectory(rules_dir) == 1 then
|
||||
local files = vim.fn.glob(rules_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local content = table.concat(vim.fn.readfile(file), "\n")
|
||||
local filename = vim.fn.fnamemodify(file, ":t:r")
|
||||
table.insert(rules, string.format("## Rule: %s\n%s", filename, content))
|
||||
end
|
||||
end
|
||||
|
||||
if #rules > 0 then
|
||||
return "\n\n# Project Rules\n" .. table.concat(rules, "\n\n")
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
--- Build messages array for API request
|
||||
---@param history AgenticMessage[]
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted messages
|
||||
local function build_messages(history, provider)
|
||||
local messages = {}
|
||||
|
||||
for _, msg in ipairs(history) do
|
||||
if msg.role == "system" then
|
||||
if provider == "claude" then
|
||||
-- Claude uses system parameter, not message
|
||||
-- Skip system messages in array
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "system",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "user" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
elseif msg.role == "assistant" then
|
||||
local message = {
|
||||
role = "assistant",
|
||||
content = msg.content,
|
||||
}
|
||||
if msg.tool_calls then
|
||||
message.tool_calls = msg.tool_calls
|
||||
if provider == "claude" then
|
||||
-- Claude format: content is array of blocks
|
||||
message.content = {}
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(message.content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
for _, tc in ipairs(msg.tool_calls) do
|
||||
table.insert(message.content, {
|
||||
type = "tool_use",
|
||||
id = tc.id,
|
||||
name = tc["function"].name,
|
||||
input = type(tc["function"].arguments) == "string"
|
||||
and vim.json.decode(tc["function"].arguments)
|
||||
or tc["function"].arguments,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(messages, message)
|
||||
elseif msg.role == "tool" then
|
||||
if provider == "claude" then
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = msg.tool_call_id,
|
||||
content = msg.content,
|
||||
},
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(messages, {
|
||||
role = "tool",
|
||||
tool_call_id = msg.tool_call_id,
|
||||
content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
--- Build tools array for API request
|
||||
---@param tool_names string[] Tool names to include
|
||||
---@param provider string "openai"|"claude"
|
||||
---@return table[] Formatted tools
|
||||
local function build_tools(tool_names, provider)
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
local tools = {}
|
||||
|
||||
for _, name in ipairs(tool_names) do
|
||||
local tool = tools_mod.get(name)
|
||||
if tool then
|
||||
local properties = {}
|
||||
local required = {}
|
||||
|
||||
for _, param in ipairs(tool.params or {}) do
|
||||
properties[param.name] = {
|
||||
type = param.type == "integer" and "number" or param.type,
|
||||
description = param.description,
|
||||
}
|
||||
if not param.optional then
|
||||
table.insert(required, param.name)
|
||||
end
|
||||
end
|
||||
|
||||
local description = type(tool.description) == "function" and tool.description() or tool.description
|
||||
|
||||
if provider == "claude" then
|
||||
table.insert(tools, {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
input_schema = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
})
|
||||
else
|
||||
table.insert(tools, {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = description,
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = properties,
|
||||
required = required,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return tools
|
||||
end
|
||||
|
||||
--- Execute a tool call
|
||||
---@param tool_call AgenticToolCall
|
||||
---@param opts AgenticOpts
|
||||
---@return string result
|
||||
---@return string|nil error
|
||||
local function execute_tool(tool_call, opts)
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
local name = tool_call["function"].name
|
||||
local args = tool_call["function"].arguments
|
||||
|
||||
-- Parse arguments if string
|
||||
if type(args) == "string" then
|
||||
local ok, parsed = pcall(vim.json.decode, args)
|
||||
if ok then
|
||||
args = parsed
|
||||
else
|
||||
return "", "Failed to parse tool arguments: " .. args
|
||||
end
|
||||
end
|
||||
|
||||
-- Notify tool start
|
||||
if opts.on_tool_start then
|
||||
opts.on_tool_start(name, args)
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status("Executing: " .. name)
|
||||
end
|
||||
|
||||
-- Execute the tool
|
||||
local tool = tools_mod.get(name)
|
||||
if not tool then
|
||||
local err = "Unknown tool: " .. name
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, nil, err)
|
||||
end
|
||||
return "", err
|
||||
end
|
||||
|
||||
local result, err = tool.func(args, {
|
||||
on_log = function(msg)
|
||||
if opts.on_status then
|
||||
opts.on_status(msg)
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
-- Notify tool end
|
||||
if opts.on_tool_end then
|
||||
opts.on_tool_end(name, result, err)
|
||||
end
|
||||
|
||||
-- Track file changes
|
||||
if opts.on_file_change and (name == "write" or name == "edit") and not err then
|
||||
opts.on_file_change(args.path, name == "write" and "created" or "modified")
|
||||
end
|
||||
|
||||
if err then
|
||||
return "", err
|
||||
end
|
||||
|
||||
return type(result) == "string" and result or vim.json.encode(result), nil
|
||||
end
|
||||
|
||||
--- Parse tool calls from LLM response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return AgenticToolCall[]
|
||||
local function parse_tool_calls(response, provider)
|
||||
local tool_calls = {}
|
||||
|
||||
-- Unified format: content array with tool_use blocks
|
||||
local content = response.content or {}
|
||||
for _, block in ipairs(content) do
|
||||
if block.type == "tool_use" then
|
||||
-- OpenAI expects arguments as JSON string, not table
|
||||
local args = block.input
|
||||
if type(args) == "table" then
|
||||
args = vim.json.encode(args)
|
||||
end
|
||||
|
||||
table.insert(tool_calls, {
|
||||
id = block.id or utils.generate_id("call"),
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = block.name,
|
||||
arguments = args,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return tool_calls
|
||||
end
|
||||
|
||||
--- Extract text content from response (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return string
|
||||
local function extract_content(response, provider)
|
||||
local parts = {}
|
||||
for _, block in ipairs(response.content or {}) do
|
||||
if block.type == "text" then
|
||||
table.insert(parts, block.text)
|
||||
end
|
||||
end
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
--- Check if response indicates completion (unified Claude-like format)
|
||||
---@param response table Raw API response in unified format
|
||||
---@param provider string Provider name (unused, kept for signature compatibility)
|
||||
---@return boolean
|
||||
local function is_complete(response, provider)
|
||||
return response.stop_reason == "end_turn"
|
||||
end
|
||||
|
||||
--- Make API request to LLM with native tool calling support
|
||||
---@param messages table[] Formatted messages
|
||||
---@param tools table[] Formatted tools
|
||||
---@param system_prompt string System prompt
|
||||
---@param provider string "openai"|"claude"|"copilot"
|
||||
---@param model string Model name
|
||||
---@param callback fun(response: table|nil, error: string|nil)
|
||||
local function call_llm(messages, tools, system_prompt, provider, model, callback)
|
||||
local context = {
|
||||
language = "lua",
|
||||
file_content = "",
|
||||
prompt_type = "agent",
|
||||
project_root = vim.fn.getcwd(),
|
||||
cwd = vim.fn.getcwd(),
|
||||
}
|
||||
|
||||
-- Use native tool calling APIs
|
||||
if provider == "copilot" then
|
||||
local client = require("codetyper.core.llm.copilot")
|
||||
|
||||
-- Copilot's generate_with_tools expects messages in a specific format
|
||||
-- Convert to the format it expects
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
-- Convert to our internal format
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or utils.generate_id("call"),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "openai" then
|
||||
local client = require("codetyper.core.llm.openai")
|
||||
|
||||
-- OpenAI's generate_with_tools
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
if response and response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
table.insert(result.content, { type = "text", text = block.text })
|
||||
elseif block.type == "tool_use" then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = block.id or utils.generate_id("call"),
|
||||
name = block.name,
|
||||
input = block.input,
|
||||
})
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
elseif provider == "ollama" then
|
||||
local client = require("codetyper.core.llm.ollama")
|
||||
|
||||
-- Ollama's generate_with_tools (text-based tool calling)
|
||||
local converted_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role ~= "system" then
|
||||
table.insert(converted_messages, msg)
|
||||
end
|
||||
end
|
||||
|
||||
client.generate_with_tools(converted_messages, context, tools, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Response is already in Claude-like format from the provider
|
||||
callback(response, nil)
|
||||
end)
|
||||
else
|
||||
-- Fallback for other providers (ollama, etc.) - use text-based parsing
|
||||
local client = require("codetyper.core.llm." .. provider)
|
||||
|
||||
-- Build prompt from messages
|
||||
local prompts = require("codetyper.prompts.agents")
|
||||
local prompt_parts = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, prompts.text_user_prefix .. content)
|
||||
elseif msg.role == "assistant" then
|
||||
local content = type(msg.content) == "string" and msg.content or vim.json.encode(msg.content)
|
||||
table.insert(prompt_parts, prompts.text_assistant_prefix .. content)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add tool descriptions to prompt for text-based providers
|
||||
local tool_desc = require("codetyper.prompts.agents").tool_instructions_text
|
||||
for _, tool in ipairs(tools) do
|
||||
local name = tool.name or (tool["function"] and tool["function"].name)
|
||||
local desc = tool.description or (tool["function"] and tool["function"].description)
|
||||
if name then
|
||||
tool_desc = tool_desc .. string.format("- **%s**: %s\n", name, desc or "")
|
||||
end
|
||||
end
|
||||
|
||||
context.file_content = system_prompt .. tool_desc
|
||||
|
||||
client.generate(table.concat(prompt_parts, "\n\n"), context, function(response, err)
|
||||
if err then
|
||||
callback(nil, err)
|
||||
return
|
||||
end
|
||||
|
||||
-- Parse response for tool calls (text-based fallback)
|
||||
local result = {
|
||||
content = {},
|
||||
stop_reason = "end_turn",
|
||||
}
|
||||
|
||||
-- Extract text content
|
||||
local text_content = response
|
||||
|
||||
-- Try to extract JSON tool calls from response
|
||||
local json_match = response:match("```json%s*(%b{})%s*```")
|
||||
if json_match then
|
||||
local ok, parsed = pcall(vim.json.decode, json_match)
|
||||
if ok and parsed.tool then
|
||||
table.insert(result.content, {
|
||||
type = "tool_use",
|
||||
id = utils.generate_id("call"),
|
||||
name = parsed.tool,
|
||||
input = parsed.arguments or {},
|
||||
})
|
||||
text_content = response:gsub("```json.-```", ""):gsub("^%s+", ""):gsub("%s+$", "")
|
||||
result.stop_reason = "tool_use"
|
||||
end
|
||||
end
|
||||
|
||||
if text_content and text_content ~= "" then
|
||||
table.insert(result.content, 1, { type = "text", text = text_content })
|
||||
end
|
||||
|
||||
callback(result, nil)
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Run the agentic loop
|
||||
---@param opts AgenticOpts
|
||||
function M.run(opts)
|
||||
-- Load agent
|
||||
local agent = load_agent(opts.agent or "coder")
|
||||
if not agent then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Unknown agent: " .. (opts.agent or "coder"))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Load rules
|
||||
local rules = load_rules()
|
||||
|
||||
-- Build system prompt
|
||||
local system_prompt = agent.system_prompt .. rules
|
||||
|
||||
-- Initialize message history
|
||||
---@type AgenticMessage[]
|
||||
local history = {
|
||||
{ role = "system", content = system_prompt },
|
||||
}
|
||||
|
||||
-- Add initial file context if provided
|
||||
if opts.files and #opts.files > 0 then
|
||||
local file_context = require("codetyper.prompts.agents").format_file_context(opts.files)
|
||||
table.insert(history, { role = "user", content = file_context })
|
||||
table.insert(history, { role = "assistant", content = "I've reviewed the provided files. What would you like me to do?" })
|
||||
end
|
||||
|
||||
-- Add the task
|
||||
table.insert(history, { role = "user", content = opts.task })
|
||||
|
||||
-- Determine provider
|
||||
local config = require("codetyper").get_config()
|
||||
local provider = config.llm.provider or "copilot"
|
||||
-- Note: Ollama has its own handler in call_llm, don't change it
|
||||
|
||||
-- Get tools for this agent
|
||||
local tool_names = agent.tools or { "view", "edit", "write", "grep", "glob", "bash" }
|
||||
|
||||
-- Ensure tools are loaded
|
||||
local tools_mod = require("codetyper.core.tools")
|
||||
tools_mod.setup()
|
||||
|
||||
-- Build tools for API
|
||||
local tools = build_tools(tool_names, provider)
|
||||
|
||||
-- Iteration tracking
|
||||
local iteration = 0
|
||||
local max_iterations = opts.max_iterations or 20
|
||||
|
||||
--- Process one iteration
|
||||
local function process_iteration()
|
||||
iteration = iteration + 1
|
||||
|
||||
if iteration > max_iterations then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, "Max iterations reached")
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if opts.on_status then
|
||||
opts.on_status(string.format("Thinking... (iteration %d)", iteration))
|
||||
end
|
||||
|
||||
-- Build messages for API
|
||||
local messages = build_messages(history, provider)
|
||||
|
||||
-- Call LLM
|
||||
call_llm(messages, tools, system_prompt, provider, opts.model, function(response, err)
|
||||
if err then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(nil, err)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract content and tool calls
|
||||
local content = extract_content(response, provider)
|
||||
local tool_calls = parse_tool_calls(response, provider)
|
||||
|
||||
-- Add assistant message to history
|
||||
local assistant_msg = {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
tool_calls = #tool_calls > 0 and tool_calls or nil,
|
||||
}
|
||||
table.insert(history, assistant_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(assistant_msg)
|
||||
end
|
||||
|
||||
-- Process tool calls if any
|
||||
if #tool_calls > 0 then
|
||||
for _, tc in ipairs(tool_calls) do
|
||||
local result, tool_err = execute_tool(tc, opts)
|
||||
|
||||
-- Add tool result to history
|
||||
local tool_msg = {
|
||||
role = "tool",
|
||||
tool_call_id = tc.id,
|
||||
name = tc["function"].name,
|
||||
content = tool_err or result,
|
||||
}
|
||||
table.insert(history, tool_msg)
|
||||
|
||||
if opts.on_message then
|
||||
opts.on_message(tool_msg)
|
||||
end
|
||||
end
|
||||
|
||||
-- Continue the loop
|
||||
vim.schedule(process_iteration)
|
||||
else
|
||||
-- No tool calls - check if complete
|
||||
if is_complete(response, provider) or content ~= "" then
|
||||
if opts.on_complete then
|
||||
opts.on_complete(content, nil)
|
||||
end
|
||||
else
|
||||
-- Continue if not explicitly complete
|
||||
vim.schedule(process_iteration)
|
||||
end
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- Start the loop
|
||||
process_iteration()
|
||||
end
|
||||
|
||||
--- Create default agent files in .coder/agents/
|
||||
function M.init_agents_dir()
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
vim.fn.mkdir(agents_dir, "p")
|
||||
|
||||
-- Create example agent
|
||||
local example_agent = require("codetyper.prompts.agents.templates").agent
|
||||
|
||||
local example_path = agents_dir .. "/example.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_agent, "\n"), example_path)
|
||||
end
|
||||
|
||||
return agents_dir
|
||||
end
|
||||
|
||||
--- Create default rules in .coder/rules/
|
||||
function M.init_rules_dir()
|
||||
local rules_dir = vim.fn.getcwd() .. "/.coder/rules"
|
||||
vim.fn.mkdir(rules_dir, "p")
|
||||
|
||||
-- Create example rule
|
||||
local example_rule = require("codetyper.prompts.agents.templates").rule
|
||||
|
||||
local example_path = rules_dir .. "/code-style.md"
|
||||
if vim.fn.filereadable(example_path) ~= 1 then
|
||||
vim.fn.writefile(vim.split(example_rule, "\n"), example_path)
|
||||
end
|
||||
|
||||
return rules_dir
|
||||
end
|
||||
|
||||
--- Initialize both agents and rules directories
|
||||
function M.init()
|
||||
M.init_agents_dir()
|
||||
M.init_rules_dir()
|
||||
end
|
||||
|
||||
--- List available agents
|
||||
---@return table[] List of {name, description, builtin}
|
||||
function M.list_agents()
|
||||
local agents = {}
|
||||
|
||||
-- Built-in agents
|
||||
local personas = require("codetyper.prompts.agents.personas").builtin
|
||||
local builtins = vim.tbl_keys(personas)
|
||||
table.sort(builtins)
|
||||
|
||||
for _, name in ipairs(builtins) do
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = true,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Custom agents from .coder/agents/
|
||||
local agents_dir = vim.fn.getcwd() .. "/.coder/agents"
|
||||
if vim.fn.isdirectory(agents_dir) == 1 then
|
||||
local files = vim.fn.glob(agents_dir .. "/*.md", false, true)
|
||||
for _, file in ipairs(files) do
|
||||
local name = vim.fn.fnamemodify(file, ":t:r")
|
||||
if not vim.tbl_contains(builtins, name) then
|
||||
local agent = load_agent(name)
|
||||
if agent then
|
||||
table.insert(agents, {
|
||||
name = agent.name,
|
||||
description = agent.description,
|
||||
builtin = false,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return agents
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -4,12 +4,14 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local tools = require("codetyper.agent.tools")
|
||||
local executor = require("codetyper.agent.executor")
|
||||
local parser = require("codetyper.agent.parser")
|
||||
local diff = require("codetyper.agent.diff")
|
||||
local utils = require("codetyper.utils")
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local tools = require("codetyper.core.tools")
|
||||
local executor = require("codetyper.core.scheduler.executor")
|
||||
local parser = require("codetyper.core.llm.parser")
|
||||
local diff = require("codetyper.core.diff.diff")
|
||||
local diff_review = require("codetyper.adapters.nvim.ui.diff_review")
|
||||
local resume = require("codetyper.core.scheduler.resume")
|
||||
local utils = require("codetyper.support.utils")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
|
||||
---@class AgentState
|
||||
---@field conversation table[] Message history for multi-turn
|
||||
@@ -21,8 +23,11 @@ local state = {
|
||||
conversation = {},
|
||||
pending_tool_results = {},
|
||||
is_running = false,
|
||||
max_iterations = 10,
|
||||
max_iterations = 25, -- Increased for complex tasks (env setup, tests, fixes)
|
||||
current_iteration = 0,
|
||||
original_prompt = "", -- Store for resume functionality
|
||||
current_context = nil, -- Store context for resume
|
||||
current_callbacks = nil, -- Store callbacks for continue
|
||||
}
|
||||
|
||||
---@class AgentCallbacks
|
||||
@@ -38,6 +43,8 @@ function M.reset()
|
||||
state.pending_tool_results = {}
|
||||
state.is_running = false
|
||||
state.current_iteration = 0
|
||||
-- Clear collected diffs
|
||||
diff_review.clear()
|
||||
end
|
||||
|
||||
--- Check if agent is currently running
|
||||
@@ -67,6 +74,9 @@ function M.run(prompt, context, callbacks)
|
||||
|
||||
state.is_running = true
|
||||
state.current_iteration = 0
|
||||
state.original_prompt = prompt
|
||||
state.current_context = context
|
||||
state.current_callbacks = callbacks
|
||||
|
||||
-- Add user message to conversation
|
||||
table.insert(state.conversation, {
|
||||
@@ -91,13 +101,13 @@ function M.agent_loop(context, callbacks)
|
||||
logs.info(string.format("Agent loop iteration %d/%d", state.current_iteration, state.max_iterations))
|
||||
|
||||
if state.current_iteration > state.max_iterations then
|
||||
logs.error("Max iterations reached")
|
||||
callbacks.on_error("Max iterations reached (" .. state.max_iterations .. ")")
|
||||
state.is_running = false
|
||||
logs.info("Max iterations reached, asking user to continue or stop")
|
||||
-- Ask user if they want to continue
|
||||
M.prompt_continue(context, callbacks)
|
||||
return
|
||||
end
|
||||
|
||||
local llm = require("codetyper.llm")
|
||||
local llm = require("codetyper.core.llm")
|
||||
local client = llm.get_client()
|
||||
|
||||
-- Check if client supports tools
|
||||
@@ -111,7 +121,11 @@ function M.agent_loop(context, callbacks)
|
||||
logs.thinking("Calling LLM with " .. #state.conversation .. " messages...")
|
||||
|
||||
-- Generate with tools enabled
|
||||
client.generate_with_tools(state.conversation, context, tools.definitions, function(response, err)
|
||||
-- Ensure tools are loaded and get definitions
|
||||
tools.setup()
|
||||
local tool_defs = tools.to_openai_format()
|
||||
|
||||
client.generate_with_tools(state.conversation, context, tool_defs, function(response, err)
|
||||
if err then
|
||||
state.is_running = false
|
||||
callbacks.on_error(err)
|
||||
@@ -123,12 +137,14 @@ function M.agent_loop(context, callbacks)
|
||||
local config = codetyper.get_config()
|
||||
local parsed
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
-- Copilot uses Claude-like response format
|
||||
if config.llm.provider == "copilot" then
|
||||
parsed = parser.parse_claude_response(response)
|
||||
-- For Claude, preserve the original content array for proper tool_use handling
|
||||
table.insert(state.conversation, {
|
||||
role = "assistant",
|
||||
content = response.content, -- Keep original content blocks for Claude API
|
||||
content = parsed.text or "",
|
||||
tool_calls = parsed.tool_calls,
|
||||
_raw_content = response.content,
|
||||
})
|
||||
else
|
||||
-- For Ollama, response is the text directly
|
||||
@@ -200,11 +216,36 @@ function M.process_tool_calls(tool_calls, index, context, callbacks)
|
||||
show_fn = diff.show_diff
|
||||
end
|
||||
|
||||
show_fn(result.diff_data, function(approved)
|
||||
show_fn(result.diff_data, function(approval_result)
|
||||
-- Handle both old (boolean) and new (table) approval result formats
|
||||
local approved = type(approval_result) == "table" and approval_result.approved or approval_result
|
||||
local permission_level = type(approval_result) == "table" and approval_result.permission_level or nil
|
||||
|
||||
if approved then
|
||||
logs.tool(tool_call.name, "approved", "User approved")
|
||||
-- Apply the change
|
||||
local log_msg = "User approved"
|
||||
if permission_level == "allow_session" then
|
||||
log_msg = "Allowed for session"
|
||||
elseif permission_level == "allow_list" then
|
||||
log_msg = "Added to allow list"
|
||||
elseif permission_level == "auto" then
|
||||
log_msg = "Auto-approved"
|
||||
end
|
||||
logs.tool(tool_call.name, "approved", log_msg)
|
||||
|
||||
-- Apply the change and collect for review
|
||||
executor.apply_change(result.diff_data, function(apply_result)
|
||||
-- Collect the diff for end-of-session review
|
||||
if result.diff_data.operation ~= "bash" then
|
||||
diff_review.add({
|
||||
path = result.diff_data.path,
|
||||
operation = result.diff_data.operation,
|
||||
original = result.diff_data.original,
|
||||
modified = result.diff_data.modified,
|
||||
approved = true,
|
||||
applied = true,
|
||||
})
|
||||
end
|
||||
|
||||
-- Store result for sending back to LLM
|
||||
table.insert(state.pending_tool_results, {
|
||||
tool_use_id = tool_call.id,
|
||||
@@ -261,20 +302,16 @@ function M.continue_with_results(context, callbacks)
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
if config.llm.provider == "claude" then
|
||||
-- Claude format: tool_result blocks
|
||||
local content = {}
|
||||
-- Copilot uses OpenAI format for tool results (role: "tool")
|
||||
if config.llm.provider == "copilot" then
|
||||
-- OpenAI-style tool messages - each result is a separate message
|
||||
for _, result in ipairs(state.pending_tool_results) do
|
||||
table.insert(content, {
|
||||
type = "tool_result",
|
||||
tool_use_id = result.tool_use_id,
|
||||
table.insert(state.conversation, {
|
||||
role = "tool",
|
||||
tool_call_id = result.tool_use_id,
|
||||
content = result.result,
|
||||
})
|
||||
end
|
||||
table.insert(state.conversation, {
|
||||
role = "user",
|
||||
content = content,
|
||||
})
|
||||
else
|
||||
-- Ollama format: plain text describing results
|
||||
local result_text = "Tool results:\n"
|
||||
@@ -305,4 +342,114 @@ function M.set_max_iterations(max)
|
||||
state.max_iterations = max
|
||||
end
|
||||
|
||||
--- Get the count of collected changes
|
||||
---@return number
|
||||
function M.get_changes_count()
|
||||
return diff_review.count()
|
||||
end
|
||||
|
||||
--- Show the diff review UI for all collected changes
|
||||
function M.show_diff_review()
|
||||
diff_review.open()
|
||||
end
|
||||
|
||||
--- Check if diff review is open
|
||||
---@return boolean
|
||||
function M.is_review_open()
|
||||
return diff_review.is_open()
|
||||
end
|
||||
|
||||
--- Prompt user to continue or stop at max iterations
|
||||
---@param context table File context
|
||||
---@param callbacks AgentCallbacks
|
||||
function M.prompt_continue(context, callbacks)
|
||||
vim.schedule(function()
|
||||
vim.ui.select({ "Continue (25 more iterations)", "Stop and save for later" }, {
|
||||
prompt = string.format("Agent reached %d iterations. Continue?", state.max_iterations),
|
||||
}, function(choice)
|
||||
if choice and choice:match("^Continue") then
|
||||
-- Reset iteration counter and continue
|
||||
state.current_iteration = 0
|
||||
logs.info("User chose to continue, resetting iteration counter")
|
||||
M.agent_loop(context, callbacks)
|
||||
else
|
||||
-- Save state for later resume
|
||||
logs.info("User chose to stop, saving state for resume")
|
||||
resume.save(
|
||||
state.conversation,
|
||||
state.pending_tool_results,
|
||||
state.current_iteration,
|
||||
state.original_prompt
|
||||
)
|
||||
state.is_running = false
|
||||
callbacks.on_text("Agent paused. Use /continue to resume later.")
|
||||
callbacks.on_complete()
|
||||
end
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
--- Continue a previously stopped agent session
|
||||
---@param callbacks AgentCallbacks
|
||||
---@return boolean Success
|
||||
function M.continue_session(callbacks)
|
||||
if state.is_running then
|
||||
utils.notify("Agent is already running", vim.log.levels.WARN)
|
||||
return false
|
||||
end
|
||||
|
||||
local saved = resume.load()
|
||||
if not saved then
|
||||
utils.notify("No saved agent session to continue", vim.log.levels.WARN)
|
||||
return false
|
||||
end
|
||||
|
||||
logs.info("Resuming agent session")
|
||||
logs.info(string.format("Loaded %d messages, iteration %d", #saved.conversation, saved.iteration))
|
||||
|
||||
-- Restore state
|
||||
state.conversation = saved.conversation
|
||||
state.pending_tool_results = saved.pending_tool_results or {}
|
||||
state.current_iteration = 0 -- Reset for fresh iterations
|
||||
state.original_prompt = saved.original_prompt
|
||||
state.is_running = true
|
||||
state.current_callbacks = callbacks
|
||||
|
||||
-- Build context from current state
|
||||
local llm = require("codetyper.core.llm")
|
||||
local context = {}
|
||||
local current_file = vim.fn.expand("%:p")
|
||||
if current_file ~= "" and vim.fn.filereadable(current_file) == 1 then
|
||||
context = llm.build_context(current_file, "agent")
|
||||
end
|
||||
state.current_context = context
|
||||
|
||||
-- Clear saved state
|
||||
resume.clear()
|
||||
|
||||
-- Add continuation message
|
||||
table.insert(state.conversation, {
|
||||
role = "user",
|
||||
content = "Continue where you left off. Complete the remaining tasks.",
|
||||
})
|
||||
|
||||
-- Continue the loop
|
||||
callbacks.on_text("Resuming agent session...")
|
||||
M.agent_loop(context, callbacks)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Check if there's a saved session to continue
|
||||
---@return boolean
|
||||
function M.has_saved_session()
|
||||
return resume.has_saved_state()
|
||||
end
|
||||
|
||||
--- Get info about saved session
|
||||
---@return table|nil
|
||||
function M.get_saved_session_info()
|
||||
return resume.get_info()
|
||||
end
|
||||
|
||||
return M
|
||||
425
lua/codetyper/features/agents/linter.lua
Normal file
425
lua/codetyper/features/agents/linter.lua
Normal file
@@ -0,0 +1,425 @@
|
||||
---@mod codetyper.agent.linter Linter validation for generated code
|
||||
---@brief [[
|
||||
--- Validates generated code by checking LSP diagnostics after injection.
|
||||
--- Automatically saves the file and waits for LSP to update before checking.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local config_params = require("codetyper.params.agents.linter")
|
||||
local prompts = require("codetyper.prompts.agents.linter")
|
||||
|
||||
--- Configuration
|
||||
local config = config_params.config
|
||||
|
||||
--- Diagnostic results for tracking
|
||||
---@type table<number, table>
|
||||
local validation_results = {}
|
||||
|
||||
--- Configure linter behavior
|
||||
---@param opts table Configuration options
|
||||
function M.configure(opts)
|
||||
for k, v in pairs(opts) do
|
||||
if config[k] ~= nil then
|
||||
config[k] = v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Get current configuration
|
||||
---@return table
|
||||
function M.get_config()
|
||||
return vim.deepcopy(config)
|
||||
end
|
||||
|
||||
--- Save buffer if modified
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean success
|
||||
local function save_buffer(bufnr)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Skip if buffer is not modified
|
||||
if not vim.bo[bufnr].modified then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Skip if buffer has no name (unsaved file)
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
if bufname == "" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Save the buffer
|
||||
local ok, err = pcall(function()
|
||||
vim.api.nvim_buf_call(bufnr, function()
|
||||
vim.cmd("silent! write")
|
||||
end)
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = "Failed to save buffer: " .. tostring(err),
|
||||
})
|
||||
end)
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get LSP diagnostics for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line? number Start line (1-indexed)
|
||||
---@param end_line? number End line (1-indexed)
|
||||
---@return table[] diagnostics List of diagnostics
|
||||
function M.get_diagnostics(bufnr, start_line, end_line)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return {}
|
||||
end
|
||||
|
||||
local all_diagnostics = vim.diagnostic.get(bufnr)
|
||||
local filtered = {}
|
||||
|
||||
for _, diag in ipairs(all_diagnostics) do
|
||||
-- Filter by severity
|
||||
if diag.severity <= config.min_severity then
|
||||
-- Filter by line range if specified
|
||||
if start_line and end_line then
|
||||
local diag_line = diag.lnum + 1 -- Convert to 1-indexed
|
||||
if diag_line >= start_line and diag_line <= end_line then
|
||||
table.insert(filtered, diag)
|
||||
end
|
||||
else
|
||||
table.insert(filtered, diag)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return filtered
|
||||
end
|
||||
|
||||
--- Format a diagnostic for display
|
||||
---@param diag table Diagnostic object
|
||||
---@return string
|
||||
local function format_diagnostic(diag)
|
||||
local severity_names = {
|
||||
[vim.diagnostic.severity.ERROR] = "ERROR",
|
||||
[vim.diagnostic.severity.WARN] = "WARN",
|
||||
[vim.diagnostic.severity.INFO] = "INFO",
|
||||
[vim.diagnostic.severity.HINT] = "HINT",
|
||||
}
|
||||
local severity = severity_names[diag.severity] or "UNKNOWN"
|
||||
local line = diag.lnum + 1
|
||||
local source = diag.source or "lsp"
|
||||
return string.format("[%s] Line %d (%s): %s", severity, line, source, diag.message)
|
||||
end
|
||||
|
||||
--- Check if there are errors in generated code region
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line number Start line (1-indexed)
|
||||
---@param end_line number End line (1-indexed)
|
||||
---@return table result {has_errors, has_warnings, diagnostics, summary}
|
||||
function M.check_region(bufnr, start_line, end_line)
|
||||
local diagnostics = M.get_diagnostics(bufnr, start_line, end_line)
|
||||
|
||||
local errors = 0
|
||||
local warnings = 0
|
||||
|
||||
for _, diag in ipairs(diagnostics) do
|
||||
if diag.severity == vim.diagnostic.severity.ERROR then
|
||||
errors = errors + 1
|
||||
elseif diag.severity == vim.diagnostic.severity.WARN then
|
||||
warnings = warnings + 1
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
has_errors = errors > 0,
|
||||
has_warnings = warnings > 0,
|
||||
error_count = errors,
|
||||
warning_count = warnings,
|
||||
diagnostics = diagnostics,
|
||||
summary = string.format("%d error(s), %d warning(s)", errors, warnings),
|
||||
}
|
||||
end
|
||||
|
||||
--- Validate code after injection and report issues
|
||||
---@param bufnr number Buffer number
|
||||
---@param start_line? number Start line of injected code (1-indexed)
|
||||
---@param end_line? number End line of injected code (1-indexed)
|
||||
---@param callback? function Callback with (result) when validation completes
|
||||
function M.validate_after_injection(bufnr, start_line, end_line, callback)
|
||||
-- Save the file first
|
||||
if config.auto_save then
|
||||
save_buffer(bufnr)
|
||||
end
|
||||
|
||||
-- Wait for LSP to process changes
|
||||
vim.defer_fn(function()
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
if callback then callback(nil) end
|
||||
return
|
||||
end
|
||||
|
||||
local result
|
||||
if start_line and end_line then
|
||||
result = M.check_region(bufnr, start_line, end_line)
|
||||
else
|
||||
-- Check entire buffer
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
result = M.check_region(bufnr, 1, line_count)
|
||||
end
|
||||
|
||||
-- Store result for this buffer
|
||||
validation_results[bufnr] = {
|
||||
timestamp = os.time(),
|
||||
result = result,
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
}
|
||||
|
||||
-- Log results
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
if result.has_errors then
|
||||
logs.add({
|
||||
type = "error",
|
||||
message = string.format("Linter found issues: %s", result.summary),
|
||||
})
|
||||
-- Log individual errors
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
if diag.severity == vim.diagnostic.severity.ERROR then
|
||||
logs.add({
|
||||
type = "error",
|
||||
message = format_diagnostic(diag),
|
||||
})
|
||||
end
|
||||
end
|
||||
elseif result.has_warnings then
|
||||
logs.add({
|
||||
type = "warning",
|
||||
message = string.format("Linter warnings: %s", result.summary),
|
||||
})
|
||||
else
|
||||
logs.add({
|
||||
type = "success",
|
||||
message = "Linter check passed - no errors or warnings",
|
||||
})
|
||||
end
|
||||
end)
|
||||
|
||||
-- Notify user
|
||||
if result.has_errors then
|
||||
vim.notify(
|
||||
string.format("Generated code has lint errors: %s", result.summary),
|
||||
vim.log.levels.ERROR
|
||||
)
|
||||
|
||||
-- Offer to fix if configured
|
||||
if config.auto_offer_fix and #result.diagnostics > 0 then
|
||||
M.offer_fix(bufnr, result)
|
||||
end
|
||||
elseif result.has_warnings then
|
||||
vim.notify(
|
||||
string.format("Generated code has warnings: %s", result.summary),
|
||||
vim.log.levels.WARN
|
||||
)
|
||||
end
|
||||
|
||||
if callback then
|
||||
callback(result)
|
||||
end
|
||||
end, config.diagnostic_delay_ms)
|
||||
end
|
||||
|
||||
--- Offer to fix lint errors using AI
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.offer_fix(bufnr, result)
|
||||
if not result.has_errors and not result.has_warnings then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build error summary for prompt
|
||||
local error_messages = {}
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(error_messages, format_diagnostic(diag))
|
||||
end
|
||||
|
||||
vim.ui.select(
|
||||
{ "Yes - Auto-fix with AI", "No - I'll fix manually", "Show errors in quickfix" },
|
||||
{
|
||||
prompt = string.format("Found %d issue(s). Would you like AI to fix them?", #result.diagnostics),
|
||||
},
|
||||
function(choice)
|
||||
if not choice then return end
|
||||
|
||||
if choice:match("^Yes") then
|
||||
M.request_ai_fix(bufnr, result)
|
||||
elseif choice:match("quickfix") then
|
||||
M.show_in_quickfix(bufnr, result)
|
||||
end
|
||||
end
|
||||
)
|
||||
end
|
||||
|
||||
--- Show lint errors in quickfix list
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.show_in_quickfix(bufnr, result)
|
||||
local qf_items = {}
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(qf_items, {
|
||||
bufnr = bufnr,
|
||||
filename = bufname,
|
||||
lnum = diag.lnum + 1,
|
||||
col = diag.col + 1,
|
||||
text = diag.message,
|
||||
type = diag.severity == vim.diagnostic.severity.ERROR and "E" or "W",
|
||||
})
|
||||
end
|
||||
|
||||
vim.fn.setqflist(qf_items, "r")
|
||||
vim.cmd("copen")
|
||||
end
|
||||
|
||||
--- Request AI to fix lint errors
|
||||
---@param bufnr number Buffer number
|
||||
---@param result table Validation result
|
||||
function M.request_ai_fix(bufnr, result)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return
|
||||
end
|
||||
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
|
||||
-- Build fix prompt
|
||||
local error_list = {}
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
table.insert(error_list, format_diagnostic(diag))
|
||||
end
|
||||
|
||||
-- Get the affected code region
|
||||
local start_line = result.diagnostics[1] and (result.diagnostics[1].lnum + 1) or 1
|
||||
local end_line = start_line
|
||||
for _, diag in ipairs(result.diagnostics) do
|
||||
local line = diag.lnum + 1
|
||||
if line < start_line then start_line = line end
|
||||
if line > end_line then end_line = line end
|
||||
end
|
||||
|
||||
-- Expand range by a few lines for context
|
||||
start_line = math.max(1, start_line - 5)
|
||||
end_line = math.min(vim.api.nvim_buf_line_count(bufnr), end_line + 5)
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local code_context = table.concat(lines, "\n")
|
||||
|
||||
-- Create fix prompt using inline tag
|
||||
local fix_prompt = string.format(
|
||||
prompts.fix_request,
|
||||
table.concat(error_list, "\n"),
|
||||
start_line,
|
||||
end_line,
|
||||
code_context
|
||||
)
|
||||
|
||||
-- Queue the fix through the scheduler
|
||||
local scheduler = require("codetyper.core.scheduler.scheduler")
|
||||
local queue = require("codetyper.core.events.queue")
|
||||
local patch_mod = require("codetyper.core.diff.patch")
|
||||
|
||||
-- Ensure scheduler is running
|
||||
if not scheduler.status().running then
|
||||
scheduler.start()
|
||||
end
|
||||
|
||||
-- Take snapshot
|
||||
local snapshot = patch_mod.snapshot_buffer(bufnr, {
|
||||
start_line = start_line,
|
||||
end_line = end_line,
|
||||
})
|
||||
|
||||
-- Enqueue fix request
|
||||
queue.enqueue({
|
||||
id = queue.generate_id(),
|
||||
bufnr = bufnr,
|
||||
range = { start_line = start_line, end_line = end_line },
|
||||
timestamp = os.clock(),
|
||||
changedtick = snapshot.changedtick,
|
||||
content_hash = snapshot.content_hash,
|
||||
prompt_content = fix_prompt,
|
||||
target_path = filepath,
|
||||
priority = 1, -- High priority for fixes
|
||||
status = "pending",
|
||||
attempt_count = 0,
|
||||
intent = {
|
||||
type = "fix",
|
||||
action = "replace",
|
||||
confidence = 0.9,
|
||||
},
|
||||
scope_range = { start_line = start_line, end_line = end_line },
|
||||
source = "linter_fix",
|
||||
})
|
||||
|
||||
pcall(function()
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.add({
|
||||
type = "info",
|
||||
message = "Queued AI fix request for lint errors",
|
||||
})
|
||||
end)
|
||||
|
||||
vim.notify("Queued AI fix request for lint errors", vim.log.levels.INFO)
|
||||
end
|
||||
|
||||
--- Get last validation result for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@return table|nil result
|
||||
function M.get_last_result(bufnr)
|
||||
return validation_results[bufnr]
|
||||
end
|
||||
|
||||
--- Clear validation results for a buffer
|
||||
---@param bufnr number Buffer number
|
||||
function M.clear_result(bufnr)
|
||||
validation_results[bufnr] = nil
|
||||
end
|
||||
|
||||
--- Check if buffer has any lint errors currently
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean has_errors
|
||||
function M.has_errors(bufnr)
|
||||
local diagnostics = vim.diagnostic.get(bufnr, {
|
||||
severity = vim.diagnostic.severity.ERROR,
|
||||
})
|
||||
return #diagnostics > 0
|
||||
end
|
||||
|
||||
--- Check if buffer has any lint warnings currently
|
||||
---@param bufnr number Buffer number
|
||||
---@return boolean has_warnings
|
||||
function M.has_warnings(bufnr)
|
||||
local diagnostics = vim.diagnostic.get(bufnr, {
|
||||
severity = { min = vim.diagnostic.severity.WARN },
|
||||
})
|
||||
return #diagnostics > 0
|
||||
end
|
||||
|
||||
--- Validate all buffers with recent changes
|
||||
function M.validate_all_changed()
|
||||
for bufnr, data in pairs(validation_results) do
|
||||
if vim.api.nvim_buf_is_valid(bufnr) then
|
||||
M.validate_after_injection(bufnr, data.start_line, data.end_line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
182
lua/codetyper/features/agents/permissions.lua
Normal file
182
lua/codetyper/features/agents/permissions.lua
Normal file
@@ -0,0 +1,182 @@
|
||||
---@mod codetyper.agent.permissions Permission manager for agent actions
|
||||
---
|
||||
--- Manages permissions for bash commands and file operations with
|
||||
--- allow, allow-session, allow-list, and reject options.
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class PermissionState
|
||||
---@field session_allowed table<string, boolean> Commands allowed for this session
|
||||
---@field allow_list table<string, boolean> Patterns always allowed
|
||||
---@field deny_list table<string, boolean> Patterns always denied
|
||||
|
||||
local params = require("codetyper.params.agents.permissions")
|
||||
|
||||
local state = {
|
||||
session_allowed = {},
|
||||
allow_list = {},
|
||||
deny_list = {},
|
||||
}
|
||||
|
||||
--- Dangerous command patterns that should never be auto-allowed
|
||||
local DANGEROUS_PATTERNS = params.dangerous_patterns
|
||||
|
||||
--- Safe command patterns that can be auto-allowed
|
||||
local SAFE_PATTERNS = params.safe_patterns
|
||||
|
||||
---@alias PermissionLevel "allow"|"allow_session"|"allow_list"|"reject"
|
||||
|
||||
---@class PermissionResult
|
||||
---@field allowed boolean Whether action is allowed
|
||||
---@field reason string Reason for the decision
|
||||
---@field auto boolean Whether this was an automatic decision
|
||||
|
||||
--- Check if a command matches a pattern
|
||||
---@param command string The command to check
|
||||
---@param pattern string The pattern to match
|
||||
---@return boolean
|
||||
local function matches_pattern(command, pattern)
|
||||
return command:match(pattern) ~= nil
|
||||
end
|
||||
|
||||
--- Check if command is dangerous
|
||||
---@param command string The command to check
|
||||
---@return boolean, string|nil dangerous, reason
|
||||
local function is_dangerous(command)
|
||||
for _, pattern in ipairs(DANGEROUS_PATTERNS) do
|
||||
if matches_pattern(command, pattern) then
|
||||
return true, "Matches dangerous pattern: " .. pattern
|
||||
end
|
||||
end
|
||||
return false, nil
|
||||
end
|
||||
|
||||
--- Check if command is safe
|
||||
---@param command string The command to check
|
||||
---@return boolean
|
||||
local function is_safe(command)
|
||||
for _, pattern in ipairs(SAFE_PATTERNS) do
|
||||
if matches_pattern(command, pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Normalize command for comparison (trim, lowercase first word)
|
||||
---@param command string
|
||||
---@return string
|
||||
local function normalize_command(command)
|
||||
return vim.trim(command)
|
||||
end
|
||||
|
||||
--- Check permission for a bash command
|
||||
---@param command string The command to check
|
||||
---@return PermissionResult
|
||||
function M.check_bash_permission(command)
|
||||
local normalized = normalize_command(command)
|
||||
|
||||
-- Check deny list first
|
||||
for pattern, _ in pairs(state.deny_list) do
|
||||
if matches_pattern(normalized, pattern) then
|
||||
return {
|
||||
allowed = false,
|
||||
reason = "Command in deny list",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if command is dangerous
|
||||
local dangerous, reason = is_dangerous(normalized)
|
||||
if dangerous then
|
||||
return {
|
||||
allowed = false,
|
||||
reason = reason,
|
||||
auto = false, -- Require explicit approval for dangerous commands
|
||||
}
|
||||
end
|
||||
|
||||
-- Check session allowed
|
||||
if state.session_allowed[normalized] then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Allowed for this session",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
|
||||
-- Check allow list patterns
|
||||
for pattern, _ in pairs(state.allow_list) do
|
||||
if matches_pattern(normalized, pattern) then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Matches allow list pattern",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
-- Check if command is inherently safe
|
||||
if is_safe(normalized) then
|
||||
return {
|
||||
allowed = true,
|
||||
reason = "Safe read-only command",
|
||||
auto = true,
|
||||
}
|
||||
end
|
||||
|
||||
-- Otherwise, require explicit permission
|
||||
return {
|
||||
allowed = false,
|
||||
reason = "Requires approval",
|
||||
auto = false,
|
||||
}
|
||||
end
|
||||
|
||||
--- Grant permission for a command
|
||||
---@param command string The command
|
||||
---@param level PermissionLevel The permission level
|
||||
function M.grant_permission(command, level)
|
||||
local normalized = normalize_command(command)
|
||||
|
||||
if level == "allow_session" then
|
||||
state.session_allowed[normalized] = true
|
||||
elseif level == "allow_list" then
|
||||
-- Add as pattern (escape special chars for exact match)
|
||||
local pattern = "^" .. vim.pesc(normalized) .. "$"
|
||||
state.allow_list[pattern] = true
|
||||
end
|
||||
end
|
||||
|
||||
--- Add a pattern to the allow list
|
||||
---@param pattern string Lua pattern to allow
|
||||
function M.add_to_allow_list(pattern)
|
||||
state.allow_list[pattern] = true
|
||||
end
|
||||
|
||||
--- Add a pattern to the deny list
|
||||
---@param pattern string Lua pattern to deny
|
||||
function M.add_to_deny_list(pattern)
|
||||
state.deny_list[pattern] = true
|
||||
end
|
||||
|
||||
--- Clear session permissions
|
||||
function M.clear_session()
|
||||
state.session_allowed = {}
|
||||
end
|
||||
|
||||
--- Reset all permissions
|
||||
function M.reset()
|
||||
state.session_allowed = {}
|
||||
state.allow_list = {}
|
||||
state.deny_list = {}
|
||||
end
|
||||
|
||||
--- Get current permission state (for debugging)
|
||||
---@return PermissionState
|
||||
function M.get_state()
|
||||
return vim.deepcopy(state)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,8 +1,8 @@
|
||||
---@mod codetyper.ask Ask window for Codetyper.nvim (similar to avante.nvim)
|
||||
---@mod codetyper.ask Ask window for Codetyper.nvim
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
---@class AskState
|
||||
---@field input_buf number|nil Input buffer
|
||||
@@ -26,6 +26,7 @@ local state = {
|
||||
agent_mode = false, -- Whether agent mode is enabled (can make file changes)
|
||||
log_listener_id = nil, -- Listener ID for LLM logs
|
||||
show_logs = true, -- Whether to show LLM logs in chat
|
||||
selection_context = nil, -- Visual selection passed when opening
|
||||
}
|
||||
|
||||
--- Get the ask window configuration
|
||||
@@ -312,7 +313,9 @@ local function append_log_to_output(entry)
|
||||
}
|
||||
|
||||
local icon = icons[entry.level] or "•"
|
||||
local formatted = string.format("[%s] %s %s", entry.timestamp, icon, entry.message)
|
||||
-- Sanitize message - replace newlines with spaces to prevent nvim_buf_set_lines error
|
||||
local sanitized_message = entry.message:gsub("\n", " "):gsub("\r", "")
|
||||
local formatted = string.format("[%s] %s %s", entry.timestamp, icon, sanitized_message)
|
||||
|
||||
vim.schedule(function()
|
||||
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then
|
||||
@@ -342,7 +345,7 @@ local function setup_log_listener()
|
||||
-- Remove existing listener if any
|
||||
if state.log_listener_id then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.remove_listener(state.log_listener_id)
|
||||
end)
|
||||
state.log_listener_id = nil
|
||||
@@ -359,7 +362,7 @@ end
|
||||
local function remove_log_listener()
|
||||
if state.log_listener_id then
|
||||
pcall(function()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.adapters.nvim.ui.logs")
|
||||
logs.remove_listener(state.log_listener_id)
|
||||
end)
|
||||
state.log_listener_id = nil
|
||||
@@ -367,13 +370,21 @@ local function remove_log_listener()
|
||||
end
|
||||
|
||||
--- Open the ask panel
|
||||
function M.open()
|
||||
---@param selection table|nil Visual selection context {text, start_line, end_line, filepath, filename, language}
|
||||
function M.open(selection)
|
||||
-- Use the is_open() function which validates window state
|
||||
if M.is_open() then
|
||||
-- If already open and new selection provided, add it as context
|
||||
if selection and selection.text and selection.text ~= "" then
|
||||
M.add_selection_context(selection)
|
||||
end
|
||||
M.focus_input()
|
||||
return
|
||||
end
|
||||
|
||||
-- Store selection context for use in questions
|
||||
state.selection_context = selection
|
||||
|
||||
local dims = calculate_dimensions()
|
||||
|
||||
-- Store the target width
|
||||
@@ -477,6 +488,70 @@ function M.open()
|
||||
-- Focus the input window and start insert mode
|
||||
vim.api.nvim_set_current_win(state.input_win)
|
||||
vim.cmd("startinsert")
|
||||
|
||||
-- If we have a selection, show it as context
|
||||
if selection and selection.text and selection.text ~= "" then
|
||||
vim.schedule(function()
|
||||
M.add_selection_context(selection)
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- Add visual selection as context in the chat
|
||||
---@param selection table Selection info {text, start_line, end_line, filepath, filename, language}
|
||||
function M.add_selection_context(selection)
|
||||
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then
|
||||
return
|
||||
end
|
||||
|
||||
state.selection_context = selection
|
||||
|
||||
vim.bo[state.output_buf].modifiable = true
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(state.output_buf, 0, -1, false)
|
||||
|
||||
-- Format the selection display
|
||||
local location = ""
|
||||
if selection.filename then
|
||||
location = selection.filename
|
||||
if selection.start_line then
|
||||
location = location .. ":" .. selection.start_line
|
||||
if selection.end_line and selection.end_line ~= selection.start_line then
|
||||
location = location .. "-" .. selection.end_line
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local new_lines = {
|
||||
"",
|
||||
"┌─ 📋 Selected Code ─────────────────",
|
||||
"│ " .. location,
|
||||
"│",
|
||||
}
|
||||
|
||||
-- Add the selected code with syntax hints
|
||||
local lang = selection.language or "text"
|
||||
for _, line in ipairs(vim.split(selection.text, "\n")) do
|
||||
table.insert(new_lines, "│ " .. line)
|
||||
end
|
||||
|
||||
table.insert(new_lines, "│")
|
||||
table.insert(new_lines, "└─────────────────────────────────────")
|
||||
table.insert(new_lines, "")
|
||||
table.insert(new_lines, "Ask about this code or describe what you'd like to do with it.")
|
||||
|
||||
for _, line in ipairs(new_lines) do
|
||||
table.insert(lines, line)
|
||||
end
|
||||
|
||||
vim.api.nvim_buf_set_lines(state.output_buf, 0, -1, false, lines)
|
||||
vim.bo[state.output_buf].modifiable = false
|
||||
|
||||
-- Scroll to bottom
|
||||
if state.output_win and vim.api.nvim_win_is_valid(state.output_win) then
|
||||
local line_count = vim.api.nvim_buf_line_count(state.output_buf)
|
||||
vim.api.nvim_win_set_cursor(state.output_win, { line_count, 0 })
|
||||
end
|
||||
end
|
||||
|
||||
--- Show file picker for @ mentions
|
||||
@@ -728,14 +803,17 @@ local function build_file_context()
|
||||
end
|
||||
|
||||
--- Build context for the question
|
||||
---@param intent? table Detected intent from intent module
|
||||
---@return table Context object
|
||||
local function build_context()
|
||||
local function build_context(intent)
|
||||
local context = {
|
||||
project_root = utils.get_project_root(),
|
||||
current_file = nil,
|
||||
current_content = nil,
|
||||
language = nil,
|
||||
referenced_files = state.referenced_files,
|
||||
brain_context = nil,
|
||||
indexer_context = nil,
|
||||
}
|
||||
|
||||
-- Try to get current file context from the non-ask window
|
||||
@@ -754,49 +832,140 @@ local function build_context()
|
||||
end
|
||||
end
|
||||
|
||||
-- Add brain context if intent needs it
|
||||
if intent and intent.needs_brain_context then
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized() then
|
||||
context.brain_context = brain.get_context_for_llm({
|
||||
file = context.current_file,
|
||||
max_tokens = 1000,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Add indexer context if intent needs project-wide context
|
||||
if intent and intent.needs_project_context then
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
context.indexer_context = indexer.get_context_for({
|
||||
file = context.current_file,
|
||||
prompt = "", -- Will be filled later
|
||||
intent = intent,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return context
|
||||
end
|
||||
|
||||
--- Submit the question to LLM
|
||||
function M.submit()
|
||||
local question = get_input_text()
|
||||
|
||||
if not question or question:match("^%s*$") then
|
||||
utils.notify("Please enter a question", vim.log.levels.WARN)
|
||||
M.focus_input()
|
||||
--- Append exploration log to output buffer
|
||||
---@param msg string
|
||||
---@param level string
|
||||
local function append_exploration_log(msg, level)
|
||||
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build context BEFORE clearing input (to preserve file references)
|
||||
local context = build_context()
|
||||
local file_context, file_count = build_file_context()
|
||||
vim.schedule(function()
|
||||
if not state.output_buf or not vim.api.nvim_buf_is_valid(state.output_buf) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build display message (without full file contents)
|
||||
local display_question = question
|
||||
if file_count > 0 then
|
||||
display_question = question .. "\n📎 " .. file_count .. " file(s) attached"
|
||||
vim.bo[state.output_buf].modifiable = true
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(state.output_buf, 0, -1, false)
|
||||
|
||||
-- Format based on level
|
||||
local formatted = msg
|
||||
if level == "progress" then
|
||||
formatted = msg
|
||||
elseif level == "debug" then
|
||||
formatted = msg
|
||||
elseif level == "file" then
|
||||
formatted = msg
|
||||
end
|
||||
|
||||
table.insert(lines, formatted)
|
||||
|
||||
vim.api.nvim_buf_set_lines(state.output_buf, 0, -1, false, lines)
|
||||
vim.bo[state.output_buf].modifiable = false
|
||||
|
||||
-- Scroll to bottom
|
||||
if state.output_win and vim.api.nvim_win_is_valid(state.output_win) then
|
||||
local line_count = vim.api.nvim_buf_line_count(state.output_buf)
|
||||
pcall(vim.api.nvim_win_set_cursor, state.output_win, { line_count, 0 })
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Continue submission after exploration
|
||||
---@param question string
|
||||
---@param intent table
|
||||
---@param context table
|
||||
---@param file_context string
|
||||
---@param file_count number
|
||||
---@param exploration_result table|nil
|
||||
local function continue_submit(question, intent, context, file_context, file_count, exploration_result)
|
||||
-- Get prompt type based on intent
|
||||
local ok_intent, intent_module = pcall(require, "codetyper.ask.intent")
|
||||
local prompt_type = "ask"
|
||||
if ok_intent then
|
||||
prompt_type = intent_module.get_prompt_type(intent)
|
||||
end
|
||||
|
||||
-- Add user message to output
|
||||
append_to_output(display_question, true)
|
||||
|
||||
-- Clear input and references AFTER building context
|
||||
M.clear_input()
|
||||
|
||||
-- Build system prompt for ask mode using prompts module
|
||||
-- Build system prompt using prompts module
|
||||
local prompts = require("codetyper.prompts")
|
||||
local system_prompt = prompts.system.ask
|
||||
local system_prompt = prompts.system[prompt_type] or prompts.system.ask
|
||||
|
||||
if context.current_file then
|
||||
system_prompt = system_prompt .. "\n\nCurrent open file: " .. context.current_file
|
||||
system_prompt = system_prompt .. "\nLanguage: " .. (context.language or "unknown")
|
||||
end
|
||||
|
||||
-- Add exploration context if available
|
||||
if exploration_result then
|
||||
local ok_explorer, explorer = pcall(require, "codetyper.ask.explorer")
|
||||
if ok_explorer then
|
||||
local explore_context = explorer.build_context(exploration_result)
|
||||
system_prompt = system_prompt .. "\n\n=== PROJECT EXPLORATION RESULTS ===\n"
|
||||
system_prompt = system_prompt .. explore_context
|
||||
system_prompt = system_prompt .. "\n=== END EXPLORATION ===\n"
|
||||
end
|
||||
end
|
||||
|
||||
-- Add brain context (learned patterns, conventions)
|
||||
if context.brain_context and context.brain_context ~= "" then
|
||||
system_prompt = system_prompt .. "\n\n=== LEARNED PROJECT KNOWLEDGE ===\n"
|
||||
system_prompt = system_prompt .. context.brain_context
|
||||
system_prompt = system_prompt .. "\n=== END LEARNED KNOWLEDGE ===\n"
|
||||
end
|
||||
|
||||
-- Add indexer context (project structure, symbols)
|
||||
if context.indexer_context then
|
||||
local idx_ctx = context.indexer_context
|
||||
if idx_ctx.project_type and idx_ctx.project_type ~= "unknown" then
|
||||
system_prompt = system_prompt .. "\n\nProject type: " .. idx_ctx.project_type
|
||||
end
|
||||
if idx_ctx.relevant_symbols and next(idx_ctx.relevant_symbols) then
|
||||
system_prompt = system_prompt .. "\n\nRelevant symbols in project:"
|
||||
for symbol, files in pairs(idx_ctx.relevant_symbols) do
|
||||
system_prompt = system_prompt .. "\n - " .. symbol .. " (in: " .. table.concat(files, ", ") .. ")"
|
||||
end
|
||||
end
|
||||
if idx_ctx.patterns and #idx_ctx.patterns > 0 then
|
||||
system_prompt = system_prompt .. "\n\nProject patterns/memories:"
|
||||
for _, pattern in ipairs(idx_ctx.patterns) do
|
||||
system_prompt = system_prompt .. "\n - " .. (pattern.summary or pattern.content or "")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Add to history
|
||||
table.insert(state.history, { role = "user", content = question })
|
||||
|
||||
-- Show loading indicator
|
||||
append_to_output("⏳ Thinking...", false)
|
||||
append_to_output("", false)
|
||||
append_to_output("⏳ Generating response...", false)
|
||||
|
||||
-- Get LLM client and generate response
|
||||
local ok, llm = pcall(require, "codetyper.llm")
|
||||
@@ -807,18 +976,49 @@ function M.submit()
|
||||
|
||||
local client = llm.get_client()
|
||||
|
||||
-- Build full prompt WITH file contents
|
||||
local full_prompt = question
|
||||
if file_context ~= "" then
|
||||
full_prompt = "USER QUESTION: "
|
||||
.. question
|
||||
.. "\n\n"
|
||||
.. "ATTACHED FILE CONTENTS (please analyze these):"
|
||||
.. file_context
|
||||
-- Build recent conversation context (limit to last N entries)
|
||||
local history_context = ""
|
||||
do
|
||||
local max_entries = 8
|
||||
local total = #state.history
|
||||
local start_i = 1
|
||||
if total > max_entries then
|
||||
start_i = total - max_entries + 1
|
||||
end
|
||||
if total > 0 then
|
||||
history_context = "\n\n=== PREVIOUS CONVERSATION ===\n"
|
||||
for i = start_i, total do
|
||||
local m = state.history[i]
|
||||
local role = (m.role == "assistant") and "ASSISTANT" or "USER"
|
||||
history_context = history_context .. role .. ": " .. (m.content or "") .. "\n"
|
||||
end
|
||||
history_context = history_context .. "=== END PREVIOUS CONVERSATION ===\n\n"
|
||||
end
|
||||
end
|
||||
|
||||
-- Also add current file if no files were explicitly attached
|
||||
if file_count == 0 and context.current_content and context.current_content ~= "" then
|
||||
-- Build full prompt starting with recent conversation + user question
|
||||
local full_prompt = history_context .. "USER QUESTION: " .. question
|
||||
|
||||
-- Add visual selection context if present
|
||||
if state.selection_context and state.selection_context.text and state.selection_context.text ~= "" then
|
||||
local sel = state.selection_context
|
||||
local location = sel.filename or "unknown"
|
||||
if sel.start_line then
|
||||
location = location .. ":" .. sel.start_line
|
||||
if sel.end_line and sel.end_line ~= sel.start_line then
|
||||
location = location .. "-" .. sel.end_line
|
||||
end
|
||||
end
|
||||
full_prompt = full_prompt .. "\n\nSELECTED CODE (" .. location .. "):\n```" .. (sel.language or "") .. "\n"
|
||||
full_prompt = full_prompt .. sel.text .. "\n```"
|
||||
end
|
||||
|
||||
if file_context ~= "" then
|
||||
full_prompt = full_prompt .. "\n\nATTACHED FILE CONTENTS (please analyze these):" .. file_context
|
||||
end
|
||||
|
||||
-- Also add current file if no files were explicitly attached and no selection
|
||||
if file_count == 0 and not state.selection_context and context.current_content and context.current_content ~= "" then
|
||||
full_prompt = "USER QUESTION: "
|
||||
.. question
|
||||
.. "\n\n"
|
||||
@@ -829,10 +1029,23 @@ function M.submit()
|
||||
.. "\n```"
|
||||
end
|
||||
|
||||
-- Add exploration summary to prompt if available
|
||||
if exploration_result then
|
||||
full_prompt = full_prompt
|
||||
.. "\n\nPROJECT EXPLORATION COMPLETE: "
|
||||
.. exploration_result.total_files
|
||||
.. " files analyzed. "
|
||||
.. "Project type: "
|
||||
.. exploration_result.project.language
|
||||
.. " ("
|
||||
.. (exploration_result.project.framework or exploration_result.project.type)
|
||||
.. ")"
|
||||
end
|
||||
|
||||
local request_context = {
|
||||
file_content = file_context ~= "" and file_context or context.current_content,
|
||||
language = context.language,
|
||||
prompt_type = "explain",
|
||||
prompt_type = prompt_type,
|
||||
file_path = context.current_file,
|
||||
}
|
||||
|
||||
@@ -844,9 +1057,9 @@ function M.submit()
|
||||
-- Remove last few lines (the thinking message)
|
||||
local to_remove = 0
|
||||
for i = #lines, 1, -1 do
|
||||
if lines[i]:match("Thinking") or lines[i]:match("^[│└┌─]") then
|
||||
if lines[i]:match("Generating") or lines[i]:match("^[│└┌─]") or lines[i] == "" then
|
||||
to_remove = to_remove + 1
|
||||
if lines[i]:match("┌") then
|
||||
if lines[i]:match("┌") or to_remove >= 5 then
|
||||
break
|
||||
end
|
||||
else
|
||||
@@ -879,6 +1092,77 @@ function M.submit()
|
||||
end)
|
||||
end
|
||||
|
||||
--- Submit the question to LLM
|
||||
function M.submit()
|
||||
local question = get_input_text()
|
||||
|
||||
if not question or question:match("^%s*$") then
|
||||
utils.notify("Please enter a question", vim.log.levels.WARN)
|
||||
M.focus_input()
|
||||
return
|
||||
end
|
||||
|
||||
-- Detect intent from prompt
|
||||
local ok_intent, intent_module = pcall(require, "codetyper.ask.intent")
|
||||
local intent = nil
|
||||
if ok_intent then
|
||||
intent = intent_module.detect(question)
|
||||
else
|
||||
-- Fallback intent
|
||||
intent = {
|
||||
type = "ask",
|
||||
confidence = 0.5,
|
||||
needs_project_context = false,
|
||||
needs_brain_context = true,
|
||||
needs_exploration = false,
|
||||
}
|
||||
end
|
||||
|
||||
-- Build context BEFORE clearing input (to preserve file references)
|
||||
local context = build_context(intent)
|
||||
local file_context, file_count = build_file_context()
|
||||
|
||||
-- Build display message (without full file contents)
|
||||
local display_question = question
|
||||
if file_count > 0 then
|
||||
display_question = question .. "\n📎 " .. file_count .. " file(s) attached"
|
||||
end
|
||||
-- Show detected intent if not standard ask
|
||||
if intent.type ~= "ask" then
|
||||
display_question = display_question .. "\n🎯 " .. intent.type:upper() .. " mode"
|
||||
end
|
||||
-- Show exploration indicator
|
||||
if intent.needs_exploration then
|
||||
display_question = display_question .. "\n🔍 Project exploration required"
|
||||
end
|
||||
|
||||
-- Add user message to output
|
||||
append_to_output(display_question, true)
|
||||
|
||||
-- Clear input and references AFTER building context
|
||||
M.clear_input()
|
||||
|
||||
-- Check if exploration is needed
|
||||
if intent.needs_exploration then
|
||||
local ok_explorer, explorer = pcall(require, "codetyper.ask.explorer")
|
||||
if ok_explorer then
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
-- Start exploration with logging
|
||||
append_to_output("", false)
|
||||
explorer.explore(root, append_exploration_log, function(exploration_result)
|
||||
-- After exploration completes, continue with LLM request
|
||||
continue_submit(question, intent, context, file_context, file_count, exploration_result)
|
||||
end)
|
||||
return
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- No exploration needed, continue directly
|
||||
continue_submit(question, intent, context, file_context, file_count, nil)
|
||||
end
|
||||
|
||||
--- Clear chat history
|
||||
function M.clear_history()
|
||||
state.history = {}
|
||||
676
lua/codetyper/features/ask/explorer.lua
Normal file
676
lua/codetyper/features/ask/explorer.lua
Normal file
@@ -0,0 +1,676 @@
|
||||
---@mod codetyper.ask.explorer Project exploration for Ask mode
|
||||
---@brief [[
|
||||
--- Performs comprehensive project exploration when explaining a project.
|
||||
--- Shows progress, indexes files, and builds brain context.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
---@class ExplorationState
|
||||
---@field is_exploring boolean
|
||||
---@field files_scanned number
|
||||
---@field total_files number
|
||||
---@field current_file string|nil
|
||||
---@field findings table
|
||||
---@field on_log fun(msg: string, level: string)|nil
|
||||
|
||||
local state = {
|
||||
is_exploring = false,
|
||||
files_scanned = 0,
|
||||
total_files = 0,
|
||||
current_file = nil,
|
||||
findings = {},
|
||||
on_log = nil,
|
||||
}
|
||||
|
||||
--- File extensions to analyze
|
||||
local ANALYZABLE_EXTENSIONS = {
|
||||
lua = true,
|
||||
ts = true,
|
||||
tsx = true,
|
||||
js = true,
|
||||
jsx = true,
|
||||
py = true,
|
||||
go = true,
|
||||
rs = true,
|
||||
rb = true,
|
||||
java = true,
|
||||
c = true,
|
||||
cpp = true,
|
||||
h = true,
|
||||
hpp = true,
|
||||
json = true,
|
||||
yaml = true,
|
||||
yml = true,
|
||||
toml = true,
|
||||
md = true,
|
||||
xml = true,
|
||||
}
|
||||
|
||||
--- Directories to skip
|
||||
local SKIP_DIRS = {
|
||||
-- Version control
|
||||
[".git"] = true,
|
||||
[".svn"] = true,
|
||||
[".hg"] = true,
|
||||
|
||||
-- IDE/Editor
|
||||
[".idea"] = true,
|
||||
[".vscode"] = true,
|
||||
[".cursor"] = true,
|
||||
[".cursorignore"] = true,
|
||||
[".claude"] = true,
|
||||
[".zed"] = true,
|
||||
|
||||
-- Project tooling
|
||||
[".coder"] = true,
|
||||
[".github"] = true,
|
||||
[".gitlab"] = true,
|
||||
[".husky"] = true,
|
||||
|
||||
-- Build outputs
|
||||
dist = true,
|
||||
build = true,
|
||||
out = true,
|
||||
target = true,
|
||||
bin = true,
|
||||
obj = true,
|
||||
[".build"] = true,
|
||||
[".output"] = true,
|
||||
|
||||
-- Dependencies
|
||||
node_modules = true,
|
||||
vendor = true,
|
||||
[".vendor"] = true,
|
||||
packages = true,
|
||||
bower_components = true,
|
||||
jspm_packages = true,
|
||||
|
||||
-- Cache/temp
|
||||
[".cache"] = true,
|
||||
[".tmp"] = true,
|
||||
[".temp"] = true,
|
||||
__pycache__ = true,
|
||||
[".pytest_cache"] = true,
|
||||
[".mypy_cache"] = true,
|
||||
[".ruff_cache"] = true,
|
||||
[".tox"] = true,
|
||||
[".nox"] = true,
|
||||
[".eggs"] = true,
|
||||
["*.egg-info"] = true,
|
||||
|
||||
-- Framework specific
|
||||
[".next"] = true,
|
||||
[".nuxt"] = true,
|
||||
[".svelte-kit"] = true,
|
||||
[".vercel"] = true,
|
||||
[".netlify"] = true,
|
||||
[".serverless"] = true,
|
||||
[".turbo"] = true,
|
||||
|
||||
-- Testing/coverage
|
||||
coverage = true,
|
||||
[".nyc_output"] = true,
|
||||
htmlcov = true,
|
||||
|
||||
-- Logs
|
||||
logs = true,
|
||||
log = true,
|
||||
|
||||
-- OS files
|
||||
[".DS_Store"] = true,
|
||||
Thumbs_db = true,
|
||||
}
|
||||
|
||||
--- Files to skip (patterns)
|
||||
local SKIP_FILES = {
|
||||
-- Lock files
|
||||
"package%-lock%.json",
|
||||
"yarn%.lock",
|
||||
"pnpm%-lock%.yaml",
|
||||
"Gemfile%.lock",
|
||||
"Cargo%.lock",
|
||||
"poetry%.lock",
|
||||
"Pipfile%.lock",
|
||||
"composer%.lock",
|
||||
"go%.sum",
|
||||
"flake%.lock",
|
||||
"%.lock$",
|
||||
"%-lock%.json$",
|
||||
"%-lock%.yaml$",
|
||||
|
||||
-- Generated files
|
||||
"%.min%.js$",
|
||||
"%.min%.css$",
|
||||
"%.bundle%.js$",
|
||||
"%.chunk%.js$",
|
||||
"%.map$",
|
||||
"%.d%.ts$",
|
||||
|
||||
-- Binary/media (shouldn't match anyway but be safe)
|
||||
"%.png$",
|
||||
"%.jpg$",
|
||||
"%.jpeg$",
|
||||
"%.gif$",
|
||||
"%.ico$",
|
||||
"%.svg$",
|
||||
"%.woff",
|
||||
"%.ttf$",
|
||||
"%.eot$",
|
||||
"%.pdf$",
|
||||
"%.zip$",
|
||||
"%.tar",
|
||||
"%.gz$",
|
||||
|
||||
-- Config that's not useful
|
||||
"%.env",
|
||||
"%.env%.",
|
||||
}
|
||||
|
||||
--- Log a message during exploration
|
||||
---@param msg string
|
||||
---@param level? string "info"|"debug"|"file"|"progress"
|
||||
local function log(msg, level)
|
||||
level = level or "info"
|
||||
if state.on_log then
|
||||
state.on_log(msg, level)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if file should be skipped
|
||||
---@param filename string
|
||||
---@return boolean
|
||||
local function should_skip_file(filename)
|
||||
for _, pattern in ipairs(SKIP_FILES) do
|
||||
if filename:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if directory should be skipped
|
||||
---@param dirname string
|
||||
---@return boolean
|
||||
local function should_skip_dir(dirname)
|
||||
-- Direct match
|
||||
if SKIP_DIRS[dirname] then
|
||||
return true
|
||||
end
|
||||
-- Pattern match for .cursor* etc
|
||||
if dirname:match("^%.cursor") then
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
--- Get all files in project
|
||||
---@param root string Project root
|
||||
---@return string[] files
|
||||
local function get_project_files(root)
|
||||
local files = {}
|
||||
|
||||
local function scan_dir(dir)
|
||||
local handle = vim.loop.fs_scandir(dir)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
|
||||
local full_path = dir .. "/" .. name
|
||||
|
||||
if type == "directory" then
|
||||
if not should_skip_dir(name) then
|
||||
scan_dir(full_path)
|
||||
end
|
||||
elseif type == "file" then
|
||||
if not should_skip_file(name) then
|
||||
local ext = name:match("%.([^%.]+)$")
|
||||
if ext and ANALYZABLE_EXTENSIONS[ext:lower()] then
|
||||
table.insert(files, full_path)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
scan_dir(root)
|
||||
return files
|
||||
end
|
||||
|
||||
--- Analyze a single file
|
||||
---@param filepath string
|
||||
---@return table|nil analysis
|
||||
local function analyze_file(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content or content == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ext = filepath:match("%.([^%.]+)$") or ""
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
local analysis = {
|
||||
path = filepath,
|
||||
extension = ext,
|
||||
lines = #lines,
|
||||
size = #content,
|
||||
imports = {},
|
||||
exports = {},
|
||||
functions = {},
|
||||
classes = {},
|
||||
summary = "",
|
||||
}
|
||||
|
||||
-- Extract key patterns based on file type
|
||||
for i, line in ipairs(lines) do
|
||||
-- Imports/requires
|
||||
local import = line:match('import%s+.*%s+from%s+["\']([^"\']+)["\']')
|
||||
or line:match('require%(["\']([^"\']+)["\']%)')
|
||||
or line:match("from%s+([%w_.]+)%s+import")
|
||||
if import then
|
||||
table.insert(analysis.imports, { source = import, line = i })
|
||||
end
|
||||
|
||||
-- Function definitions
|
||||
local func = line:match("^%s*function%s+([%w_:%.]+)%s*%(")
|
||||
or line:match("^%s*local%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*def%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*func%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*async%s+function%s+([%w_]+)%s*%(")
|
||||
or line:match("^%s*public%s+.*%s+([%w_]+)%s*%(")
|
||||
if func then
|
||||
table.insert(analysis.functions, { name = func, line = i })
|
||||
end
|
||||
|
||||
-- Class definitions
|
||||
local class = line:match("^%s*class%s+([%w_]+)")
|
||||
or line:match("^%s*public%s+class%s+([%w_]+)")
|
||||
or line:match("^%s*interface%s+([%w_]+)")
|
||||
if class then
|
||||
table.insert(analysis.classes, { name = class, line = i })
|
||||
end
|
||||
|
||||
-- Exports
|
||||
local exp = line:match("^%s*export%s+.*%s+([%w_]+)")
|
||||
or line:match("^%s*module%.exports%s*=")
|
||||
or line:match("^return%s+M")
|
||||
if exp then
|
||||
table.insert(analysis.exports, { name = exp, line = i })
|
||||
end
|
||||
end
|
||||
|
||||
-- Create summary
|
||||
local parts = {}
|
||||
if #analysis.functions > 0 then
|
||||
table.insert(parts, #analysis.functions .. " functions")
|
||||
end
|
||||
if #analysis.classes > 0 then
|
||||
table.insert(parts, #analysis.classes .. " classes")
|
||||
end
|
||||
if #analysis.imports > 0 then
|
||||
table.insert(parts, #analysis.imports .. " imports")
|
||||
end
|
||||
analysis.summary = table.concat(parts, ", ")
|
||||
|
||||
return analysis
|
||||
end
|
||||
|
||||
--- Detect project type from files
|
||||
---@param root string
|
||||
---@return string type, table info
|
||||
local function detect_project_type(root)
|
||||
local info = {
|
||||
name = vim.fn.fnamemodify(root, ":t"),
|
||||
type = "unknown",
|
||||
framework = nil,
|
||||
language = nil,
|
||||
}
|
||||
|
||||
-- Check for common project files
|
||||
if utils.file_exists(root .. "/package.json") then
|
||||
info.type = "node"
|
||||
info.language = "JavaScript/TypeScript"
|
||||
local content = utils.read_file(root .. "/package.json")
|
||||
if content then
|
||||
local ok, pkg = pcall(vim.json.decode, content)
|
||||
if ok then
|
||||
info.name = pkg.name or info.name
|
||||
if pkg.dependencies then
|
||||
if pkg.dependencies.react then
|
||||
info.framework = "React"
|
||||
elseif pkg.dependencies.vue then
|
||||
info.framework = "Vue"
|
||||
elseif pkg.dependencies.next then
|
||||
info.framework = "Next.js"
|
||||
elseif pkg.dependencies.express then
|
||||
info.framework = "Express"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif utils.file_exists(root .. "/pom.xml") then
|
||||
info.type = "maven"
|
||||
info.language = "Java"
|
||||
local content = utils.read_file(root .. "/pom.xml")
|
||||
if content and content:match("spring%-boot") then
|
||||
info.framework = "Spring Boot"
|
||||
end
|
||||
elseif utils.file_exists(root .. "/Cargo.toml") then
|
||||
info.type = "rust"
|
||||
info.language = "Rust"
|
||||
elseif utils.file_exists(root .. "/go.mod") then
|
||||
info.type = "go"
|
||||
info.language = "Go"
|
||||
elseif utils.file_exists(root .. "/requirements.txt") or utils.file_exists(root .. "/pyproject.toml") then
|
||||
info.type = "python"
|
||||
info.language = "Python"
|
||||
elseif utils.file_exists(root .. "/init.lua") or utils.file_exists(root .. "/plugin/") then
|
||||
info.type = "neovim-plugin"
|
||||
info.language = "Lua"
|
||||
end
|
||||
|
||||
return info.type, info
|
||||
end
|
||||
|
||||
--- Build project structure summary
|
||||
---@param files string[]
|
||||
---@param root string
|
||||
---@return table structure
|
||||
local function build_structure(files, root)
|
||||
local structure = {
|
||||
directories = {},
|
||||
by_extension = {},
|
||||
total_files = #files,
|
||||
}
|
||||
|
||||
for _, file in ipairs(files) do
|
||||
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local dir = vim.fn.fnamemodify(relative, ":h")
|
||||
local ext = file:match("%.([^%.]+)$") or "unknown"
|
||||
|
||||
structure.directories[dir] = (structure.directories[dir] or 0) + 1
|
||||
structure.by_extension[ext] = (structure.by_extension[ext] or 0) + 1
|
||||
end
|
||||
|
||||
return structure
|
||||
end
|
||||
|
||||
--- Explore project and build context
|
||||
---@param root string Project root
|
||||
---@param on_log fun(msg: string, level: string) Log callback
|
||||
---@param on_complete fun(result: table) Completion callback
|
||||
function M.explore(root, on_log, on_complete)
|
||||
if state.is_exploring then
|
||||
on_log("⚠️ Already exploring...", "warning")
|
||||
return
|
||||
end
|
||||
|
||||
state.is_exploring = true
|
||||
state.on_log = on_log
|
||||
state.findings = {}
|
||||
|
||||
-- Start exploration
|
||||
log("⏺ Exploring project structure...", "info")
|
||||
log("", "info")
|
||||
|
||||
-- Detect project type
|
||||
log(" Detect(Project type)", "progress")
|
||||
local project_type, project_info = detect_project_type(root)
|
||||
log(" ⎿ " .. project_info.language .. " (" .. (project_info.framework or project_type) .. ")", "debug")
|
||||
|
||||
state.findings.project = project_info
|
||||
|
||||
-- Get all files
|
||||
log("", "info")
|
||||
log(" Scan(Project files)", "progress")
|
||||
local files = get_project_files(root)
|
||||
state.total_files = #files
|
||||
log(" ⎿ Found " .. #files .. " analyzable files", "debug")
|
||||
|
||||
-- Build structure
|
||||
local structure = build_structure(files, root)
|
||||
state.findings.structure = structure
|
||||
|
||||
-- Show directory breakdown
|
||||
log("", "info")
|
||||
log(" Structure(Directories)", "progress")
|
||||
local sorted_dirs = {}
|
||||
for dir, count in pairs(structure.directories) do
|
||||
table.insert(sorted_dirs, { dir = dir, count = count })
|
||||
end
|
||||
table.sort(sorted_dirs, function(a, b)
|
||||
return a.count > b.count
|
||||
end)
|
||||
for i, entry in ipairs(sorted_dirs) do
|
||||
if i <= 5 then
|
||||
log(" ⎿ " .. entry.dir .. " (" .. entry.count .. " files)", "debug")
|
||||
end
|
||||
end
|
||||
if #sorted_dirs > 5 then
|
||||
log(" ⎿ +" .. (#sorted_dirs - 5) .. " more directories", "debug")
|
||||
end
|
||||
|
||||
-- Analyze files asynchronously
|
||||
log("", "info")
|
||||
log(" Analyze(Source files)", "progress")
|
||||
|
||||
state.files_scanned = 0
|
||||
local analyses = {}
|
||||
local key_files = {}
|
||||
|
||||
-- Process files in batches to avoid blocking
|
||||
local batch_size = 10
|
||||
local current_batch = 0
|
||||
|
||||
local function process_batch()
|
||||
local start_idx = current_batch * batch_size + 1
|
||||
local end_idx = math.min(start_idx + batch_size - 1, #files)
|
||||
|
||||
for i = start_idx, end_idx do
|
||||
local file = files[i]
|
||||
local relative = file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
|
||||
state.files_scanned = state.files_scanned + 1
|
||||
state.current_file = relative
|
||||
|
||||
local analysis = analyze_file(file)
|
||||
if analysis then
|
||||
analysis.relative_path = relative
|
||||
table.insert(analyses, analysis)
|
||||
|
||||
-- Track key files (many functions/classes)
|
||||
if #analysis.functions >= 3 or #analysis.classes >= 1 then
|
||||
table.insert(key_files, {
|
||||
path = relative,
|
||||
functions = #analysis.functions,
|
||||
classes = #analysis.classes,
|
||||
summary = analysis.summary,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Log some files
|
||||
if i <= 3 or (i % 20 == 0) then
|
||||
log(" ⎿ " .. relative .. ": " .. (analysis and analysis.summary or "(empty)"), "file")
|
||||
end
|
||||
end
|
||||
|
||||
-- Progress update
|
||||
local progress = math.floor((state.files_scanned / state.total_files) * 100)
|
||||
if progress % 25 == 0 and progress > 0 then
|
||||
log(" ⎿ " .. progress .. "% complete (" .. state.files_scanned .. "/" .. state.total_files .. ")", "debug")
|
||||
end
|
||||
|
||||
current_batch = current_batch + 1
|
||||
|
||||
if end_idx < #files then
|
||||
-- Schedule next batch
|
||||
vim.defer_fn(process_batch, 10)
|
||||
else
|
||||
-- Complete
|
||||
finish_exploration(root, analyses, key_files, on_complete)
|
||||
end
|
||||
end
|
||||
|
||||
-- Start processing
|
||||
vim.defer_fn(process_batch, 10)
|
||||
end
|
||||
|
||||
--- Finish exploration and store results
|
||||
---@param root string
|
||||
---@param analyses table
|
||||
---@param key_files table
|
||||
---@param on_complete fun(result: table)
|
||||
function finish_exploration(root, analyses, key_files, on_complete)
|
||||
log(" ⎿ +" .. (#analyses - 3) .. " more files analyzed", "debug")
|
||||
|
||||
-- Show key files
|
||||
if #key_files > 0 then
|
||||
log("", "info")
|
||||
log(" KeyFiles(Important components)", "progress")
|
||||
table.sort(key_files, function(a, b)
|
||||
return (a.functions + a.classes * 2) > (b.functions + b.classes * 2)
|
||||
end)
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i <= 5 then
|
||||
log(" ⎿ " .. kf.path .. ": " .. kf.summary, "file")
|
||||
end
|
||||
end
|
||||
if #key_files > 5 then
|
||||
log(" ⎿ +" .. (#key_files - 5) .. " more key files", "debug")
|
||||
end
|
||||
end
|
||||
|
||||
state.findings.analyses = analyses
|
||||
state.findings.key_files = key_files
|
||||
|
||||
-- Store in brain if available
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized() then
|
||||
log("", "info")
|
||||
log(" Store(Brain context)", "progress")
|
||||
|
||||
-- Store project pattern
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root,
|
||||
content = {
|
||||
summary = "Project: " .. state.findings.project.name,
|
||||
detail = state.findings.project.language
|
||||
.. " "
|
||||
.. (state.findings.project.framework or state.findings.project.type),
|
||||
code = nil,
|
||||
},
|
||||
context = {
|
||||
file = root,
|
||||
language = state.findings.project.language,
|
||||
},
|
||||
})
|
||||
|
||||
-- Store key file patterns
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i <= 10 then
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root .. "/" .. kf.path,
|
||||
content = {
|
||||
summary = kf.path .. " - " .. kf.summary,
|
||||
detail = kf.summary,
|
||||
},
|
||||
context = {
|
||||
file = kf.path,
|
||||
},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
log(" ⎿ Stored " .. math.min(#key_files, 10) + 1 .. " patterns in brain", "debug")
|
||||
end
|
||||
|
||||
-- Store in indexer if available
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
log(" Index(Project index)", "progress")
|
||||
indexer.index_project(function(index)
|
||||
log(" ⎿ Indexed " .. (index.stats.files or 0) .. " files", "debug")
|
||||
end)
|
||||
end
|
||||
|
||||
log("", "info")
|
||||
log("✓ Exploration complete!", "info")
|
||||
log("", "info")
|
||||
|
||||
-- Build result
|
||||
local result = {
|
||||
project = state.findings.project,
|
||||
structure = state.findings.structure,
|
||||
key_files = key_files,
|
||||
total_files = state.total_files,
|
||||
analyses = analyses,
|
||||
}
|
||||
|
||||
state.is_exploring = false
|
||||
state.on_log = nil
|
||||
|
||||
on_complete(result)
|
||||
end
|
||||
|
||||
--- Check if exploration is in progress
|
||||
---@return boolean
|
||||
function M.is_exploring()
|
||||
return state.is_exploring
|
||||
end
|
||||
|
||||
--- Get exploration progress
|
||||
---@return number scanned, number total
|
||||
function M.get_progress()
|
||||
return state.files_scanned, state.total_files
|
||||
end
|
||||
|
||||
--- Build context string from exploration result
|
||||
---@param result table Exploration result
|
||||
---@return string context
|
||||
function M.build_context(result)
|
||||
local parts = {}
|
||||
|
||||
-- Project info
|
||||
table.insert(parts, "## Project: " .. result.project.name)
|
||||
table.insert(parts, "- Type: " .. result.project.type)
|
||||
table.insert(parts, "- Language: " .. (result.project.language or "Unknown"))
|
||||
if result.project.framework then
|
||||
table.insert(parts, "- Framework: " .. result.project.framework)
|
||||
end
|
||||
table.insert(parts, "- Files: " .. result.total_files)
|
||||
table.insert(parts, "")
|
||||
|
||||
-- Structure
|
||||
table.insert(parts, "## Structure")
|
||||
if result.structure and result.structure.by_extension then
|
||||
for ext, count in pairs(result.structure.by_extension) do
|
||||
table.insert(parts, "- ." .. ext .. ": " .. count .. " files")
|
||||
end
|
||||
end
|
||||
table.insert(parts, "")
|
||||
|
||||
-- Key components
|
||||
if result.key_files and #result.key_files > 0 then
|
||||
table.insert(parts, "## Key Components")
|
||||
for i, kf in ipairs(result.key_files) do
|
||||
if i <= 10 then
|
||||
table.insert(parts, "- " .. kf.path .. ": " .. kf.summary)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return table.concat(parts, "\n")
|
||||
end
|
||||
|
||||
return M
|
||||
302
lua/codetyper/features/ask/intent.lua
Normal file
302
lua/codetyper/features/ask/intent.lua
Normal file
@@ -0,0 +1,302 @@
|
||||
---@mod codetyper.ask.intent Intent detection for Ask mode
|
||||
---@brief [[
|
||||
--- Analyzes user prompts to detect intent (ask/explain vs code generation).
|
||||
--- Routes to appropriate prompt type and context sources.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@alias IntentType "ask"|"explain"|"generate"|"refactor"|"document"|"test"
|
||||
|
||||
---@class Intent
|
||||
---@field type IntentType Detected intent type
|
||||
---@field confidence number 0-1 confidence score
|
||||
---@field needs_project_context boolean Whether project-wide context is needed
|
||||
---@field needs_brain_context boolean Whether brain/learned context is helpful
|
||||
---@field needs_exploration boolean Whether full project exploration is needed
|
||||
---@field keywords string[] Keywords that influenced detection
|
||||
|
||||
--- Patterns for detecting ask/explain intent (questions about code)
|
||||
local ASK_PATTERNS = {
|
||||
-- Question words
|
||||
{ pattern = "^what%s", weight = 0.9 },
|
||||
{ pattern = "^why%s", weight = 0.95 },
|
||||
{ pattern = "^how%s+does", weight = 0.9 },
|
||||
{ pattern = "^how%s+do%s+i", weight = 0.7 }, -- Could be asking for code
|
||||
{ pattern = "^where%s", weight = 0.85 },
|
||||
{ pattern = "^when%s", weight = 0.85 },
|
||||
{ pattern = "^which%s", weight = 0.8 },
|
||||
{ pattern = "^who%s", weight = 0.85 },
|
||||
{ pattern = "^can%s+you%s+explain", weight = 0.95 },
|
||||
{ pattern = "^could%s+you%s+explain", weight = 0.95 },
|
||||
{ pattern = "^please%s+explain", weight = 0.95 },
|
||||
|
||||
-- Explanation requests
|
||||
{ pattern = "explain%s", weight = 0.9 },
|
||||
{ pattern = "describe%s", weight = 0.85 },
|
||||
{ pattern = "tell%s+me%s+about", weight = 0.85 },
|
||||
{ pattern = "walk%s+me%s+through", weight = 0.9 },
|
||||
{ pattern = "help%s+me%s+understand", weight = 0.95 },
|
||||
{ pattern = "what%s+is%s+the%s+purpose", weight = 0.95 },
|
||||
{ pattern = "what%s+does%s+this", weight = 0.9 },
|
||||
{ pattern = "what%s+does%s+it", weight = 0.9 },
|
||||
{ pattern = "how%s+does%s+this%s+work", weight = 0.95 },
|
||||
{ pattern = "how%s+does%s+it%s+work", weight = 0.95 },
|
||||
|
||||
-- Understanding queries
|
||||
{ pattern = "understand", weight = 0.7 },
|
||||
{ pattern = "meaning%s+of", weight = 0.85 },
|
||||
{ pattern = "difference%s+between", weight = 0.9 },
|
||||
{ pattern = "compared%s+to", weight = 0.8 },
|
||||
{ pattern = "vs%s", weight = 0.7 },
|
||||
{ pattern = "versus", weight = 0.7 },
|
||||
{ pattern = "pros%s+and%s+cons", weight = 0.9 },
|
||||
{ pattern = "advantages", weight = 0.8 },
|
||||
{ pattern = "disadvantages", weight = 0.8 },
|
||||
{ pattern = "trade%-?offs?", weight = 0.85 },
|
||||
|
||||
-- Analysis requests
|
||||
{ pattern = "analyze", weight = 0.85 },
|
||||
{ pattern = "review", weight = 0.7 }, -- Could also be refactor
|
||||
{ pattern = "overview", weight = 0.9 },
|
||||
{ pattern = "summary", weight = 0.9 },
|
||||
{ pattern = "summarize", weight = 0.9 },
|
||||
|
||||
-- Question marks (weaker signal)
|
||||
{ pattern = "%?$", weight = 0.3 },
|
||||
{ pattern = "%?%s*$", weight = 0.3 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting code generation intent
|
||||
local GENERATE_PATTERNS = {
|
||||
-- Direct commands
|
||||
{ pattern = "^create%s", weight = 0.9 },
|
||||
{ pattern = "^make%s", weight = 0.85 },
|
||||
{ pattern = "^build%s", weight = 0.85 },
|
||||
{ pattern = "^write%s", weight = 0.9 },
|
||||
{ pattern = "^add%s", weight = 0.85 },
|
||||
{ pattern = "^implement%s", weight = 0.95 },
|
||||
{ pattern = "^generate%s", weight = 0.95 },
|
||||
{ pattern = "^code%s", weight = 0.8 },
|
||||
|
||||
-- Modification commands
|
||||
{ pattern = "^fix%s", weight = 0.9 },
|
||||
{ pattern = "^change%s", weight = 0.8 },
|
||||
{ pattern = "^update%s", weight = 0.75 },
|
||||
{ pattern = "^modify%s", weight = 0.8 },
|
||||
{ pattern = "^replace%s", weight = 0.85 },
|
||||
{ pattern = "^remove%s", weight = 0.85 },
|
||||
{ pattern = "^delete%s", weight = 0.85 },
|
||||
|
||||
-- Feature requests
|
||||
{ pattern = "i%s+need%s+a", weight = 0.8 },
|
||||
{ pattern = "i%s+want%s+a", weight = 0.8 },
|
||||
{ pattern = "give%s+me", weight = 0.7 },
|
||||
{ pattern = "show%s+me%s+how%s+to%s+code", weight = 0.9 },
|
||||
{ pattern = "how%s+do%s+i%s+implement", weight = 0.85 },
|
||||
{ pattern = "can%s+you%s+write", weight = 0.9 },
|
||||
{ pattern = "can%s+you%s+create", weight = 0.9 },
|
||||
{ pattern = "can%s+you%s+add", weight = 0.85 },
|
||||
{ pattern = "can%s+you%s+make", weight = 0.85 },
|
||||
|
||||
-- Code-specific terms
|
||||
{ pattern = "function%s+that", weight = 0.85 },
|
||||
{ pattern = "class%s+that", weight = 0.85 },
|
||||
{ pattern = "method%s+that", weight = 0.85 },
|
||||
{ pattern = "component%s+that", weight = 0.85 },
|
||||
{ pattern = "module%s+that", weight = 0.85 },
|
||||
{ pattern = "api%s+for", weight = 0.8 },
|
||||
{ pattern = "endpoint%s+for", weight = 0.8 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting refactor intent
|
||||
local REFACTOR_PATTERNS = {
|
||||
{ pattern = "^refactor%s", weight = 0.95 },
|
||||
{ pattern = "refactor%s+this", weight = 0.95 },
|
||||
{ pattern = "clean%s+up", weight = 0.85 },
|
||||
{ pattern = "improve%s+this%s+code", weight = 0.85 },
|
||||
{ pattern = "make%s+this%s+cleaner", weight = 0.85 },
|
||||
{ pattern = "simplify", weight = 0.8 },
|
||||
{ pattern = "optimize", weight = 0.75 }, -- Could be explain
|
||||
{ pattern = "reorganize", weight = 0.9 },
|
||||
{ pattern = "restructure", weight = 0.9 },
|
||||
{ pattern = "extract%s+to", weight = 0.9 },
|
||||
{ pattern = "split%s+into", weight = 0.85 },
|
||||
{ pattern = "dry%s+this", weight = 0.9 }, -- Don't repeat yourself
|
||||
{ pattern = "reduce%s+duplication", weight = 0.9 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting documentation intent
|
||||
local DOCUMENT_PATTERNS = {
|
||||
{ pattern = "^document%s", weight = 0.95 },
|
||||
{ pattern = "add%s+documentation", weight = 0.95 },
|
||||
{ pattern = "add%s+docs", weight = 0.95 },
|
||||
{ pattern = "add%s+comments", weight = 0.9 },
|
||||
{ pattern = "add%s+docstring", weight = 0.95 },
|
||||
{ pattern = "add%s+jsdoc", weight = 0.95 },
|
||||
{ pattern = "write%s+documentation", weight = 0.95 },
|
||||
{ pattern = "document%s+this", weight = 0.95 },
|
||||
}
|
||||
|
||||
--- Patterns for detecting test generation intent
|
||||
local TEST_PATTERNS = {
|
||||
{ pattern = "^test%s", weight = 0.9 },
|
||||
{ pattern = "write%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "add%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "create%s+tests?%s+for", weight = 0.95 },
|
||||
{ pattern = "generate%s+tests?", weight = 0.95 },
|
||||
{ pattern = "unit%s+tests?", weight = 0.9 },
|
||||
{ pattern = "test%s+cases?%s+for", weight = 0.95 },
|
||||
{ pattern = "spec%s+for", weight = 0.85 },
|
||||
}
|
||||
|
||||
--- Patterns indicating project-wide context is needed
|
||||
local PROJECT_CONTEXT_PATTERNS = {
|
||||
{ pattern = "project", weight = 0.9 },
|
||||
{ pattern = "codebase", weight = 0.95 },
|
||||
{ pattern = "entire", weight = 0.7 },
|
||||
{ pattern = "whole", weight = 0.7 },
|
||||
{ pattern = "all%s+files", weight = 0.9 },
|
||||
{ pattern = "architecture", weight = 0.95 },
|
||||
{ pattern = "structure", weight = 0.85 },
|
||||
{ pattern = "how%s+is%s+.*%s+organized", weight = 0.95 },
|
||||
{ pattern = "where%s+is%s+.*%s+defined", weight = 0.9 },
|
||||
{ pattern = "dependencies", weight = 0.85 },
|
||||
{ pattern = "imports?%s+from", weight = 0.7 },
|
||||
{ pattern = "modules?", weight = 0.6 },
|
||||
{ pattern = "packages?", weight = 0.6 },
|
||||
}
|
||||
|
||||
--- Patterns indicating project exploration is needed (full indexing)
|
||||
local EXPLORE_PATTERNS = {
|
||||
{ pattern = "explain%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "explain%s+.*%s*codebase", weight = 1.0 },
|
||||
{ pattern = "explain%s+me%s+the%s+project", weight = 1.0 },
|
||||
{ pattern = "tell%s+me%s+about%s+.*%s*project", weight = 0.95 },
|
||||
{ pattern = "what%s+is%s+this%s+project", weight = 0.95 },
|
||||
{ pattern = "overview%s+of%s+.*%s*project", weight = 0.95 },
|
||||
{ pattern = "understand%s+.*%s*project", weight = 0.9 },
|
||||
{ pattern = "analyze%s+.*%s*project", weight = 0.9 },
|
||||
{ pattern = "explore%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "explore%s+.*%s*codebase", weight = 1.0 },
|
||||
{ pattern = "index%s+.*%s*project", weight = 1.0 },
|
||||
{ pattern = "scan%s+.*%s*project", weight = 0.95 },
|
||||
}
|
||||
|
||||
--- Match patterns against text
|
||||
---@param text string Lowercased text to match
|
||||
---@param patterns table Pattern list with weights
|
||||
---@return number Score, string[] Matched keywords
|
||||
local function match_patterns(text, patterns)
|
||||
local score = 0
|
||||
local matched = {}
|
||||
|
||||
for _, p in ipairs(patterns) do
|
||||
if text:match(p.pattern) then
|
||||
score = score + p.weight
|
||||
table.insert(matched, p.pattern)
|
||||
end
|
||||
end
|
||||
|
||||
return score, matched
|
||||
end
|
||||
|
||||
--- Detect intent from user prompt
|
||||
---@param prompt string User's question/request
|
||||
---@return Intent Detected intent
|
||||
function M.detect(prompt)
|
||||
local text = prompt:lower()
|
||||
|
||||
-- Calculate raw scores for each intent type (sum of matched weights)
|
||||
local ask_score, ask_kw = match_patterns(text, ASK_PATTERNS)
|
||||
local gen_score, gen_kw = match_patterns(text, GENERATE_PATTERNS)
|
||||
local ref_score, ref_kw = match_patterns(text, REFACTOR_PATTERNS)
|
||||
local doc_score, doc_kw = match_patterns(text, DOCUMENT_PATTERNS)
|
||||
local test_score, test_kw = match_patterns(text, TEST_PATTERNS)
|
||||
local proj_score, _ = match_patterns(text, PROJECT_CONTEXT_PATTERNS)
|
||||
local explore_score, _ = match_patterns(text, EXPLORE_PATTERNS)
|
||||
|
||||
-- Find the winner by raw score (highest accumulated weight)
|
||||
local scores = {
|
||||
{ type = "ask", score = ask_score, keywords = ask_kw },
|
||||
{ type = "generate", score = gen_score, keywords = gen_kw },
|
||||
{ type = "refactor", score = ref_score, keywords = ref_kw },
|
||||
{ type = "document", score = doc_score, keywords = doc_kw },
|
||||
{ type = "test", score = test_score, keywords = test_kw },
|
||||
}
|
||||
|
||||
table.sort(scores, function(a, b)
|
||||
return a.score > b.score
|
||||
end)
|
||||
|
||||
local winner = scores[1]
|
||||
|
||||
-- If top score is very low, default to ask (safer for Q&A)
|
||||
if winner.score < 0.3 then
|
||||
winner = { type = "ask", score = 0.5, keywords = {} }
|
||||
end
|
||||
|
||||
-- If ask and generate are close AND there's a question mark, prefer ask
|
||||
if winner.type == "generate" and ask_score > 0 then
|
||||
if text:match("%?%s*$") and ask_score >= gen_score * 0.5 then
|
||||
winner = { type = "ask", score = ask_score, keywords = ask_kw }
|
||||
end
|
||||
end
|
||||
|
||||
-- Determine if "explain" vs "ask" (explain needs more context)
|
||||
local intent_type = winner.type
|
||||
if intent_type == "ask" then
|
||||
-- "explain" if asking about how something works, otherwise "ask"
|
||||
if text:match("explain") or text:match("how%s+does") or text:match("walk%s+me%s+through") then
|
||||
intent_type = "explain"
|
||||
end
|
||||
end
|
||||
|
||||
-- Normalize confidence to 0-1 range (cap at reasonable max)
|
||||
local confidence = math.min(winner.score / 2, 1.0)
|
||||
|
||||
-- Check if exploration is needed (full project indexing)
|
||||
local needs_exploration = explore_score >= 0.9
|
||||
|
||||
---@type Intent
|
||||
local intent = {
|
||||
type = intent_type,
|
||||
confidence = confidence,
|
||||
needs_project_context = proj_score > 0.5 or needs_exploration,
|
||||
needs_brain_context = intent_type == "ask" or intent_type == "explain",
|
||||
needs_exploration = needs_exploration,
|
||||
keywords = winner.keywords,
|
||||
}
|
||||
|
||||
return intent
|
||||
end
|
||||
|
||||
--- Get prompt type for system prompt selection
|
||||
---@param intent Intent Detected intent
|
||||
---@return string Prompt type for prompts.system
|
||||
function M.get_prompt_type(intent)
|
||||
local mapping = {
|
||||
ask = "ask",
|
||||
explain = "ask", -- Uses same prompt as ask
|
||||
generate = "code_generation",
|
||||
refactor = "refactor",
|
||||
document = "document",
|
||||
test = "test",
|
||||
}
|
||||
return mapping[intent.type] or "ask"
|
||||
end
|
||||
|
||||
--- Check if intent requires code output
|
||||
---@param intent Intent
|
||||
---@return boolean
|
||||
function M.produces_code(intent)
|
||||
local code_intents = {
|
||||
generate = true,
|
||||
refactor = true,
|
||||
document = true, -- Documentation is code (comments)
|
||||
test = true,
|
||||
}
|
||||
return code_intents[intent.type] or false
|
||||
end
|
||||
|
||||
return M
|
||||
456
lua/codetyper/features/completion/inject.lua
Normal file
456
lua/codetyper/features/completion/inject.lua
Normal file
@@ -0,0 +1,456 @@
|
||||
---@mod codetyper.agent.inject Smart code injection with import handling
|
||||
---@brief [[
|
||||
--- Intelligent code injection that properly handles imports, merging them
|
||||
--- into existing import sections instead of blindly appending.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class ImportConfig
|
||||
---@field pattern string Lua pattern to match import statements
|
||||
---@field multi_line boolean Whether imports can span multiple lines
|
||||
---@field sort_key function|nil Function to extract sort key from import
|
||||
---@field group_by function|nil Function to group imports
|
||||
|
||||
---@class ParsedCode
|
||||
---@field imports string[] Import statements
|
||||
---@field body string[] Non-import code lines
|
||||
---@field import_lines table<number, boolean> Map of line numbers that are imports
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local languages = require("codetyper.params.agents.languages")
|
||||
local import_patterns = languages.import_patterns
|
||||
|
||||
--- Check if a line is an import statement for the given language
|
||||
---@param line string
|
||||
---@param patterns table[] Import patterns for the language
|
||||
---@return boolean is_import
|
||||
---@return boolean is_multi_line
|
||||
local function is_import_line(line, patterns)
|
||||
for _, p in ipairs(patterns) do
|
||||
if line:match(p.pattern) then
|
||||
return true, p.multi_line or false
|
||||
end
|
||||
end
|
||||
return false, false
|
||||
end
|
||||
|
||||
|
||||
--- Check if a line ends a multi-line import
|
||||
---@param line string
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
local function ends_multiline_import(line, filetype)
|
||||
return utils.ends_multiline_import(line, filetype)
|
||||
end
|
||||
|
||||
--- Parse code into imports and body
|
||||
---@param code string|string[] Code to parse
|
||||
---@param filetype string File type/extension
|
||||
---@return ParsedCode
|
||||
function M.parse_code(code, filetype)
|
||||
local lines
|
||||
if type(code) == "string" then
|
||||
lines = vim.split(code, "\n", { plain = true })
|
||||
else
|
||||
lines = code
|
||||
end
|
||||
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local result = {
|
||||
imports = {},
|
||||
body = {},
|
||||
import_lines = {},
|
||||
}
|
||||
|
||||
local in_multiline_import = false
|
||||
local current_import_lines = {}
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline_import then
|
||||
-- Continue collecting multi-line import
|
||||
table.insert(current_import_lines, line)
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
-- Complete the multi-line import
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
for j = i - #current_import_lines + 1, i do
|
||||
result.import_lines[j] = true
|
||||
end
|
||||
current_import_lines = {}
|
||||
in_multiline_import = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
result.import_lines[i] = true
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
-- Start of multi-line import
|
||||
in_multiline_import = true
|
||||
current_import_lines = { line }
|
||||
else
|
||||
-- Single-line import
|
||||
table.insert(result.imports, line)
|
||||
end
|
||||
else
|
||||
-- Non-import line
|
||||
table.insert(result.body, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle unclosed multi-line import (shouldn't happen with well-formed code)
|
||||
if #current_import_lines > 0 then
|
||||
table.insert(result.imports, table.concat(current_import_lines, "\n"))
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Find the import section range in a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return number|nil start_line First import line (1-indexed)
|
||||
---@return number|nil end_line Last import line (1-indexed)
|
||||
function M.find_import_section(bufnr, filetype)
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local patterns = import_patterns[filetype] or import_patterns.javascript
|
||||
|
||||
local first_import = nil
|
||||
local last_import = nil
|
||||
local in_multiline = false
|
||||
local consecutive_non_import = 0
|
||||
local max_gap = 3 -- Allow up to 3 blank/comment lines between imports
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
if in_multiline then
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if ends_multiline_import(line, filetype) then
|
||||
in_multiline = false
|
||||
end
|
||||
else
|
||||
local is_import, is_multi = is_import_line(line, patterns)
|
||||
|
||||
if is_import then
|
||||
if not first_import then
|
||||
first_import = i
|
||||
end
|
||||
last_import = i
|
||||
consecutive_non_import = 0
|
||||
|
||||
if is_multi and not ends_multiline_import(line, filetype) then
|
||||
in_multiline = true
|
||||
end
|
||||
elseif utils.is_empty_or_comment(line, filetype) then
|
||||
-- Allow gaps in import section
|
||||
if first_import then
|
||||
consecutive_non_import = consecutive_non_import + 1
|
||||
if consecutive_non_import > max_gap then
|
||||
-- Too many non-import lines, import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
else
|
||||
-- Non-import, non-empty line
|
||||
if first_import then
|
||||
-- Import section has ended
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return first_import, last_import
|
||||
end
|
||||
|
||||
--- Get existing imports from a buffer
|
||||
---@param bufnr number Buffer number
|
||||
---@param filetype string
|
||||
---@return string[] Existing import statements
|
||||
function M.get_existing_imports(bufnr, filetype)
|
||||
local start_line, end_line = M.find_import_section(bufnr, filetype)
|
||||
if not start_line then
|
||||
return {}
|
||||
end
|
||||
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false)
|
||||
local parsed = M.parse_code(lines, filetype)
|
||||
return parsed.imports
|
||||
end
|
||||
|
||||
--- Normalize an import for comparison (remove whitespace variations)
|
||||
---@param import_str string
|
||||
---@return string
|
||||
local function normalize_import(import_str)
|
||||
-- Remove trailing semicolon for comparison
|
||||
local normalized = import_str:gsub(";%s*$", "")
|
||||
-- Remove all whitespace around braces, commas, colons
|
||||
normalized = normalized:gsub("%s*{%s*", "{")
|
||||
normalized = normalized:gsub("%s*}%s*", "}")
|
||||
normalized = normalized:gsub("%s*,%s*", ",")
|
||||
normalized = normalized:gsub("%s*:%s*", ":")
|
||||
-- Collapse multiple whitespace to single space
|
||||
normalized = normalized:gsub("%s+", " ")
|
||||
-- Trim leading/trailing whitespace
|
||||
normalized = normalized:match("^%s*(.-)%s*$")
|
||||
return normalized
|
||||
end
|
||||
|
||||
--- Check if two imports are duplicates
|
||||
---@param import1 string
|
||||
---@param import2 string
|
||||
---@return boolean
|
||||
local function are_duplicate_imports(import1, import2)
|
||||
return normalize_import(import1) == normalize_import(import2)
|
||||
end
|
||||
|
||||
--- Merge new imports with existing ones, avoiding duplicates
|
||||
---@param existing string[] Existing imports
|
||||
---@param new_imports string[] New imports to merge
|
||||
---@return string[] Merged imports
|
||||
function M.merge_imports(existing, new_imports)
|
||||
local merged = {}
|
||||
local seen = {}
|
||||
|
||||
-- Add existing imports
|
||||
for _, imp in ipairs(existing) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new imports that aren't duplicates
|
||||
for _, imp in ipairs(new_imports) do
|
||||
local normalized = normalize_import(imp)
|
||||
if not seen[normalized] then
|
||||
seen[normalized] = true
|
||||
table.insert(merged, imp)
|
||||
end
|
||||
end
|
||||
|
||||
return merged
|
||||
end
|
||||
|
||||
--- Sort imports by their source/module
|
||||
---@param imports string[]
|
||||
---@param filetype string
|
||||
---@return string[]
|
||||
function M.sort_imports(imports, filetype)
|
||||
-- Group imports: stdlib/builtin first, then third-party, then local
|
||||
local builtin = {}
|
||||
local third_party = {}
|
||||
local local_imports = {}
|
||||
|
||||
for _, imp in ipairs(imports) do
|
||||
local category = utils.classify_import(imp, filetype)
|
||||
|
||||
if category == "builtin" then
|
||||
table.insert(builtin, imp)
|
||||
elseif category == "local" then
|
||||
table.insert(local_imports, imp)
|
||||
else
|
||||
table.insert(third_party, imp)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort each group alphabetically
|
||||
table.sort(builtin)
|
||||
table.sort(third_party)
|
||||
table.sort(local_imports)
|
||||
|
||||
-- Combine with proper spacing
|
||||
local result = {}
|
||||
|
||||
for _, imp in ipairs(builtin) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #builtin > 0 and (#third_party > 0 or #local_imports > 0) then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(third_party) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
if #third_party > 0 and #local_imports > 0 then
|
||||
table.insert(result, "") -- Blank line between groups
|
||||
end
|
||||
|
||||
for _, imp in ipairs(local_imports) do
|
||||
table.insert(result, imp)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
---@class InjectResult
|
||||
---@field success boolean
|
||||
---@field imports_added number Number of new imports added
|
||||
---@field imports_merged boolean Whether imports were merged into existing section
|
||||
---@field body_lines number Number of body lines injected
|
||||
|
||||
--- Smart inject code into a buffer, properly handling imports
|
||||
---@param bufnr number Target buffer
|
||||
---@param code string|string[] Code to inject
|
||||
---@param opts table Options: { strategy: "append"|"replace"|"insert", range: {start_line, end_line}|nil, filetype: string|nil, sort_imports: boolean|nil }
|
||||
---@return InjectResult
|
||||
function M.inject(bufnr, code, opts)
|
||||
opts = opts or {}
|
||||
|
||||
if not vim.api.nvim_buf_is_valid(bufnr) then
|
||||
return { success = false, imports_added = 0, imports_merged = false, body_lines = 0 }
|
||||
end
|
||||
|
||||
-- Get filetype
|
||||
local filetype = opts.filetype
|
||||
if not filetype then
|
||||
local bufname = vim.api.nvim_buf_get_name(bufnr)
|
||||
filetype = vim.fn.fnamemodify(bufname, ":e")
|
||||
end
|
||||
|
||||
-- Parse the code to separate imports from body
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
|
||||
local result = {
|
||||
success = true,
|
||||
imports_added = 0,
|
||||
imports_merged = false,
|
||||
body_lines = #parsed.body,
|
||||
}
|
||||
|
||||
-- Handle imports first if there are any
|
||||
if #parsed.imports > 0 then
|
||||
local import_start, import_end = M.find_import_section(bufnr, filetype)
|
||||
|
||||
if import_start then
|
||||
-- Merge with existing import section
|
||||
local existing_imports = M.get_existing_imports(bufnr, filetype)
|
||||
local merged = M.merge_imports(existing_imports, parsed.imports)
|
||||
|
||||
-- Count how many new imports were actually added
|
||||
result.imports_added = #merged - #existing_imports
|
||||
result.imports_merged = true
|
||||
|
||||
-- Optionally sort imports
|
||||
if opts.sort_imports ~= false then
|
||||
merged = M.sort_imports(merged, filetype)
|
||||
end
|
||||
|
||||
-- Convert back to lines (handling multi-line imports)
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(merged) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
|
||||
-- Replace the import section
|
||||
vim.api.nvim_buf_set_lines(bufnr, import_start - 1, import_end, false, import_lines)
|
||||
|
||||
-- Adjust line numbers for body injection
|
||||
local lines_diff = #import_lines - (import_end - import_start + 1)
|
||||
if opts.range and opts.range.start_line and opts.range.start_line > import_end then
|
||||
opts.range.start_line = opts.range.start_line + lines_diff
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + lines_diff
|
||||
end
|
||||
end
|
||||
else
|
||||
-- No existing import section, add imports at the top
|
||||
-- Find the first non-comment, non-empty line
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local insert_at = 0
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
local trimmed = line:match("^%s*(.-)%s*$")
|
||||
-- Skip shebang, docstrings, and initial comments
|
||||
if trimmed ~= "" and not trimmed:match("^#!")
|
||||
and not trimmed:match("^['\"]") and not utils.is_empty_or_comment(line, filetype) then
|
||||
insert_at = i - 1
|
||||
break
|
||||
end
|
||||
insert_at = i
|
||||
end
|
||||
|
||||
-- Add imports with a trailing blank line
|
||||
local import_lines = {}
|
||||
for _, imp in ipairs(parsed.imports) do
|
||||
for _, line in ipairs(vim.split(imp, "\n", { plain = true })) do
|
||||
table.insert(import_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(import_lines, "") -- Blank line after imports
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_at, insert_at, false, import_lines)
|
||||
result.imports_added = #parsed.imports
|
||||
result.imports_merged = false
|
||||
|
||||
-- Adjust body injection range
|
||||
if opts.range and opts.range.start_line then
|
||||
opts.range.start_line = opts.range.start_line + #import_lines
|
||||
if opts.range.end_line then
|
||||
opts.range.end_line = opts.range.end_line + #import_lines
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle body (non-import) code
|
||||
if #parsed.body > 0 then
|
||||
-- Filter out empty leading/trailing lines from body
|
||||
local body_lines = parsed.body
|
||||
while #body_lines > 0 and body_lines[1]:match("^%s*$") do
|
||||
table.remove(body_lines, 1)
|
||||
end
|
||||
while #body_lines > 0 and body_lines[#body_lines]:match("^%s*$") do
|
||||
table.remove(body_lines)
|
||||
end
|
||||
|
||||
if #body_lines > 0 then
|
||||
local line_count = vim.api.nvim_buf_line_count(bufnr)
|
||||
local strategy = opts.strategy or "append"
|
||||
|
||||
if strategy == "replace" and opts.range then
|
||||
local start_line = math.max(1, opts.range.start_line)
|
||||
local end_line = math.min(line_count, opts.range.end_line)
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, body_lines)
|
||||
elseif strategy == "insert" and opts.range then
|
||||
local insert_line = math.max(0, math.min(line_count, opts.range.start_line - 1))
|
||||
vim.api.nvim_buf_set_lines(bufnr, insert_line, insert_line, false, body_lines)
|
||||
else
|
||||
-- Default: append
|
||||
local last_line = vim.api.nvim_buf_get_lines(bufnr, line_count - 1, line_count, false)[1] or ""
|
||||
if last_line:match("%S") then
|
||||
-- Add blank line for spacing
|
||||
table.insert(body_lines, 1, "")
|
||||
end
|
||||
vim.api.nvim_buf_set_lines(bufnr, line_count, line_count, false, body_lines)
|
||||
end
|
||||
|
||||
result.body_lines = #body_lines
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Check if code contains imports
|
||||
---@param code string|string[]
|
||||
---@param filetype string
|
||||
---@return boolean
|
||||
function M.has_imports(code, filetype)
|
||||
local parsed = M.parse_code(code, filetype)
|
||||
return #parsed.imports > 0
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -5,7 +5,7 @@
|
||||
local M = {}
|
||||
|
||||
local parser = require("codetyper.parser")
|
||||
local utils = require("codetyper.utils")
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Get list of files for completion
|
||||
---@param prefix string Prefix to filter files
|
||||
491
lua/codetyper/features/completion/suggestion.lua
Normal file
491
lua/codetyper/features/completion/suggestion.lua
Normal file
@@ -0,0 +1,491 @@
|
||||
---@mod codetyper.suggestion Inline ghost text suggestions
|
||||
---@brief [[
|
||||
--- Provides Copilot-style inline suggestions with ghost text.
|
||||
--- Uses Copilot when available, falls back to codetyper's own suggestions.
|
||||
--- Shows suggestions as grayed-out text that can be accepted with Tab.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
---@class SuggestionState
|
||||
---@field current_suggestion string|nil Current suggestion text
|
||||
---@field suggestions string[] List of available suggestions
|
||||
---@field current_index number Current suggestion index
|
||||
---@field extmark_id number|nil Virtual text extmark ID
|
||||
---@field bufnr number|nil Buffer where suggestion is shown
|
||||
---@field line number|nil Line where suggestion is shown
|
||||
---@field col number|nil Column where suggestion starts
|
||||
---@field timer any|nil Debounce timer
|
||||
---@field using_copilot boolean Whether currently using copilot
|
||||
|
||||
local state = {
|
||||
current_suggestion = nil,
|
||||
suggestions = {},
|
||||
current_index = 0,
|
||||
extmark_id = nil,
|
||||
bufnr = nil,
|
||||
line = nil,
|
||||
col = nil,
|
||||
timer = nil,
|
||||
using_copilot = false,
|
||||
}
|
||||
|
||||
--- Namespace for virtual text
|
||||
local ns = vim.api.nvim_create_namespace("codetyper_suggestion")
|
||||
|
||||
--- Highlight group for ghost text
|
||||
local hl_group = "CmpGhostText"
|
||||
|
||||
--- Configuration
|
||||
local config = {
|
||||
enabled = true,
|
||||
auto_trigger = true,
|
||||
debounce = 150,
|
||||
use_copilot = true, -- Use copilot when available
|
||||
keymap = {
|
||||
accept = "<Tab>",
|
||||
next = "<M-]>",
|
||||
prev = "<M-[>",
|
||||
dismiss = "<C-]>",
|
||||
},
|
||||
}
|
||||
|
||||
--- Check if copilot is available and enabled
|
||||
---@return boolean, table|nil available, copilot_suggestion module
|
||||
local function get_copilot()
|
||||
if not config.use_copilot then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local ok, copilot_suggestion = pcall(require, "copilot.suggestion")
|
||||
if not ok then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
-- Check if copilot suggestion is enabled
|
||||
local ok_client, copilot_client = pcall(require, "copilot.client")
|
||||
if ok_client and copilot_client.is_disabled and copilot_client.is_disabled() then
|
||||
return false, nil
|
||||
end
|
||||
|
||||
return true, copilot_suggestion
|
||||
end
|
||||
|
||||
--- Check if suggestion is visible (copilot or codetyper)
|
||||
---@return boolean
|
||||
function M.is_visible()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
state.using_copilot = true
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check codetyper's own suggestion
|
||||
state.using_copilot = false
|
||||
return state.extmark_id ~= nil and state.current_suggestion ~= nil
|
||||
end
|
||||
|
||||
--- Clear the current suggestion
|
||||
function M.dismiss()
|
||||
-- Dismiss copilot if active
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.dismiss()
|
||||
end
|
||||
|
||||
-- Clear codetyper's suggestion
|
||||
if state.extmark_id and state.bufnr then
|
||||
pcall(vim.api.nvim_buf_del_extmark, state.bufnr, ns, state.extmark_id)
|
||||
end
|
||||
|
||||
state.current_suggestion = nil
|
||||
state.suggestions = {}
|
||||
state.current_index = 0
|
||||
state.extmark_id = nil
|
||||
state.bufnr = nil
|
||||
state.line = nil
|
||||
state.col = nil
|
||||
state.using_copilot = false
|
||||
end
|
||||
|
||||
--- Display suggestion as ghost text
|
||||
---@param suggestion string The suggestion to display
|
||||
local function display_suggestion(suggestion)
|
||||
if not suggestion or suggestion == "" then
|
||||
return
|
||||
end
|
||||
|
||||
M.dismiss()
|
||||
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local line = cursor[1] - 1
|
||||
local col = cursor[2]
|
||||
|
||||
-- Split suggestion into lines
|
||||
local lines = vim.split(suggestion, "\n", { plain = true })
|
||||
|
||||
-- Build virtual text
|
||||
local virt_text = {}
|
||||
local virt_lines = {}
|
||||
|
||||
-- First line goes inline
|
||||
if #lines > 0 then
|
||||
virt_text = { { lines[1], hl_group } }
|
||||
end
|
||||
|
||||
-- Remaining lines go below
|
||||
for i = 2, #lines do
|
||||
table.insert(virt_lines, { { lines[i], hl_group } })
|
||||
end
|
||||
|
||||
-- Create extmark with virtual text
|
||||
local opts = {
|
||||
virt_text = virt_text,
|
||||
virt_text_pos = "overlay",
|
||||
hl_mode = "combine",
|
||||
}
|
||||
|
||||
if #virt_lines > 0 then
|
||||
opts.virt_lines = virt_lines
|
||||
end
|
||||
|
||||
state.extmark_id = vim.api.nvim_buf_set_extmark(bufnr, ns, line, col, opts)
|
||||
state.bufnr = bufnr
|
||||
state.line = line
|
||||
state.col = col
|
||||
state.current_suggestion = suggestion
|
||||
end
|
||||
|
||||
--- Accept the current suggestion
|
||||
---@return boolean Whether a suggestion was accepted
|
||||
function M.accept()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.accept()
|
||||
state.using_copilot = false
|
||||
return true
|
||||
end
|
||||
|
||||
-- Accept codetyper's suggestion
|
||||
if not M.is_visible() then
|
||||
return false
|
||||
end
|
||||
|
||||
local suggestion = state.current_suggestion
|
||||
local bufnr = state.bufnr
|
||||
local line = state.line
|
||||
local col = state.col
|
||||
|
||||
M.dismiss()
|
||||
|
||||
if suggestion and bufnr and line ~= nil and col ~= nil then
|
||||
-- Get current line content
|
||||
local current_line = vim.api.nvim_buf_get_lines(bufnr, line, line + 1, false)[1] or ""
|
||||
|
||||
-- Split suggestion into lines
|
||||
local suggestion_lines = vim.split(suggestion, "\n", { plain = true })
|
||||
|
||||
if #suggestion_lines == 1 then
|
||||
-- Single line - insert at cursor
|
||||
local new_line = current_line:sub(1, col) .. suggestion .. current_line:sub(col + 1)
|
||||
vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, { new_line })
|
||||
-- Move cursor to end of inserted text
|
||||
vim.api.nvim_win_set_cursor(0, { line + 1, col + #suggestion })
|
||||
else
|
||||
-- Multi-line - insert at cursor
|
||||
local first_line = current_line:sub(1, col) .. suggestion_lines[1]
|
||||
local last_line = suggestion_lines[#suggestion_lines] .. current_line:sub(col + 1)
|
||||
|
||||
local new_lines = { first_line }
|
||||
for i = 2, #suggestion_lines - 1 do
|
||||
table.insert(new_lines, suggestion_lines[i])
|
||||
end
|
||||
table.insert(new_lines, last_line)
|
||||
|
||||
vim.api.nvim_buf_set_lines(bufnr, line, line + 1, false, new_lines)
|
||||
-- Move cursor to end of last line
|
||||
vim.api.nvim_win_set_cursor(0, { line + #new_lines, #suggestion_lines[#suggestion_lines] })
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Show next suggestion
|
||||
function M.next()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.next()
|
||||
return
|
||||
end
|
||||
|
||||
-- Codetyper's suggestions
|
||||
if #state.suggestions <= 1 then
|
||||
return
|
||||
end
|
||||
|
||||
state.current_index = (state.current_index % #state.suggestions) + 1
|
||||
display_suggestion(state.suggestions[state.current_index])
|
||||
end
|
||||
|
||||
--- Show previous suggestion
|
||||
function M.prev()
|
||||
-- Check copilot first
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
copilot_suggestion.prev()
|
||||
return
|
||||
end
|
||||
|
||||
-- Codetyper's suggestions
|
||||
if #state.suggestions <= 1 then
|
||||
return
|
||||
end
|
||||
|
||||
state.current_index = state.current_index - 1
|
||||
if state.current_index < 1 then
|
||||
state.current_index = #state.suggestions
|
||||
end
|
||||
display_suggestion(state.suggestions[state.current_index])
|
||||
end
|
||||
|
||||
--- Get suggestions from brain/indexer
|
||||
---@param prefix string Current word prefix
|
||||
---@param context table Context info
|
||||
---@return string[] suggestions
|
||||
local function get_suggestions(prefix, context)
|
||||
local suggestions = {}
|
||||
|
||||
-- Get completions from brain
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if ok_brain and brain.is_initialized and brain.is_initialized() then
|
||||
local result = brain.query({
|
||||
query = prefix,
|
||||
max_results = 5,
|
||||
types = { "pattern" },
|
||||
})
|
||||
|
||||
if result and result.nodes then
|
||||
for _, node in ipairs(result.nodes) do
|
||||
if node.c and node.c.code then
|
||||
table.insert(suggestions, node.c.code)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get completions from indexer
|
||||
local ok_indexer, indexer = pcall(require, "codetyper.indexer")
|
||||
if ok_indexer then
|
||||
local index = indexer.load_index()
|
||||
if index and index.symbols then
|
||||
for symbol, _ in pairs(index.symbols) do
|
||||
if symbol:lower():find(prefix:lower(), 1, true) and symbol ~= prefix then
|
||||
-- Just complete the symbol name
|
||||
local completion = symbol:sub(#prefix + 1)
|
||||
if completion ~= "" then
|
||||
table.insert(suggestions, completion)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Buffer-based completions
|
||||
local bufnr = vim.api.nvim_get_current_buf()
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local seen = {}
|
||||
|
||||
for _, line in ipairs(lines) do
|
||||
for word in line:gmatch("[%a_][%w_]*") do
|
||||
if
|
||||
#word > #prefix
|
||||
and word:lower():find(prefix:lower(), 1, true) == 1
|
||||
and not seen[word]
|
||||
and word ~= prefix
|
||||
then
|
||||
seen[word] = true
|
||||
local completion = word:sub(#prefix + 1)
|
||||
if completion ~= "" then
|
||||
table.insert(suggestions, completion)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return suggestions
|
||||
end
|
||||
|
||||
--- Trigger suggestion generation
|
||||
function M.trigger()
|
||||
if not config.enabled then
|
||||
return
|
||||
end
|
||||
|
||||
-- If copilot is available and has a suggestion, don't show codetyper's
|
||||
local copilot_ok, copilot_suggestion = get_copilot()
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
-- Copilot is handling suggestions
|
||||
state.using_copilot = true
|
||||
return
|
||||
end
|
||||
|
||||
-- Cancel existing timer
|
||||
if state.timer then
|
||||
state.timer:stop()
|
||||
state.timer = nil
|
||||
end
|
||||
|
||||
-- Get current context
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
local line = vim.api.nvim_get_current_line()
|
||||
local col = cursor[2]
|
||||
local before_cursor = line:sub(1, col)
|
||||
|
||||
-- Extract prefix (word being typed)
|
||||
local prefix = before_cursor:match("[%a_][%w_]*$") or ""
|
||||
|
||||
if #prefix < 2 then
|
||||
M.dismiss()
|
||||
return
|
||||
end
|
||||
|
||||
-- Debounce - wait a bit longer to let copilot try first
|
||||
local debounce_time = copilot_ok and (config.debounce + 200) or config.debounce
|
||||
|
||||
state.timer = vim.defer_fn(function()
|
||||
-- Check again if copilot has shown something
|
||||
if copilot_ok and copilot_suggestion.is_visible() then
|
||||
state.using_copilot = true
|
||||
state.timer = nil
|
||||
return
|
||||
end
|
||||
|
||||
local suggestions = get_suggestions(prefix, {
|
||||
line = line,
|
||||
col = col,
|
||||
bufnr = vim.api.nvim_get_current_buf(),
|
||||
})
|
||||
|
||||
if #suggestions > 0 then
|
||||
state.suggestions = suggestions
|
||||
state.current_index = 1
|
||||
display_suggestion(suggestions[1])
|
||||
else
|
||||
M.dismiss()
|
||||
end
|
||||
|
||||
state.timer = nil
|
||||
end, debounce_time)
|
||||
end
|
||||
|
||||
--- Setup keymaps
|
||||
local function setup_keymaps()
|
||||
-- Accept with Tab (only when suggestion visible)
|
||||
vim.keymap.set("i", config.keymap.accept, function()
|
||||
if M.is_visible() then
|
||||
M.accept()
|
||||
return ""
|
||||
end
|
||||
-- Fallback to normal Tab behavior
|
||||
return vim.api.nvim_replace_termcodes("<Tab>", true, false, true)
|
||||
end, { expr = true, silent = true, desc = "Accept codetyper suggestion" })
|
||||
|
||||
-- Next suggestion
|
||||
vim.keymap.set("i", config.keymap.next, function()
|
||||
M.next()
|
||||
end, { silent = true, desc = "Next codetyper suggestion" })
|
||||
|
||||
-- Previous suggestion
|
||||
vim.keymap.set("i", config.keymap.prev, function()
|
||||
M.prev()
|
||||
end, { silent = true, desc = "Previous codetyper suggestion" })
|
||||
|
||||
-- Dismiss
|
||||
vim.keymap.set("i", config.keymap.dismiss, function()
|
||||
M.dismiss()
|
||||
end, { silent = true, desc = "Dismiss codetyper suggestion" })
|
||||
end
|
||||
|
||||
--- Setup autocmds for auto-trigger
|
||||
local function setup_autocmds()
|
||||
local group = vim.api.nvim_create_augroup("CodetypeSuggestion", { clear = true })
|
||||
|
||||
-- Trigger on text change in insert mode
|
||||
if config.auto_trigger then
|
||||
vim.api.nvim_create_autocmd("TextChangedI", {
|
||||
group = group,
|
||||
callback = function()
|
||||
M.trigger()
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
-- Dismiss on leaving insert mode
|
||||
vim.api.nvim_create_autocmd("InsertLeave", {
|
||||
group = group,
|
||||
callback = function()
|
||||
M.dismiss()
|
||||
end,
|
||||
})
|
||||
|
||||
-- Dismiss on cursor move (not from typing)
|
||||
vim.api.nvim_create_autocmd("CursorMovedI", {
|
||||
group = group,
|
||||
callback = function()
|
||||
-- Only dismiss if cursor moved significantly
|
||||
if state.line ~= nil then
|
||||
local cursor = vim.api.nvim_win_get_cursor(0)
|
||||
if cursor[1] - 1 ~= state.line then
|
||||
M.dismiss()
|
||||
end
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Setup highlight group
|
||||
local function setup_highlights()
|
||||
-- Use Comment highlight or define custom ghost text style
|
||||
vim.api.nvim_set_hl(0, hl_group, { link = "Comment" })
|
||||
end
|
||||
|
||||
--- Setup the suggestion system
|
||||
---@param opts? table Configuration options
|
||||
function M.setup(opts)
|
||||
if opts then
|
||||
config = vim.tbl_deep_extend("force", config, opts)
|
||||
end
|
||||
|
||||
setup_highlights()
|
||||
setup_keymaps()
|
||||
setup_autocmds()
|
||||
end
|
||||
|
||||
--- Enable suggestions
|
||||
function M.enable()
|
||||
config.enabled = true
|
||||
end
|
||||
|
||||
--- Disable suggestions
|
||||
function M.disable()
|
||||
config.enabled = false
|
||||
M.dismiss()
|
||||
end
|
||||
|
||||
--- Toggle suggestions
|
||||
function M.toggle()
|
||||
if config.enabled then
|
||||
M.disable()
|
||||
else
|
||||
M.enable()
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
585
lua/codetyper/features/indexer/analyzer.lua
Normal file
585
lua/codetyper/features/indexer/analyzer.lua
Normal file
@@ -0,0 +1,585 @@
|
||||
---@mod codetyper.indexer.analyzer Code analyzer using Tree-sitter
|
||||
---@brief [[
|
||||
--- Analyzes source files to extract functions, classes, exports, and imports.
|
||||
--- Uses Tree-sitter when available, falls back to pattern matching.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
local scanner = require("codetyper.features.indexer.scanner")
|
||||
|
||||
--- Language-specific query patterns for Tree-sitter
|
||||
local TS_QUERIES = {
|
||||
lua = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(function_definition) @func
|
||||
(local_function name: (identifier) @name) @func
|
||||
(assignment_statement
|
||||
(variable_list name: (identifier) @name)
|
||||
(expression_list value: (function_definition) @func))
|
||||
]],
|
||||
exports = [[
|
||||
(return_statement (expression_list (table_constructor))) @export
|
||||
]],
|
||||
},
|
||||
typescript = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_definition name: (property_identifier) @name) @func
|
||||
(arrow_function) @func
|
||||
(lexical_declaration
|
||||
(variable_declarator name: (identifier) @name value: (arrow_function) @func))
|
||||
]],
|
||||
exports = [[
|
||||
(export_statement) @export
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
]],
|
||||
},
|
||||
javascript = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_definition name: (property_identifier) @name) @func
|
||||
(arrow_function) @func
|
||||
]],
|
||||
exports = [[
|
||||
(export_statement) @export
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
]],
|
||||
},
|
||||
python = {
|
||||
functions = [[
|
||||
(function_definition name: (identifier) @name) @func
|
||||
]],
|
||||
classes = [[
|
||||
(class_definition name: (identifier) @name) @class
|
||||
]],
|
||||
imports = [[
|
||||
(import_statement) @import
|
||||
(import_from_statement) @import
|
||||
]],
|
||||
},
|
||||
go = {
|
||||
functions = [[
|
||||
(function_declaration name: (identifier) @name) @func
|
||||
(method_declaration name: (field_identifier) @name) @func
|
||||
]],
|
||||
imports = [[
|
||||
(import_declaration) @import
|
||||
]],
|
||||
},
|
||||
rust = {
|
||||
functions = [[
|
||||
(function_item name: (identifier) @name) @func
|
||||
]],
|
||||
imports = [[
|
||||
(use_declaration) @import
|
||||
]],
|
||||
},
|
||||
}
|
||||
|
||||
-- Forward declaration for analyze_tree_generic (defined below)
|
||||
local analyze_tree_generic
|
||||
|
||||
--- Hash file content for change detection
|
||||
---@param content string
|
||||
---@return string
|
||||
local function hash_content(content)
|
||||
local hash = 0
|
||||
for i = 1, math.min(#content, 10000) do
|
||||
hash = (hash * 31 + string.byte(content, i)) % 2147483647
|
||||
end
|
||||
return string.format("%08x", hash)
|
||||
end
|
||||
|
||||
--- Try to get Tree-sitter parser for a language
|
||||
---@param lang string
|
||||
---@return boolean
|
||||
local function has_ts_parser(lang)
|
||||
local ok = pcall(vim.treesitter.language.inspect, lang)
|
||||
return ok
|
||||
end
|
||||
|
||||
--- Analyze file using Tree-sitter
|
||||
---@param filepath string
|
||||
---@param lang string
|
||||
---@param content string
|
||||
---@return table|nil
|
||||
local function analyze_with_treesitter(filepath, lang, content)
|
||||
if not has_ts_parser(lang) then
|
||||
return nil
|
||||
end
|
||||
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
-- Create a temporary buffer for parsing
|
||||
local bufnr = vim.api.nvim_create_buf(false, true)
|
||||
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, vim.split(content, "\n"))
|
||||
|
||||
local ok, parser = pcall(vim.treesitter.get_parser, bufnr, lang)
|
||||
if not ok or not parser then
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return nil
|
||||
end
|
||||
|
||||
local tree = parser:parse()[1]
|
||||
if not tree then
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return nil
|
||||
end
|
||||
|
||||
local root = tree:root()
|
||||
local queries = TS_QUERIES[lang]
|
||||
|
||||
if not queries then
|
||||
-- Fallback: walk tree manually for common patterns
|
||||
result = analyze_tree_generic(root, bufnr)
|
||||
else
|
||||
-- Use language-specific queries
|
||||
if queries.functions then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.functions)
|
||||
if query_ok then
|
||||
for id, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local capture_name = query.captures[id]
|
||||
if capture_name == "func" or capture_name == "name" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name = nil
|
||||
|
||||
-- Try to get name from sibling capture or child
|
||||
if capture_name == "func" then
|
||||
local name_node = node:field("name")[1]
|
||||
if name_node then
|
||||
name = vim.treesitter.get_node_text(name_node, bufnr)
|
||||
end
|
||||
else
|
||||
name = vim.treesitter.get_node_text(node, bufnr)
|
||||
end
|
||||
|
||||
if name and not vim.tbl_contains(vim.tbl_map(function(f)
|
||||
return f.name
|
||||
end, result.functions), name) then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.classes then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.classes)
|
||||
if query_ok then
|
||||
for id, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local capture_name = query.captures[id]
|
||||
if capture_name == "class" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.exports then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.exports)
|
||||
if query_ok then
|
||||
for _, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local text = vim.treesitter.get_node_text(node, bufnr)
|
||||
local start_row, _, _, _ = node:range()
|
||||
|
||||
-- Extract export names (simplified)
|
||||
local names = {}
|
||||
for name in text:gmatch("export%s+[%w_]+%s+([%w_]+)") do
|
||||
table.insert(names, name)
|
||||
end
|
||||
for name in text:gmatch("export%s*{([^}]+)}") do
|
||||
for n in name:gmatch("([%w_]+)") do
|
||||
table.insert(names, n)
|
||||
end
|
||||
end
|
||||
|
||||
for _, name in ipairs(names) do
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "unknown",
|
||||
line = start_row + 1,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if queries.imports then
|
||||
local query_ok, query = pcall(vim.treesitter.query.parse, lang, queries.imports)
|
||||
if query_ok then
|
||||
for _, node in query:iter_captures(root, bufnr, 0, -1) do
|
||||
local text = vim.treesitter.get_node_text(node, bufnr)
|
||||
local start_row, _, _, _ = node:range()
|
||||
|
||||
-- Extract import source
|
||||
local source = text:match('["\']([^"\']+)["\']')
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = start_row + 1,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
vim.api.nvim_buf_delete(bufnr, { force = true })
|
||||
return result
|
||||
end
|
||||
|
||||
--- Generic tree analysis for unsupported languages
|
||||
---@param root TSNode
|
||||
---@param bufnr number
|
||||
---@return table
|
||||
analyze_tree_generic = function(root, bufnr)
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
local function visit(node)
|
||||
local node_type = node:type()
|
||||
|
||||
-- Common function patterns
|
||||
if
|
||||
node_type:match("function")
|
||||
or node_type:match("method")
|
||||
or node_type == "arrow_function"
|
||||
or node_type == "func_literal"
|
||||
then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
|
||||
-- Common class patterns
|
||||
if node_type:match("class") or node_type == "struct_item" or node_type == "impl_item" then
|
||||
local start_row, _, end_row, _ = node:range()
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and vim.treesitter.get_node_text(name_node, bufnr) or "anonymous"
|
||||
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = start_row + 1,
|
||||
end_line = end_row + 1,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
|
||||
-- Recurse into children
|
||||
for child in node:iter_children() do
|
||||
visit(child)
|
||||
end
|
||||
end
|
||||
|
||||
visit(root)
|
||||
return result
|
||||
end
|
||||
|
||||
--- Analyze file using pattern matching (fallback)
|
||||
---@param content string
|
||||
---@param lang string
|
||||
---@return table
|
||||
local function analyze_with_patterns(content, lang)
|
||||
local result = {
|
||||
functions = {},
|
||||
classes = {},
|
||||
exports = {},
|
||||
imports = {},
|
||||
}
|
||||
|
||||
local lines = vim.split(content, "\n")
|
||||
|
||||
-- Language-specific patterns
|
||||
local patterns = {
|
||||
lua = {
|
||||
func_start = "^%s*local?%s*function%s+([%w_%.]+)",
|
||||
func_assign = "^%s*([%w_%.]+)%s*=%s*function",
|
||||
module_return = "^return%s+M",
|
||||
},
|
||||
javascript = {
|
||||
func_start = "^%s*function%s+([%w_]+)",
|
||||
func_arrow = "^%s*const%s+([%w_]+)%s*=%s*",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
export_line = "^%s*export%s+",
|
||||
import_line = "^%s*import%s+",
|
||||
},
|
||||
typescript = {
|
||||
func_start = "^%s*function%s+([%w_]+)",
|
||||
func_arrow = "^%s*const%s+([%w_]+)%s*=%s*",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
export_line = "^%s*export%s+",
|
||||
import_line = "^%s*import%s+",
|
||||
},
|
||||
python = {
|
||||
func_start = "^%s*def%s+([%w_]+)",
|
||||
class_start = "^%s*class%s+([%w_]+)",
|
||||
import_line = "^%s*import%s+",
|
||||
from_import = "^%s*from%s+",
|
||||
},
|
||||
go = {
|
||||
func_start = "^func%s+([%w_]+)",
|
||||
method_start = "^func%s+%([^%)]+%)%s+([%w_]+)",
|
||||
import_line = "^import%s+",
|
||||
},
|
||||
rust = {
|
||||
func_start = "^%s*pub?%s*fn%s+([%w_]+)",
|
||||
struct_start = "^%s*pub?%s*struct%s+([%w_]+)",
|
||||
impl_start = "^%s*impl%s+([%w_<>]+)",
|
||||
use_line = "^%s*use%s+",
|
||||
},
|
||||
}
|
||||
|
||||
local lang_patterns = patterns[lang] or patterns.javascript
|
||||
|
||||
for i, line in ipairs(lines) do
|
||||
-- Functions
|
||||
if lang_patterns.func_start then
|
||||
local name = line:match(lang_patterns.func_start)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.func_arrow then
|
||||
local name = line:match(lang_patterns.func_arrow)
|
||||
if name and line:match("=>") then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.func_assign then
|
||||
local name = line:match(lang_patterns.func_assign)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.method_start then
|
||||
local name = line:match(lang_patterns.method_start)
|
||||
if name then
|
||||
table.insert(result.functions, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
params = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Classes
|
||||
if lang_patterns.class_start then
|
||||
local name = line:match(lang_patterns.class_start)
|
||||
if name then
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.struct_start then
|
||||
local name = line:match(lang_patterns.struct_start)
|
||||
if name then
|
||||
table.insert(result.classes, {
|
||||
name = name,
|
||||
line = i,
|
||||
end_line = i,
|
||||
methods = {},
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Exports
|
||||
if lang_patterns.export_line and line:match(lang_patterns.export_line) then
|
||||
local name = line:match("export%s+[%w_]+%s+([%w_]+)")
|
||||
or line:match("export%s+default%s+([%w_]+)")
|
||||
or line:match("export%s+{%s*([%w_]+)")
|
||||
if name then
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "unknown",
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
-- Imports
|
||||
if lang_patterns.import_line and line:match(lang_patterns.import_line) then
|
||||
local source = line:match('["\']([^"\']+)["\']')
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.from_import and line:match(lang_patterns.from_import) then
|
||||
local source = line:match("from%s+([%w_%.]+)")
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if lang_patterns.use_line and line:match(lang_patterns.use_line) then
|
||||
local source = line:match("use%s+([%w_:]+)")
|
||||
if source then
|
||||
table.insert(result.imports, {
|
||||
source = source,
|
||||
names = {},
|
||||
line = i,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- For Lua, infer exports from module table
|
||||
if lang == "lua" then
|
||||
for _, func in ipairs(result.functions) do
|
||||
if func.name:match("^M%.") then
|
||||
local name = func.name:gsub("^M%.", "")
|
||||
table.insert(result.exports, {
|
||||
name = name,
|
||||
type = "function",
|
||||
line = func.line,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
--- Analyze a single file
|
||||
---@param filepath string Full path to file
|
||||
---@return FileIndex|nil
|
||||
function M.analyze_file(filepath)
|
||||
local content = utils.read_file(filepath)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local lang = scanner.get_language(filepath)
|
||||
|
||||
-- Map to Tree-sitter language names
|
||||
local ts_lang_map = {
|
||||
typescript = "typescript",
|
||||
typescriptreact = "tsx",
|
||||
javascript = "javascript",
|
||||
javascriptreact = "javascript",
|
||||
python = "python",
|
||||
go = "go",
|
||||
rust = "rust",
|
||||
lua = "lua",
|
||||
}
|
||||
|
||||
local ts_lang = ts_lang_map[lang] or lang
|
||||
|
||||
-- Try Tree-sitter first
|
||||
local analysis = analyze_with_treesitter(filepath, ts_lang, content)
|
||||
|
||||
-- Fallback to pattern matching
|
||||
if not analysis then
|
||||
analysis = analyze_with_patterns(content, lang)
|
||||
end
|
||||
|
||||
return {
|
||||
path = filepath,
|
||||
language = lang,
|
||||
hash = hash_content(content),
|
||||
exports = analysis.exports,
|
||||
imports = analysis.imports,
|
||||
functions = analysis.functions,
|
||||
classes = analysis.classes,
|
||||
last_indexed = os.time(),
|
||||
}
|
||||
end
|
||||
|
||||
--- Extract exports from a buffer
|
||||
---@param bufnr number
|
||||
---@return Export[]
|
||||
function M.extract_exports(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.exports or {}
|
||||
end
|
||||
|
||||
--- Extract functions from a buffer
|
||||
---@param bufnr number
|
||||
---@return FunctionInfo[]
|
||||
function M.extract_functions(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.functions or {}
|
||||
end
|
||||
|
||||
--- Extract imports from a buffer
|
||||
---@param bufnr number
|
||||
---@return Import[]
|
||||
function M.extract_imports(bufnr)
|
||||
local filepath = vim.api.nvim_buf_get_name(bufnr)
|
||||
local analysis = M.analyze_file(filepath)
|
||||
return analysis and analysis.imports or {}
|
||||
end
|
||||
|
||||
return M
|
||||
604
lua/codetyper/features/indexer/init.lua
Normal file
604
lua/codetyper/features/indexer/init.lua
Normal file
@@ -0,0 +1,604 @@
|
||||
---@mod codetyper.indexer Project indexer for Codetyper.nvim
|
||||
---@brief [[
|
||||
--- Indexes project structure, dependencies, and code symbols.
|
||||
--- Stores knowledge in .coder/ directory for enriching LLM context.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Index schema version for migrations
|
||||
local INDEX_VERSION = 1
|
||||
|
||||
--- Index file name
|
||||
local INDEX_FILE = "index.json"
|
||||
|
||||
--- Debounce timer for file indexing
|
||||
local index_timer = nil
|
||||
local INDEX_DEBOUNCE_MS = 500
|
||||
|
||||
--- Default indexer configuration
|
||||
local default_config = {
|
||||
enabled = true,
|
||||
auto_index = true,
|
||||
index_on_open = false,
|
||||
max_file_size = 100000,
|
||||
excluded_dirs = { "node_modules", "dist", "build", ".git", ".coder", "__pycache__", "vendor", "target" },
|
||||
index_extensions = { "lua", "ts", "tsx", "js", "jsx", "py", "go", "rs", "rb", "java", "c", "cpp", "h", "hpp" },
|
||||
memory = {
|
||||
enabled = true,
|
||||
max_memories = 1000,
|
||||
prune_threshold = 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
--- Current configuration
|
||||
---@type table
|
||||
local config = vim.deepcopy(default_config)
|
||||
|
||||
--- Cached project index
|
||||
---@type table<string, ProjectIndex>
|
||||
local index_cache = {}
|
||||
|
||||
---@class ProjectIndex
|
||||
---@field version number Index schema version
|
||||
---@field project_root string Absolute path to project
|
||||
---@field project_name string Project name
|
||||
---@field project_type string "node"|"rust"|"go"|"python"|"lua"|"unknown"
|
||||
---@field dependencies table<string, string> name -> version
|
||||
---@field dev_dependencies table<string, string> name -> version
|
||||
---@field files table<string, FileIndex> path -> FileIndex
|
||||
---@field symbols table<string, string[]> symbol -> [file paths]
|
||||
---@field last_indexed number Timestamp
|
||||
---@field stats {files: number, functions: number, classes: number, exports: number}
|
||||
|
||||
---@class FileIndex
|
||||
---@field path string Relative path from project root
|
||||
---@field language string Detected language
|
||||
---@field hash string Content hash for change detection
|
||||
---@field exports Export[] Exported symbols
|
||||
---@field imports Import[] Dependencies
|
||||
---@field functions FunctionInfo[]
|
||||
---@field classes ClassInfo[]
|
||||
---@field last_indexed number Timestamp
|
||||
|
||||
---@class Export
|
||||
---@field name string Symbol name
|
||||
---@field type string "function"|"class"|"constant"|"type"|"variable"
|
||||
---@field line number Line number
|
||||
|
||||
---@class Import
|
||||
---@field source string Import source/module
|
||||
---@field names string[] Imported names
|
||||
---@field line number Line number
|
||||
|
||||
---@class FunctionInfo
|
||||
---@field name string Function name
|
||||
---@field params string[] Parameter names
|
||||
---@field line number Start line
|
||||
---@field end_line number End line
|
||||
---@field docstring string|nil Documentation
|
||||
|
||||
---@class ClassInfo
|
||||
---@field name string Class name
|
||||
---@field methods string[] Method names
|
||||
---@field line number Start line
|
||||
---@field end_line number End line
|
||||
---@field docstring string|nil Documentation
|
||||
|
||||
--- Get the index file path
|
||||
---@return string|nil
|
||||
local function get_index_path()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. INDEX_FILE
|
||||
end
|
||||
|
||||
--- Create empty index structure
|
||||
---@return ProjectIndex
|
||||
local function create_empty_index()
|
||||
local root = utils.get_project_root()
|
||||
return {
|
||||
version = INDEX_VERSION,
|
||||
project_root = root or "",
|
||||
project_name = root and vim.fn.fnamemodify(root, ":t") or "",
|
||||
project_type = "unknown",
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
files = {},
|
||||
symbols = {},
|
||||
last_indexed = os.time(),
|
||||
stats = {
|
||||
files = 0,
|
||||
functions = 0,
|
||||
classes = 0,
|
||||
exports = 0,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Load index from disk
|
||||
---@return ProjectIndex|nil
|
||||
function M.load_index()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Check cache first
|
||||
if index_cache[root] then
|
||||
return index_cache[root]
|
||||
end
|
||||
|
||||
local path = get_index_path()
|
||||
if not path then
|
||||
return nil
|
||||
end
|
||||
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ok, index = pcall(vim.json.decode, content)
|
||||
if not ok or not index then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Validate version
|
||||
if index.version ~= INDEX_VERSION then
|
||||
-- Index needs migration or rebuild
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Cache it
|
||||
index_cache[root] = index
|
||||
return index
|
||||
end
|
||||
|
||||
--- Save index to disk
|
||||
---@param index ProjectIndex
|
||||
---@return boolean
|
||||
function M.save_index(index)
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Ensure .coder directory exists
|
||||
local coder_dir = root .. "/.coder"
|
||||
utils.ensure_dir(coder_dir)
|
||||
|
||||
local path = get_index_path()
|
||||
if not path then
|
||||
return false
|
||||
end
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, index)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
local success = utils.write_file(path, encoded)
|
||||
if success then
|
||||
-- Update cache
|
||||
index_cache[root] = index
|
||||
end
|
||||
return success
|
||||
end
|
||||
|
||||
--- Index the entire project
|
||||
---@param callback? fun(index: ProjectIndex)
|
||||
---@return ProjectIndex|nil
|
||||
function M.index_project(callback)
|
||||
local scanner = require("codetyper.features.indexer.scanner")
|
||||
local analyzer = require("codetyper.features.indexer.analyzer")
|
||||
|
||||
local index = create_empty_index()
|
||||
local root = utils.get_project_root()
|
||||
|
||||
if not root then
|
||||
if callback then
|
||||
callback(index)
|
||||
end
|
||||
return index
|
||||
end
|
||||
|
||||
-- Detect project type and parse dependencies
|
||||
index.project_type = scanner.detect_project_type(root)
|
||||
local deps = scanner.parse_dependencies(root, index.project_type)
|
||||
index.dependencies = deps.dependencies or {}
|
||||
index.dev_dependencies = deps.dev_dependencies or {}
|
||||
|
||||
-- Get all indexable files
|
||||
local files = scanner.get_indexable_files(root, config)
|
||||
|
||||
-- Index each file
|
||||
local total_functions = 0
|
||||
local total_classes = 0
|
||||
local total_exports = 0
|
||||
|
||||
for _, filepath in ipairs(files) do
|
||||
local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local file_index = analyzer.analyze_file(filepath)
|
||||
|
||||
if file_index then
|
||||
file_index.path = relative_path
|
||||
index.files[relative_path] = file_index
|
||||
|
||||
-- Update symbol index
|
||||
for _, exp in ipairs(file_index.exports or {}) do
|
||||
if not index.symbols[exp.name] then
|
||||
index.symbols[exp.name] = {}
|
||||
end
|
||||
table.insert(index.symbols[exp.name], relative_path)
|
||||
total_exports = total_exports + 1
|
||||
end
|
||||
|
||||
total_functions = total_functions + #(file_index.functions or {})
|
||||
total_classes = total_classes + #(file_index.classes or {})
|
||||
end
|
||||
end
|
||||
|
||||
-- Update stats
|
||||
index.stats = {
|
||||
files = #files,
|
||||
functions = total_functions,
|
||||
classes = total_classes,
|
||||
exports = total_exports,
|
||||
}
|
||||
index.last_indexed = os.time()
|
||||
|
||||
-- Save to disk
|
||||
M.save_index(index)
|
||||
|
||||
-- Store memories
|
||||
local memory = require("codetyper.features.indexer.memory")
|
||||
memory.store_index_summary(index)
|
||||
|
||||
-- Sync project summary to brain
|
||||
M.sync_project_to_brain(index, files, root)
|
||||
|
||||
if callback then
|
||||
callback(index)
|
||||
end
|
||||
|
||||
return index
|
||||
end
|
||||
|
||||
--- Sync project index to brain
|
||||
---@param index ProjectIndex
|
||||
---@param files string[] List of file paths
|
||||
---@param root string Project root
|
||||
function M.sync_project_to_brain(index, files, root)
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain or not brain.is_initialized or not brain.is_initialized() then
|
||||
return
|
||||
end
|
||||
|
||||
-- Store project-level pattern
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = root,
|
||||
content = {
|
||||
summary = "Project: "
|
||||
.. index.project_name
|
||||
.. " ("
|
||||
.. index.project_type
|
||||
.. ") - "
|
||||
.. index.stats.files
|
||||
.. " files",
|
||||
detail = string.format(
|
||||
"%d functions, %d classes, %d exports",
|
||||
index.stats.functions,
|
||||
index.stats.classes,
|
||||
index.stats.exports
|
||||
),
|
||||
},
|
||||
context = {
|
||||
file = root,
|
||||
project_type = index.project_type,
|
||||
dependencies = index.dependencies,
|
||||
},
|
||||
})
|
||||
|
||||
-- Store key file patterns (files with most functions/classes)
|
||||
local key_files = {}
|
||||
for path, file_index in pairs(index.files) do
|
||||
local score = #(file_index.functions or {}) + (#(file_index.classes or {}) * 2)
|
||||
if score >= 3 then
|
||||
table.insert(key_files, { path = path, index = file_index, score = score })
|
||||
end
|
||||
end
|
||||
|
||||
table.sort(key_files, function(a, b)
|
||||
return a.score > b.score
|
||||
end)
|
||||
|
||||
-- Store top 20 key files in brain
|
||||
for i, kf in ipairs(key_files) do
|
||||
if i > 20 then
|
||||
break
|
||||
end
|
||||
M.sync_to_brain(root .. "/" .. kf.path, kf.index)
|
||||
end
|
||||
end
|
||||
|
||||
--- Index a single file (incremental update)
|
||||
---@param filepath string
|
||||
---@return FileIndex|nil
|
||||
function M.index_file(filepath)
|
||||
local analyzer = require("codetyper.features.indexer.analyzer")
|
||||
local memory = require("codetyper.features.indexer.memory")
|
||||
local root = utils.get_project_root()
|
||||
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Load existing index
|
||||
local index = M.load_index() or create_empty_index()
|
||||
|
||||
-- Analyze file
|
||||
local file_index = analyzer.analyze_file(filepath)
|
||||
if not file_index then
|
||||
return nil
|
||||
end
|
||||
|
||||
local relative_path = filepath:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
file_index.path = relative_path
|
||||
|
||||
-- Remove old symbol references for this file
|
||||
for symbol, paths in pairs(index.symbols) do
|
||||
for i = #paths, 1, -1 do
|
||||
if paths[i] == relative_path then
|
||||
table.remove(paths, i)
|
||||
end
|
||||
end
|
||||
if #paths == 0 then
|
||||
index.symbols[symbol] = nil
|
||||
end
|
||||
end
|
||||
|
||||
-- Add new file index
|
||||
index.files[relative_path] = file_index
|
||||
|
||||
-- Update symbol index
|
||||
for _, exp in ipairs(file_index.exports or {}) do
|
||||
if not index.symbols[exp.name] then
|
||||
index.symbols[exp.name] = {}
|
||||
end
|
||||
table.insert(index.symbols[exp.name], relative_path)
|
||||
end
|
||||
|
||||
-- Recalculate stats
|
||||
local total_functions = 0
|
||||
local total_classes = 0
|
||||
local total_exports = 0
|
||||
local file_count = 0
|
||||
|
||||
for _, f in pairs(index.files) do
|
||||
file_count = file_count + 1
|
||||
total_functions = total_functions + #(f.functions or {})
|
||||
total_classes = total_classes + #(f.classes or {})
|
||||
total_exports = total_exports + #(f.exports or {})
|
||||
end
|
||||
|
||||
index.stats = {
|
||||
files = file_count,
|
||||
functions = total_functions,
|
||||
classes = total_classes,
|
||||
exports = total_exports,
|
||||
}
|
||||
index.last_indexed = os.time()
|
||||
|
||||
-- Save to disk
|
||||
M.save_index(index)
|
||||
|
||||
-- Store file memory
|
||||
memory.store_file_memory(relative_path, file_index)
|
||||
|
||||
-- Sync to brain if available
|
||||
M.sync_to_brain(filepath, file_index)
|
||||
|
||||
return file_index
|
||||
end
|
||||
|
||||
--- Sync file analysis to brain system
|
||||
---@param filepath string Full file path
|
||||
---@param file_index FileIndex File analysis
|
||||
function M.sync_to_brain(filepath, file_index)
|
||||
local ok_brain, brain = pcall(require, "codetyper.brain")
|
||||
if not ok_brain or not brain.is_initialized or not brain.is_initialized() then
|
||||
return
|
||||
end
|
||||
|
||||
-- Only store if file has meaningful content
|
||||
local funcs = file_index.functions or {}
|
||||
local classes = file_index.classes or {}
|
||||
if #funcs == 0 and #classes == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
-- Build summary
|
||||
local parts = {}
|
||||
if #funcs > 0 then
|
||||
local func_names = {}
|
||||
for i, f in ipairs(funcs) do
|
||||
if i <= 5 then
|
||||
table.insert(func_names, f.name)
|
||||
end
|
||||
end
|
||||
table.insert(parts, "functions: " .. table.concat(func_names, ", "))
|
||||
if #funcs > 5 then
|
||||
table.insert(parts, "(+" .. (#funcs - 5) .. " more)")
|
||||
end
|
||||
end
|
||||
if #classes > 0 then
|
||||
local class_names = {}
|
||||
for _, c in ipairs(classes) do
|
||||
table.insert(class_names, c.name)
|
||||
end
|
||||
table.insert(parts, "classes: " .. table.concat(class_names, ", "))
|
||||
end
|
||||
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
local summary = filename .. " - " .. table.concat(parts, "; ")
|
||||
|
||||
-- Learn this pattern in brain
|
||||
brain.learn({
|
||||
type = "pattern",
|
||||
file = filepath,
|
||||
content = {
|
||||
summary = summary,
|
||||
detail = #funcs .. " functions, " .. #classes .. " classes",
|
||||
},
|
||||
context = {
|
||||
file = file_index.path or filepath,
|
||||
language = file_index.language,
|
||||
functions = funcs,
|
||||
classes = classes,
|
||||
exports = file_index.exports,
|
||||
imports = file_index.imports,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
--- Schedule file indexing with debounce
|
||||
---@param filepath string
|
||||
function M.schedule_index_file(filepath)
|
||||
if not config.enabled or not config.auto_index then
|
||||
return
|
||||
end
|
||||
|
||||
-- Check if file should be indexed
|
||||
local scanner = require("codetyper.features.indexer.scanner")
|
||||
if not scanner.should_index(filepath, config) then
|
||||
return
|
||||
end
|
||||
|
||||
-- Cancel existing timer
|
||||
if index_timer then
|
||||
index_timer:stop()
|
||||
end
|
||||
|
||||
-- Schedule new index
|
||||
index_timer = vim.defer_fn(function()
|
||||
M.index_file(filepath)
|
||||
index_timer = nil
|
||||
end, INDEX_DEBOUNCE_MS)
|
||||
end
|
||||
|
||||
--- Get relevant context for a prompt
|
||||
---@param opts {file: string, intent: table|nil, prompt: string, scope: string|nil}
|
||||
---@return table Context information
|
||||
function M.get_context_for(opts)
|
||||
local memory = require("codetyper.features.indexer.memory")
|
||||
local index = M.load_index()
|
||||
|
||||
local context = {
|
||||
project_type = "unknown",
|
||||
dependencies = {},
|
||||
relevant_files = {},
|
||||
relevant_symbols = {},
|
||||
patterns = {},
|
||||
}
|
||||
|
||||
if not index then
|
||||
return context
|
||||
end
|
||||
|
||||
context.project_type = index.project_type
|
||||
context.dependencies = index.dependencies
|
||||
|
||||
-- Find relevant symbols from prompt
|
||||
local words = {}
|
||||
for word in opts.prompt:gmatch("%w+") do
|
||||
if #word > 2 then
|
||||
words[word:lower()] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- Match symbols
|
||||
for symbol, files in pairs(index.symbols) do
|
||||
if words[symbol:lower()] then
|
||||
context.relevant_symbols[symbol] = files
|
||||
end
|
||||
end
|
||||
|
||||
-- Get file context if available
|
||||
if opts.file then
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
local relative_path = opts.file:gsub("^" .. vim.pesc(root) .. "/", "")
|
||||
local file_index = index.files[relative_path]
|
||||
if file_index then
|
||||
context.current_file = file_index
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Get relevant memories
|
||||
context.patterns = memory.get_relevant(opts.prompt, 5)
|
||||
|
||||
return context
|
||||
end
|
||||
|
||||
--- Get index status
|
||||
---@return table Status information
|
||||
function M.get_status()
|
||||
local index = M.load_index()
|
||||
if not index then
|
||||
return {
|
||||
indexed = false,
|
||||
stats = nil,
|
||||
last_indexed = nil,
|
||||
}
|
||||
end
|
||||
|
||||
return {
|
||||
indexed = true,
|
||||
stats = index.stats,
|
||||
last_indexed = index.last_indexed,
|
||||
project_type = index.project_type,
|
||||
}
|
||||
end
|
||||
|
||||
--- Clear the project index
|
||||
function M.clear()
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
index_cache[root] = nil
|
||||
end
|
||||
|
||||
local path = get_index_path()
|
||||
if path and utils.file_exists(path) then
|
||||
os.remove(path)
|
||||
end
|
||||
end
|
||||
|
||||
--- Setup the indexer with configuration
|
||||
---@param opts? table Configuration options
|
||||
function M.setup(opts)
|
||||
if opts then
|
||||
config = vim.tbl_deep_extend("force", config, opts)
|
||||
end
|
||||
|
||||
-- Index on startup if configured
|
||||
if config.index_on_open then
|
||||
vim.defer_fn(function()
|
||||
M.index_project()
|
||||
end, 1000)
|
||||
end
|
||||
end
|
||||
|
||||
--- Get current configuration
|
||||
---@return table
|
||||
function M.get_config()
|
||||
return vim.deepcopy(config)
|
||||
end
|
||||
|
||||
return M
|
||||
539
lua/codetyper/features/indexer/memory.lua
Normal file
539
lua/codetyper/features/indexer/memory.lua
Normal file
@@ -0,0 +1,539 @@
|
||||
---@mod codetyper.indexer.memory Memory persistence manager
|
||||
---@brief [[
|
||||
--- Stores and retrieves learned patterns and memories in .coder/memories/.
|
||||
--- Supports session history for learning from interactions.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Memory directories
|
||||
local MEMORIES_DIR = "memories"
|
||||
local SESSIONS_DIR = "sessions"
|
||||
local FILES_DIR = "files"
|
||||
|
||||
--- Memory files
|
||||
local PATTERNS_FILE = "patterns.json"
|
||||
local CONVENTIONS_FILE = "conventions.json"
|
||||
local SYMBOLS_FILE = "symbols.json"
|
||||
|
||||
--- In-memory cache
|
||||
local cache = {
|
||||
patterns = nil,
|
||||
conventions = nil,
|
||||
symbols = nil,
|
||||
}
|
||||
|
||||
---@class Memory
|
||||
---@field id string Unique identifier
|
||||
---@field type "pattern"|"convention"|"session"|"interaction"
|
||||
---@field content string The learned information
|
||||
---@field context table Where/when learned
|
||||
---@field weight number Importance score (0.0-1.0)
|
||||
---@field created_at number Timestamp
|
||||
---@field updated_at number Last update timestamp
|
||||
---@field used_count number Times referenced
|
||||
|
||||
--- Get the memories base directory
|
||||
---@return string|nil
|
||||
local function get_memories_dir()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. MEMORIES_DIR
|
||||
end
|
||||
|
||||
--- Get the sessions directory
|
||||
---@return string|nil
|
||||
local function get_sessions_dir()
|
||||
local root = utils.get_project_root()
|
||||
if not root then
|
||||
return nil
|
||||
end
|
||||
return root .. "/.coder/" .. SESSIONS_DIR
|
||||
end
|
||||
|
||||
--- Ensure memories directory exists
|
||||
---@return boolean
|
||||
local function ensure_memories_dir()
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
utils.ensure_dir(dir)
|
||||
utils.ensure_dir(dir .. "/" .. FILES_DIR)
|
||||
return true
|
||||
end
|
||||
|
||||
--- Ensure sessions directory exists
|
||||
---@return boolean
|
||||
local function ensure_sessions_dir()
|
||||
local dir = get_sessions_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
return utils.ensure_dir(dir)
|
||||
end
|
||||
|
||||
--- Generate a unique ID
|
||||
---@return string
|
||||
local function generate_id()
|
||||
return string.format("mem_%d_%s", os.time(), string.sub(tostring(math.random()), 3, 8))
|
||||
end
|
||||
|
||||
--- Load a memory file
|
||||
---@param filename string
|
||||
---@return table
|
||||
local function load_memory_file(filename)
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return {}
|
||||
end
|
||||
|
||||
local path = dir .. "/" .. filename
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return {}
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok or not data then
|
||||
return {}
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Save a memory file
|
||||
---@param filename string
|
||||
---@param data table
|
||||
---@return boolean
|
||||
local function save_memory_file(filename, data)
|
||||
if not ensure_memories_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
|
||||
local path = dir .. "/" .. filename
|
||||
local ok, encoded = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, encoded)
|
||||
end
|
||||
|
||||
--- Hash a file path for storage
|
||||
---@param filepath string
|
||||
---@return string
|
||||
local function hash_path(filepath)
|
||||
local hash = 0
|
||||
for i = 1, #filepath do
|
||||
hash = (hash * 31 + string.byte(filepath, i)) % 2147483647
|
||||
end
|
||||
return string.format("%08x", hash)
|
||||
end
|
||||
|
||||
--- Load patterns from cache or disk
|
||||
---@return table
|
||||
function M.load_patterns()
|
||||
if cache.patterns then
|
||||
return cache.patterns
|
||||
end
|
||||
cache.patterns = load_memory_file(PATTERNS_FILE)
|
||||
return cache.patterns
|
||||
end
|
||||
|
||||
--- Load conventions from cache or disk
|
||||
---@return table
|
||||
function M.load_conventions()
|
||||
if cache.conventions then
|
||||
return cache.conventions
|
||||
end
|
||||
cache.conventions = load_memory_file(CONVENTIONS_FILE)
|
||||
return cache.conventions
|
||||
end
|
||||
|
||||
--- Load symbols from cache or disk
|
||||
---@return table
|
||||
function M.load_symbols()
|
||||
if cache.symbols then
|
||||
return cache.symbols
|
||||
end
|
||||
cache.symbols = load_memory_file(SYMBOLS_FILE)
|
||||
return cache.symbols
|
||||
end
|
||||
|
||||
--- Store a new memory
|
||||
---@param memory Memory
|
||||
---@return boolean
|
||||
function M.store_memory(memory)
|
||||
memory.id = memory.id or generate_id()
|
||||
memory.created_at = memory.created_at or os.time()
|
||||
memory.updated_at = os.time()
|
||||
memory.used_count = memory.used_count or 0
|
||||
memory.weight = memory.weight or 0.5
|
||||
|
||||
local filename
|
||||
if memory.type == "pattern" then
|
||||
filename = PATTERNS_FILE
|
||||
cache.patterns = nil
|
||||
elseif memory.type == "convention" then
|
||||
filename = CONVENTIONS_FILE
|
||||
cache.conventions = nil
|
||||
else
|
||||
filename = PATTERNS_FILE
|
||||
cache.patterns = nil
|
||||
end
|
||||
|
||||
local data = load_memory_file(filename)
|
||||
data[memory.id] = memory
|
||||
|
||||
return save_memory_file(filename, data)
|
||||
end
|
||||
|
||||
--- Store file-specific memory
|
||||
---@param relative_path string Relative file path
|
||||
---@param file_index table FileIndex data
|
||||
---@return boolean
|
||||
function M.store_file_memory(relative_path, file_index)
|
||||
if not ensure_memories_dir() then
|
||||
return false
|
||||
end
|
||||
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return false
|
||||
end
|
||||
|
||||
local hash = hash_path(relative_path)
|
||||
local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json"
|
||||
|
||||
local data = {
|
||||
path = relative_path,
|
||||
indexed_at = os.time(),
|
||||
functions = file_index.functions or {},
|
||||
classes = file_index.classes or {},
|
||||
exports = file_index.exports or {},
|
||||
imports = file_index.imports or {},
|
||||
}
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, data)
|
||||
if not ok then
|
||||
return false
|
||||
end
|
||||
|
||||
return utils.write_file(path, encoded)
|
||||
end
|
||||
|
||||
--- Load file-specific memory
|
||||
---@param relative_path string
|
||||
---@return table|nil
|
||||
function M.load_file_memory(relative_path)
|
||||
local dir = get_memories_dir()
|
||||
if not dir then
|
||||
return nil
|
||||
end
|
||||
|
||||
local hash = hash_path(relative_path)
|
||||
local path = dir .. "/" .. FILES_DIR .. "/" .. hash .. ".json"
|
||||
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return nil
|
||||
end
|
||||
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if not ok then
|
||||
return nil
|
||||
end
|
||||
|
||||
return data
|
||||
end
|
||||
|
||||
--- Store index summary as memories
|
||||
---@param index ProjectIndex
|
||||
function M.store_index_summary(index)
|
||||
-- Store project type convention
|
||||
if index.project_type and index.project_type ~= "unknown" then
|
||||
M.store_memory({
|
||||
type = "convention",
|
||||
content = "Project uses " .. index.project_type .. " ecosystem",
|
||||
context = {
|
||||
project_root = index.project_root,
|
||||
detected_at = os.time(),
|
||||
},
|
||||
weight = 0.9,
|
||||
})
|
||||
end
|
||||
|
||||
-- Store dependency patterns
|
||||
local dep_count = 0
|
||||
for _ in pairs(index.dependencies or {}) do
|
||||
dep_count = dep_count + 1
|
||||
end
|
||||
|
||||
if dep_count > 0 then
|
||||
local deps_list = {}
|
||||
for name, _ in pairs(index.dependencies) do
|
||||
table.insert(deps_list, name)
|
||||
end
|
||||
|
||||
M.store_memory({
|
||||
type = "pattern",
|
||||
content = "Project dependencies: " .. table.concat(deps_list, ", "),
|
||||
context = {
|
||||
dependency_count = dep_count,
|
||||
},
|
||||
weight = 0.7,
|
||||
})
|
||||
end
|
||||
|
||||
-- Update symbol cache
|
||||
cache.symbols = nil
|
||||
save_memory_file(SYMBOLS_FILE, index.symbols or {})
|
||||
end
|
||||
|
||||
--- Store session interaction
|
||||
---@param interaction {prompt: string, response: string, file: string|nil, success: boolean}
|
||||
function M.store_session(interaction)
|
||||
if not ensure_sessions_dir() then
|
||||
return
|
||||
end
|
||||
|
||||
local dir = get_sessions_dir()
|
||||
if not dir then
|
||||
return
|
||||
end
|
||||
|
||||
-- Use date-based session files
|
||||
local date = os.date("%Y-%m-%d")
|
||||
local path = dir .. "/" .. date .. ".json"
|
||||
|
||||
local sessions = {}
|
||||
local content = utils.read_file(path)
|
||||
if content then
|
||||
local ok, data = pcall(vim.json.decode, content)
|
||||
if ok and data then
|
||||
sessions = data
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(sessions, {
|
||||
timestamp = os.time(),
|
||||
prompt = interaction.prompt,
|
||||
response = string.sub(interaction.response or "", 1, 500), -- Truncate
|
||||
file = interaction.file,
|
||||
success = interaction.success,
|
||||
})
|
||||
|
||||
-- Limit session size
|
||||
if #sessions > 100 then
|
||||
sessions = { unpack(sessions, #sessions - 99) }
|
||||
end
|
||||
|
||||
local ok, encoded = pcall(vim.json.encode, sessions)
|
||||
if ok then
|
||||
utils.write_file(path, encoded)
|
||||
end
|
||||
end
|
||||
|
||||
--- Get relevant memories for a query
|
||||
---@param query string Search query
|
||||
---@param limit number Maximum results
|
||||
---@return Memory[]
|
||||
function M.get_relevant(query, limit)
|
||||
limit = limit or 10
|
||||
local results = {}
|
||||
|
||||
-- Tokenize query
|
||||
local query_words = {}
|
||||
for word in query:lower():gmatch("%w+") do
|
||||
if #word > 2 then
|
||||
query_words[word] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- Search patterns
|
||||
local patterns = M.load_patterns()
|
||||
for _, memory in pairs(patterns) do
|
||||
local score = 0
|
||||
local content_lower = (memory.content or ""):lower()
|
||||
|
||||
for word in pairs(query_words) do
|
||||
if content_lower:find(word, 1, true) then
|
||||
score = score + 1
|
||||
end
|
||||
end
|
||||
|
||||
if score > 0 then
|
||||
memory.relevance_score = score * (memory.weight or 0.5)
|
||||
table.insert(results, memory)
|
||||
end
|
||||
end
|
||||
|
||||
-- Search conventions
|
||||
local conventions = M.load_conventions()
|
||||
for _, memory in pairs(conventions) do
|
||||
local score = 0
|
||||
local content_lower = (memory.content or ""):lower()
|
||||
|
||||
for word in pairs(query_words) do
|
||||
if content_lower:find(word, 1, true) then
|
||||
score = score + 1
|
||||
end
|
||||
end
|
||||
|
||||
if score > 0 then
|
||||
memory.relevance_score = score * (memory.weight or 0.5)
|
||||
table.insert(results, memory)
|
||||
end
|
||||
end
|
||||
|
||||
-- Sort by relevance
|
||||
table.sort(results, function(a, b)
|
||||
return (a.relevance_score or 0) > (b.relevance_score or 0)
|
||||
end)
|
||||
|
||||
-- Limit results
|
||||
local limited = {}
|
||||
for i = 1, math.min(limit, #results) do
|
||||
limited[i] = results[i]
|
||||
end
|
||||
|
||||
return limited
|
||||
end
|
||||
|
||||
--- Update memory usage count
|
||||
---@param memory_id string
|
||||
function M.update_usage(memory_id)
|
||||
local patterns = M.load_patterns()
|
||||
if patterns[memory_id] then
|
||||
patterns[memory_id].used_count = (patterns[memory_id].used_count or 0) + 1
|
||||
patterns[memory_id].updated_at = os.time()
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
return
|
||||
end
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
if conventions[memory_id] then
|
||||
conventions[memory_id].used_count = (conventions[memory_id].used_count or 0) + 1
|
||||
conventions[memory_id].updated_at = os.time()
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
end
|
||||
|
||||
--- Get all memories
|
||||
---@return {patterns: table, conventions: table, symbols: table}
|
||||
function M.get_all()
|
||||
return {
|
||||
patterns = M.load_patterns(),
|
||||
conventions = M.load_conventions(),
|
||||
symbols = M.load_symbols(),
|
||||
}
|
||||
end
|
||||
|
||||
--- Clear all memories
|
||||
---@param pattern? string Optional pattern to match memory IDs
|
||||
function M.clear(pattern)
|
||||
if not pattern then
|
||||
-- Clear all
|
||||
cache = { patterns = nil, conventions = nil, symbols = nil }
|
||||
save_memory_file(PATTERNS_FILE, {})
|
||||
save_memory_file(CONVENTIONS_FILE, {})
|
||||
save_memory_file(SYMBOLS_FILE, {})
|
||||
return
|
||||
end
|
||||
|
||||
-- Clear matching pattern
|
||||
local patterns = M.load_patterns()
|
||||
for id in pairs(patterns) do
|
||||
if id:match(pattern) then
|
||||
patterns[id] = nil
|
||||
end
|
||||
end
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
for id in pairs(conventions) do
|
||||
if id:match(pattern) then
|
||||
conventions[id] = nil
|
||||
end
|
||||
end
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
|
||||
--- Prune low-weight memories
|
||||
---@param threshold number Weight threshold (default: 0.1)
|
||||
function M.prune(threshold)
|
||||
threshold = threshold or 0.1
|
||||
|
||||
local patterns = M.load_patterns()
|
||||
local pruned = 0
|
||||
for id, memory in pairs(patterns) do
|
||||
if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then
|
||||
patterns[id] = nil
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
if pruned > 0 then
|
||||
save_memory_file(PATTERNS_FILE, patterns)
|
||||
cache.patterns = nil
|
||||
end
|
||||
|
||||
local conventions = M.load_conventions()
|
||||
for id, memory in pairs(conventions) do
|
||||
if (memory.weight or 0) < threshold and (memory.used_count or 0) == 0 then
|
||||
conventions[id] = nil
|
||||
pruned = pruned + 1
|
||||
end
|
||||
end
|
||||
if pruned > 0 then
|
||||
save_memory_file(CONVENTIONS_FILE, conventions)
|
||||
cache.conventions = nil
|
||||
end
|
||||
|
||||
return pruned
|
||||
end
|
||||
|
||||
--- Get memory statistics
|
||||
---@return table
|
||||
function M.get_stats()
|
||||
local patterns = M.load_patterns()
|
||||
local conventions = M.load_conventions()
|
||||
local symbols = M.load_symbols()
|
||||
|
||||
local pattern_count = 0
|
||||
for _ in pairs(patterns) do
|
||||
pattern_count = pattern_count + 1
|
||||
end
|
||||
|
||||
local convention_count = 0
|
||||
for _ in pairs(conventions) do
|
||||
convention_count = convention_count + 1
|
||||
end
|
||||
|
||||
local symbol_count = 0
|
||||
for _ in pairs(symbols) do
|
||||
symbol_count = symbol_count + 1
|
||||
end
|
||||
|
||||
return {
|
||||
patterns = pattern_count,
|
||||
conventions = convention_count,
|
||||
symbols = symbol_count,
|
||||
total = pattern_count + convention_count,
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
409
lua/codetyper/features/indexer/scanner.lua
Normal file
409
lua/codetyper/features/indexer/scanner.lua
Normal file
@@ -0,0 +1,409 @@
|
||||
---@mod codetyper.indexer.scanner File scanner for project indexing
|
||||
---@brief [[
|
||||
--- Discovers indexable files, detects project type, and parses dependencies.
|
||||
---@brief ]]
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Project type markers
|
||||
local PROJECT_MARKERS = {
|
||||
node = { "package.json" },
|
||||
rust = { "Cargo.toml" },
|
||||
go = { "go.mod" },
|
||||
python = { "pyproject.toml", "setup.py", "requirements.txt" },
|
||||
lua = { "init.lua", ".luarc.json" },
|
||||
ruby = { "Gemfile" },
|
||||
java = { "pom.xml", "build.gradle" },
|
||||
csharp = { "*.csproj", "*.sln" },
|
||||
}
|
||||
|
||||
--- File extension to language mapping
|
||||
local EXTENSION_LANGUAGE = {
|
||||
lua = "lua",
|
||||
ts = "typescript",
|
||||
tsx = "typescriptreact",
|
||||
js = "javascript",
|
||||
jsx = "javascriptreact",
|
||||
py = "python",
|
||||
go = "go",
|
||||
rs = "rust",
|
||||
rb = "ruby",
|
||||
java = "java",
|
||||
c = "c",
|
||||
cpp = "cpp",
|
||||
h = "c",
|
||||
hpp = "cpp",
|
||||
cs = "csharp",
|
||||
}
|
||||
|
||||
--- Default ignore patterns
|
||||
local DEFAULT_IGNORES = {
|
||||
"^%.", -- Hidden files/folders
|
||||
"^node_modules$",
|
||||
"^__pycache__$",
|
||||
"^%.git$",
|
||||
"^%.coder$",
|
||||
"^dist$",
|
||||
"^build$",
|
||||
"^target$",
|
||||
"^vendor$",
|
||||
"^%.next$",
|
||||
"^%.nuxt$",
|
||||
"^coverage$",
|
||||
"%.min%.js$",
|
||||
"%.min%.css$",
|
||||
"%.map$",
|
||||
"%.lock$",
|
||||
"%-lock%.json$",
|
||||
}
|
||||
|
||||
--- Detect project type from root markers
|
||||
---@param root string Project root path
|
||||
---@return string Project type
|
||||
function M.detect_project_type(root)
|
||||
for project_type, markers in pairs(PROJECT_MARKERS) do
|
||||
for _, marker in ipairs(markers) do
|
||||
local path = root .. "/" .. marker
|
||||
if marker:match("^%*") then
|
||||
-- Glob pattern
|
||||
local pattern = marker:gsub("^%*", "")
|
||||
local entries = vim.fn.glob(root .. "/*" .. pattern, false, true)
|
||||
if #entries > 0 then
|
||||
return project_type
|
||||
end
|
||||
else
|
||||
if utils.file_exists(path) then
|
||||
return project_type
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return "unknown"
|
||||
end
|
||||
|
||||
--- Parse project dependencies
|
||||
---@param root string Project root path
|
||||
---@param project_type string Project type
|
||||
---@return {dependencies: table<string, string>, dev_dependencies: table<string, string>}
|
||||
function M.parse_dependencies(root, project_type)
|
||||
local deps = {
|
||||
dependencies = {},
|
||||
dev_dependencies = {},
|
||||
}
|
||||
|
||||
if project_type == "node" then
|
||||
deps = M.parse_package_json(root)
|
||||
elseif project_type == "rust" then
|
||||
deps = M.parse_cargo_toml(root)
|
||||
elseif project_type == "go" then
|
||||
deps = M.parse_go_mod(root)
|
||||
elseif project_type == "python" then
|
||||
deps = M.parse_python_deps(root)
|
||||
end
|
||||
|
||||
return deps
|
||||
end
|
||||
|
||||
--- Parse package.json for Node.js projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_package_json(root)
|
||||
local path = root .. "/package.json"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local ok, pkg = pcall(vim.json.decode, content)
|
||||
if not ok or not pkg then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
return {
|
||||
dependencies = pkg.dependencies or {},
|
||||
dev_dependencies = pkg.devDependencies or {},
|
||||
}
|
||||
end
|
||||
|
||||
--- Parse Cargo.toml for Rust projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_cargo_toml(root)
|
||||
local path = root .. "/Cargo.toml"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local deps = {}
|
||||
local dev_deps = {}
|
||||
local in_deps = false
|
||||
local in_dev_deps = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^%[dependencies%]") then
|
||||
in_deps = true
|
||||
in_dev_deps = false
|
||||
elseif line:match("^%[dev%-dependencies%]") then
|
||||
in_deps = false
|
||||
in_dev_deps = true
|
||||
elseif line:match("^%[") then
|
||||
in_deps = false
|
||||
in_dev_deps = false
|
||||
elseif in_deps or in_dev_deps then
|
||||
local name, version = line:match('^([%w_%-]+)%s*=%s*"([^"]+)"')
|
||||
if not name then
|
||||
name = line:match("^([%w_%-]+)%s*=")
|
||||
version = "workspace"
|
||||
end
|
||||
if name then
|
||||
if in_deps then
|
||||
deps[name] = version or "unknown"
|
||||
else
|
||||
dev_deps[name] = version or "unknown"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = dev_deps }
|
||||
end
|
||||
|
||||
--- Parse go.mod for Go projects
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_go_mod(root)
|
||||
local path = root .. "/go.mod"
|
||||
local content = utils.read_file(path)
|
||||
if not content then
|
||||
return { dependencies = {}, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
local deps = {}
|
||||
local in_require = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^require%s*%(") then
|
||||
in_require = true
|
||||
elseif line:match("^%)") then
|
||||
in_require = false
|
||||
elseif in_require then
|
||||
local module, version = line:match("^%s*([%w%.%-%_/]+)%s+([%w%.%-]+)")
|
||||
if module then
|
||||
deps[module] = version
|
||||
end
|
||||
else
|
||||
local module, version = line:match("^require%s+([%w%.%-%_/]+)%s+([%w%.%-]+)")
|
||||
if module then
|
||||
deps[module] = version
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = {} }
|
||||
end
|
||||
|
||||
--- Parse Python dependencies (pyproject.toml or requirements.txt)
|
||||
---@param root string Project root path
|
||||
---@return {dependencies: table, dev_dependencies: table}
|
||||
function M.parse_python_deps(root)
|
||||
local deps = {}
|
||||
local dev_deps = {}
|
||||
|
||||
-- Try pyproject.toml first
|
||||
local pyproject = root .. "/pyproject.toml"
|
||||
local content = utils.read_file(pyproject)
|
||||
|
||||
if content then
|
||||
-- Simple parsing for dependencies
|
||||
local in_deps = false
|
||||
local in_dev = false
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if line:match("^%[project%.dependencies%]") or line:match("^dependencies%s*=") then
|
||||
in_deps = true
|
||||
in_dev = false
|
||||
elseif line:match("dev") and line:match("dependencies") then
|
||||
in_deps = false
|
||||
in_dev = true
|
||||
elseif line:match("^%[") then
|
||||
in_deps = false
|
||||
in_dev = false
|
||||
elseif in_deps or in_dev then
|
||||
local name = line:match('"([%w_%-]+)')
|
||||
if name then
|
||||
if in_deps then
|
||||
deps[name] = "latest"
|
||||
else
|
||||
dev_deps[name] = "latest"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Fallback to requirements.txt
|
||||
local req_file = root .. "/requirements.txt"
|
||||
content = utils.read_file(req_file)
|
||||
|
||||
if content then
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
if not line:match("^#") and not line:match("^%s*$") then
|
||||
local name, version = line:match("^([%w_%-]+)==([%d%.]+)")
|
||||
if not name then
|
||||
name = line:match("^([%w_%-]+)")
|
||||
version = "latest"
|
||||
end
|
||||
if name then
|
||||
deps[name] = version or "latest"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return { dependencies = deps, dev_dependencies = dev_deps }
|
||||
end
|
||||
|
||||
--- Check if a file/directory should be ignored
|
||||
---@param name string File or directory name
|
||||
---@param config table Indexer configuration
|
||||
---@return boolean
|
||||
function M.should_ignore(name, config)
|
||||
-- Check default patterns
|
||||
for _, pattern in ipairs(DEFAULT_IGNORES) do
|
||||
if name:match(pattern) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
-- Check config excluded dirs
|
||||
if config and config.excluded_dirs then
|
||||
for _, dir in ipairs(config.excluded_dirs) do
|
||||
if name == dir then
|
||||
return true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--- Check if a file should be indexed
|
||||
---@param filepath string Full file path
|
||||
---@param config table Indexer configuration
|
||||
---@return boolean
|
||||
function M.should_index(filepath, config)
|
||||
local name = vim.fn.fnamemodify(filepath, ":t")
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
|
||||
-- Check if it's a coder file
|
||||
if utils.is_coder_file(filepath) then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Check file size
|
||||
if config and config.max_file_size then
|
||||
local stat = vim.loop.fs_stat(filepath)
|
||||
if stat and stat.size > config.max_file_size then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
-- Check extension
|
||||
if config and config.index_extensions then
|
||||
local valid_ext = false
|
||||
for _, allowed_ext in ipairs(config.index_extensions) do
|
||||
if ext == allowed_ext then
|
||||
valid_ext = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if not valid_ext then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
-- Check ignore patterns
|
||||
if M.should_ignore(name, config) then
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
--- Get all indexable files in the project
|
||||
---@param root string Project root path
|
||||
---@param config table Indexer configuration
|
||||
---@return string[] List of file paths
|
||||
function M.get_indexable_files(root, config)
|
||||
local files = {}
|
||||
|
||||
local function scan_dir(path)
|
||||
local handle = vim.loop.fs_scandir(path)
|
||||
if not handle then
|
||||
return
|
||||
end
|
||||
|
||||
while true do
|
||||
local name, type = vim.loop.fs_scandir_next(handle)
|
||||
if not name then
|
||||
break
|
||||
end
|
||||
|
||||
local full_path = path .. "/" .. name
|
||||
|
||||
if M.should_ignore(name, config) then
|
||||
goto continue
|
||||
end
|
||||
|
||||
if type == "directory" then
|
||||
scan_dir(full_path)
|
||||
elseif type == "file" then
|
||||
if M.should_index(full_path, config) then
|
||||
table.insert(files, full_path)
|
||||
end
|
||||
end
|
||||
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
|
||||
scan_dir(root)
|
||||
return files
|
||||
end
|
||||
|
||||
--- Get language from file extension
|
||||
---@param filepath string File path
|
||||
---@return string Language name
|
||||
function M.get_language(filepath)
|
||||
local ext = vim.fn.fnamemodify(filepath, ":e")
|
||||
return EXTENSION_LANGUAGE[ext] or ext
|
||||
end
|
||||
|
||||
--- Read .gitignore patterns
|
||||
---@param root string Project root
|
||||
---@return string[] Patterns
|
||||
function M.read_gitignore(root)
|
||||
local patterns = {}
|
||||
local path = root .. "/.gitignore"
|
||||
local content = utils.read_file(path)
|
||||
|
||||
if not content then
|
||||
return patterns
|
||||
end
|
||||
|
||||
for line in content:gmatch("[^\n]+") do
|
||||
-- Skip comments and empty lines
|
||||
if not line:match("^#") and not line:match("^%s*$") then
|
||||
-- Convert gitignore pattern to Lua pattern (simplified)
|
||||
local pattern = line:gsub("^/", "^"):gsub("%*%*", ".*"):gsub("%*", "[^/]*"):gsub("%?", ".")
|
||||
table.insert(patterns, pattern)
|
||||
end
|
||||
end
|
||||
|
||||
return patterns
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,7 +1,7 @@
|
||||
---@mod codetyper Codetyper.nvim - AI-powered coding partner
|
||||
---@brief [[
|
||||
--- Codetyper.nvim is a Neovim plugin that acts as your coding partner.
|
||||
--- It uses LLM APIs (Claude, OpenAI, Gemini, Copilot, Ollama) to help you
|
||||
--- It uses LLM APIs (OpenAI, Gemini, Copilot, Ollama) to help you
|
||||
--- write code faster using special `.coder.*` files and inline prompt tags.
|
||||
--- Features an event-driven scheduler with confidence scoring and
|
||||
--- completion-aware injection timing.
|
||||
@@ -22,16 +22,16 @@ function M.setup(opts)
|
||||
return
|
||||
end
|
||||
|
||||
local config = require("codetyper.config")
|
||||
local config = require("codetyper.config.defaults")
|
||||
M.config = config.setup(opts)
|
||||
|
||||
-- Initialize modules
|
||||
local commands = require("codetyper.commands")
|
||||
local gitignore = require("codetyper.gitignore")
|
||||
local autocmds = require("codetyper.autocmds")
|
||||
local tree = require("codetyper.tree")
|
||||
local completion = require("codetyper.completion")
|
||||
local logs_panel = require("codetyper.logs_panel")
|
||||
local commands = require("codetyper.adapters.nvim.commands")
|
||||
local gitignore = require("codetyper.support.gitignore")
|
||||
local autocmds = require("codetyper.adapters.nvim.autocmds")
|
||||
local tree = require("codetyper.support.tree")
|
||||
local completion = require("codetyper.features.completion.inline")
|
||||
local logs_panel = require("codetyper.adapters.nvim.ui.logs_panel")
|
||||
|
||||
-- Register commands
|
||||
commands.setup()
|
||||
@@ -51,9 +51,27 @@ function M.setup(opts)
|
||||
-- Initialize tree logging (creates .coder folder and initial tree.log)
|
||||
tree.setup()
|
||||
|
||||
-- Initialize project indexer if enabled
|
||||
if M.config.indexer and M.config.indexer.enabled then
|
||||
local indexer = require("codetyper.features.indexer")
|
||||
indexer.setup(M.config.indexer)
|
||||
end
|
||||
|
||||
-- Initialize brain learning system if enabled
|
||||
if M.config.brain and M.config.brain.enabled then
|
||||
local brain = require("codetyper.core.memory")
|
||||
brain.setup(M.config.brain)
|
||||
end
|
||||
|
||||
-- Setup inline ghost text suggestions (Copilot-style)
|
||||
if M.config.suggestion and M.config.suggestion.enabled then
|
||||
local suggestion = require("codetyper.features.completion.suggestion")
|
||||
suggestion.setup(M.config.suggestion)
|
||||
end
|
||||
|
||||
-- Start the event-driven scheduler if enabled
|
||||
if M.config.scheduler and M.config.scheduler.enabled then
|
||||
local scheduler = require("codetyper.agent.scheduler")
|
||||
local scheduler = require("codetyper.core.scheduler.scheduler")
|
||||
scheduler.start(M.config.scheduler)
|
||||
end
|
||||
|
||||
@@ -62,7 +80,7 @@ function M.setup(opts)
|
||||
-- Auto-open Ask panel after a short delay (to let UI settle)
|
||||
if M.config.auto_open_ask then
|
||||
vim.defer_fn(function()
|
||||
local ask = require("codetyper.ask")
|
||||
local ask = require("codetyper.features.ask.engine")
|
||||
if not ask.is_open() then
|
||||
ask.open()
|
||||
end
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local utils = require("codetyper.support.utils")
|
||||
|
||||
--- Inject generated code into target file
|
||||
---@param target_path string Path to target file
|
||||
---@param code string Generated code
|
||||
---@param prompt_type string Type of prompt (refactor, add, document, etc.)
|
||||
function M.inject_code(target_path, code, prompt_type)
|
||||
local window = require("codetyper.window")
|
||||
local window = require("codetyper.adapters.nvim.windows")
|
||||
|
||||
-- Normalize the target path
|
||||
target_path = vim.fn.fnamemodify(target_path, ":p")
|
||||
@@ -109,7 +109,7 @@ function M.inject_add(bufnr, code)
|
||||
local lines = vim.split(code, "\n", { plain = true })
|
||||
|
||||
-- Get cursor position in target window
|
||||
local window = require("codetyper.window")
|
||||
local window = require("codetyper.adapters.nvim.windows")
|
||||
local target_win = window.get_target_win()
|
||||
|
||||
local insert_line
|
||||
|
||||
@@ -1,364 +0,0 @@
|
||||
---@mod codetyper.llm.claude Claude API client for Codetyper.nvim
|
||||
|
||||
local M = {}
|
||||
|
||||
local utils = require("codetyper.utils")
|
||||
local llm = require("codetyper.llm")
|
||||
|
||||
--- Claude API endpoint
|
||||
local API_URL = "https://api.anthropic.com/v1/messages"
|
||||
|
||||
--- Get API key from config or environment
|
||||
---@return string|nil API key
|
||||
local function get_api_key()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.claude.api_key or vim.env.ANTHROPIC_API_KEY
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.claude.model
|
||||
end
|
||||
|
||||
--- Build request body for Claude API
|
||||
---@param prompt string User prompt
|
||||
---@param context table Context information
|
||||
---@return table Request body
|
||||
local function build_request_body(prompt, context)
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
|
||||
return {
|
||||
model = get_model(),
|
||||
max_tokens = 4096,
|
||||
system = system_prompt,
|
||||
messages = {
|
||||
{
|
||||
role = "user",
|
||||
content = prompt,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Make HTTP request to Claude API
|
||||
---@param body table Request body
|
||||
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
|
||||
local function make_request(body, callback)
|
||||
local api_key = get_api_key()
|
||||
if not api_key then
|
||||
callback(nil, "Claude API key not configured", nil)
|
||||
return
|
||||
end
|
||||
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
-- Use curl for HTTP request (plenary.curl alternative)
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X",
|
||||
"POST",
|
||||
API_URL,
|
||||
"-H",
|
||||
"Content-Type: application/json",
|
||||
"-H",
|
||||
"x-api-key: " .. api_key,
|
||||
"-H",
|
||||
"anthropic-version: 2023-06-01",
|
||||
"-d",
|
||||
json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Claude response", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error.message or "Claude API error", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract usage info
|
||||
local usage = response.usage or {}
|
||||
|
||||
if response.content and response.content[1] and response.content[1].text then
|
||||
local code = llm.extract_code(response.content[1].text)
|
||||
vim.schedule(function()
|
||||
callback(code, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No content in Claude response", nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Claude API request failed: " .. table.concat(data, "\n"), nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Claude API request failed with code: " .. code, nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Generate code using Claude API
|
||||
---@param prompt string The user's prompt
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate(prompt, context, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local model = get_model()
|
||||
|
||||
-- Log the request
|
||||
logs.request("claude", model)
|
||||
logs.thinking("Building request body...")
|
||||
|
||||
local body = build_request_body(prompt, context)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Claude API...")
|
||||
|
||||
utils.notify("Sending request to Claude...", vim.log.levels.INFO)
|
||||
|
||||
make_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
utils.notify(err, vim.log.levels.ERROR)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(usage.input_tokens or 0, usage.output_tokens or 0, "end_turn")
|
||||
end
|
||||
logs.thinking("Response received, extracting code...")
|
||||
logs.info("Code generated successfully")
|
||||
utils.notify("Code generated successfully", vim.log.levels.INFO)
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Check if Claude is properly configured
|
||||
---@return boolean, string? Valid status and optional error message
|
||||
function M.validate()
|
||||
local api_key = get_api_key()
|
||||
if not api_key or api_key == "" then
|
||||
return false, "Claude API key not configured"
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
--- Generate with tool use support for agentic mode
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil) Callback with raw response
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local model = get_model()
|
||||
|
||||
-- Log the request
|
||||
logs.request("claude", model)
|
||||
logs.thinking("Preparing agent request...")
|
||||
|
||||
local api_key = get_api_key()
|
||||
if not api_key then
|
||||
logs.error("Claude API key not configured")
|
||||
callback(nil, "Claude API key not configured")
|
||||
return
|
||||
end
|
||||
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
|
||||
-- Build system prompt with agent instructions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
|
||||
|
||||
-- Build request body with tools
|
||||
local body = {
|
||||
model = get_model(),
|
||||
max_tokens = 4096,
|
||||
system = system_prompt,
|
||||
messages = M.format_messages_for_claude(messages),
|
||||
tools = tools_module.to_claude_format(),
|
||||
}
|
||||
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(json_body)
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Claude API...")
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X",
|
||||
"POST",
|
||||
API_URL,
|
||||
"-H",
|
||||
"Content-Type: application/json",
|
||||
"-H",
|
||||
"x-api-key: " .. api_key,
|
||||
"-H",
|
||||
"anthropic-version: 2023-06-01",
|
||||
"-d",
|
||||
json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
logs.error("Failed to parse Claude response")
|
||||
callback(nil, "Failed to parse Claude response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
logs.error(response.error.message or "Claude API error")
|
||||
callback(nil, response.error.message or "Claude API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Log token usage from response
|
||||
if response.usage then
|
||||
logs.response(response.usage.input_tokens or 0, response.usage.output_tokens or 0, response.stop_reason)
|
||||
end
|
||||
|
||||
-- Log what's in the response
|
||||
if response.content then
|
||||
for _, block in ipairs(response.content) do
|
||||
if block.type == "text" then
|
||||
logs.thinking("Response contains text")
|
||||
elseif block.type == "tool_use" then
|
||||
logs.thinking("Tool call: " .. block.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Return raw response for parser to handle
|
||||
vim.schedule(function()
|
||||
callback(response, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
logs.error("Claude API request failed: " .. table.concat(data, "\n"))
|
||||
callback(nil, "Claude API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
logs.error("Claude API request failed with code: " .. code)
|
||||
callback(nil, "Claude API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Format messages for Claude API
|
||||
---@param messages table[] Internal message format
|
||||
---@return table[] Claude API message format
|
||||
function M.format_messages_for_claude(messages)
|
||||
local formatted = {}
|
||||
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
if type(msg.content) == "table" then
|
||||
-- Tool results
|
||||
table.insert(formatted, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
else
|
||||
table.insert(formatted, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "assistant" then
|
||||
-- Build content array for assistant messages
|
||||
local content = {}
|
||||
|
||||
-- Add text if present
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
|
||||
-- Add tool uses if present
|
||||
if msg.tool_calls then
|
||||
for _, tool_call in ipairs(msg.tool_calls) do
|
||||
table.insert(content, {
|
||||
type = "tool_use",
|
||||
id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
input = tool_call.parameters,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if #content > 0 then
|
||||
table.insert(formatted, {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return formatted
|
||||
end
|
||||
|
||||
return M
|
||||
35
lua/codetyper/params/agents/bash.lua
Normal file
35
lua/codetyper/params/agents/bash.lua
Normal file
@@ -0,0 +1,35 @@
|
||||
M.params = {
|
||||
{
|
||||
name = "command",
|
||||
description = "The shell command to execute",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "cwd",
|
||||
description = "Working directory for the command (optional)",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
{
|
||||
name = "timeout",
|
||||
description = "Timeout in milliseconds (default: 120000)",
|
||||
type = "integer",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "stdout",
|
||||
description = "Command output",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if command failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
return M
|
||||
40
lua/codetyper/params/agents/confidence.lua
Normal file
40
lua/codetyper/params/agents/confidence.lua
Normal file
@@ -0,0 +1,40 @@
|
||||
---@mod codetyper.params.agents.confidence Parameters for confidence scoring
|
||||
local M = {}
|
||||
|
||||
--- Heuristic weights (must sum to 1.0)
|
||||
M.weights = {
|
||||
length = 0.15, -- Response length relative to prompt
|
||||
uncertainty = 0.30, -- Uncertainty phrases
|
||||
syntax = 0.25, -- Syntax completeness
|
||||
repetition = 0.15, -- Duplicate lines
|
||||
truncation = 0.15, -- Incomplete ending
|
||||
}
|
||||
|
||||
--- Uncertainty phrases that indicate low confidence
|
||||
M.uncertainty_phrases = {
|
||||
-- English
|
||||
"i'm not sure",
|
||||
"i am not sure",
|
||||
"maybe",
|
||||
"perhaps",
|
||||
"might work",
|
||||
"could work",
|
||||
"not certain",
|
||||
"uncertain",
|
||||
"i think",
|
||||
"possibly",
|
||||
"TODO",
|
||||
"FIXME",
|
||||
"XXX",
|
||||
"placeholder",
|
||||
"implement this",
|
||||
"fill in",
|
||||
"your code here",
|
||||
"...", -- Ellipsis as placeholder
|
||||
"# TODO",
|
||||
"// TODO",
|
||||
"-- TODO",
|
||||
"/* TODO",
|
||||
}
|
||||
|
||||
return M
|
||||
33
lua/codetyper/params/agents/conflict.lua
Normal file
33
lua/codetyper/params/agents/conflict.lua
Normal file
@@ -0,0 +1,33 @@
|
||||
---@mod codetyper.params.agents.conflict Parameters for conflict resolution
|
||||
local M = {}
|
||||
|
||||
--- Configuration defaults
|
||||
M.config = {
|
||||
-- Run linter check after accepting AI suggestions
|
||||
lint_after_accept = true,
|
||||
-- Auto-fix lint errors without prompting
|
||||
auto_fix_lint_errors = true,
|
||||
-- Auto-show menu after injecting conflict
|
||||
auto_show_menu = true,
|
||||
-- Auto-show menu for next conflict after resolving one
|
||||
auto_show_next_menu = true,
|
||||
}
|
||||
|
||||
--- Highlight groups
|
||||
M.hl_groups = {
|
||||
current = "CoderConflictCurrent",
|
||||
current_label = "CoderConflictCurrentLabel",
|
||||
incoming = "CoderConflictIncoming",
|
||||
incoming_label = "CoderConflictIncomingLabel",
|
||||
separator = "CoderConflictSeparator",
|
||||
hint = "CoderConflictHint",
|
||||
}
|
||||
|
||||
--- Conflict markers
|
||||
M.markers = {
|
||||
current_start = "<<<<<<< CURRENT",
|
||||
separator = "=======",
|
||||
incoming_end = ">>>>>>> INCOMING",
|
||||
}
|
||||
|
||||
return M
|
||||
48
lua/codetyper/params/agents/context.lua
Normal file
48
lua/codetyper/params/agents/context.lua
Normal file
@@ -0,0 +1,48 @@
|
||||
---@mod codetyper.params.agents.context Parameters for context building
|
||||
local M = {}
|
||||
|
||||
--- Common ignore patterns
|
||||
M.ignore_patterns = {
|
||||
"^%.", -- Hidden files/dirs
|
||||
"node_modules",
|
||||
"%.git$",
|
||||
"__pycache__",
|
||||
"%.pyc$",
|
||||
"target", -- Rust
|
||||
"build",
|
||||
"dist",
|
||||
"%.o$",
|
||||
"%.a$",
|
||||
"%.so$",
|
||||
"%.min%.",
|
||||
"%.map$",
|
||||
}
|
||||
|
||||
--- Key files that are important for understanding the project
|
||||
M.important_files = {
|
||||
["package.json"] = "Node.js project config",
|
||||
["Cargo.toml"] = "Rust project config",
|
||||
["go.mod"] = "Go module config",
|
||||
["pyproject.toml"] = "Python project config",
|
||||
["setup.py"] = "Python setup config",
|
||||
["Makefile"] = "Build configuration",
|
||||
["CMakeLists.txt"] = "CMake config",
|
||||
[".gitignore"] = "Git ignore patterns",
|
||||
["README.md"] = "Project documentation",
|
||||
["init.lua"] = "Neovim plugin entry",
|
||||
["plugin.lua"] = "Neovim plugin config",
|
||||
}
|
||||
|
||||
--- Project type detection indicators
|
||||
M.indicators = {
|
||||
["package.json"] = { type = "node", language = "javascript/typescript" },
|
||||
["Cargo.toml"] = { type = "rust", language = "rust" },
|
||||
["go.mod"] = { type = "go", language = "go" },
|
||||
["pyproject.toml"] = { type = "python", language = "python" },
|
||||
["setup.py"] = { type = "python", language = "python" },
|
||||
["Gemfile"] = { type = "ruby", language = "ruby" },
|
||||
["pom.xml"] = { type = "maven", language = "java" },
|
||||
["build.gradle"] = { type = "gradle", language = "java/kotlin" },
|
||||
}
|
||||
|
||||
return M
|
||||
33
lua/codetyper/params/agents/edit.lua
Normal file
33
lua/codetyper/params/agents/edit.lua
Normal file
@@ -0,0 +1,33 @@
|
||||
M.params = {
|
||||
{
|
||||
name = "path",
|
||||
description = "Path to the file to edit",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "old_string",
|
||||
description = "Text to find and replace (empty string to create new file or append)",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "new_string",
|
||||
description = "Text to replace with",
|
||||
type = "string",
|
||||
},
|
||||
}
|
||||
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "Whether the edit was applied",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if edit failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
return M
|
||||
10
lua/codetyper/params/agents/grep.lua
Normal file
10
lua/codetyper/params/agents/grep.lua
Normal file
@@ -0,0 +1,10 @@
|
||||
M.description = [[Searches for a pattern in files using ripgrep.
|
||||
|
||||
Returns file paths and matching lines. Use this to find code by content.
|
||||
|
||||
Example patterns:
|
||||
- "function foo" - Find function definitions
|
||||
- "import.*react" - Find React imports
|
||||
- "TODO|FIXME" - Find todo comments]]
|
||||
|
||||
return M
|
||||
161
lua/codetyper/params/agents/intent.lua
Normal file
161
lua/codetyper/params/agents/intent.lua
Normal file
@@ -0,0 +1,161 @@
|
||||
---@mod codetyper.params.agents.intent Intent patterns and scope configuration
|
||||
local M = {}
|
||||
|
||||
--- Intent patterns with associated metadata
|
||||
M.intent_patterns = {
|
||||
-- Complete: fill in missing implementation
|
||||
complete = {
|
||||
patterns = {
|
||||
"complete",
|
||||
"finish",
|
||||
"implement",
|
||||
"fill in",
|
||||
"fill out",
|
||||
"stub",
|
||||
"todo",
|
||||
"fixme",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace",
|
||||
priority = 1,
|
||||
},
|
||||
|
||||
-- Refactor: rewrite existing code
|
||||
refactor = {
|
||||
patterns = {
|
||||
"refactor",
|
||||
"rewrite",
|
||||
"restructure",
|
||||
"reorganize",
|
||||
"clean up",
|
||||
"cleanup",
|
||||
"simplify",
|
||||
"improve",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace",
|
||||
priority = 2,
|
||||
},
|
||||
|
||||
-- Fix: repair bugs or issues
|
||||
fix = {
|
||||
patterns = {
|
||||
"fix",
|
||||
"repair",
|
||||
"correct",
|
||||
"debug",
|
||||
"solve",
|
||||
"resolve",
|
||||
"patch",
|
||||
"bug",
|
||||
"error",
|
||||
"issue",
|
||||
"update",
|
||||
"modify",
|
||||
"change",
|
||||
"adjust",
|
||||
"tweak",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace",
|
||||
priority = 1,
|
||||
},
|
||||
|
||||
-- Add: insert new code
|
||||
add = {
|
||||
patterns = {
|
||||
"add",
|
||||
"create",
|
||||
"insert",
|
||||
"include",
|
||||
"append",
|
||||
"new",
|
||||
"generate",
|
||||
"write",
|
||||
},
|
||||
scope_hint = nil, -- Could be anywhere
|
||||
action = "insert",
|
||||
priority = 3,
|
||||
},
|
||||
|
||||
-- Document: add documentation
|
||||
document = {
|
||||
patterns = {
|
||||
"document",
|
||||
"comment",
|
||||
"jsdoc",
|
||||
"docstring",
|
||||
"describe",
|
||||
"annotate",
|
||||
"type hint",
|
||||
"typehint",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace", -- Replace with documented version
|
||||
priority = 2,
|
||||
},
|
||||
|
||||
-- Test: generate tests
|
||||
test = {
|
||||
patterns = {
|
||||
"test",
|
||||
"spec",
|
||||
"unit test",
|
||||
"integration test",
|
||||
"coverage",
|
||||
},
|
||||
scope_hint = "file",
|
||||
action = "append",
|
||||
priority = 3,
|
||||
},
|
||||
|
||||
-- Optimize: improve performance
|
||||
optimize = {
|
||||
patterns = {
|
||||
"optimize",
|
||||
"performance",
|
||||
"faster",
|
||||
"efficient",
|
||||
"speed up",
|
||||
"reduce",
|
||||
"minimize",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "replace",
|
||||
priority = 2,
|
||||
},
|
||||
|
||||
-- Explain: provide explanation (no code change)
|
||||
explain = {
|
||||
patterns = {
|
||||
"explain",
|
||||
"what does",
|
||||
"how does",
|
||||
"why",
|
||||
"describe",
|
||||
"walk through",
|
||||
"understand",
|
||||
},
|
||||
scope_hint = "function",
|
||||
action = "none",
|
||||
priority = 4,
|
||||
},
|
||||
}
|
||||
|
||||
--- Scope hint patterns
|
||||
M.scope_patterns = {
|
||||
["this function"] = "function",
|
||||
["this method"] = "function",
|
||||
["the function"] = "function",
|
||||
["the method"] = "function",
|
||||
["this class"] = "class",
|
||||
["the class"] = "class",
|
||||
["this file"] = "file",
|
||||
["the file"] = "file",
|
||||
["this block"] = "block",
|
||||
["the block"] = "block",
|
||||
["this"] = nil, -- Use Tree-sitter to determine
|
||||
["here"] = nil,
|
||||
}
|
||||
|
||||
return M
|
||||
87
lua/codetyper/params/agents/languages.lua
Normal file
87
lua/codetyper/params/agents/languages.lua
Normal file
@@ -0,0 +1,87 @@
|
||||
---@mod codetyper.params.agents.languages Language-specific patterns and configurations
|
||||
local M = {}
|
||||
|
||||
--- Language-specific import patterns
|
||||
M.import_patterns = {
|
||||
-- JavaScript/TypeScript
|
||||
javascript = {
|
||||
{ pattern = "^%s*import%s+.+%s+from%s+['\"]", multi_line = true },
|
||||
{ pattern = "^%s*import%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*import%s*{", multi_line = true },
|
||||
{ pattern = "^%s*import%s*%*", multi_line = true },
|
||||
{ pattern = "^%s*export%s+{.+}%s+from%s+['\"]", multi_line = true },
|
||||
{ pattern = "^%s*const%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
{ pattern = "^%s*let%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
{ pattern = "^%s*var%s+%w+%s*=%s*require%(['\"]", multi_line = false },
|
||||
},
|
||||
-- Python
|
||||
python = {
|
||||
{ pattern = "^%s*import%s+%w", multi_line = false },
|
||||
{ pattern = "^%s*from%s+[%w%.]+%s+import%s+", multi_line = true },
|
||||
},
|
||||
-- Lua
|
||||
lua = {
|
||||
{ pattern = "^%s*local%s+%w+%s*=%s*require%s*%(?['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require%s*%(?['\"]", multi_line = false },
|
||||
},
|
||||
-- Go
|
||||
go = {
|
||||
{ pattern = "^%s*import%s+%(?", multi_line = true },
|
||||
},
|
||||
-- Rust
|
||||
rust = {
|
||||
{ pattern = "^%s*use%s+", multi_line = true },
|
||||
{ pattern = "^%s*extern%s+crate%s+", multi_line = false },
|
||||
},
|
||||
-- C/C++
|
||||
c = {
|
||||
{ pattern = "^%s*#include%s*[<\"]", multi_line = false },
|
||||
},
|
||||
-- Java/Kotlin
|
||||
java = {
|
||||
{ pattern = "^%s*import%s+", multi_line = false },
|
||||
},
|
||||
-- Ruby
|
||||
ruby = {
|
||||
{ pattern = "^%s*require%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require_relative%s+['\"]", multi_line = false },
|
||||
},
|
||||
-- PHP
|
||||
php = {
|
||||
{ pattern = "^%s*use%s+", multi_line = false },
|
||||
{ pattern = "^%s*require%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*require_once%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*include%s+['\"]", multi_line = false },
|
||||
{ pattern = "^%s*include_once%s+['\"]", multi_line = false },
|
||||
},
|
||||
}
|
||||
|
||||
-- Alias common extensions to language configs
|
||||
M.import_patterns.ts = M.import_patterns.javascript
|
||||
M.import_patterns.tsx = M.import_patterns.javascript
|
||||
M.import_patterns.jsx = M.import_patterns.javascript
|
||||
M.import_patterns.mjs = M.import_patterns.javascript
|
||||
M.import_patterns.cjs = M.import_patterns.javascript
|
||||
M.import_patterns.py = M.import_patterns.python
|
||||
M.import_patterns.cpp = M.import_patterns.c
|
||||
M.import_patterns.hpp = M.import_patterns.c
|
||||
M.import_patterns.h = M.import_patterns.c
|
||||
M.import_patterns.kt = M.import_patterns.java
|
||||
M.import_patterns.rs = M.import_patterns.rust
|
||||
M.import_patterns.rb = M.import_patterns.ruby
|
||||
|
||||
--- Language-specific comment patterns
|
||||
M.comment_patterns = {
|
||||
lua = { "^%-%-" },
|
||||
python = { "^#" },
|
||||
javascript = { "^//", "^/%*", "^%*" },
|
||||
typescript = { "^//", "^/%*", "^%*" },
|
||||
go = { "^//", "^/%*", "^%*" },
|
||||
rust = { "^//", "^/%*", "^%*" },
|
||||
c = { "^//", "^/%*", "^%*", "^#" },
|
||||
java = { "^//", "^/%*", "^%*" },
|
||||
ruby = { "^#" },
|
||||
php = { "^//", "^/%*", "^%*", "^#" },
|
||||
}
|
||||
|
||||
return M
|
||||
15
lua/codetyper/params/agents/linter.lua
Normal file
15
lua/codetyper/params/agents/linter.lua
Normal file
@@ -0,0 +1,15 @@
|
||||
---@mod codetyper.params.agents.linter Linter configuration
|
||||
local M = {}
|
||||
|
||||
M.config = {
|
||||
-- Auto-save file after code injection
|
||||
auto_save = true,
|
||||
-- Delay in ms to wait for LSP diagnostics to update
|
||||
diagnostic_delay_ms = 500,
|
||||
-- Severity levels to check (1=Error, 2=Warning, 3=Info, 4=Hint)
|
||||
min_severity = vim.diagnostic.severity.WARN,
|
||||
-- Auto-offer to fix lint errors
|
||||
auto_offer_fix = true,
|
||||
}
|
||||
|
||||
return M
|
||||
36
lua/codetyper/params/agents/logs.lua
Normal file
36
lua/codetyper/params/agents/logs.lua
Normal file
@@ -0,0 +1,36 @@
|
||||
---@mod codetyper.params.agents.logs Log parameters
|
||||
local M = {}
|
||||
|
||||
M.icons = {
|
||||
start = "->",
|
||||
success = "OK",
|
||||
error = "ERR",
|
||||
approval = "??",
|
||||
approved = "YES",
|
||||
rejected = "NO",
|
||||
}
|
||||
|
||||
M.level_icons = {
|
||||
info = "i",
|
||||
debug = ".",
|
||||
request = ">",
|
||||
response = "<",
|
||||
tool = "T",
|
||||
error = "!",
|
||||
warning = "?",
|
||||
success = "i",
|
||||
queue = "Q",
|
||||
patch = "P",
|
||||
}
|
||||
|
||||
M.thinking_types = { "thinking", "reason", "action", "task", "result" }
|
||||
|
||||
M.thinking_prefixes = {
|
||||
thinking = "⏺",
|
||||
reason = "⏺",
|
||||
action = "⏺",
|
||||
task = "✶",
|
||||
result = "",
|
||||
}
|
||||
|
||||
return M
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user