feat: Enhanced Model Provider Support and Configuration Flexibility For Rag Service (#2056)
Co-authored-by: doodleEsc <cokie@foxmail.com> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
40
README.md
40
README.md
@@ -705,7 +705,7 @@ Given its early stage, `avante.nvim` currently supports the following basic func
|
||||
> model = "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
> aws_profile = "bedrock",
|
||||
> aws_region = "us-east-1",
|
||||
>},
|
||||
> },
|
||||
> ```
|
||||
>
|
||||
> Note: Bedrock requires the [AWS CLI](https://aws.amazon.com/cli/) to be installed on your system.
|
||||
@@ -884,21 +884,37 @@ For more information, see [Custom Providers](https://github.com/yetone/avante.nv
|
||||
Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. By default, it is not enabled. You can enable it this way:
|
||||
|
||||
```lua
|
||||
rag_service = {
|
||||
enabled = false, -- Enables the RAG service
|
||||
host_mount = os.getenv("HOME"), -- Host mount path for the rag service
|
||||
provider = "openai", -- The provider to use for RAG service (e.g. openai or ollama)
|
||||
llm_model = "", -- The LLM model to use for RAG service
|
||||
embed_model = "", -- The embedding model to use for RAG service
|
||||
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
|
||||
},
|
||||
rag_service = { -- RAG Service configuration
|
||||
enabled = false, -- Enables the RAG service
|
||||
host_mount = os.getenv("HOME"), -- Host mount path for the rag service (Docker will mount this path)
|
||||
runner = "docker", -- Runner for the RAG service (can use docker or nix)
|
||||
llm = { -- Language Model (LLM) configuration for RAG service
|
||||
provider = "openai", -- LLM provider
|
||||
endpoint = "https://api.openai.com/v1", -- LLM API endpoint
|
||||
api_key = "OPENAI_API_KEY", -- Environment variable name for the LLM API key
|
||||
model = "gpt-4o-mini", -- LLM model name
|
||||
extra = nil, -- Additional configuration options for LLM
|
||||
},
|
||||
embed = { -- Embedding model configuration for RAG service
|
||||
provider = "openai", -- Embedding provider
|
||||
endpoint = "https://api.openai.com/v1", -- Embedding API endpoint
|
||||
api_key = "OPENAI_API_KEY", -- Environment variable name for the embedding API key
|
||||
model = "text-embedding-3-large", -- Embedding model name
|
||||
extra = nil, -- Additional configuration options for the embedding model
|
||||
},
|
||||
docker_extra_args = "", -- Extra arguments to pass to the docker command
|
||||
},
|
||||
```
|
||||
|
||||
If your rag_service provider is `openai`, then you need to set the `OPENAI_API_KEY` environment variable!
|
||||
The RAG Service can currently configure the LLM and embedding models separately. In the `llm` and `embed` configuration blocks, you can set the following fields:
|
||||
|
||||
If your rag_service provider is `ollama`, you need to set the endpoint to `http://localhost:11434` (note there is no `/v1` at the end) or any address of your own ollama server.
|
||||
- `provider`: Model provider (e.g., "openai", "ollama", "dashscope", and "openrouter")
|
||||
- `endpoint`: API endpoint
|
||||
- `api_key`: Environment variable name for the API key
|
||||
- `model`: Model name
|
||||
- `extra`: Additional configuration options
|
||||
|
||||
If your rag_service provider is `ollama`, when `llm_model` is empty, it defaults to `llama3`, and when `embed_model` is empty, it defaults to `nomic-embed-text`. Please make sure these models are available in your ollama server.
|
||||
For detailed configuration of different model providers, you can check [here](./py/rag-service/README.md).
|
||||
|
||||
Additionally, RAG Service also depends on Docker! (For macOS users, OrbStack is recommended as a Docker alternative).
|
||||
|
||||
|
||||
38
README_zh.md
38
README_zh.md
@@ -747,21 +747,37 @@ Avante 提供了一组默认提供者,但用户也可以创建自己的提供
|
||||
Avante 提供了一个 RAG 服务,这是一个用于获取 AI 生成代码所需上下文的工具。默认情况下,它未启用。您可以通过以下方式启用它:
|
||||
|
||||
```lua
|
||||
rag_service = {
|
||||
enabled = false, -- 启用 RAG 服务
|
||||
host_mount = os.getenv("HOME"), -- RAG 服务的主机挂载路径
|
||||
provider = "openai", -- 用于 RAG 服务的提供者(例如 openai 或 ollama)
|
||||
llm_model = "", -- 用于 RAG 服务的 LLM 模型
|
||||
embed_model = "", -- 用于 RAG 服务的嵌入模型
|
||||
endpoint = "https://api.openai.com/v1", -- RAG 服务的 API 端点
|
||||
},
|
||||
rag_service = { -- RAG 服务配置
|
||||
enabled = false, -- 启用 RAG 服务
|
||||
host_mount = os.getenv("HOME"), -- RAG 服务的主机挂载路径 (Docker 将挂载此路径)
|
||||
runner = "docker", -- RAG 服务的运行器 (可以使用 docker 或 nix)
|
||||
llm = { -- RAG 服务使用的语言模型 (LLM) 配置
|
||||
provider = "openai", -- LLM 提供者
|
||||
endpoint = "https://api.openai.com/v1", -- LLM API 端点
|
||||
api_key = "OPENAI_API_KEY", -- LLM API 密钥的环境变量名称
|
||||
model = "gpt-4o-mini", -- LLM 模型名称
|
||||
extra = nil, -- LLM 的额外配置选项
|
||||
},
|
||||
embed = { -- RAG 服务使用的嵌入模型配置
|
||||
provider = "openai", -- 嵌入提供者
|
||||
endpoint = "https://api.openai.com/v1", -- 嵌入 API 端点
|
||||
api_key = "OPENAI_API_KEY", -- 嵌入 API 密钥的环境变量名称
|
||||
model = "text-embedding-3-large", -- 嵌入模型名称
|
||||
extra = nil, -- 嵌入模型的额外配置选项
|
||||
},
|
||||
docker_extra_args = "", -- 传递给 docker 命令的额外参数
|
||||
},
|
||||
```
|
||||
|
||||
如果您的 rag_service 提供者是 `openai`,那么您需要设置 `OPENAI_API_KEY` 环境变量!
|
||||
RAG 服务可以单独设置llm模型和嵌入模型。在 `llm` 和 `embed` 配置块中,您可以设置以下字段:
|
||||
|
||||
如果您的 rag_service 提供者是 `ollama`,您需要将端点设置为 `http://localhost:11434`(注意末尾没有 `/v1`)或您自己的 ollama 服务器的任何地址。
|
||||
- `provider`: 模型提供者(例如 "openai", "ollama", "dashscope"以及"openrouter")
|
||||
- `endpoint`: API 端点
|
||||
- `api_key`: API 密钥的环境变量名称
|
||||
- `model`: 模型名称
|
||||
- `extra`: 额外的配置选项
|
||||
|
||||
如果您的 rag_service 提供者是 `ollama`,当 `llm_model` 为空时,默认为 `llama3`,当 `embed_model` 为空时,默认为 `nomic-embed-text`。请确保这些模型在您的 ollama 服务器中可用。
|
||||
有关不同模型提供商的详细配置,你可以在[这里](./py/rag-service/README.md)查看。
|
||||
|
||||
此外,RAG 服务还依赖于 Docker!(对于 macOS 用户,推荐使用 OrbStack 作为 Docker 的替代品)。
|
||||
|
||||
|
||||
@@ -290,7 +290,6 @@ function M.remove_selected_file(filepath)
|
||||
|
||||
for _, file in ipairs(files) do
|
||||
local rel_path = Utils.uniform_path(file)
|
||||
vim.notify(rel_path)
|
||||
sidebar.file_selector:remove_selected_file(rel_path)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -37,14 +37,24 @@ M._defaults = {
|
||||
tokenizer = "tiktoken",
|
||||
---@type string | (fun(): string) | nil
|
||||
system_prompt = nil,
|
||||
rag_service = {
|
||||
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
|
||||
host_mount = os.getenv("HOME"), -- Host mount path for the rag service (docker will mount this path)
|
||||
runner = "docker", -- The runner for the rag service, (can use docker, or nix)
|
||||
provider = "openai", -- The provider to use for RAG service. eg: openai or ollama
|
||||
llm_model = "", -- The LLM model to use for RAG service
|
||||
embed_model = "", -- The embedding model to use for RAG service
|
||||
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
|
||||
rag_service = { -- RAG service configuration
|
||||
enabled = false, -- Enables the RAG service
|
||||
host_mount = os.getenv("HOME"), -- Host mount path for the RAG service (Docker will mount this path)
|
||||
runner = "docker", -- The runner for the RAG service (can use docker or nix)
|
||||
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
|
||||
provider = "openai", -- The LLM provider
|
||||
endpoint = "https://api.openai.com/v1", -- The LLM API endpoint
|
||||
api_key = "OPENAI_API_KEY", -- The environment variable name for the LLM API key
|
||||
model = "gpt-4o-mini", -- The LLM model name
|
||||
extra = nil, -- Extra configuration options for the LLM
|
||||
},
|
||||
embed = { -- Configuration for the Embedding model used by the RAG service
|
||||
provider = "openai", -- The embedding provider
|
||||
endpoint = "https://api.openai.com/v1", -- The embedding API endpoint
|
||||
api_key = "OPENAI_API_KEY", -- The environment variable name for the embedding API key
|
||||
model = "text-embedding-3-large", -- The embedding model name
|
||||
extra = nil, -- Extra configuration options for the embedding model
|
||||
},
|
||||
docker_extra_args = "", -- Extra arguments to pass to the docker command
|
||||
},
|
||||
web_search_engine = {
|
||||
|
||||
@@ -8,7 +8,7 @@ local M = {}
|
||||
local container_name = "avante-rag-service"
|
||||
local service_path = "/tmp/" .. container_name
|
||||
|
||||
function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.10" end
|
||||
function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.11" end
|
||||
|
||||
function M.get_rag_service_port() return 20250 end
|
||||
|
||||
@@ -35,13 +35,46 @@ function M.get_rag_service_runner() return (Config.rag_service and Config.rag_se
|
||||
|
||||
---@param cb fun()
|
||||
function M.launch_rag_service(cb)
|
||||
local openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if Config.rag_service.provider == "openai" then
|
||||
if openai_api_key == nil then
|
||||
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
|
||||
--- If Config.rag_service.llm.api_key is nil or empty, llm_api_key will be an empty string.
|
||||
local llm_api_key = ""
|
||||
if
|
||||
Config.rag_service
|
||||
and Config.rag_service.llm
|
||||
and Config.rag_service.llm.api_key
|
||||
and Config.rag_service.llm.api_key ~= ""
|
||||
then
|
||||
llm_api_key = os.getenv(Config.rag_service.llm.api_key) or ""
|
||||
if llm_api_key == nil or llm_api_key == "" then
|
||||
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.llm.api_key))
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
--- If Config.rag_service.embed.api_key is nil or empty, embed_api_key will be an empty string.
|
||||
local embed_api_key = ""
|
||||
if
|
||||
Config.rag_service
|
||||
and Config.rag_service.embed
|
||||
and Config.rag_service.embed.api_key
|
||||
and Config.rag_service.embed.api_key ~= ""
|
||||
then
|
||||
embed_api_key = os.getenv(Config.rag_service.embed.api_key) or ""
|
||||
if embed_api_key == nil or embed_api_key == "" then
|
||||
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.embed.api_key))
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
local embed_extra = "{}" -- Default to empty JSON object string
|
||||
if Config.rag_service and Config.rag_service.embed and Config.rag_service.embed.extra then
|
||||
embed_extra = vim.json.encode(Config.rag_service.embed.extra):gsub('"', '\\"')
|
||||
end
|
||||
|
||||
local llm_extra = "{}" -- Default to empty JSON object string
|
||||
if Config.rag_service and Config.rag_service.llm and Config.rag_service.llm.extra then
|
||||
llm_extra = vim.json.encode(Config.rag_service.llm.extra):gsub('"', '\\"')
|
||||
end
|
||||
|
||||
local port = M.get_rag_service_port()
|
||||
|
||||
if M.get_rag_service_runner() == "docker" then
|
||||
@@ -74,19 +107,22 @@ function M.launch_rag_service(cb)
|
||||
M.stop_rag_service()
|
||||
end
|
||||
local cmd_ = string.format(
|
||||
"docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s %s",
|
||||
"docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_EMBED_PROVIDER=%s -e RAG_EMBED_ENDPOINT=%s -e RAG_EMBED_API_KEY=%s -e RAG_EMBED_MODEL=%s -e RAG_EMBED_EXTRA=%s -e RAG_LLM_PROVIDER=%s -e RAG_LLM_ENDPOINT=%s -e RAG_LLM_API_KEY=%s -e RAG_LLM_MODEL=%s -e RAG_LLM_EXTRA=%s %s %s",
|
||||
M.get_rag_service_port(),
|
||||
M.get_rag_service_port(),
|
||||
container_name,
|
||||
data_path,
|
||||
Config.rag_service.host_mount,
|
||||
Config.rag_service.provider,
|
||||
Config.rag_service.provider:upper(),
|
||||
openai_api_key,
|
||||
Config.rag_service.provider:upper(),
|
||||
Config.rag_service.endpoint,
|
||||
Config.rag_service.llm_model,
|
||||
Config.rag_service.embed_model,
|
||||
Config.rag_service.embed.provider,
|
||||
Config.rag_service.embed.endpoint,
|
||||
embed_api_key,
|
||||
Config.rag_service.embed.model,
|
||||
embed_extra,
|
||||
Config.rag_service.llm.provider,
|
||||
Config.rag_service.llm.endpoint,
|
||||
llm_api_key,
|
||||
Config.rag_service.llm.model,
|
||||
llm_extra,
|
||||
Config.rag_service.docker_extra_args,
|
||||
image
|
||||
)
|
||||
@@ -118,17 +154,20 @@ function M.launch_rag_service(cb)
|
||||
Utils.debug(string.format("launching %s with nix...", container_name))
|
||||
|
||||
local cmd = string.format(
|
||||
"cd %s && ALLOW_RESET=TRUE PORT=%d DATA_DIR=%s RAG_PROVIDER=%s %s_API_KEY=%s %s_API_BASE=%s RAG_LLM_MODEL=%s RAG_EMBED_MODEL=%s sh run.sh %s",
|
||||
"cd %s && ALLOW_RESET=TRUE PORT=%d DATA_DIR=%s RAG_EMBED_PROVIDER=%s RAG_EMBED_ENDPOINT=%s RAG_EMBED_API_KEY=%s RAG_EMBED_MODEL=%s RAG_EMBED_EXTRA=%s RAG_LLM_PROVIDER=%s RAG_LLM_ENDPOINT=%s RAG_LLM_API_KEY=%s RAG_LLM_MODEL=%s RAG_LLM_EXTRA=%s sh run.sh %s",
|
||||
rag_service_dir,
|
||||
port,
|
||||
service_path,
|
||||
Config.rag_service.provider,
|
||||
Config.rag_service.provider:upper(),
|
||||
openai_api_key,
|
||||
Config.rag_service.provider:upper(),
|
||||
Config.rag_service.endpoint,
|
||||
Config.rag_service.llm_model,
|
||||
Config.rag_service.embed_model,
|
||||
Config.rag_service.embed.provider,
|
||||
Config.rag_service.embed.endpoint,
|
||||
embed_api_key,
|
||||
Config.rag_service.embed.model,
|
||||
embed_extra,
|
||||
Config.rag_service.llm.provider,
|
||||
Config.rag_service.llm.endpoint,
|
||||
embed_api_key,
|
||||
Config.rag_service.llm.model,
|
||||
llm_extra,
|
||||
service_path
|
||||
)
|
||||
|
||||
@@ -211,7 +250,12 @@ function M.to_local_uri(uri)
|
||||
end
|
||||
|
||||
function M.is_ready()
|
||||
vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url()))
|
||||
vim.fn.system(
|
||||
string.format(
|
||||
"curl -s -o /dev/null -w '%%{http_code}' %s",
|
||||
string.format("%s%s", M.get_rag_service_url(), "/api/health")
|
||||
)
|
||||
)
|
||||
return vim.v.shell_error == 0
|
||||
end
|
||||
|
||||
|
||||
@@ -2,21 +2,20 @@ FROM python:3.11-slim-bookworm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y curl git \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends curl git \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
ENV PATH="/root/.local/bin:$PATH" \
|
||||
PYTHONPATH=/app/src \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
COPY requirements.txt .
|
||||
|
||||
RUN uv venv
|
||||
|
||||
RUN uv pip install -r requirements.txt
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1
|
||||
# 直接安装到系统依赖中,不创建虚拟环境
|
||||
RUN uv pip install --system -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
CMD ["uv", "run", "fastapi", "run", "src/main.py", "--workers", "3", "--port", "20250"]
|
||||
CMD ["uvicorn", "src.main:app", "--workers", "3", "--host", "0.0.0.0", "--port", "20250"]
|
||||
|
||||
135
py/rag-service/README.md
Normal file
135
py/rag-service/README.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# RAG Service Configuration
|
||||
|
||||
This document describes how to configure the RAG service, including setting up Language Model (LLM) and Embedding providers.
|
||||
|
||||
## Provider Support Matrix
|
||||
|
||||
The following table shows which model types are supported by each provider:
|
||||
|
||||
| Provider | LLM Support | Embedding Support |
|
||||
| ---------- | ----------- | ----------------- |
|
||||
| dashscope | Yes | Yes |
|
||||
| ollama | Yes | Yes |
|
||||
| openai | Yes | Yes |
|
||||
| openrouter | Yes | No |
|
||||
|
||||
## LLM Provider Configuration
|
||||
|
||||
The `llm` section in the configuration file is used to configure the Language Model (LLM) used by the RAG service.
|
||||
|
||||
Here are the configuration examples for each supported LLM provider:
|
||||
|
||||
### OpenAI LLM Configuration
|
||||
|
||||
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py#L130)
|
||||
|
||||
```lua
|
||||
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
|
||||
provider = "openai", -- The LLM provider ("openai")
|
||||
endpoint = "https://api.openai.com/v1", -- The LLM API endpoint
|
||||
api_key = "OPENAI_API_KEY", -- The environment variable name for the LLM API key
|
||||
model = "gpt-4o-mini", -- The LLM model name (e.g., "gpt-4o-mini", "gpt-3.5-turbo")
|
||||
extra = {-- Extra configuration options for the LLM (optional)
|
||||
temperature = 0.7, -- Controls the randomness of the output. Lower values make it more deterministic.
|
||||
max_tokens = 512, -- The maximum number of tokens to generate in the completion.
|
||||
-- system_prompt = "You are a helpful assistant.", -- A system prompt to guide the model's behavior.
|
||||
-- timeout = 120, -- Request timeout in seconds.
|
||||
},
|
||||
},
|
||||
```
|
||||
|
||||
### DashScope LLM Configuration
|
||||
|
||||
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-dashscope/llama_index/llms/dashscope/base.py#L155)
|
||||
|
||||
```lua
|
||||
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
|
||||
provider = "dashscope", -- The LLM provider ("dashscope")
|
||||
endpoint = "", -- The LLM API endpoint (DashScope typically uses default or environment variables)
|
||||
api_key = "DASHSCOPE_API_KEY", -- The environment variable name for the LLM API key
|
||||
model = "qwen-plus", -- The LLM model name (e.g., "qwen-plus", "qwen-max")
|
||||
extra = nil, -- Extra configuration options for the LLM (optional)
|
||||
},
|
||||
```
|
||||
|
||||
### Ollama LLM Configuration
|
||||
|
||||
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py#L65)
|
||||
|
||||
```lua
|
||||
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
|
||||
provider = "ollama", -- The LLM provider ("ollama")
|
||||
endpoint = "http://localhost:11434", -- The LLM API endpoint for Ollama
|
||||
api_key = "", -- Ollama typically does not require an API key
|
||||
model = "llama2", -- The LLM model name (e.g., "llama2", "mistral")
|
||||
extra = nil, -- Extra configuration options for the LLM (optional) Kristin", -- Extra configuration options for the LLM (optional)
|
||||
},
|
||||
```
|
||||
|
||||
### OpenRouter LLM Configuration
|
||||
|
||||
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-openrouter/llama_index/llms/openrouter/base.py#L17)
|
||||
|
||||
```lua
|
||||
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
|
||||
provider = "openrouter", -- The LLM provider ("openrouter")
|
||||
endpoint = "https://openrouter.ai/api/v1", -- The LLM API endpoint for OpenRouter
|
||||
api_key = "OPENROUTER_API_KEY", -- The environment variable name for the LLM API key
|
||||
model = "openai/gpt-4o-mini", -- The LLM model name (e.g., "openai/gpt-4o-mini", "mistralai/mistral-7b-instruct")
|
||||
extra = nil, -- Extra configuration options for the LLM (optional)
|
||||
},
|
||||
```
|
||||
|
||||
## Embedding Provider Configuration
|
||||
|
||||
The `embedding` section in the configuration file is used to configure the Embedding Model used by the RAG service.
|
||||
|
||||
Here are the configuration examples for each supported Embedding provider:
|
||||
|
||||
### OpenAI Embedding Configuration
|
||||
|
||||
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-openai/llama_index/embeddings/openai/base.py#L214)
|
||||
|
||||
```lua
|
||||
embed = { -- Configuration for the Embedding Model used by the RAG service
|
||||
provider = "openai", -- The Embedding provider ("openai")
|
||||
endpoint = "https://api.openai.com/v1", -- The Embedding API endpoint
|
||||
api_key = "OPENAI_API_KEY", -- The environment variable name for the Embedding API key
|
||||
model = "text-embedding-3-large", -- The Embedding model name (e.g., "text-embedding-3-small", "text-embedding-3-large")
|
||||
extra = {-- Extra configuration options for the Embedding model (optional)
|
||||
dimensions = nil,
|
||||
},
|
||||
},
|
||||
```
|
||||
|
||||
### DashScope Embedding Configuration
|
||||
|
||||
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-dashscope/llama_index/embeddings/dashscope/base.py#L156)
|
||||
|
||||
```lua
|
||||
embed = { -- Configuration for the Embedding Model used by the RAG service
|
||||
provider = "dashscope", -- The Embedding provider ("dashscope")
|
||||
endpoint = "", -- The Embedding API endpoint (DashScope typically uses default or environment variables)
|
||||
api_key = "DASHSCOPE_API_KEY", -- The environment variable name for the Embedding API key
|
||||
model = "text-embedding-v3", -- The Embedding model name (e.g., "text-embedding-v2")
|
||||
extra = { -- Extra configuration options for the Embedding model (optional)
|
||||
embed_batch_size = 10,
|
||||
},
|
||||
},
|
||||
```
|
||||
|
||||
### Ollama Embedding Configuration
|
||||
|
||||
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py#L12)
|
||||
|
||||
```lua
|
||||
embed = { -- Configuration for the Embedding Model used by the RAG service
|
||||
provider = "ollama", -- The Embedding provider ("ollama")
|
||||
endpoint = "http://localhost:11434", -- The Embedding API endpoint for Ollama
|
||||
api_key = "", -- Ollama typically does not require an API key
|
||||
model = "nomic-embed-text", -- The Embedding model name (e.g., "nomic-embed-text")
|
||||
extra = { -- Extra configuration options for the Embedding model (optional)
|
||||
embed_batch_size = 10,
|
||||
},
|
||||
},
|
||||
```
|
||||
@@ -17,15 +17,17 @@ chroma-hnswlib==0.7.6
|
||||
chromadb==0.6.3
|
||||
click==8.1.8
|
||||
coloredlogs==15.0.1
|
||||
dashscope==1.22.2
|
||||
dataclasses-json==0.6.7
|
||||
decorator==5.1.1
|
||||
Deprecated==1.2.18
|
||||
deprecated==1.2.18
|
||||
dirtyjson==1.0.8
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
docx2txt==0.8
|
||||
durationpy==0.9
|
||||
email_validator==2.2.0
|
||||
et_xmlfile==2.0.0
|
||||
email-validator==2.2.0
|
||||
et-xmlfile==2.0.0
|
||||
executing==2.2.0
|
||||
fastapi==0.115.8
|
||||
fastapi-cli==0.0.7
|
||||
@@ -42,15 +44,15 @@ h11==0.14.0
|
||||
httpcore==1.0.7
|
||||
httptools==0.6.4
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.28.1
|
||||
huggingface-hub==0.31.4
|
||||
humanfriendly==10.0
|
||||
idna==3.10
|
||||
importlib_metadata==8.5.0
|
||||
importlib_resources==6.5.2
|
||||
importlib-metadata==8.5.0
|
||||
importlib-resources==6.5.2
|
||||
ipdb==0.13.13
|
||||
ipython==8.32.0
|
||||
jedi==0.19.2
|
||||
Jinja2==3.1.5
|
||||
jinja2==3.1.5
|
||||
jiter==0.8.2
|
||||
joblib==1.4.2
|
||||
kubernetes==32.0.0
|
||||
@@ -60,11 +62,15 @@ llama-index==0.12.16
|
||||
llama-index-agent-openai==0.4.3
|
||||
llama-index-cli==0.4.0
|
||||
llama-index-core==0.12.16.post1
|
||||
llama-index-embeddings-openai==0.3.1
|
||||
llama-index-embeddings-dashscope==0.3.0
|
||||
llama-index-embeddings-ollama==0.5.0
|
||||
llama-index-embeddings-openai==0.3.1
|
||||
llama-index-indices-managed-llama-cloud==0.6.4
|
||||
llama-index-llms-openai==0.3.18
|
||||
llama-index-llms-dashscope==0.3.3
|
||||
llama-index-llms-ollama==0.5.2
|
||||
llama-index-llms-openai==0.3.18
|
||||
llama-index-llms-openai-like==0.3.4
|
||||
llama-index-llms-openrouter==0.3.1
|
||||
llama-index-multi-modal-llms-openai==0.4.3
|
||||
llama-index-program-openai==0.3.1
|
||||
llama-index-question-gen-openai==0.3.0
|
||||
@@ -74,7 +80,7 @@ llama-index-vector-stores-chroma==0.4.1
|
||||
llama-parse==0.6.0
|
||||
markdown-it-py==3.0.0
|
||||
markdownify==0.14.1
|
||||
MarkupSafe==3.0.2
|
||||
markupsafe==3.0.2
|
||||
marshmallow==3.26.1
|
||||
matplotlib-inline==0.1.7
|
||||
mdurl==0.1.2
|
||||
@@ -88,6 +94,7 @@ networkx==3.4.2
|
||||
nltk==3.9.1
|
||||
numpy==2.2.2
|
||||
oauthlib==3.2.2
|
||||
ollama==0.4.8
|
||||
onnxruntime==1.20.1
|
||||
openai==1.61.1
|
||||
openpyxl==3.1.5
|
||||
@@ -110,35 +117,36 @@ pathspec==0.12.1
|
||||
pexpect==4.9.0
|
||||
pillow==11.1.0
|
||||
posthog==3.11.0
|
||||
prompt_toolkit==3.0.50
|
||||
prompt-toolkit==3.0.50
|
||||
propcache==0.2.1
|
||||
protobuf==5.29.3
|
||||
ptyprocess==0.7.0
|
||||
pure_eval==0.2.3
|
||||
pure-eval==0.2.3
|
||||
pyasn1==0.6.1
|
||||
pyasn1_modules==0.4.1
|
||||
pyasn1-modules==0.4.1
|
||||
pydantic==2.10.6
|
||||
pydantic_core==2.27.2
|
||||
Pygments==2.19.1
|
||||
pydantic-core==2.27.2
|
||||
pygments==2.19.1
|
||||
pypdf==5.2.0
|
||||
PyPika==0.48.9
|
||||
pyproject_hooks==1.2.0
|
||||
pypika==0.48.9
|
||||
pyproject-hooks==1.2.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
python-multipart==0.0.20
|
||||
pytz==2025.1
|
||||
PyYAML==6.0.2
|
||||
pyyaml==6.0.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
requests-oauthlib==2.0.0
|
||||
rich==13.9.4
|
||||
rich-toolkit==0.13.2
|
||||
rsa==4.9
|
||||
safetensors==0.5.3
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.6
|
||||
SQLAlchemy==2.0.38
|
||||
sqlalchemy==2.0.38
|
||||
stack-data==0.6.3
|
||||
starlette==0.45.3
|
||||
striprtf==0.0.26
|
||||
@@ -148,11 +156,15 @@ tiktoken==0.8.0
|
||||
tokenizers==0.21.0
|
||||
tqdm==4.67.1
|
||||
traitlets==5.14.3
|
||||
transformers==4.51.3
|
||||
tree-sitter==0.24.0
|
||||
tree-sitter-c-sharp==0.23.1
|
||||
tree-sitter-embedded-template==0.23.2
|
||||
tree-sitter-language-pack==0.6.1
|
||||
tree-sitter-yaml==0.7.0
|
||||
typer==0.15.1
|
||||
typing-extensions==4.12.2
|
||||
typing-inspect==0.9.0
|
||||
typing_extensions==4.12.2
|
||||
tzdata==2025.1
|
||||
urllib3==2.3.0
|
||||
uvicorn==0.34.0
|
||||
@@ -165,4 +177,3 @@ websockets==14.2
|
||||
wrapt==1.17.2
|
||||
yarl==1.18.3
|
||||
zipp==3.21.0
|
||||
docx2txt==0.8.0
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Standard library imports
|
||||
import asyncio
|
||||
import fcntl
|
||||
import json
|
||||
@@ -18,42 +19,24 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
# Third-party imports
|
||||
import chromadb
|
||||
import httpx
|
||||
import pathspec
|
||||
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
||||
from libs.configs import (
|
||||
BASE_DATA_DIR,
|
||||
CHROMA_PERSIST_DIR,
|
||||
)
|
||||
|
||||
# Local application imports
|
||||
from libs.configs import BASE_DATA_DIR, CHROMA_PERSIST_DIR
|
||||
from libs.db import init_db
|
||||
from libs.logger import logger
|
||||
from libs.utils import (
|
||||
get_node_uri,
|
||||
inject_uri_to_node,
|
||||
is_local_uri,
|
||||
is_path_node,
|
||||
is_remote_uri,
|
||||
path_to_uri,
|
||||
uri_to_path,
|
||||
)
|
||||
from llama_index.core import (
|
||||
Settings,
|
||||
SimpleDirectoryReader,
|
||||
StorageContext,
|
||||
VectorStoreIndex,
|
||||
load_index_from_storage,
|
||||
)
|
||||
from libs.utils import get_node_uri, inject_uri_to_node, is_local_uri, is_path_node, is_remote_uri, path_to_uri, uri_to_path
|
||||
from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.node_parser import CodeSplitter
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from markdownify import markdownify as md
|
||||
from models.indexing_history import IndexingHistory # noqa: TC002
|
||||
from models.resource import Resource
|
||||
from providers.factory import initialize_embed_model, initialize_llm_model
|
||||
from pydantic import BaseModel, Field
|
||||
from services.indexing_history import indexing_history_service
|
||||
from services.resource import resource_service
|
||||
@@ -65,6 +48,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from models.indexing_history import IndexingHistory
|
||||
from watchdog.observers.api import BaseObserver
|
||||
|
||||
# Lock file for leader election
|
||||
@@ -111,8 +95,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if is_local_uri(resource.uri):
|
||||
directory = uri_to_path(resource.uri)
|
||||
if not directory.exists():
|
||||
logger.error("Directory not found: %s", directory)
|
||||
resource_service.update_resource_status(resource.uri, "error", f"Directory not found: {directory}")
|
||||
error_msg = f"Directory not found: {directory}"
|
||||
logger.error(error_msg)
|
||||
resource_service.update_resource_status(resource.uri, "error", error_msg)
|
||||
continue
|
||||
|
||||
# Start file system watcher
|
||||
@@ -127,8 +112,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
|
||||
elif is_remote_uri(resource.uri):
|
||||
if not is_remote_resource_exists(resource.uri):
|
||||
logger.error("HTTPS resource not found: %s", resource.uri)
|
||||
resource_service.update_resource_status(resource.uri, "error", "remote resource not found")
|
||||
error_msg = "HTTPS resource not found"
|
||||
logger.error("%s: %s", error_msg, resource.uri)
|
||||
resource_service.update_resource_status(resource.uri, "error", error_msg)
|
||||
continue
|
||||
|
||||
# Start indexing
|
||||
@@ -273,7 +259,11 @@ def is_remote_resource_exists(url: str) -> bool:
|
||||
"""Check if a URL exists."""
|
||||
try:
|
||||
response = httpx.head(url, headers=http_headers)
|
||||
return response.status_code in {httpx.codes.OK, httpx.codes.MOVED_PERMANENTLY, httpx.codes.FOUND}
|
||||
return response.status_code in {
|
||||
httpx.codes.OK,
|
||||
httpx.codes.MOVED_PERMANENTLY,
|
||||
httpx.codes.FOUND,
|
||||
}
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
logger.error("Error checking if URL exists %s: %s", url, e)
|
||||
return False
|
||||
@@ -318,53 +308,78 @@ init_db()
|
||||
# Initialize ChromaDB and LlamaIndex services
|
||||
chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR))
|
||||
|
||||
# Check if provider or model has changed
|
||||
current_provider = os.getenv("RAG_PROVIDER", "openai").lower()
|
||||
current_embed_model = os.getenv("RAG_EMBED_MODEL", "")
|
||||
current_llm_model = os.getenv("RAG_LLM_MODEL", "")
|
||||
# # Check if provider or model has changed
|
||||
rag_embed_provider = os.getenv("RAG_EMBED_PROVIDER", "openai")
|
||||
rag_embed_endpoint = os.getenv("RAG_EMBED_ENDPOINT", "https://api.openai.com/v1")
|
||||
rag_embed_model = os.getenv("RAG_EMBED_MODEL", "text-embedding-3-large")
|
||||
rag_embed_api_key = os.getenv("RAG_EMBED_API_KEY", None)
|
||||
rag_embed_extra = os.getenv("RAG_EMBED_EXTRA", None)
|
||||
|
||||
rag_llm_provider = os.getenv("RAG_LLM_PROVIDER", "openai")
|
||||
rag_llm_endpoint = os.getenv("RAG_LLM_ENDPOINT", "https://api.openai.com/v1")
|
||||
rag_llm_model = os.getenv("RAG_LLM_MODEL", "gpt-4o-mini")
|
||||
rag_llm_api_key = os.getenv("RAG_LLM_API_KEY", None)
|
||||
rag_llm_extra = os.getenv("RAG_LLM_EXTRA", None)
|
||||
|
||||
# Try to read previous config
|
||||
config_file = BASE_DATA_DIR / "rag_config.json"
|
||||
if config_file.exists():
|
||||
with Path.open(config_file, "r") as f:
|
||||
prev_config = json.load(f)
|
||||
if prev_config.get("provider") != current_provider or prev_config.get("embed_model") != current_embed_model:
|
||||
if prev_config.get("provider") != rag_embed_provider or prev_config.get("embed_model") != rag_embed_model:
|
||||
# Clear existing data if config changed
|
||||
logger.info("Detected config change, clearing existing data...")
|
||||
chroma_client.reset()
|
||||
|
||||
# Save current config
|
||||
with Path.open(config_file, "w") as f:
|
||||
json.dump({"provider": current_provider, "embed_model": current_embed_model}, f)
|
||||
json.dump({"provider": rag_embed_provider, "embed_model": rag_embed_model}, f)
|
||||
|
||||
chroma_collection = chroma_client.get_or_create_collection("documents") # pyright: ignore
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
|
||||
# Initialize embedding model based on provider
|
||||
llm_provider = current_provider
|
||||
base_url = os.getenv(llm_provider.upper() + "_API_BASE", "")
|
||||
rag_embed_model = current_embed_model
|
||||
rag_llm_model = current_llm_model
|
||||
try:
|
||||
embed_extra = json.loads(rag_embed_extra) if rag_embed_extra is not None else {}
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode RAG_EMBED_EXTRA, defaulting to empty dict.")
|
||||
embed_extra = {}
|
||||
|
||||
try:
|
||||
llm_extra = json.loads(rag_llm_extra) if rag_llm_extra is not None else {}
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode RAG_LLM_EXTRA, defaulting to empty dict.")
|
||||
llm_extra = {}
|
||||
|
||||
# Initialize embedding model and LLM based on provider using the factory
|
||||
try:
|
||||
embed_model = initialize_embed_model(
|
||||
embed_provider=rag_embed_provider,
|
||||
embed_model=rag_embed_model,
|
||||
embed_endpoint=rag_embed_endpoint,
|
||||
embed_api_key=rag_embed_api_key,
|
||||
embed_extra=embed_extra,
|
||||
)
|
||||
logger.info("Embedding model initialized successfully.")
|
||||
except (ValueError, RuntimeError) as e:
|
||||
error_msg = f"Failed to initialize embedding model: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
try:
|
||||
llm_model = initialize_llm_model(
|
||||
llm_provider=rag_llm_provider,
|
||||
llm_model=rag_llm_model,
|
||||
llm_endpoint=rag_llm_endpoint,
|
||||
llm_api_key=rag_llm_api_key,
|
||||
llm_extra=llm_extra,
|
||||
)
|
||||
logger.info("LLM model initialized successfully.")
|
||||
except (ValueError, RuntimeError) as e:
|
||||
error_msg = f"Failed to initialize LLM model: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
if llm_provider == "ollama":
|
||||
if base_url == "":
|
||||
base_url = "http://localhost:11434"
|
||||
if rag_embed_model == "":
|
||||
rag_embed_model = "nomic-embed-text"
|
||||
if rag_llm_model == "":
|
||||
rag_llm_model = "llama3"
|
||||
embed_model = OllamaEmbedding(model_name=rag_embed_model, base_url=base_url)
|
||||
llm_model = Ollama(model=rag_llm_model, base_url=base_url, request_timeout=60.0)
|
||||
else:
|
||||
if base_url == "":
|
||||
base_url = "https://api.openai.com/v1"
|
||||
if rag_embed_model == "":
|
||||
rag_embed_model = "text-embedding-3-small"
|
||||
if rag_llm_model == "":
|
||||
rag_llm_model = "gpt-4o-mini"
|
||||
embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url)
|
||||
llm_model = OpenAI(model=rag_llm_model, api_base=base_url)
|
||||
|
||||
Settings.embed_model = embed_model
|
||||
Settings.llm = llm_model
|
||||
@@ -400,7 +415,10 @@ class SourceDocument(BaseModel):
|
||||
class RetrieveRequest(BaseModel):
|
||||
"""Request model for information retrieval."""
|
||||
|
||||
query: str = Field(..., description="The query text to search for in the indexed documents")
|
||||
query: str = Field(
|
||||
...,
|
||||
description="The query text to search for in the indexed documents",
|
||||
)
|
||||
base_uri: str = Field(..., description="The base URI to search in")
|
||||
top_k: int | None = Field(5, description="Number of top results to return", ge=1, le=20)
|
||||
|
||||
@@ -478,7 +496,10 @@ def process_document_batch(documents: list[Document]) -> bool: # noqa: PLR0915,
|
||||
# Check if document with same hash has already been successfully processed
|
||||
status_records = indexing_history_service.get_indexing_status(doc=doc)
|
||||
if status_records and status_records[0].status == "completed":
|
||||
logger.debug("Document with same hash already processed, skipping: %s", doc.doc_id)
|
||||
logger.debug(
|
||||
"Document with same hash already processed, skipping: %s",
|
||||
doc.doc_id,
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug("Processing document: %s", doc.doc_id)
|
||||
@@ -592,7 +613,10 @@ def get_gitcrypt_files(directory: Path) -> list[str]:
|
||||
)
|
||||
|
||||
if git_root_cmd.returncode != 0:
|
||||
logger.warning("Not a git repository or git command failed: %s", git_root_cmd.stderr.strip())
|
||||
logger.warning(
|
||||
"Not a git repository or git command failed: %s",
|
||||
git_root_cmd.stderr.strip(),
|
||||
)
|
||||
return git_crypt_patterns
|
||||
|
||||
git_root = Path(git_root_cmd.stdout.strip())
|
||||
@@ -613,7 +637,15 @@ def get_gitcrypt_files(directory: Path) -> list[str]:
|
||||
|
||||
# Use Python to process the output instead of xargs, grep, and cut
|
||||
git_check_attr = subprocess.run(
|
||||
[git_executable, "-C", str(git_root), "check-attr", "filter", "--stdin", "-z"],
|
||||
[
|
||||
git_executable,
|
||||
"-C",
|
||||
str(git_root),
|
||||
"check-attr",
|
||||
"filter",
|
||||
"--stdin",
|
||||
"-z",
|
||||
],
|
||||
input=git_ls_files.stdout,
|
||||
capture_output=True,
|
||||
text=False,
|
||||
@@ -830,7 +862,11 @@ def split_documents(documents: list[Document]) -> list[Document]:
|
||||
t = doc.get_content()
|
||||
texts = code_splitter.split_text(t)
|
||||
except ValueError as e:
|
||||
logger.error("Error splitting document: %s, so skipping split, error: %s", doc.doc_id, str(e))
|
||||
logger.error(
|
||||
"Error splitting document: %s, so skipping split, error: %s",
|
||||
doc.doc_id,
|
||||
str(e),
|
||||
)
|
||||
processed_documents.append(doc)
|
||||
continue
|
||||
|
||||
@@ -1034,7 +1070,10 @@ async def add_resource(request: ResourceRequest, background_tasks: BackgroundTas
|
||||
|
||||
if resource:
|
||||
if resource.name != request.name:
|
||||
raise HTTPException(status_code=400, detail=f"Resource name cannot be changed: {resource.name}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Resource name cannot be changed: {resource.name}",
|
||||
)
|
||||
|
||||
resource_service.update_resource_status(resource.uri, "active")
|
||||
else:
|
||||
@@ -1352,3 +1391,9 @@ async def list_resources() -> ResourceListResponse:
|
||||
total_count=len(resources),
|
||||
status_summary=status_counts,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
|
||||
0
py/rag-service/src/providers/__init__.py
Normal file
0
py/rag-service/src/providers/__init__.py
Normal file
70
py/rag-service/src/providers/dashscope.py
Normal file
70
py/rag-service/src/providers/dashscope.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# src/providers/dashscope.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.dashscope import DashScopeEmbedding
|
||||
from llama_index.llms.dashscope import DashScope
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str, # noqa: ARG001
|
||||
embed_api_key: str,
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create DashScope embedding model.
|
||||
|
||||
Args:
|
||||
embed_endpoint: Not be used directly by the constructor.
|
||||
embed_api_key: The API key for the DashScope API.
|
||||
embed_model: The name of the embedding model.
|
||||
embed_extra: Extra parameters of the embedding model.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# DashScope typically uses the API key and model name.
|
||||
# The endpoint might be set via environment variables or default.
|
||||
# We pass embed_api_key and embed_model to the constructor.
|
||||
# We include embed_endpoint in the signature to match the factory interface,
|
||||
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
|
||||
return DashScopeEmbedding(
|
||||
model_name=embed_model,
|
||||
api_key=embed_api_key,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str, # noqa: ARG001
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create DashScope LLM model.
|
||||
|
||||
Args:
|
||||
llm_endpoint: Not be used directly by the constructor.
|
||||
llm_api_key: The API key for the DashScope API.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: Extra parameters of the LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# DashScope typically uses the API key and model name.
|
||||
# The endpoint might be set via environment variables or default.
|
||||
# We pass llm_api_key and llm_model to the constructor.
|
||||
# We include llm_endpoint in the signature to match the factory interface,
|
||||
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
|
||||
return DashScope(
|
||||
model_name=llm_model,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
179
py/rag-service/src/providers/factory.py
Normal file
179
py/rag-service/src/providers/factory.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from libs.logger import logger # Assuming libs.logger exists and provides a logger instance
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_provider: str,
|
||||
embed_model: str,
|
||||
embed_endpoint: str | None = None,
|
||||
embed_api_key: str | None = None,
|
||||
embed_extra: dict[str, Any] | None = None,
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Initialize embedding model based on specified provider and configuration.
|
||||
|
||||
Dynamically loads the provider module based on the embed_provider parameter.
|
||||
|
||||
Args:
|
||||
embed_provider: The name of the embedding provider (e.g., "openai", "ollama").
|
||||
embed_model: The name of the embedding model.
|
||||
embed_endpoint: The API endpoint for the embedding provider.
|
||||
embed_api_key: The API key for the embedding provider.
|
||||
embed_extra: Additional provider-specific configuration parameters.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified embed_provider is not supported or module/function not found.
|
||||
RuntimeError: If model initialization fails for the selected provider.
|
||||
|
||||
"""
|
||||
# Validate provider name
|
||||
error_msg = f"Invalid EMBED_PROVIDER specified: '{embed_provider}'. Provider name must be alphanumeric or contain underscores."
|
||||
if not embed_provider.replace("_", "").isalnum():
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
provider_module = importlib.import_module(f".{embed_provider}", package="providers")
|
||||
logger.debug(f"Successfully imported provider module: providers.{embed_provider}")
|
||||
attribute = getattr(provider_module, "initialize_embed_model", None)
|
||||
if attribute is None:
|
||||
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
|
||||
raise ValueError(error_msg) # noqa: TRY301
|
||||
|
||||
initializer = cast("Callable[..., BaseEmbedding]", attribute)
|
||||
|
||||
except ImportError as err:
|
||||
error_msg = f"Unsupported EMBED_PROVIDER specified: '{embed_provider}'. Could not find provider module 'providers.{embed_provider}"
|
||||
raise ValueError(error_msg) from err
|
||||
except AttributeError as err:
|
||||
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
|
||||
raise ValueError(error_msg) from err
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"An unexpected error occurred while loading provider '{embed_provider}': {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
error_msg = f"Failed to load provider '{embed_provider}' due to an unexpected error."
|
||||
raise RuntimeError(error_msg) from err
|
||||
|
||||
logger.info(f"Initializing embedding model for provider: {embed_provider}")
|
||||
|
||||
try:
|
||||
embedding: BaseEmbedding = initializer(
|
||||
embed_endpoint,
|
||||
embed_api_key,
|
||||
embed_model,
|
||||
**(embed_extra or {}),
|
||||
)
|
||||
|
||||
logger.info(f"Embedding model initialized successfully for {embed_provider}")
|
||||
return embedding
|
||||
except TypeError as err:
|
||||
error_msg = f"Provider initializer 'initialize_embed_model' was called with incorrect arguments in '{embed_provider}'"
|
||||
logger.error(
|
||||
f"{error_msg}: {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from err
|
||||
except Exception as err:
|
||||
error_msg = f"Failed to initialize embedding model for provider '{embed_provider}'"
|
||||
logger.error(
|
||||
f"{error_msg}: {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from err
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_provider: str,
|
||||
llm_model: str,
|
||||
llm_endpoint: str | None = None,
|
||||
llm_api_key: str | None = None,
|
||||
llm_extra: dict[str, Any] | None = None,
|
||||
) -> LLM:
|
||||
"""
|
||||
Create LLM model with the specified configuration.
|
||||
|
||||
Dynamically loads the provider module based on the llm_provider parameter.
|
||||
|
||||
Args:
|
||||
llm_provider: The name of the LLM provider (e.g., "openai", "ollama").
|
||||
llm_endpoint: The API endpoint for the LLM provider.
|
||||
llm_api_key: The API key for the LLM provider.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: The name of the LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified llm_provider is not supported or module/function not found.
|
||||
RuntimeError: If model initialization fails for the selected provider.
|
||||
|
||||
"""
|
||||
if not llm_provider.replace("_", "").isalnum():
|
||||
error_msg = f"Invalid LLM_PROVIDER specified: '{llm_provider}'. Provider name must be alphanumeric or contain underscores."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
provider_module = importlib.import_module(
|
||||
f".{llm_provider}",
|
||||
package="providers",
|
||||
)
|
||||
logger.debug(f"Successfully imported provider module: providers.{llm_provider}")
|
||||
attribute = getattr(provider_module, "initialize_llm_model", None)
|
||||
if attribute is None:
|
||||
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
|
||||
raise ValueError(error_msg) # noqa: TRY301
|
||||
|
||||
initializer = cast("Callable[..., LLM]", attribute)
|
||||
|
||||
except ImportError as err:
|
||||
error_msg = f"Unsupported LLM_PROVIDER specified: '{llm_provider}'. Could not find provider module 'providers.{llm_provider}'."
|
||||
raise ValueError(error_msg) from err
|
||||
|
||||
except AttributeError as err:
|
||||
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
|
||||
raise ValueError(error_msg) from err
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"An unexpected error occurred while loading provider '{llm_provider}': {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
logger.info(f"Initializing LLM model for provider: '{llm_provider}'")
|
||||
logger.debug(f"Args: llm_model='{llm_model}', llm_endpoint='{llm_endpoint}'")
|
||||
|
||||
try:
|
||||
llm: LLM = initializer(
|
||||
llm_endpoint,
|
||||
llm_api_key,
|
||||
llm_model,
|
||||
**(llm_extra or {}),
|
||||
)
|
||||
logger.info(f"LLM model initialized successfully for '{llm_provider}'.")
|
||||
|
||||
except TypeError as e:
|
||||
error_msg = f"Provider initializer 'initialize_llm_model' in '{llm_provider}' was called with incorrect arguments: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to initialize LLM model for provider '{llm_provider}': {e}"
|
||||
logger.error(
|
||||
error_msg,
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
return llm
|
||||
66
py/rag-service/src/providers/ollama.py
Normal file
66
py/rag-service/src/providers/ollama.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# src/providers/ollama.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str,
|
||||
embed_api_key: str, # noqa: ARG001
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create Ollama embedding model.
|
||||
|
||||
Args:
|
||||
embed_endpoint: The API endpoint for the Ollama API.
|
||||
embed_api_key: Not be used by Ollama.
|
||||
embed_model: The name of the embedding model.
|
||||
embed_extra: Extra parameters for Ollama embedding model.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# Ollama typically uses the endpoint directly and may not require an API key
|
||||
# We include embed_api_key in the signature to match the factory interface
|
||||
# Pass embed_api_key even if Ollama doesn't use it, to match the signature
|
||||
return OllamaEmbedding(
|
||||
model_name=embed_model,
|
||||
base_url=embed_endpoint,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str, # noqa: ARG001
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create Ollama LLM model.
|
||||
|
||||
Args:
|
||||
llm_endpoint: The API endpoint for the Ollama API.
|
||||
llm_api_key: Not be used by Ollama.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: Extra parameters for LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Ollama typically uses the endpoint directly and may not require an API key
|
||||
# We include llm_api_key in the signature to match the factory interface
|
||||
# Pass llm_api_key even if Ollama doesn't use it, to match the signature
|
||||
return Ollama(
|
||||
model=llm_model,
|
||||
base_url=llm_endpoint,
|
||||
**llm_extra,
|
||||
)
|
||||
68
py/rag-service/src/providers/openai.py
Normal file
68
py/rag-service/src/providers/openai.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# src/providers/openai.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str,
|
||||
embed_api_key: str,
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create OpenAI embedding model.
|
||||
|
||||
Args:
|
||||
embed_model: The name of the embedding model.
|
||||
embed_endpoint: The API endpoint for the OpenAI API.
|
||||
embed_api_key: The API key for the OpenAI API.
|
||||
embed_extra: Extra Paramaters for the OpenAI API.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# Note: OpenAIEmbedding automatically picks up OPENAI_API_KEY env var
|
||||
# We are not using embed_api_key parameter here, relying on env var as original code did.
|
||||
return OpenAIEmbedding(
|
||||
model=embed_model,
|
||||
api_base=embed_endpoint,
|
||||
api_key=embed_api_key,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create OpenAI LLM model.
|
||||
|
||||
Args:
|
||||
llm_model: The name of the LLM model.
|
||||
llm_endpoint: The API endpoint for the OpenAI API.
|
||||
llm_api_key: The API key for the OpenAI API.
|
||||
llm_extra: Extra paramaters for the OpenAI API.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# Note: OpenAI automatically picks up OPENAI_API_KEY env var
|
||||
# We are not using llm_api_key parameter here, relying on env var as original code did.
|
||||
return OpenAI(
|
||||
model=llm_model,
|
||||
api_base=llm_endpoint,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
35
py/rag-service/src/providers/openrouter.py
Normal file
35
py/rag-service/src/providers/openrouter.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# src/providers/openrouter.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.llms.openrouter import OpenRouter
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create OpenRouter LLM model.
|
||||
|
||||
Args:
|
||||
llm_model: The name of the LLM model.
|
||||
llm_endpoint: The API endpoint for the OpenRouter API.
|
||||
llm_api_key: The API key for the OpenRouter API.
|
||||
llm_extra: The Extra Parameters for OpenROuter,
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# We are not using llm_api_key parameter here, relying on env var as original code did.
|
||||
return OpenRouter(
|
||||
model=llm_model,
|
||||
api_base=llm_endpoint,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
Reference in New Issue
Block a user