feat: Enhanced Model Provider Support and Configuration Flexibility For Rag Service (#2056)

Co-authored-by: doodleEsc <cokie@foxmail.com>
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
doodleEsc
2025-06-06 23:07:07 +08:00
committed by GitHub
parent ec0f4f9ae0
commit 2dd4c04088
15 changed files with 844 additions and 151 deletions

View File

View File

@@ -0,0 +1,70 @@
# src/providers/dashscope.py
from typing import Any
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
from llama_index.embeddings.dashscope import DashScopeEmbedding
from llama_index.llms.dashscope import DashScope
def initialize_embed_model(
embed_endpoint: str, # noqa: ARG001
embed_api_key: str,
embed_model: str,
**embed_extra: Any, # noqa: ANN401
) -> BaseEmbedding:
"""
Create DashScope embedding model.
Args:
embed_endpoint: Not be used directly by the constructor.
embed_api_key: The API key for the DashScope API.
embed_model: The name of the embedding model.
embed_extra: Extra parameters of the embedding model.
Returns:
The initialized embed_model.
"""
# DashScope typically uses the API key and model name.
# The endpoint might be set via environment variables or default.
# We pass embed_api_key and embed_model to the constructor.
# We include embed_endpoint in the signature to match the factory interface,
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
return DashScopeEmbedding(
model_name=embed_model,
api_key=embed_api_key,
**embed_extra,
)
def initialize_llm_model(
llm_endpoint: str, # noqa: ARG001
llm_api_key: str,
llm_model: str,
**llm_extra: Any, # noqa: ANN401
) -> LLM:
"""
Create DashScope LLM model.
Args:
llm_endpoint: Not be used directly by the constructor.
llm_api_key: The API key for the DashScope API.
llm_model: The name of the LLM model.
llm_extra: Extra parameters of the LLM model.
Returns:
The initialized llm_model.
"""
# DashScope typically uses the API key and model name.
# The endpoint might be set via environment variables or default.
# We pass llm_api_key and llm_model to the constructor.
# We include llm_endpoint in the signature to match the factory interface,
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
return DashScope(
model_name=llm_model,
api_key=llm_api_key,
**llm_extra,
)

View File

@@ -0,0 +1,179 @@
import importlib
from typing import TYPE_CHECKING, Any, cast
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
if TYPE_CHECKING:
from collections.abc import Callable
from libs.logger import logger # Assuming libs.logger exists and provides a logger instance
def initialize_embed_model(
embed_provider: str,
embed_model: str,
embed_endpoint: str | None = None,
embed_api_key: str | None = None,
embed_extra: dict[str, Any] | None = None,
) -> BaseEmbedding:
"""
Initialize embedding model based on specified provider and configuration.
Dynamically loads the provider module based on the embed_provider parameter.
Args:
embed_provider: The name of the embedding provider (e.g., "openai", "ollama").
embed_model: The name of the embedding model.
embed_endpoint: The API endpoint for the embedding provider.
embed_api_key: The API key for the embedding provider.
embed_extra: Additional provider-specific configuration parameters.
Returns:
The initialized embed_model.
Raises:
ValueError: If the specified embed_provider is not supported or module/function not found.
RuntimeError: If model initialization fails for the selected provider.
"""
# Validate provider name
error_msg = f"Invalid EMBED_PROVIDER specified: '{embed_provider}'. Provider name must be alphanumeric or contain underscores."
if not embed_provider.replace("_", "").isalnum():
raise ValueError(error_msg)
try:
provider_module = importlib.import_module(f".{embed_provider}", package="providers")
logger.debug(f"Successfully imported provider module: providers.{embed_provider}")
attribute = getattr(provider_module, "initialize_embed_model", None)
if attribute is None:
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
raise ValueError(error_msg) # noqa: TRY301
initializer = cast("Callable[..., BaseEmbedding]", attribute)
except ImportError as err:
error_msg = f"Unsupported EMBED_PROVIDER specified: '{embed_provider}'. Could not find provider module 'providers.{embed_provider}"
raise ValueError(error_msg) from err
except AttributeError as err:
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
raise ValueError(error_msg) from err
except Exception as err:
logger.error(
f"An unexpected error occurred while loading provider '{embed_provider}': {err!r}",
exc_info=True,
)
error_msg = f"Failed to load provider '{embed_provider}' due to an unexpected error."
raise RuntimeError(error_msg) from err
logger.info(f"Initializing embedding model for provider: {embed_provider}")
try:
embedding: BaseEmbedding = initializer(
embed_endpoint,
embed_api_key,
embed_model,
**(embed_extra or {}),
)
logger.info(f"Embedding model initialized successfully for {embed_provider}")
return embedding
except TypeError as err:
error_msg = f"Provider initializer 'initialize_embed_model' was called with incorrect arguments in '{embed_provider}'"
logger.error(
f"{error_msg}: {err!r}",
exc_info=True,
)
raise RuntimeError(error_msg) from err
except Exception as err:
error_msg = f"Failed to initialize embedding model for provider '{embed_provider}'"
logger.error(
f"{error_msg}: {err!r}",
exc_info=True,
)
raise RuntimeError(error_msg) from err
def initialize_llm_model(
llm_provider: str,
llm_model: str,
llm_endpoint: str | None = None,
llm_api_key: str | None = None,
llm_extra: dict[str, Any] | None = None,
) -> LLM:
"""
Create LLM model with the specified configuration.
Dynamically loads the provider module based on the llm_provider parameter.
Args:
llm_provider: The name of the LLM provider (e.g., "openai", "ollama").
llm_endpoint: The API endpoint for the LLM provider.
llm_api_key: The API key for the LLM provider.
llm_model: The name of the LLM model.
llm_extra: The name of the LLM model.
Returns:
The initialized llm_model.
Raises:
ValueError: If the specified llm_provider is not supported or module/function not found.
RuntimeError: If model initialization fails for the selected provider.
"""
if not llm_provider.replace("_", "").isalnum():
error_msg = f"Invalid LLM_PROVIDER specified: '{llm_provider}'. Provider name must be alphanumeric or contain underscores."
raise ValueError(error_msg)
try:
provider_module = importlib.import_module(
f".{llm_provider}",
package="providers",
)
logger.debug(f"Successfully imported provider module: providers.{llm_provider}")
attribute = getattr(provider_module, "initialize_llm_model", None)
if attribute is None:
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
raise ValueError(error_msg) # noqa: TRY301
initializer = cast("Callable[..., LLM]", attribute)
except ImportError as err:
error_msg = f"Unsupported LLM_PROVIDER specified: '{llm_provider}'. Could not find provider module 'providers.{llm_provider}'."
raise ValueError(error_msg) from err
except AttributeError as err:
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
raise ValueError(error_msg) from err
except Exception as e:
error_msg = f"An unexpected error occurred while loading provider '{llm_provider}': {e}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
logger.info(f"Initializing LLM model for provider: '{llm_provider}'")
logger.debug(f"Args: llm_model='{llm_model}', llm_endpoint='{llm_endpoint}'")
try:
llm: LLM = initializer(
llm_endpoint,
llm_api_key,
llm_model,
**(llm_extra or {}),
)
logger.info(f"LLM model initialized successfully for '{llm_provider}'.")
except TypeError as e:
error_msg = f"Provider initializer 'initialize_llm_model' in '{llm_provider}' was called with incorrect arguments: {e}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
except Exception as e:
error_msg = f"Failed to initialize LLM model for provider '{llm_provider}': {e}"
logger.error(
error_msg,
exc_info=True,
)
raise RuntimeError(error_msg) from e
return llm

View File

@@ -0,0 +1,66 @@
# src/providers/ollama.py
from typing import Any
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
def initialize_embed_model(
embed_endpoint: str,
embed_api_key: str, # noqa: ARG001
embed_model: str,
**embed_extra: Any, # noqa: ANN401
) -> BaseEmbedding:
"""
Create Ollama embedding model.
Args:
embed_endpoint: The API endpoint for the Ollama API.
embed_api_key: Not be used by Ollama.
embed_model: The name of the embedding model.
embed_extra: Extra parameters for Ollama embedding model.
Returns:
The initialized embed_model.
"""
# Ollama typically uses the endpoint directly and may not require an API key
# We include embed_api_key in the signature to match the factory interface
# Pass embed_api_key even if Ollama doesn't use it, to match the signature
return OllamaEmbedding(
model_name=embed_model,
base_url=embed_endpoint,
**embed_extra,
)
def initialize_llm_model(
llm_endpoint: str,
llm_api_key: str, # noqa: ARG001
llm_model: str,
**llm_extra: Any, # noqa: ANN401
) -> LLM:
"""
Create Ollama LLM model.
Args:
llm_endpoint: The API endpoint for the Ollama API.
llm_api_key: Not be used by Ollama.
llm_model: The name of the LLM model.
llm_extra: Extra parameters for LLM model.
Returns:
The initialized llm_model.
"""
# Ollama typically uses the endpoint directly and may not require an API key
# We include llm_api_key in the signature to match the factory interface
# Pass llm_api_key even if Ollama doesn't use it, to match the signature
return Ollama(
model=llm_model,
base_url=llm_endpoint,
**llm_extra,
)

View File

@@ -0,0 +1,68 @@
# src/providers/openai.py
from typing import Any
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
def initialize_embed_model(
embed_endpoint: str,
embed_api_key: str,
embed_model: str,
**embed_extra: Any, # noqa: ANN401
) -> BaseEmbedding:
"""
Create OpenAI embedding model.
Args:
embed_model: The name of the embedding model.
embed_endpoint: The API endpoint for the OpenAI API.
embed_api_key: The API key for the OpenAI API.
embed_extra: Extra Paramaters for the OpenAI API.
Returns:
The initialized embed_model.
"""
# Use the provided endpoint directly.
# Note: OpenAIEmbedding automatically picks up OPENAI_API_KEY env var
# We are not using embed_api_key parameter here, relying on env var as original code did.
return OpenAIEmbedding(
model=embed_model,
api_base=embed_endpoint,
api_key=embed_api_key,
**embed_extra,
)
def initialize_llm_model(
llm_endpoint: str,
llm_api_key: str,
llm_model: str,
**llm_extra: Any, # noqa: ANN401
) -> LLM:
"""
Create OpenAI LLM model.
Args:
llm_model: The name of the LLM model.
llm_endpoint: The API endpoint for the OpenAI API.
llm_api_key: The API key for the OpenAI API.
llm_extra: Extra paramaters for the OpenAI API.
Returns:
The initialized llm_model.
"""
# Use the provided endpoint directly.
# Note: OpenAI automatically picks up OPENAI_API_KEY env var
# We are not using llm_api_key parameter here, relying on env var as original code did.
return OpenAI(
model=llm_model,
api_base=llm_endpoint,
api_key=llm_api_key,
**llm_extra,
)

View File

@@ -0,0 +1,35 @@
# src/providers/openrouter.py
from typing import Any
from llama_index.core.llms.llm import LLM
from llama_index.llms.openrouter import OpenRouter
def initialize_llm_model(
llm_endpoint: str,
llm_api_key: str,
llm_model: str,
**llm_extra: Any, # noqa: ANN401
) -> LLM:
"""
Create OpenRouter LLM model.
Args:
llm_model: The name of the LLM model.
llm_endpoint: The API endpoint for the OpenRouter API.
llm_api_key: The API key for the OpenRouter API.
llm_extra: The Extra Parameters for OpenROuter,
Returns:
The initialized llm_model.
"""
# Use the provided endpoint directly.
# We are not using llm_api_key parameter here, relying on env var as original code did.
return OpenRouter(
model=llm_model,
api_base=llm_endpoint,
api_key=llm_api_key,
**llm_extra,
)