feat: Enhanced Model Provider Support and Configuration Flexibility For Rag Service (#2056)
Co-authored-by: doodleEsc <cokie@foxmail.com> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
@@ -290,7 +290,6 @@ function M.remove_selected_file(filepath)
|
||||
|
||||
for _, file in ipairs(files) do
|
||||
local rel_path = Utils.uniform_path(file)
|
||||
vim.notify(rel_path)
|
||||
sidebar.file_selector:remove_selected_file(rel_path)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -37,14 +37,24 @@ M._defaults = {
|
||||
tokenizer = "tiktoken",
|
||||
---@type string | (fun(): string) | nil
|
||||
system_prompt = nil,
|
||||
rag_service = {
|
||||
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
|
||||
host_mount = os.getenv("HOME"), -- Host mount path for the rag service (docker will mount this path)
|
||||
runner = "docker", -- The runner for the rag service, (can use docker, or nix)
|
||||
provider = "openai", -- The provider to use for RAG service. eg: openai or ollama
|
||||
llm_model = "", -- The LLM model to use for RAG service
|
||||
embed_model = "", -- The embedding model to use for RAG service
|
||||
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
|
||||
rag_service = { -- RAG service configuration
|
||||
enabled = false, -- Enables the RAG service
|
||||
host_mount = os.getenv("HOME"), -- Host mount path for the RAG service (Docker will mount this path)
|
||||
runner = "docker", -- The runner for the RAG service (can use docker or nix)
|
||||
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
|
||||
provider = "openai", -- The LLM provider
|
||||
endpoint = "https://api.openai.com/v1", -- The LLM API endpoint
|
||||
api_key = "OPENAI_API_KEY", -- The environment variable name for the LLM API key
|
||||
model = "gpt-4o-mini", -- The LLM model name
|
||||
extra = nil, -- Extra configuration options for the LLM
|
||||
},
|
||||
embed = { -- Configuration for the Embedding model used by the RAG service
|
||||
provider = "openai", -- The embedding provider
|
||||
endpoint = "https://api.openai.com/v1", -- The embedding API endpoint
|
||||
api_key = "OPENAI_API_KEY", -- The environment variable name for the embedding API key
|
||||
model = "text-embedding-3-large", -- The embedding model name
|
||||
extra = nil, -- Extra configuration options for the embedding model
|
||||
},
|
||||
docker_extra_args = "", -- Extra arguments to pass to the docker command
|
||||
},
|
||||
web_search_engine = {
|
||||
|
||||
@@ -8,7 +8,7 @@ local M = {}
|
||||
local container_name = "avante-rag-service"
|
||||
local service_path = "/tmp/" .. container_name
|
||||
|
||||
function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.10" end
|
||||
function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.11" end
|
||||
|
||||
function M.get_rag_service_port() return 20250 end
|
||||
|
||||
@@ -35,13 +35,46 @@ function M.get_rag_service_runner() return (Config.rag_service and Config.rag_se
|
||||
|
||||
---@param cb fun()
|
||||
function M.launch_rag_service(cb)
|
||||
local openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if Config.rag_service.provider == "openai" then
|
||||
if openai_api_key == nil then
|
||||
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
|
||||
--- If Config.rag_service.llm.api_key is nil or empty, llm_api_key will be an empty string.
|
||||
local llm_api_key = ""
|
||||
if
|
||||
Config.rag_service
|
||||
and Config.rag_service.llm
|
||||
and Config.rag_service.llm.api_key
|
||||
and Config.rag_service.llm.api_key ~= ""
|
||||
then
|
||||
llm_api_key = os.getenv(Config.rag_service.llm.api_key) or ""
|
||||
if llm_api_key == nil or llm_api_key == "" then
|
||||
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.llm.api_key))
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
--- If Config.rag_service.embed.api_key is nil or empty, embed_api_key will be an empty string.
|
||||
local embed_api_key = ""
|
||||
if
|
||||
Config.rag_service
|
||||
and Config.rag_service.embed
|
||||
and Config.rag_service.embed.api_key
|
||||
and Config.rag_service.embed.api_key ~= ""
|
||||
then
|
||||
embed_api_key = os.getenv(Config.rag_service.embed.api_key) or ""
|
||||
if embed_api_key == nil or embed_api_key == "" then
|
||||
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.embed.api_key))
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
local embed_extra = "{}" -- Default to empty JSON object string
|
||||
if Config.rag_service and Config.rag_service.embed and Config.rag_service.embed.extra then
|
||||
embed_extra = vim.json.encode(Config.rag_service.embed.extra):gsub('"', '\\"')
|
||||
end
|
||||
|
||||
local llm_extra = "{}" -- Default to empty JSON object string
|
||||
if Config.rag_service and Config.rag_service.llm and Config.rag_service.llm.extra then
|
||||
llm_extra = vim.json.encode(Config.rag_service.llm.extra):gsub('"', '\\"')
|
||||
end
|
||||
|
||||
local port = M.get_rag_service_port()
|
||||
|
||||
if M.get_rag_service_runner() == "docker" then
|
||||
@@ -74,19 +107,22 @@ function M.launch_rag_service(cb)
|
||||
M.stop_rag_service()
|
||||
end
|
||||
local cmd_ = string.format(
|
||||
"docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s %s",
|
||||
"docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_EMBED_PROVIDER=%s -e RAG_EMBED_ENDPOINT=%s -e RAG_EMBED_API_KEY=%s -e RAG_EMBED_MODEL=%s -e RAG_EMBED_EXTRA=%s -e RAG_LLM_PROVIDER=%s -e RAG_LLM_ENDPOINT=%s -e RAG_LLM_API_KEY=%s -e RAG_LLM_MODEL=%s -e RAG_LLM_EXTRA=%s %s %s",
|
||||
M.get_rag_service_port(),
|
||||
M.get_rag_service_port(),
|
||||
container_name,
|
||||
data_path,
|
||||
Config.rag_service.host_mount,
|
||||
Config.rag_service.provider,
|
||||
Config.rag_service.provider:upper(),
|
||||
openai_api_key,
|
||||
Config.rag_service.provider:upper(),
|
||||
Config.rag_service.endpoint,
|
||||
Config.rag_service.llm_model,
|
||||
Config.rag_service.embed_model,
|
||||
Config.rag_service.embed.provider,
|
||||
Config.rag_service.embed.endpoint,
|
||||
embed_api_key,
|
||||
Config.rag_service.embed.model,
|
||||
embed_extra,
|
||||
Config.rag_service.llm.provider,
|
||||
Config.rag_service.llm.endpoint,
|
||||
llm_api_key,
|
||||
Config.rag_service.llm.model,
|
||||
llm_extra,
|
||||
Config.rag_service.docker_extra_args,
|
||||
image
|
||||
)
|
||||
@@ -118,17 +154,20 @@ function M.launch_rag_service(cb)
|
||||
Utils.debug(string.format("launching %s with nix...", container_name))
|
||||
|
||||
local cmd = string.format(
|
||||
"cd %s && ALLOW_RESET=TRUE PORT=%d DATA_DIR=%s RAG_PROVIDER=%s %s_API_KEY=%s %s_API_BASE=%s RAG_LLM_MODEL=%s RAG_EMBED_MODEL=%s sh run.sh %s",
|
||||
"cd %s && ALLOW_RESET=TRUE PORT=%d DATA_DIR=%s RAG_EMBED_PROVIDER=%s RAG_EMBED_ENDPOINT=%s RAG_EMBED_API_KEY=%s RAG_EMBED_MODEL=%s RAG_EMBED_EXTRA=%s RAG_LLM_PROVIDER=%s RAG_LLM_ENDPOINT=%s RAG_LLM_API_KEY=%s RAG_LLM_MODEL=%s RAG_LLM_EXTRA=%s sh run.sh %s",
|
||||
rag_service_dir,
|
||||
port,
|
||||
service_path,
|
||||
Config.rag_service.provider,
|
||||
Config.rag_service.provider:upper(),
|
||||
openai_api_key,
|
||||
Config.rag_service.provider:upper(),
|
||||
Config.rag_service.endpoint,
|
||||
Config.rag_service.llm_model,
|
||||
Config.rag_service.embed_model,
|
||||
Config.rag_service.embed.provider,
|
||||
Config.rag_service.embed.endpoint,
|
||||
embed_api_key,
|
||||
Config.rag_service.embed.model,
|
||||
embed_extra,
|
||||
Config.rag_service.llm.provider,
|
||||
Config.rag_service.llm.endpoint,
|
||||
embed_api_key,
|
||||
Config.rag_service.llm.model,
|
||||
llm_extra,
|
||||
service_path
|
||||
)
|
||||
|
||||
@@ -211,7 +250,12 @@ function M.to_local_uri(uri)
|
||||
end
|
||||
|
||||
function M.is_ready()
|
||||
vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url()))
|
||||
vim.fn.system(
|
||||
string.format(
|
||||
"curl -s -o /dev/null -w '%%{http_code}' %s",
|
||||
string.format("%s%s", M.get_rag_service_url(), "/api/health")
|
||||
)
|
||||
)
|
||||
return vim.v.shell_error == 0
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user