Files
avante.nvim/lua/avante/rag_service.lua
Francesco Tassi 232c9a635c feat: improve avante-rag-service container execution (#1448)
* Refactor Docker mount to mount only user home

Mounting the whole filesystem expose the user to security risks,
considering the container is running are root.

This mounts only the user home directory in the container, to mitigate
the security risks. The user home directory is mounted in read only mode
to even reduce the risks of accidental or malicious modifications.

Mounting the whole should allow the user to have multiple neovim instances runinng at
the same time and sharing the same rag_service.

Also the container is started with the --rm flag to remove it after it stops.

* RAG mount point is not configurable

* Remove useless filter.lua file

* Use Path to join paths

This should be more safe than just concatenating strings.
2025-03-05 16:18:52 +08:00

395 lines
12 KiB
Lua

local curl = require("plenary.curl")
local Path = require("plenary.path")
local Config = require("avante.config")
local Utils = require("avante.utils")
local Config = require("avante.config")
local M = {}
local container_name = "avante-rag-service"
local service_path = "/tmp/" .. container_name
function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.6" end
function M.get_rag_service_port() return 20250 end
function M.get_rag_service_url() return string.format("http://localhost:%d", M.get_rag_service_port()) end
function M.get_data_path()
local p = Path:new(vim.fn.stdpath("data")):joinpath("avante/rag_service")
if not p:exists() then p:mkdir({ parents = true }) end
return p
end
function M.get_current_image()
local cmd = string.format("docker inspect %s | grep Image | grep %s", container_name, container_name)
local result = vim.fn.system(cmd)
if result == "" then return nil end
local exit_code = vim.v.shell_error
if exit_code ~= 0 then return nil end
local image = result:match('"Image":%s*"(.*)"')
if image == nil then return nil end
return image
end
function M.get_rag_service_runner() return (Config.rag_service and Config.rag_service.runner) or "docker" end
---@param cb fun()
function M.launch_rag_service(cb)
local openai_api_key = os.getenv("OPENAI_API_KEY")
if Config.rag_service.provider == "openai" then
if openai_api_key == nil then
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
return
end
end
local port = M.get_rag_service_port()
if M.get_rag_service_runner() == "docker" then
local image = M.get_rag_service_image()
local data_path = M.get_data_path()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result ~= "" then
Utils.debug(string.format("container %s already running", container_name))
local current_image = M.get_current_image()
if current_image == image then
cb()
return
end
Utils.debug(
string.format(
"container %s is running with different image: %s != %s, stopping...",
container_name,
current_image,
image
)
)
M.stop_rag_service()
else
Utils.debug(string.format("container %s not found, starting...", container_name))
end
local cmd_ = string.format(
"docker run --rm -d -p %d:8000 --name %s -v %s:/data -v %s:/host:ro -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s",
port,
container_name,
data_path,
Config.rag_service.host_mount,
Config.rag_service.provider,
Config.rag_service.provider:upper(),
openai_api_key,
Config.rag_service.provider:upper(),
Config.rag_service.endpoint,
Config.rag_service.llm_model,
Config.rag_service.embed_model,
image
)
vim.fn.jobstart(cmd_, {
detach = true,
on_exit = function(_, exit_code)
if exit_code ~= 0 then
Utils.error(string.format("container %s failed to start, exit code: %d", container_name, exit_code))
else
Utils.debug(string.format("container %s started", container_name))
cb()
end
end,
})
elseif M.get_rag_service_runner() == "nix" then
-- Check if service is already running
local check_cmd = string.format("pgrep -f '%s'", service_path)
local check_result = vim.fn.system(check_cmd)
if check_result ~= "" then
Utils.debug(string.format("RAG service already running at %s", service_path))
cb()
return
end
local dirname =
Utils.trim(string.sub(debug.getinfo(1).source, 2, #"/lua/avante/rag_service.lua" * -1), { suffix = "/" })
local rag_service_dir = dirname .. "/py/rag-service"
Utils.debug(string.format("launching %s with nix...", container_name))
local cmd = string.format(
"cd %s && PORT=%d DATA_DIR=%s RAG_PROVIDER=%s %s_API_KEY=%s %s_API_BASE=%s RAG_LLM_MODEL=%s RAG_EMBED_MODEL=%s sh run.sh %s",
rag_service_dir,
port,
service_path,
Config.rag_service.provider,
Config.rag_service.provider:upper(),
openai_api_key,
Config.rag_service.provider:upper(),
Config.rag_service.endpoint,
Config.rag_service.llm_model,
Config.rag_service.embed_model,
service_path
)
vim.fn.jobstart(cmd, {
detach = true,
on_exit = function(_, exit_code)
if exit_code ~= 0 then
Utils.error(string.format("service %s failed to start, exit code: %d", container_name, exit_code))
else
Utils.debug(string.format("service %s started", container_name))
cb()
end
end,
})
end
end
function M.stop_rag_service()
if M.get_rag_service_runner() == "docker" then
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result ~= "" then vim.fn.system(string.format("docker rm -fv %s", container_name)) end
else
local cmd = string.format("pgrep -f '%s' | xargs -r kill -9", service_path)
vim.fn.system(cmd)
Utils.debug(string.format("Attempted to kill processes related to %s", service_path))
end
end
function M.get_rag_service_status()
if M.get_rag_service_runner() == "docker" then
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result == "" then
return "stopped"
else
return "running"
end
elseif M.get_rag_service_runner() == "nix" then
local cmd = string.format("pgrep -f '%s'", service_path)
local result = vim.fn.system(cmd)
if result == "" then
return "stopped"
else
return "running"
end
end
end
function M.get_scheme(uri)
local scheme = uri:match("^(%w+)://")
if scheme == nil then return "unknown" end
return scheme
end
function M.to_container_uri(uri)
local runner = M.get_rag_service_runner()
if runner == "nix" then return uri end
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file://(.*)$")
local host_dir = Config.rag_service.host_mount
if path:sub(1, #host_dir) == host_dir then path = "/host" .. path:sub(#host_dir + 1) end
uri = string.format("file://%s", path)
end
return uri
end
function M.to_local_uri(uri)
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file:///host(.*)$")
local host_dir = Config.rag_service.host_mount
local full_path = Path:new(host_dir):joinpath(path:sub(2)):absolute()
uri = string.format("file://%s", full_path)
end
return uri
end
function M.is_ready()
vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url()))
return vim.v.shell_error == 0
end
---@class AvanteRagServiceAddResourceResponse
---@field status string
---@field message string
---@param uri string
function M.add_resource(uri)
uri = M.to_container_uri(uri)
local resource_name = uri:match("([^/]+)/$")
local resources_resp = M.get_resources()
if resources_resp == nil then
Utils.error("Failed to get resources")
return nil
end
local already_added = false
for _, resource in ipairs(resources_resp.resources) do
if resource.uri == uri then
already_added = true
resource_name = resource.name
break
end
end
if not already_added then
local names_map = {}
for _, resource in ipairs(resources_resp.resources) do
names_map[resource.name] = true
end
if names_map[resource_name] then
for i = 1, 100 do
local resource_name_ = string.format("%s-%d", resource_name, i)
if not names_map[resource_name_] then
resource_name = resource_name_
break
end
end
if names_map[resource_name] then
Utils.error(string.format("Failed to add resource, name conflict: %s", resource_name))
return nil
end
end
end
local cmd = {
"curl",
"-X",
"POST",
M.get_rag_service_url() .. "/api/v1/add_resource",
"-H",
"Content-Type: application/json",
"-d",
vim.json.encode({ name = resource_name, uri = uri }),
}
vim.system(cmd, { text = true }, function(output)
if output.code == 0 then
Utils.debug(string.format("Added resource: %s", uri))
else
Utils.error(string.format("Failed to add resource: %s; output: %s", uri, output.stderr))
end
end)
end
function M.remove_resource(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/remove_resource", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("failed to remove resource: " .. resp.body)
return
end
return vim.json.decode(resp.body)
end
---@class AvanteRagServiceRetrieveSource
---@field uri string
---@field content string
---@class AvanteRagServiceRetrieveResponse
---@field response string
---@field sources AvanteRagServiceRetrieveSource[]
---@param base_uri string
---@param query string
---@return AvanteRagServiceRetrieveResponse | nil resp
---@return string | nil error
function M.retrieve(base_uri, query)
base_uri = M.to_container_uri(base_uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
base_uri = base_uri,
query = query,
top_k = 10,
}),
timeout = 100000,
})
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
return nil, "failed to retrieve: " .. resp.body
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
return jsn, nil
end
---@class AvanteRagServiceIndexingStatusSummary
---@field indexing integer
---@field completed integer
---@field failed integer
---@class AvanteRagServiceIndexingStatusResponse
---@field uri string
---@field is_watched boolean
---@field total_files integer
---@field status_summary AvanteRagServiceIndexingStatusSummary
---@param uri string
---@return AvanteRagServiceIndexingStatusResponse | nil
function M.indexing_status(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/indexing_status", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("Failed to get indexing status: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.uri = M.to_local_uri(jsn.uri)
return jsn
end
---@class AvanteRagServiceResource
---@field name string
---@field uri string
---@field type string
---@field status string
---@field indexing_status string
---@field created_at string
---@field indexing_started_at string | nil
---@field last_indexed_at string | nil
---@class AvanteRagServiceResourceListResponse
---@field resources AvanteRagServiceResource[]
---@field total_count number
---@return AvanteRagServiceResourceListResponse | nil
function M.get_resources()
local resp = curl.get(M.get_rag_service_url() .. "/api/v1/resources", {
headers = {
["Content-Type"] = "application/json",
},
})
if resp.status ~= 200 then
Utils.error("Failed to get resources: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.resources = vim
.iter(jsn.resources)
:map(function(resource)
local uri = M.to_local_uri(resource.uri)
return vim.tbl_deep_extend("force", resource, { uri = uri })
end)
:totable()
return jsn
end
return M