* Added nbconvert needed for rag indexing jupyter notebooks * Allow rag service image to be configured * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
449 lines
14 KiB
Lua
449 lines
14 KiB
Lua
local curl = require("plenary.curl")
|
|
local Path = require("plenary.path")
|
|
local Config = require("avante.config")
|
|
local Utils = require("avante.utils")
|
|
|
|
local M = {}
|
|
|
|
local container_name = "avante-rag-service"
|
|
local service_path = "/tmp/" .. container_name
|
|
|
|
function M.get_rag_service_image()
|
|
if Config.rag_service and Config.rag_service.image then
|
|
return Config.rag_service.image
|
|
else
|
|
return "quay.io/yetoneful/avante-rag-service:0.0.11"
|
|
end
|
|
end
|
|
|
|
function M.get_rag_service_port() return 20250 end
|
|
|
|
function M.get_rag_service_url() return string.format("http://localhost:%d", M.get_rag_service_port()) end
|
|
|
|
function M.get_data_path()
|
|
local p = Path:new(vim.fn.stdpath("data")):joinpath("avante/rag_service")
|
|
if not p:exists() then p:mkdir({ parents = true }) end
|
|
return p
|
|
end
|
|
|
|
function M.get_current_image()
|
|
local cmd = { "docker", "inspect", "--format", "{{.Config.Image}}", container_name }
|
|
local result = vim.system(cmd, { text = true }):wait()
|
|
if result.code ~= 0 or result.stdout == "" then return nil end
|
|
return result.stdout
|
|
end
|
|
|
|
function M.get_rag_service_runner() return (Config.rag_service and Config.rag_service.runner) or "docker" end
|
|
|
|
---@param cb fun()
|
|
function M.launch_rag_service(cb)
|
|
--- If Config.rag_service.llm.api_key is nil or empty, llm_api_key will be an empty string.
|
|
local llm_api_key = ""
|
|
if
|
|
Config.rag_service
|
|
and Config.rag_service.llm
|
|
and Config.rag_service.llm.api_key
|
|
and Config.rag_service.llm.api_key ~= ""
|
|
then
|
|
llm_api_key = os.getenv(Config.rag_service.llm.api_key) or ""
|
|
if llm_api_key == nil or llm_api_key == "" then
|
|
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.llm.api_key))
|
|
return
|
|
end
|
|
end
|
|
|
|
--- If Config.rag_service.embed.api_key is nil or empty, embed_api_key will be an empty string.
|
|
local embed_api_key = ""
|
|
if
|
|
Config.rag_service
|
|
and Config.rag_service.embed
|
|
and Config.rag_service.embed.api_key
|
|
and Config.rag_service.embed.api_key ~= ""
|
|
then
|
|
embed_api_key = os.getenv(Config.rag_service.embed.api_key) or ""
|
|
if embed_api_key == nil or embed_api_key == "" then
|
|
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.embed.api_key))
|
|
return
|
|
end
|
|
end
|
|
|
|
local embed_extra = "{}" -- Default to empty JSON object string
|
|
if Config.rag_service and Config.rag_service.embed and Config.rag_service.embed.extra then
|
|
embed_extra = string.format("%q", vim.json.encode(Config.rag_service.embed.extra))
|
|
end
|
|
|
|
local llm_extra = "{}" -- Default to empty JSON object string
|
|
if Config.rag_service and Config.rag_service.llm and Config.rag_service.llm.extra then
|
|
llm_extra = string.format("%q", vim.json.encode(Config.rag_service.llm.extra))
|
|
end
|
|
|
|
local port = M.get_rag_service_port()
|
|
|
|
if M.get_rag_service_runner() == "docker" then
|
|
local image = M.get_rag_service_image()
|
|
local data_path = M.get_data_path()
|
|
local cmd = { "docker", "inspect", "--format", "{{.State.Status}}", container_name }
|
|
local result = vim.system(cmd, { text = true }):wait()
|
|
if result.code ~= 0 then Utils.debug(string.format("cmd: %s execution error", table.concat(cmd, " "))) end
|
|
if result.stdout == "" then
|
|
Utils.debug(string.format("container %s not found, starting...", container_name))
|
|
elseif result.stdout == "running" then
|
|
Utils.debug(string.format("container %s already running", container_name))
|
|
local current_image = M.get_current_image()
|
|
if current_image == image then
|
|
cb()
|
|
return
|
|
end
|
|
Utils.debug(
|
|
string.format(
|
|
"container %s is running with different image: %s != %s, stopping...",
|
|
container_name,
|
|
current_image,
|
|
image
|
|
)
|
|
)
|
|
M.stop_rag_service()
|
|
end
|
|
if result.stdout ~= "running" then
|
|
Utils.info(string.format("container %s already started but not running, stopping...", container_name))
|
|
M.stop_rag_service()
|
|
end
|
|
local cmd_ = string.format(
|
|
"docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_EMBED_PROVIDER=%s -e RAG_EMBED_ENDPOINT=%s -e RAG_EMBED_API_KEY=%s -e RAG_EMBED_MODEL=%s -e RAG_EMBED_EXTRA=%s -e RAG_LLM_PROVIDER=%s -e RAG_LLM_ENDPOINT=%s -e RAG_LLM_API_KEY=%s -e RAG_LLM_MODEL=%s -e RAG_LLM_EXTRA=%s %s %s",
|
|
M.get_rag_service_port(),
|
|
M.get_rag_service_port(),
|
|
container_name,
|
|
data_path,
|
|
Config.rag_service.host_mount,
|
|
Config.rag_service.embed.provider,
|
|
Config.rag_service.embed.endpoint,
|
|
embed_api_key,
|
|
Config.rag_service.embed.model,
|
|
embed_extra,
|
|
Config.rag_service.llm.provider,
|
|
Config.rag_service.llm.endpoint,
|
|
llm_api_key,
|
|
Config.rag_service.llm.model,
|
|
llm_extra,
|
|
Config.rag_service.docker_extra_args,
|
|
image
|
|
)
|
|
vim.fn.jobstart(cmd_, {
|
|
detach = true,
|
|
on_exit = function(_, exit_code)
|
|
if exit_code ~= 0 then
|
|
Utils.error(string.format("container %s failed to start, exit code: %d", container_name, exit_code))
|
|
else
|
|
Utils.debug(string.format("container %s started", container_name))
|
|
cb()
|
|
end
|
|
end,
|
|
})
|
|
elseif M.get_rag_service_runner() == "nix" then
|
|
-- Check if service is already running
|
|
local check_cmd = { "pgrep", "-f", service_path }
|
|
local check_result = vim.system(check_cmd, { text = true }):wait().stdout
|
|
if check_result ~= "" then
|
|
Utils.debug(string.format("RAG service already running at %s", service_path))
|
|
cb()
|
|
return
|
|
end
|
|
|
|
local dirname =
|
|
Utils.trim(string.sub(debug.getinfo(1).source, 2, #"/lua/avante/rag_service.lua" * -1), { suffix = "/" })
|
|
local rag_service_dir = dirname .. "/py/rag-service"
|
|
|
|
Utils.debug(string.format("launching %s with nix...", container_name))
|
|
|
|
vim.system({ "sh", "run.sh", service_path }, {
|
|
detach = true,
|
|
cwd = rag_service_dir,
|
|
env = {
|
|
ALLOW_RESET = "TRUE",
|
|
PORT = port,
|
|
DATA_DIR = service_path,
|
|
RAG_EMBED_PROVIDER = Config.rag_service.embed.provider,
|
|
RAG_EMBED_ENDPOINT = Config.rag_service.embed.endpoint,
|
|
RAG_EMBED_API_KEY = embed_api_key,
|
|
RAG_EMBED_MODEL = Config.rag_service.embed.model,
|
|
RAG_EMBED_EXTRA = embed_extra,
|
|
RAG_LLM_PROVIDER = Config.rag_service.llm.provider,
|
|
RAG_LLM_ENDPOINT = Config.rag_service.llm.endpoint,
|
|
RAG_LLM_API_KEY = llm_api_key,
|
|
RAG_LLM_MODEL = Config.rag_service.llm.model,
|
|
RAG_LLM_EXTRA = llm_extra,
|
|
},
|
|
}, function(res)
|
|
if res.code ~= 0 then
|
|
Utils.error(string.format("service %s failed to start, exit code: %d", container_name, res.code))
|
|
else
|
|
Utils.debug(string.format("service %s started", container_name))
|
|
cb()
|
|
end
|
|
end)
|
|
end
|
|
end
|
|
|
|
function M.stop_rag_service()
|
|
if M.get_rag_service_runner() == "docker" then
|
|
local cmd = { "docker", "inspect", "--format", "{{.State.Status}}", container_name }
|
|
local result = vim.system(cmd, { text = true }):wait().stdout
|
|
if result ~= "" then vim.system({ "docker", "rm", "-fv", container_name }):wait() end
|
|
else
|
|
local pid = vim.system({ "pgrep", "-f", service_path }, { text = true }):wait().stdout
|
|
if pid ~= "" then
|
|
vim.system({ "kill", "-9", pid }):wait()
|
|
Utils.debug(string.format("Attempted to kill processes related to %s", service_path))
|
|
end
|
|
end
|
|
end
|
|
|
|
function M.get_rag_service_status()
|
|
if M.get_rag_service_runner() == "docker" then
|
|
local cmd = { "docker", "inspect", "--format", "{{.State.Status}}", container_name }
|
|
local result = vim.system(cmd, { text = true }):wait().stdout
|
|
if result ~= "running" then
|
|
return "stopped"
|
|
else
|
|
return "running"
|
|
end
|
|
elseif M.get_rag_service_runner() == "nix" then
|
|
local cmd = { "pgrep", "-f", service_path }
|
|
local result = vim.system(cmd, { text = true }):wait().stdout
|
|
if result == "" then
|
|
return "stopped"
|
|
else
|
|
return "running"
|
|
end
|
|
end
|
|
end
|
|
|
|
function M.get_scheme(uri)
|
|
local scheme = uri:match("^(%w+)://")
|
|
if scheme == nil then return "unknown" end
|
|
return scheme
|
|
end
|
|
|
|
function M.to_container_uri(uri)
|
|
local runner = M.get_rag_service_runner()
|
|
if runner == "nix" then return uri end
|
|
local scheme = M.get_scheme(uri)
|
|
if scheme == "file" then
|
|
local path = uri:match("^file://(.*)$")
|
|
local host_dir = Config.rag_service.host_mount
|
|
if path:sub(1, #host_dir) == host_dir then path = "/host" .. path:sub(#host_dir + 1) end
|
|
uri = string.format("file://%s", path)
|
|
end
|
|
return uri
|
|
end
|
|
|
|
function M.to_local_uri(uri)
|
|
local scheme = M.get_scheme(uri)
|
|
local path = uri:match("^file:///host(.*)$")
|
|
|
|
if scheme == "file" and path ~= nil then
|
|
local host_dir = Config.rag_service.host_mount
|
|
local full_path = Path:new(host_dir):joinpath(path:sub(2)):absolute()
|
|
uri = string.format("file://%s", full_path)
|
|
end
|
|
|
|
return uri
|
|
end
|
|
|
|
function M.is_ready()
|
|
return vim
|
|
.system(
|
|
{ "curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", M.get_rag_service_url() .. "/api/health" },
|
|
{ text = true }
|
|
)
|
|
:wait().code == 0
|
|
end
|
|
|
|
---@class AvanteRagServiceAddResourceResponse
|
|
---@field status string
|
|
---@field message string
|
|
|
|
---@param uri string
|
|
function M.add_resource(uri)
|
|
uri = M.to_container_uri(uri)
|
|
local resource_name = uri:match("([^/]+)/$")
|
|
local resources_resp = M.get_resources()
|
|
if resources_resp == nil then
|
|
Utils.error("Failed to get resources")
|
|
return nil
|
|
end
|
|
local already_added = false
|
|
for _, resource in ipairs(resources_resp.resources) do
|
|
if resource.uri == uri then
|
|
already_added = true
|
|
resource_name = resource.name
|
|
break
|
|
end
|
|
end
|
|
if not already_added then
|
|
local names_map = {}
|
|
for _, resource in ipairs(resources_resp.resources) do
|
|
names_map[resource.name] = true
|
|
end
|
|
if names_map[resource_name] then
|
|
for i = 1, 100 do
|
|
local resource_name_ = string.format("%s-%d", resource_name, i)
|
|
if not names_map[resource_name_] then
|
|
resource_name = resource_name_
|
|
break
|
|
end
|
|
end
|
|
if names_map[resource_name] then
|
|
Utils.error(string.format("Failed to add resource, name conflict: %s", resource_name))
|
|
return nil
|
|
end
|
|
end
|
|
end
|
|
local cmd = {
|
|
"curl",
|
|
"-X",
|
|
"POST",
|
|
M.get_rag_service_url() .. "/api/v1/add_resource",
|
|
"-H",
|
|
"Content-Type: application/json",
|
|
"-d",
|
|
vim.json.encode({ name = resource_name, uri = uri }),
|
|
}
|
|
vim.system(cmd, { text = true }, function(output)
|
|
if output.code == 0 then
|
|
Utils.debug(string.format("Added resource: %s", uri))
|
|
else
|
|
Utils.error(string.format("Failed to add resource: %s; output: %s", uri, output.stderr))
|
|
end
|
|
end)
|
|
end
|
|
|
|
function M.remove_resource(uri)
|
|
uri = M.to_container_uri(uri)
|
|
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/remove_resource", {
|
|
headers = {
|
|
["Content-Type"] = "application/json",
|
|
},
|
|
body = vim.json.encode({
|
|
uri = uri,
|
|
}),
|
|
})
|
|
if resp.status ~= 200 then
|
|
Utils.error("failed to remove resource: " .. resp.body)
|
|
return
|
|
end
|
|
return vim.json.decode(resp.body)
|
|
end
|
|
|
|
---@class AvanteRagServiceRetrieveSource
|
|
---@field uri string
|
|
---@field content string
|
|
|
|
---@class AvanteRagServiceRetrieveResponse
|
|
---@field response string
|
|
---@field sources AvanteRagServiceRetrieveSource[]
|
|
|
|
---@param base_uri string
|
|
---@param query string
|
|
---@param on_complete fun(resp: AvanteRagServiceRetrieveResponse | nil, error: string | nil): nil
|
|
function M.retrieve(base_uri, query, on_complete)
|
|
base_uri = M.to_container_uri(base_uri)
|
|
curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
|
|
headers = {
|
|
["Content-Type"] = "application/json",
|
|
},
|
|
body = vim.json.encode({
|
|
base_uri = base_uri,
|
|
query = query,
|
|
top_k = 10,
|
|
}),
|
|
timeout = 100000,
|
|
callback = function(resp)
|
|
if resp.status ~= 200 then
|
|
Utils.error("failed to retrieve: " .. resp.body)
|
|
on_complete(nil, resp.body)
|
|
return
|
|
end
|
|
local jsn = vim.json.decode(resp.body)
|
|
jsn.sources = vim
|
|
.iter(jsn.sources)
|
|
:map(function(source)
|
|
local uri = M.to_local_uri(source.uri)
|
|
return vim.tbl_deep_extend("force", source, { uri = uri })
|
|
end)
|
|
:totable()
|
|
on_complete(jsn, nil)
|
|
end,
|
|
})
|
|
end
|
|
|
|
---@class AvanteRagServiceIndexingStatusSummary
|
|
---@field indexing integer
|
|
---@field completed integer
|
|
---@field failed integer
|
|
|
|
---@class AvanteRagServiceIndexingStatusResponse
|
|
---@field uri string
|
|
---@field is_watched boolean
|
|
---@field total_files integer
|
|
---@field status_summary AvanteRagServiceIndexingStatusSummary
|
|
|
|
---@param uri string
|
|
---@return AvanteRagServiceIndexingStatusResponse | nil
|
|
function M.indexing_status(uri)
|
|
uri = M.to_container_uri(uri)
|
|
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/indexing_status", {
|
|
headers = {
|
|
["Content-Type"] = "application/json",
|
|
},
|
|
body = vim.json.encode({
|
|
uri = uri,
|
|
}),
|
|
})
|
|
if resp.status ~= 200 then
|
|
Utils.error("Failed to get indexing status: " .. resp.body)
|
|
return
|
|
end
|
|
local jsn = vim.json.decode(resp.body)
|
|
jsn.uri = M.to_local_uri(jsn.uri)
|
|
return jsn
|
|
end
|
|
|
|
---@class AvanteRagServiceResource
|
|
---@field name string
|
|
---@field uri string
|
|
---@field type string
|
|
---@field status string
|
|
---@field indexing_status string
|
|
---@field created_at string
|
|
---@field indexing_started_at string | nil
|
|
---@field last_indexed_at string | nil
|
|
|
|
---@class AvanteRagServiceResourceListResponse
|
|
---@field resources AvanteRagServiceResource[]
|
|
---@field total_count number
|
|
|
|
---@return AvanteRagServiceResourceListResponse | nil
|
|
function M.get_resources()
|
|
local resp = curl.get(M.get_rag_service_url() .. "/api/v1/resources", {
|
|
headers = {
|
|
["Content-Type"] = "application/json",
|
|
},
|
|
})
|
|
if resp.status ~= 200 then
|
|
Utils.error("Failed to get resources: " .. resp.body)
|
|
return
|
|
end
|
|
local jsn = vim.json.decode(resp.body)
|
|
jsn.resources = vim
|
|
.iter(jsn.resources)
|
|
:map(function(resource)
|
|
local uri = M.to_local_uri(resource.uri)
|
|
return vim.tbl_deep_extend("force", resource, { uri = uri })
|
|
end)
|
|
:totable()
|
|
return jsn
|
|
end
|
|
|
|
return M
|