Initial commit: LLM-DS optimizer framework with data files excluded
This commit is contained in:
213
llmds/retrieval_pipeline.py
Normal file
213
llmds/retrieval_pipeline.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Retrieval pipeline combining ANN, lexical search, and fusion."""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from llmds.cmsketch import CountMinSketch
|
||||
from llmds.hnsw import HNSW
|
||||
from llmds.indexed_heap import IndexedHeap
|
||||
from llmds.inverted_index import InvertedIndex
|
||||
from llmds.token_lru import TokenLRU
|
||||
from llmds.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class RetrievalPipeline:
|
||||
"""
|
||||
End-to-end retrieval pipeline combining ANN, lexical search, and fusion.
|
||||
|
||||
Combines HNSW for dense embeddings, inverted index for BM25,
|
||||
and score fusion with top-K maintenance using indexed heap.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 384,
|
||||
hnsw_M: int = 16,
|
||||
hnsw_ef_construction: int = 200,
|
||||
hnsw_ef_search: int = 50,
|
||||
token_budget: int = 100000,
|
||||
tokenizer: Optional[Tokenizer] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize retrieval pipeline.
|
||||
|
||||
Args:
|
||||
embedding_dim: Dimension of embedding vectors
|
||||
hnsw_M: HNSW M parameter
|
||||
hnsw_ef_construction: HNSW efConstruction parameter
|
||||
hnsw_ef_search: HNSW efSearch parameter
|
||||
token_budget: Token budget for cache
|
||||
tokenizer: Tokenizer instance
|
||||
seed: Optional random seed for HNSW reproducibility (default: None)
|
||||
"""
|
||||
self.tokenizer = tokenizer or Tokenizer()
|
||||
self.hnsw = HNSW(
|
||||
dim=embedding_dim,
|
||||
M=hnsw_M,
|
||||
ef_construction=hnsw_ef_construction,
|
||||
ef_search=hnsw_ef_search,
|
||||
seed=seed,
|
||||
)
|
||||
self.inverted_index = InvertedIndex(tokenizer=self.tokenizer)
|
||||
self.cmsketch = CountMinSketch(width=2048, depth=4)
|
||||
self.token_cache: TokenLRU[str, str] = TokenLRU[str, str](
|
||||
token_budget=token_budget,
|
||||
token_of=lambda text: self.tokenizer.count_tokens(text),
|
||||
)
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
doc_id: int,
|
||||
text: str,
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a document to both indices.
|
||||
|
||||
Args:
|
||||
doc_id: Document identifier
|
||||
text: Document text
|
||||
embedding: Optional embedding vector (if None, generates random)
|
||||
"""
|
||||
# Add to inverted index
|
||||
self.inverted_index.add_document(doc_id, text)
|
||||
|
||||
# Add to HNSW if embedding provided
|
||||
if embedding is not None:
|
||||
if embedding.shape != (self.hnsw.dim,):
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch: expected {self.hnsw.dim}, "
|
||||
f"got {embedding.shape[0]}"
|
||||
)
|
||||
self.hnsw.add(embedding, doc_id)
|
||||
else:
|
||||
# Generate random embedding for testing
|
||||
random_embedding = np.random.randn(self.hnsw.dim).astype(np.float32)
|
||||
random_embedding = random_embedding / np.linalg.norm(random_embedding)
|
||||
self.hnsw.add(random_embedding, doc_id)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: Optional[np.ndarray] = None,
|
||||
top_k: int = 10,
|
||||
fusion_weight: float = 0.5,
|
||||
) -> list[tuple[int, float]]:
|
||||
"""
|
||||
Search with hybrid retrieval and score fusion.
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
query_embedding: Optional query embedding vector
|
||||
top_k: Number of results to return
|
||||
fusion_weight: Weight for dense search (1-fusion_weight for BM25)
|
||||
|
||||
Returns:
|
||||
List of (doc_id, fused_score) tuples
|
||||
"""
|
||||
# Check cache
|
||||
cached = self.token_cache.get(query)
|
||||
if cached:
|
||||
self.cmsketch.add(query)
|
||||
# Parse cached string back to list of tuples
|
||||
import ast
|
||||
try:
|
||||
parsed_results = ast.literal_eval(cached)
|
||||
if isinstance(parsed_results, list):
|
||||
return parsed_results
|
||||
except (ValueError, SyntaxError):
|
||||
pass # Fall through to compute results
|
||||
|
||||
# BM25 search
|
||||
bm25_results = self.inverted_index.search(query, top_k=top_k * 2)
|
||||
|
||||
# Dense search (if embedding provided)
|
||||
dense_results = []
|
||||
if query_embedding is not None:
|
||||
dense_results = self.hnsw.search(query_embedding, k=top_k * 2)
|
||||
|
||||
# Normalize scores
|
||||
bm25_scores: dict[int, float] = {doc_id: score for doc_id, score in bm25_results}
|
||||
dense_scores: dict[int, float] = {}
|
||||
|
||||
if dense_results:
|
||||
max_dense = max(dist for _, dist in dense_results) if dense_results else 1.0
|
||||
min_dense = min(dist for _, dist in dense_results) if dense_results else 0.0
|
||||
dense_range = max_dense - min_dense if max_dense > min_dense else 1.0
|
||||
|
||||
for doc_id, dist in dense_results: # HNSW.search returns (node_id, distance)
|
||||
# Convert distance to similarity (inverse)
|
||||
normalized = 1.0 - (dist - min_dense) / dense_range if dense_range > 0 else 1.0
|
||||
dense_scores[doc_id] = normalized
|
||||
|
||||
# Normalize BM25 scores
|
||||
if bm25_scores:
|
||||
max_bm25 = max(bm25_scores.values())
|
||||
min_bm25 = min(bm25_scores.values())
|
||||
bm25_range = max_bm25 - min_bm25 if max_bm25 > min_bm25 else 1.0
|
||||
|
||||
for doc_id in bm25_scores:
|
||||
bm25_scores[doc_id] = (
|
||||
(bm25_scores[doc_id] - min_bm25) / bm25_range if bm25_range > 0 else 1.0
|
||||
)
|
||||
|
||||
# Fuse scores using indexed heap
|
||||
fused_scores: dict[int, float] = {}
|
||||
all_doc_ids = set(bm25_scores.keys()) | set(dense_scores.keys())
|
||||
|
||||
for doc_id in all_doc_ids:
|
||||
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||
dense_score = dense_scores.get(doc_id, 0.0)
|
||||
|
||||
# Weighted fusion
|
||||
fused_score = fusion_weight * dense_score + (1 - fusion_weight) * bm25_score
|
||||
fused_scores[doc_id] = fused_score
|
||||
|
||||
# Top-K using indexed heap
|
||||
heap = IndexedHeap(max_heap=True)
|
||||
for doc_id, score in fused_scores.items():
|
||||
if heap.size() < top_k:
|
||||
heap.push(doc_id, score)
|
||||
else:
|
||||
peek_result = heap.peek()
|
||||
if peek_result is not None:
|
||||
min_score, _ = peek_result
|
||||
if min_score is not None and score > min_score:
|
||||
heap.pop()
|
||||
heap.push(doc_id, score)
|
||||
|
||||
# Extract results
|
||||
results = []
|
||||
while not heap.is_empty():
|
||||
score, doc_id = heap.pop()
|
||||
results.append((doc_id, score))
|
||||
|
||||
results.reverse() # Highest score first
|
||||
|
||||
# Cache results (store as string representation for token counting)
|
||||
results_str = str(results)
|
||||
self.token_cache.put(query, results_str)
|
||||
self.cmsketch.add(query)
|
||||
|
||||
return results
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get pipeline statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with pipeline statistics
|
||||
"""
|
||||
hnsw_stats = self.hnsw.stats()
|
||||
index_stats = self.inverted_index.stats()
|
||||
|
||||
return {
|
||||
"hnsw": hnsw_stats,
|
||||
"inverted_index": index_stats,
|
||||
"cmsketch_total_count": self.cmsketch.get_total_count(),
|
||||
"cache_size": self.token_cache.size(),
|
||||
"cache_tokens": self.token_cache.total_tokens(),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user