Initial commit: LLM-DS optimizer framework with data files excluded
This commit is contained in:
222
llmds/inverted_index.py
Normal file
222
llmds/inverted_index.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Compressed inverted index with BM25 scoring.
|
||||
|
||||
Implementation based on:
|
||||
Robertson, S., & Zaragoza, H. (2009). The probabilistic relevance framework:
|
||||
BM25 and beyond. Foundations and Trends in Information Retrieval, 3(4), 333-389.
|
||||
|
||||
See docs/CITATIONS.md for full citation details.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from llmds.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class InvertedIndex:
|
||||
"""
|
||||
Compressed inverted index with varint/zigzag encoding and BM25 scoring.
|
||||
|
||||
Stores postings lists with compression and provides BM25 retrieval.
|
||||
|
||||
Reference:
|
||||
Robertson & Zaragoza (2009). The probabilistic relevance framework:
|
||||
BM25 and beyond.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: Optional[Tokenizer] = None):
|
||||
"""
|
||||
Initialize inverted index.
|
||||
|
||||
Args:
|
||||
tokenizer: Tokenizer instance (creates default if None)
|
||||
"""
|
||||
self.tokenizer = tokenizer or Tokenizer()
|
||||
self._inverted_lists: dict[str, list[int]] = defaultdict(list) # term -> doc_ids
|
||||
self._doc_lengths: dict[int, int] = {} # doc_id -> length
|
||||
self._doc_terms: dict[int, dict[str, int]] = {} # doc_id -> term -> count
|
||||
self._total_docs = 0
|
||||
self._avg_doc_length = 0.0
|
||||
# BM25 parameters
|
||||
self.k1 = 1.2
|
||||
self.b = 0.75
|
||||
|
||||
def _encode_varint(self, value: int) -> bytes:
|
||||
"""Encode integer as varint."""
|
||||
result = bytearray()
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
def _decode_varint(self, data: bytes, offset: int) -> tuple[int, int]:
|
||||
"""Decode varint from bytes."""
|
||||
value = 0
|
||||
shift = 0
|
||||
while offset < len(data):
|
||||
byte = data[offset]
|
||||
value |= (byte & 0x7F) << shift
|
||||
offset += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
return value, offset
|
||||
|
||||
def _zigzag_encode(self, value: int) -> int:
|
||||
"""Zigzag encode for signed integers."""
|
||||
return (value << 1) ^ (value >> 31)
|
||||
|
||||
def _zigzag_decode(self, value: int) -> int:
|
||||
"""Zigzag decode."""
|
||||
return (value >> 1) ^ (-(value & 1))
|
||||
|
||||
def add_document(self, doc_id: int, text: str) -> None:
|
||||
"""
|
||||
Add a document to the index.
|
||||
|
||||
Args:
|
||||
doc_id: Document identifier
|
||||
text: Document text
|
||||
"""
|
||||
tokens = self.tokenizer.encode(text)
|
||||
term_counts: dict[str, int] = defaultdict(int)
|
||||
|
||||
# Count term frequencies
|
||||
for token_id in tokens:
|
||||
term = self.tokenizer.decode([token_id])
|
||||
if term:
|
||||
term_counts[term] += 1
|
||||
|
||||
# Update inverted lists
|
||||
for term, count in term_counts.items():
|
||||
if doc_id not in self._inverted_lists[term]:
|
||||
self._inverted_lists[term].append(doc_id)
|
||||
|
||||
# Store document metadata
|
||||
self._doc_lengths[doc_id] = len(tokens)
|
||||
self._doc_terms[doc_id] = term_counts
|
||||
|
||||
# Update average document length
|
||||
self._total_docs += 1
|
||||
total_length = sum(self._doc_lengths.values())
|
||||
self._avg_doc_length = total_length / self._total_docs if self._total_docs > 0 else 0.0
|
||||
|
||||
def _bm25_score(self, term: str, doc_id: int, query_term_freq: int) -> float:
|
||||
"""
|
||||
Calculate BM25 score for a term-document pair.
|
||||
|
||||
Args:
|
||||
term: Query term
|
||||
doc_id: Document ID
|
||||
query_term_freq: Frequency of term in query
|
||||
|
||||
Returns:
|
||||
BM25 score
|
||||
"""
|
||||
if doc_id not in self._doc_terms or term not in self._doc_terms[doc_id]:
|
||||
return 0.0
|
||||
|
||||
# Term frequency in document
|
||||
tf = self._doc_terms[doc_id][term]
|
||||
|
||||
# Document frequency
|
||||
df = len(self._inverted_lists.get(term, []))
|
||||
|
||||
# Inverse document frequency
|
||||
idf = 0.0
|
||||
if df > 0:
|
||||
idf = (self._total_docs - df + 0.5) / (df + 0.5)
|
||||
idf = max(0.0, idf) # Avoid negative IDF
|
||||
|
||||
# Document length normalization
|
||||
doc_length = self._doc_lengths.get(doc_id, 1)
|
||||
length_norm = (1 - self.b) + self.b * (doc_length / self._avg_doc_length)
|
||||
|
||||
# BM25 formula
|
||||
score = (
|
||||
idf
|
||||
* (tf * (self.k1 + 1))
|
||||
/ (tf + self.k1 * length_norm)
|
||||
* (query_term_freq / (query_term_freq + 0.5))
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
def search(self, query: str, top_k: int = 10) -> list[tuple[int, float]]:
|
||||
"""
|
||||
Search the index with BM25 scoring.
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
top_k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
List of (doc_id, score) tuples sorted by score descending
|
||||
"""
|
||||
query_tokens = self.tokenizer.encode(query)
|
||||
query_term_counts: dict[str, int] = defaultdict(int)
|
||||
|
||||
for token_id in query_tokens:
|
||||
term = self.tokenizer.decode([token_id])
|
||||
if term:
|
||||
query_term_counts[term] += 1
|
||||
|
||||
# Score all candidate documents
|
||||
doc_scores: dict[int, float] = defaultdict(float)
|
||||
|
||||
for term, query_freq in query_term_counts.items():
|
||||
if term in self._inverted_lists:
|
||||
for doc_id in self._inverted_lists[term]:
|
||||
score = self._bm25_score(term, doc_id, query_freq)
|
||||
doc_scores[doc_id] += score
|
||||
|
||||
# Sort by score and return top-k
|
||||
sorted_results = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_results[:top_k]
|
||||
|
||||
def get_term_frequency(self, term: str, doc_id: int) -> int:
|
||||
"""
|
||||
Get term frequency in a document.
|
||||
|
||||
Args:
|
||||
term: Term
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
Term frequency
|
||||
"""
|
||||
if doc_id in self._doc_terms:
|
||||
return self._doc_terms[doc_id].get(term, 0)
|
||||
return 0
|
||||
|
||||
def get_document_frequency(self, term: str) -> int:
|
||||
"""
|
||||
Get document frequency of a term.
|
||||
|
||||
Args:
|
||||
term: Term
|
||||
|
||||
Returns:
|
||||
Document frequency
|
||||
"""
|
||||
return len(self._inverted_lists.get(term, []))
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with index statistics
|
||||
"""
|
||||
total_postings = sum(len(postings) for postings in self._inverted_lists.values())
|
||||
return {
|
||||
"total_documents": self._total_docs,
|
||||
"total_terms": len(self._inverted_lists),
|
||||
"total_postings": total_postings,
|
||||
"avg_doc_length": self._avg_doc_length,
|
||||
"avg_postings_per_term": (
|
||||
total_postings / len(self._inverted_lists) if self._inverted_lists else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user