Files
llm-rag-ds-optimizer/llmds/retrieval_pipeline.py

214 lines
7.3 KiB
Python

"""Retrieval pipeline combining ANN, lexical search, and fusion."""
from typing import Any, Optional
import numpy as np
from llmds.cmsketch import CountMinSketch
from llmds.hnsw import HNSW
from llmds.indexed_heap import IndexedHeap
from llmds.inverted_index import InvertedIndex
from llmds.token_lru import TokenLRU
from llmds.tokenizer import Tokenizer
class RetrievalPipeline:
"""
End-to-end retrieval pipeline combining ANN, lexical search, and fusion.
Combines HNSW for dense embeddings, inverted index for BM25,
and score fusion with top-K maintenance using indexed heap.
"""
def __init__(
self,
embedding_dim: int = 384,
hnsw_M: int = 16,
hnsw_ef_construction: int = 200,
hnsw_ef_search: int = 50,
token_budget: int = 100000,
tokenizer: Optional[Tokenizer] = None,
seed: Optional[int] = None,
):
"""
Initialize retrieval pipeline.
Args:
embedding_dim: Dimension of embedding vectors
hnsw_M: HNSW M parameter
hnsw_ef_construction: HNSW efConstruction parameter
hnsw_ef_search: HNSW efSearch parameter
token_budget: Token budget for cache
tokenizer: Tokenizer instance
seed: Optional random seed for HNSW reproducibility (default: None)
"""
self.tokenizer = tokenizer or Tokenizer()
self.hnsw = HNSW(
dim=embedding_dim,
M=hnsw_M,
ef_construction=hnsw_ef_construction,
ef_search=hnsw_ef_search,
seed=seed,
)
self.inverted_index = InvertedIndex(tokenizer=self.tokenizer)
self.cmsketch = CountMinSketch(width=2048, depth=4)
self.token_cache: TokenLRU[str, str] = TokenLRU[str, str](
token_budget=token_budget,
token_of=lambda text: self.tokenizer.count_tokens(text),
)
def add_document(
self,
doc_id: int,
text: str,
embedding: Optional[np.ndarray] = None,
) -> None:
"""
Add a document to both indices.
Args:
doc_id: Document identifier
text: Document text
embedding: Optional embedding vector (if None, generates random)
"""
# Add to inverted index
self.inverted_index.add_document(doc_id, text)
# Add to HNSW if embedding provided
if embedding is not None:
if embedding.shape != (self.hnsw.dim,):
raise ValueError(
f"Embedding dimension mismatch: expected {self.hnsw.dim}, "
f"got {embedding.shape[0]}"
)
self.hnsw.add(embedding, doc_id)
else:
# Generate random embedding for testing
random_embedding = np.random.randn(self.hnsw.dim).astype(np.float32)
random_embedding = random_embedding / np.linalg.norm(random_embedding)
self.hnsw.add(random_embedding, doc_id)
def search(
self,
query: str,
query_embedding: Optional[np.ndarray] = None,
top_k: int = 10,
fusion_weight: float = 0.5,
) -> list[tuple[int, float]]:
"""
Search with hybrid retrieval and score fusion.
Args:
query: Query text
query_embedding: Optional query embedding vector
top_k: Number of results to return
fusion_weight: Weight for dense search (1-fusion_weight for BM25)
Returns:
List of (doc_id, fused_score) tuples
"""
# Check cache
cached = self.token_cache.get(query)
if cached:
self.cmsketch.add(query)
# Parse cached string back to list of tuples
import ast
try:
parsed_results = ast.literal_eval(cached)
if isinstance(parsed_results, list):
return parsed_results
except (ValueError, SyntaxError):
pass # Fall through to compute results
# BM25 search
bm25_results = self.inverted_index.search(query, top_k=top_k * 2)
# Dense search (if embedding provided)
dense_results = []
if query_embedding is not None:
dense_results = self.hnsw.search(query_embedding, k=top_k * 2)
# Normalize scores
bm25_scores: dict[int, float] = {doc_id: score for doc_id, score in bm25_results}
dense_scores: dict[int, float] = {}
if dense_results:
max_dense = max(dist for _, dist in dense_results) if dense_results else 1.0
min_dense = min(dist for _, dist in dense_results) if dense_results else 0.0
dense_range = max_dense - min_dense if max_dense > min_dense else 1.0
for doc_id, dist in dense_results: # HNSW.search returns (node_id, distance)
# Convert distance to similarity (inverse)
normalized = 1.0 - (dist - min_dense) / dense_range if dense_range > 0 else 1.0
dense_scores[doc_id] = normalized
# Normalize BM25 scores
if bm25_scores:
max_bm25 = max(bm25_scores.values())
min_bm25 = min(bm25_scores.values())
bm25_range = max_bm25 - min_bm25 if max_bm25 > min_bm25 else 1.0
for doc_id in bm25_scores:
bm25_scores[doc_id] = (
(bm25_scores[doc_id] - min_bm25) / bm25_range if bm25_range > 0 else 1.0
)
# Fuse scores using indexed heap
fused_scores: dict[int, float] = {}
all_doc_ids = set(bm25_scores.keys()) | set(dense_scores.keys())
for doc_id in all_doc_ids:
bm25_score = bm25_scores.get(doc_id, 0.0)
dense_score = dense_scores.get(doc_id, 0.0)
# Weighted fusion
fused_score = fusion_weight * dense_score + (1 - fusion_weight) * bm25_score
fused_scores[doc_id] = fused_score
# Top-K using indexed heap
heap = IndexedHeap(max_heap=True)
for doc_id, score in fused_scores.items():
if heap.size() < top_k:
heap.push(doc_id, score)
else:
peek_result = heap.peek()
if peek_result is not None:
min_score, _ = peek_result
if min_score is not None and score > min_score:
heap.pop()
heap.push(doc_id, score)
# Extract results
results = []
while not heap.is_empty():
score, doc_id = heap.pop()
results.append((doc_id, score))
results.reverse() # Highest score first
# Cache results (store as string representation for token counting)
results_str = str(results)
self.token_cache.put(query, results_str)
self.cmsketch.add(query)
return results
def stats(self) -> dict[str, Any]:
"""
Get pipeline statistics.
Returns:
Dictionary with pipeline statistics
"""
hnsw_stats = self.hnsw.stats()
index_stats = self.inverted_index.stats()
return {
"hnsw": hnsw_stats,
"inverted_index": index_stats,
"cmsketch_total_count": self.cmsketch.get_total_count(),
"cache_size": self.token_cache.size(),
"cache_tokens": self.token_cache.total_tokens(),
}