From 968d5fbd520a6c62c4265acd8ea94a9bcb9be024 Mon Sep 17 00:00:00 2001 From: Omar Crespo Date: Thu, 20 Mar 2025 00:28:04 -0500 Subject: [PATCH] fix: RAG services improvements (#1565) * fix: rag nix runner * feat: improve rag default models * fix: change log levels to avoid huge log files in rag service --- lua/avante/rag_service.lua | 6 ++++-- py/rag-service/run.sh | 6 +++--- py/rag-service/shell.nix | 1 + py/rag-service/src/main.py | 36 ++++++++++++++++++------------------ 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/lua/avante/rag_service.lua b/lua/avante/rag_service.lua index ad24c11..5cff622 100644 --- a/lua/avante/rag_service.lua +++ b/lua/avante/rag_service.lua @@ -197,12 +197,14 @@ end function M.to_local_uri(uri) local scheme = M.get_scheme(uri) - if scheme == "file" then - local path = uri:match("^file:///host(.*)$") + local path = uri:match("^file:///host(.*)$") + + if scheme == "file" and path ~= nil then local host_dir = Config.rag_service.host_mount local full_path = Path:new(host_dir):joinpath(path:sub(2)):absolute() uri = string.format("file://%s", full_path) end + return uri end diff --git a/py/rag-service/run.sh b/py/rag-service/run.sh index 0848e7e..4c66cfb 100755 --- a/py/rag-service/run.sh +++ b/py/rag-service/run.sh @@ -1,9 +1,9 @@ -#!/bin/bash +#!/usr/bin/env bash -# Set the target directory (use the first argument or default to a temporary directory) +# Set the target directory (use the first argument or default to a local state directory) TARGET_DIR=$1 if [ -z "$TARGET_DIR" ]; then - TARGET_DIR="/tmp/avante-rag-service" + TARGET_DIR="$HOME/.local/state/avante-rag-service" fi # Create the target directory if it doesn't exist mkdir -p "$TARGET_DIR" diff --git a/py/rag-service/shell.nix b/py/rag-service/shell.nix index d193330..d791338 100755 --- a/py/rag-service/shell.nix +++ b/py/rag-service/shell.nix @@ -12,6 +12,7 @@ in pkgs.mkShell { PYTHONUNBUFFERED = 1; PYTHONDONTWRITEBYTECODE = 1; LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH"; + PORT = 20250; }; shellHook = '' diff --git a/py/rag-service/src/main.py b/py/rag-service/src/main.py index 3a394d3..ac38b9e 100644 --- a/py/rag-service/src/main.py +++ b/py/rag-service/src/main.py @@ -47,7 +47,7 @@ from llama_index.core import ( from llama_index.core.node_parser import CodeSplitter from llama_index.core.schema import Document from llama_index.embeddings.ollama import OllamaEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType +from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.ollama import Ollama from llama_index.llms.openai import OpenAI from llama_index.vector_stores.chroma import ChromaVectorStore @@ -133,7 +133,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 # Start indexing await index_remote_resource_async(resource) - logger.info("Successfully synced resource: %s", resource.uri) + logger.debug("Successfully synced resource: %s", resource.uri) except (OSError, ValueError, RuntimeError) as e: error_msg = f"Failed to sync resource {resource.uri}: {e}" @@ -359,9 +359,9 @@ else: if base_url == "": base_url = "https://api.openai.com/v1" if rag_embed_model == "": - rag_embed_model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002 + rag_embed_model = "text-embedding-3-small" if rag_llm_model == "": - rag_llm_model = "gpt-3.5-turbo" + rag_llm_model = "gpt-4o-mini" embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url) llm_model = OpenAI(model=rag_llm_model, api_base=base_url) @@ -477,10 +477,10 @@ def process_document_batch(documents: list[Document]) -> bool: # noqa: PLR0915, # Check if document with same hash has already been successfully processed status_records = indexing_history_service.get_indexing_status(doc=doc) if status_records and status_records[0].status == "completed": - logger.info("Document with same hash already processed, skipping: %s", doc.doc_id) + logger.debug("Document with same hash already processed, skipping: %s", doc.doc_id) continue - logger.info("Processing document: %s", doc.doc_id) + logger.debug("Processing document: %s", doc.doc_id) try: content = doc.get_content() @@ -636,7 +636,7 @@ def get_gitcrypt_files(directory: Path) -> list[str]: # Log if git-crypt patterns were found if git_crypt_patterns: - logger.info("Excluding git-crypt encrypted files: %s", git_crypt_patterns) + logger.debug("Excluding git-crypt encrypted files: %s", git_crypt_patterns) except (subprocess.SubprocessError, OSError) as e: logger.warning("Error getting git-crypt files: %s", str(e)) @@ -752,11 +752,11 @@ def scan_directory(directory: Path) -> list[str]: for file in file_paths: file_ext = Path(file).suffix.lower() if file_ext in binary_extensions: - logger.info("Skipping binary file: %s", file) + logger.debug("Skipping binary file: %s", file) continue if spec and spec.match_file(os.path.relpath(file, directory)): - logger.info("Ignoring file: %s", file) + logger.debug("Ignoring file: %s", file) else: matched_files.append(file) @@ -765,13 +765,13 @@ def scan_directory(directory: Path) -> list[str]: def update_index_for_file(directory: Path, abs_file_path: Path) -> None: """Update the index for a single file.""" - logger.info("Starting to index file: %s", abs_file_path) + logger.debug("Starting to index file: %s", abs_file_path) rel_file_path = abs_file_path.relative_to(directory) spec = get_pathspec(directory) if spec and spec.match_file(rel_file_path): - logger.info("File is ignored, skipping: %s", abs_file_path) + logger.debug("File is ignored, skipping: %s", abs_file_path) return resource = resource_service.get_resource(path_to_uri(directory)) @@ -787,13 +787,13 @@ def update_index_for_file(directory: Path, abs_file_path: Path) -> None: required_exts=required_exts, ).load_data() - logger.info("Updating index: %s", abs_file_path) + logger.debug("Updating index: %s", abs_file_path) processed_documents = split_documents(documents) success = process_document_batch(processed_documents) if success: resource_service.update_resource_indexing_status(resource.uri, "indexed", "") - logger.info("File indexing completed: %s", abs_file_path) + logger.debug("File indexing completed: %s", abs_file_path) else: resource_service.update_resource_indexing_status(resource.uri, "failed", "unknown error") logger.error("File indexing failed: %s", abs_file_path) @@ -858,7 +858,7 @@ async def index_remote_resource_async(resource: Resource) -> None: resource_service.update_resource_indexing_status(resource.uri, "indexing", "") url = resource.uri try: - logger.info("Loading resource content: %s", url) + logger.debug("Loading resource content: %s", url) # Fetch markdown content markdown = fetch_markdown(url) @@ -868,7 +868,7 @@ async def index_remote_resource_async(resource: Resource) -> None: # Extract links from markdown links = markdown_to_links(url, markdown) - logger.info("Found %d sub links", len(links)) + logger.debug("Found %d sub links", len(links)) logger.debug("Link list: %s", links) # Use thread pool for parallel batch processing @@ -885,13 +885,13 @@ async def index_remote_resource_async(resource: Resource) -> None: # Create documents from links documents = [Document(text=markdown, doc_id=link) for link, markdown in link_md_pairs] - logger.info("Found %d documents", len(documents)) + logger.debug("Found %d documents", len(documents)) logger.debug("Document list: %s", [doc.doc_id for doc in documents]) # Process documents in batches total_documents = len(documents) batches = [documents[i : i + BATCH_SIZE] for i in range(0, total_documents, BATCH_SIZE)] - logger.info("Splitting documents into %d batches for processing", len(batches)) + logger.debug("Splitting documents into %d batches for processing", len(batches)) with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: results = await loop.run_in_executor( @@ -901,7 +901,7 @@ async def index_remote_resource_async(resource: Resource) -> None: # Check processing results if all(results): - logger.info("Resource %s indexing completed", url) + logger.debug("Resource %s indexing completed", url) resource_service.update_resource_indexing_status(resource.uri, "indexed", "") else: failed_batches = len([r for r in results if not r])