feat: Enhanced Model Provider Support and Configuration Flexibility For Rag Service (#2056)
Co-authored-by: doodleEsc <cokie@foxmail.com> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Standard library imports
|
||||
import asyncio
|
||||
import fcntl
|
||||
import json
|
||||
@@ -18,42 +19,24 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
# Third-party imports
|
||||
import chromadb
|
||||
import httpx
|
||||
import pathspec
|
||||
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
||||
from libs.configs import (
|
||||
BASE_DATA_DIR,
|
||||
CHROMA_PERSIST_DIR,
|
||||
)
|
||||
|
||||
# Local application imports
|
||||
from libs.configs import BASE_DATA_DIR, CHROMA_PERSIST_DIR
|
||||
from libs.db import init_db
|
||||
from libs.logger import logger
|
||||
from libs.utils import (
|
||||
get_node_uri,
|
||||
inject_uri_to_node,
|
||||
is_local_uri,
|
||||
is_path_node,
|
||||
is_remote_uri,
|
||||
path_to_uri,
|
||||
uri_to_path,
|
||||
)
|
||||
from llama_index.core import (
|
||||
Settings,
|
||||
SimpleDirectoryReader,
|
||||
StorageContext,
|
||||
VectorStoreIndex,
|
||||
load_index_from_storage,
|
||||
)
|
||||
from libs.utils import get_node_uri, inject_uri_to_node, is_local_uri, is_path_node, is_remote_uri, path_to_uri, uri_to_path
|
||||
from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.node_parser import CodeSplitter
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from markdownify import markdownify as md
|
||||
from models.indexing_history import IndexingHistory # noqa: TC002
|
||||
from models.resource import Resource
|
||||
from providers.factory import initialize_embed_model, initialize_llm_model
|
||||
from pydantic import BaseModel, Field
|
||||
from services.indexing_history import indexing_history_service
|
||||
from services.resource import resource_service
|
||||
@@ -65,6 +48,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from models.indexing_history import IndexingHistory
|
||||
from watchdog.observers.api import BaseObserver
|
||||
|
||||
# Lock file for leader election
|
||||
@@ -111,8 +95,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if is_local_uri(resource.uri):
|
||||
directory = uri_to_path(resource.uri)
|
||||
if not directory.exists():
|
||||
logger.error("Directory not found: %s", directory)
|
||||
resource_service.update_resource_status(resource.uri, "error", f"Directory not found: {directory}")
|
||||
error_msg = f"Directory not found: {directory}"
|
||||
logger.error(error_msg)
|
||||
resource_service.update_resource_status(resource.uri, "error", error_msg)
|
||||
continue
|
||||
|
||||
# Start file system watcher
|
||||
@@ -127,8 +112,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
|
||||
elif is_remote_uri(resource.uri):
|
||||
if not is_remote_resource_exists(resource.uri):
|
||||
logger.error("HTTPS resource not found: %s", resource.uri)
|
||||
resource_service.update_resource_status(resource.uri, "error", "remote resource not found")
|
||||
error_msg = "HTTPS resource not found"
|
||||
logger.error("%s: %s", error_msg, resource.uri)
|
||||
resource_service.update_resource_status(resource.uri, "error", error_msg)
|
||||
continue
|
||||
|
||||
# Start indexing
|
||||
@@ -273,7 +259,11 @@ def is_remote_resource_exists(url: str) -> bool:
|
||||
"""Check if a URL exists."""
|
||||
try:
|
||||
response = httpx.head(url, headers=http_headers)
|
||||
return response.status_code in {httpx.codes.OK, httpx.codes.MOVED_PERMANENTLY, httpx.codes.FOUND}
|
||||
return response.status_code in {
|
||||
httpx.codes.OK,
|
||||
httpx.codes.MOVED_PERMANENTLY,
|
||||
httpx.codes.FOUND,
|
||||
}
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
logger.error("Error checking if URL exists %s: %s", url, e)
|
||||
return False
|
||||
@@ -318,53 +308,78 @@ init_db()
|
||||
# Initialize ChromaDB and LlamaIndex services
|
||||
chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR))
|
||||
|
||||
# Check if provider or model has changed
|
||||
current_provider = os.getenv("RAG_PROVIDER", "openai").lower()
|
||||
current_embed_model = os.getenv("RAG_EMBED_MODEL", "")
|
||||
current_llm_model = os.getenv("RAG_LLM_MODEL", "")
|
||||
# # Check if provider or model has changed
|
||||
rag_embed_provider = os.getenv("RAG_EMBED_PROVIDER", "openai")
|
||||
rag_embed_endpoint = os.getenv("RAG_EMBED_ENDPOINT", "https://api.openai.com/v1")
|
||||
rag_embed_model = os.getenv("RAG_EMBED_MODEL", "text-embedding-3-large")
|
||||
rag_embed_api_key = os.getenv("RAG_EMBED_API_KEY", None)
|
||||
rag_embed_extra = os.getenv("RAG_EMBED_EXTRA", None)
|
||||
|
||||
rag_llm_provider = os.getenv("RAG_LLM_PROVIDER", "openai")
|
||||
rag_llm_endpoint = os.getenv("RAG_LLM_ENDPOINT", "https://api.openai.com/v1")
|
||||
rag_llm_model = os.getenv("RAG_LLM_MODEL", "gpt-4o-mini")
|
||||
rag_llm_api_key = os.getenv("RAG_LLM_API_KEY", None)
|
||||
rag_llm_extra = os.getenv("RAG_LLM_EXTRA", None)
|
||||
|
||||
# Try to read previous config
|
||||
config_file = BASE_DATA_DIR / "rag_config.json"
|
||||
if config_file.exists():
|
||||
with Path.open(config_file, "r") as f:
|
||||
prev_config = json.load(f)
|
||||
if prev_config.get("provider") != current_provider or prev_config.get("embed_model") != current_embed_model:
|
||||
if prev_config.get("provider") != rag_embed_provider or prev_config.get("embed_model") != rag_embed_model:
|
||||
# Clear existing data if config changed
|
||||
logger.info("Detected config change, clearing existing data...")
|
||||
chroma_client.reset()
|
||||
|
||||
# Save current config
|
||||
with Path.open(config_file, "w") as f:
|
||||
json.dump({"provider": current_provider, "embed_model": current_embed_model}, f)
|
||||
json.dump({"provider": rag_embed_provider, "embed_model": rag_embed_model}, f)
|
||||
|
||||
chroma_collection = chroma_client.get_or_create_collection("documents") # pyright: ignore
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
|
||||
# Initialize embedding model based on provider
|
||||
llm_provider = current_provider
|
||||
base_url = os.getenv(llm_provider.upper() + "_API_BASE", "")
|
||||
rag_embed_model = current_embed_model
|
||||
rag_llm_model = current_llm_model
|
||||
try:
|
||||
embed_extra = json.loads(rag_embed_extra) if rag_embed_extra is not None else {}
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode RAG_EMBED_EXTRA, defaulting to empty dict.")
|
||||
embed_extra = {}
|
||||
|
||||
try:
|
||||
llm_extra = json.loads(rag_llm_extra) if rag_llm_extra is not None else {}
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode RAG_LLM_EXTRA, defaulting to empty dict.")
|
||||
llm_extra = {}
|
||||
|
||||
# Initialize embedding model and LLM based on provider using the factory
|
||||
try:
|
||||
embed_model = initialize_embed_model(
|
||||
embed_provider=rag_embed_provider,
|
||||
embed_model=rag_embed_model,
|
||||
embed_endpoint=rag_embed_endpoint,
|
||||
embed_api_key=rag_embed_api_key,
|
||||
embed_extra=embed_extra,
|
||||
)
|
||||
logger.info("Embedding model initialized successfully.")
|
||||
except (ValueError, RuntimeError) as e:
|
||||
error_msg = f"Failed to initialize embedding model: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
try:
|
||||
llm_model = initialize_llm_model(
|
||||
llm_provider=rag_llm_provider,
|
||||
llm_model=rag_llm_model,
|
||||
llm_endpoint=rag_llm_endpoint,
|
||||
llm_api_key=rag_llm_api_key,
|
||||
llm_extra=llm_extra,
|
||||
)
|
||||
logger.info("LLM model initialized successfully.")
|
||||
except (ValueError, RuntimeError) as e:
|
||||
error_msg = f"Failed to initialize LLM model: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
if llm_provider == "ollama":
|
||||
if base_url == "":
|
||||
base_url = "http://localhost:11434"
|
||||
if rag_embed_model == "":
|
||||
rag_embed_model = "nomic-embed-text"
|
||||
if rag_llm_model == "":
|
||||
rag_llm_model = "llama3"
|
||||
embed_model = OllamaEmbedding(model_name=rag_embed_model, base_url=base_url)
|
||||
llm_model = Ollama(model=rag_llm_model, base_url=base_url, request_timeout=60.0)
|
||||
else:
|
||||
if base_url == "":
|
||||
base_url = "https://api.openai.com/v1"
|
||||
if rag_embed_model == "":
|
||||
rag_embed_model = "text-embedding-3-small"
|
||||
if rag_llm_model == "":
|
||||
rag_llm_model = "gpt-4o-mini"
|
||||
embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url)
|
||||
llm_model = OpenAI(model=rag_llm_model, api_base=base_url)
|
||||
|
||||
Settings.embed_model = embed_model
|
||||
Settings.llm = llm_model
|
||||
@@ -400,7 +415,10 @@ class SourceDocument(BaseModel):
|
||||
class RetrieveRequest(BaseModel):
|
||||
"""Request model for information retrieval."""
|
||||
|
||||
query: str = Field(..., description="The query text to search for in the indexed documents")
|
||||
query: str = Field(
|
||||
...,
|
||||
description="The query text to search for in the indexed documents",
|
||||
)
|
||||
base_uri: str = Field(..., description="The base URI to search in")
|
||||
top_k: int | None = Field(5, description="Number of top results to return", ge=1, le=20)
|
||||
|
||||
@@ -478,7 +496,10 @@ def process_document_batch(documents: list[Document]) -> bool: # noqa: PLR0915,
|
||||
# Check if document with same hash has already been successfully processed
|
||||
status_records = indexing_history_service.get_indexing_status(doc=doc)
|
||||
if status_records and status_records[0].status == "completed":
|
||||
logger.debug("Document with same hash already processed, skipping: %s", doc.doc_id)
|
||||
logger.debug(
|
||||
"Document with same hash already processed, skipping: %s",
|
||||
doc.doc_id,
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug("Processing document: %s", doc.doc_id)
|
||||
@@ -592,7 +613,10 @@ def get_gitcrypt_files(directory: Path) -> list[str]:
|
||||
)
|
||||
|
||||
if git_root_cmd.returncode != 0:
|
||||
logger.warning("Not a git repository or git command failed: %s", git_root_cmd.stderr.strip())
|
||||
logger.warning(
|
||||
"Not a git repository or git command failed: %s",
|
||||
git_root_cmd.stderr.strip(),
|
||||
)
|
||||
return git_crypt_patterns
|
||||
|
||||
git_root = Path(git_root_cmd.stdout.strip())
|
||||
@@ -613,7 +637,15 @@ def get_gitcrypt_files(directory: Path) -> list[str]:
|
||||
|
||||
# Use Python to process the output instead of xargs, grep, and cut
|
||||
git_check_attr = subprocess.run(
|
||||
[git_executable, "-C", str(git_root), "check-attr", "filter", "--stdin", "-z"],
|
||||
[
|
||||
git_executable,
|
||||
"-C",
|
||||
str(git_root),
|
||||
"check-attr",
|
||||
"filter",
|
||||
"--stdin",
|
||||
"-z",
|
||||
],
|
||||
input=git_ls_files.stdout,
|
||||
capture_output=True,
|
||||
text=False,
|
||||
@@ -830,7 +862,11 @@ def split_documents(documents: list[Document]) -> list[Document]:
|
||||
t = doc.get_content()
|
||||
texts = code_splitter.split_text(t)
|
||||
except ValueError as e:
|
||||
logger.error("Error splitting document: %s, so skipping split, error: %s", doc.doc_id, str(e))
|
||||
logger.error(
|
||||
"Error splitting document: %s, so skipping split, error: %s",
|
||||
doc.doc_id,
|
||||
str(e),
|
||||
)
|
||||
processed_documents.append(doc)
|
||||
continue
|
||||
|
||||
@@ -1034,7 +1070,10 @@ async def add_resource(request: ResourceRequest, background_tasks: BackgroundTas
|
||||
|
||||
if resource:
|
||||
if resource.name != request.name:
|
||||
raise HTTPException(status_code=400, detail=f"Resource name cannot be changed: {resource.name}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Resource name cannot be changed: {resource.name}",
|
||||
)
|
||||
|
||||
resource_service.update_resource_status(resource.uri, "active")
|
||||
else:
|
||||
@@ -1352,3 +1391,9 @@ async def list_resources() -> ResourceListResponse:
|
||||
total_count=len(resources),
|
||||
status_summary=status_counts,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
|
||||
0
py/rag-service/src/providers/__init__.py
Normal file
0
py/rag-service/src/providers/__init__.py
Normal file
70
py/rag-service/src/providers/dashscope.py
Normal file
70
py/rag-service/src/providers/dashscope.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# src/providers/dashscope.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.dashscope import DashScopeEmbedding
|
||||
from llama_index.llms.dashscope import DashScope
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str, # noqa: ARG001
|
||||
embed_api_key: str,
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create DashScope embedding model.
|
||||
|
||||
Args:
|
||||
embed_endpoint: Not be used directly by the constructor.
|
||||
embed_api_key: The API key for the DashScope API.
|
||||
embed_model: The name of the embedding model.
|
||||
embed_extra: Extra parameters of the embedding model.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# DashScope typically uses the API key and model name.
|
||||
# The endpoint might be set via environment variables or default.
|
||||
# We pass embed_api_key and embed_model to the constructor.
|
||||
# We include embed_endpoint in the signature to match the factory interface,
|
||||
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
|
||||
return DashScopeEmbedding(
|
||||
model_name=embed_model,
|
||||
api_key=embed_api_key,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str, # noqa: ARG001
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create DashScope LLM model.
|
||||
|
||||
Args:
|
||||
llm_endpoint: Not be used directly by the constructor.
|
||||
llm_api_key: The API key for the DashScope API.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: Extra parameters of the LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# DashScope typically uses the API key and model name.
|
||||
# The endpoint might be set via environment variables or default.
|
||||
# We pass llm_api_key and llm_model to the constructor.
|
||||
# We include llm_endpoint in the signature to match the factory interface,
|
||||
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
|
||||
return DashScope(
|
||||
model_name=llm_model,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
179
py/rag-service/src/providers/factory.py
Normal file
179
py/rag-service/src/providers/factory.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from libs.logger import logger # Assuming libs.logger exists and provides a logger instance
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_provider: str,
|
||||
embed_model: str,
|
||||
embed_endpoint: str | None = None,
|
||||
embed_api_key: str | None = None,
|
||||
embed_extra: dict[str, Any] | None = None,
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Initialize embedding model based on specified provider and configuration.
|
||||
|
||||
Dynamically loads the provider module based on the embed_provider parameter.
|
||||
|
||||
Args:
|
||||
embed_provider: The name of the embedding provider (e.g., "openai", "ollama").
|
||||
embed_model: The name of the embedding model.
|
||||
embed_endpoint: The API endpoint for the embedding provider.
|
||||
embed_api_key: The API key for the embedding provider.
|
||||
embed_extra: Additional provider-specific configuration parameters.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified embed_provider is not supported or module/function not found.
|
||||
RuntimeError: If model initialization fails for the selected provider.
|
||||
|
||||
"""
|
||||
# Validate provider name
|
||||
error_msg = f"Invalid EMBED_PROVIDER specified: '{embed_provider}'. Provider name must be alphanumeric or contain underscores."
|
||||
if not embed_provider.replace("_", "").isalnum():
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
provider_module = importlib.import_module(f".{embed_provider}", package="providers")
|
||||
logger.debug(f"Successfully imported provider module: providers.{embed_provider}")
|
||||
attribute = getattr(provider_module, "initialize_embed_model", None)
|
||||
if attribute is None:
|
||||
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
|
||||
raise ValueError(error_msg) # noqa: TRY301
|
||||
|
||||
initializer = cast("Callable[..., BaseEmbedding]", attribute)
|
||||
|
||||
except ImportError as err:
|
||||
error_msg = f"Unsupported EMBED_PROVIDER specified: '{embed_provider}'. Could not find provider module 'providers.{embed_provider}"
|
||||
raise ValueError(error_msg) from err
|
||||
except AttributeError as err:
|
||||
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
|
||||
raise ValueError(error_msg) from err
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"An unexpected error occurred while loading provider '{embed_provider}': {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
error_msg = f"Failed to load provider '{embed_provider}' due to an unexpected error."
|
||||
raise RuntimeError(error_msg) from err
|
||||
|
||||
logger.info(f"Initializing embedding model for provider: {embed_provider}")
|
||||
|
||||
try:
|
||||
embedding: BaseEmbedding = initializer(
|
||||
embed_endpoint,
|
||||
embed_api_key,
|
||||
embed_model,
|
||||
**(embed_extra or {}),
|
||||
)
|
||||
|
||||
logger.info(f"Embedding model initialized successfully for {embed_provider}")
|
||||
return embedding
|
||||
except TypeError as err:
|
||||
error_msg = f"Provider initializer 'initialize_embed_model' was called with incorrect arguments in '{embed_provider}'"
|
||||
logger.error(
|
||||
f"{error_msg}: {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from err
|
||||
except Exception as err:
|
||||
error_msg = f"Failed to initialize embedding model for provider '{embed_provider}'"
|
||||
logger.error(
|
||||
f"{error_msg}: {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from err
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_provider: str,
|
||||
llm_model: str,
|
||||
llm_endpoint: str | None = None,
|
||||
llm_api_key: str | None = None,
|
||||
llm_extra: dict[str, Any] | None = None,
|
||||
) -> LLM:
|
||||
"""
|
||||
Create LLM model with the specified configuration.
|
||||
|
||||
Dynamically loads the provider module based on the llm_provider parameter.
|
||||
|
||||
Args:
|
||||
llm_provider: The name of the LLM provider (e.g., "openai", "ollama").
|
||||
llm_endpoint: The API endpoint for the LLM provider.
|
||||
llm_api_key: The API key for the LLM provider.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: The name of the LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified llm_provider is not supported or module/function not found.
|
||||
RuntimeError: If model initialization fails for the selected provider.
|
||||
|
||||
"""
|
||||
if not llm_provider.replace("_", "").isalnum():
|
||||
error_msg = f"Invalid LLM_PROVIDER specified: '{llm_provider}'. Provider name must be alphanumeric or contain underscores."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
provider_module = importlib.import_module(
|
||||
f".{llm_provider}",
|
||||
package="providers",
|
||||
)
|
||||
logger.debug(f"Successfully imported provider module: providers.{llm_provider}")
|
||||
attribute = getattr(provider_module, "initialize_llm_model", None)
|
||||
if attribute is None:
|
||||
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
|
||||
raise ValueError(error_msg) # noqa: TRY301
|
||||
|
||||
initializer = cast("Callable[..., LLM]", attribute)
|
||||
|
||||
except ImportError as err:
|
||||
error_msg = f"Unsupported LLM_PROVIDER specified: '{llm_provider}'. Could not find provider module 'providers.{llm_provider}'."
|
||||
raise ValueError(error_msg) from err
|
||||
|
||||
except AttributeError as err:
|
||||
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
|
||||
raise ValueError(error_msg) from err
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"An unexpected error occurred while loading provider '{llm_provider}': {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
logger.info(f"Initializing LLM model for provider: '{llm_provider}'")
|
||||
logger.debug(f"Args: llm_model='{llm_model}', llm_endpoint='{llm_endpoint}'")
|
||||
|
||||
try:
|
||||
llm: LLM = initializer(
|
||||
llm_endpoint,
|
||||
llm_api_key,
|
||||
llm_model,
|
||||
**(llm_extra or {}),
|
||||
)
|
||||
logger.info(f"LLM model initialized successfully for '{llm_provider}'.")
|
||||
|
||||
except TypeError as e:
|
||||
error_msg = f"Provider initializer 'initialize_llm_model' in '{llm_provider}' was called with incorrect arguments: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to initialize LLM model for provider '{llm_provider}': {e}"
|
||||
logger.error(
|
||||
error_msg,
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
return llm
|
||||
66
py/rag-service/src/providers/ollama.py
Normal file
66
py/rag-service/src/providers/ollama.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# src/providers/ollama.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str,
|
||||
embed_api_key: str, # noqa: ARG001
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create Ollama embedding model.
|
||||
|
||||
Args:
|
||||
embed_endpoint: The API endpoint for the Ollama API.
|
||||
embed_api_key: Not be used by Ollama.
|
||||
embed_model: The name of the embedding model.
|
||||
embed_extra: Extra parameters for Ollama embedding model.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# Ollama typically uses the endpoint directly and may not require an API key
|
||||
# We include embed_api_key in the signature to match the factory interface
|
||||
# Pass embed_api_key even if Ollama doesn't use it, to match the signature
|
||||
return OllamaEmbedding(
|
||||
model_name=embed_model,
|
||||
base_url=embed_endpoint,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str, # noqa: ARG001
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create Ollama LLM model.
|
||||
|
||||
Args:
|
||||
llm_endpoint: The API endpoint for the Ollama API.
|
||||
llm_api_key: Not be used by Ollama.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: Extra parameters for LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Ollama typically uses the endpoint directly and may not require an API key
|
||||
# We include llm_api_key in the signature to match the factory interface
|
||||
# Pass llm_api_key even if Ollama doesn't use it, to match the signature
|
||||
return Ollama(
|
||||
model=llm_model,
|
||||
base_url=llm_endpoint,
|
||||
**llm_extra,
|
||||
)
|
||||
68
py/rag-service/src/providers/openai.py
Normal file
68
py/rag-service/src/providers/openai.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# src/providers/openai.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str,
|
||||
embed_api_key: str,
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create OpenAI embedding model.
|
||||
|
||||
Args:
|
||||
embed_model: The name of the embedding model.
|
||||
embed_endpoint: The API endpoint for the OpenAI API.
|
||||
embed_api_key: The API key for the OpenAI API.
|
||||
embed_extra: Extra Paramaters for the OpenAI API.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# Note: OpenAIEmbedding automatically picks up OPENAI_API_KEY env var
|
||||
# We are not using embed_api_key parameter here, relying on env var as original code did.
|
||||
return OpenAIEmbedding(
|
||||
model=embed_model,
|
||||
api_base=embed_endpoint,
|
||||
api_key=embed_api_key,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create OpenAI LLM model.
|
||||
|
||||
Args:
|
||||
llm_model: The name of the LLM model.
|
||||
llm_endpoint: The API endpoint for the OpenAI API.
|
||||
llm_api_key: The API key for the OpenAI API.
|
||||
llm_extra: Extra paramaters for the OpenAI API.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# Note: OpenAI automatically picks up OPENAI_API_KEY env var
|
||||
# We are not using llm_api_key parameter here, relying on env var as original code did.
|
||||
return OpenAI(
|
||||
model=llm_model,
|
||||
api_base=llm_endpoint,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
35
py/rag-service/src/providers/openrouter.py
Normal file
35
py/rag-service/src/providers/openrouter.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# src/providers/openrouter.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.llms.openrouter import OpenRouter
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create OpenRouter LLM model.
|
||||
|
||||
Args:
|
||||
llm_model: The name of the LLM model.
|
||||
llm_endpoint: The API endpoint for the OpenRouter API.
|
||||
llm_api_key: The API key for the OpenRouter API.
|
||||
llm_extra: The Extra Parameters for OpenROuter,
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# We are not using llm_api_key parameter here, relying on env var as original code did.
|
||||
return OpenRouter(
|
||||
model=llm_model,
|
||||
api_base=llm_endpoint,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
Reference in New Issue
Block a user