feat: add support for ollama RAG providers (#1427)
* fix: openai env * feat: add support for multiple RAG providers - Added provider, model and endpoint configuration options for RAG service - Updated RAG service to support both OpenAI and Ollama providers - Added Ollama embedding support and dependencies - Improved environment variable handling for RAG service configuration Signed-off-by: wfhtqp@gmail.com <wfhtqp@gmail.com> * fix: update docker env * feat: rag server add ollama llm * fix: pre-commit * feat: check embed model and clean * docs: add rag server config docs * fix: pyright ignore --------- Signed-off-by: wfhtqp@gmail.com <wfhtqp@gmail.com>
This commit is contained in:
@@ -35,6 +35,10 @@ M._defaults = {
|
||||
tokenizer = "tiktoken",
|
||||
rag_service = {
|
||||
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
|
||||
provider = "openai", -- The provider to use for RAG service. eg: openai or ollama
|
||||
llm_model = "", -- The LLM model to use for RAG service
|
||||
embed_model = "", -- The embedding model to use for RAG service
|
||||
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
|
||||
},
|
||||
web_search_engine = {
|
||||
provider = "tavily",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
local curl = require("plenary.curl")
|
||||
local Path = require("plenary.path")
|
||||
local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
|
||||
local M = {}
|
||||
|
||||
@@ -32,12 +33,12 @@ end
|
||||
---@param cb fun()
|
||||
function M.launch_rag_service(cb)
|
||||
local openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if openai_api_key == nil then
|
||||
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
|
||||
return
|
||||
if Config.rag_service.provider == "openai" then
|
||||
if openai_api_key == nil then
|
||||
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
|
||||
return
|
||||
end
|
||||
end
|
||||
local openai_base_url = os.getenv("OPENAI_BASE_URL")
|
||||
if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end
|
||||
local port = M.get_rag_service_port()
|
||||
local image = M.get_rag_service_image()
|
||||
local data_path = M.get_data_path()
|
||||
@@ -63,13 +64,17 @@ function M.launch_rag_service(cb)
|
||||
Utils.debug(string.format("container %s not found, starting...", container_name))
|
||||
end
|
||||
local cmd_ = string.format(
|
||||
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_API_BASE=%s -e OPENAI_EMBED_MODEL=%s %s",
|
||||
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s",
|
||||
port,
|
||||
container_name,
|
||||
data_path,
|
||||
Config.rag_service.provider,
|
||||
Config.rag_service.provider:upper(),
|
||||
openai_api_key,
|
||||
openai_base_url,
|
||||
os.getenv("OPENAI_EMBED_MODEL"),
|
||||
Config.rag_service.provider:upper(),
|
||||
Config.rag_service.endpoint,
|
||||
Config.rag_service.llm_model,
|
||||
Config.rag_service.embed_model,
|
||||
image
|
||||
)
|
||||
vim.fn.jobstart(cmd_, {
|
||||
@@ -229,6 +234,7 @@ function M.retrieve(base_uri, query)
|
||||
query = query,
|
||||
top_k = 10,
|
||||
}),
|
||||
timeout = 100000,
|
||||
})
|
||||
if resp.status ~= 200 then
|
||||
Utils.error("failed to retrieve: " .. resp.body)
|
||||
|
||||
Reference in New Issue
Block a user