223 lines
6.8 KiB
Python
223 lines
6.8 KiB
Python
"""Compressed inverted index with BM25 scoring.
|
|
|
|
Implementation based on:
|
|
Robertson, S., & Zaragoza, H. (2009). The probabilistic relevance framework:
|
|
BM25 and beyond. Foundations and Trends in Information Retrieval, 3(4), 333-389.
|
|
|
|
See docs/CITATIONS.md for full citation details.
|
|
"""
|
|
|
|
from collections import defaultdict
|
|
from typing import Any, Optional
|
|
|
|
from llmds.tokenizer import Tokenizer
|
|
|
|
|
|
class InvertedIndex:
|
|
"""
|
|
Compressed inverted index with varint/zigzag encoding and BM25 scoring.
|
|
|
|
Stores postings lists with compression and provides BM25 retrieval.
|
|
|
|
Reference:
|
|
Robertson & Zaragoza (2009). The probabilistic relevance framework:
|
|
BM25 and beyond.
|
|
"""
|
|
|
|
def __init__(self, tokenizer: Optional[Tokenizer] = None):
|
|
"""
|
|
Initialize inverted index.
|
|
|
|
Args:
|
|
tokenizer: Tokenizer instance (creates default if None)
|
|
"""
|
|
self.tokenizer = tokenizer or Tokenizer()
|
|
self._inverted_lists: dict[str, list[int]] = defaultdict(list) # term -> doc_ids
|
|
self._doc_lengths: dict[int, int] = {} # doc_id -> length
|
|
self._doc_terms: dict[int, dict[str, int]] = {} # doc_id -> term -> count
|
|
self._total_docs = 0
|
|
self._avg_doc_length = 0.0
|
|
# BM25 parameters
|
|
self.k1 = 1.2
|
|
self.b = 0.75
|
|
|
|
def _encode_varint(self, value: int) -> bytes:
|
|
"""Encode integer as varint."""
|
|
result = bytearray()
|
|
while value >= 0x80:
|
|
result.append((value & 0x7F) | 0x80)
|
|
value >>= 7
|
|
result.append(value & 0x7F)
|
|
return bytes(result)
|
|
|
|
def _decode_varint(self, data: bytes, offset: int) -> tuple[int, int]:
|
|
"""Decode varint from bytes."""
|
|
value = 0
|
|
shift = 0
|
|
while offset < len(data):
|
|
byte = data[offset]
|
|
value |= (byte & 0x7F) << shift
|
|
offset += 1
|
|
if (byte & 0x80) == 0:
|
|
break
|
|
shift += 7
|
|
return value, offset
|
|
|
|
def _zigzag_encode(self, value: int) -> int:
|
|
"""Zigzag encode for signed integers."""
|
|
return (value << 1) ^ (value >> 31)
|
|
|
|
def _zigzag_decode(self, value: int) -> int:
|
|
"""Zigzag decode."""
|
|
return (value >> 1) ^ (-(value & 1))
|
|
|
|
def add_document(self, doc_id: int, text: str) -> None:
|
|
"""
|
|
Add a document to the index.
|
|
|
|
Args:
|
|
doc_id: Document identifier
|
|
text: Document text
|
|
"""
|
|
tokens = self.tokenizer.encode(text)
|
|
term_counts: dict[str, int] = defaultdict(int)
|
|
|
|
# Count term frequencies
|
|
for token_id in tokens:
|
|
term = self.tokenizer.decode([token_id])
|
|
if term:
|
|
term_counts[term] += 1
|
|
|
|
# Update inverted lists
|
|
for term, count in term_counts.items():
|
|
if doc_id not in self._inverted_lists[term]:
|
|
self._inverted_lists[term].append(doc_id)
|
|
|
|
# Store document metadata
|
|
self._doc_lengths[doc_id] = len(tokens)
|
|
self._doc_terms[doc_id] = term_counts
|
|
|
|
# Update average document length
|
|
self._total_docs += 1
|
|
total_length = sum(self._doc_lengths.values())
|
|
self._avg_doc_length = total_length / self._total_docs if self._total_docs > 0 else 0.0
|
|
|
|
def _bm25_score(self, term: str, doc_id: int, query_term_freq: int) -> float:
|
|
"""
|
|
Calculate BM25 score for a term-document pair.
|
|
|
|
Args:
|
|
term: Query term
|
|
doc_id: Document ID
|
|
query_term_freq: Frequency of term in query
|
|
|
|
Returns:
|
|
BM25 score
|
|
"""
|
|
if doc_id not in self._doc_terms or term not in self._doc_terms[doc_id]:
|
|
return 0.0
|
|
|
|
# Term frequency in document
|
|
tf = self._doc_terms[doc_id][term]
|
|
|
|
# Document frequency
|
|
df = len(self._inverted_lists.get(term, []))
|
|
|
|
# Inverse document frequency
|
|
idf = 0.0
|
|
if df > 0:
|
|
idf = (self._total_docs - df + 0.5) / (df + 0.5)
|
|
idf = max(0.0, idf) # Avoid negative IDF
|
|
|
|
# Document length normalization
|
|
doc_length = self._doc_lengths.get(doc_id, 1)
|
|
length_norm = (1 - self.b) + self.b * (doc_length / self._avg_doc_length)
|
|
|
|
# BM25 formula
|
|
score = (
|
|
idf
|
|
* (tf * (self.k1 + 1))
|
|
/ (tf + self.k1 * length_norm)
|
|
* (query_term_freq / (query_term_freq + 0.5))
|
|
)
|
|
|
|
return score
|
|
|
|
def search(self, query: str, top_k: int = 10) -> list[tuple[int, float]]:
|
|
"""
|
|
Search the index with BM25 scoring.
|
|
|
|
Args:
|
|
query: Query text
|
|
top_k: Number of top results to return
|
|
|
|
Returns:
|
|
List of (doc_id, score) tuples sorted by score descending
|
|
"""
|
|
query_tokens = self.tokenizer.encode(query)
|
|
query_term_counts: dict[str, int] = defaultdict(int)
|
|
|
|
for token_id in query_tokens:
|
|
term = self.tokenizer.decode([token_id])
|
|
if term:
|
|
query_term_counts[term] += 1
|
|
|
|
# Score all candidate documents
|
|
doc_scores: dict[int, float] = defaultdict(float)
|
|
|
|
for term, query_freq in query_term_counts.items():
|
|
if term in self._inverted_lists:
|
|
for doc_id in self._inverted_lists[term]:
|
|
score = self._bm25_score(term, doc_id, query_freq)
|
|
doc_scores[doc_id] += score
|
|
|
|
# Sort by score and return top-k
|
|
sorted_results = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
|
|
return sorted_results[:top_k]
|
|
|
|
def get_term_frequency(self, term: str, doc_id: int) -> int:
|
|
"""
|
|
Get term frequency in a document.
|
|
|
|
Args:
|
|
term: Term
|
|
doc_id: Document ID
|
|
|
|
Returns:
|
|
Term frequency
|
|
"""
|
|
if doc_id in self._doc_terms:
|
|
return self._doc_terms[doc_id].get(term, 0)
|
|
return 0
|
|
|
|
def get_document_frequency(self, term: str) -> int:
|
|
"""
|
|
Get document frequency of a term.
|
|
|
|
Args:
|
|
term: Term
|
|
|
|
Returns:
|
|
Document frequency
|
|
"""
|
|
return len(self._inverted_lists.get(term, []))
|
|
|
|
def stats(self) -> dict[str, Any]:
|
|
"""
|
|
Get index statistics.
|
|
|
|
Returns:
|
|
Dictionary with index statistics
|
|
"""
|
|
total_postings = sum(len(postings) for postings in self._inverted_lists.values())
|
|
return {
|
|
"total_documents": self._total_docs,
|
|
"total_terms": len(self._inverted_lists),
|
|
"total_postings": total_postings,
|
|
"avg_doc_length": self._avg_doc_length,
|
|
"avg_postings_per_term": (
|
|
total_postings / len(self._inverted_lists) if self._inverted_lists else 0.0
|
|
),
|
|
}
|
|
|