214 lines
7.3 KiB
Python
214 lines
7.3 KiB
Python
"""Retrieval pipeline combining ANN, lexical search, and fusion."""
|
|
|
|
from typing import Any, Optional
|
|
|
|
import numpy as np
|
|
|
|
from llmds.cmsketch import CountMinSketch
|
|
from llmds.hnsw import HNSW
|
|
from llmds.indexed_heap import IndexedHeap
|
|
from llmds.inverted_index import InvertedIndex
|
|
from llmds.token_lru import TokenLRU
|
|
from llmds.tokenizer import Tokenizer
|
|
|
|
|
|
class RetrievalPipeline:
|
|
"""
|
|
End-to-end retrieval pipeline combining ANN, lexical search, and fusion.
|
|
|
|
Combines HNSW for dense embeddings, inverted index for BM25,
|
|
and score fusion with top-K maintenance using indexed heap.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int = 384,
|
|
hnsw_M: int = 16,
|
|
hnsw_ef_construction: int = 200,
|
|
hnsw_ef_search: int = 50,
|
|
token_budget: int = 100000,
|
|
tokenizer: Optional[Tokenizer] = None,
|
|
seed: Optional[int] = None,
|
|
):
|
|
"""
|
|
Initialize retrieval pipeline.
|
|
|
|
Args:
|
|
embedding_dim: Dimension of embedding vectors
|
|
hnsw_M: HNSW M parameter
|
|
hnsw_ef_construction: HNSW efConstruction parameter
|
|
hnsw_ef_search: HNSW efSearch parameter
|
|
token_budget: Token budget for cache
|
|
tokenizer: Tokenizer instance
|
|
seed: Optional random seed for HNSW reproducibility (default: None)
|
|
"""
|
|
self.tokenizer = tokenizer or Tokenizer()
|
|
self.hnsw = HNSW(
|
|
dim=embedding_dim,
|
|
M=hnsw_M,
|
|
ef_construction=hnsw_ef_construction,
|
|
ef_search=hnsw_ef_search,
|
|
seed=seed,
|
|
)
|
|
self.inverted_index = InvertedIndex(tokenizer=self.tokenizer)
|
|
self.cmsketch = CountMinSketch(width=2048, depth=4)
|
|
self.token_cache: TokenLRU[str, str] = TokenLRU[str, str](
|
|
token_budget=token_budget,
|
|
token_of=lambda text: self.tokenizer.count_tokens(text),
|
|
)
|
|
|
|
def add_document(
|
|
self,
|
|
doc_id: int,
|
|
text: str,
|
|
embedding: Optional[np.ndarray] = None,
|
|
) -> None:
|
|
"""
|
|
Add a document to both indices.
|
|
|
|
Args:
|
|
doc_id: Document identifier
|
|
text: Document text
|
|
embedding: Optional embedding vector (if None, generates random)
|
|
"""
|
|
# Add to inverted index
|
|
self.inverted_index.add_document(doc_id, text)
|
|
|
|
# Add to HNSW if embedding provided
|
|
if embedding is not None:
|
|
if embedding.shape != (self.hnsw.dim,):
|
|
raise ValueError(
|
|
f"Embedding dimension mismatch: expected {self.hnsw.dim}, "
|
|
f"got {embedding.shape[0]}"
|
|
)
|
|
self.hnsw.add(embedding, doc_id)
|
|
else:
|
|
# Generate random embedding for testing
|
|
random_embedding = np.random.randn(self.hnsw.dim).astype(np.float32)
|
|
random_embedding = random_embedding / np.linalg.norm(random_embedding)
|
|
self.hnsw.add(random_embedding, doc_id)
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
query_embedding: Optional[np.ndarray] = None,
|
|
top_k: int = 10,
|
|
fusion_weight: float = 0.5,
|
|
) -> list[tuple[int, float]]:
|
|
"""
|
|
Search with hybrid retrieval and score fusion.
|
|
|
|
Args:
|
|
query: Query text
|
|
query_embedding: Optional query embedding vector
|
|
top_k: Number of results to return
|
|
fusion_weight: Weight for dense search (1-fusion_weight for BM25)
|
|
|
|
Returns:
|
|
List of (doc_id, fused_score) tuples
|
|
"""
|
|
# Check cache
|
|
cached = self.token_cache.get(query)
|
|
if cached:
|
|
self.cmsketch.add(query)
|
|
# Parse cached string back to list of tuples
|
|
import ast
|
|
try:
|
|
parsed_results = ast.literal_eval(cached)
|
|
if isinstance(parsed_results, list):
|
|
return parsed_results
|
|
except (ValueError, SyntaxError):
|
|
pass # Fall through to compute results
|
|
|
|
# BM25 search
|
|
bm25_results = self.inverted_index.search(query, top_k=top_k * 2)
|
|
|
|
# Dense search (if embedding provided)
|
|
dense_results = []
|
|
if query_embedding is not None:
|
|
dense_results = self.hnsw.search(query_embedding, k=top_k * 2)
|
|
|
|
# Normalize scores
|
|
bm25_scores: dict[int, float] = {doc_id: score for doc_id, score in bm25_results}
|
|
dense_scores: dict[int, float] = {}
|
|
|
|
if dense_results:
|
|
max_dense = max(dist for _, dist in dense_results) if dense_results else 1.0
|
|
min_dense = min(dist for _, dist in dense_results) if dense_results else 0.0
|
|
dense_range = max_dense - min_dense if max_dense > min_dense else 1.0
|
|
|
|
for doc_id, dist in dense_results: # HNSW.search returns (node_id, distance)
|
|
# Convert distance to similarity (inverse)
|
|
normalized = 1.0 - (dist - min_dense) / dense_range if dense_range > 0 else 1.0
|
|
dense_scores[doc_id] = normalized
|
|
|
|
# Normalize BM25 scores
|
|
if bm25_scores:
|
|
max_bm25 = max(bm25_scores.values())
|
|
min_bm25 = min(bm25_scores.values())
|
|
bm25_range = max_bm25 - min_bm25 if max_bm25 > min_bm25 else 1.0
|
|
|
|
for doc_id in bm25_scores:
|
|
bm25_scores[doc_id] = (
|
|
(bm25_scores[doc_id] - min_bm25) / bm25_range if bm25_range > 0 else 1.0
|
|
)
|
|
|
|
# Fuse scores using indexed heap
|
|
fused_scores: dict[int, float] = {}
|
|
all_doc_ids = set(bm25_scores.keys()) | set(dense_scores.keys())
|
|
|
|
for doc_id in all_doc_ids:
|
|
bm25_score = bm25_scores.get(doc_id, 0.0)
|
|
dense_score = dense_scores.get(doc_id, 0.0)
|
|
|
|
# Weighted fusion
|
|
fused_score = fusion_weight * dense_score + (1 - fusion_weight) * bm25_score
|
|
fused_scores[doc_id] = fused_score
|
|
|
|
# Top-K using indexed heap
|
|
heap = IndexedHeap(max_heap=True)
|
|
for doc_id, score in fused_scores.items():
|
|
if heap.size() < top_k:
|
|
heap.push(doc_id, score)
|
|
else:
|
|
peek_result = heap.peek()
|
|
if peek_result is not None:
|
|
min_score, _ = peek_result
|
|
if min_score is not None and score > min_score:
|
|
heap.pop()
|
|
heap.push(doc_id, score)
|
|
|
|
# Extract results
|
|
results = []
|
|
while not heap.is_empty():
|
|
score, doc_id = heap.pop()
|
|
results.append((doc_id, score))
|
|
|
|
results.reverse() # Highest score first
|
|
|
|
# Cache results (store as string representation for token counting)
|
|
results_str = str(results)
|
|
self.token_cache.put(query, results_str)
|
|
self.cmsketch.add(query)
|
|
|
|
return results
|
|
|
|
def stats(self) -> dict[str, Any]:
|
|
"""
|
|
Get pipeline statistics.
|
|
|
|
Returns:
|
|
Dictionary with pipeline statistics
|
|
"""
|
|
hnsw_stats = self.hnsw.stats()
|
|
index_stats = self.inverted_index.stats()
|
|
|
|
return {
|
|
"hnsw": hnsw_stats,
|
|
"inverted_index": index_stats,
|
|
"cmsketch_total_count": self.cmsketch.get_total_count(),
|
|
"cache_size": self.token_cache.size(),
|
|
"cache_tokens": self.token_cache.total_tokens(),
|
|
}
|
|
|