feat: Enhanced Model Provider Support and Configuration Flexibility For Rag Service (#2056)
Co-authored-by: doodleEsc <cokie@foxmail.com> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Standard library imports
|
||||
import asyncio
|
||||
import fcntl
|
||||
import json
|
||||
@@ -18,42 +19,24 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
# Third-party imports
|
||||
import chromadb
|
||||
import httpx
|
||||
import pathspec
|
||||
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
||||
from libs.configs import (
|
||||
BASE_DATA_DIR,
|
||||
CHROMA_PERSIST_DIR,
|
||||
)
|
||||
|
||||
# Local application imports
|
||||
from libs.configs import BASE_DATA_DIR, CHROMA_PERSIST_DIR
|
||||
from libs.db import init_db
|
||||
from libs.logger import logger
|
||||
from libs.utils import (
|
||||
get_node_uri,
|
||||
inject_uri_to_node,
|
||||
is_local_uri,
|
||||
is_path_node,
|
||||
is_remote_uri,
|
||||
path_to_uri,
|
||||
uri_to_path,
|
||||
)
|
||||
from llama_index.core import (
|
||||
Settings,
|
||||
SimpleDirectoryReader,
|
||||
StorageContext,
|
||||
VectorStoreIndex,
|
||||
load_index_from_storage,
|
||||
)
|
||||
from libs.utils import get_node_uri, inject_uri_to_node, is_local_uri, is_path_node, is_remote_uri, path_to_uri, uri_to_path
|
||||
from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.node_parser import CodeSplitter
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from markdownify import markdownify as md
|
||||
from models.indexing_history import IndexingHistory # noqa: TC002
|
||||
from models.resource import Resource
|
||||
from providers.factory import initialize_embed_model, initialize_llm_model
|
||||
from pydantic import BaseModel, Field
|
||||
from services.indexing_history import indexing_history_service
|
||||
from services.resource import resource_service
|
||||
@@ -65,6 +48,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from models.indexing_history import IndexingHistory
|
||||
from watchdog.observers.api import BaseObserver
|
||||
|
||||
# Lock file for leader election
|
||||
@@ -111,8 +95,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if is_local_uri(resource.uri):
|
||||
directory = uri_to_path(resource.uri)
|
||||
if not directory.exists():
|
||||
logger.error("Directory not found: %s", directory)
|
||||
resource_service.update_resource_status(resource.uri, "error", f"Directory not found: {directory}")
|
||||
error_msg = f"Directory not found: {directory}"
|
||||
logger.error(error_msg)
|
||||
resource_service.update_resource_status(resource.uri, "error", error_msg)
|
||||
continue
|
||||
|
||||
# Start file system watcher
|
||||
@@ -127,8 +112,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
|
||||
elif is_remote_uri(resource.uri):
|
||||
if not is_remote_resource_exists(resource.uri):
|
||||
logger.error("HTTPS resource not found: %s", resource.uri)
|
||||
resource_service.update_resource_status(resource.uri, "error", "remote resource not found")
|
||||
error_msg = "HTTPS resource not found"
|
||||
logger.error("%s: %s", error_msg, resource.uri)
|
||||
resource_service.update_resource_status(resource.uri, "error", error_msg)
|
||||
continue
|
||||
|
||||
# Start indexing
|
||||
@@ -273,7 +259,11 @@ def is_remote_resource_exists(url: str) -> bool:
|
||||
"""Check if a URL exists."""
|
||||
try:
|
||||
response = httpx.head(url, headers=http_headers)
|
||||
return response.status_code in {httpx.codes.OK, httpx.codes.MOVED_PERMANENTLY, httpx.codes.FOUND}
|
||||
return response.status_code in {
|
||||
httpx.codes.OK,
|
||||
httpx.codes.MOVED_PERMANENTLY,
|
||||
httpx.codes.FOUND,
|
||||
}
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
logger.error("Error checking if URL exists %s: %s", url, e)
|
||||
return False
|
||||
@@ -318,53 +308,78 @@ init_db()
|
||||
# Initialize ChromaDB and LlamaIndex services
|
||||
chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR))
|
||||
|
||||
# Check if provider or model has changed
|
||||
current_provider = os.getenv("RAG_PROVIDER", "openai").lower()
|
||||
current_embed_model = os.getenv("RAG_EMBED_MODEL", "")
|
||||
current_llm_model = os.getenv("RAG_LLM_MODEL", "")
|
||||
# # Check if provider or model has changed
|
||||
rag_embed_provider = os.getenv("RAG_EMBED_PROVIDER", "openai")
|
||||
rag_embed_endpoint = os.getenv("RAG_EMBED_ENDPOINT", "https://api.openai.com/v1")
|
||||
rag_embed_model = os.getenv("RAG_EMBED_MODEL", "text-embedding-3-large")
|
||||
rag_embed_api_key = os.getenv("RAG_EMBED_API_KEY", None)
|
||||
rag_embed_extra = os.getenv("RAG_EMBED_EXTRA", None)
|
||||
|
||||
rag_llm_provider = os.getenv("RAG_LLM_PROVIDER", "openai")
|
||||
rag_llm_endpoint = os.getenv("RAG_LLM_ENDPOINT", "https://api.openai.com/v1")
|
||||
rag_llm_model = os.getenv("RAG_LLM_MODEL", "gpt-4o-mini")
|
||||
rag_llm_api_key = os.getenv("RAG_LLM_API_KEY", None)
|
||||
rag_llm_extra = os.getenv("RAG_LLM_EXTRA", None)
|
||||
|
||||
# Try to read previous config
|
||||
config_file = BASE_DATA_DIR / "rag_config.json"
|
||||
if config_file.exists():
|
||||
with Path.open(config_file, "r") as f:
|
||||
prev_config = json.load(f)
|
||||
if prev_config.get("provider") != current_provider or prev_config.get("embed_model") != current_embed_model:
|
||||
if prev_config.get("provider") != rag_embed_provider or prev_config.get("embed_model") != rag_embed_model:
|
||||
# Clear existing data if config changed
|
||||
logger.info("Detected config change, clearing existing data...")
|
||||
chroma_client.reset()
|
||||
|
||||
# Save current config
|
||||
with Path.open(config_file, "w") as f:
|
||||
json.dump({"provider": current_provider, "embed_model": current_embed_model}, f)
|
||||
json.dump({"provider": rag_embed_provider, "embed_model": rag_embed_model}, f)
|
||||
|
||||
chroma_collection = chroma_client.get_or_create_collection("documents") # pyright: ignore
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
|
||||
# Initialize embedding model based on provider
|
||||
llm_provider = current_provider
|
||||
base_url = os.getenv(llm_provider.upper() + "_API_BASE", "")
|
||||
rag_embed_model = current_embed_model
|
||||
rag_llm_model = current_llm_model
|
||||
try:
|
||||
embed_extra = json.loads(rag_embed_extra) if rag_embed_extra is not None else {}
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode RAG_EMBED_EXTRA, defaulting to empty dict.")
|
||||
embed_extra = {}
|
||||
|
||||
try:
|
||||
llm_extra = json.loads(rag_llm_extra) if rag_llm_extra is not None else {}
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode RAG_LLM_EXTRA, defaulting to empty dict.")
|
||||
llm_extra = {}
|
||||
|
||||
# Initialize embedding model and LLM based on provider using the factory
|
||||
try:
|
||||
embed_model = initialize_embed_model(
|
||||
embed_provider=rag_embed_provider,
|
||||
embed_model=rag_embed_model,
|
||||
embed_endpoint=rag_embed_endpoint,
|
||||
embed_api_key=rag_embed_api_key,
|
||||
embed_extra=embed_extra,
|
||||
)
|
||||
logger.info("Embedding model initialized successfully.")
|
||||
except (ValueError, RuntimeError) as e:
|
||||
error_msg = f"Failed to initialize embedding model: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
try:
|
||||
llm_model = initialize_llm_model(
|
||||
llm_provider=rag_llm_provider,
|
||||
llm_model=rag_llm_model,
|
||||
llm_endpoint=rag_llm_endpoint,
|
||||
llm_api_key=rag_llm_api_key,
|
||||
llm_extra=llm_extra,
|
||||
)
|
||||
logger.info("LLM model initialized successfully.")
|
||||
except (ValueError, RuntimeError) as e:
|
||||
error_msg = f"Failed to initialize LLM model: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
if llm_provider == "ollama":
|
||||
if base_url == "":
|
||||
base_url = "http://localhost:11434"
|
||||
if rag_embed_model == "":
|
||||
rag_embed_model = "nomic-embed-text"
|
||||
if rag_llm_model == "":
|
||||
rag_llm_model = "llama3"
|
||||
embed_model = OllamaEmbedding(model_name=rag_embed_model, base_url=base_url)
|
||||
llm_model = Ollama(model=rag_llm_model, base_url=base_url, request_timeout=60.0)
|
||||
else:
|
||||
if base_url == "":
|
||||
base_url = "https://api.openai.com/v1"
|
||||
if rag_embed_model == "":
|
||||
rag_embed_model = "text-embedding-3-small"
|
||||
if rag_llm_model == "":
|
||||
rag_llm_model = "gpt-4o-mini"
|
||||
embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url)
|
||||
llm_model = OpenAI(model=rag_llm_model, api_base=base_url)
|
||||
|
||||
Settings.embed_model = embed_model
|
||||
Settings.llm = llm_model
|
||||
@@ -400,7 +415,10 @@ class SourceDocument(BaseModel):
|
||||
class RetrieveRequest(BaseModel):
|
||||
"""Request model for information retrieval."""
|
||||
|
||||
query: str = Field(..., description="The query text to search for in the indexed documents")
|
||||
query: str = Field(
|
||||
...,
|
||||
description="The query text to search for in the indexed documents",
|
||||
)
|
||||
base_uri: str = Field(..., description="The base URI to search in")
|
||||
top_k: int | None = Field(5, description="Number of top results to return", ge=1, le=20)
|
||||
|
||||
@@ -478,7 +496,10 @@ def process_document_batch(documents: list[Document]) -> bool: # noqa: PLR0915,
|
||||
# Check if document with same hash has already been successfully processed
|
||||
status_records = indexing_history_service.get_indexing_status(doc=doc)
|
||||
if status_records and status_records[0].status == "completed":
|
||||
logger.debug("Document with same hash already processed, skipping: %s", doc.doc_id)
|
||||
logger.debug(
|
||||
"Document with same hash already processed, skipping: %s",
|
||||
doc.doc_id,
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug("Processing document: %s", doc.doc_id)
|
||||
@@ -592,7 +613,10 @@ def get_gitcrypt_files(directory: Path) -> list[str]:
|
||||
)
|
||||
|
||||
if git_root_cmd.returncode != 0:
|
||||
logger.warning("Not a git repository or git command failed: %s", git_root_cmd.stderr.strip())
|
||||
logger.warning(
|
||||
"Not a git repository or git command failed: %s",
|
||||
git_root_cmd.stderr.strip(),
|
||||
)
|
||||
return git_crypt_patterns
|
||||
|
||||
git_root = Path(git_root_cmd.stdout.strip())
|
||||
@@ -613,7 +637,15 @@ def get_gitcrypt_files(directory: Path) -> list[str]:
|
||||
|
||||
# Use Python to process the output instead of xargs, grep, and cut
|
||||
git_check_attr = subprocess.run(
|
||||
[git_executable, "-C", str(git_root), "check-attr", "filter", "--stdin", "-z"],
|
||||
[
|
||||
git_executable,
|
||||
"-C",
|
||||
str(git_root),
|
||||
"check-attr",
|
||||
"filter",
|
||||
"--stdin",
|
||||
"-z",
|
||||
],
|
||||
input=git_ls_files.stdout,
|
||||
capture_output=True,
|
||||
text=False,
|
||||
@@ -830,7 +862,11 @@ def split_documents(documents: list[Document]) -> list[Document]:
|
||||
t = doc.get_content()
|
||||
texts = code_splitter.split_text(t)
|
||||
except ValueError as e:
|
||||
logger.error("Error splitting document: %s, so skipping split, error: %s", doc.doc_id, str(e))
|
||||
logger.error(
|
||||
"Error splitting document: %s, so skipping split, error: %s",
|
||||
doc.doc_id,
|
||||
str(e),
|
||||
)
|
||||
processed_documents.append(doc)
|
||||
continue
|
||||
|
||||
@@ -1034,7 +1070,10 @@ async def add_resource(request: ResourceRequest, background_tasks: BackgroundTas
|
||||
|
||||
if resource:
|
||||
if resource.name != request.name:
|
||||
raise HTTPException(status_code=400, detail=f"Resource name cannot be changed: {resource.name}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Resource name cannot be changed: {resource.name}",
|
||||
)
|
||||
|
||||
resource_service.update_resource_status(resource.uri, "active")
|
||||
else:
|
||||
@@ -1352,3 +1391,9 @@ async def list_resources() -> ResourceListResponse:
|
||||
total_count=len(resources),
|
||||
status_summary=status_counts,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
|
||||
Reference in New Issue
Block a user