Files
avante.nvim/lua/avante/rag_service.lua
edmundhighcock f663865186 Added nbconvert needed for rag indexing jupyter notebooks (#2862)
* Added nbconvert needed for rag indexing jupyter notebooks

* Allow rag service image to be configured

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
2025-12-30 11:59:52 +08:00

449 lines
14 KiB
Lua

local curl = require("plenary.curl")
local Path = require("plenary.path")
local Config = require("avante.config")
local Utils = require("avante.utils")
local M = {}
local container_name = "avante-rag-service"
local service_path = "/tmp/" .. container_name
function M.get_rag_service_image()
if Config.rag_service and Config.rag_service.image then
return Config.rag_service.image
else
return "quay.io/yetoneful/avante-rag-service:0.0.11"
end
end
function M.get_rag_service_port() return 20250 end
function M.get_rag_service_url() return string.format("http://localhost:%d", M.get_rag_service_port()) end
function M.get_data_path()
local p = Path:new(vim.fn.stdpath("data")):joinpath("avante/rag_service")
if not p:exists() then p:mkdir({ parents = true }) end
return p
end
function M.get_current_image()
local cmd = { "docker", "inspect", "--format", "{{.Config.Image}}", container_name }
local result = vim.system(cmd, { text = true }):wait()
if result.code ~= 0 or result.stdout == "" then return nil end
return result.stdout
end
function M.get_rag_service_runner() return (Config.rag_service and Config.rag_service.runner) or "docker" end
---@param cb fun()
function M.launch_rag_service(cb)
--- If Config.rag_service.llm.api_key is nil or empty, llm_api_key will be an empty string.
local llm_api_key = ""
if
Config.rag_service
and Config.rag_service.llm
and Config.rag_service.llm.api_key
and Config.rag_service.llm.api_key ~= ""
then
llm_api_key = os.getenv(Config.rag_service.llm.api_key) or ""
if llm_api_key == nil or llm_api_key == "" then
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.llm.api_key))
return
end
end
--- If Config.rag_service.embed.api_key is nil or empty, embed_api_key will be an empty string.
local embed_api_key = ""
if
Config.rag_service
and Config.rag_service.embed
and Config.rag_service.embed.api_key
and Config.rag_service.embed.api_key ~= ""
then
embed_api_key = os.getenv(Config.rag_service.embed.api_key) or ""
if embed_api_key == nil or embed_api_key == "" then
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.embed.api_key))
return
end
end
local embed_extra = "{}" -- Default to empty JSON object string
if Config.rag_service and Config.rag_service.embed and Config.rag_service.embed.extra then
embed_extra = string.format("%q", vim.json.encode(Config.rag_service.embed.extra))
end
local llm_extra = "{}" -- Default to empty JSON object string
if Config.rag_service and Config.rag_service.llm and Config.rag_service.llm.extra then
llm_extra = string.format("%q", vim.json.encode(Config.rag_service.llm.extra))
end
local port = M.get_rag_service_port()
if M.get_rag_service_runner() == "docker" then
local image = M.get_rag_service_image()
local data_path = M.get_data_path()
local cmd = { "docker", "inspect", "--format", "{{.State.Status}}", container_name }
local result = vim.system(cmd, { text = true }):wait()
if result.code ~= 0 then Utils.debug(string.format("cmd: %s execution error", table.concat(cmd, " "))) end
if result.stdout == "" then
Utils.debug(string.format("container %s not found, starting...", container_name))
elseif result.stdout == "running" then
Utils.debug(string.format("container %s already running", container_name))
local current_image = M.get_current_image()
if current_image == image then
cb()
return
end
Utils.debug(
string.format(
"container %s is running with different image: %s != %s, stopping...",
container_name,
current_image,
image
)
)
M.stop_rag_service()
end
if result.stdout ~= "running" then
Utils.info(string.format("container %s already started but not running, stopping...", container_name))
M.stop_rag_service()
end
local cmd_ = string.format(
"docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_EMBED_PROVIDER=%s -e RAG_EMBED_ENDPOINT=%s -e RAG_EMBED_API_KEY=%s -e RAG_EMBED_MODEL=%s -e RAG_EMBED_EXTRA=%s -e RAG_LLM_PROVIDER=%s -e RAG_LLM_ENDPOINT=%s -e RAG_LLM_API_KEY=%s -e RAG_LLM_MODEL=%s -e RAG_LLM_EXTRA=%s %s %s",
M.get_rag_service_port(),
M.get_rag_service_port(),
container_name,
data_path,
Config.rag_service.host_mount,
Config.rag_service.embed.provider,
Config.rag_service.embed.endpoint,
embed_api_key,
Config.rag_service.embed.model,
embed_extra,
Config.rag_service.llm.provider,
Config.rag_service.llm.endpoint,
llm_api_key,
Config.rag_service.llm.model,
llm_extra,
Config.rag_service.docker_extra_args,
image
)
vim.fn.jobstart(cmd_, {
detach = true,
on_exit = function(_, exit_code)
if exit_code ~= 0 then
Utils.error(string.format("container %s failed to start, exit code: %d", container_name, exit_code))
else
Utils.debug(string.format("container %s started", container_name))
cb()
end
end,
})
elseif M.get_rag_service_runner() == "nix" then
-- Check if service is already running
local check_cmd = { "pgrep", "-f", service_path }
local check_result = vim.system(check_cmd, { text = true }):wait().stdout
if check_result ~= "" then
Utils.debug(string.format("RAG service already running at %s", service_path))
cb()
return
end
local dirname =
Utils.trim(string.sub(debug.getinfo(1).source, 2, #"/lua/avante/rag_service.lua" * -1), { suffix = "/" })
local rag_service_dir = dirname .. "/py/rag-service"
Utils.debug(string.format("launching %s with nix...", container_name))
vim.system({ "sh", "run.sh", service_path }, {
detach = true,
cwd = rag_service_dir,
env = {
ALLOW_RESET = "TRUE",
PORT = port,
DATA_DIR = service_path,
RAG_EMBED_PROVIDER = Config.rag_service.embed.provider,
RAG_EMBED_ENDPOINT = Config.rag_service.embed.endpoint,
RAG_EMBED_API_KEY = embed_api_key,
RAG_EMBED_MODEL = Config.rag_service.embed.model,
RAG_EMBED_EXTRA = embed_extra,
RAG_LLM_PROVIDER = Config.rag_service.llm.provider,
RAG_LLM_ENDPOINT = Config.rag_service.llm.endpoint,
RAG_LLM_API_KEY = llm_api_key,
RAG_LLM_MODEL = Config.rag_service.llm.model,
RAG_LLM_EXTRA = llm_extra,
},
}, function(res)
if res.code ~= 0 then
Utils.error(string.format("service %s failed to start, exit code: %d", container_name, res.code))
else
Utils.debug(string.format("service %s started", container_name))
cb()
end
end)
end
end
function M.stop_rag_service()
if M.get_rag_service_runner() == "docker" then
local cmd = { "docker", "inspect", "--format", "{{.State.Status}}", container_name }
local result = vim.system(cmd, { text = true }):wait().stdout
if result ~= "" then vim.system({ "docker", "rm", "-fv", container_name }):wait() end
else
local pid = vim.system({ "pgrep", "-f", service_path }, { text = true }):wait().stdout
if pid ~= "" then
vim.system({ "kill", "-9", pid }):wait()
Utils.debug(string.format("Attempted to kill processes related to %s", service_path))
end
end
end
function M.get_rag_service_status()
if M.get_rag_service_runner() == "docker" then
local cmd = { "docker", "inspect", "--format", "{{.State.Status}}", container_name }
local result = vim.system(cmd, { text = true }):wait().stdout
if result ~= "running" then
return "stopped"
else
return "running"
end
elseif M.get_rag_service_runner() == "nix" then
local cmd = { "pgrep", "-f", service_path }
local result = vim.system(cmd, { text = true }):wait().stdout
if result == "" then
return "stopped"
else
return "running"
end
end
end
function M.get_scheme(uri)
local scheme = uri:match("^(%w+)://")
if scheme == nil then return "unknown" end
return scheme
end
function M.to_container_uri(uri)
local runner = M.get_rag_service_runner()
if runner == "nix" then return uri end
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file://(.*)$")
local host_dir = Config.rag_service.host_mount
if path:sub(1, #host_dir) == host_dir then path = "/host" .. path:sub(#host_dir + 1) end
uri = string.format("file://%s", path)
end
return uri
end
function M.to_local_uri(uri)
local scheme = M.get_scheme(uri)
local path = uri:match("^file:///host(.*)$")
if scheme == "file" and path ~= nil then
local host_dir = Config.rag_service.host_mount
local full_path = Path:new(host_dir):joinpath(path:sub(2)):absolute()
uri = string.format("file://%s", full_path)
end
return uri
end
function M.is_ready()
return vim
.system(
{ "curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", M.get_rag_service_url() .. "/api/health" },
{ text = true }
)
:wait().code == 0
end
---@class AvanteRagServiceAddResourceResponse
---@field status string
---@field message string
---@param uri string
function M.add_resource(uri)
uri = M.to_container_uri(uri)
local resource_name = uri:match("([^/]+)/$")
local resources_resp = M.get_resources()
if resources_resp == nil then
Utils.error("Failed to get resources")
return nil
end
local already_added = false
for _, resource in ipairs(resources_resp.resources) do
if resource.uri == uri then
already_added = true
resource_name = resource.name
break
end
end
if not already_added then
local names_map = {}
for _, resource in ipairs(resources_resp.resources) do
names_map[resource.name] = true
end
if names_map[resource_name] then
for i = 1, 100 do
local resource_name_ = string.format("%s-%d", resource_name, i)
if not names_map[resource_name_] then
resource_name = resource_name_
break
end
end
if names_map[resource_name] then
Utils.error(string.format("Failed to add resource, name conflict: %s", resource_name))
return nil
end
end
end
local cmd = {
"curl",
"-X",
"POST",
M.get_rag_service_url() .. "/api/v1/add_resource",
"-H",
"Content-Type: application/json",
"-d",
vim.json.encode({ name = resource_name, uri = uri }),
}
vim.system(cmd, { text = true }, function(output)
if output.code == 0 then
Utils.debug(string.format("Added resource: %s", uri))
else
Utils.error(string.format("Failed to add resource: %s; output: %s", uri, output.stderr))
end
end)
end
function M.remove_resource(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/remove_resource", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("failed to remove resource: " .. resp.body)
return
end
return vim.json.decode(resp.body)
end
---@class AvanteRagServiceRetrieveSource
---@field uri string
---@field content string
---@class AvanteRagServiceRetrieveResponse
---@field response string
---@field sources AvanteRagServiceRetrieveSource[]
---@param base_uri string
---@param query string
---@param on_complete fun(resp: AvanteRagServiceRetrieveResponse | nil, error: string | nil): nil
function M.retrieve(base_uri, query, on_complete)
base_uri = M.to_container_uri(base_uri)
curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
base_uri = base_uri,
query = query,
top_k = 10,
}),
timeout = 100000,
callback = function(resp)
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
on_complete(nil, resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
on_complete(jsn, nil)
end,
})
end
---@class AvanteRagServiceIndexingStatusSummary
---@field indexing integer
---@field completed integer
---@field failed integer
---@class AvanteRagServiceIndexingStatusResponse
---@field uri string
---@field is_watched boolean
---@field total_files integer
---@field status_summary AvanteRagServiceIndexingStatusSummary
---@param uri string
---@return AvanteRagServiceIndexingStatusResponse | nil
function M.indexing_status(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/indexing_status", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("Failed to get indexing status: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.uri = M.to_local_uri(jsn.uri)
return jsn
end
---@class AvanteRagServiceResource
---@field name string
---@field uri string
---@field type string
---@field status string
---@field indexing_status string
---@field created_at string
---@field indexing_started_at string | nil
---@field last_indexed_at string | nil
---@class AvanteRagServiceResourceListResponse
---@field resources AvanteRagServiceResource[]
---@field total_count number
---@return AvanteRagServiceResourceListResponse | nil
function M.get_resources()
local resp = curl.get(M.get_rag_service_url() .. "/api/v1/resources", {
headers = {
["Content-Type"] = "application/json",
},
})
if resp.status ~= 200 then
Utils.error("Failed to get resources: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.resources = vim
.iter(jsn.resources)
:map(function(resource)
local uri = M.to_local_uri(resource.uri)
return vim.tbl_deep_extend("force", resource, { uri = uri })
end)
:totable()
return jsn
end
return M