diff --git a/README.md b/README.md index dd64181..3a57bd1 100644 --- a/README.md +++ b/README.md @@ -658,6 +658,7 @@ Avante provides a RAG service, which is a tool for obtaining the required contex ```lua rag_service = { enabled = false, -- Enables the RAG service, requires OPENAI_API_KEY to be set + host_mount = os.getenv("HOME"), -- Host mount path for the rag service provider = "openai", -- The provider to use for RAG service (e.g. openai or ollama) llm_model = "", -- The LLM model to use for RAG service embed_model = "", -- The embedding model to use for RAG service @@ -667,7 +668,13 @@ rag_service = { Please note that since the RAG service uses OpenAI for embeddings, you must set `OPENAI_API_KEY` environment variable! -Additionally, RAG Service also depends on Docker! (For macOS users, OrbStack is recommended as a Docker alternative) +Additionally, RAG Service also depends on Docker! (For macOS users, OrbStack is recommended as a Docker alternative). +`host_mount` is the path that will be mounted to the container, and the default is the home directory. The mount is required +for the RAG service to access the files in the host machine. It is up to the user to decide if you want to mount the whole +`/` directory, just the project directory, or the home directory. If you plan using avante and RAG event for projects +stored outside your home directory, you will need to set the `host_mount` to the root directory of your file system. + +The mount will be read only. ## Web Search Engines diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 68515cf..f53f53c 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -35,6 +35,7 @@ M._defaults = { tokenizer = "tiktoken", rag_service = { enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set + host_mount = os.getenv("HOME"), -- Host mount path for the rag service (docker will mount this path) runner = "docker", -- The runner for the rag service, (can use docker, or nix) provider = "openai", -- The provider to use for RAG service. eg: openai or ollama llm_model = "", -- The LLM model to use for RAG service diff --git a/lua/avante/rag_service.lua b/lua/avante/rag_service.lua index e0fadac..962b41e 100644 --- a/lua/avante/rag_service.lua +++ b/lua/avante/rag_service.lua @@ -2,6 +2,7 @@ local curl = require("plenary.curl") local Path = require("plenary.path") local Config = require("avante.config") local Utils = require("avante.utils") +local Config = require("avante.config") local M = {} @@ -69,10 +70,11 @@ function M.launch_rag_service(cb) Utils.debug(string.format("container %s not found, starting...", container_name)) end local cmd_ = string.format( - "docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s", + "docker run --rm -d -p %d:8000 --name %s -v %s:/data -v %s:/host:ro -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s", port, container_name, data_path, + Config.rag_service.host_mount, Config.rag_service.provider, Config.rag_service.provider:upper(), openai_api_key, @@ -182,7 +184,9 @@ function M.to_container_uri(uri) local scheme = M.get_scheme(uri) if scheme == "file" then local path = uri:match("^file://(.*)$") - uri = string.format("file:///host%s", path) + local host_dir = Config.rag_service.host_mount + if path:sub(1, #host_dir) == host_dir then path = "/host" .. path:sub(#host_dir + 1) end + uri = string.format("file://%s", path) end return uri end @@ -190,8 +194,10 @@ end function M.to_local_uri(uri) local scheme = M.get_scheme(uri) if scheme == "file" then - local path = uri:match("^file://host(.*)$") - uri = string.format("file://%s", path) + local path = uri:match("^file:///host(.*)$") + local host_dir = Config.rag_service.host_mount + local full_path = Path:new(host_dir):joinpath(path:sub(2)):absolute() + uri = string.format("file://%s", full_path) end return uri end diff --git a/tests/rag_service_spec.lua b/tests/rag_service_spec.lua new file mode 100644 index 0000000..8d685d2 --- /dev/null +++ b/tests/rag_service_spec.lua @@ -0,0 +1,38 @@ +local mock = require("luassert.mock") +local match = require("luassert.match") + +describe("RagService", function() + local RagService + local Config_mock + + before_each(function() + -- Load the module before each test + RagService = require("avante.rag_service") + + -- Setup common mocks + Config_mock = mock(require("avante.config"), true) + Config_mock.rag_service = { host_mount = "/home/user" } + end) + + after_each(function() + -- Clean up after each test + package.loaded["avante.rag_service"] = nil + mock.revert(Config_mock) + end) + + describe("URI conversion functions", function() + it("should convert URIs between host and container formats", function() + -- Test both directions of conversion + local host_uri = "file:///home/user/project/file.txt" + local container_uri = "file:///host/project/file.txt" + + -- Host to container + local result1 = RagService.to_container_uri(host_uri) + assert.equals(container_uri, result1) + + -- Container to host + local result2 = RagService.to_local_uri(container_uri) + assert.equals(host_uri, result2) + end) + end) +end)