From e47c27af664c8eafc5ced59338878ac4c7685f14 Mon Sep 17 00:00:00 2001 From: yetone Date: Sun, 23 Feb 2025 21:53:40 +0800 Subject: [PATCH] fix: python llm tool with confirmation (#1365) --- lua/avante/llm_tools.lua | 46 +++++++++++++++++++++++++++++------ tests/llm_tools_spec.lua | 52 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index a8ca520..a06421d 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -12,7 +12,9 @@ local M = {} local function get_abs_path(rel_path) if Path:new(rel_path):is_absolute() then return rel_path end local project_root = Utils.get_project_root() - return Path:new(project_root):joinpath(rel_path):absolute() + local p = tostring(Path:new(project_root):joinpath(rel_path):absolute()) + if p:sub(-2) == "/." then p = p:sub(1, -3) end + return p end function M.confirm(msg) @@ -552,7 +554,7 @@ function M.rag_search(opts, on_log) return vim.json.encode(resp), nil end ----@param opts { code: string, rel_path: string } +---@param opts { code: string, rel_path: string, container_image?: string } ---@param on_log? fun(log: string): nil ---@return string|nil result ---@return string|nil error @@ -562,15 +564,45 @@ function M.python(opts, on_log) if not Path:new(abs_path):exists() then return nil, "Path not found: " .. abs_path end if on_log then on_log("cwd: " .. abs_path) end if on_log then on_log("code: " .. opts.code) end + if + not M.confirm( + "Are you sure you want to run the python code in the contianer in the directory: `" + .. abs_path + .. "`? code: " + .. opts.code + ) + then + return nil, "User canceled" + end + local container_image = opts.container_image or "python:3.11-slim-bookworm" ---change cwd to abs_path local old_cwd = vim.fn.getcwd() + vim.fn.chdir(abs_path) - local output = vim.fn.system({ "python", "-c", opts.code }) - local exit_code = vim.v.shell_error + local output = vim + .system({ + "docker", + "run", + "--rm", + "-v", + abs_path .. ":" .. abs_path, + "-w", + abs_path, + container_image, + "python", + "-c", + opts.code, + }, { + text = true, + }) + :wait() + vim.fn.chdir(old_cwd) - if exit_code ~= 0 then return nil, "Error: " .. output end - Utils.debug("output", output) - return output, nil + + if output.code ~= 0 then return nil, "Error: " .. (output.stderr or "Unknown error") end + + Utils.debug("output", output.stdout) + return output.stdout, nil end ---@return AvanteLLMTool[] diff --git a/tests/llm_tools_spec.lua b/tests/llm_tools_spec.lua index 95dacb7..a362b23 100644 --- a/tests/llm_tools_spec.lua +++ b/tests/llm_tools_spec.lua @@ -218,4 +218,56 @@ describe("llm_tools", function() assert.truthy(err:find("No permission to access path")) end) end) + + describe("python", function() + local original_system = vim.fn.system + + it("should execute Python code and return output", function() + local result, err = LlmTools.python({ + rel_path = ".", + code = "print('Hello from Python')", + }) + assert.is_nil(err) + assert.equals("Hello from Python\n", result) + end) + + it("should handle Python errors", function() + local result, err = LlmTools.python({ + rel_path = ".", + code = "print(undefined_variable)", + }) + assert.is_nil(result) + assert.truthy(err) + assert.truthy(err:find("Error")) + end) + + it("should respect path permissions", function() + local result, err = LlmTools.python({ + rel_path = "../outside_project", + code = "print('test')", + }) + assert.is_nil(result) + assert.truthy(err:find("No permission to access path")) + end) + + it("should handle non-existent paths", function() + local result, err = LlmTools.python({ + rel_path = "non_existent_dir", + code = "print('test')", + }) + assert.is_nil(result) + assert.truthy(err:find("Path not found")) + end) + + it("should support custom container image", function() + os.execute("docker image rm python:3.12-slim") + local result, err = LlmTools.python({ + rel_path = ".", + code = "print('Hello from custom container')", + container_image = "python:3.12-slim", + }) + assert.is_nil(err) + assert.equals("Hello from custom container\n", result) + end) + end) end)