From 2dd4c040880b271861369b361489a2d418d42648 Mon Sep 17 00:00:00 2001 From: doodleEsc Date: Fri, 6 Jun 2025 23:07:07 +0800 Subject: [PATCH] feat: Enhanced Model Provider Support and Configuration Flexibility For Rag Service (#2056) Co-authored-by: doodleEsc Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> --- README.md | 40 +++-- README_zh.md | 38 +++-- lua/avante/api.lua | 1 - lua/avante/config.lua | 26 ++- lua/avante/rag_service.lua | 88 +++++++--- py/rag-service/Dockerfile | 21 ++- py/rag-service/README.md | 135 ++++++++++++++++ py/rag-service/requirements.txt | 53 +++--- py/rag-service/src/main.py | 175 ++++++++++++-------- py/rag-service/src/providers/__init__.py | 0 py/rag-service/src/providers/dashscope.py | 70 ++++++++ py/rag-service/src/providers/factory.py | 179 +++++++++++++++++++++ py/rag-service/src/providers/ollama.py | 66 ++++++++ py/rag-service/src/providers/openai.py | 68 ++++++++ py/rag-service/src/providers/openrouter.py | 35 ++++ 15 files changed, 844 insertions(+), 151 deletions(-) create mode 100644 py/rag-service/README.md create mode 100644 py/rag-service/src/providers/__init__.py create mode 100644 py/rag-service/src/providers/dashscope.py create mode 100644 py/rag-service/src/providers/factory.py create mode 100644 py/rag-service/src/providers/ollama.py create mode 100644 py/rag-service/src/providers/openai.py create mode 100644 py/rag-service/src/providers/openrouter.py diff --git a/README.md b/README.md index ccbd694..cc09e59 100644 --- a/README.md +++ b/README.md @@ -705,7 +705,7 @@ Given its early stage, `avante.nvim` currently supports the following basic func > model = "us.anthropic.claude-3-5-sonnet-20241022-v2:0", > aws_profile = "bedrock", > aws_region = "us-east-1", ->}, +> }, > ``` > > Note: Bedrock requires the [AWS CLI](https://aws.amazon.com/cli/) to be installed on your system. @@ -884,21 +884,37 @@ For more information, see [Custom Providers](https://github.com/yetone/avante.nv Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. By default, it is not enabled. You can enable it this way: ```lua -rag_service = { - enabled = false, -- Enables the RAG service - host_mount = os.getenv("HOME"), -- Host mount path for the rag service - provider = "openai", -- The provider to use for RAG service (e.g. openai or ollama) - llm_model = "", -- The LLM model to use for RAG service - embed_model = "", -- The embedding model to use for RAG service - endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service -}, + rag_service = { -- RAG Service configuration + enabled = false, -- Enables the RAG service + host_mount = os.getenv("HOME"), -- Host mount path for the rag service (Docker will mount this path) + runner = "docker", -- Runner for the RAG service (can use docker or nix) + llm = { -- Language Model (LLM) configuration for RAG service + provider = "openai", -- LLM provider + endpoint = "https://api.openai.com/v1", -- LLM API endpoint + api_key = "OPENAI_API_KEY", -- Environment variable name for the LLM API key + model = "gpt-4o-mini", -- LLM model name + extra = nil, -- Additional configuration options for LLM + }, + embed = { -- Embedding model configuration for RAG service + provider = "openai", -- Embedding provider + endpoint = "https://api.openai.com/v1", -- Embedding API endpoint + api_key = "OPENAI_API_KEY", -- Environment variable name for the embedding API key + model = "text-embedding-3-large", -- Embedding model name + extra = nil, -- Additional configuration options for the embedding model + }, + docker_extra_args = "", -- Extra arguments to pass to the docker command + }, ``` -If your rag_service provider is `openai`, then you need to set the `OPENAI_API_KEY` environment variable! +The RAG Service can currently configure the LLM and embedding models separately. In the `llm` and `embed` configuration blocks, you can set the following fields: -If your rag_service provider is `ollama`, you need to set the endpoint to `http://localhost:11434` (note there is no `/v1` at the end) or any address of your own ollama server. +- `provider`: Model provider (e.g., "openai", "ollama", "dashscope", and "openrouter") +- `endpoint`: API endpoint +- `api_key`: Environment variable name for the API key +- `model`: Model name +- `extra`: Additional configuration options -If your rag_service provider is `ollama`, when `llm_model` is empty, it defaults to `llama3`, and when `embed_model` is empty, it defaults to `nomic-embed-text`. Please make sure these models are available in your ollama server. +For detailed configuration of different model providers, you can check [here](./py/rag-service/README.md). Additionally, RAG Service also depends on Docker! (For macOS users, OrbStack is recommended as a Docker alternative). diff --git a/README_zh.md b/README_zh.md index 99d4374..209aea5 100644 --- a/README_zh.md +++ b/README_zh.md @@ -747,21 +747,37 @@ Avante 提供了一组默认提供者,但用户也可以创建自己的提供 Avante 提供了一个 RAG 服务,这是一个用于获取 AI 生成代码所需上下文的工具。默认情况下,它未启用。您可以通过以下方式启用它: ```lua -rag_service = { - enabled = false, -- 启用 RAG 服务 - host_mount = os.getenv("HOME"), -- RAG 服务的主机挂载路径 - provider = "openai", -- 用于 RAG 服务的提供者(例如 openai 或 ollama) - llm_model = "", -- 用于 RAG 服务的 LLM 模型 - embed_model = "", -- 用于 RAG 服务的嵌入模型 - endpoint = "https://api.openai.com/v1", -- RAG 服务的 API 端点 -}, + rag_service = { -- RAG 服务配置 + enabled = false, -- 启用 RAG 服务 + host_mount = os.getenv("HOME"), -- RAG 服务的主机挂载路径 (Docker 将挂载此路径) + runner = "docker", -- RAG 服务的运行器 (可以使用 docker 或 nix) + llm = { -- RAG 服务使用的语言模型 (LLM) 配置 + provider = "openai", -- LLM 提供者 + endpoint = "https://api.openai.com/v1", -- LLM API 端点 + api_key = "OPENAI_API_KEY", -- LLM API 密钥的环境变量名称 + model = "gpt-4o-mini", -- LLM 模型名称 + extra = nil, -- LLM 的额外配置选项 + }, + embed = { -- RAG 服务使用的嵌入模型配置 + provider = "openai", -- 嵌入提供者 + endpoint = "https://api.openai.com/v1", -- 嵌入 API 端点 + api_key = "OPENAI_API_KEY", -- 嵌入 API 密钥的环境变量名称 + model = "text-embedding-3-large", -- 嵌入模型名称 + extra = nil, -- 嵌入模型的额外配置选项 + }, + docker_extra_args = "", -- 传递给 docker 命令的额外参数 + }, ``` -如果您的 rag_service 提供者是 `openai`,那么您需要设置 `OPENAI_API_KEY` 环境变量! +RAG 服务可以单独设置llm模型和嵌入模型。在 `llm` 和 `embed` 配置块中,您可以设置以下字段: -如果您的 rag_service 提供者是 `ollama`,您需要将端点设置为 `http://localhost:11434`(注意末尾没有 `/v1`)或您自己的 ollama 服务器的任何地址。 +- `provider`: 模型提供者(例如 "openai", "ollama", "dashscope"以及"openrouter") +- `endpoint`: API 端点 +- `api_key`: API 密钥的环境变量名称 +- `model`: 模型名称 +- `extra`: 额外的配置选项 -如果您的 rag_service 提供者是 `ollama`,当 `llm_model` 为空时,默认为 `llama3`,当 `embed_model` 为空时,默认为 `nomic-embed-text`。请确保这些模型在您的 ollama 服务器中可用。 +有关不同模型提供商的详细配置,你可以在[这里](./py/rag-service/README.md)查看。 此外,RAG 服务还依赖于 Docker!(对于 macOS 用户,推荐使用 OrbStack 作为 Docker 的替代品)。 diff --git a/lua/avante/api.lua b/lua/avante/api.lua index 48da87a..4cfc054 100644 --- a/lua/avante/api.lua +++ b/lua/avante/api.lua @@ -290,7 +290,6 @@ function M.remove_selected_file(filepath) for _, file in ipairs(files) do local rel_path = Utils.uniform_path(file) - vim.notify(rel_path) sidebar.file_selector:remove_selected_file(rel_path) end end diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 34cc039..8c905d5 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -37,14 +37,24 @@ M._defaults = { tokenizer = "tiktoken", ---@type string | (fun(): string) | nil system_prompt = nil, - rag_service = { - enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set - host_mount = os.getenv("HOME"), -- Host mount path for the rag service (docker will mount this path) - runner = "docker", -- The runner for the rag service, (can use docker, or nix) - provider = "openai", -- The provider to use for RAG service. eg: openai or ollama - llm_model = "", -- The LLM model to use for RAG service - embed_model = "", -- The embedding model to use for RAG service - endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service + rag_service = { -- RAG service configuration + enabled = false, -- Enables the RAG service + host_mount = os.getenv("HOME"), -- Host mount path for the RAG service (Docker will mount this path) + runner = "docker", -- The runner for the RAG service (can use docker or nix) + llm = { -- Configuration for the Language Model (LLM) used by the RAG service + provider = "openai", -- The LLM provider + endpoint = "https://api.openai.com/v1", -- The LLM API endpoint + api_key = "OPENAI_API_KEY", -- The environment variable name for the LLM API key + model = "gpt-4o-mini", -- The LLM model name + extra = nil, -- Extra configuration options for the LLM + }, + embed = { -- Configuration for the Embedding model used by the RAG service + provider = "openai", -- The embedding provider + endpoint = "https://api.openai.com/v1", -- The embedding API endpoint + api_key = "OPENAI_API_KEY", -- The environment variable name for the embedding API key + model = "text-embedding-3-large", -- The embedding model name + extra = nil, -- Extra configuration options for the embedding model + }, docker_extra_args = "", -- Extra arguments to pass to the docker command }, web_search_engine = { diff --git a/lua/avante/rag_service.lua b/lua/avante/rag_service.lua index a6d8eb8..b6e83e0 100644 --- a/lua/avante/rag_service.lua +++ b/lua/avante/rag_service.lua @@ -8,7 +8,7 @@ local M = {} local container_name = "avante-rag-service" local service_path = "/tmp/" .. container_name -function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.10" end +function M.get_rag_service_image() return "quay.io/yetoneful/avante-rag-service:0.0.11" end function M.get_rag_service_port() return 20250 end @@ -35,13 +35,46 @@ function M.get_rag_service_runner() return (Config.rag_service and Config.rag_se ---@param cb fun() function M.launch_rag_service(cb) - local openai_api_key = os.getenv("OPENAI_API_KEY") - if Config.rag_service.provider == "openai" then - if openai_api_key == nil then - error("cannot launch avante rag service, OPENAI_API_KEY is not set") + --- If Config.rag_service.llm.api_key is nil or empty, llm_api_key will be an empty string. + local llm_api_key = "" + if + Config.rag_service + and Config.rag_service.llm + and Config.rag_service.llm.api_key + and Config.rag_service.llm.api_key ~= "" + then + llm_api_key = os.getenv(Config.rag_service.llm.api_key) or "" + if llm_api_key == nil or llm_api_key == "" then + error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.llm.api_key)) return end end + + --- If Config.rag_service.embed.api_key is nil or empty, embed_api_key will be an empty string. + local embed_api_key = "" + if + Config.rag_service + and Config.rag_service.embed + and Config.rag_service.embed.api_key + and Config.rag_service.embed.api_key ~= "" + then + embed_api_key = os.getenv(Config.rag_service.embed.api_key) or "" + if embed_api_key == nil or embed_api_key == "" then + error(string.format("cannot launch avante rag service, %s is not set", Config.rag_service.embed.api_key)) + return + end + end + + local embed_extra = "{}" -- Default to empty JSON object string + if Config.rag_service and Config.rag_service.embed and Config.rag_service.embed.extra then + embed_extra = vim.json.encode(Config.rag_service.embed.extra):gsub('"', '\\"') + end + + local llm_extra = "{}" -- Default to empty JSON object string + if Config.rag_service and Config.rag_service.llm and Config.rag_service.llm.extra then + llm_extra = vim.json.encode(Config.rag_service.llm.extra):gsub('"', '\\"') + end + local port = M.get_rag_service_port() if M.get_rag_service_runner() == "docker" then @@ -74,19 +107,22 @@ function M.launch_rag_service(cb) M.stop_rag_service() end local cmd_ = string.format( - "docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s %s", + "docker run --platform=linux/amd64 -d -p 0.0.0.0:%d:%d --name %s -v %s:/data -v %s:/host:ro -e ALLOW_RESET=TRUE -e DATA_DIR=/data -e RAG_EMBED_PROVIDER=%s -e RAG_EMBED_ENDPOINT=%s -e RAG_EMBED_API_KEY=%s -e RAG_EMBED_MODEL=%s -e RAG_EMBED_EXTRA=%s -e RAG_LLM_PROVIDER=%s -e RAG_LLM_ENDPOINT=%s -e RAG_LLM_API_KEY=%s -e RAG_LLM_MODEL=%s -e RAG_LLM_EXTRA=%s %s %s", M.get_rag_service_port(), M.get_rag_service_port(), container_name, data_path, Config.rag_service.host_mount, - Config.rag_service.provider, - Config.rag_service.provider:upper(), - openai_api_key, - Config.rag_service.provider:upper(), - Config.rag_service.endpoint, - Config.rag_service.llm_model, - Config.rag_service.embed_model, + Config.rag_service.embed.provider, + Config.rag_service.embed.endpoint, + embed_api_key, + Config.rag_service.embed.model, + embed_extra, + Config.rag_service.llm.provider, + Config.rag_service.llm.endpoint, + llm_api_key, + Config.rag_service.llm.model, + llm_extra, Config.rag_service.docker_extra_args, image ) @@ -118,17 +154,20 @@ function M.launch_rag_service(cb) Utils.debug(string.format("launching %s with nix...", container_name)) local cmd = string.format( - "cd %s && ALLOW_RESET=TRUE PORT=%d DATA_DIR=%s RAG_PROVIDER=%s %s_API_KEY=%s %s_API_BASE=%s RAG_LLM_MODEL=%s RAG_EMBED_MODEL=%s sh run.sh %s", + "cd %s && ALLOW_RESET=TRUE PORT=%d DATA_DIR=%s RAG_EMBED_PROVIDER=%s RAG_EMBED_ENDPOINT=%s RAG_EMBED_API_KEY=%s RAG_EMBED_MODEL=%s RAG_EMBED_EXTRA=%s RAG_LLM_PROVIDER=%s RAG_LLM_ENDPOINT=%s RAG_LLM_API_KEY=%s RAG_LLM_MODEL=%s RAG_LLM_EXTRA=%s sh run.sh %s", rag_service_dir, port, service_path, - Config.rag_service.provider, - Config.rag_service.provider:upper(), - openai_api_key, - Config.rag_service.provider:upper(), - Config.rag_service.endpoint, - Config.rag_service.llm_model, - Config.rag_service.embed_model, + Config.rag_service.embed.provider, + Config.rag_service.embed.endpoint, + embed_api_key, + Config.rag_service.embed.model, + embed_extra, + Config.rag_service.llm.provider, + Config.rag_service.llm.endpoint, + embed_api_key, + Config.rag_service.llm.model, + llm_extra, service_path ) @@ -211,7 +250,12 @@ function M.to_local_uri(uri) end function M.is_ready() - vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url())) + vim.fn.system( + string.format( + "curl -s -o /dev/null -w '%%{http_code}' %s", + string.format("%s%s", M.get_rag_service_url(), "/api/health") + ) + ) return vim.v.shell_error == 0 end diff --git a/py/rag-service/Dockerfile b/py/rag-service/Dockerfile index ea30a91..089f795 100644 --- a/py/rag-service/Dockerfile +++ b/py/rag-service/Dockerfile @@ -2,21 +2,20 @@ FROM python:3.11-slim-bookworm WORKDIR /app -RUN apt-get update && apt-get install -y curl git \ - && rm -rf /var/lib/apt/lists/* \ - && curl -LsSf https://astral.sh/uv/install.sh | sh +RUN apt-get update && apt-get install -y --no-install-recommends curl git \ + && rm -rf /var/lib/apt/lists/* \ + && curl -LsSf https://astral.sh/uv/install.sh | sh -ENV PATH="/root/.local/bin:$PATH" +ENV PATH="/root/.local/bin:$PATH" \ + PYTHONPATH=/app/src \ + PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 COPY requirements.txt . -RUN uv venv - -RUN uv pip install -r requirements.txt - -ENV PYTHONUNBUFFERED=1 \ - PYTHONDONTWRITEBYTECODE=1 +# 直接安装到系统依赖中,不创建虚拟环境 +RUN uv pip install --system -r requirements.txt COPY . . -CMD ["uv", "run", "fastapi", "run", "src/main.py", "--workers", "3", "--port", "20250"] +CMD ["uvicorn", "src.main:app", "--workers", "3", "--host", "0.0.0.0", "--port", "20250"] diff --git a/py/rag-service/README.md b/py/rag-service/README.md new file mode 100644 index 0000000..9d62be6 --- /dev/null +++ b/py/rag-service/README.md @@ -0,0 +1,135 @@ +# RAG Service Configuration + +This document describes how to configure the RAG service, including setting up Language Model (LLM) and Embedding providers. + +## Provider Support Matrix + +The following table shows which model types are supported by each provider: + +| Provider | LLM Support | Embedding Support | +| ---------- | ----------- | ----------------- | +| dashscope | Yes | Yes | +| ollama | Yes | Yes | +| openai | Yes | Yes | +| openrouter | Yes | No | + +## LLM Provider Configuration + +The `llm` section in the configuration file is used to configure the Language Model (LLM) used by the RAG service. + +Here are the configuration examples for each supported LLM provider: + +### OpenAI LLM Configuration + +[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py#L130) + +```lua +llm = { -- Configuration for the Language Model (LLM) used by the RAG service + provider = "openai", -- The LLM provider ("openai") + endpoint = "https://api.openai.com/v1", -- The LLM API endpoint + api_key = "OPENAI_API_KEY", -- The environment variable name for the LLM API key + model = "gpt-4o-mini", -- The LLM model name (e.g., "gpt-4o-mini", "gpt-3.5-turbo") + extra = {-- Extra configuration options for the LLM (optional) + temperature = 0.7, -- Controls the randomness of the output. Lower values make it more deterministic. + max_tokens = 512, -- The maximum number of tokens to generate in the completion. + -- system_prompt = "You are a helpful assistant.", -- A system prompt to guide the model's behavior. + -- timeout = 120, -- Request timeout in seconds. + }, +}, +``` + +### DashScope LLM Configuration + +[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-dashscope/llama_index/llms/dashscope/base.py#L155) + +```lua +llm = { -- Configuration for the Language Model (LLM) used by the RAG service + provider = "dashscope", -- The LLM provider ("dashscope") + endpoint = "", -- The LLM API endpoint (DashScope typically uses default or environment variables) + api_key = "DASHSCOPE_API_KEY", -- The environment variable name for the LLM API key + model = "qwen-plus", -- The LLM model name (e.g., "qwen-plus", "qwen-max") + extra = nil, -- Extra configuration options for the LLM (optional) +}, +``` + +### Ollama LLM Configuration + +[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py#L65) + +```lua +llm = { -- Configuration for the Language Model (LLM) used by the RAG service + provider = "ollama", -- The LLM provider ("ollama") + endpoint = "http://localhost:11434", -- The LLM API endpoint for Ollama + api_key = "", -- Ollama typically does not require an API key + model = "llama2", -- The LLM model name (e.g., "llama2", "mistral") + extra = nil, -- Extra configuration options for the LLM (optional) Kristin", -- Extra configuration options for the LLM (optional) +}, +``` + +### OpenRouter LLM Configuration + +[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-openrouter/llama_index/llms/openrouter/base.py#L17) + +```lua +llm = { -- Configuration for the Language Model (LLM) used by the RAG service + provider = "openrouter", -- The LLM provider ("openrouter") + endpoint = "https://openrouter.ai/api/v1", -- The LLM API endpoint for OpenRouter + api_key = "OPENROUTER_API_KEY", -- The environment variable name for the LLM API key + model = "openai/gpt-4o-mini", -- The LLM model name (e.g., "openai/gpt-4o-mini", "mistralai/mistral-7b-instruct") + extra = nil, -- Extra configuration options for the LLM (optional) +}, +``` + +## Embedding Provider Configuration + +The `embedding` section in the configuration file is used to configure the Embedding Model used by the RAG service. + +Here are the configuration examples for each supported Embedding provider: + +### OpenAI Embedding Configuration + +[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-openai/llama_index/embeddings/openai/base.py#L214) + +```lua +embed = { -- Configuration for the Embedding Model used by the RAG service + provider = "openai", -- The Embedding provider ("openai") + endpoint = "https://api.openai.com/v1", -- The Embedding API endpoint + api_key = "OPENAI_API_KEY", -- The environment variable name for the Embedding API key + model = "text-embedding-3-large", -- The Embedding model name (e.g., "text-embedding-3-small", "text-embedding-3-large") + extra = {-- Extra configuration options for the Embedding model (optional) + dimensions = nil, + }, +}, +``` + +### DashScope Embedding Configuration + +[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-dashscope/llama_index/embeddings/dashscope/base.py#L156) + +```lua +embed = { -- Configuration for the Embedding Model used by the RAG service + provider = "dashscope", -- The Embedding provider ("dashscope") + endpoint = "", -- The Embedding API endpoint (DashScope typically uses default or environment variables) + api_key = "DASHSCOPE_API_KEY", -- The environment variable name for the Embedding API key + model = "text-embedding-v3", -- The Embedding model name (e.g., "text-embedding-v2") + extra = { -- Extra configuration options for the Embedding model (optional) + embed_batch_size = 10, + }, +}, +``` + +### Ollama Embedding Configuration + +[See more configurations](https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py#L12) + +```lua +embed = { -- Configuration for the Embedding Model used by the RAG service + provider = "ollama", -- The Embedding provider ("ollama") + endpoint = "http://localhost:11434", -- The Embedding API endpoint for Ollama + api_key = "", -- Ollama typically does not require an API key + model = "nomic-embed-text", -- The Embedding model name (e.g., "nomic-embed-text") + extra = { -- Extra configuration options for the Embedding model (optional) + embed_batch_size = 10, + }, +}, +``` diff --git a/py/rag-service/requirements.txt b/py/rag-service/requirements.txt index 79ccc37..5dd24e1 100644 --- a/py/rag-service/requirements.txt +++ b/py/rag-service/requirements.txt @@ -17,15 +17,17 @@ chroma-hnswlib==0.7.6 chromadb==0.6.3 click==8.1.8 coloredlogs==15.0.1 +dashscope==1.22.2 dataclasses-json==0.6.7 decorator==5.1.1 -Deprecated==1.2.18 +deprecated==1.2.18 dirtyjson==1.0.8 distro==1.9.0 dnspython==2.7.0 +docx2txt==0.8 durationpy==0.9 -email_validator==2.2.0 -et_xmlfile==2.0.0 +email-validator==2.2.0 +et-xmlfile==2.0.0 executing==2.2.0 fastapi==0.115.8 fastapi-cli==0.0.7 @@ -42,15 +44,15 @@ h11==0.14.0 httpcore==1.0.7 httptools==0.6.4 httpx==0.28.1 -huggingface-hub==0.28.1 +huggingface-hub==0.31.4 humanfriendly==10.0 idna==3.10 -importlib_metadata==8.5.0 -importlib_resources==6.5.2 +importlib-metadata==8.5.0 +importlib-resources==6.5.2 ipdb==0.13.13 ipython==8.32.0 jedi==0.19.2 -Jinja2==3.1.5 +jinja2==3.1.5 jiter==0.8.2 joblib==1.4.2 kubernetes==32.0.0 @@ -60,11 +62,15 @@ llama-index==0.12.16 llama-index-agent-openai==0.4.3 llama-index-cli==0.4.0 llama-index-core==0.12.16.post1 -llama-index-embeddings-openai==0.3.1 +llama-index-embeddings-dashscope==0.3.0 llama-index-embeddings-ollama==0.5.0 +llama-index-embeddings-openai==0.3.1 llama-index-indices-managed-llama-cloud==0.6.4 -llama-index-llms-openai==0.3.18 +llama-index-llms-dashscope==0.3.3 llama-index-llms-ollama==0.5.2 +llama-index-llms-openai==0.3.18 +llama-index-llms-openai-like==0.3.4 +llama-index-llms-openrouter==0.3.1 llama-index-multi-modal-llms-openai==0.4.3 llama-index-program-openai==0.3.1 llama-index-question-gen-openai==0.3.0 @@ -74,7 +80,7 @@ llama-index-vector-stores-chroma==0.4.1 llama-parse==0.6.0 markdown-it-py==3.0.0 markdownify==0.14.1 -MarkupSafe==3.0.2 +markupsafe==3.0.2 marshmallow==3.26.1 matplotlib-inline==0.1.7 mdurl==0.1.2 @@ -88,6 +94,7 @@ networkx==3.4.2 nltk==3.9.1 numpy==2.2.2 oauthlib==3.2.2 +ollama==0.4.8 onnxruntime==1.20.1 openai==1.61.1 openpyxl==3.1.5 @@ -110,35 +117,36 @@ pathspec==0.12.1 pexpect==4.9.0 pillow==11.1.0 posthog==3.11.0 -prompt_toolkit==3.0.50 +prompt-toolkit==3.0.50 propcache==0.2.1 protobuf==5.29.3 ptyprocess==0.7.0 -pure_eval==0.2.3 +pure-eval==0.2.3 pyasn1==0.6.1 -pyasn1_modules==0.4.1 +pyasn1-modules==0.4.1 pydantic==2.10.6 -pydantic_core==2.27.2 -Pygments==2.19.1 +pydantic-core==2.27.2 +pygments==2.19.1 pypdf==5.2.0 -PyPika==0.48.9 -pyproject_hooks==1.2.0 +pypika==0.48.9 +pyproject-hooks==1.2.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 python-multipart==0.0.20 pytz==2025.1 -PyYAML==6.0.2 +pyyaml==6.0.2 regex==2024.11.6 requests==2.32.3 requests-oauthlib==2.0.0 rich==13.9.4 rich-toolkit==0.13.2 rsa==4.9 +safetensors==0.5.3 shellingham==1.5.4 six==1.17.0 sniffio==1.3.1 soupsieve==2.6 -SQLAlchemy==2.0.38 +sqlalchemy==2.0.38 stack-data==0.6.3 starlette==0.45.3 striprtf==0.0.26 @@ -148,11 +156,15 @@ tiktoken==0.8.0 tokenizers==0.21.0 tqdm==4.67.1 traitlets==5.14.3 +transformers==4.51.3 tree-sitter==0.24.0 +tree-sitter-c-sharp==0.23.1 +tree-sitter-embedded-template==0.23.2 tree-sitter-language-pack==0.6.1 +tree-sitter-yaml==0.7.0 typer==0.15.1 +typing-extensions==4.12.2 typing-inspect==0.9.0 -typing_extensions==4.12.2 tzdata==2025.1 urllib3==2.3.0 uvicorn==0.34.0 @@ -165,4 +177,3 @@ websockets==14.2 wrapt==1.17.2 yarl==1.18.3 zipp==3.21.0 -docx2txt==0.8.0 diff --git a/py/rag-service/src/main.py b/py/rag-service/src/main.py index fa2602d..c7bf1ac 100644 --- a/py/rag-service/src/main.py +++ b/py/rag-service/src/main.py @@ -2,6 +2,7 @@ from __future__ import annotations +# Standard library imports import asyncio import fcntl import json @@ -18,42 +19,24 @@ from pathlib import Path from typing import TYPE_CHECKING from urllib.parse import urljoin, urlparse +# Third-party imports import chromadb import httpx import pathspec from fastapi import BackgroundTasks, FastAPI, HTTPException -from libs.configs import ( - BASE_DATA_DIR, - CHROMA_PERSIST_DIR, -) + +# Local application imports +from libs.configs import BASE_DATA_DIR, CHROMA_PERSIST_DIR from libs.db import init_db from libs.logger import logger -from libs.utils import ( - get_node_uri, - inject_uri_to_node, - is_local_uri, - is_path_node, - is_remote_uri, - path_to_uri, - uri_to_path, -) -from llama_index.core import ( - Settings, - SimpleDirectoryReader, - StorageContext, - VectorStoreIndex, - load_index_from_storage, -) +from libs.utils import get_node_uri, inject_uri_to_node, is_local_uri, is_path_node, is_remote_uri, path_to_uri, uri_to_path +from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, VectorStoreIndex, load_index_from_storage from llama_index.core.node_parser import CodeSplitter from llama_index.core.schema import Document -from llama_index.embeddings.ollama import OllamaEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding -from llama_index.llms.ollama import Ollama -from llama_index.llms.openai import OpenAI from llama_index.vector_stores.chroma import ChromaVectorStore from markdownify import markdownify as md -from models.indexing_history import IndexingHistory # noqa: TC002 from models.resource import Resource +from providers.factory import initialize_embed_model, initialize_llm_model from pydantic import BaseModel, Field from services.indexing_history import indexing_history_service from services.resource import resource_service @@ -65,6 +48,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator from llama_index.core.schema import NodeWithScore, QueryBundle + from models.indexing_history import IndexingHistory from watchdog.observers.api import BaseObserver # Lock file for leader election @@ -111,8 +95,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 if is_local_uri(resource.uri): directory = uri_to_path(resource.uri) if not directory.exists(): - logger.error("Directory not found: %s", directory) - resource_service.update_resource_status(resource.uri, "error", f"Directory not found: {directory}") + error_msg = f"Directory not found: {directory}" + logger.error(error_msg) + resource_service.update_resource_status(resource.uri, "error", error_msg) continue # Start file system watcher @@ -127,8 +112,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 elif is_remote_uri(resource.uri): if not is_remote_resource_exists(resource.uri): - logger.error("HTTPS resource not found: %s", resource.uri) - resource_service.update_resource_status(resource.uri, "error", "remote resource not found") + error_msg = "HTTPS resource not found" + logger.error("%s: %s", error_msg, resource.uri) + resource_service.update_resource_status(resource.uri, "error", error_msg) continue # Start indexing @@ -273,7 +259,11 @@ def is_remote_resource_exists(url: str) -> bool: """Check if a URL exists.""" try: response = httpx.head(url, headers=http_headers) - return response.status_code in {httpx.codes.OK, httpx.codes.MOVED_PERMANENTLY, httpx.codes.FOUND} + return response.status_code in { + httpx.codes.OK, + httpx.codes.MOVED_PERMANENTLY, + httpx.codes.FOUND, + } except (OSError, ValueError, RuntimeError) as e: logger.error("Error checking if URL exists %s: %s", url, e) return False @@ -318,53 +308,78 @@ init_db() # Initialize ChromaDB and LlamaIndex services chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR)) -# Check if provider or model has changed -current_provider = os.getenv("RAG_PROVIDER", "openai").lower() -current_embed_model = os.getenv("RAG_EMBED_MODEL", "") -current_llm_model = os.getenv("RAG_LLM_MODEL", "") +# # Check if provider or model has changed +rag_embed_provider = os.getenv("RAG_EMBED_PROVIDER", "openai") +rag_embed_endpoint = os.getenv("RAG_EMBED_ENDPOINT", "https://api.openai.com/v1") +rag_embed_model = os.getenv("RAG_EMBED_MODEL", "text-embedding-3-large") +rag_embed_api_key = os.getenv("RAG_EMBED_API_KEY", None) +rag_embed_extra = os.getenv("RAG_EMBED_EXTRA", None) + +rag_llm_provider = os.getenv("RAG_LLM_PROVIDER", "openai") +rag_llm_endpoint = os.getenv("RAG_LLM_ENDPOINT", "https://api.openai.com/v1") +rag_llm_model = os.getenv("RAG_LLM_MODEL", "gpt-4o-mini") +rag_llm_api_key = os.getenv("RAG_LLM_API_KEY", None) +rag_llm_extra = os.getenv("RAG_LLM_EXTRA", None) # Try to read previous config config_file = BASE_DATA_DIR / "rag_config.json" if config_file.exists(): with Path.open(config_file, "r") as f: prev_config = json.load(f) - if prev_config.get("provider") != current_provider or prev_config.get("embed_model") != current_embed_model: + if prev_config.get("provider") != rag_embed_provider or prev_config.get("embed_model") != rag_embed_model: # Clear existing data if config changed logger.info("Detected config change, clearing existing data...") chroma_client.reset() # Save current config with Path.open(config_file, "w") as f: - json.dump({"provider": current_provider, "embed_model": current_embed_model}, f) + json.dump({"provider": rag_embed_provider, "embed_model": rag_embed_model}, f) chroma_collection = chroma_client.get_or_create_collection("documents") # pyright: ignore vector_store = ChromaVectorStore(chroma_collection=chroma_collection) storage_context = StorageContext.from_defaults(vector_store=vector_store) -# Initialize embedding model based on provider -llm_provider = current_provider -base_url = os.getenv(llm_provider.upper() + "_API_BASE", "") -rag_embed_model = current_embed_model -rag_llm_model = current_llm_model +try: + embed_extra = json.loads(rag_embed_extra) if rag_embed_extra is not None else {} +except json.JSONDecodeError: + logger.error("Failed to decode RAG_EMBED_EXTRA, defaulting to empty dict.") + embed_extra = {} + +try: + llm_extra = json.loads(rag_llm_extra) if rag_llm_extra is not None else {} +except json.JSONDecodeError: + logger.error("Failed to decode RAG_LLM_EXTRA, defaulting to empty dict.") + llm_extra = {} + +# Initialize embedding model and LLM based on provider using the factory +try: + embed_model = initialize_embed_model( + embed_provider=rag_embed_provider, + embed_model=rag_embed_model, + embed_endpoint=rag_embed_endpoint, + embed_api_key=rag_embed_api_key, + embed_extra=embed_extra, + ) + logger.info("Embedding model initialized successfully.") +except (ValueError, RuntimeError) as e: + error_msg = f"Failed to initialize embedding model: {e}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + +try: + llm_model = initialize_llm_model( + llm_provider=rag_llm_provider, + llm_model=rag_llm_model, + llm_endpoint=rag_llm_endpoint, + llm_api_key=rag_llm_api_key, + llm_extra=llm_extra, + ) + logger.info("LLM model initialized successfully.") +except (ValueError, RuntimeError) as e: + error_msg = f"Failed to initialize LLM model: {e}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e -if llm_provider == "ollama": - if base_url == "": - base_url = "http://localhost:11434" - if rag_embed_model == "": - rag_embed_model = "nomic-embed-text" - if rag_llm_model == "": - rag_llm_model = "llama3" - embed_model = OllamaEmbedding(model_name=rag_embed_model, base_url=base_url) - llm_model = Ollama(model=rag_llm_model, base_url=base_url, request_timeout=60.0) -else: - if base_url == "": - base_url = "https://api.openai.com/v1" - if rag_embed_model == "": - rag_embed_model = "text-embedding-3-small" - if rag_llm_model == "": - rag_llm_model = "gpt-4o-mini" - embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url) - llm_model = OpenAI(model=rag_llm_model, api_base=base_url) Settings.embed_model = embed_model Settings.llm = llm_model @@ -400,7 +415,10 @@ class SourceDocument(BaseModel): class RetrieveRequest(BaseModel): """Request model for information retrieval.""" - query: str = Field(..., description="The query text to search for in the indexed documents") + query: str = Field( + ..., + description="The query text to search for in the indexed documents", + ) base_uri: str = Field(..., description="The base URI to search in") top_k: int | None = Field(5, description="Number of top results to return", ge=1, le=20) @@ -478,7 +496,10 @@ def process_document_batch(documents: list[Document]) -> bool: # noqa: PLR0915, # Check if document with same hash has already been successfully processed status_records = indexing_history_service.get_indexing_status(doc=doc) if status_records and status_records[0].status == "completed": - logger.debug("Document with same hash already processed, skipping: %s", doc.doc_id) + logger.debug( + "Document with same hash already processed, skipping: %s", + doc.doc_id, + ) continue logger.debug("Processing document: %s", doc.doc_id) @@ -592,7 +613,10 @@ def get_gitcrypt_files(directory: Path) -> list[str]: ) if git_root_cmd.returncode != 0: - logger.warning("Not a git repository or git command failed: %s", git_root_cmd.stderr.strip()) + logger.warning( + "Not a git repository or git command failed: %s", + git_root_cmd.stderr.strip(), + ) return git_crypt_patterns git_root = Path(git_root_cmd.stdout.strip()) @@ -613,7 +637,15 @@ def get_gitcrypt_files(directory: Path) -> list[str]: # Use Python to process the output instead of xargs, grep, and cut git_check_attr = subprocess.run( - [git_executable, "-C", str(git_root), "check-attr", "filter", "--stdin", "-z"], + [ + git_executable, + "-C", + str(git_root), + "check-attr", + "filter", + "--stdin", + "-z", + ], input=git_ls_files.stdout, capture_output=True, text=False, @@ -830,7 +862,11 @@ def split_documents(documents: list[Document]) -> list[Document]: t = doc.get_content() texts = code_splitter.split_text(t) except ValueError as e: - logger.error("Error splitting document: %s, so skipping split, error: %s", doc.doc_id, str(e)) + logger.error( + "Error splitting document: %s, so skipping split, error: %s", + doc.doc_id, + str(e), + ) processed_documents.append(doc) continue @@ -1034,7 +1070,10 @@ async def add_resource(request: ResourceRequest, background_tasks: BackgroundTas if resource: if resource.name != request.name: - raise HTTPException(status_code=400, detail=f"Resource name cannot be changed: {resource.name}") + raise HTTPException( + status_code=400, + detail=f"Resource name cannot be changed: {resource.name}", + ) resource_service.update_resource_status(resource.uri, "active") else: @@ -1352,3 +1391,9 @@ async def list_resources() -> ResourceListResponse: total_count=len(resources), status_summary=status_counts, ) + + +@app.get("/api/health") +async def health_check() -> dict[str, str]: + """Health check endpoint.""" + return {"status": "ok"} diff --git a/py/rag-service/src/providers/__init__.py b/py/rag-service/src/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/py/rag-service/src/providers/dashscope.py b/py/rag-service/src/providers/dashscope.py new file mode 100644 index 0000000..b7995ac --- /dev/null +++ b/py/rag-service/src/providers/dashscope.py @@ -0,0 +1,70 @@ +# src/providers/dashscope.py + +from typing import Any + +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.llms.llm import LLM +from llama_index.embeddings.dashscope import DashScopeEmbedding +from llama_index.llms.dashscope import DashScope + + +def initialize_embed_model( + embed_endpoint: str, # noqa: ARG001 + embed_api_key: str, + embed_model: str, + **embed_extra: Any, # noqa: ANN401 +) -> BaseEmbedding: + """ + Create DashScope embedding model. + + Args: + embed_endpoint: Not be used directly by the constructor. + embed_api_key: The API key for the DashScope API. + embed_model: The name of the embedding model. + embed_extra: Extra parameters of the embedding model. + + Returns: + The initialized embed_model. + + """ + # DashScope typically uses the API key and model name. + # The endpoint might be set via environment variables or default. + # We pass embed_api_key and embed_model to the constructor. + # We include embed_endpoint in the signature to match the factory interface, + # but it might not be directly used by the constructor depending on LlamaIndex's implementation. + return DashScopeEmbedding( + model_name=embed_model, + api_key=embed_api_key, + **embed_extra, + ) + + +def initialize_llm_model( + llm_endpoint: str, # noqa: ARG001 + llm_api_key: str, + llm_model: str, + **llm_extra: Any, # noqa: ANN401 +) -> LLM: + """ + Create DashScope LLM model. + + Args: + llm_endpoint: Not be used directly by the constructor. + llm_api_key: The API key for the DashScope API. + llm_model: The name of the LLM model. + llm_extra: Extra parameters of the LLM model. + + Returns: + The initialized llm_model. + + """ + # DashScope typically uses the API key and model name. + # The endpoint might be set via environment variables or default. + # We pass llm_api_key and llm_model to the constructor. + # We include llm_endpoint in the signature to match the factory interface, + # but it might not be directly used by the constructor depending on LlamaIndex's implementation. + return DashScope( + model_name=llm_model, + api_key=llm_api_key, + **llm_extra, + ) diff --git a/py/rag-service/src/providers/factory.py b/py/rag-service/src/providers/factory.py new file mode 100644 index 0000000..574a478 --- /dev/null +++ b/py/rag-service/src/providers/factory.py @@ -0,0 +1,179 @@ +import importlib +from typing import TYPE_CHECKING, Any, cast + +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.llms.llm import LLM + +if TYPE_CHECKING: + from collections.abc import Callable + +from libs.logger import logger # Assuming libs.logger exists and provides a logger instance + + +def initialize_embed_model( + embed_provider: str, + embed_model: str, + embed_endpoint: str | None = None, + embed_api_key: str | None = None, + embed_extra: dict[str, Any] | None = None, +) -> BaseEmbedding: + """ + Initialize embedding model based on specified provider and configuration. + + Dynamically loads the provider module based on the embed_provider parameter. + + Args: + embed_provider: The name of the embedding provider (e.g., "openai", "ollama"). + embed_model: The name of the embedding model. + embed_endpoint: The API endpoint for the embedding provider. + embed_api_key: The API key for the embedding provider. + embed_extra: Additional provider-specific configuration parameters. + + Returns: + The initialized embed_model. + + Raises: + ValueError: If the specified embed_provider is not supported or module/function not found. + RuntimeError: If model initialization fails for the selected provider. + + """ + # Validate provider name + error_msg = f"Invalid EMBED_PROVIDER specified: '{embed_provider}'. Provider name must be alphanumeric or contain underscores." + if not embed_provider.replace("_", "").isalnum(): + raise ValueError(error_msg) + + try: + provider_module = importlib.import_module(f".{embed_provider}", package="providers") + logger.debug(f"Successfully imported provider module: providers.{embed_provider}") + attribute = getattr(provider_module, "initialize_embed_model", None) + if attribute is None: + error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function." + raise ValueError(error_msg) # noqa: TRY301 + + initializer = cast("Callable[..., BaseEmbedding]", attribute) + + except ImportError as err: + error_msg = f"Unsupported EMBED_PROVIDER specified: '{embed_provider}'. Could not find provider module 'providers.{embed_provider}" + raise ValueError(error_msg) from err + except AttributeError as err: + error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function." + raise ValueError(error_msg) from err + except Exception as err: + logger.error( + f"An unexpected error occurred while loading provider '{embed_provider}': {err!r}", + exc_info=True, + ) + error_msg = f"Failed to load provider '{embed_provider}' due to an unexpected error." + raise RuntimeError(error_msg) from err + + logger.info(f"Initializing embedding model for provider: {embed_provider}") + + try: + embedding: BaseEmbedding = initializer( + embed_endpoint, + embed_api_key, + embed_model, + **(embed_extra or {}), + ) + + logger.info(f"Embedding model initialized successfully for {embed_provider}") + return embedding + except TypeError as err: + error_msg = f"Provider initializer 'initialize_embed_model' was called with incorrect arguments in '{embed_provider}'" + logger.error( + f"{error_msg}: {err!r}", + exc_info=True, + ) + raise RuntimeError(error_msg) from err + except Exception as err: + error_msg = f"Failed to initialize embedding model for provider '{embed_provider}'" + logger.error( + f"{error_msg}: {err!r}", + exc_info=True, + ) + raise RuntimeError(error_msg) from err + + +def initialize_llm_model( + llm_provider: str, + llm_model: str, + llm_endpoint: str | None = None, + llm_api_key: str | None = None, + llm_extra: dict[str, Any] | None = None, +) -> LLM: + """ + Create LLM model with the specified configuration. + + Dynamically loads the provider module based on the llm_provider parameter. + + Args: + llm_provider: The name of the LLM provider (e.g., "openai", "ollama"). + llm_endpoint: The API endpoint for the LLM provider. + llm_api_key: The API key for the LLM provider. + llm_model: The name of the LLM model. + llm_extra: The name of the LLM model. + + Returns: + The initialized llm_model. + + Raises: + ValueError: If the specified llm_provider is not supported or module/function not found. + RuntimeError: If model initialization fails for the selected provider. + + """ + if not llm_provider.replace("_", "").isalnum(): + error_msg = f"Invalid LLM_PROVIDER specified: '{llm_provider}'. Provider name must be alphanumeric or contain underscores." + raise ValueError(error_msg) + + try: + provider_module = importlib.import_module( + f".{llm_provider}", + package="providers", + ) + logger.debug(f"Successfully imported provider module: providers.{llm_provider}") + attribute = getattr(provider_module, "initialize_llm_model", None) + if attribute is None: + error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function." + raise ValueError(error_msg) # noqa: TRY301 + + initializer = cast("Callable[..., LLM]", attribute) + + except ImportError as err: + error_msg = f"Unsupported LLM_PROVIDER specified: '{llm_provider}'. Could not find provider module 'providers.{llm_provider}'." + raise ValueError(error_msg) from err + + except AttributeError as err: + error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function." + raise ValueError(error_msg) from err + + except Exception as e: + error_msg = f"An unexpected error occurred while loading provider '{llm_provider}': {e}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + logger.info(f"Initializing LLM model for provider: '{llm_provider}'") + logger.debug(f"Args: llm_model='{llm_model}', llm_endpoint='{llm_endpoint}'") + + try: + llm: LLM = initializer( + llm_endpoint, + llm_api_key, + llm_model, + **(llm_extra or {}), + ) + logger.info(f"LLM model initialized successfully for '{llm_provider}'.") + + except TypeError as e: + error_msg = f"Provider initializer 'initialize_llm_model' in '{llm_provider}' was called with incorrect arguments: {e}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + except Exception as e: + error_msg = f"Failed to initialize LLM model for provider '{llm_provider}': {e}" + logger.error( + error_msg, + exc_info=True, + ) + raise RuntimeError(error_msg) from e + + return llm diff --git a/py/rag-service/src/providers/ollama.py b/py/rag-service/src/providers/ollama.py new file mode 100644 index 0000000..cc69e91 --- /dev/null +++ b/py/rag-service/src/providers/ollama.py @@ -0,0 +1,66 @@ +# src/providers/ollama.py + +from typing import Any + +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.llms.llm import LLM +from llama_index.embeddings.ollama import OllamaEmbedding +from llama_index.llms.ollama import Ollama + + +def initialize_embed_model( + embed_endpoint: str, + embed_api_key: str, # noqa: ARG001 + embed_model: str, + **embed_extra: Any, # noqa: ANN401 +) -> BaseEmbedding: + """ + Create Ollama embedding model. + + Args: + embed_endpoint: The API endpoint for the Ollama API. + embed_api_key: Not be used by Ollama. + embed_model: The name of the embedding model. + embed_extra: Extra parameters for Ollama embedding model. + + Returns: + The initialized embed_model. + + """ + # Ollama typically uses the endpoint directly and may not require an API key + # We include embed_api_key in the signature to match the factory interface + # Pass embed_api_key even if Ollama doesn't use it, to match the signature + return OllamaEmbedding( + model_name=embed_model, + base_url=embed_endpoint, + **embed_extra, + ) + + +def initialize_llm_model( + llm_endpoint: str, + llm_api_key: str, # noqa: ARG001 + llm_model: str, + **llm_extra: Any, # noqa: ANN401 +) -> LLM: + """ + Create Ollama LLM model. + + Args: + llm_endpoint: The API endpoint for the Ollama API. + llm_api_key: Not be used by Ollama. + llm_model: The name of the LLM model. + llm_extra: Extra parameters for LLM model. + + Returns: + The initialized llm_model. + + """ + # Ollama typically uses the endpoint directly and may not require an API key + # We include llm_api_key in the signature to match the factory interface + # Pass llm_api_key even if Ollama doesn't use it, to match the signature + return Ollama( + model=llm_model, + base_url=llm_endpoint, + **llm_extra, + ) diff --git a/py/rag-service/src/providers/openai.py b/py/rag-service/src/providers/openai.py new file mode 100644 index 0000000..f520051 --- /dev/null +++ b/py/rag-service/src/providers/openai.py @@ -0,0 +1,68 @@ +# src/providers/openai.py + +from typing import Any + +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.llms.llm import LLM +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.llms.openai import OpenAI + + +def initialize_embed_model( + embed_endpoint: str, + embed_api_key: str, + embed_model: str, + **embed_extra: Any, # noqa: ANN401 +) -> BaseEmbedding: + """ + Create OpenAI embedding model. + + Args: + embed_model: The name of the embedding model. + embed_endpoint: The API endpoint for the OpenAI API. + embed_api_key: The API key for the OpenAI API. + embed_extra: Extra Paramaters for the OpenAI API. + + Returns: + The initialized embed_model. + + """ + # Use the provided endpoint directly. + # Note: OpenAIEmbedding automatically picks up OPENAI_API_KEY env var + # We are not using embed_api_key parameter here, relying on env var as original code did. + return OpenAIEmbedding( + model=embed_model, + api_base=embed_endpoint, + api_key=embed_api_key, + **embed_extra, + ) + + +def initialize_llm_model( + llm_endpoint: str, + llm_api_key: str, + llm_model: str, + **llm_extra: Any, # noqa: ANN401 +) -> LLM: + """ + Create OpenAI LLM model. + + Args: + llm_model: The name of the LLM model. + llm_endpoint: The API endpoint for the OpenAI API. + llm_api_key: The API key for the OpenAI API. + llm_extra: Extra paramaters for the OpenAI API. + + Returns: + The initialized llm_model. + + """ + # Use the provided endpoint directly. + # Note: OpenAI automatically picks up OPENAI_API_KEY env var + # We are not using llm_api_key parameter here, relying on env var as original code did. + return OpenAI( + model=llm_model, + api_base=llm_endpoint, + api_key=llm_api_key, + **llm_extra, + ) diff --git a/py/rag-service/src/providers/openrouter.py b/py/rag-service/src/providers/openrouter.py new file mode 100644 index 0000000..16cf284 --- /dev/null +++ b/py/rag-service/src/providers/openrouter.py @@ -0,0 +1,35 @@ +# src/providers/openrouter.py + +from typing import Any + +from llama_index.core.llms.llm import LLM +from llama_index.llms.openrouter import OpenRouter + + +def initialize_llm_model( + llm_endpoint: str, + llm_api_key: str, + llm_model: str, + **llm_extra: Any, # noqa: ANN401 +) -> LLM: + """ + Create OpenRouter LLM model. + + Args: + llm_model: The name of the LLM model. + llm_endpoint: The API endpoint for the OpenRouter API. + llm_api_key: The API key for the OpenRouter API. + llm_extra: The Extra Parameters for OpenROuter, + + Returns: + The initialized llm_model. + + """ + # Use the provided endpoint directly. + # We are not using llm_api_key parameter here, relying on env var as original code did. + return OpenRouter( + model=llm_model, + api_base=llm_endpoint, + api_key=llm_api_key, + **llm_extra, + )