Files
avante.nvim/lua/avante/rag_service.lua
2025-02-23 01:37:26 +08:00

303 lines
8.4 KiB
Lua

local curl = require("plenary.curl")
local Path = require("plenary.path")
local Utils = require("avante.utils")
local M = {}
local container_name = "avante-rag-service"
function M.get_rag_service_image() return "ghcr.io/yetone/avante-rag-service:0.0.3" end
function M.get_rag_service_port() return 20250 end
function M.get_rag_service_url() return string.format("http://localhost:%d", M.get_rag_service_port()) end
function M.get_data_path()
local p = Path:new(vim.fn.stdpath("data")):joinpath("avante/rag_service")
if not p:exists() then p:mkdir({ parents = true }) end
return p
end
function M.get_current_image()
local cmd = string.format("docker inspect %s | grep Image | grep %s", container_name, container_name)
local result = vim.fn.system(cmd)
if result == "" then return nil end
local exit_code = vim.v.shell_error
if exit_code ~= 0 then return nil end
local image = result:match('"Image":%s*"(.*)"')
if image == nil then return nil end
return image
end
---@return boolean already_running
function M.launch_rag_service()
local openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key == nil then
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
return false
end
local openai_base_url = os.getenv("OPENAI_BASE_URL")
if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end
local port = M.get_rag_service_port()
local image = M.get_rag_service_image()
local data_path = M.get_data_path()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result ~= "" then
Utils.debug(string.format("container %s already running", container_name))
local current_image = M.get_current_image()
if current_image == image then return false end
Utils.debug(
string.format(
"container %s is running with different image: %s != %s, stopping...",
container_name,
current_image,
image
)
)
M.stop_rag_service()
else
Utils.debug(string.format("container %s not found, starting...", container_name))
end
local cmd_ = string.format(
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_BASE_URL=%s %s",
port,
container_name,
data_path,
openai_api_key,
openai_base_url,
image
)
vim.fn.system(cmd_)
Utils.debug(string.format("container %s started", container_name))
return true
end
function M.stop_rag_service()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result ~= "" then vim.fn.system(string.format("docker rm -fv %s", container_name)) end
end
function M.get_rag_service_status()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result == "" then
return "running"
else
return "stopped"
end
end
function M.get_scheme(uri)
local scheme = uri:match("^(%w+)://")
if scheme == nil then return "unknown" end
return scheme
end
function M.to_container_uri(uri)
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file://(.*)$")
uri = string.format("file:///host%s", path)
end
return uri
end
function M.to_local_uri(uri)
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file://host(.*)$")
uri = string.format("file://%s", path)
end
return uri
end
function M.is_ready()
vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url()))
return vim.v.shell_error == 0
end
---@class AvanteRagServiceAddResourceResponse
---@field status string
---@field message string
---@param uri string
---@return AvanteRagServiceAddResourceResponse | nil
function M.add_resource(uri)
uri = M.to_container_uri(uri)
local resource_name = uri:match("([^/]+)/$")
local resources_resp = M.get_resources()
if resources_resp == nil then
Utils.error("Failed to get resources")
return nil
end
local already_added = false
for _, resource in ipairs(resources_resp.resources) do
if resource.uri == uri then
already_added = true
resource_name = resource.name
break
end
end
if not already_added then
local names_map = {}
for _, resource in ipairs(resources_resp.resources) do
names_map[resource.name] = true
end
if names_map[resource_name] then
for i = 1, 100 do
local resource_name_ = string.format("%s-%d", resource_name, i)
if not names_map[resource_name_] then
resource_name = resource_name_
break
end
end
if names_map[resource_name] then
Utils.error(string.format("Failed to add resource, name conflict: %s", resource_name))
return nil
end
end
end
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/add_resource", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
name = resource_name,
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("failed to add resource: " .. resp.body)
return
end
return vim.json.decode(resp.body)
end
function M.remove_resource(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/remove_resource", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("failed to remove resource: " .. resp.body)
return
end
return vim.json.decode(resp.body)
end
---@class AvanteRagServiceRetrieveSource
---@field uri string
---@field content string
---@class AvanteRagServiceRetrieveResponse
---@field response string
---@field sources AvanteRagServiceRetrieveSource[]
---@param base_uri string
---@param query string
---@return AvanteRagServiceRetrieveResponse | nil resp
---@return string | nil error
function M.retrieve(base_uri, query)
base_uri = M.to_container_uri(base_uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
base_uri = base_uri,
query = query,
top_k = 10,
}),
})
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
return nil, "failed to retrieve: " .. resp.body
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
return jsn, nil
end
---@class AvanteRagServiceIndexingStatusSummary
---@field indexing integer
---@field completed integer
---@field failed integer
---@class AvanteRagServiceIndexingStatusResponse
---@field uri string
---@field is_watched boolean
---@field total_files integer
---@field status_summary AvanteRagServiceIndexingStatusSummary
---@param uri string
---@return AvanteRagServiceIndexingStatusResponse | nil
function M.indexing_status(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/indexing_status", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("Failed to get indexing status: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.uri = M.to_local_uri(jsn.uri)
return jsn
end
---@class AvanteRagServiceResource
---@field name string
---@field uri string
---@field type string
---@field status string
---@field indexing_status string
---@field created_at string
---@field indexing_started_at string | nil
---@field last_indexed_at string | nil
---@class AvanteRagServiceResourceListResponse
---@field resources AvanteRagServiceResource[]
---@field total_count number
---@return AvanteRagServiceResourceListResponse | nil
M.get_resources = function()
local resp = curl.get(M.get_rag_service_url() .. "/api/v1/resources", {
headers = {
["Content-Type"] = "application/json",
},
})
if resp.status ~= 200 then
Utils.error("Failed to get resources: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.resources = vim
.iter(jsn.resources)
:map(function(resource)
local uri = M.to_local_uri(resource.uri)
return vim.tbl_deep_extend("force", resource, { uri = uri })
end)
:totable()
return jsn
end
return M