diff --git a/.gitignore b/.gitignore index f03cc40..9995ad4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.coder.* .coder/ .claude/ +Makefile # Created by https://www.toptal.com/developers/gitignore/api/lua diff --git a/tests/minimal_init.lua b/tests/minimal_init.lua new file mode 100644 index 0000000..f98e1c4 --- /dev/null +++ b/tests/minimal_init.lua @@ -0,0 +1,48 @@ +-- Minimal init.lua for running tests +-- This sets up the minimum Neovim environment needed for testing + +-- Add the plugin to the runtimepath +local plugin_root = vim.fn.fnamemodify(debug.getinfo(1, "S").source:sub(2), ":p:h:h") +vim.opt.rtp:prepend(plugin_root) + +-- Add plenary for testing (if available) +local plenary_path = vim.fn.expand("~/.local/share/nvim/lazy/plenary.nvim") +if vim.fn.isdirectory(plenary_path) == 1 then + vim.opt.rtp:prepend(plenary_path) +end + +-- Alternative plenary paths +local alt_plenary_paths = { + vim.fn.expand("~/.local/share/nvim/site/pack/*/start/plenary.nvim"), + vim.fn.expand("~/.config/nvim/plugged/plenary.nvim"), + "/opt/homebrew/share/nvim/site/pack/packer/start/plenary.nvim", +} + +for _, path in ipairs(alt_plenary_paths) do + local expanded = vim.fn.glob(path) + if expanded ~= "" and vim.fn.isdirectory(expanded) == 1 then + vim.opt.rtp:prepend(expanded) + break + end +end + +-- Set up test environment +vim.opt.swapfile = false +vim.opt.backup = false +vim.opt.writebackup = false + +-- Initialize codetyper with test defaults +require("codetyper").setup({ + llm = { + provider = "ollama", + ollama = { + host = "http://localhost:11434", + model = "test-model", + }, + }, + scheduler = { + enabled = false, -- Disable scheduler during tests + }, + auto_gitignore = false, + auto_open_ask = false, +}) diff --git a/tests/run_tests.sh b/tests/run_tests.sh new file mode 100755 index 0000000..25f1dd8 --- /dev/null +++ b/tests/run_tests.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Run codetyper.nvim tests using plenary.nvim + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +echo -e "${YELLOW}Running codetyper.nvim tests...${NC}" +echo "Project root: $PROJECT_ROOT" +echo "" + +# Check if plenary is installed +PLENARY_PATH="" +POSSIBLE_PATHS=( + "$HOME/.local/share/nvim/lazy/plenary.nvim" + "$HOME/.local/share/nvim/site/pack/packer/start/plenary.nvim" + "$HOME/.config/nvim/plugged/plenary.nvim" + "/opt/homebrew/share/nvim/site/pack/packer/start/plenary.nvim" +) + +for path in "${POSSIBLE_PATHS[@]}"; do + if [ -d "$path" ]; then + PLENARY_PATH="$path" + break + fi +done + +if [ -z "$PLENARY_PATH" ]; then + echo -e "${RED}Error: plenary.nvim not found!${NC}" + echo "Please install plenary.nvim first:" + echo " - With lazy.nvim: { 'nvim-lua/plenary.nvim' }" + echo " - With packer: use 'nvim-lua/plenary.nvim'" + exit 1 +fi + +echo "Found plenary at: $PLENARY_PATH" +echo "" + +# Run tests +if [ "$1" == "--file" ] && [ -n "$2" ]; then + # Run specific test file + echo -e "${YELLOW}Running: $2${NC}" + nvim --headless \ + -u "$SCRIPT_DIR/minimal_init.lua" \ + -c "PlenaryBustedFile $SCRIPT_DIR/spec/$2" +else + # Run all tests + echo -e "${YELLOW}Running all tests in spec/ directory${NC}" + nvim --headless \ + -u "$SCRIPT_DIR/minimal_init.lua" \ + -c "PlenaryBustedDirectory $SCRIPT_DIR/spec/ {minimal_init = '$SCRIPT_DIR/minimal_init.lua'}" +fi + +echo "" +echo -e "${GREEN}Tests completed!${NC}" diff --git a/tests/spec/confidence_spec.lua b/tests/spec/confidence_spec.lua new file mode 100644 index 0000000..d0ca040 --- /dev/null +++ b/tests/spec/confidence_spec.lua @@ -0,0 +1,148 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/agent/confidence.lua + +describe("confidence", function() + local confidence = require("codetyper.agent.confidence") + + describe("weights", function() + it("should have weights that sum to 1.0", function() + local total = 0 + for _, weight in pairs(confidence.weights) do + total = total + weight + end + assert.is_near(1.0, total, 0.001) + end) + end) + + describe("score", function() + it("should return 0 for empty response", function() + local score, breakdown = confidence.score("", "some prompt") + + assert.equals(0, score) + assert.equals(0, breakdown.weighted_total) + end) + + it("should return high score for good response", function() + local good_response = [[ +function validateEmail(email) + local pattern = "^[%w%.]+@[%w%.]+%.%w+$" + return string.match(email, pattern) ~= nil +end +]] + local score, breakdown = confidence.score(good_response, "create email validator") + + assert.is_true(score > 0.7) + assert.is_true(breakdown.syntax > 0.5) + end) + + it("should return lower score for response with uncertainty", function() + local uncertain_response = [[ +-- I'm not sure if this is correct, maybe try: +function doSomething() + -- TODO: implement this + -- placeholder code here +end +]] + local score, _ = confidence.score(uncertain_response, "implement function") + + assert.is_true(score < 0.7) + end) + + it("should penalize unbalanced brackets", function() + local unbalanced = [[ +function test() { + if (true) { + console.log("missing bracket") +]] + local _, breakdown = confidence.score(unbalanced, "test") + + assert.is_true(breakdown.syntax < 0.7) + end) + + it("should penalize short responses to long prompts", function() + local long_prompt = "Create a comprehensive function that handles user authentication, " .. + "validates credentials against the database, generates JWT tokens, " .. + "handles refresh tokens, and logs all authentication attempts" + local short_response = "done" + + local score, breakdown = confidence.score(short_response, long_prompt) + + assert.is_true(breakdown.length < 0.5) + end) + + it("should penalize repetitive code", function() + local repetitive = [[ +console.log("test"); +console.log("test"); +console.log("test"); +console.log("test"); +console.log("test"); +console.log("test"); +console.log("test"); +console.log("test"); +]] + local _, breakdown = confidence.score(repetitive, "test") + + assert.is_true(breakdown.repetition < 0.7) + end) + + it("should penalize truncated responses", function() + local truncated = [[ +function process(data) { + const result = data.map(item => { + return { + id: item.id, + name: item... +]] + local _, breakdown = confidence.score(truncated, "test") + + assert.is_true(breakdown.truncation < 1.0) + end) + end) + + describe("needs_escalation", function() + it("should return true for low confidence", function() + assert.is_true(confidence.needs_escalation(0.5, 0.7)) + assert.is_true(confidence.needs_escalation(0.3, 0.7)) + end) + + it("should return false for high confidence", function() + assert.is_false(confidence.needs_escalation(0.8, 0.7)) + assert.is_false(confidence.needs_escalation(0.95, 0.7)) + end) + + it("should use default threshold of 0.7", function() + assert.is_true(confidence.needs_escalation(0.6)) + assert.is_false(confidence.needs_escalation(0.8)) + end) + end) + + describe("level_name", function() + it("should return correct level names", function() + assert.equals("excellent", confidence.level_name(0.95)) + assert.equals("good", confidence.level_name(0.85)) + assert.equals("acceptable", confidence.level_name(0.75)) + assert.equals("uncertain", confidence.level_name(0.6)) + assert.equals("poor", confidence.level_name(0.3)) + end) + end) + + describe("format_breakdown", function() + it("should format breakdown correctly", function() + local breakdown = { + length = 0.8, + uncertainty = 0.9, + syntax = 1.0, + repetition = 0.85, + truncation = 0.95, + weighted_total = 0.9, + } + + local formatted = confidence.format_breakdown(breakdown) + + assert.is_true(formatted:match("len:0.80")) + assert.is_true(formatted:match("unc:0.90")) + assert.is_true(formatted:match("syn:1.00")) + end) + end) +end) diff --git a/tests/spec/config_spec.lua b/tests/spec/config_spec.lua new file mode 100644 index 0000000..24606b6 --- /dev/null +++ b/tests/spec/config_spec.lua @@ -0,0 +1,149 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/config.lua + +describe("config", function() + local config = require("codetyper.config") + + describe("defaults", function() + local defaults = config.defaults + + it("should have llm configuration", function() + assert.is_table(defaults.llm) + assert.equals("claude", defaults.llm.provider) + end) + + it("should have window configuration", function() + assert.is_table(defaults.window) + assert.equals(25, defaults.window.width) + assert.equals("left", defaults.window.position) + end) + + it("should have pattern configuration", function() + assert.is_table(defaults.patterns) + assert.equals("/@", defaults.patterns.open_tag) + assert.equals("@/", defaults.patterns.close_tag) + end) + + it("should have scheduler configuration", function() + assert.is_table(defaults.scheduler) + assert.is_boolean(defaults.scheduler.enabled) + assert.is_boolean(defaults.scheduler.ollama_scout) + assert.is_number(defaults.scheduler.escalation_threshold) + end) + + it("should have claude configuration", function() + assert.is_table(defaults.llm.claude) + assert.is_truthy(defaults.llm.claude.model) + end) + + it("should have openai configuration", function() + assert.is_table(defaults.llm.openai) + assert.is_truthy(defaults.llm.openai.model) + end) + + it("should have gemini configuration", function() + assert.is_table(defaults.llm.gemini) + assert.is_truthy(defaults.llm.gemini.model) + end) + + it("should have ollama configuration", function() + assert.is_table(defaults.llm.ollama) + assert.is_truthy(defaults.llm.ollama.host) + assert.is_truthy(defaults.llm.ollama.model) + end) + end) + + describe("merge", function() + it("should merge user config with defaults", function() + local user_config = { + llm = { + provider = "openai", + }, + } + + local merged = config.merge(user_config) + + -- User value should override + assert.equals("openai", merged.llm.provider) + -- Other defaults should be preserved + assert.equals(25, merged.window.width) + end) + + it("should deep merge nested tables", function() + local user_config = { + llm = { + claude = { + model = "claude-opus-4", + }, + }, + } + + local merged = config.merge(user_config) + + -- User value should override + assert.equals("claude-opus-4", merged.llm.claude.model) + -- Provider default should be preserved + assert.equals("claude", merged.llm.provider) + end) + + it("should handle empty user config", function() + local merged = config.merge({}) + + assert.equals("claude", merged.llm.provider) + assert.equals(25, merged.window.width) + end) + + it("should handle nil user config", function() + local merged = config.merge(nil) + + assert.equals("claude", merged.llm.provider) + end) + end) + + describe("validate", function() + it("should return true for valid config", function() + local valid_config = config.defaults + local is_valid, err = config.validate(valid_config) + + assert.is_true(is_valid) + assert.is_nil(err) + end) + + it("should validate provider value", function() + local invalid_config = vim.tbl_deep_extend("force", {}, config.defaults) + invalid_config.llm.provider = "invalid_provider" + + local is_valid, err = config.validate(invalid_config) + + assert.is_false(is_valid) + assert.is_truthy(err) + end) + + it("should validate window width range", function() + local invalid_config = vim.tbl_deep_extend("force", {}, config.defaults) + invalid_config.window.width = 101 -- Over 100% + + local is_valid, err = config.validate(invalid_config) + + assert.is_false(is_valid) + end) + + it("should validate window position", function() + local invalid_config = vim.tbl_deep_extend("force", {}, config.defaults) + invalid_config.window.position = "center" -- Invalid + + local is_valid, err = config.validate(invalid_config) + + assert.is_false(is_valid) + end) + + it("should validate scheduler threshold range", function() + local invalid_config = vim.tbl_deep_extend("force", {}, config.defaults) + invalid_config.scheduler.escalation_threshold = 1.5 -- Over 1.0 + + local is_valid, err = config.validate(invalid_config) + + assert.is_false(is_valid) + end) + end) +end) diff --git a/tests/spec/intent_spec.lua b/tests/spec/intent_spec.lua new file mode 100644 index 0000000..1d85c90 --- /dev/null +++ b/tests/spec/intent_spec.lua @@ -0,0 +1,286 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/agent/intent.lua + +describe("intent", function() + local intent = require("codetyper.agent.intent") + + describe("detect", function() + describe("complete intent", function() + it("should detect 'complete' keyword", function() + local result = intent.detect("complete this function") + assert.equals("complete", result.type) + assert.equals("replace", result.action) + end) + + it("should detect 'finish' keyword", function() + local result = intent.detect("finish implementing this method") + assert.equals("complete", result.type) + end) + + it("should detect 'implement' keyword", function() + local result = intent.detect("implement the sorting algorithm") + assert.equals("complete", result.type) + end) + + it("should detect 'todo' keyword", function() + local result = intent.detect("fix the TODO here") + assert.equals("complete", result.type) + end) + end) + + describe("refactor intent", function() + it("should detect 'refactor' keyword", function() + local result = intent.detect("refactor this messy code") + assert.equals("refactor", result.type) + assert.equals("replace", result.action) + end) + + it("should detect 'rewrite' keyword", function() + local result = intent.detect("rewrite using async/await") + assert.equals("refactor", result.type) + end) + + it("should detect 'simplify' keyword", function() + local result = intent.detect("simplify this logic") + assert.equals("refactor", result.type) + end) + + it("should detect 'cleanup' keyword", function() + local result = intent.detect("cleanup this code") + assert.equals("refactor", result.type) + end) + end) + + describe("fix intent", function() + it("should detect 'fix' keyword", function() + local result = intent.detect("fix the bug in this function") + assert.equals("fix", result.type) + assert.equals("replace", result.action) + end) + + it("should detect 'debug' keyword", function() + local result = intent.detect("debug this issue") + assert.equals("fix", result.type) + end) + + it("should detect 'bug' keyword", function() + local result = intent.detect("there's a bug here") + assert.equals("fix", result.type) + end) + + it("should detect 'error' keyword", function() + local result = intent.detect("getting an error with this code") + assert.equals("fix", result.type) + end) + end) + + describe("add intent", function() + it("should detect 'add' keyword", function() + local result = intent.detect("add input validation") + assert.equals("add", result.type) + assert.equals("insert", result.action) + end) + + it("should detect 'create' keyword", function() + local result = intent.detect("create a new helper function") + assert.equals("add", result.type) + end) + + it("should detect 'generate' keyword", function() + local result = intent.detect("generate a utility function") + assert.equals("add", result.type) + end) + end) + + describe("document intent", function() + it("should detect 'document' keyword", function() + local result = intent.detect("document this function") + assert.equals("document", result.type) + assert.equals("replace", result.action) + end) + + it("should detect 'jsdoc' keyword", function() + local result = intent.detect("add jsdoc comments") + assert.equals("document", result.type) + end) + + it("should detect 'comment' keyword", function() + local result = intent.detect("add comments to explain") + assert.equals("document", result.type) + end) + end) + + describe("test intent", function() + it("should detect 'test' keyword", function() + local result = intent.detect("write tests for this function") + assert.equals("test", result.type) + assert.equals("append", result.action) + end) + + it("should detect 'unit test' keyword", function() + local result = intent.detect("create unit tests") + assert.equals("test", result.type) + end) + end) + + describe("optimize intent", function() + it("should detect 'optimize' keyword", function() + local result = intent.detect("optimize this loop") + assert.equals("optimize", result.type) + assert.equals("replace", result.action) + end) + + it("should detect 'performance' keyword", function() + local result = intent.detect("improve performance of this function") + assert.equals("optimize", result.type) + end) + + it("should detect 'faster' keyword", function() + local result = intent.detect("make this faster") + assert.equals("optimize", result.type) + end) + end) + + describe("explain intent", function() + it("should detect 'explain' keyword", function() + local result = intent.detect("explain what this does") + assert.equals("explain", result.type) + assert.equals("none", result.action) + end) + + it("should detect 'what does' pattern", function() + local result = intent.detect("what does this function do") + assert.equals("explain", result.type) + end) + + it("should detect 'how does' pattern", function() + local result = intent.detect("how does this algorithm work") + assert.equals("explain", result.type) + end) + end) + + describe("default intent", function() + it("should default to 'add' for unknown prompts", function() + local result = intent.detect("make it blue") + assert.equals("add", result.type) + end) + end) + + describe("scope hints", function() + it("should detect 'this function' scope hint", function() + local result = intent.detect("refactor this function") + assert.equals("function", result.scope_hint) + end) + + it("should detect 'this class' scope hint", function() + local result = intent.detect("document this class") + assert.equals("class", result.scope_hint) + end) + + it("should detect 'this file' scope hint", function() + local result = intent.detect("test this file") + assert.equals("file", result.scope_hint) + end) + end) + + describe("confidence", function() + it("should have higher confidence with more keyword matches", function() + local result1 = intent.detect("fix") + local result2 = intent.detect("fix the bug error") + + assert.is_true(result2.confidence >= result1.confidence) + end) + + it("should cap confidence at 1.0", function() + local result = intent.detect("fix debug bug error issue solve") + assert.is_true(result.confidence <= 1.0) + end) + end) + end) + + describe("modifies_code", function() + it("should return true for replacement intents", function() + assert.is_true(intent.modifies_code({ action = "replace" })) + end) + + it("should return true for insertion intents", function() + assert.is_true(intent.modifies_code({ action = "insert" })) + end) + + it("should return false for explain intent", function() + assert.is_false(intent.modifies_code({ action = "none" })) + end) + end) + + describe("is_replacement", function() + it("should return true for replace action", function() + assert.is_true(intent.is_replacement({ action = "replace" })) + end) + + it("should return false for insert action", function() + assert.is_false(intent.is_replacement({ action = "insert" })) + end) + end) + + describe("is_insertion", function() + it("should return true for insert action", function() + assert.is_true(intent.is_insertion({ action = "insert" })) + end) + + it("should return true for append action", function() + assert.is_true(intent.is_insertion({ action = "append" })) + end) + + it("should return false for replace action", function() + assert.is_false(intent.is_insertion({ action = "replace" })) + end) + end) + + describe("get_prompt_modifier", function() + it("should return modifier for each intent type", function() + local types = { "complete", "refactor", "fix", "add", "document", "test", "optimize", "explain" } + + for _, type_name in ipairs(types) do + local modifier = intent.get_prompt_modifier({ type = type_name }) + assert.is_truthy(modifier) + assert.is_true(#modifier > 0) + end + end) + + it("should return add modifier for unknown type", function() + local modifier = intent.get_prompt_modifier({ type = "unknown" }) + assert.is_truthy(modifier) + end) + end) + + describe("format", function() + it("should format intent correctly", function() + local i = { + type = "refactor", + scope_hint = "function", + action = "replace", + confidence = 0.85, + } + + local formatted = intent.format(i) + + assert.is_true(formatted:match("refactor")) + assert.is_true(formatted:match("function")) + assert.is_true(formatted:match("replace")) + assert.is_true(formatted:match("0.85")) + end) + + it("should handle nil scope_hint", function() + local i = { + type = "add", + scope_hint = nil, + action = "insert", + confidence = 0.5, + } + + local formatted = intent.format(i) + + assert.is_true(formatted:match("auto")) + end) + end) +end) diff --git a/tests/spec/llm_spec.lua b/tests/spec/llm_spec.lua new file mode 100644 index 0000000..be9e90b --- /dev/null +++ b/tests/spec/llm_spec.lua @@ -0,0 +1,118 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/llm/init.lua + +describe("llm", function() + local llm = require("codetyper.llm") + + describe("extract_code", function() + it("should extract code from markdown code block", function() + local response = [[ +Here is the code: + +```lua +function hello() + print("Hello!") +end +``` + +That should work. +]] + local code = llm.extract_code(response) + + assert.is_true(code:match("function hello")) + assert.is_true(code:match('print%("Hello!"%)')) + assert.is_false(code:match("```")) + assert.is_false(code:match("Here is the code")) + end) + + it("should extract code from generic code block", function() + local response = [[ +``` +const x = 1; +const y = 2; +``` +]] + local code = llm.extract_code(response) + + assert.is_true(code:match("const x = 1")) + end) + + it("should handle multiple code blocks (return first)", function() + local response = [[ +```javascript +const first = true; +``` + +```javascript +const second = true; +``` +]] + local code = llm.extract_code(response) + + assert.is_true(code:match("first")) + end) + + it("should return original if no code blocks", function() + local response = "function test() return true end" + local code = llm.extract_code(response) + + assert.equals(response, code) + end) + + it("should handle empty code blocks", function() + local response = [[ +``` +``` +]] + local code = llm.extract_code(response) + + assert.equals("", vim.trim(code)) + end) + + it("should preserve indentation in extracted code", function() + local response = [[ +```lua +function test() + if true then + print("nested") + end +end +``` +]] + local code = llm.extract_code(response) + + assert.is_true(code:match(" if true then")) + assert.is_true(code:match(" print")) + end) + end) + + describe("get_client", function() + it("should return a client with generate function", function() + -- This test depends on config, but verifies interface + local client = llm.get_client() + + assert.is_table(client) + assert.is_function(client.generate) + end) + end) + + describe("build_system_prompt", function() + it("should include language context when provided", function() + local context = { + language = "typescript", + file_path = "/test/file.ts", + } + + local prompt = llm.build_system_prompt(context) + + assert.is_true(prompt:match("typescript") or prompt:match("TypeScript")) + end) + + it("should work with minimal context", function() + local prompt = llm.build_system_prompt({}) + + assert.is_string(prompt) + assert.is_true(#prompt > 0) + end) + end) +end) diff --git a/tests/spec/logs_spec.lua b/tests/spec/logs_spec.lua new file mode 100644 index 0000000..0e19279 --- /dev/null +++ b/tests/spec/logs_spec.lua @@ -0,0 +1,280 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/agent/logs.lua + +describe("logs", function() + local logs + + before_each(function() + -- Reset module state before each test + package.loaded["codetyper.agent.logs"] = nil + logs = require("codetyper.agent.logs") + end) + + describe("log", function() + it("should add entry to log", function() + logs.log("info", "test message") + + local entries = logs.get_entries() + assert.equals(1, #entries) + assert.equals("info", entries[1].level) + assert.equals("test message", entries[1].message) + end) + + it("should include timestamp", function() + logs.log("info", "test") + + local entries = logs.get_entries() + assert.is_truthy(entries[1].timestamp) + assert.is_true(entries[1].timestamp:match("%d+:%d+:%d+")) + end) + + it("should include optional data", function() + logs.log("info", "test", { key = "value" }) + + local entries = logs.get_entries() + assert.equals("value", entries[1].data.key) + end) + end) + + describe("info", function() + it("should log with info level", function() + logs.info("info message") + + local entries = logs.get_entries() + assert.equals("info", entries[1].level) + end) + end) + + describe("debug", function() + it("should log with debug level", function() + logs.debug("debug message") + + local entries = logs.get_entries() + assert.equals("debug", entries[1].level) + end) + end) + + describe("error", function() + it("should log with error level", function() + logs.error("error message") + + local entries = logs.get_entries() + assert.equals("error", entries[1].level) + assert.is_true(entries[1].message:match("ERROR")) + end) + end) + + describe("warning", function() + it("should log with warning level", function() + logs.warning("warning message") + + local entries = logs.get_entries() + assert.equals("warning", entries[1].level) + assert.is_true(entries[1].message:match("WARN")) + end) + end) + + describe("request", function() + it("should log API request", function() + logs.request("claude", "claude-sonnet-4", 1000) + + local entries = logs.get_entries() + assert.equals("request", entries[1].level) + assert.is_true(entries[1].message:match("CLAUDE")) + assert.is_true(entries[1].message:match("claude%-sonnet%-4")) + end) + + it("should store provider info", function() + logs.request("openai", "gpt-4") + + local provider, model = logs.get_provider_info() + assert.equals("openai", provider) + assert.equals("gpt-4", model) + end) + end) + + describe("response", function() + it("should log API response with token counts", function() + logs.response(500, 200, "end_turn") + + local entries = logs.get_entries() + assert.equals("response", entries[1].level) + assert.is_true(entries[1].message:match("500")) + assert.is_true(entries[1].message:match("200")) + end) + + it("should accumulate token totals", function() + logs.response(100, 50) + logs.response(200, 100) + + local prompt_tokens, response_tokens = logs.get_token_totals() + assert.equals(300, prompt_tokens) + assert.equals(150, response_tokens) + end) + end) + + describe("tool", function() + it("should log tool execution", function() + logs.tool("read_file", "start", "/path/to/file.lua") + + local entries = logs.get_entries() + assert.equals("tool", entries[1].level) + assert.is_true(entries[1].message:match("read_file")) + end) + + it("should show correct status icons", function() + logs.tool("write_file", "success", "file created") + local entries = logs.get_entries() + assert.is_true(entries[1].message:match("OK")) + + logs.tool("bash", "error", "command failed") + entries = logs.get_entries() + assert.is_true(entries[2].message:match("ERR")) + end) + end) + + describe("thinking", function() + it("should log thinking step", function() + logs.thinking("Analyzing code structure") + + local entries = logs.get_entries() + assert.equals("debug", entries[1].level) + assert.is_true(entries[1].message:match("> Analyzing")) + end) + end) + + describe("add", function() + it("should add entry using type field", function() + logs.add({ type = "info", message = "test message" }) + + local entries = logs.get_entries() + assert.equals(1, #entries) + assert.equals("info", entries[1].level) + end) + + it("should handle clear type", function() + logs.info("test") + logs.add({ type = "clear" }) + + local entries = logs.get_entries() + assert.equals(0, #entries) + end) + end) + + describe("listeners", function() + it("should notify listeners on new entries", function() + local received = {} + logs.add_listener(function(entry) + table.insert(received, entry) + end) + + logs.info("test message") + + assert.equals(1, #received) + assert.equals("info", received[1].level) + end) + + it("should support multiple listeners", function() + local count = 0 + logs.add_listener(function() count = count + 1 end) + logs.add_listener(function() count = count + 1 end) + + logs.info("test") + + assert.equals(2, count) + end) + + it("should remove listener by ID", function() + local count = 0 + local id = logs.add_listener(function() count = count + 1 end) + + logs.info("test1") + logs.remove_listener(id) + logs.info("test2") + + assert.equals(1, count) + end) + end) + + describe("clear", function() + it("should clear all entries", function() + logs.info("test1") + logs.info("test2") + logs.clear() + + assert.equals(0, #logs.get_entries()) + end) + + it("should reset token totals", function() + logs.response(100, 50) + logs.clear() + + local prompt, response = logs.get_token_totals() + assert.equals(0, prompt) + assert.equals(0, response) + end) + + it("should notify listeners of clear", function() + local cleared = false + logs.add_listener(function(entry) + if entry.level == "clear" then + cleared = true + end + end) + + logs.clear() + + assert.is_true(cleared) + end) + end) + + describe("format_entry", function() + it("should format entry for display", function() + logs.info("test message") + local entry = logs.get_entries()[1] + + local formatted = logs.format_entry(entry) + + assert.is_true(formatted:match("%[%d+:%d+:%d+%]")) + assert.is_true(formatted:match("i")) -- info prefix + assert.is_true(formatted:match("test message")) + end) + + it("should use correct level prefixes", function() + local prefixes = { + { level = "info", prefix = "i" }, + { level = "debug", prefix = "%." }, + { level = "request", prefix = ">" }, + { level = "response", prefix = "<" }, + { level = "tool", prefix = "T" }, + { level = "error", prefix = "!" }, + } + + for _, test in ipairs(prefixes) do + local entry = { + timestamp = "12:00:00", + level = test.level, + message = "test", + } + local formatted = logs.format_entry(entry) + assert.is_true(formatted:match(test.prefix), "Missing prefix for " .. test.level) + end + end) + end) + + describe("estimate_tokens", function() + it("should estimate tokens from text", function() + local text = "This is a test string for token estimation." + local tokens = logs.estimate_tokens(text) + + -- Rough estimate: ~4 chars per token + assert.is_true(tokens > 0) + assert.is_true(tokens < #text) -- Should be less than character count + end) + + it("should handle empty string", function() + local tokens = logs.estimate_tokens("") + assert.equals(0, tokens) + end) + end) +end) diff --git a/tests/spec/parser_spec.lua b/tests/spec/parser_spec.lua new file mode 100644 index 0000000..e60f591 --- /dev/null +++ b/tests/spec/parser_spec.lua @@ -0,0 +1,141 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/parser.lua + +describe("parser", function() + local parser = require("codetyper.parser") + + describe("find_prompts", function() + it("should find single-line prompt", function() + local content = "/@ create a function @/" + local prompts = parser.find_prompts(content, "/@", "@/") + + assert.equals(1, #prompts) + assert.equals(" create a function ", prompts[1].content) + assert.equals(1, prompts[1].start_line) + assert.equals(1, prompts[1].end_line) + end) + + it("should find multi-line prompt", function() + local content = [[ +/@ create a function +that validates email +addresses @/ +]] + local prompts = parser.find_prompts(content, "/@", "@/") + + assert.equals(1, #prompts) + assert.is_true(prompts[1].content:match("validates email")) + assert.equals(2, prompts[1].start_line) + assert.equals(4, prompts[1].end_line) + end) + + it("should find multiple prompts", function() + local content = [[ +/@ first prompt @/ +some code here +/@ second prompt @/ +more code +/@ third prompt +multiline @/ +]] + local prompts = parser.find_prompts(content, "/@", "@/") + + assert.equals(3, #prompts) + assert.equals(" first prompt ", prompts[1].content) + assert.equals(" second prompt ", prompts[2].content) + assert.is_true(prompts[3].content:match("third prompt")) + end) + + it("should return empty table when no prompts found", function() + local content = "just some regular code\nno prompts here" + local prompts = parser.find_prompts(content, "/@", "@/") + + assert.equals(0, #prompts) + end) + + it("should handle prompts with special characters", function() + local content = "/@ add (function) with [brackets] @/" + local prompts = parser.find_prompts(content, "/@", "@/") + + assert.equals(1, #prompts) + assert.is_true(prompts[1].content:match("function")) + assert.is_true(prompts[1].content:match("brackets")) + end) + + it("should handle empty prompt content", function() + local content = "/@ @/" + local prompts = parser.find_prompts(content, "/@", "@/") + + assert.equals(1, #prompts) + assert.equals(" ", prompts[1].content) + end) + + it("should handle custom tags", function() + local content = "" + local prompts = parser.find_prompts(content, "") + + assert.equals(1, #prompts) + assert.is_true(prompts[1].content:match("create button")) + end) + end) + + describe("detect_prompt_type", function() + it("should detect refactor type", function() + assert.equals("refactor", parser.detect_prompt_type("refactor this code")) + assert.equals("refactor", parser.detect_prompt_type("REFACTOR the function")) + end) + + it("should detect add type", function() + assert.equals("add", parser.detect_prompt_type("add a new function")) + assert.equals("add", parser.detect_prompt_type("create a component")) + assert.equals("add", parser.detect_prompt_type("implement sorting algorithm")) + end) + + it("should detect document type", function() + assert.equals("document", parser.detect_prompt_type("document this function")) + assert.equals("document", parser.detect_prompt_type("add jsdoc comments")) + assert.equals("document", parser.detect_prompt_type("comment the code")) + end) + + it("should detect explain type", function() + assert.equals("explain", parser.detect_prompt_type("explain this code")) + assert.equals("explain", parser.detect_prompt_type("what does this do")) + assert.equals("explain", parser.detect_prompt_type("how does this work")) + end) + + it("should return generic for unknown types", function() + assert.equals("generic", parser.detect_prompt_type("do something")) + assert.equals("generic", parser.detect_prompt_type("make it better")) + end) + end) + + describe("clean_prompt", function() + it("should trim whitespace", function() + assert.equals("hello", parser.clean_prompt(" hello ")) + assert.equals("hello", parser.clean_prompt("\n\nhello\n\n")) + end) + + it("should normalize multiple newlines", function() + local input = "line1\n\n\n\nline2" + local expected = "line1\n\nline2" + assert.equals(expected, parser.clean_prompt(input)) + end) + + it("should preserve single newlines", function() + local input = "line1\nline2\nline3" + assert.equals(input, parser.clean_prompt(input)) + end) + end) + + describe("has_closing_tag", function() + it("should return true when closing tag exists", function() + assert.is_true(parser.has_closing_tag("some text @/", "@/")) + assert.is_true(parser.has_closing_tag("@/", "@/")) + end) + + it("should return false when closing tag missing", function() + assert.is_false(parser.has_closing_tag("some text", "@/")) + assert.is_false(parser.has_closing_tag("", "@/")) + end) + end) +end) diff --git a/tests/spec/patch_spec.lua b/tests/spec/patch_spec.lua new file mode 100644 index 0000000..84fedd5 --- /dev/null +++ b/tests/spec/patch_spec.lua @@ -0,0 +1,305 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/agent/patch.lua + +describe("patch", function() + local patch + + before_each(function() + -- Reset module state before each test + package.loaded["codetyper.agent.patch"] = nil + patch = require("codetyper.agent.patch") + end) + + describe("generate_id", function() + it("should generate unique IDs", function() + local id1 = patch.generate_id() + local id2 = patch.generate_id() + + assert.is_not.equals(id1, id2) + assert.is_true(id1:match("^patch_")) + end) + end) + + describe("snapshot_buffer", function() + local test_buf + + before_each(function() + test_buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(test_buf, 0, -1, false, { + "line 1", + "line 2", + "line 3", + "line 4", + "line 5", + }) + end) + + after_each(function() + if vim.api.nvim_buf_is_valid(test_buf) then + vim.api.nvim_buf_delete(test_buf, { force = true }) + end + end) + + it("should capture changedtick", function() + local snapshot = patch.snapshot_buffer(test_buf) + + assert.is_number(snapshot.changedtick) + end) + + it("should capture content hash", function() + local snapshot = patch.snapshot_buffer(test_buf) + + assert.is_string(snapshot.content_hash) + assert.is_true(#snapshot.content_hash > 0) + end) + + it("should snapshot specific range", function() + local snapshot = patch.snapshot_buffer(test_buf, { start_line = 2, end_line = 4 }) + + assert.equals(test_buf, snapshot.bufnr) + assert.is_truthy(snapshot.range) + assert.equals(2, snapshot.range.start_line) + assert.equals(4, snapshot.range.end_line) + end) + end) + + describe("is_snapshot_stale", function() + local test_buf + + before_each(function() + test_buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(test_buf, 0, -1, false, { + "original content", + "line 2", + }) + end) + + after_each(function() + if vim.api.nvim_buf_is_valid(test_buf) then + vim.api.nvim_buf_delete(test_buf, { force = true }) + end + end) + + it("should return false for unchanged buffer", function() + local snapshot = patch.snapshot_buffer(test_buf) + + local is_stale, reason = patch.is_snapshot_stale(snapshot) + + assert.is_false(is_stale) + assert.is_nil(reason) + end) + + it("should return true when content changes", function() + local snapshot = patch.snapshot_buffer(test_buf) + + -- Modify buffer + vim.api.nvim_buf_set_lines(test_buf, 0, 1, false, { "modified content" }) + + local is_stale, reason = patch.is_snapshot_stale(snapshot) + + assert.is_true(is_stale) + assert.equals("content_changed", reason) + end) + + it("should return true for invalid buffer", function() + local snapshot = patch.snapshot_buffer(test_buf) + + -- Delete buffer + vim.api.nvim_buf_delete(test_buf, { force = true }) + + local is_stale, reason = patch.is_snapshot_stale(snapshot) + + assert.is_true(is_stale) + assert.equals("buffer_invalid", reason) + end) + end) + + describe("queue_patch", function() + it("should add patch to queue", function() + local p = { + event_id = "test_event", + target_bufnr = 1, + target_path = "/test/file.lua", + original_snapshot = { + bufnr = 1, + changedtick = 0, + content_hash = "abc123", + }, + generated_code = "function test() end", + confidence = 0.9, + } + + local queued = patch.queue_patch(p) + + assert.is_truthy(queued.id) + assert.equals("pending", queued.status) + + local pending = patch.get_pending() + assert.equals(1, #pending) + end) + + it("should set default status", function() + local p = { + event_id = "test", + generated_code = "code", + confidence = 0.8, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "x" }, + } + + local queued = patch.queue_patch(p) + + assert.equals("pending", queued.status) + end) + end) + + describe("get", function() + it("should return patch by ID", function() + local p = patch.queue_patch({ + event_id = "test", + generated_code = "code", + confidence = 0.8, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "x" }, + }) + + local found = patch.get(p.id) + + assert.is_not.nil(found) + assert.equals(p.id, found.id) + end) + + it("should return nil for unknown ID", function() + local found = patch.get("unknown_id") + assert.is_nil(found) + end) + end) + + describe("mark_applied", function() + it("should mark patch as applied", function() + local p = patch.queue_patch({ + event_id = "test", + generated_code = "code", + confidence = 0.8, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "x" }, + }) + + local success = patch.mark_applied(p.id) + + assert.is_true(success) + assert.equals("applied", patch.get(p.id).status) + assert.is_truthy(patch.get(p.id).applied_at) + end) + end) + + describe("mark_stale", function() + it("should mark patch as stale with reason", function() + local p = patch.queue_patch({ + event_id = "test", + generated_code = "code", + confidence = 0.8, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "x" }, + }) + + local success = patch.mark_stale(p.id, "content_changed") + + assert.is_true(success) + assert.equals("stale", patch.get(p.id).status) + assert.equals("content_changed", patch.get(p.id).stale_reason) + end) + end) + + describe("stats", function() + it("should return correct statistics", function() + local p1 = patch.queue_patch({ + event_id = "test1", + generated_code = "code1", + confidence = 0.8, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "x" }, + }) + + patch.queue_patch({ + event_id = "test2", + generated_code = "code2", + confidence = 0.9, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "y" }, + }) + + patch.mark_applied(p1.id) + + local stats = patch.stats() + + assert.equals(2, stats.total) + assert.equals(1, stats.pending) + assert.equals(1, stats.applied) + end) + end) + + describe("get_for_event", function() + it("should return patches for specific event", function() + patch.queue_patch({ + event_id = "event_a", + generated_code = "code1", + confidence = 0.8, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "x" }, + }) + + patch.queue_patch({ + event_id = "event_b", + generated_code = "code2", + confidence = 0.9, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "y" }, + }) + + patch.queue_patch({ + event_id = "event_a", + generated_code = "code3", + confidence = 0.7, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "z" }, + }) + + local event_a_patches = patch.get_for_event("event_a") + + assert.equals(2, #event_a_patches) + end) + end) + + describe("clear", function() + it("should clear all patches", function() + patch.queue_patch({ + event_id = "test", + generated_code = "code", + confidence = 0.8, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "x" }, + }) + + patch.clear() + + assert.equals(0, #patch.get_pending()) + assert.equals(0, patch.stats().total) + end) + end) + + describe("cancel_for_buffer", function() + it("should cancel patches for specific buffer", function() + patch.queue_patch({ + event_id = "test1", + target_bufnr = 1, + generated_code = "code1", + confidence = 0.8, + original_snapshot = { bufnr = 1, changedtick = 0, content_hash = "x" }, + }) + + patch.queue_patch({ + event_id = "test2", + target_bufnr = 2, + generated_code = "code2", + confidence = 0.9, + original_snapshot = { bufnr = 2, changedtick = 0, content_hash = "y" }, + }) + + local cancelled = patch.cancel_for_buffer(1) + + assert.equals(1, cancelled) + assert.equals(1, #patch.get_pending()) + end) + end) +end) diff --git a/tests/spec/queue_spec.lua b/tests/spec/queue_spec.lua new file mode 100644 index 0000000..8ebaee6 --- /dev/null +++ b/tests/spec/queue_spec.lua @@ -0,0 +1,332 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/agent/queue.lua + +describe("queue", function() + local queue + + before_each(function() + -- Reset module state before each test + package.loaded["codetyper.agent.queue"] = nil + queue = require("codetyper.agent.queue") + end) + + describe("generate_id", function() + it("should generate unique IDs", function() + local id1 = queue.generate_id() + local id2 = queue.generate_id() + + assert.is_not.equals(id1, id2) + assert.is_true(id1:match("^evt_")) + assert.is_true(id2:match("^evt_")) + end) + end) + + describe("hash_content", function() + it("should generate consistent hashes", function() + local content = "test content" + local hash1 = queue.hash_content(content) + local hash2 = queue.hash_content(content) + + assert.equals(hash1, hash2) + end) + + it("should generate different hashes for different content", function() + local hash1 = queue.hash_content("content A") + local hash2 = queue.hash_content("content B") + + assert.is_not.equals(hash1, hash2) + end) + end) + + describe("enqueue", function() + it("should add event to queue", function() + local event = { + bufnr = 1, + prompt_content = "test prompt", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + } + + local enqueued = queue.enqueue(event) + + assert.is_not.nil(enqueued.id) + assert.equals("pending", enqueued.status) + assert.equals(1, queue.size()) + end) + + it("should set default priority to 2", function() + local event = { + bufnr = 1, + prompt_content = "test prompt", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + } + + local enqueued = queue.enqueue(event) + + assert.equals(2, enqueued.priority) + end) + + it("should maintain priority order", function() + queue.enqueue({ + bufnr = 1, + prompt_content = "low priority", + target_path = "/test/file.lua", + priority = 3, + range = { start_line = 1, end_line = 1 }, + }) + + queue.enqueue({ + bufnr = 1, + prompt_content = "high priority", + target_path = "/test/file.lua", + priority = 1, + range = { start_line = 1, end_line = 1 }, + }) + + local first = queue.dequeue() + assert.equals("high priority", first.prompt_content) + end) + + it("should generate content hash automatically", function() + local event = { + bufnr = 1, + prompt_content = "test prompt", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + } + + local enqueued = queue.enqueue(event) + + assert.is_not.nil(enqueued.content_hash) + end) + end) + + describe("dequeue", function() + it("should return nil when queue is empty", function() + local event = queue.dequeue() + assert.is_nil(event) + end) + + it("should return and mark event as processing", function() + queue.enqueue({ + bufnr = 1, + prompt_content = "test", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + local event = queue.dequeue() + + assert.is_not.nil(event) + assert.equals("processing", event.status) + end) + + it("should skip non-pending events", function() + local evt1 = queue.enqueue({ + bufnr = 1, + prompt_content = "first", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + queue.enqueue({ + bufnr = 1, + prompt_content = "second", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + -- Mark first as completed + queue.complete(evt1.id) + + local event = queue.dequeue() + assert.equals("second", event.prompt_content) + end) + end) + + describe("peek", function() + it("should return next pending without removing", function() + queue.enqueue({ + bufnr = 1, + prompt_content = "test", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + local event1 = queue.peek() + local event2 = queue.peek() + + assert.is_not.nil(event1) + assert.equals(event1.id, event2.id) + assert.equals("pending", event1.status) + end) + end) + + describe("get", function() + it("should return event by ID", function() + local enqueued = queue.enqueue({ + bufnr = 1, + prompt_content = "test", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + local event = queue.get(enqueued.id) + + assert.is_not.nil(event) + assert.equals(enqueued.id, event.id) + end) + + it("should return nil for unknown ID", function() + local event = queue.get("unknown_id") + assert.is_nil(event) + end) + end) + + describe("update_status", function() + it("should update event status", function() + local enqueued = queue.enqueue({ + bufnr = 1, + prompt_content = "test", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + local success = queue.update_status(enqueued.id, "completed") + + assert.is_true(success) + assert.equals("completed", queue.get(enqueued.id).status) + end) + + it("should return false for unknown ID", function() + local success = queue.update_status("unknown_id", "completed") + assert.is_false(success) + end) + + it("should merge extra fields", function() + local enqueued = queue.enqueue({ + bufnr = 1, + prompt_content = "test", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + queue.update_status(enqueued.id, "completed", { error = "test error" }) + + local event = queue.get(enqueued.id) + assert.equals("test error", event.error) + end) + end) + + describe("cancel_for_buffer", function() + it("should cancel all pending events for buffer", function() + queue.enqueue({ + bufnr = 1, + prompt_content = "buffer 1 - first", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + queue.enqueue({ + bufnr = 1, + prompt_content = "buffer 1 - second", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + queue.enqueue({ + bufnr = 2, + prompt_content = "buffer 2", + target_path = "/test/file2.lua", + range = { start_line = 1, end_line = 1 }, + }) + + local cancelled = queue.cancel_for_buffer(1) + + assert.equals(2, cancelled) + assert.equals(1, queue.pending_count()) + end) + end) + + describe("stats", function() + it("should return correct statistics", function() + queue.enqueue({ + bufnr = 1, + prompt_content = "pending", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + local evt = queue.enqueue({ + bufnr = 1, + prompt_content = "to complete", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + queue.complete(evt.id) + + local stats = queue.stats() + + assert.equals(2, stats.total) + assert.equals(1, stats.pending) + assert.equals(1, stats.completed) + end) + end) + + describe("clear", function() + it("should clear all events", function() + queue.enqueue({ + bufnr = 1, + prompt_content = "test", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + queue.clear() + + assert.equals(0, queue.size()) + end) + + it("should clear only specified status", function() + local evt = queue.enqueue({ + bufnr = 1, + prompt_content = "to complete", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + queue.complete(evt.id) + + queue.enqueue({ + bufnr = 1, + prompt_content = "pending", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + queue.clear("completed") + + assert.equals(1, queue.size()) + assert.equals(1, queue.pending_count()) + end) + end) + + describe("listeners", function() + it("should notify listeners on enqueue", function() + local notifications = {} + queue.add_listener(function(event_type, event, size) + table.insert(notifications, { type = event_type, event = event, size = size }) + end) + + queue.enqueue({ + bufnr = 1, + prompt_content = "test", + target_path = "/test/file.lua", + range = { start_line = 1, end_line = 1 }, + }) + + assert.equals(1, #notifications) + assert.equals("enqueue", notifications[1].type) + end) + end) +end) diff --git a/tests/spec/utils_spec.lua b/tests/spec/utils_spec.lua new file mode 100644 index 0000000..b13f391 --- /dev/null +++ b/tests/spec/utils_spec.lua @@ -0,0 +1,139 @@ +---@diagnostic disable: undefined-global +-- Tests for lua/codetyper/utils.lua + +describe("utils", function() + local utils = require("codetyper.utils") + + describe("is_coder_file", function() + it("should return true for coder files", function() + assert.is_true(utils.is_coder_file("index.coder.ts")) + assert.is_true(utils.is_coder_file("main.coder.lua")) + assert.is_true(utils.is_coder_file("/path/to/file.coder.py")) + end) + + it("should return false for regular files", function() + assert.is_false(utils.is_coder_file("index.ts")) + assert.is_false(utils.is_coder_file("main.lua")) + assert.is_false(utils.is_coder_file("coder.ts")) + end) + end) + + describe("get_target_path", function() + it("should convert coder path to target path", function() + assert.equals("index.ts", utils.get_target_path("index.coder.ts")) + assert.equals("main.lua", utils.get_target_path("main.coder.lua")) + assert.equals("/path/to/file.py", utils.get_target_path("/path/to/file.coder.py")) + end) + end) + + describe("get_coder_path", function() + it("should convert target path to coder path", function() + assert.equals("index.coder.ts", utils.get_coder_path("index.ts")) + assert.equals("main.coder.lua", utils.get_coder_path("main.lua")) + end) + + it("should preserve directory path", function() + local result = utils.get_coder_path("/path/to/file.py") + assert.is_truthy(result:match("/path/to/")) + assert.is_truthy(result:match("file%.coder%.py")) + end) + end) + + describe("escape_pattern", function() + it("should escape special pattern characters", function() + -- Note: @ is NOT a special Lua pattern character + -- Special chars are: ( ) . % + - * ? [ ] ^ $ + assert.equals("/@", utils.escape_pattern("/@")) + assert.equals("@/", utils.escape_pattern("@/")) + assert.equals("hello%.world", utils.escape_pattern("hello.world")) + assert.equals("test%+pattern", utils.escape_pattern("test+pattern")) + end) + + it("should handle multiple special characters", function() + local input = "(test)[pattern]" + local escaped = utils.escape_pattern(input) + -- Use string.find with plain=true to avoid pattern interpretation + assert.is_truthy(string.find(escaped, "%(", 1, true)) + assert.is_truthy(string.find(escaped, "%)", 1, true)) + assert.is_truthy(string.find(escaped, "%[", 1, true)) + assert.is_truthy(string.find(escaped, "%]", 1, true)) + end) + end) + + describe("file operations", function() + local test_dir + local test_file + + before_each(function() + test_dir = vim.fn.tempname() + utils.ensure_dir(test_dir) + test_file = test_dir .. "/test.txt" + end) + + after_each(function() + vim.fn.delete(test_dir, "rf") + end) + + describe("ensure_dir", function() + it("should create directory", function() + local new_dir = test_dir .. "/subdir" + local result = utils.ensure_dir(new_dir) + + assert.is_true(result) + assert.equals(1, vim.fn.isdirectory(new_dir)) + end) + + it("should return true for existing directory", function() + local result = utils.ensure_dir(test_dir) + assert.is_true(result) + end) + end) + + describe("write_file", function() + it("should write content to file", function() + local result = utils.write_file(test_file, "test content") + + assert.is_true(result) + assert.is_true(utils.file_exists(test_file)) + end) + end) + + describe("read_file", function() + it("should read file content", function() + utils.write_file(test_file, "test content") + + local content = utils.read_file(test_file) + + assert.equals("test content", content) + end) + + it("should return nil for non-existent file", function() + local content = utils.read_file("/non/existent/file.txt") + assert.is_nil(content) + end) + end) + + describe("file_exists", function() + it("should return true for existing file", function() + utils.write_file(test_file, "content") + assert.is_true(utils.file_exists(test_file)) + end) + + it("should return false for non-existent file", function() + assert.is_false(utils.file_exists("/non/existent/file.txt")) + end) + end) + end) + + describe("get_filetype", function() + it("should return filetype for buffer", function() + local buf = vim.api.nvim_create_buf(false, true) + vim.bo[buf].filetype = "lua" + + local ft = utils.get_filetype(buf) + + assert.equals("lua", ft) + vim.api.nvim_buf_delete(buf, { force = true }) + end) + end) +end)