feat: Enhanced Model Provider Support and Configuration Flexibility For Rag Service (#2056)

Co-authored-by: doodleEsc <cokie@foxmail.com>
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
doodleEsc
2025-06-06 23:07:07 +08:00
committed by GitHub
parent ec0f4f9ae0
commit 2dd4c04088
15 changed files with 844 additions and 151 deletions

View File

@@ -705,7 +705,7 @@ Given its early stage, `avante.nvim` currently supports the following basic func
> model = "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
> aws_profile = "bedrock",
> aws_region = "us-east-1",
>},
> },
> ```
>
> Note: Bedrock requires the [AWS CLI](https://aws.amazon.com/cli/) to be installed on your system.
@@ -884,21 +884,37 @@ For more information, see [Custom Providers](https://github.com/yetone/avante.nv
Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. By default, it is not enabled. You can enable it this way:
```lua
rag_service = {
enabled = false, -- Enables the RAG service
host_mount = os.getenv("HOME"), -- Host mount path for the rag service
provider = "openai", -- The provider to use for RAG service (e.g. openai or ollama)
llm_model = "", -- The LLM model to use for RAG service
embed_model = "", -- The embedding model to use for RAG service
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
},
rag_service = { -- RAG Service configuration
enabled = false, -- Enables the RAG service
host_mount = os.getenv("HOME"), -- Host mount path for the rag service (Docker will mount this path)
runner = "docker", -- Runner for the RAG service (can use docker or nix)
llm = { -- Language Model (LLM) configuration for RAG service
provider = "openai", -- LLM provider
endpoint = "https://api.openai.com/v1", -- LLM API endpoint
api_key = "OPENAI_API_KEY", -- Environment variable name for the LLM API key
model = "gpt-4o-mini", -- LLM model name
extra = nil, -- Additional configuration options for LLM
},
embed = { -- Embedding model configuration for RAG service
provider = "openai", -- Embedding provider
endpoint = "https://api.openai.com/v1", -- Embedding API endpoint
api_key = "OPENAI_API_KEY", -- Environment variable name for the embedding API key
model = "text-embedding-3-large", -- Embedding model name
extra = nil, -- Additional configuration options for the embedding model
},
docker_extra_args = "", -- Extra arguments to pass to the docker command
},
```
If your rag_service provider is `openai`, then you need to set the `OPENAI_API_KEY` environment variable!
The RAG Service can currently configure the LLM and embedding models separately. In the `llm` and `embed` configuration blocks, you can set the following fields:
If your rag_service provider is `ollama`, you need to set the endpoint to `http://localhost:11434` (note there is no `/v1` at the end) or any address of your own ollama server.
- `provider`: Model provider (e.g., "openai", "ollama", "dashscope", and "openrouter")
- `endpoint`: API endpoint
- `api_key`: Environment variable name for the API key
- `model`: Model name
- `extra`: Additional configuration options
If your rag_service provider is `ollama`, when `llm_model` is empty, it defaults to `llama3`, and when `embed_model` is empty, it defaults to `nomic-embed-text`. Please make sure these models are available in your ollama server.
For detailed configuration of different model providers, you can check [here](./py/rag-service/README.md).
Additionally, RAG Service also depends on Docker! (For macOS users, OrbStack is recommended as a Docker alternative).

View File

@@ -747,21 +747,37 @@ Avante 提供了一组默认提供者,但用户也可以创建自己的提供
Avante 提供了一个 RAG 服务,这是一个用于获取 AI 生成代码所需上下文的工具。默认情况下,它未启用。您可以通过以下方式启用它:
```lua
rag_service = {
enabled = false, -- 启用 RAG 服务
host_mount = os.getenv("HOME"), -- RAG 服务的主机挂载路径
provider = "openai", -- 用于 RAG 服务的提供者(例如 openai 或 ollama
llm_model = "", -- 用于 RAG 服务的 LLM 模型
embed_model = "", -- 用于 RAG 服务的嵌入模型
endpoint = "https://api.openai.com/v1", -- RAG 服务的 API 端点
},
rag_service = { -- RAG 服务配置
enabled = false, -- 启用 RAG 服务
host_mount = os.getenv("HOME"), -- RAG 服务的主机挂载路径 (Docker 将挂载此路径)
runner = "docker", -- RAG 服务的运行器 (可以使用 docker 或 nix)
llm = { -- RAG 服务使用的语言模型 (LLM) 配置
provider = "openai", -- LLM 提供者
endpoint = "https://api.openai.com/v1", -- LLM API 端点
api_key = "OPENAI_API_KEY", -- LLM API 密钥的环境变量名称
model = "gpt-4o-mini", -- LLM 模型名称
extra = nil, -- LLM 的额外配置选项
},
embed = { -- RAG 服务使用的嵌入模型配置
provider = "openai", -- 嵌入提供者
endpoint = "https://api.openai.com/v1", -- 嵌入 API 端点
api_key = "OPENAI_API_KEY", -- 嵌入 API 密钥的环境变量名称
model = "text-embedding-3-large", -- 嵌入模型名称
extra = nil, -- 嵌入模型的额外配置选项
},
docker_extra_args = "", -- 传递给 docker 命令的额外参数
},
```
如果您的 rag_service 提供者是 `openai`,那么您需要设置 `OPENAI_API_KEY` 环境变量!
RAG 服务可以单独设置llm模型和嵌入模型。在 `llm``embed` 配置块中,您可以设置以下字段:
如果您的 rag_service 提供者是 `ollama`,您需要将端点设置为 `http://localhost:11434`(注意末尾没有 `/v1`)或您自己的 ollama 服务器的任何地址。
- `provider`: 模型提供者(例如 "openai", "ollama", "dashscope"以及"openrouter"
- `endpoint`: API 端点
- `api_key`: API 密钥的环境变量名称
- `model`: 模型名称
- `extra`: 额外的配置选项
如果您的 rag_service 提供者是 `ollama`,当 `llm_model` 为空时,默认为 `llama3`,当 `embed_model` 为空时,默认为 `nomic-embed-text`。请确保这些模型在您的 ollama 服务器中可用
有关不同模型提供商的详细配置,你可以在[这里](./py/rag-service/README.md)查看
此外RAG 服务还依赖于 Docker对于 macOS 用户,推荐使用 OrbStack 作为 Docker 的替代品)。

View File

@@ -290,7 +290,6 @@ function M.remove_selected_file(filepath)
for _, file in ipairs(files) do
local rel_path = Utils.uniform_path(file)
vim.notify(rel_path)
sidebar.file_selector:remove_selected_file(rel_path)
end
end

View File

@@ -37,14 +37,24 @@ M._defaults = {
tokenizer = "tiktoken",
---@type string | (fun(): string) | nil
system_prompt = nil,
rag_service = {
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
host_mount = os.getenv("HOME"), -- Host mount path for the rag service (docker will mount this path)
runner = "docker", -- The runner for the rag service, (can use docker, or nix)
provider = "openai", -- The provider to use for RAG service. eg: openai or ollama
llm_model = "", -- The LLM model to use for RAG service
embed_model = "", -- The embedding model to use for RAG service
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
rag_service = { -- RAG service configuration
enabled = false, -- Enables the RAG service
host_mount = os.getenv("HOME"), -- Host mount path for the RAG service (Docker will mount this path)
runner = "docker", -- The runner for the RAG service (can use docker or nix)
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
provider = "openai", -- The LLM provider
endpoint = "https://api.openai.com/v1", -- The LLM API endpoint
api_key = "OPENAI_API_KEY", -- The environment variable name for the LLM API key
model = "gpt-4o-mini", -- The LLM model name
extra = nil, -- Extra configuration options for the LLM
},
embed = { -- Configuration for the Embedding model used by the RAG service
provider = "openai", -- The embedding provider
endpoint = "https://api.openai.com/v1", -- The embedding API endpoint
api_key = "OPENAI_API_KEY", -- The environment variable name for the embedding API key
model = "text-embedding-3-large", -- The embedding model name
extra = nil, -- Extra configuration options for the embedding model
},
docker_extra_args = "", -- Extra arguments to pass to the docker command
},
web_search_engine = {

View File

@@ -8,7 +8,7 @@ local M = {}
local container_name = "avante-rag-service"
local service_path = "/tmp/" .. container_name
function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.10" end
function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.11" end
function M.get_rag_service_port() return 20250 end
@@ -35,13 +35,46 @@ function M.get_rag_service_runner() return (Config.rag_service and Config.rag_se
---@param cb fun()
function M.launch_rag_service(cb)
local openai_api_key = os.getenv("OPENAI_API_KEY")
if Config.rag_service.provider == "openai" then
if openai_api_key == nil then
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
--- If Config.rag_service.llm.api_key is nil or empty, llm_api_key will be an empty string.
local llm_api_key = ""
if
Config.rag_service
and Config.rag_service.llm
and Config.rag_service.llm.api_key
and Config.rag_service.llm.api_key ~= ""
then
llm_api_key = os.getenv(Config.rag_service.llm.api_key) or ""
if llm_api_key == nil or llm_api_key == "" then
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.llm.api_key))
return
end
end
--- If Config.rag_service.embed.api_key is nil or empty, embed_api_key will be an empty string.
local embed_api_key = ""
if
Config.rag_service
and Config.rag_service.embed
and Config.rag_service.embed.api_key
and Config.rag_service.embed.api_key ~= ""
then
embed_api_key = os.getenv(Config.rag_service.embed.api_key) or ""
if embed_api_key == nil or embed_api_key == "" then
error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.embed.api_key))
return
end
end
local embed_extra = "{}" -- Default to empty JSON object string
if Config.rag_service and Config.rag_service.embed and Config.rag_service.embed.extra then
embed_extra = vim.json.encode(Config.rag_service.embed.extra):gsub('"', '\\"')
end
local llm_extra = "{}" -- Default to empty JSON object string
if Config.rag_service and Config.rag_service.llm and Config.rag_service.llm.extra then
llm_extra = vim.json.encode(Config.rag_service.llm.extra):gsub('"', '\\"')
end
local port = M.get_rag_service_port()
if M.get_rag_service_runner() == "docker" then
@@ -74,19 +107,22 @@ function M.launch_rag_service(cb)
M.stop_rag_service()
end
local cmd_ = string.format(
"docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s %s",
"docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_EMBED_PROVIDER=%s -e RAG_EMBED_ENDPOINT=%s -e RAG_EMBED_API_KEY=%s -e RAG_EMBED_MODEL=%s -e RAG_EMBED_EXTRA=%s -e RAG_LLM_PROVIDER=%s -e RAG_LLM_ENDPOINT=%s -e RAG_LLM_API_KEY=%s -e RAG_LLM_MODEL=%s -e RAG_LLM_EXTRA=%s %s %s",
M.get_rag_service_port(),
M.get_rag_service_port(),
container_name,
data_path,
Config.rag_service.host_mount,
Config.rag_service.provider,
Config.rag_service.provider:upper(),
openai_api_key,
Config.rag_service.provider:upper(),
Config.rag_service.endpoint,
Config.rag_service.llm_model,
Config.rag_service.embed_model,
Config.rag_service.embed.provider,
Config.rag_service.embed.endpoint,
embed_api_key,
Config.rag_service.embed.model,
embed_extra,
Config.rag_service.llm.provider,
Config.rag_service.llm.endpoint,
llm_api_key,
Config.rag_service.llm.model,
llm_extra,
Config.rag_service.docker_extra_args,
image
)
@@ -118,17 +154,20 @@ function M.launch_rag_service(cb)
Utils.debug(string.format("launching %s with nix...", container_name))
local cmd = string.format(
"cd %s && ALLOW_RESET=TRUE PORT=%d DATA_DIR=%s RAG_PROVIDER=%s %s_API_KEY=%s %s_API_BASE=%s RAG_LLM_MODEL=%s RAG_EMBED_MODEL=%s sh run.sh %s",
"cd %s && ALLOW_RESET=TRUE PORT=%d DATA_DIR=%s RAG_EMBED_PROVIDER=%s RAG_EMBED_ENDPOINT=%s RAG_EMBED_API_KEY=%s RAG_EMBED_MODEL=%s RAG_EMBED_EXTRA=%s RAG_LLM_PROVIDER=%s RAG_LLM_ENDPOINT=%s RAG_LLM_API_KEY=%s RAG_LLM_MODEL=%s RAG_LLM_EXTRA=%s sh run.sh %s",
rag_service_dir,
port,
service_path,
Config.rag_service.provider,
Config.rag_service.provider:upper(),
openai_api_key,
Config.rag_service.provider:upper(),
Config.rag_service.endpoint,
Config.rag_service.llm_model,
Config.rag_service.embed_model,
Config.rag_service.embed.provider,
Config.rag_service.embed.endpoint,
embed_api_key,
Config.rag_service.embed.model,
embed_extra,
Config.rag_service.llm.provider,
Config.rag_service.llm.endpoint,
embed_api_key,
Config.rag_service.llm.model,
llm_extra,
service_path
)
@@ -211,7 +250,12 @@ function M.to_local_uri(uri)
end
function M.is_ready()
vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url()))
vim.fn.system(
string.format(
"curl -s -o /dev/null -w '%%{http_code}' %s",
string.format("%s%s", M.get_rag_service_url(), "/api/health")
)
)
return vim.v.shell_error == 0
end

View File

@@ -2,21 +2,20 @@ FROM python:3.11-slim-bookworm
WORKDIR /app
RUN apt-get update && apt-get install -y curl git \
&& rm -rf /var/lib/apt/lists/* \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
RUN apt-get update && apt-get install -y --no-install-recommends curl git \
&& rm -rf /var/lib/apt/lists/* \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:$PATH"
ENV PATH="/root/.local/bin:$PATH" \
PYTHONPATH=/app/src \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1
COPY requirements.txt .
RUN uv venv
RUN uv pip install -r requirements.txt
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1
# 直接安装到系统依赖中,不创建虚拟环境
RUN uv pip install --system -r requirements.txt
COPY . .
CMD ["uv", "run", "fastapi", "run", "src/main.py", "--workers", "3", "--port", "20250"]
CMD ["uvicorn", "src.main:app", "--workers", "3", "--host", "0.0.0.0", "--port", "20250"]

135
py/rag-service/README.md Normal file
View File

@@ -0,0 +1,135 @@
# RAG Service Configuration
This document describes how to configure the RAG service, including setting up Language Model (LLM) and Embedding providers.
## Provider Support Matrix
The following table shows which model types are supported by each provider:
| Provider | LLM Support | Embedding Support |
| ---------- | ----------- | ----------------- |
| dashscope | Yes | Yes |
| ollama | Yes | Yes |
| openai | Yes | Yes |
| openrouter | Yes | No |
## LLM Provider Configuration
The `llm` section in the configuration file is used to configure the Language Model (LLM) used by the RAG service.
Here are the configuration examples for each supported LLM provider:
### OpenAI LLM Configuration
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py#L130)
```lua
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
provider = "openai", -- The LLM provider ("openai")
endpoint = "https://api.openai.com/v1", -- The LLM API endpoint
api_key = "OPENAI_API_KEY", -- The environment variable name for the LLM API key
model = "gpt-4o-mini", -- The LLM model name (e.g., "gpt-4o-mini", "gpt-3.5-turbo")
extra = {-- Extra configuration options for the LLM (optional)
temperature = 0.7, -- Controls the randomness of the output. Lower values make it more deterministic.
max_tokens = 512, -- The maximum number of tokens to generate in the completion.
-- system_prompt = "You are a helpful assistant.", -- A system prompt to guide the model's behavior.
-- timeout = 120, -- Request timeout in seconds.
},
},
```
### DashScope LLM Configuration
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-dashscope/llama_index/llms/dashscope/base.py#L155)
```lua
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
provider = "dashscope", -- The LLM provider ("dashscope")
endpoint = "", -- The LLM API endpoint (DashScope typically uses default or environment variables)
api_key = "DASHSCOPE_API_KEY", -- The environment variable name for the LLM API key
model = "qwen-plus", -- The LLM model name (e.g., "qwen-plus", "qwen-max")
extra = nil, -- Extra configuration options for the LLM (optional)
},
```
### Ollama LLM Configuration
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py#L65)
```lua
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
provider = "ollama", -- The LLM provider ("ollama")
endpoint = "http://localhost:11434", -- The LLM API endpoint for Ollama
api_key = "", -- Ollama typically does not require an API key
model = "llama2", -- The LLM model name (e.g., "llama2", "mistral")
extra = nil, -- Extra configuration options for the LLM (optional) Kristin", -- Extra configuration options for the LLM (optional)
},
```
### OpenRouter LLM Configuration
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-openrouter/llama_index/llms/openrouter/base.py#L17)
```lua
llm = { -- Configuration for the Language Model (LLM) used by the RAG service
provider = "openrouter", -- The LLM provider ("openrouter")
endpoint = "https://openrouter.ai/api/v1", -- The LLM API endpoint for OpenRouter
api_key = "OPENROUTER_API_KEY", -- The environment variable name for the LLM API key
model = "openai/gpt-4o-mini", -- The LLM model name (e.g., "openai/gpt-4o-mini", "mistralai/mistral-7b-instruct")
extra = nil, -- Extra configuration options for the LLM (optional)
},
```
## Embedding Provider Configuration
The `embedding` section in the configuration file is used to configure the Embedding Model used by the RAG service.
Here are the configuration examples for each supported Embedding provider:
### OpenAI Embedding Configuration
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-openai/llama_index/embeddings/openai/base.py#L214)
```lua
embed = { -- Configuration for the Embedding Model used by the RAG service
provider = "openai", -- The Embedding provider ("openai")
endpoint = "https://api.openai.com/v1", -- The Embedding API endpoint
api_key = "OPENAI_API_KEY", -- The environment variable name for the Embedding API key
model = "text-embedding-3-large", -- The Embedding model name (e.g., "text-embedding-3-small", "text-embedding-3-large")
extra = {-- Extra configuration options for the Embedding model (optional)
dimensions = nil,
},
},
```
### DashScope Embedding Configuration
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-dashscope/llama_index/embeddings/dashscope/base.py#L156)
```lua
embed = { -- Configuration for the Embedding Model used by the RAG service
provider = "dashscope", -- The Embedding provider ("dashscope")
endpoint = "", -- The Embedding API endpoint (DashScope typically uses default or environment variables)
api_key = "DASHSCOPE_API_KEY", -- The environment variable name for the Embedding API key
model = "text-embedding-v3", -- The Embedding model name (e.g., "text-embedding-v2")
extra = { -- Extra configuration options for the Embedding model (optional)
embed_batch_size = 10,
},
},
```
### Ollama Embedding Configuration
[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py#L12)
```lua
embed = { -- Configuration for the Embedding Model used by the RAG service
provider = "ollama", -- The Embedding provider ("ollama")
endpoint = "http://localhost:11434", -- The Embedding API endpoint for Ollama
api_key = "", -- Ollama typically does not require an API key
model = "nomic-embed-text", -- The Embedding model name (e.g., "nomic-embed-text")
extra = { -- Extra configuration options for the Embedding model (optional)
embed_batch_size = 10
},
},
```

View File

@@ -17,15 +17,17 @@ chroma-hnswlib==0.7.6
chromadb==0.6.3
click==8.1.8
coloredlogs==15.0.1
dashscope==1.22.2
dataclasses-json==0.6.7
decorator==5.1.1
Deprecated==1.2.18
deprecated==1.2.18
dirtyjson==1.0.8
distro==1.9.0
dnspython==2.7.0
docx2txt==0.8
durationpy==0.9
email_validator==2.2.0
et_xmlfile==2.0.0
email-validator==2.2.0
et-xmlfile==2.0.0
executing==2.2.0
fastapi==0.115.8
fastapi-cli==0.0.7
@@ -42,15 +44,15 @@ h11==0.14.0
httpcore==1.0.7
httptools==0.6.4
httpx==0.28.1
huggingface-hub==0.28.1
huggingface-hub==0.31.4
humanfriendly==10.0
idna==3.10
importlib_metadata==8.5.0
importlib_resources==6.5.2
importlib-metadata==8.5.0
importlib-resources==6.5.2
ipdb==0.13.13
ipython==8.32.0
jedi==0.19.2
Jinja2==3.1.5
jinja2==3.1.5
jiter==0.8.2
joblib==1.4.2
kubernetes==32.0.0
@@ -60,11 +62,15 @@ llama-index==0.12.16
llama-index-agent-openai==0.4.3
llama-index-cli==0.4.0
llama-index-core==0.12.16.post1
llama-index-embeddings-openai==0.3.1
llama-index-embeddings-dashscope==0.3.0
llama-index-embeddings-ollama==0.5.0
llama-index-embeddings-openai==0.3.1
llama-index-indices-managed-llama-cloud==0.6.4
llama-index-llms-openai==0.3.18
llama-index-llms-dashscope==0.3.3
llama-index-llms-ollama==0.5.2
llama-index-llms-openai==0.3.18
llama-index-llms-openai-like==0.3.4
llama-index-llms-openrouter==0.3.1
llama-index-multi-modal-llms-openai==0.4.3
llama-index-program-openai==0.3.1
llama-index-question-gen-openai==0.3.0
@@ -74,7 +80,7 @@ llama-index-vector-stores-chroma==0.4.1
llama-parse==0.6.0
markdown-it-py==3.0.0
markdownify==0.14.1
MarkupSafe==3.0.2
markupsafe==3.0.2
marshmallow==3.26.1
matplotlib-inline==0.1.7
mdurl==0.1.2
@@ -88,6 +94,7 @@ networkx==3.4.2
nltk==3.9.1
numpy==2.2.2
oauthlib==3.2.2
ollama==0.4.8
onnxruntime==1.20.1
openai==1.61.1
openpyxl==3.1.5
@@ -110,35 +117,36 @@ pathspec==0.12.1
pexpect==4.9.0
pillow==11.1.0
posthog==3.11.0
prompt_toolkit==3.0.50
prompt-toolkit==3.0.50
propcache==0.2.1
protobuf==5.29.3
ptyprocess==0.7.0
pure_eval==0.2.3
pure-eval==0.2.3
pyasn1==0.6.1
pyasn1_modules==0.4.1
pyasn1-modules==0.4.1
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pydantic-core==2.27.2
pygments==2.19.1
pypdf==5.2.0
PyPika==0.48.9
pyproject_hooks==1.2.0
pypika==0.48.9
pyproject-hooks==1.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.20
pytz==2025.1
PyYAML==6.0.2
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
requests-oauthlib==2.0.0
rich==13.9.4
rich-toolkit==0.13.2
rsa==4.9
safetensors==0.5.3
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.38
sqlalchemy==2.0.38
stack-data==0.6.3
starlette==0.45.3
striprtf==0.0.26
@@ -148,11 +156,15 @@ tiktoken==0.8.0
tokenizers==0.21.0
tqdm==4.67.1
traitlets==5.14.3
transformers==4.51.3
tree-sitter==0.24.0
tree-sitter-c-sharp==0.23.1
tree-sitter-embedded-template==0.23.2
tree-sitter-language-pack==0.6.1
tree-sitter-yaml==0.7.0
typer==0.15.1
typing-extensions==4.12.2
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2025.1
urllib3==2.3.0
uvicorn==0.34.0
@@ -165,4 +177,3 @@ websockets==14.2
wrapt==1.17.2
yarl==1.18.3
zipp==3.21.0
docx2txt==0.8.0

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
# Standard library imports
import asyncio
import fcntl
import json
@@ -18,42 +19,24 @@ from pathlib import Path
from typing import TYPE_CHECKING
from urllib.parse import urljoin, urlparse
# Third-party imports
import chromadb
import httpx
import pathspec
from fastapi import BackgroundTasks, FastAPI, HTTPException
from libs.configs import (
BASE_DATA_DIR,
CHROMA_PERSIST_DIR,
)
# Local application imports
from libs.configs import BASE_DATA_DIR, CHROMA_PERSIST_DIR
from libs.db import init_db
from libs.logger import logger
from libs.utils import (
get_node_uri,
inject_uri_to_node,
is_local_uri,
is_path_node,
is_remote_uri,
path_to_uri,
uri_to_path,
)
from llama_index.core import (
Settings,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from libs.utils import get_node_uri, inject_uri_to_node, is_local_uri, is_path_node, is_remote_uri, path_to_uri, uri_to_path
from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.node_parser import CodeSplitter
from llama_index.core.schema import Document
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from llama_index.vector_stores.chroma import ChromaVectorStore
from markdownify import markdownify as md
from models.indexing_history import IndexingHistory # noqa: TC002
from models.resource import Resource
from providers.factory import initialize_embed_model, initialize_llm_model
from pydantic import BaseModel, Field
from services.indexing_history import indexing_history_service
from services.resource import resource_service
@@ -65,6 +48,7 @@ if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from llama_index.core.schema import NodeWithScore, QueryBundle
from models.indexing_history import IndexingHistory
from watchdog.observers.api import BaseObserver
# Lock file for leader election
@@ -111,8 +95,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
if is_local_uri(resource.uri):
directory = uri_to_path(resource.uri)
if not directory.exists():
logger.error("Directory not found: %s", directory)
resource_service.update_resource_status(resource.uri, "error", f"Directory not found: {directory}")
error_msg = f"Directory not found: {directory}"
logger.error(error_msg)
resource_service.update_resource_status(resource.uri, "error", error_msg)
continue
# Start file system watcher
@@ -127,8 +112,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
elif is_remote_uri(resource.uri):
if not is_remote_resource_exists(resource.uri):
logger.error("HTTPS resource not found: %s", resource.uri)
resource_service.update_resource_status(resource.uri, "error", "remote resource not found")
error_msg = "HTTPS resource not found"
logger.error("%s: %s", error_msg, resource.uri)
resource_service.update_resource_status(resource.uri, "error", error_msg)
continue
# Start indexing
@@ -273,7 +259,11 @@ def is_remote_resource_exists(url: str) -> bool:
"""Check if a URL exists."""
try:
response = httpx.head(url, headers=http_headers)
return response.status_code in {httpx.codes.OK, httpx.codes.MOVED_PERMANENTLY, httpx.codes.FOUND}
return response.status_code in {
httpx.codes.OK,
httpx.codes.MOVED_PERMANENTLY,
httpx.codes.FOUND,
}
except (OSError, ValueError, RuntimeError) as e:
logger.error("Error checking if URL exists %s: %s", url, e)
return False
@@ -318,53 +308,78 @@ init_db()
# Initialize ChromaDB and LlamaIndex services
chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR))
# Check if provider or model has changed
current_provider = os.getenv("RAG_PROVIDER", "openai").lower()
current_embed_model = os.getenv("RAG_EMBED_MODEL", "")
current_llm_model = os.getenv("RAG_LLM_MODEL", "")
# # Check if provider or model has changed
rag_embed_provider = os.getenv("RAG_EMBED_PROVIDER", "openai")
rag_embed_endpoint = os.getenv("RAG_EMBED_ENDPOINT", "https://api.openai.com/v1")
rag_embed_model = os.getenv("RAG_EMBED_MODEL", "text-embedding-3-large")
rag_embed_api_key = os.getenv("RAG_EMBED_API_KEY", None)
rag_embed_extra = os.getenv("RAG_EMBED_EXTRA", None)
rag_llm_provider = os.getenv("RAG_LLM_PROVIDER", "openai")
rag_llm_endpoint = os.getenv("RAG_LLM_ENDPOINT", "https://api.openai.com/v1")
rag_llm_model = os.getenv("RAG_LLM_MODEL", "gpt-4o-mini")
rag_llm_api_key = os.getenv("RAG_LLM_API_KEY", None)
rag_llm_extra = os.getenv("RAG_LLM_EXTRA", None)
# Try to read previous config
config_file = BASE_DATA_DIR / "rag_config.json"
if config_file.exists():
with Path.open(config_file, "r") as f:
prev_config = json.load(f)
if prev_config.get("provider") != current_provider or prev_config.get("embed_model") != current_embed_model:
if prev_config.get("provider") != rag_embed_provider or prev_config.get("embed_model") != rag_embed_model:
# Clear existing data if config changed
logger.info("Detected config change, clearing existing data...")
chroma_client.reset()
# Save current config
with Path.open(config_file, "w") as f:
json.dump({"provider": current_provider, "embed_model": current_embed_model}, f)
json.dump({"provider": rag_embed_provider, "embed_model": rag_embed_model}, f)
chroma_collection = chroma_client.get_or_create_collection("documents") # pyright: ignore
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Initialize embedding model based on provider
llm_provider = current_provider
base_url = os.getenv(llm_provider.upper() + "_API_BASE", "")
rag_embed_model = current_embed_model
rag_llm_model = current_llm_model
try:
embed_extra = json.loads(rag_embed_extra) if rag_embed_extra is not None else {}
except json.JSONDecodeError:
logger.error("Failed to decode RAG_EMBED_EXTRA, defaulting to empty dict.")
embed_extra = {}
try:
llm_extra = json.loads(rag_llm_extra) if rag_llm_extra is not None else {}
except json.JSONDecodeError:
logger.error("Failed to decode RAG_LLM_EXTRA, defaulting to empty dict.")
llm_extra = {}
# Initialize embedding model and LLM based on provider using the factory
try:
embed_model = initialize_embed_model(
embed_provider=rag_embed_provider,
embed_model=rag_embed_model,
embed_endpoint=rag_embed_endpoint,
embed_api_key=rag_embed_api_key,
embed_extra=embed_extra,
)
logger.info("Embedding model initialized successfully.")
except (ValueError, RuntimeError) as e:
error_msg = f"Failed to initialize embedding model: {e}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
try:
llm_model = initialize_llm_model(
llm_provider=rag_llm_provider,
llm_model=rag_llm_model,
llm_endpoint=rag_llm_endpoint,
llm_api_key=rag_llm_api_key,
llm_extra=llm_extra,
)
logger.info("LLM model initialized successfully.")
except (ValueError, RuntimeError) as e:
error_msg = f"Failed to initialize LLM model: {e}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
if llm_provider == "ollama":
if base_url == "":
base_url = "http://localhost:11434"
if rag_embed_model == "":
rag_embed_model = "nomic-embed-text"
if rag_llm_model == "":
rag_llm_model = "llama3"
embed_model = OllamaEmbedding(model_name=rag_embed_model, base_url=base_url)
llm_model = Ollama(model=rag_llm_model, base_url=base_url, request_timeout=60.0)
else:
if base_url == "":
base_url = "https://api.openai.com/v1"
if rag_embed_model == "":
rag_embed_model = "text-embedding-3-small"
if rag_llm_model == "":
rag_llm_model = "gpt-4o-mini"
embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url)
llm_model = OpenAI(model=rag_llm_model, api_base=base_url)
Settings.embed_model = embed_model
Settings.llm = llm_model
@@ -400,7 +415,10 @@ class SourceDocument(BaseModel):
class RetrieveRequest(BaseModel):
"""Request model for information retrieval."""
query: str = Field(..., description="The query text to search for in the indexed documents")
query: str = Field(
...,
description="The query text to search for in the indexed documents",
)
base_uri: str = Field(..., description="The base URI to search in")
top_k: int | None = Field(5, description="Number of top results to return", ge=1, le=20)
@@ -478,7 +496,10 @@ def process_document_batch(documents: list[Document]) -> bool: # noqa: PLR0915,
# Check if document with same hash has already been successfully processed
status_records = indexing_history_service.get_indexing_status(doc=doc)
if status_records and status_records[0].status == "completed":
logger.debug("Document with same hash already processed, skipping: %s", doc.doc_id)
logger.debug(
"Document with same hash already processed, skipping: %s",
doc.doc_id,
)
continue
logger.debug("Processing document: %s", doc.doc_id)
@@ -592,7 +613,10 @@ def get_gitcrypt_files(directory: Path) -> list[str]:
)
if git_root_cmd.returncode != 0:
logger.warning("Not a git repository or git command failed: %s", git_root_cmd.stderr.strip())
logger.warning(
"Not a git repository or git command failed: %s",
git_root_cmd.stderr.strip(),
)
return git_crypt_patterns
git_root = Path(git_root_cmd.stdout.strip())
@@ -613,7 +637,15 @@ def get_gitcrypt_files(directory: Path) -> list[str]:
# Use Python to process the output instead of xargs, grep, and cut
git_check_attr = subprocess.run(
[git_executable, "-C", str(git_root), "check-attr", "filter", "--stdin", "-z"],
[
git_executable,
"-C",
str(git_root),
"check-attr",
"filter",
"--stdin",
"-z",
],
input=git_ls_files.stdout,
capture_output=True,
text=False,
@@ -830,7 +862,11 @@ def split_documents(documents: list[Document]) -> list[Document]:
t = doc.get_content()
texts = code_splitter.split_text(t)
except ValueError as e:
logger.error("Error splitting document: %s, so skipping split, error: %s", doc.doc_id, str(e))
logger.error(
"Error splitting document: %s, so skipping split, error: %s",
doc.doc_id,
str(e),
)
processed_documents.append(doc)
continue
@@ -1034,7 +1070,10 @@ async def add_resource(request: ResourceRequest, background_tasks: BackgroundTas
if resource:
if resource.name != request.name:
raise HTTPException(status_code=400, detail=f"Resource name cannot be changed: {resource.name}")
raise HTTPException(
status_code=400,
detail=f"Resource name cannot be changed: {resource.name}",
)
resource_service.update_resource_status(resource.uri, "active")
else:
@@ -1352,3 +1391,9 @@ async def list_resources() -> ResourceListResponse:
total_count=len(resources),
status_summary=status_counts,
)
@app.get("/api/health")
async def health_check() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "ok"}

View File

View File

@@ -0,0 +1,70 @@
# src/providers/dashscope.py
from typing import Any
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
from llama_index.embeddings.dashscope import DashScopeEmbedding
from llama_index.llms.dashscope import DashScope
def initialize_embed_model(
embed_endpoint: str, # noqa: ARG001
embed_api_key: str,
embed_model: str,
**embed_extra: Any, # noqa: ANN401
) -> BaseEmbedding:
"""
Create DashScope embedding model.
Args:
embed_endpoint: Not be used directly by the constructor.
embed_api_key: The API key for the DashScope API.
embed_model: The name of the embedding model.
embed_extra: Extra parameters of the embedding model.
Returns:
The initialized embed_model.
"""
# DashScope typically uses the API key and model name.
# The endpoint might be set via environment variables or default.
# We pass embed_api_key and embed_model to the constructor.
# We include embed_endpoint in the signature to match the factory interface,
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
return DashScopeEmbedding(
model_name=embed_model,
api_key=embed_api_key,
**embed_extra,
)
def initialize_llm_model(
llm_endpoint: str, # noqa: ARG001
llm_api_key: str,
llm_model: str,
**llm_extra: Any, # noqa: ANN401
) -> LLM:
"""
Create DashScope LLM model.
Args:
llm_endpoint: Not be used directly by the constructor.
llm_api_key: The API key for the DashScope API.
llm_model: The name of the LLM model.
llm_extra: Extra parameters of the LLM model.
Returns:
The initialized llm_model.
"""
# DashScope typically uses the API key and model name.
# The endpoint might be set via environment variables or default.
# We pass llm_api_key and llm_model to the constructor.
# We include llm_endpoint in the signature to match the factory interface,
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
return DashScope(
model_name=llm_model,
api_key=llm_api_key,
**llm_extra,
)

View File

@@ -0,0 +1,179 @@
import importlib
from typing import TYPE_CHECKING, Any, cast
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
if TYPE_CHECKING:
from collections.abc import Callable
from libs.logger import logger # Assuming libs.logger exists and provides a logger instance
def initialize_embed_model(
embed_provider: str,
embed_model: str,
embed_endpoint: str | None = None,
embed_api_key: str | None = None,
embed_extra: dict[str, Any] | None = None,
) -> BaseEmbedding:
"""
Initialize embedding model based on specified provider and configuration.
Dynamically loads the provider module based on the embed_provider parameter.
Args:
embed_provider: The name of the embedding provider (e.g., "openai", "ollama").
embed_model: The name of the embedding model.
embed_endpoint: The API endpoint for the embedding provider.
embed_api_key: The API key for the embedding provider.
embed_extra: Additional provider-specific configuration parameters.
Returns:
The initialized embed_model.
Raises:
ValueError: If the specified embed_provider is not supported or module/function not found.
RuntimeError: If model initialization fails for the selected provider.
"""
# Validate provider name
error_msg = f"Invalid EMBED_PROVIDER specified: '{embed_provider}'. Provider name must be alphanumeric or contain underscores."
if not embed_provider.replace("_", "").isalnum():
raise ValueError(error_msg)
try:
provider_module = importlib.import_module(f".{embed_provider}", package="providers")
logger.debug(f"Successfully imported provider module: providers.{embed_provider}")
attribute = getattr(provider_module, "initialize_embed_model", None)
if attribute is None:
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
raise ValueError(error_msg) # noqa: TRY301
initializer = cast("Callable[..., BaseEmbedding]", attribute)
except ImportError as err:
error_msg = f"Unsupported EMBED_PROVIDER specified: '{embed_provider}'. Could not find provider module 'providers.{embed_provider}"
raise ValueError(error_msg) from err
except AttributeError as err:
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
raise ValueError(error_msg) from err
except Exception as err:
logger.error(
f"An unexpected error occurred while loading provider '{embed_provider}': {err!r}",
exc_info=True,
)
error_msg = f"Failed to load provider '{embed_provider}' due to an unexpected error."
raise RuntimeError(error_msg) from err
logger.info(f"Initializing embedding model for provider: {embed_provider}")
try:
embedding: BaseEmbedding = initializer(
embed_endpoint,
embed_api_key,
embed_model,
**(embed_extra or {}),
)
logger.info(f"Embedding model initialized successfully for {embed_provider}")
return embedding
except TypeError as err:
error_msg = f"Provider initializer 'initialize_embed_model' was called with incorrect arguments in '{embed_provider}'"
logger.error(
f"{error_msg}: {err!r}",
exc_info=True,
)
raise RuntimeError(error_msg) from err
except Exception as err:
error_msg = f"Failed to initialize embedding model for provider '{embed_provider}'"
logger.error(
f"{error_msg}: {err!r}",
exc_info=True,
)
raise RuntimeError(error_msg) from err
def initialize_llm_model(
llm_provider: str,
llm_model: str,
llm_endpoint: str | None = None,
llm_api_key: str | None = None,
llm_extra: dict[str, Any] | None = None,
) -> LLM:
"""
Create LLM model with the specified configuration.
Dynamically loads the provider module based on the llm_provider parameter.
Args:
llm_provider: The name of the LLM provider (e.g., "openai", "ollama").
llm_endpoint: The API endpoint for the LLM provider.
llm_api_key: The API key for the LLM provider.
llm_model: The name of the LLM model.
llm_extra: The name of the LLM model.
Returns:
The initialized llm_model.
Raises:
ValueError: If the specified llm_provider is not supported or module/function not found.
RuntimeError: If model initialization fails for the selected provider.
"""
if not llm_provider.replace("_", "").isalnum():
error_msg = f"Invalid LLM_PROVIDER specified: '{llm_provider}'. Provider name must be alphanumeric or contain underscores."
raise ValueError(error_msg)
try:
provider_module = importlib.import_module(
f".{llm_provider}",
package="providers",
)
logger.debug(f"Successfully imported provider module: providers.{llm_provider}")
attribute = getattr(provider_module, "initialize_llm_model", None)
if attribute is None:
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
raise ValueError(error_msg) # noqa: TRY301
initializer = cast("Callable[..., LLM]", attribute)
except ImportError as err:
error_msg = f"Unsupported LLM_PROVIDER specified: '{llm_provider}'. Could not find provider module 'providers.{llm_provider}'."
raise ValueError(error_msg) from err
except AttributeError as err:
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
raise ValueError(error_msg) from err
except Exception as e:
error_msg = f"An unexpected error occurred while loading provider '{llm_provider}': {e}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
logger.info(f"Initializing LLM model for provider: '{llm_provider}'")
logger.debug(f"Args: llm_model='{llm_model}', llm_endpoint='{llm_endpoint}'")
try:
llm: LLM = initializer(
llm_endpoint,
llm_api_key,
llm_model,
**(llm_extra or {}),
)
logger.info(f"LLM model initialized successfully for '{llm_provider}'.")
except TypeError as e:
error_msg = f"Provider initializer 'initialize_llm_model' in '{llm_provider}' was called with incorrect arguments: {e}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
except Exception as e:
error_msg = f"Failed to initialize LLM model for provider '{llm_provider}': {e}"
logger.error(
error_msg,
exc_info=True,
)
raise RuntimeError(error_msg) from e
return llm

View File

@@ -0,0 +1,66 @@
# src/providers/ollama.py
from typing import Any
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
def initialize_embed_model(
embed_endpoint: str,
embed_api_key: str, # noqa: ARG001
embed_model: str,
**embed_extra: Any, # noqa: ANN401
) -> BaseEmbedding:
"""
Create Ollama embedding model.
Args:
embed_endpoint: The API endpoint for the Ollama API.
embed_api_key: Not be used by Ollama.
embed_model: The name of the embedding model.
embed_extra: Extra parameters for Ollama embedding model.
Returns:
The initialized embed_model.
"""
# Ollama typically uses the endpoint directly and may not require an API key
# We include embed_api_key in the signature to match the factory interface
# Pass embed_api_key even if Ollama doesn't use it, to match the signature
return OllamaEmbedding(
model_name=embed_model,
base_url=embed_endpoint,
**embed_extra,
)
def initialize_llm_model(
llm_endpoint: str,
llm_api_key: str, # noqa: ARG001
llm_model: str,
**llm_extra: Any, # noqa: ANN401
) -> LLM:
"""
Create Ollama LLM model.
Args:
llm_endpoint: The API endpoint for the Ollama API.
llm_api_key: Not be used by Ollama.
llm_model: The name of the LLM model.
llm_extra: Extra parameters for LLM model.
Returns:
The initialized llm_model.
"""
# Ollama typically uses the endpoint directly and may not require an API key
# We include llm_api_key in the signature to match the factory interface
# Pass llm_api_key even if Ollama doesn't use it, to match the signature
return Ollama(
model=llm_model,
base_url=llm_endpoint,
**llm_extra,
)

View File

@@ -0,0 +1,68 @@
# src/providers/openai.py
from typing import Any
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
def initialize_embed_model(
embed_endpoint: str,
embed_api_key: str,
embed_model: str,
**embed_extra: Any, # noqa: ANN401
) -> BaseEmbedding:
"""
Create OpenAI embedding model.
Args:
embed_model: The name of the embedding model.
embed_endpoint: The API endpoint for the OpenAI API.
embed_api_key: The API key for the OpenAI API.
embed_extra: Extra Paramaters for the OpenAI API.
Returns:
The initialized embed_model.
"""
# Use the provided endpoint directly.
# Note: OpenAIEmbedding automatically picks up OPENAI_API_KEY env var
# We are not using embed_api_key parameter here, relying on env var as original code did.
return OpenAIEmbedding(
model=embed_model,
api_base=embed_endpoint,
api_key=embed_api_key,
**embed_extra,
)
def initialize_llm_model(
llm_endpoint: str,
llm_api_key: str,
llm_model: str,
**llm_extra: Any, # noqa: ANN401
) -> LLM:
"""
Create OpenAI LLM model.
Args:
llm_model: The name of the LLM model.
llm_endpoint: The API endpoint for the OpenAI API.
llm_api_key: The API key for the OpenAI API.
llm_extra: Extra paramaters for the OpenAI API.
Returns:
The initialized llm_model.
"""
# Use the provided endpoint directly.
# Note: OpenAI automatically picks up OPENAI_API_KEY env var
# We are not using llm_api_key parameter here, relying on env var as original code did.
return OpenAI(
model=llm_model,
api_base=llm_endpoint,
api_key=llm_api_key,
**llm_extra,
)

View File

@@ -0,0 +1,35 @@
# src/providers/openrouter.py
from typing import Any
from llama_index.core.llms.llm import LLM
from llama_index.llms.openrouter import OpenRouter
def initialize_llm_model(
llm_endpoint: str,
llm_api_key: str,
llm_model: str,
**llm_extra: Any, # noqa: ANN401
) -> LLM:
"""
Create OpenRouter LLM model.
Args:
llm_model: The name of the LLM model.
llm_endpoint: The API endpoint for the OpenRouter API.
llm_api_key: The API key for the OpenRouter API.
llm_extra: The Extra Parameters for OpenROuter,
Returns:
The initialized llm_model.
"""
# Use the provided endpoint directly.
# We are not using llm_api_key parameter here, relying on env var as original code did.
return OpenRouter(
model=llm_model,
api_base=llm_endpoint,
api_key=llm_api_key,
**llm_extra,
)