feat: Enhanced Model Provider Support and Configuration Flexibility For Rag Service (#2056)
Co-authored-by: doodleEsc <cokie@foxmail.com> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
0
py/rag-service/src/providers/__init__.py
Normal file
0
py/rag-service/src/providers/__init__.py
Normal file
70
py/rag-service/src/providers/dashscope.py
Normal file
70
py/rag-service/src/providers/dashscope.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# src/providers/dashscope.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.dashscope import DashScopeEmbedding
|
||||
from llama_index.llms.dashscope import DashScope
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str, # noqa: ARG001
|
||||
embed_api_key: str,
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create DashScope embedding model.
|
||||
|
||||
Args:
|
||||
embed_endpoint: Not be used directly by the constructor.
|
||||
embed_api_key: The API key for the DashScope API.
|
||||
embed_model: The name of the embedding model.
|
||||
embed_extra: Extra parameters of the embedding model.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# DashScope typically uses the API key and model name.
|
||||
# The endpoint might be set via environment variables or default.
|
||||
# We pass embed_api_key and embed_model to the constructor.
|
||||
# We include embed_endpoint in the signature to match the factory interface,
|
||||
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
|
||||
return DashScopeEmbedding(
|
||||
model_name=embed_model,
|
||||
api_key=embed_api_key,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str, # noqa: ARG001
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create DashScope LLM model.
|
||||
|
||||
Args:
|
||||
llm_endpoint: Not be used directly by the constructor.
|
||||
llm_api_key: The API key for the DashScope API.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: Extra parameters of the LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# DashScope typically uses the API key and model name.
|
||||
# The endpoint might be set via environment variables or default.
|
||||
# We pass llm_api_key and llm_model to the constructor.
|
||||
# We include llm_endpoint in the signature to match the factory interface,
|
||||
# but it might not be directly used by the constructor depending on LlamaIndex's implementation.
|
||||
return DashScope(
|
||||
model_name=llm_model,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
179
py/rag-service/src/providers/factory.py
Normal file
179
py/rag-service/src/providers/factory.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from libs.logger import logger # Assuming libs.logger exists and provides a logger instance
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_provider: str,
|
||||
embed_model: str,
|
||||
embed_endpoint: str | None = None,
|
||||
embed_api_key: str | None = None,
|
||||
embed_extra: dict[str, Any] | None = None,
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Initialize embedding model based on specified provider and configuration.
|
||||
|
||||
Dynamically loads the provider module based on the embed_provider parameter.
|
||||
|
||||
Args:
|
||||
embed_provider: The name of the embedding provider (e.g., "openai", "ollama").
|
||||
embed_model: The name of the embedding model.
|
||||
embed_endpoint: The API endpoint for the embedding provider.
|
||||
embed_api_key: The API key for the embedding provider.
|
||||
embed_extra: Additional provider-specific configuration parameters.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified embed_provider is not supported or module/function not found.
|
||||
RuntimeError: If model initialization fails for the selected provider.
|
||||
|
||||
"""
|
||||
# Validate provider name
|
||||
error_msg = f"Invalid EMBED_PROVIDER specified: '{embed_provider}'. Provider name must be alphanumeric or contain underscores."
|
||||
if not embed_provider.replace("_", "").isalnum():
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
provider_module = importlib.import_module(f".{embed_provider}", package="providers")
|
||||
logger.debug(f"Successfully imported provider module: providers.{embed_provider}")
|
||||
attribute = getattr(provider_module, "initialize_embed_model", None)
|
||||
if attribute is None:
|
||||
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
|
||||
raise ValueError(error_msg) # noqa: TRY301
|
||||
|
||||
initializer = cast("Callable[..., BaseEmbedding]", attribute)
|
||||
|
||||
except ImportError as err:
|
||||
error_msg = f"Unsupported EMBED_PROVIDER specified: '{embed_provider}'. Could not find provider module 'providers.{embed_provider}"
|
||||
raise ValueError(error_msg) from err
|
||||
except AttributeError as err:
|
||||
error_msg = f"Provider module '{embed_provider}' does not have an 'initialize_embed_model' function."
|
||||
raise ValueError(error_msg) from err
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"An unexpected error occurred while loading provider '{embed_provider}': {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
error_msg = f"Failed to load provider '{embed_provider}' due to an unexpected error."
|
||||
raise RuntimeError(error_msg) from err
|
||||
|
||||
logger.info(f"Initializing embedding model for provider: {embed_provider}")
|
||||
|
||||
try:
|
||||
embedding: BaseEmbedding = initializer(
|
||||
embed_endpoint,
|
||||
embed_api_key,
|
||||
embed_model,
|
||||
**(embed_extra or {}),
|
||||
)
|
||||
|
||||
logger.info(f"Embedding model initialized successfully for {embed_provider}")
|
||||
return embedding
|
||||
except TypeError as err:
|
||||
error_msg = f"Provider initializer 'initialize_embed_model' was called with incorrect arguments in '{embed_provider}'"
|
||||
logger.error(
|
||||
f"{error_msg}: {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from err
|
||||
except Exception as err:
|
||||
error_msg = f"Failed to initialize embedding model for provider '{embed_provider}'"
|
||||
logger.error(
|
||||
f"{error_msg}: {err!r}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from err
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_provider: str,
|
||||
llm_model: str,
|
||||
llm_endpoint: str | None = None,
|
||||
llm_api_key: str | None = None,
|
||||
llm_extra: dict[str, Any] | None = None,
|
||||
) -> LLM:
|
||||
"""
|
||||
Create LLM model with the specified configuration.
|
||||
|
||||
Dynamically loads the provider module based on the llm_provider parameter.
|
||||
|
||||
Args:
|
||||
llm_provider: The name of the LLM provider (e.g., "openai", "ollama").
|
||||
llm_endpoint: The API endpoint for the LLM provider.
|
||||
llm_api_key: The API key for the LLM provider.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: The name of the LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified llm_provider is not supported or module/function not found.
|
||||
RuntimeError: If model initialization fails for the selected provider.
|
||||
|
||||
"""
|
||||
if not llm_provider.replace("_", "").isalnum():
|
||||
error_msg = f"Invalid LLM_PROVIDER specified: '{llm_provider}'. Provider name must be alphanumeric or contain underscores."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
provider_module = importlib.import_module(
|
||||
f".{llm_provider}",
|
||||
package="providers",
|
||||
)
|
||||
logger.debug(f"Successfully imported provider module: providers.{llm_provider}")
|
||||
attribute = getattr(provider_module, "initialize_llm_model", None)
|
||||
if attribute is None:
|
||||
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
|
||||
raise ValueError(error_msg) # noqa: TRY301
|
||||
|
||||
initializer = cast("Callable[..., LLM]", attribute)
|
||||
|
||||
except ImportError as err:
|
||||
error_msg = f"Unsupported LLM_PROVIDER specified: '{llm_provider}'. Could not find provider module 'providers.{llm_provider}'."
|
||||
raise ValueError(error_msg) from err
|
||||
|
||||
except AttributeError as err:
|
||||
error_msg = f"Provider module '{llm_provider}' does not have an 'initialize_llm_model' function."
|
||||
raise ValueError(error_msg) from err
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"An unexpected error occurred while loading provider '{llm_provider}': {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
logger.info(f"Initializing LLM model for provider: '{llm_provider}'")
|
||||
logger.debug(f"Args: llm_model='{llm_model}', llm_endpoint='{llm_endpoint}'")
|
||||
|
||||
try:
|
||||
llm: LLM = initializer(
|
||||
llm_endpoint,
|
||||
llm_api_key,
|
||||
llm_model,
|
||||
**(llm_extra or {}),
|
||||
)
|
||||
logger.info(f"LLM model initialized successfully for '{llm_provider}'.")
|
||||
|
||||
except TypeError as e:
|
||||
error_msg = f"Provider initializer 'initialize_llm_model' in '{llm_provider}' was called with incorrect arguments: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to initialize LLM model for provider '{llm_provider}': {e}"
|
||||
logger.error(
|
||||
error_msg,
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
return llm
|
||||
66
py/rag-service/src/providers/ollama.py
Normal file
66
py/rag-service/src/providers/ollama.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# src/providers/ollama.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str,
|
||||
embed_api_key: str, # noqa: ARG001
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create Ollama embedding model.
|
||||
|
||||
Args:
|
||||
embed_endpoint: The API endpoint for the Ollama API.
|
||||
embed_api_key: Not be used by Ollama.
|
||||
embed_model: The name of the embedding model.
|
||||
embed_extra: Extra parameters for Ollama embedding model.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# Ollama typically uses the endpoint directly and may not require an API key
|
||||
# We include embed_api_key in the signature to match the factory interface
|
||||
# Pass embed_api_key even if Ollama doesn't use it, to match the signature
|
||||
return OllamaEmbedding(
|
||||
model_name=embed_model,
|
||||
base_url=embed_endpoint,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str, # noqa: ARG001
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create Ollama LLM model.
|
||||
|
||||
Args:
|
||||
llm_endpoint: The API endpoint for the Ollama API.
|
||||
llm_api_key: Not be used by Ollama.
|
||||
llm_model: The name of the LLM model.
|
||||
llm_extra: Extra parameters for LLM model.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Ollama typically uses the endpoint directly and may not require an API key
|
||||
# We include llm_api_key in the signature to match the factory interface
|
||||
# Pass llm_api_key even if Ollama doesn't use it, to match the signature
|
||||
return Ollama(
|
||||
model=llm_model,
|
||||
base_url=llm_endpoint,
|
||||
**llm_extra,
|
||||
)
|
||||
68
py/rag-service/src/providers/openai.py
Normal file
68
py/rag-service/src/providers/openai.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# src/providers/openai.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
|
||||
def initialize_embed_model(
|
||||
embed_endpoint: str,
|
||||
embed_api_key: str,
|
||||
embed_model: str,
|
||||
**embed_extra: Any, # noqa: ANN401
|
||||
) -> BaseEmbedding:
|
||||
"""
|
||||
Create OpenAI embedding model.
|
||||
|
||||
Args:
|
||||
embed_model: The name of the embedding model.
|
||||
embed_endpoint: The API endpoint for the OpenAI API.
|
||||
embed_api_key: The API key for the OpenAI API.
|
||||
embed_extra: Extra Paramaters for the OpenAI API.
|
||||
|
||||
Returns:
|
||||
The initialized embed_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# Note: OpenAIEmbedding automatically picks up OPENAI_API_KEY env var
|
||||
# We are not using embed_api_key parameter here, relying on env var as original code did.
|
||||
return OpenAIEmbedding(
|
||||
model=embed_model,
|
||||
api_base=embed_endpoint,
|
||||
api_key=embed_api_key,
|
||||
**embed_extra,
|
||||
)
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create OpenAI LLM model.
|
||||
|
||||
Args:
|
||||
llm_model: The name of the LLM model.
|
||||
llm_endpoint: The API endpoint for the OpenAI API.
|
||||
llm_api_key: The API key for the OpenAI API.
|
||||
llm_extra: Extra paramaters for the OpenAI API.
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# Note: OpenAI automatically picks up OPENAI_API_KEY env var
|
||||
# We are not using llm_api_key parameter here, relying on env var as original code did.
|
||||
return OpenAI(
|
||||
model=llm_model,
|
||||
api_base=llm_endpoint,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
35
py/rag-service/src/providers/openrouter.py
Normal file
35
py/rag-service/src/providers/openrouter.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# src/providers/openrouter.py
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.llms.openrouter import OpenRouter
|
||||
|
||||
|
||||
def initialize_llm_model(
|
||||
llm_endpoint: str,
|
||||
llm_api_key: str,
|
||||
llm_model: str,
|
||||
**llm_extra: Any, # noqa: ANN401
|
||||
) -> LLM:
|
||||
"""
|
||||
Create OpenRouter LLM model.
|
||||
|
||||
Args:
|
||||
llm_model: The name of the LLM model.
|
||||
llm_endpoint: The API endpoint for the OpenRouter API.
|
||||
llm_api_key: The API key for the OpenRouter API.
|
||||
llm_extra: The Extra Parameters for OpenROuter,
|
||||
|
||||
Returns:
|
||||
The initialized llm_model.
|
||||
|
||||
"""
|
||||
# Use the provided endpoint directly.
|
||||
# We are not using llm_api_key parameter here, relying on env var as original code did.
|
||||
return OpenRouter(
|
||||
model=llm_model,
|
||||
api_base=llm_endpoint,
|
||||
api_key=llm_api_key,
|
||||
**llm_extra,
|
||||
)
|
||||
Reference in New Issue
Block a user