Adding more features
This commit is contained in:
427
tests/spec/agent_tools_spec.lua
Normal file
427
tests/spec/agent_tools_spec.lua
Normal file
@@ -0,0 +1,427 @@
|
||||
--- Tests for agent tools system
|
||||
|
||||
describe("codetyper.agent.tools", function()
|
||||
local tools
|
||||
|
||||
before_each(function()
|
||||
tools = require("codetyper.agent.tools")
|
||||
-- Clear any existing registrations
|
||||
for name, _ in pairs(tools.get_all()) do
|
||||
tools.unregister(name)
|
||||
end
|
||||
end)
|
||||
|
||||
describe("tool registration", function()
|
||||
it("should register a tool", function()
|
||||
local test_tool = {
|
||||
name = "test_tool",
|
||||
description = "A test tool",
|
||||
params = {
|
||||
{ name = "input", type = "string", description = "Test input" },
|
||||
},
|
||||
func = function(input, opts)
|
||||
return "result", nil
|
||||
end,
|
||||
}
|
||||
|
||||
tools.register(test_tool)
|
||||
local retrieved = tools.get("test_tool")
|
||||
|
||||
assert.is_not_nil(retrieved)
|
||||
assert.equals("test_tool", retrieved.name)
|
||||
end)
|
||||
|
||||
it("should unregister a tool", function()
|
||||
local test_tool = {
|
||||
name = "temp_tool",
|
||||
description = "Temporary",
|
||||
func = function() end,
|
||||
}
|
||||
|
||||
tools.register(test_tool)
|
||||
assert.is_not_nil(tools.get("temp_tool"))
|
||||
|
||||
tools.unregister("temp_tool")
|
||||
assert.is_nil(tools.get("temp_tool"))
|
||||
end)
|
||||
|
||||
it("should list all tools", function()
|
||||
tools.register({ name = "tool1", func = function() end })
|
||||
tools.register({ name = "tool2", func = function() end })
|
||||
tools.register({ name = "tool3", func = function() end })
|
||||
|
||||
local list = tools.list()
|
||||
assert.equals(3, #list)
|
||||
end)
|
||||
|
||||
it("should filter tools with predicate", function()
|
||||
tools.register({ name = "safe_tool", requires_confirmation = false, func = function() end })
|
||||
tools.register({ name = "dangerous_tool", requires_confirmation = true, func = function() end })
|
||||
|
||||
local safe_list = tools.list(function(t)
|
||||
return not t.requires_confirmation
|
||||
end)
|
||||
|
||||
assert.equals(1, #safe_list)
|
||||
assert.equals("safe_tool", safe_list[1].name)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tool execution", function()
|
||||
it("should execute a tool and return result", function()
|
||||
tools.register({
|
||||
name = "adder",
|
||||
params = {
|
||||
{ name = "a", type = "number" },
|
||||
{ name = "b", type = "number" },
|
||||
},
|
||||
func = function(input, opts)
|
||||
return input.a + input.b, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.execute("adder", { a = 5, b = 3 }, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals(8, result)
|
||||
end)
|
||||
|
||||
it("should return error for unknown tool", function()
|
||||
local result, err = tools.execute("nonexistent", {}, {})
|
||||
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("Unknown tool"))
|
||||
end)
|
||||
|
||||
it("should track execution history", function()
|
||||
tools.clear_history()
|
||||
tools.register({
|
||||
name = "tracked_tool",
|
||||
func = function()
|
||||
return "done", nil
|
||||
end,
|
||||
})
|
||||
|
||||
tools.execute("tracked_tool", {}, {})
|
||||
tools.execute("tracked_tool", {}, {})
|
||||
|
||||
local history = tools.get_history()
|
||||
assert.equals(2, #history)
|
||||
assert.equals("tracked_tool", history[1].tool)
|
||||
assert.equals("completed", history[1].status)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tool schemas", function()
|
||||
it("should generate JSON schema for tools", function()
|
||||
tools.register({
|
||||
name = "schema_test",
|
||||
description = "Test schema generation",
|
||||
params = {
|
||||
{ name = "required_param", type = "string", description = "A required param" },
|
||||
{ name = "optional_param", type = "number", description = "Optional", optional = true },
|
||||
},
|
||||
returns = {
|
||||
{ name = "result", type = "string" },
|
||||
},
|
||||
to_schema = require("codetyper.agent.tools.base").to_schema,
|
||||
func = function() end,
|
||||
})
|
||||
|
||||
local schemas = tools.get_schemas()
|
||||
assert.equals(1, #schemas)
|
||||
|
||||
local schema = schemas[1]
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("schema_test", schema.function_def.name)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.required_param)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.optional_param)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("process_tool_call", function()
|
||||
it("should process tool call with name and input", function()
|
||||
tools.register({
|
||||
name = "processor_test",
|
||||
func = function(input, opts)
|
||||
return "processed: " .. input.value, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.process_tool_call({
|
||||
name = "processor_test",
|
||||
input = { value = "test" },
|
||||
}, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals("processed: test", result)
|
||||
end)
|
||||
|
||||
it("should parse JSON string arguments", function()
|
||||
tools.register({
|
||||
name = "json_parser_test",
|
||||
func = function(input, opts)
|
||||
return input.key, nil
|
||||
end,
|
||||
})
|
||||
|
||||
local result, err = tools.process_tool_call({
|
||||
name = "json_parser_test",
|
||||
arguments = '{"key": "value"}',
|
||||
}, {})
|
||||
|
||||
assert.is_nil(err)
|
||||
assert.equals("value", result)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("codetyper.agent.tools.base", function()
|
||||
local base
|
||||
|
||||
before_each(function()
|
||||
base = require("codetyper.agent.tools.base")
|
||||
end)
|
||||
|
||||
describe("validate_input", function()
|
||||
it("should validate required parameters", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "required", type = "string" },
|
||||
{ name = "optional", type = "string", optional = true },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid, err = tool:validate_input({ required = "value" })
|
||||
assert.is_true(valid)
|
||||
assert.is_nil(err)
|
||||
end)
|
||||
|
||||
it("should fail on missing required parameter", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "required", type = "string" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid, err = tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
assert.truthy(err:match("Missing required parameter"))
|
||||
end)
|
||||
|
||||
it("should validate parameter types", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "num", type = "number" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid1, _ = tool:validate_input({ num = 42 })
|
||||
assert.is_true(valid1)
|
||||
|
||||
local valid2, err2 = tool:validate_input({ num = "not a number" })
|
||||
assert.is_false(valid2)
|
||||
assert.truthy(err2:match("must be number"))
|
||||
end)
|
||||
|
||||
it("should validate integer type", function()
|
||||
local tool = setmetatable({
|
||||
params = {
|
||||
{ name = "int", type = "integer" },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local valid1, _ = tool:validate_input({ int = 42 })
|
||||
assert.is_true(valid1)
|
||||
|
||||
local valid2, err2 = tool:validate_input({ int = 42.5 })
|
||||
assert.is_false(valid2)
|
||||
assert.truthy(err2:match("must be an integer"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_description", function()
|
||||
it("should return string description", function()
|
||||
local tool = setmetatable({
|
||||
description = "Static description",
|
||||
}, base)
|
||||
|
||||
assert.equals("Static description", tool:get_description())
|
||||
end)
|
||||
|
||||
it("should call function description", function()
|
||||
local tool = setmetatable({
|
||||
description = function()
|
||||
return "Dynamic description"
|
||||
end,
|
||||
}, base)
|
||||
|
||||
assert.equals("Dynamic description", tool:get_description())
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("to_schema", function()
|
||||
it("should generate valid schema", function()
|
||||
local tool = setmetatable({
|
||||
name = "test",
|
||||
description = "Test tool",
|
||||
params = {
|
||||
{ name = "input", type = "string", description = "Input value" },
|
||||
{ name = "count", type = "integer", description = "Count", optional = true },
|
||||
},
|
||||
}, base)
|
||||
|
||||
local schema = tool:to_schema()
|
||||
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("test", schema.function_def.name)
|
||||
assert.equals("Test tool", schema.function_def.description)
|
||||
assert.equals("object", schema.function_def.parameters.type)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.input)
|
||||
assert.is_not_nil(schema.function_def.parameters.properties.count)
|
||||
assert.same({ "input" }, schema.function_def.parameters.required)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("built-in tools", function()
|
||||
describe("view tool", function()
|
||||
local view
|
||||
|
||||
before_each(function()
|
||||
view = require("codetyper.agent.tools.view")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("view", view.name)
|
||||
assert.is_string(view.description)
|
||||
assert.is_table(view.params)
|
||||
assert.is_function(view.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = view.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("grep tool", function()
|
||||
local grep
|
||||
|
||||
before_each(function()
|
||||
grep = require("codetyper.agent.tools.grep")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("grep", grep.name)
|
||||
assert.is_string(grep.description)
|
||||
assert.is_table(grep.params)
|
||||
assert.is_function(grep.func)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local result, err = grep.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("pattern is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("glob tool", function()
|
||||
local glob
|
||||
|
||||
before_each(function()
|
||||
glob = require("codetyper.agent.tools.glob")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("glob", glob.name)
|
||||
assert.is_string(glob.description)
|
||||
assert.is_table(glob.params)
|
||||
assert.is_function(glob.func)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local result, err = glob.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("pattern is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edit tool", function()
|
||||
local edit
|
||||
|
||||
before_each(function()
|
||||
edit = require("codetyper.agent.tools.edit")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("edit", edit.name)
|
||||
assert.is_string(edit.description)
|
||||
assert.is_table(edit.params)
|
||||
assert.is_function(edit.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = edit.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
|
||||
it("should require old_string parameter", function()
|
||||
local result, err = edit.func({ path = "/tmp/test" }, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("old_string is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("write tool", function()
|
||||
local write
|
||||
|
||||
before_each(function()
|
||||
write = require("codetyper.agent.tools.write")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("write", write.name)
|
||||
assert.is_string(write.description)
|
||||
assert.is_table(write.params)
|
||||
assert.is_function(write.func)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local result, err = write.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("path is required"))
|
||||
end)
|
||||
|
||||
it("should require content parameter", function()
|
||||
local result, err = write.func({ path = "/tmp/test" }, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("content is required"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("bash tool", function()
|
||||
local bash
|
||||
|
||||
before_each(function()
|
||||
bash = require("codetyper.agent.tools.bash")
|
||||
end)
|
||||
|
||||
it("should have required fields", function()
|
||||
assert.equals("bash", bash.name)
|
||||
assert.is_function(bash.func)
|
||||
end)
|
||||
|
||||
it("should require command parameter", function()
|
||||
local result, err = bash.func({}, {})
|
||||
assert.is_nil(result)
|
||||
assert.truthy(err:match("command is required"))
|
||||
end)
|
||||
|
||||
it("should require confirmation by default", function()
|
||||
assert.is_true(bash.requires_confirmation)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
312
tests/spec/agentic_spec.lua
Normal file
312
tests/spec/agentic_spec.lua
Normal file
@@ -0,0 +1,312 @@
|
||||
---@diagnostic disable: undefined-global
|
||||
-- Unit tests for the agentic system
|
||||
|
||||
describe("agentic module", function()
|
||||
local agentic
|
||||
|
||||
before_each(function()
|
||||
-- Reset and reload
|
||||
package.loaded["codetyper.agent.agentic"] = nil
|
||||
agentic = require("codetyper.agent.agentic")
|
||||
end)
|
||||
|
||||
it("should list built-in agents", function()
|
||||
local agents = agentic.list_agents()
|
||||
assert.is_table(agents)
|
||||
assert.is_true(#agents >= 3) -- coder, planner, explorer
|
||||
|
||||
local names = {}
|
||||
for _, agent in ipairs(agents) do
|
||||
names[agent.name] = true
|
||||
end
|
||||
|
||||
assert.is_true(names["coder"])
|
||||
assert.is_true(names["planner"])
|
||||
assert.is_true(names["explorer"])
|
||||
end)
|
||||
|
||||
it("should have description for each agent", function()
|
||||
local agents = agentic.list_agents()
|
||||
for _, agent in ipairs(agents) do
|
||||
assert.is_string(agent.description)
|
||||
assert.is_true(#agent.description > 0)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should mark built-in agents as builtin", function()
|
||||
local agents = agentic.list_agents()
|
||||
local coder = nil
|
||||
for _, agent in ipairs(agents) do
|
||||
if agent.name == "coder" then
|
||||
coder = agent
|
||||
break
|
||||
end
|
||||
end
|
||||
assert.is_not_nil(coder)
|
||||
assert.is_true(coder.builtin)
|
||||
end)
|
||||
|
||||
it("should have init function to create directories", function()
|
||||
assert.is_function(agentic.init)
|
||||
assert.is_function(agentic.init_agents_dir)
|
||||
assert.is_function(agentic.init_rules_dir)
|
||||
end)
|
||||
|
||||
it("should have run function for executing tasks", function()
|
||||
assert.is_function(agentic.run)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("tools format conversion", function()
|
||||
local tools_module
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools"] = nil
|
||||
tools_module = require("codetyper.agent.tools")
|
||||
-- Load tools
|
||||
if tools_module.load_builtins then
|
||||
pcall(tools_module.load_builtins)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should have to_openai_format function", function()
|
||||
assert.is_function(tools_module.to_openai_format)
|
||||
end)
|
||||
|
||||
it("should have to_claude_format function", function()
|
||||
assert.is_function(tools_module.to_claude_format)
|
||||
end)
|
||||
|
||||
it("should convert tools to OpenAI format", function()
|
||||
local openai_tools = tools_module.to_openai_format()
|
||||
assert.is_table(openai_tools)
|
||||
|
||||
-- If tools are loaded, check format
|
||||
if #openai_tools > 0 then
|
||||
local first_tool = openai_tools[1]
|
||||
assert.equals("function", first_tool.type)
|
||||
assert.is_table(first_tool["function"])
|
||||
assert.is_string(first_tool["function"].name)
|
||||
end
|
||||
end)
|
||||
|
||||
it("should convert tools to Claude format", function()
|
||||
local claude_tools = tools_module.to_claude_format()
|
||||
assert.is_table(claude_tools)
|
||||
|
||||
-- If tools are loaded, check format
|
||||
if #claude_tools > 0 then
|
||||
local first_tool = claude_tools[1]
|
||||
assert.is_string(first_tool.name)
|
||||
assert.is_table(first_tool.input_schema)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edit tool", function()
|
||||
local edit_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.edit"] = nil
|
||||
edit_tool = require("codetyper.agent.tools.edit")
|
||||
end)
|
||||
|
||||
it("should have name 'edit'", function()
|
||||
assert.equals("edit", edit_tool.name)
|
||||
end)
|
||||
|
||||
it("should have description mentioning matching strategies", function()
|
||||
local desc = edit_tool:get_description()
|
||||
assert.is_string(desc)
|
||||
-- Should mention the matching capabilities
|
||||
assert.is_true(desc:lower():match("match") ~= nil or desc:lower():match("replac") ~= nil)
|
||||
end)
|
||||
|
||||
it("should have params defined", function()
|
||||
assert.is_table(edit_tool.params)
|
||||
assert.is_true(#edit_tool.params >= 3) -- path, old_string, new_string
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
old_string = "test",
|
||||
new_string = "test2",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
assert.is_string(err)
|
||||
end)
|
||||
|
||||
it("should require old_string parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test",
|
||||
new_string = "test",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should require new_string parameter", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test",
|
||||
old_string = "test",
|
||||
})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept empty old_string for new file creation", function()
|
||||
local valid, err = edit_tool:validate_input({
|
||||
path = "/test/new_file.lua",
|
||||
old_string = "",
|
||||
new_string = "new content",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
assert.is_nil(err)
|
||||
end)
|
||||
|
||||
it("should have func implementation", function()
|
||||
assert.is_function(edit_tool.func)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("view tool", function()
|
||||
local view_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.view"] = nil
|
||||
view_tool = require("codetyper.agent.tools.view")
|
||||
end)
|
||||
|
||||
it("should have name 'view'", function()
|
||||
assert.equals("view", view_tool.name)
|
||||
end)
|
||||
|
||||
it("should require path parameter", function()
|
||||
local valid, err = view_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid path", function()
|
||||
local valid, err = view_tool:validate_input({
|
||||
path = "/test/file.lua",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("write tool", function()
|
||||
local write_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.write"] = nil
|
||||
write_tool = require("codetyper.agent.tools.write")
|
||||
end)
|
||||
|
||||
it("should have name 'write'", function()
|
||||
assert.equals("write", write_tool.name)
|
||||
end)
|
||||
|
||||
it("should require path and content parameters", function()
|
||||
local valid, err = write_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
|
||||
valid, err = write_tool:validate_input({ path = "/test" })
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid input", function()
|
||||
local valid, err = write_tool:validate_input({
|
||||
path = "/test/file.lua",
|
||||
content = "test content",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("grep tool", function()
|
||||
local grep_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.grep"] = nil
|
||||
grep_tool = require("codetyper.agent.tools.grep")
|
||||
end)
|
||||
|
||||
it("should have name 'grep'", function()
|
||||
assert.equals("grep", grep_tool.name)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local valid, err = grep_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid pattern", function()
|
||||
local valid, err = grep_tool:validate_input({
|
||||
pattern = "function.*test",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("glob tool", function()
|
||||
local glob_tool
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.glob"] = nil
|
||||
glob_tool = require("codetyper.agent.tools.glob")
|
||||
end)
|
||||
|
||||
it("should have name 'glob'", function()
|
||||
assert.equals("glob", glob_tool.name)
|
||||
end)
|
||||
|
||||
it("should require pattern parameter", function()
|
||||
local valid, err = glob_tool:validate_input({})
|
||||
assert.is_false(valid)
|
||||
end)
|
||||
|
||||
it("should accept valid pattern", function()
|
||||
local valid, err = glob_tool:validate_input({
|
||||
pattern = "**/*.lua",
|
||||
})
|
||||
assert.is_true(valid)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("base tool", function()
|
||||
local Base
|
||||
|
||||
before_each(function()
|
||||
package.loaded["codetyper.agent.tools.base"] = nil
|
||||
Base = require("codetyper.agent.tools.base")
|
||||
end)
|
||||
|
||||
it("should have validate_input method", function()
|
||||
assert.is_function(Base.validate_input)
|
||||
end)
|
||||
|
||||
it("should have to_schema method", function()
|
||||
assert.is_function(Base.to_schema)
|
||||
end)
|
||||
|
||||
it("should have get_description method", function()
|
||||
assert.is_function(Base.get_description)
|
||||
end)
|
||||
|
||||
it("should generate valid schema", function()
|
||||
local test_tool = setmetatable({
|
||||
name = "test",
|
||||
description = "A test tool",
|
||||
params = {
|
||||
{ name = "arg1", type = "string", description = "First arg" },
|
||||
{ name = "arg2", type = "number", description = "Second arg", optional = true },
|
||||
},
|
||||
}, Base)
|
||||
|
||||
local schema = test_tool:to_schema()
|
||||
assert.equals("function", schema.type)
|
||||
assert.equals("test", schema.function_def.name)
|
||||
assert.is_table(schema.function_def.parameters.properties)
|
||||
assert.is_table(schema.function_def.parameters.required)
|
||||
assert.is_true(vim.tbl_contains(schema.function_def.parameters.required, "arg1"))
|
||||
assert.is_false(vim.tbl_contains(schema.function_def.parameters.required, "arg2"))
|
||||
end)
|
||||
end)
|
||||
153
tests/spec/brain_learners_spec.lua
Normal file
153
tests/spec/brain_learners_spec.lua
Normal file
@@ -0,0 +1,153 @@
|
||||
--- Tests for brain/learners pattern detection and extraction
|
||||
describe("brain.learners", function()
|
||||
local pattern_learner
|
||||
|
||||
before_each(function()
|
||||
-- Clear module cache
|
||||
package.loaded["codetyper.brain.learners.pattern"] = nil
|
||||
package.loaded["codetyper.brain.types"] = nil
|
||||
|
||||
pattern_learner = require("codetyper.brain.learners.pattern")
|
||||
end)
|
||||
|
||||
describe("pattern learner detection", function()
|
||||
it("should detect code_completion events", function()
|
||||
local event = { type = "code_completion", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect file_indexed events", function()
|
||||
local event = { type = "file_indexed", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect code_analyzed events", function()
|
||||
local event = { type = "code_analyzed", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should detect pattern_detected events", function()
|
||||
local event = { type = "pattern_detected", data = {} }
|
||||
assert.is_true(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect plain 'pattern' type events", function()
|
||||
-- This was the bug - 'pattern' type was not in the valid_types list
|
||||
local event = { type = "pattern", data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect unknown event types", function()
|
||||
local event = { type = "unknown_type", data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
|
||||
it("should NOT detect nil events", function()
|
||||
assert.is_false(pattern_learner.detect(nil))
|
||||
end)
|
||||
|
||||
it("should NOT detect events without type", function()
|
||||
local event = { data = {} }
|
||||
assert.is_false(pattern_learner.detect(event))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("pattern learner extraction", function()
|
||||
it("should extract from pattern_detected events", function()
|
||||
local event = {
|
||||
type = "pattern_detected",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
name = "Test pattern",
|
||||
description = "Pattern description",
|
||||
language = "lua",
|
||||
symbols = { "func1", "func2" },
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.equals("Test pattern", extracted.summary)
|
||||
assert.equals("Pattern description", extracted.detail)
|
||||
assert.equals("lua", extracted.lang)
|
||||
assert.equals("/path/to/file.lua", extracted.file)
|
||||
end)
|
||||
|
||||
it("should handle pattern_detected with minimal data", function()
|
||||
local event = {
|
||||
type = "pattern_detected",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
name = "Minimal pattern",
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.equals("Minimal pattern", extracted.summary)
|
||||
assert.equals("Minimal pattern", extracted.detail)
|
||||
end)
|
||||
|
||||
it("should extract from code_completion events", function()
|
||||
local event = {
|
||||
type = "code_completion",
|
||||
file = "/path/to/file.lua",
|
||||
data = {
|
||||
intent = "add function",
|
||||
code = "function test() end",
|
||||
language = "lua",
|
||||
},
|
||||
}
|
||||
|
||||
local extracted = pattern_learner.extract(event)
|
||||
|
||||
assert.is_not_nil(extracted)
|
||||
assert.is_true(extracted.summary:find("Code pattern") ~= nil)
|
||||
assert.equals("function test() end", extracted.detail)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("should_learn validation", function()
|
||||
it("should accept valid patterns", function()
|
||||
local data = {
|
||||
summary = "Valid pattern summary",
|
||||
detail = "This is a detailed description of the pattern",
|
||||
}
|
||||
assert.is_true(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns without summary", function()
|
||||
local data = {
|
||||
summary = "",
|
||||
detail = "Some detail",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns with nil summary", function()
|
||||
local data = {
|
||||
summary = nil,
|
||||
detail = "Some detail",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject patterns with very short detail", function()
|
||||
local data = {
|
||||
summary = "Valid summary",
|
||||
detail = "short", -- Less than 10 chars
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
|
||||
it("should reject whitespace-only summaries", function()
|
||||
local data = {
|
||||
summary = " ",
|
||||
detail = "Some valid detail here",
|
||||
}
|
||||
assert.is_false(pattern_learner.should_learn(data))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
194
tests/spec/coder_context_spec.lua
Normal file
194
tests/spec/coder_context_spec.lua
Normal file
@@ -0,0 +1,194 @@
|
||||
--- Tests for coder file context injection
|
||||
describe("coder context injection", function()
|
||||
local test_dir
|
||||
local original_filereadable
|
||||
|
||||
before_each(function()
|
||||
test_dir = "/tmp/codetyper_coder_test_" .. os.time()
|
||||
vim.fn.mkdir(test_dir, "p")
|
||||
|
||||
-- Store original function
|
||||
original_filereadable = vim.fn.filereadable
|
||||
end)
|
||||
|
||||
after_each(function()
|
||||
vim.fn.delete(test_dir, "rf")
|
||||
vim.fn.filereadable = original_filereadable
|
||||
end)
|
||||
|
||||
describe("get_coder_companion_path logic", function()
|
||||
-- Test the path generation logic (simulating the function behavior)
|
||||
local function get_coder_companion_path(target_path, file_exists_check)
|
||||
if not target_path or target_path == "" then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Skip if target is already a coder file
|
||||
if target_path:match("%.coder%.") then
|
||||
return nil
|
||||
end
|
||||
|
||||
local dir = vim.fn.fnamemodify(target_path, ":h")
|
||||
local name = vim.fn.fnamemodify(target_path, ":t:r")
|
||||
local ext = vim.fn.fnamemodify(target_path, ":e")
|
||||
|
||||
local coder_path = dir .. "/" .. name .. ".coder." .. ext
|
||||
if file_exists_check(coder_path) then
|
||||
return coder_path
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
it("should generate correct coder path for source file", function()
|
||||
local target = "/path/to/file.ts"
|
||||
local expected = "/path/to/file.coder.ts"
|
||||
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
|
||||
assert.equals(expected, path)
|
||||
end)
|
||||
|
||||
it("should return nil for empty path", function()
|
||||
local path = get_coder_companion_path("", function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil for nil path", function()
|
||||
local path = get_coder_companion_path(nil, function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil for coder files (avoid recursion)", function()
|
||||
local target = "/path/to/file.coder.ts"
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should return nil if coder file doesn't exist", function()
|
||||
local target = "/path/to/file.ts"
|
||||
local path = get_coder_companion_path(target, function() return false end)
|
||||
assert.is_nil(path)
|
||||
end)
|
||||
|
||||
it("should handle files with multiple dots", function()
|
||||
local target = "/path/to/my.component.ts"
|
||||
local expected = "/path/to/my.component.coder.ts"
|
||||
|
||||
local path = get_coder_companion_path(target, function() return true end)
|
||||
|
||||
assert.equals(expected, path)
|
||||
end)
|
||||
|
||||
it("should handle different extensions", function()
|
||||
local test_cases = {
|
||||
{ target = "/path/file.lua", expected = "/path/file.coder.lua" },
|
||||
{ target = "/path/file.py", expected = "/path/file.coder.py" },
|
||||
{ target = "/path/file.js", expected = "/path/file.coder.js" },
|
||||
{ target = "/path/file.go", expected = "/path/file.coder.go" },
|
||||
}
|
||||
|
||||
for _, tc in ipairs(test_cases) do
|
||||
local path = get_coder_companion_path(tc.target, function() return true end)
|
||||
assert.equals(tc.expected, path, "Failed for: " .. tc.target)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("coder content filtering", function()
|
||||
-- Test the filtering logic that skips template-only content
|
||||
local function has_meaningful_content(lines)
|
||||
for _, line in ipairs(lines) do
|
||||
local trimmed = line:gsub("^%s*", "")
|
||||
if not trimmed:match("^[%-#/]+%s*Coder companion")
|
||||
and not trimmed:match("^[%-#/]+%s*Use /@ @/")
|
||||
and not trimmed:match("^[%-#/]+%s*Example:")
|
||||
and not trimmed:match("^<!%-%-")
|
||||
and trimmed ~= ""
|
||||
and not trimmed:match("^[%-#/]+%s*$") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
it("should detect meaningful content", function()
|
||||
local lines = {
|
||||
"-- Coder companion for test.lua",
|
||||
"-- This file handles authentication",
|
||||
"/@",
|
||||
"Add login function",
|
||||
"@/",
|
||||
}
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should reject template-only content", function()
|
||||
-- Template lines are filtered by specific patterns
|
||||
-- Only header comments that match the template format are filtered
|
||||
local lines = {
|
||||
"-- Coder companion for test.lua",
|
||||
"-- Use /@ @/ tags to write pseudo-code prompts",
|
||||
"-- Example:",
|
||||
"--",
|
||||
"",
|
||||
}
|
||||
assert.is_false(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should detect pseudo-code content", function()
|
||||
local lines = {
|
||||
"-- Authentication module",
|
||||
"",
|
||||
"-- This module should:",
|
||||
"-- 1. Validate user credentials",
|
||||
"-- 2. Generate JWT tokens",
|
||||
"-- 3. Handle session management",
|
||||
}
|
||||
-- "-- Authentication module" doesn't match template patterns
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should handle JavaScript style comments", function()
|
||||
local lines = {
|
||||
"// Coder companion for test.ts",
|
||||
"// Business logic for user authentication",
|
||||
"",
|
||||
"// The auth flow should:",
|
||||
"// 1. Check OAuth token",
|
||||
"// 2. Validate permissions",
|
||||
}
|
||||
-- "// Business logic..." doesn't match template patterns
|
||||
assert.is_true(has_meaningful_content(lines))
|
||||
end)
|
||||
|
||||
it("should handle empty lines", function()
|
||||
local lines = {
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
}
|
||||
assert.is_false(has_meaningful_content(lines))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("context format", function()
|
||||
it("should format context with proper header", function()
|
||||
local function format_coder_context(content, ext)
|
||||
return string.format(
|
||||
"\n\n--- Business Context / Pseudo-code ---\n" ..
|
||||
"The following describes the intended behavior and design for this file:\n" ..
|
||||
"```%s\n%s\n```",
|
||||
ext,
|
||||
content
|
||||
)
|
||||
end
|
||||
|
||||
local formatted = format_coder_context("-- Auth logic here", "lua")
|
||||
|
||||
assert.is_true(formatted:find("Business Context") ~= nil)
|
||||
assert.is_true(formatted:find("```lua") ~= nil)
|
||||
assert.is_true(formatted:find("Auth logic here") ~= nil)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
161
tests/spec/coder_ignore_spec.lua
Normal file
161
tests/spec/coder_ignore_spec.lua
Normal file
@@ -0,0 +1,161 @@
|
||||
--- Tests for coder file ignore logic
|
||||
describe("coder file ignore logic", function()
|
||||
-- Directories to ignore
|
||||
local ignored_directories = {
|
||||
".git",
|
||||
".coder",
|
||||
".claude",
|
||||
".vscode",
|
||||
".idea",
|
||||
"node_modules",
|
||||
"vendor",
|
||||
"dist",
|
||||
"build",
|
||||
"target",
|
||||
"__pycache__",
|
||||
".cache",
|
||||
".npm",
|
||||
".yarn",
|
||||
"coverage",
|
||||
".next",
|
||||
".nuxt",
|
||||
".svelte-kit",
|
||||
"out",
|
||||
"bin",
|
||||
"obj",
|
||||
}
|
||||
|
||||
-- Files to ignore
|
||||
local ignored_files = {
|
||||
".gitignore",
|
||||
".gitattributes",
|
||||
"package-lock.json",
|
||||
"yarn.lock",
|
||||
".env",
|
||||
".eslintrc",
|
||||
"tsconfig.json",
|
||||
"README.md",
|
||||
"LICENSE",
|
||||
"Makefile",
|
||||
}
|
||||
|
||||
local function is_in_ignored_directory(filepath)
|
||||
for _, dir in ipairs(ignored_directories) do
|
||||
if filepath:match("/" .. dir .. "/") or filepath:match("/" .. dir .. "$") then
|
||||
return true
|
||||
end
|
||||
if filepath:match("^" .. dir .. "/") then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function should_ignore_for_coder(filepath)
|
||||
local filename = vim.fn.fnamemodify(filepath, ":t")
|
||||
|
||||
for _, ignored in ipairs(ignored_files) do
|
||||
if filename == ignored then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
if filename:match("^%.") then
|
||||
return true
|
||||
end
|
||||
|
||||
if is_in_ignored_directory(filepath) then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
describe("ignored directories", function()
|
||||
it("should ignore files in node_modules", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/node_modules/lodash/index.js"))
|
||||
assert.is_true(should_ignore_for_coder("/project/node_modules/react/index.js"))
|
||||
end)
|
||||
|
||||
it("should ignore files in .git", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.git/config"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.git/hooks/pre-commit"))
|
||||
end)
|
||||
|
||||
it("should ignore files in .coder", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.coder/brain/meta.json"))
|
||||
end)
|
||||
|
||||
it("should ignore files in vendor", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/vendor/autoload.php"))
|
||||
end)
|
||||
|
||||
it("should ignore files in dist/build", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/dist/bundle.js"))
|
||||
assert.is_true(should_ignore_for_coder("/project/build/output.js"))
|
||||
end)
|
||||
|
||||
it("should ignore files in __pycache__", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/__pycache__/module.cpython-39.pyc"))
|
||||
end)
|
||||
|
||||
it("should NOT ignore regular source files", function()
|
||||
assert.is_false(should_ignore_for_coder("/project/src/index.ts"))
|
||||
assert.is_false(should_ignore_for_coder("/project/lib/utils.lua"))
|
||||
assert.is_false(should_ignore_for_coder("/project/app/main.py"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("ignored files", function()
|
||||
it("should ignore .gitignore", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.gitignore"))
|
||||
end)
|
||||
|
||||
it("should ignore lock files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/package-lock.json"))
|
||||
assert.is_true(should_ignore_for_coder("/project/yarn.lock"))
|
||||
end)
|
||||
|
||||
it("should ignore config files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/tsconfig.json"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.eslintrc"))
|
||||
end)
|
||||
|
||||
it("should ignore .env files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.env"))
|
||||
end)
|
||||
|
||||
it("should ignore README and LICENSE", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/README.md"))
|
||||
assert.is_true(should_ignore_for_coder("/project/LICENSE"))
|
||||
end)
|
||||
|
||||
it("should ignore hidden/dot files", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/.hidden"))
|
||||
assert.is_true(should_ignore_for_coder("/project/.secret"))
|
||||
end)
|
||||
|
||||
it("should NOT ignore regular source files", function()
|
||||
assert.is_false(should_ignore_for_coder("/project/src/app.ts"))
|
||||
assert.is_false(should_ignore_for_coder("/project/components/Button.tsx"))
|
||||
assert.is_false(should_ignore_for_coder("/project/utils/helpers.js"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edge cases", function()
|
||||
it("should handle nested node_modules", function()
|
||||
assert.is_true(should_ignore_for_coder("/project/packages/core/node_modules/dep/index.js"))
|
||||
end)
|
||||
|
||||
it("should handle files named like directories but not in them", function()
|
||||
-- A file named "node_modules.md" in root should be ignored (starts with .)
|
||||
-- But a file in a folder that contains "node" should NOT be ignored
|
||||
assert.is_false(should_ignore_for_coder("/project/src/node_utils.ts"))
|
||||
end)
|
||||
|
||||
it("should handle relative paths", function()
|
||||
assert.is_true(should_ignore_for_coder("node_modules/lodash/index.js"))
|
||||
assert.is_false(should_ignore_for_coder("src/index.ts"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
371
tests/spec/inject_spec.lua
Normal file
371
tests/spec/inject_spec.lua
Normal file
@@ -0,0 +1,371 @@
|
||||
--- Tests for smart code injection with import handling
|
||||
|
||||
describe("codetyper.agent.inject", function()
|
||||
local inject
|
||||
|
||||
before_each(function()
|
||||
inject = require("codetyper.agent.inject")
|
||||
end)
|
||||
|
||||
describe("parse_code", function()
|
||||
describe("JavaScript/TypeScript", function()
|
||||
it("should detect ES6 named imports", function()
|
||||
local code = [[import { useState, useEffect } from 'react';
|
||||
import { Button } from './components';
|
||||
|
||||
function App() {
|
||||
return <div>Hello</div>;
|
||||
}]]
|
||||
local result = inject.parse_code(code, "typescript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("useState"))
|
||||
assert.truthy(result.imports[2]:match("Button"))
|
||||
assert.truthy(#result.body > 0)
|
||||
end)
|
||||
|
||||
it("should detect ES6 default imports", function()
|
||||
local code = [[import React from 'react';
|
||||
import axios from 'axios';
|
||||
|
||||
const api = axios.create();]]
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("React"))
|
||||
assert.truthy(result.imports[2]:match("axios"))
|
||||
end)
|
||||
|
||||
it("should detect require imports", function()
|
||||
local code = [[const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
module.exports = { fs, path };]]
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("fs"))
|
||||
assert.truthy(result.imports[2]:match("path"))
|
||||
end)
|
||||
|
||||
it("should detect multi-line imports", function()
|
||||
local code = [[import {
|
||||
useState,
|
||||
useEffect,
|
||||
useCallback
|
||||
} from 'react';
|
||||
|
||||
function Component() {}]]
|
||||
local result = inject.parse_code(code, "typescript")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("useState"))
|
||||
assert.truthy(result.imports[1]:match("useCallback"))
|
||||
end)
|
||||
|
||||
it("should detect namespace imports", function()
|
||||
local code = [[import * as React from 'react';
|
||||
|
||||
export default React;]]
|
||||
local result = inject.parse_code(code, "tsx")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("%* as React"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Python", function()
|
||||
it("should detect simple imports", function()
|
||||
local code = [[import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
def main():
|
||||
pass]]
|
||||
local result = inject.parse_code(code, "python")
|
||||
|
||||
assert.equals(3, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("import os"))
|
||||
assert.truthy(result.imports[2]:match("import sys"))
|
||||
assert.truthy(result.imports[3]:match("import json"))
|
||||
end)
|
||||
|
||||
it("should detect from imports", function()
|
||||
local code = [[from typing import List, Dict
|
||||
from pathlib import Path
|
||||
|
||||
def process(items: List[str]) -> None:
|
||||
pass]]
|
||||
local result = inject.parse_code(code, "py")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("from typing"))
|
||||
assert.truthy(result.imports[2]:match("from pathlib"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Lua", function()
|
||||
it("should detect require statements", function()
|
||||
local code = [[local M = {}
|
||||
local utils = require("codetyper.utils")
|
||||
local config = require('codetyper.config')
|
||||
|
||||
function M.setup()
|
||||
end
|
||||
|
||||
return M]]
|
||||
local result = inject.parse_code(code, "lua")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("utils"))
|
||||
assert.truthy(result.imports[2]:match("config"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Go", function()
|
||||
it("should detect single imports", function()
|
||||
local code = [[package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println("Hello")
|
||||
}]]
|
||||
local result = inject.parse_code(code, "go")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match('import "fmt"'))
|
||||
end)
|
||||
|
||||
it("should detect grouped imports", function()
|
||||
local code = [[package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {}]]
|
||||
local result = inject.parse_code(code, "go")
|
||||
|
||||
assert.equals(1, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("fmt"))
|
||||
assert.truthy(result.imports[1]:match("os"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("Rust", function()
|
||||
it("should detect use statements", function()
|
||||
local code = [[use std::io;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
let map = HashMap::new();
|
||||
}]]
|
||||
local result = inject.parse_code(code, "rs")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("std::io"))
|
||||
assert.truthy(result.imports[2]:match("HashMap"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("C/C++", function()
|
||||
it("should detect include statements", function()
|
||||
local code = [[#include <stdio.h>
|
||||
#include "myheader.h"
|
||||
|
||||
int main() {
|
||||
return 0;
|
||||
}]]
|
||||
local result = inject.parse_code(code, "c")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.truthy(result.imports[1]:match("stdio"))
|
||||
assert.truthy(result.imports[2]:match("myheader"))
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("merge_imports", function()
|
||||
it("should merge without duplicates", function()
|
||||
local existing = {
|
||||
"import { useState } from 'react';",
|
||||
"import { Button } from './components';",
|
||||
}
|
||||
local new_imports = {
|
||||
"import { useEffect } from 'react';",
|
||||
"import { useState } from 'react';", -- duplicate
|
||||
"import { Card } from './components';",
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(4, #merged) -- Should not have duplicate useState
|
||||
end)
|
||||
|
||||
it("should handle empty existing imports", function()
|
||||
local existing = {}
|
||||
local new_imports = {
|
||||
"import os",
|
||||
"import sys",
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(2, #merged)
|
||||
end)
|
||||
|
||||
it("should handle empty new imports", function()
|
||||
local existing = {
|
||||
"import os",
|
||||
"import sys",
|
||||
}
|
||||
local new_imports = {}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(2, #merged)
|
||||
end)
|
||||
|
||||
it("should handle whitespace variations in duplicates", function()
|
||||
local existing = {
|
||||
"import { useState } from 'react';",
|
||||
}
|
||||
local new_imports = {
|
||||
"import {useState} from 'react';", -- Same but different spacing
|
||||
}
|
||||
|
||||
local merged = inject.merge_imports(existing, new_imports)
|
||||
|
||||
assert.equals(1, #merged) -- Should detect as duplicate
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("sort_imports", function()
|
||||
it("should group imports by type for JavaScript", function()
|
||||
local imports = {
|
||||
"import React from 'react';",
|
||||
"import { Button } from './components';",
|
||||
"import axios from 'axios';",
|
||||
"import path from 'path';",
|
||||
}
|
||||
|
||||
local sorted = inject.sort_imports(imports, "javascript")
|
||||
|
||||
-- Check ordering: builtin -> third-party -> local
|
||||
local found_builtin = false
|
||||
local found_local = false
|
||||
local builtin_pos = 0
|
||||
local local_pos = 0
|
||||
|
||||
for i, imp in ipairs(sorted) do
|
||||
if imp:match("path") then
|
||||
found_builtin = true
|
||||
builtin_pos = i
|
||||
end
|
||||
if imp:match("%.%/") then
|
||||
found_local = true
|
||||
local_pos = i
|
||||
end
|
||||
end
|
||||
|
||||
-- Local imports should come after third-party
|
||||
if found_local and found_builtin then
|
||||
assert.truthy(local_pos > builtin_pos)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("has_imports", function()
|
||||
it("should return true when code has imports", function()
|
||||
local code = [[import { useState } from 'react';
|
||||
function App() {}]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "typescript"))
|
||||
end)
|
||||
|
||||
it("should return false when code has no imports", function()
|
||||
local code = [[function App() {
|
||||
return <div>Hello</div>;
|
||||
}]]
|
||||
|
||||
assert.is_false(inject.has_imports(code, "typescript"))
|
||||
end)
|
||||
|
||||
it("should detect Python imports", function()
|
||||
local code = [[from typing import List
|
||||
|
||||
def process(items: List[str]):
|
||||
pass]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "python"))
|
||||
end)
|
||||
|
||||
it("should detect Lua requires", function()
|
||||
local code = [[local utils = require("utils")
|
||||
|
||||
local M = {}
|
||||
return M]]
|
||||
|
||||
assert.is_true(inject.has_imports(code, "lua"))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("edge cases", function()
|
||||
it("should handle empty code", function()
|
||||
local result = inject.parse_code("", "javascript")
|
||||
|
||||
assert.equals(0, #result.imports)
|
||||
assert.equals(1, #result.body) -- Empty string becomes one empty line
|
||||
end)
|
||||
|
||||
it("should handle code with only imports", function()
|
||||
local code = [[import React from 'react';
|
||||
import { useState } from 'react';]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(2, #result.imports)
|
||||
assert.equals(0, #result.body)
|
||||
end)
|
||||
|
||||
it("should handle code with only body", function()
|
||||
local code = [[function hello() {
|
||||
console.log("Hello");
|
||||
}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(0, #result.imports)
|
||||
assert.truthy(#result.body > 0)
|
||||
end)
|
||||
|
||||
it("should handle imports in string literals (not detect as imports)", function()
|
||||
local code = [[const example = "import { fake } from 'not-real';";
|
||||
const config = { import: true };
|
||||
|
||||
function test() {}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
-- The first line looks like an import but is in a string
|
||||
-- This is a known limitation - we accept some false positives
|
||||
-- The important thing is we don't break the code
|
||||
assert.truthy(#result.body >= 0)
|
||||
end)
|
||||
|
||||
it("should handle mixed import styles in same file", function()
|
||||
local code = [[import React from 'react';
|
||||
const axios = require('axios');
|
||||
import { useState } from 'react';
|
||||
|
||||
function App() {}]]
|
||||
|
||||
local result = inject.parse_code(code, "javascript")
|
||||
|
||||
assert.equals(3, #result.imports)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
174
tests/spec/llm_selector_spec.lua
Normal file
174
tests/spec/llm_selector_spec.lua
Normal file
@@ -0,0 +1,174 @@
|
||||
--- Tests for smart LLM selection with memory-based confidence
|
||||
|
||||
describe("codetyper.llm.selector", function()
|
||||
local selector
|
||||
|
||||
before_each(function()
|
||||
selector = require("codetyper.llm.selector")
|
||||
-- Reset stats for clean tests
|
||||
selector.reset_accuracy_stats()
|
||||
end)
|
||||
|
||||
describe("select_provider", function()
|
||||
it("should return copilot when no brain memories exist", function()
|
||||
local result = selector.select_provider("write a function", {
|
||||
file_path = "/test/file.lua",
|
||||
})
|
||||
|
||||
assert.equals("copilot", result.provider)
|
||||
assert.equals(0, result.memory_count)
|
||||
assert.truthy(result.reason:match("Insufficient context"))
|
||||
end)
|
||||
|
||||
it("should return a valid selection result structure", function()
|
||||
local result = selector.select_provider("test prompt", {})
|
||||
|
||||
assert.is_string(result.provider)
|
||||
assert.is_number(result.confidence)
|
||||
assert.is_number(result.memory_count)
|
||||
assert.is_string(result.reason)
|
||||
end)
|
||||
|
||||
it("should have confidence between 0 and 1", function()
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
assert.truthy(result.confidence >= 0)
|
||||
assert.truthy(result.confidence <= 1)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("should_ponder", function()
|
||||
it("should return true for medium confidence", function()
|
||||
assert.is_true(selector.should_ponder(0.5))
|
||||
assert.is_true(selector.should_ponder(0.6))
|
||||
end)
|
||||
|
||||
it("should return false for low confidence", function()
|
||||
assert.is_false(selector.should_ponder(0.2))
|
||||
assert.is_false(selector.should_ponder(0.3))
|
||||
end)
|
||||
|
||||
-- High confidence pondering is probabilistic, so we test the range
|
||||
it("should sometimes ponder for high confidence (sampling)", function()
|
||||
-- Run multiple times to test probabilistic behavior
|
||||
local pondered_count = 0
|
||||
for _ = 1, 100 do
|
||||
if selector.should_ponder(0.9) then
|
||||
pondered_count = pondered_count + 1
|
||||
end
|
||||
end
|
||||
-- Should ponder roughly 20% of the time (PONDER_SAMPLE_RATE = 0.2)
|
||||
-- Allow range of 5-40% due to randomness
|
||||
assert.truthy(pondered_count >= 5, "Should ponder at least sometimes")
|
||||
assert.truthy(pondered_count <= 40, "Should not ponder too often")
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("get_accuracy_stats", function()
|
||||
it("should return initial empty stats", function()
|
||||
local stats = selector.get_accuracy_stats()
|
||||
|
||||
assert.equals(0, stats.ollama.total)
|
||||
assert.equals(0, stats.ollama.correct)
|
||||
assert.equals(0, stats.ollama.accuracy)
|
||||
assert.equals(0, stats.copilot.total)
|
||||
assert.equals(0, stats.copilot.correct)
|
||||
assert.equals(0, stats.copilot.accuracy)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("report_feedback", function()
|
||||
it("should track positive feedback", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(3, stats.ollama.total)
|
||||
assert.equals(2, stats.ollama.correct)
|
||||
end)
|
||||
|
||||
it("should track copilot feedback separately", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("copilot", true)
|
||||
selector.report_feedback("copilot", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(1, stats.ollama.total)
|
||||
assert.equals(2, stats.copilot.total)
|
||||
end)
|
||||
|
||||
it("should calculate accuracy correctly", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("ollama", false)
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(0.75, stats.ollama.accuracy)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("reset_accuracy_stats", function()
|
||||
it("should clear all stats", function()
|
||||
selector.report_feedback("ollama", true)
|
||||
selector.report_feedback("copilot", true)
|
||||
|
||||
selector.reset_accuracy_stats()
|
||||
|
||||
local stats = selector.get_accuracy_stats()
|
||||
assert.equals(0, stats.ollama.total)
|
||||
assert.equals(0, stats.copilot.total)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("agreement calculation", function()
|
||||
-- Test the internal agreement calculation through pondering behavior
|
||||
-- Since calculate_agreement is local, we test its effects indirectly
|
||||
|
||||
it("should detect high agreement for similar responses", function()
|
||||
-- This is tested through the pondering system
|
||||
-- When responses are similar, agreement should be high
|
||||
local selector = require("codetyper.llm.selector")
|
||||
|
||||
-- Verify that should_ponder returns predictable results
|
||||
-- for medium confidence (where pondering always happens)
|
||||
assert.is_true(selector.should_ponder(0.5))
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("provider selection with accuracy history", function()
|
||||
local selector
|
||||
|
||||
before_each(function()
|
||||
selector = require("codetyper.llm.selector")
|
||||
selector.reset_accuracy_stats()
|
||||
end)
|
||||
|
||||
it("should factor in historical accuracy for selection", function()
|
||||
-- Simulate high Ollama accuracy
|
||||
for _ = 1, 10 do
|
||||
selector.report_feedback("ollama", true)
|
||||
end
|
||||
|
||||
-- Even with no brain context, historical accuracy should influence confidence
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
-- Confidence should be higher due to historical accuracy
|
||||
-- but provider might still be copilot if no memories
|
||||
assert.is_number(result.confidence)
|
||||
end)
|
||||
|
||||
it("should have lower confidence for low historical accuracy", function()
|
||||
-- Simulate low Ollama accuracy
|
||||
for _ = 1, 10 do
|
||||
selector.report_feedback("ollama", false)
|
||||
end
|
||||
|
||||
local result = selector.select_provider("test", {})
|
||||
|
||||
-- With bad history and no memories, should definitely use copilot
|
||||
assert.equals("copilot", result.provider)
|
||||
end)
|
||||
end)
|
||||
Reference in New Issue
Block a user