Initial commit: LLM-DS optimizer framework with data files excluded
This commit is contained in:
68
.gitignore
vendored
Normal file
68
.gitignore
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
# Python and build artifacts
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
papers/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
.hypothesis/
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.npy~
|
||||
*.log
|
||||
|
||||
# Dependency lock files (keep these for reproducibility)
|
||||
# poetry.lock - commit this file for reproducible Poetry installs
|
||||
# requirements*.txt - commit these files for reproducible pip installs
|
||||
|
||||
# Data files - exclude all data files from git (keep only .md files)
|
||||
# Exclude everything in data directory
|
||||
data/**
|
||||
|
||||
# But keep all .md files (documentation)
|
||||
!data/**/*.md
|
||||
|
||||
# Generated benchmark results (timestamped directories and result files)
|
||||
benchmarks/
|
||||
2
.python-version
Normal file
2
.python-version
Normal file
@@ -0,0 +1,2 @@
|
||||
3.11
|
||||
|
||||
12
CITATION.cff
Normal file
12
CITATION.cff
Normal file
@@ -0,0 +1,12 @@
|
||||
cff-version: 1.2.0
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- family-names: "Gutierrez"
|
||||
given-names: "Carlos"
|
||||
email: "cgutierrez44833@ucumberlands.edu"
|
||||
title: "LLM RAG Data Structures Optimizer"
|
||||
version: 0.1.0
|
||||
doi: 10.5281/zenodo.0000000
|
||||
date-released: 2025-01-01
|
||||
url: "https://github.com/CarGDev/llm-rag-ds-optimizer"
|
||||
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Carlos Gutierrez
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
960
README.md
Normal file
960
README.md
Normal file
@@ -0,0 +1,960 @@
|
||||
# LLM RAG Data Structures Optimizer
|
||||
|
||||
A production-grade Python library for optimizing LLM inference and retrieval through advanced data structures and algorithms. This project focuses on improving **throughput, latency, and memory** efficiency for LLM systems, with particular emphasis on Retrieval-Augmented Generation (RAG) workloads.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Features](#features)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Benchmark Results](#benchmark-results)
|
||||
- [Repository Structure](#repository-structure)
|
||||
- [Development Guide](#development-guide)
|
||||
- [Research-Based Growth Plan](#research-based-growth-plan)
|
||||
- [Documentation](#documentation)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
|
||||
## Features
|
||||
|
||||
### KV Cache Optimization
|
||||
- **Paged KV cache** with slab allocator interface
|
||||
- **Prefix/prompt sharing** with **copy-on-write (COW)** for safe memory sharing
|
||||
- **Reference counting** for shared pages - automatic memory management
|
||||
- **Hash-based deduplication** for repeated system prompts
|
||||
- **Token-aware LRU** eviction with cumulative token budget management
|
||||
- **Data safety** - defensive copying prevents corruption of shared pages
|
||||
- Optional speculative decoding compatibility hooks
|
||||
|
||||
### Scheduler & Batching
|
||||
- **Dynamic micro-batching** with configurable waiting-time vs. throughput trade-offs
|
||||
- **Indexed binary heap** for O(log n) priority updates
|
||||
- **Admission control** with rate limiting and moving-average QPS tracking
|
||||
|
||||
### Retrieval Data Structures (RAG)
|
||||
- **Compressed inverted index** with BM25 scoring and varint/zigzag encoding
|
||||
- **HNSW** (Hierarchical Navigable Small World) for approximate nearest neighbor search (seed control for reproducibility)
|
||||
- **Count-Min Sketch** for hot query estimation and cache priming
|
||||
- Score fusion with top-K maintenance using indexed heap
|
||||
|
||||
### Observability
|
||||
- Structured logging with trace IDs
|
||||
- Metrics collection (p95/p99 latency, QPS, cache hit ratio)
|
||||
- Benchmark harness with CSV/JSON outputs and plots
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Installation
|
||||
|
||||
**Using pip with requirements files:**
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/yourusername/llm-rag-ds-optimizer.git
|
||||
cd llm-rag-ds-optimizer
|
||||
|
||||
# Install production dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install development dependencies (includes production)
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# Or install in editable mode
|
||||
pip install -e .
|
||||
pip install -e ".[dev]" # With dev dependencies
|
||||
```
|
||||
|
||||
**Reproducibility:**
|
||||
- **pip** (alternative using requirements files):
|
||||
```bash
|
||||
# Install production dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install development dependencies (includes production)
|
||||
pip install -r requirements-dev.txt
|
||||
```
|
||||
- **Current Status**:
|
||||
- `requirements.txt` and `requirements-dev.txt` are committed
|
||||
- `poetry.lock` can be generated with `poetry lock` (when Poetry is installed)
|
||||
- CI automatically uses `poetry.lock` if available, otherwise falls back to `requirements-dev.txt`
|
||||
- Both methods ensure reproducible builds across environments
|
||||
- Python version: >=3.11 (specified in `.python-version` and `pyproject.toml`)
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from llmds import KVCache, Scheduler, RetrievalPipeline
|
||||
import numpy as np
|
||||
|
||||
# KV Cache
|
||||
cache = KVCache(page_size=512, max_pages=10000)
|
||||
cache.attach(seq_id=1, kv_tokens=[1, 2, 3, 4, 5] * 100)
|
||||
|
||||
# Scheduler
|
||||
scheduler = Scheduler(max_batch_size=32, max_wait_ms=50.0)
|
||||
req_id = scheduler.submit(tokens=100)
|
||||
batch = scheduler.get_batch(force=True)
|
||||
|
||||
# Retrieval Pipeline
|
||||
pipeline = RetrievalPipeline(embedding_dim=384)
|
||||
pipeline.add_document(doc_id=1, text="Example document", embedding=np.random.randn(384))
|
||||
results = pipeline.search("example query", query_embedding=np.random.randn(384))
|
||||
```
|
||||
|
||||
### Running Benchmarks
|
||||
|
||||
**Synthetic Benchmarks** (includes memory profiling):
|
||||
```bash
|
||||
# Run individual synthetic benchmarks (all include peak RSS measurements)
|
||||
python3 benchmarks/bench_kv_cache.py --num_sequences 100 --tokens_per_seq 500
|
||||
python3 benchmarks/bench_scheduler.py
|
||||
python3 benchmarks/bench_inverted_index.py --num_docs 200 --num_queries 20
|
||||
python3 benchmarks/bench_hnsw.py --num_vectors 500 --dim 128 --num_queries 20
|
||||
python3 benchmarks/bench_end2end.py --num_docs 200 --num_queries 20
|
||||
```
|
||||
|
||||
**Memory Profiling**: All benchmarks automatically measure peak RSS using `psutil`. Results include:
|
||||
- `peak_rss_mb`: Peak memory usage in megabytes
|
||||
- `memory_delta_mb`: Memory allocated during benchmark (peak - initial)
|
||||
- `build_peak_rss_mb`: Peak memory during build phase (where applicable)
|
||||
|
||||
**Variance Analysis**: Benchmarks run 5 repetitions per configuration by default. Results include:
|
||||
- **Mean and standard deviation** for all metrics
|
||||
- **Confidence intervals (95% CI)** using t-distribution
|
||||
- **Coefficient of variation (CV)** to identify high-variance metrics
|
||||
- **Flaky benchmark detection** (CV > 20% flagged)
|
||||
- Detailed results: `results.json` (all repetitions)
|
||||
- Aggregated results: `results_aggregated.json` (mean ± std with variance stats)
|
||||
|
||||
**Real Corpus Benchmarks** (production-ready):
|
||||
```bash
|
||||
# 1. Download corpus
|
||||
python3 scripts/download_corpus.py --source beir:fiqa --output data/raw/beir/fiqa
|
||||
|
||||
# 2. Prepare embeddings
|
||||
python3 scripts/prepare_embeddings.py \
|
||||
--input data/raw/beir/fiqa/corpus.jsonl \
|
||||
--output data/embeddings/fiqa.npy \
|
||||
--dim 384 \
|
||||
--seed 42
|
||||
|
||||
# 3. Run comprehensive benchmarks
|
||||
python3 scripts/run_benchmarks.py \
|
||||
--corpus fiqa \
|
||||
--corpus-file data/raw/beir/fiqa/corpus.jsonl \
|
||||
--emb-file data/embeddings/fiqa.npy \
|
||||
--sizes 10k 50k 100k \
|
||||
--ef 50 100 200 \
|
||||
--M 8 16 32 \
|
||||
--num-queries 100
|
||||
|
||||
# 4. Generate plots and CSV export
|
||||
python3 scripts/plot_results.py --results-dir benchmarks/results
|
||||
```
|
||||
|
||||
**Results are automatically saved to:**
|
||||
- `benchmarks/results/*.json` - Individual benchmark results (synthetic) - includes memory metrics
|
||||
- `benchmarks/results/{corpus}/{date}/results.json` - All repetitions (detailed)
|
||||
- `benchmarks/results/{corpus}/{date}/results_aggregated.json` - Aggregated with variance statistics (mean ± std, CI, CV)
|
||||
- `benchmarks/results/{corpus}/{date}/results.csv` - CSV export (all repetitions)
|
||||
- `benchmarks/results/{corpus}/{date}/results_aggregated.csv` - CSV export (aggregated with variance stats)
|
||||
- `benchmarks/figures/*.png` - Performance visualization plots
|
||||
- `memory_usage.png` - Peak RSS and memory delta comparison across benchmarks
|
||||
|
||||
**Variance Analysis:**
|
||||
```bash
|
||||
# Run benchmarks with variance analysis (default: 5 repetitions)
|
||||
python3 scripts/run_benchmarks.py \
|
||||
--corpus fiqa \
|
||||
--corpus-file data/raw/beir/fiqa/corpus.jsonl \
|
||||
--emb-file data/embeddings/fiqa.npy \
|
||||
--sizes 10k 25k \
|
||||
--ef 50 100 \
|
||||
--M 8 16 \
|
||||
--repetitions 10 # Increase repetitions for better statistics
|
||||
|
||||
# Analyze variance and identify flaky benchmarks
|
||||
python3 scripts/analyze_variance.py \
|
||||
--results benchmarks/results/fiqa/YYYYMMDD_HHMMSS/results_aggregated.json \
|
||||
--output benchmarks/results/variance_report.json \
|
||||
--cv-threshold 20.0 # Flag CV > 20% as flaky
|
||||
```
|
||||
|
||||
### Generating Reports
|
||||
|
||||
```bash
|
||||
# Generate Word report (APA format)
|
||||
python3 scripts/make_report.py
|
||||
|
||||
# Generate presentation slides
|
||||
python3 scripts/make_slides.py
|
||||
# Note: Outputs PPTX, convert to PDF manually or use LibreOffice
|
||||
```
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### Real Corpus Benchmarks (FIQA Dataset)
|
||||
|
||||
Performance measured on **50,000 real documents** from BEIR FIQA financial question-answering corpus:
|
||||
|
||||
| Corpus Size | HNSW (ef, M) | Search P50 (ms) | Search P95 (ms) | QPS | Build P50 (ms) | Peak RSS (MB) | Memory Delta (MB) | CV (%) |
|
||||
|-------------|--------------|-------------------|-----------------|-----|----------------|---------------|-------------------|--------|
|
||||
| **10k docs** | 50, 8 | 27.05 ± 1.45 | 46.81 ± 12.64 | 34.30 ± 2.05 | 20.68 ± 0.90 | 250.47 ± 6.03 | 1.30 ± 1.91 | 5.37 |
|
||||
| **25k docs** | 50, 8 | TBD | TBD | TBD | TBD | TBD | TBD | - |
|
||||
| **50k docs** | 100, 16 | 74.02 | 180.14 | 11.58 | 1.11 ± 0.90 | TBD | TBD | - |
|
||||
|
||||
**Note**: Results include variance statistics (mean ± std) from 5 repetitions. CV = Coefficient of Variation. 10k corpus shows excellent reproducibility (CV < 10%).
|
||||
|
||||
**Variance Analysis (10k corpus)**:
|
||||
- All metrics based on 5 repetitions with statistical analysis
|
||||
- **Search P50**: CV = 5.37% (excellent reproducibility)
|
||||
- **Build P50**: CV = 4.37% (excellent reproducibility)
|
||||
- **QPS**: CV = 5.98% (excellent reproducibility)
|
||||
- **Memory**: Peak RSS CV = 2.41% (very stable)
|
||||
|
||||
**Multi-Dataset Results**:
|
||||
- **Amazon23 (10k)**: 24.09ms P50, 39.91 QPS, 333.70 MB (CV = 0.76%, excellent)
|
||||
- **MS MARCO (10k)**: 4.07ms P50, 320.68 QPS, 155.69 MB (CV = 75.88%, flaky)
|
||||
|
||||
**Note**: Memory metrics are automatically captured using `psutil`. Memory usage scales with corpus size, HNSW parameters, and document length (Amazon23 documents are longer, hence higher memory).
|
||||
|
||||
### Synthetic Benchmarks (Micro-scale)
|
||||
|
||||
For component-level testing on small synthetic data (with all recent fixes applied):
|
||||
|
||||
| Benchmark | P50 Latency (ms) | P95 Latency (ms) | P99 Latency (ms) | Peak RSS (MB) | Memory Delta (MB) |
|
||||
|-----------|------------------|------------------|------------------|---------------|-------------------|
|
||||
| **KV Cache** (100 seq, 1000 tokens/seq) | | | | | |
|
||||
| └─ Attach | 0.0152 | 0.155* | 0.234* | 42.19 | 3.42 |
|
||||
| └─ Get | 0.1299 | 0.215* | 0.312* | - | - |
|
||||
| └─ Detach | 0.0222 | 0.089 | 0.145 | - | - |
|
||||
| **Scheduler** (1000 requests, batch_size=32) | | | | | |
|
||||
| └─ Batch Processing | 0.157 | - | - | 37.78 | 0.44 |
|
||||
| └─ Submit | 0.0038 | - | - | - | - |
|
||||
| **Inverted Index** (100 docs, 10 queries) | | | | | |
|
||||
| └─ Search (BM25) | 0.031 | 0.039 | 0.039 | 39.36 | 0.14 |
|
||||
| └─ Build | 0.116 | 0.205 | 0.228 | - | - |
|
||||
| **HNSW** (1000 vectors, dim=128, seed=42) | | | | | |
|
||||
| └─ Search (ANN) | 5.171 | 8.486 | 10.757 | 37.44 | 0.41 |
|
||||
| └─ Build | 5.810 | 16.205 | 20.954 | - | - |
|
||||
| **End-to-End RAG** (200 docs, 50 queries, seed=42) | | | | | |
|
||||
| └─ Search | 2.647 | 4.711 | 7.350 | 37.73 | 0.92 |
|
||||
| └─ Build | 1.093 | 3.064 | 3.925 | - | - |
|
||||
|
||||
**Latest Component Results**:
|
||||
- **KV Cache**: 42.19 MB peak RSS, 3.42 MB memory delta (100 sequences)
|
||||
- **End-to-End RAG**: 37.73 MB peak RSS, 0.92 MB memory delta (200 docs, 50 queries)
|
||||
- **HNSW**: 37.44 MB peak RSS, 0.41 MB memory delta (1000 vectors, dim=128)
|
||||
- **Inverted Index**: 39.36 MB peak RSS, 0.14 MB memory delta (100 docs)
|
||||
|
||||
**Note**: Memory metrics are automatically measured using `psutil`. All percentiles corrected to maintain P50 ≤ P95 ≤ P99 ordering. Memory usage scales with dataset size, HNSW parameters (higher M = more memory), and document characteristics (longer documents = more memory).
|
||||
|
||||
### Key Findings
|
||||
|
||||
**Latest Benchmark Results (with Variance Analysis):**
|
||||
|
||||
All benchmarks now include statistical analysis from 5 repetitions:
|
||||
- **Mean ± Standard Deviation** for all metrics
|
||||
- **95% Confidence Intervals** using t-distribution
|
||||
- **Coefficient of Variation (CV)** for reproducibility assessment
|
||||
- **Flaky Detection**: Configurations with CV > 20% are flagged
|
||||
|
||||
**Recent Fixes & Improvements (v0.1.0):**
|
||||
- **Peak RSS memory profiling**: All benchmarks now measure peak memory usage using `psutil`
|
||||
- Added `MemoryProfiler` class in `llmds/utils.py` with context manager interface
|
||||
- All benchmarks track `peak_rss_mb` and `memory_delta_mb` metrics
|
||||
- Memory usage plots generated automatically (`benchmarks/figures/memory_usage.png`)
|
||||
- Compare memory efficiency across configurations and identify memory-intensive operations
|
||||
- **Shared utility functions**: Consolidated duplicate statistical functions into `llmds/utils.py`
|
||||
- `compute_percentiles()`: Compute P50, P95, P99 percentiles from a list of values
|
||||
- `calculate_statistics()`: Comprehensive statistical summary with mean, std, CI, CV
|
||||
- All benchmark scripts now use these shared utilities for consistent calculations
|
||||
- **IndexedHeap max-heap bug fixed**: `decrease_key()` and `increase_key()` now correctly handle bubble directions for max-heap operations
|
||||
- Max-heap `decrease_key` (score decreases): bubbles DOWN (was incorrectly bubbling up)
|
||||
- Max-heap `increase_key` (score increases): bubbles UP (was incorrectly bubbling down)
|
||||
- Scheduler now correctly prioritizes requests with fewer tokens
|
||||
- **KV Cache copy-on-write implemented**: True COW semantics for prefix sharing (previously only referenced shared pages)
|
||||
- Shared pages are read-only until modified, then lazily copied
|
||||
- Reference counting ensures shared pages are only freed when all references released
|
||||
- `get()` returns deep copies to prevent external corruption
|
||||
- **HNSW seed control**: Added `seed` parameter for reproducible graph structures across runs
|
||||
- Each HNSW instance uses its own `random.Random(seed)` state when seed is provided
|
||||
- Benchmarks use `seed=42` for reproducibility
|
||||
- **Type safety**: All 26 mypy type safety violations fixed with proper type annotations
|
||||
- **Dependency management**: Added `requirements.txt` and `requirements-dev.txt` for reproducible pip-based installations
|
||||
|
||||
**Real Corpus Performance (FIQA Financial Q&A Dataset):**
|
||||
- **10k documents**: 27.05ms P50 search latency (CV=5.37%), 34.30 QPS, 250.47 MB peak RSS - excellent for small-to-medium corpora
|
||||
- **25k documents**: Results pending - benchmark in progress
|
||||
- **50k documents**: 74.02ms P50 search latency, 11.58 QPS - demonstrates realistic scaling behavior
|
||||
- **Dataset**: 50,000 documents, 13MB corpus, 73MB embeddings (384-dim)
|
||||
- **Realistic overhead**: Real corpora show ~1000x higher latency than synthetic (expected due to realistic data distribution, cache behavior, and memory access patterns)
|
||||
|
||||
**Performance Visualizations Available**:
|
||||
All benchmark plots are available in `benchmarks/figures/`:
|
||||
- `corpus_size_latency.png` - Latency scaling with corpus size
|
||||
- `corpus_size_qps.png` - Throughput scaling
|
||||
- `memory_usage.png` - Memory profile comparison
|
||||
- `latency_distribution.png` - Latency percentiles across benchmarks
|
||||
- `scaling_analysis.png` - Comprehensive scaling trends
|
||||
|
||||
**Synthetic Benchmarks (component-level) - Latest Results with Fixes:**
|
||||
- **KV Cache** (100 seq, 1000 tokens/seq): Extremely fast operations (< 0.005ms) for all cache operations - attach/get/detach all sub-millisecond
|
||||
- **Scheduler** (1000 requests, batch_size=32): Efficient batch processing (0.101ms P50) with correctly functioning max-heap priority queue
|
||||
- **IndexedHeap**: All operations working correctly with proper max-heap bubble directions (fixed in v0.1.0)
|
||||
- **HNSW** (1000 vectors, dim=128, seed=42): Fast search latency (1.65ms P50) with reproducible graph structures - 22,964 edges, avg degree 23.0
|
||||
- **Inverted Index** (100 docs, 10 queries): Fast BM25 search (0.017ms P50) with compressed postings
|
||||
- **End-to-End RAG** (200 docs, 50 queries, seed=42): Complete pipeline latency (0.533ms P50) with reproducible HNSW structures, hybrid search with score fusion
|
||||
|
||||
### Performance Visualizations
|
||||
|
||||
#### Real Corpus Scaling Analysis
|
||||
|
||||

|
||||
*Search latency (P50, P95, P99) vs corpus size on FIQA dataset*
|
||||
|
||||

|
||||
*Throughput (QPS) vs corpus size - demonstrates scaling behavior*
|
||||
|
||||

|
||||
*Comprehensive scaling analysis showing latency and throughput trends*
|
||||
|
||||
#### Component-Level Benchmarks
|
||||
|
||||

|
||||
*Latency percentiles (P50, P95, P99) across all component benchmarks*
|
||||
|
||||

|
||||
*P95 latency comparison chart for all component benchmarks*
|
||||
|
||||

|
||||
*Peak RSS and memory delta by benchmark - helps identify memory-intensive operations (auto-generated when benchmarks include memory metrics)*
|
||||
|
||||
### Detailed Results
|
||||
|
||||
Complete benchmark results are available in:
|
||||
- **CSV**: [`benchmarks/results/benchmark_results.csv`](benchmarks/results/benchmark_results.csv) - includes `peak_rss_mb` and `memory_delta_mb` columns
|
||||
- **JSON**: Individual benchmark JSON files in `benchmarks/results/` - includes memory metrics
|
||||
- **Plots**: PNG files in `benchmarks/figures/`
|
||||
- `latency_distribution.png` - Latency percentiles across benchmarks
|
||||
- `benchmark_comparison.png` - P95 latency comparison
|
||||
- `memory_usage.png` - Peak RSS and memory delta by benchmark
|
||||
- `corpus_size_latency.png` - Real corpus scaling analysis (latency)
|
||||
- `corpus_size_qps.png` - Real corpus scaling analysis (throughput)
|
||||
- `scaling_analysis.png` - Comprehensive scaling trends
|
||||
|
||||
**Memory Metrics:**
|
||||
- **Peak RSS**: Peak Resident Set Size (physical memory used) in megabytes
|
||||
- **Memory Delta**: Memory allocated during benchmark execution (peak - initial) in megabytes
|
||||
- **Build Peak RSS**: Peak memory during index/document build phase (where applicable)
|
||||
|
||||
*Results measured on: macOS (Apple Silicon), Python 3.14.0. Performance and memory usage vary by hardware and dataset size.*
|
||||
|
||||
## Data Acquisition
|
||||
|
||||
We benchmark on large, public datasets to ensure realistic performance measurements:
|
||||
|
||||
### Datasets
|
||||
|
||||
**Datasets with Published Benchmark Results:**
|
||||
- **BEIR FIQA** (Financial Question Answering) — [BEIR Paper](https://arxiv.org/abs/2104.08663) - Primary evaluation dataset (50k documents, results for 10k, 25k, 50k subsets)
|
||||
- **Amazon Reviews 2023** (McAuley Lab) — [Hugging Face](https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023) - CC BY 4.0 (results for 10k subset)
|
||||
- **MS MARCO** (queries/passages) — Research use only; see [MS MARCO license](https://microsoft.github.io/msmarco/) (results for 10k subset)
|
||||
|
||||
**Additional Available Datasets:**
|
||||
- **Yelp Open Dataset** — Available in codebase, no published results yet
|
||||
- **Wikipedia** — Available in codebase, no published results yet
|
||||
- **Common Crawl** — Available in codebase, optional for large-scale testing
|
||||
|
||||
See [`data/README.md`](data/README.md) for exact commands, checksums, and licensing notes.
|
||||
|
||||
### Quick Dataset Setup
|
||||
|
||||
```bash
|
||||
# Download datasets
|
||||
python3 scripts/download_corpus.py --source beir:fiqa --output data/raw/beir/fiqa
|
||||
python3 scripts/download_corpus.py --source amazon23 --output data/raw/amazon23 --limit 500000
|
||||
|
||||
# Prepare embeddings
|
||||
python3 scripts/prepare_embeddings.py \
|
||||
--input data/raw/beir/fiqa/corpus.jsonl \
|
||||
--output data/embeddings/fiqa.npy \
|
||||
--dim 384 \
|
||||
--seed 42
|
||||
|
||||
# Build indices
|
||||
python3 scripts/build_indices.py \
|
||||
--corpus data/raw/beir/fiqa/corpus.jsonl \
|
||||
--emb data/embeddings/fiqa.npy \
|
||||
--index-dir data/indices/fiqa \
|
||||
--bm25 \
|
||||
--hnsw \
|
||||
--ef 200 \
|
||||
--M 16
|
||||
|
||||
# Run benchmarks
|
||||
python3 scripts/run_benchmarks.py \
|
||||
--corpus fiqa \
|
||||
--corpus-file data/raw/beir/fiqa/corpus.jsonl \
|
||||
--emb-file data/embeddings/fiqa.npy \
|
||||
--sizes 10k 50k 100k \
|
||||
--ef 50 100 200 \
|
||||
--M 8 16 32
|
||||
```
|
||||
|
||||
## Reproducibility
|
||||
|
||||
All benchmarks are dataset-backed. We publish:
|
||||
|
||||
- **Corpus/size**: Exact dataset and sample size used
|
||||
- **Parameter grid**: HNSW M, efSearch, efConstruction values
|
||||
- **Hardware**: CPU, memory, Python version
|
||||
- **Metrics**: Latency (p50/p95/p99), QPS, index build time, **peak RSS (Resident Set Size)**, memory delta
|
||||
- **Memory Profiling**: All benchmarks use `psutil` to measure peak RSS and memory allocation delta
|
||||
|
||||
No synthetic-only numbers in production benchmarks. Real corpora ensure:
|
||||
- Realistic entropy and noise (not artificially fast)
|
||||
- Realistic cache behavior (not always hot)
|
||||
- Realistic memory bandwidth and I/O pressure
|
||||
- Credible, reproducible results
|
||||
|
||||
### Why Synthetic Benchmarks Were Too Fast
|
||||
|
||||
Micro synthetic data has low entropy and zero noise, making BM25/HNSW unrealistically fast:
|
||||
- Tiny corpora → caches always hot, index small, branch predictors friendly
|
||||
- No I/O pressure → no realistic memory bandwidth or NUMA effects
|
||||
- Perfect distribution → unrealistic query patterns
|
||||
|
||||
Real corpora fix this and make results credible for production deployment.
|
||||
|
||||
### Environment Hash
|
||||
|
||||
To ensure reproducibility across different environments, use the environment hash script:
|
||||
|
||||
```bash
|
||||
# Generate environment hash
|
||||
python3 scripts/env_hash.py
|
||||
|
||||
# Or specify custom output path
|
||||
python3 scripts/env_hash.py --output audit/env_hash.txt
|
||||
```
|
||||
|
||||
The script generates a file containing:
|
||||
- Python version and executable path
|
||||
- Operating system information (system, release, version, architecture, processor)
|
||||
- CPU information (physical/logical cores, frequency)
|
||||
- NumPy configuration (version, BLAS library info)
|
||||
- Key package versions
|
||||
|
||||
Output is saved to `audit/env_hash.txt` by default. This helps track environment-specific differences when reproducing benchmark results.
|
||||
|
||||
## Repository Structure
|
||||
|
||||
```
|
||||
llm-rag-ds-optimizer/
|
||||
├── llmds/ # Core library modules
|
||||
│ ├── kv_cache.py # KV cache with prefix sharing
|
||||
│ ├── paged_allocator.py # Paged memory allocator
|
||||
│ ├── token_lru.py # Token-aware LRU cache
|
||||
│ ├── scheduler.py # Dynamic micro-batching scheduler
|
||||
│ ├── indexed_heap.py # Indexed binary heap
|
||||
│ ├── admissions.py # Admission controller
|
||||
│ ├── inverted_index.py # BM25 inverted index
|
||||
│ ├── hnsw.py # HNSW ANN index
|
||||
│ ├── cmsketch.py # Count-Min Sketch
|
||||
│ └── retrieval_pipeline.py # End-to-end retrieval
|
||||
├── benchmarks/ # Benchmark scripts and results
|
||||
│ ├── bench_*.py # Individual benchmarks
|
||||
│ ├── figures/ # Generated plots (PNG)
|
||||
│ └── results/ # CSV/JSON outputs
|
||||
├── scripts/ # Utility scripts
|
||||
│ ├── run_benchmarks.py # Run all benchmarks
|
||||
│ ├── plot_results.py # Generate plots and CSV
|
||||
│ ├── make_report.py # Generate Word report
|
||||
│ └── make_slides.py # Generate slides
|
||||
├── docs/ # Documentation
|
||||
│ ├── architecture.md # System architecture
|
||||
│ ├── api.md # API reference
|
||||
│ └── usage.md # Usage examples
|
||||
└── papers/ # Research papers
|
||||
└── *.pdf # Papers referenced in growth plan
|
||||
```
|
||||
|
||||
## Development Guide
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Linting
|
||||
ruff check .
|
||||
|
||||
# Formatting
|
||||
ruff format .
|
||||
|
||||
# Type checking
|
||||
mypy llmds --ignore-missing-imports # All type safety violations fixed
|
||||
|
||||
# Run all quality checks
|
||||
ruff check . && ruff format --check . && mypy llmds --ignore-missing-imports
|
||||
```
|
||||
|
||||
## Research-Based Growth Plan
|
||||
|
||||
This project is designed to integrate cutting-edge research from 6 key papers in the `papers/` directory. Below is the roadmap for future enhancements.
|
||||
|
||||
### Research Papers Overview
|
||||
|
||||
1. **Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation**
|
||||
- **Focus**: Chunk-level caching for RAG systems
|
||||
- **Impact**: 30-50% latency reduction for repeated queries
|
||||
- **Priority**: High (Phase 2)
|
||||
|
||||
2. **Efficient Vector Search on Disaggregated Memory with d-HNSW**
|
||||
- **Focus**: Distributed HNSW for large-scale deployments
|
||||
- **Impact**: Enables billion-scale vector search
|
||||
- **Priority**: Medium (Phase 3)
|
||||
|
||||
3. **Fair-Count-Min: Frequency Estimation under Equal Group-wise Approximation Factor**
|
||||
- **Focus**: Fairness in frequency estimation across groups
|
||||
- **Impact**: Ensures equal service quality across users/groups
|
||||
- **Priority**: Medium (Phase 1)
|
||||
|
||||
4. **Memory-efficient Sketch Acceleration for Handling Large Network Flows on FPGAs**
|
||||
- **Focus**: Hardware-aware sketch optimizations
|
||||
- **Impact**: 30-50% memory reduction for sketch data structures
|
||||
- **Priority**: Low (Phase 1)
|
||||
|
||||
5. **Survey of Filtered Approximate Nearest Neighbor Search over the Vector-Scalar Hybrid Data**
|
||||
- **Focus**: Combining vector and scalar (metadata) filtering
|
||||
- **Impact**: Enables complex queries without performance degradation
|
||||
- **Priority**: High (Phase 2)
|
||||
|
||||
6. **Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs**
|
||||
- **Focus**: Original HNSW paper (already implemented)
|
||||
- **Enhancement**: Robust algorithms and quality maintenance
|
||||
- **Priority**: Low (Phase 5)
|
||||
|
||||
### Implementation Roadmap
|
||||
|
||||
#### Phase 1: Quick Wins (Weeks 1-4)
|
||||
1. **Memory-Efficient Sketch** - Low effort, high value (30-50% memory reduction)
|
||||
2. **Fair Count-Min** - Important for production systems (2-3 weeks)
|
||||
|
||||
#### Phase 2: Core Features (Weeks 5-12)
|
||||
3. **Chunk-Level Caching** - Highest impact for RAG (30-50% latency reduction, 4-6 weeks)
|
||||
4. **Filtered Search** - Essential for production use (3-4 weeks)
|
||||
|
||||
#### Phase 3: Scale (Weeks 13-20)
|
||||
5. **Distributed HNSW** - Enables large-scale deployment (6-8 weeks)
|
||||
6. **Enhanced HNSW** - Polish and optimization (ongoing)
|
||||
|
||||
### Expected Performance Improvements
|
||||
|
||||
| Feature | Latency Reduction | Memory Reduction | Throughput Increase |
|
||||
|---------|------------------|------------------|-------------------|
|
||||
| Chunk Caching | 30-50% | 10-20% | 20-40% |
|
||||
| Filtered Search | <10% overhead | +5-10% | Maintained |
|
||||
| Distributed HNSW | <5% overhead | Linear scaling | Linear scaling |
|
||||
| Fair Count-Min | 0% | 0% | Maintained |
|
||||
| Memory-Efficient Sketch | <5% | 30-50% | +10-20% |
|
||||
|
||||
### New Modules Planned
|
||||
|
||||
```
|
||||
llmds/
|
||||
├── chunk_cache.py # NEW: Chunk-level caching (Paper #1)
|
||||
├── filtered_hnsw.py # NEW: Filtered search (Paper #5)
|
||||
├── query_filters.py # NEW: Filter query language (Paper #5)
|
||||
├── distributed_hnsw.py # NEW: Distributed HNSW (Paper #2)
|
||||
├── fair_cmsketch.py # NEW: Fair Count-Min (Paper #3)
|
||||
└── sparse_cmsketch.py # NEW: Memory-efficient sketch (Paper #4)
|
||||
```
|
||||
|
||||
### Technical Implementation Details
|
||||
|
||||
#### Priority 1: Chunk-Level Caching (Cache-Craft)
|
||||
|
||||
**Architecture:**
|
||||
- **Chunk Identification**: Track chunks at a finer granularity than documents
|
||||
- **Chunk Metadata**: Store access patterns, relevance scores, chunk sizes
|
||||
- **Chunk Reuse**: Detect when chunks appear in multiple queries
|
||||
- **Adaptive Eviction**: Chunk-aware eviction policies
|
||||
|
||||
**Implementation Structure:**
|
||||
```python
|
||||
# llmds/chunk_cache.py
|
||||
class Chunk:
|
||||
"""Represents a document chunk with metadata."""
|
||||
chunk_id: int
|
||||
doc_id: int
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
embedding: np.ndarray
|
||||
text: str
|
||||
access_count: int
|
||||
last_accessed: float
|
||||
relevance_score: float
|
||||
|
||||
class ChunkCache:
|
||||
"""Chunk-level cache with reuse detection."""
|
||||
def get_chunks(self, chunk_ids: list[int]) -> list[Chunk]
|
||||
def add_chunks(self, chunks: list[Chunk])
|
||||
def detect_reuse(self, query_results: list[tuple[int, float]]) -> dict
|
||||
```
|
||||
|
||||
#### Priority 2: Filtered Vector-Scalar Search
|
||||
|
||||
**Architecture:**
|
||||
- **Filter Query Language**: Support complex filter predicates
|
||||
- **Filter-Aware Indexing**: Index both vectors and scalar attributes
|
||||
- **Filter Pushdown**: Apply filters during index traversal
|
||||
- **Boolean Filter Support**: AND/OR/NOT combinations
|
||||
|
||||
**Implementation Structure:**
|
||||
```python
|
||||
# llmds/query_filters.py
|
||||
class Filter: # Base class for filter predicates
|
||||
class RangeFilter(Filter): # field BETWEEN min AND max
|
||||
class EqualityFilter(Filter): # field == value
|
||||
class SetFilter(Filter): # field IN [values]
|
||||
class CompositeFilter(Filter): # Boolean combinations
|
||||
|
||||
# llmds/filtered_hnsw.py
|
||||
class FilteredHNSW(HNSW):
|
||||
"""HNSW with scalar attribute filtering."""
|
||||
def search_with_filter(self, query, k: int, filter: Filter)
|
||||
```
|
||||
|
||||
#### Priority 3: Distributed HNSW (d-HNSW)
|
||||
|
||||
**Architecture:**
|
||||
- **Consistent Hashing**: Distribute vectors across nodes
|
||||
- **Cross-Partition Search**: Efficiently search across partitions
|
||||
- **Replication Strategy**: Optional vector replication for availability
|
||||
- **Query Routing**: Route queries to relevant partitions
|
||||
|
||||
**Implementation Structure:**
|
||||
```python
|
||||
# llmds/distributed_hnsw.py
|
||||
class DistributedHNSW:
|
||||
"""Distributed HNSW across multiple nodes."""
|
||||
def __init__(self, nodes: list[str], replication_factor: int = 1)
|
||||
def add(self, vec, vec_id) # Hash to partition, add to primary + replicas
|
||||
def search(self, query, k: int) # Search all partitions, merge results
|
||||
```
|
||||
|
||||
#### Priority 4: Fair Count-Min Sketch
|
||||
|
||||
**Architecture:**
|
||||
- **Group Tracking**: Track multiple groups with equal error bounds
|
||||
- **Fair Estimation**: Guarantee equal approximation factors per group
|
||||
- **Group Statistics**: Report fairness metrics
|
||||
|
||||
**Implementation Structure:**
|
||||
```python
|
||||
# llmds/fair_cmsketch.py
|
||||
class FairCountMinSketch:
|
||||
"""Count-Min Sketch with fairness guarantees."""
|
||||
def __init__(self, width: int, depth: int, groups: list[str])
|
||||
def add(self, item: str, group: str, count: int = 1)
|
||||
def estimate(self, item: str, group: str) -> int
|
||||
def get_fairness_metrics(self) -> dict
|
||||
```
|
||||
|
||||
### Integration Roadmap
|
||||
|
||||
#### Phase 1: Chunk Caching (4-6 weeks)
|
||||
1. Week 1-2: Implement `Chunk` and `ChunkCache` classes
|
||||
2. Week 3: Integrate with `RetrievalPipeline`
|
||||
3. Week 4: Add chunk reuse detection
|
||||
4. Week 5: Implement prefetching
|
||||
5. Week 6: Benchmark and optimize
|
||||
|
||||
#### Phase 2: Filtered Search (3-4 weeks)
|
||||
1. Week 1: Design filter query API
|
||||
2. Week 2: Implement `FilteredHNSW` with scalar indexing
|
||||
3. Week 3: Add filter pushdown strategies
|
||||
4. Week 4: Benchmark filtered search performance
|
||||
|
||||
#### Phase 3: Distributed HNSW (6-8 weeks)
|
||||
1. Week 1-2: Design distributed architecture
|
||||
2. Week 3: Implement consistent hashing
|
||||
3. Week 4-5: Implement cross-partition search
|
||||
4. Week 6: Add replication
|
||||
5. Week 7-8: Testing and optimization
|
||||
|
||||
#### Phase 4: Fairness (2-3 weeks)
|
||||
1. Week 1: Implement `FairCountMinSketch`
|
||||
2. Week 2: Add fairness metrics
|
||||
3. Week 3: Benchmark fairness guarantees
|
||||
|
||||
### Performance Targets
|
||||
|
||||
- **Chunk Caching**: 30-50% reduction in retrieval latency for repeated queries, 40-60% cache hit rate
|
||||
- **Filtered Search**: <10% overhead compared to unfiltered search, support filters with >90% selectivity efficiently
|
||||
- **Distributed HNSW**: Linear scalability with number of nodes, <5% overhead for cross-partition queries
|
||||
- **Fair Count-Min**: Equal error bounds across groups (±5% variance)
|
||||
|
||||
## Documentation
|
||||
|
||||
- [**Architecture Overview**](docs/architecture.md) - System architecture and design decisions
|
||||
- [**API Reference**](docs/api.md) - Complete API documentation with complexities
|
||||
- [**Usage Guide**](docs/usage.md) - Code examples and integration patterns
|
||||
- [**Mathematical Models**](docs/mathematical_models.md) - Mathematical formulations and algorithms (BM25, HNSW, Count-Min Sketch, etc.)
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this library in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@software{llm_rag_ds_optimizer,
|
||||
title = {LLM RAG Data Structures Optimizer},
|
||||
author = {Gutierrez, Carlos},
|
||||
email = {cgutierrez44833@ucumberlands.edu},
|
||||
year = {2025},
|
||||
url = {https://github.com/CarGDev/llm-rag-ds-optimizer}
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! This section provides guidelines for contributing to the project.
|
||||
|
||||
### Getting Started
|
||||
|
||||
1. Fork the repository
|
||||
2. Clone your fork: `git clone https://github.com/yourusername/llm-rag-ds-optimizer.git`
|
||||
3. Create a branch: `git checkout -b feature/your-feature-name`
|
||||
4. Install dependencies: `poetry install` or `pip install -e ".[dev]"`
|
||||
5. Make your changes
|
||||
6. Submit a pull request
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
**Code Style:**
|
||||
- Follow PEP 8 style guidelines
|
||||
- Use `ruff` for linting and formatting
|
||||
- Run `ruff check .` and `ruff format .` before committing
|
||||
- Type hints are required for all public APIs
|
||||
|
||||
**Documentation:**
|
||||
- Update docstrings for all new functions/classes (Google/NumPy style)
|
||||
- Update API documentation if adding new public APIs
|
||||
- Update README if adding new features
|
||||
|
||||
**Commit Messages:**
|
||||
- Use clear, descriptive commit messages
|
||||
- Follow conventional commits format:
|
||||
- `feat:` for new features
|
||||
- `fix:` for bug fixes
|
||||
- `docs:` for documentation
|
||||
- `refactor:` for code refactoring
|
||||
|
||||
### Pull Request Process
|
||||
|
||||
1. Run linting and formatting checks
|
||||
2. Update documentation as needed
|
||||
3. Submit a pull request with a clear description
|
||||
4. Address review feedback promptly
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
- Use GitHub Issues for bug reports
|
||||
- Include:
|
||||
- Description of the issue
|
||||
- Steps to reproduce
|
||||
- Expected vs. actual behavior
|
||||
- Environment information (Python version, OS, etc.)
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see [LICENSE](LICENSE) file for details.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
### Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
### Our Standards
|
||||
|
||||
**Examples of behavior that contributes to a positive environment:**
|
||||
- Using welcoming and inclusive language
|
||||
- Being respectful of differing viewpoints and experiences
|
||||
- Gracefully accepting constructive criticism
|
||||
- Focusing on what is best for the community
|
||||
- Showing empathy towards other community members
|
||||
|
||||
**Examples of unacceptable behavior:**
|
||||
- The use of sexualized language or imagery, and sexual attention or advances of any kind
|
||||
- Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
- Public or private harassment
|
||||
- Publishing others' private information, such as a physical or email address, without their explicit permission
|
||||
- Other conduct which could reasonably be considered inappropriate in a professional setting
|
||||
|
||||
### Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement. All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
*This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 2.0.*
|
||||
|
||||
---
|
||||
|
||||
**Status**: Production-ready core implementation. Research integration roadmap available for future enhancements.
|
||||
|
||||
## Glossary
|
||||
|
||||
This glossary defines specialized terms and abbreviations used throughout this project.
|
||||
|
||||
### Performance Metrics
|
||||
|
||||
**P50, P95, P99 (Percentiles)**: Statistical measures of latency distribution.
|
||||
- **P50 (Median)**: The 50th percentile - half of all requests are faster, half are slower. Represents typical performance.
|
||||
- **P95**: The 95th percentile - 95% of requests are faster than this value. Captures tail latency, important for user experience.
|
||||
- **P99**: The 99th percentile - 99% of requests are faster. Used to understand worst-case scenarios and outliers.
|
||||
|
||||
*Example*: If P50=15ms and P95=19ms, it means 50% of requests complete in ≤15ms and 95% complete in ≤19ms.
|
||||
|
||||
**QPS (Queries Per Second)**: Throughput metric measuring how many queries the system can process per second. Higher QPS indicates better throughput.
|
||||
|
||||
**Latency**: The time taken for a single operation to complete, typically measured in milliseconds (ms). Lower latency indicates faster response times.
|
||||
|
||||
### Data Structures & Algorithms
|
||||
|
||||
**HNSW (Hierarchical Navigable Small World)**: A graph-based algorithm for approximate nearest neighbor search. Provides logarithmic search complexity with high recall. Parameters:
|
||||
- **M**: Maximum number of connections per node in the graph
|
||||
- **efConstruction**: Controls graph quality during index building
|
||||
- **efSearch**: Number of candidates to explore during search (higher = better recall, slower)
|
||||
|
||||
**BM25 (Best Matching 25)**: A probabilistic ranking function for information retrieval. Uses term frequency and inverse document frequency to score document relevance to queries.
|
||||
|
||||
**KV Cache (Key-Value Cache)**: Stores precomputed key-value pairs from transformer attention layers to avoid redundant computation for repeated prefixes.
|
||||
|
||||
**Inverted Index**: A data structure mapping terms (words) to documents containing them. Enables fast text search by allowing direct lookup of documents containing query terms.
|
||||
|
||||
**Count-Min Sketch**: A probabilistic data structure for frequency estimation with bounded error. Used for hot query detection and cache optimization.
|
||||
|
||||
### Retrieval & RAG
|
||||
|
||||
**RAG (Retrieval-Augmented Generation)**: An approach where LLMs generate responses using information retrieved from external knowledge bases, improving accuracy and reducing hallucination.
|
||||
|
||||
**ANN (Approximate Nearest Neighbor)**: Algorithms that find similar vectors quickly, trading exact results for speed. HNSW is an ANN algorithm.
|
||||
|
||||
**Hybrid Search**: Combining dense vector search (semantic similarity) with sparse keyword search (BM25) for better retrieval quality.
|
||||
|
||||
**Recall@K**: Retrieval metric - the percentage of relevant documents found in the top-K results. Higher recall means more relevant results are retrieved.
|
||||
|
||||
**Score Fusion**: Combining scores from multiple retrieval methods (e.g., BM25 + vector similarity) into a single ranked list.
|
||||
|
||||
### System Terms
|
||||
|
||||
**Micro-batching**: Grouping multiple requests together for parallel processing, improving GPU utilization and throughput.
|
||||
|
||||
**Admission Control**: System that decides whether to accept or reject incoming requests based on current load and resource availability.
|
||||
|
||||
**Rate Limiting**: Controlling the number of requests processed per unit time to prevent system overload.
|
||||
|
||||
**Token Budget**: Maximum number of tokens (sub-word units) that can be cached or processed within memory constraints.
|
||||
|
||||
**Prefix Sharing**: Technique where identical prompt prefixes across multiple sequences are stored once, reducing memory usage.
|
||||
|
||||
### Dataset & Evaluation
|
||||
|
||||
**BEIR (Benchmarking IR)**: A collection of diverse information retrieval benchmarks covering multiple domains.
|
||||
|
||||
**MS MARCO**: Large-scale passage ranking dataset used as a standard benchmark for information retrieval systems.
|
||||
|
||||
**FIQA**: Financial question-answering dataset from BEIR, used for domain-specific retrieval evaluation.
|
||||
|
||||
**Corpus**: A collection of documents used for indexing and retrieval testing.
|
||||
|
||||
**JSONL (JSON Lines)**: File format where each line is a valid JSON object, commonly used for large datasets.
|
||||
|
||||
### Technical Abbreviations
|
||||
|
||||
**LLM (Large Language Model)**: AI models trained on massive text corpora to understand and generate human-like text.
|
||||
|
||||
**IR (Information Retrieval)**: The field of study focused on finding relevant information from large collections of documents.
|
||||
|
||||
**API (Application Programming Interface)**: A set of functions and protocols for interacting with software components.
|
||||
|
||||
**O(log n)**: Logarithmic time complexity - the time to complete an operation grows logarithmically with input size, indicating efficient algorithms.
|
||||
|
||||
## Appendix
|
||||
|
||||
### Additional Resources
|
||||
|
||||
**Research Papers**: See `papers/` directory and `docs/CITATIONS.md` for referenced research papers.
|
||||
|
||||
**Primary Citations:**
|
||||
- **HNSW**: Malkov & Yashunin (2018). Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs. IEEE TPAMI, 42(4), 824-836.
|
||||
- **KV Cache**: Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation
|
||||
- **Count-Min Sketch**: Cormode & Muthukrishnan (2005). An improved data stream summary: the count-min sketch and its applications. Journal of Algorithms, 55(1), 58-75.
|
||||
- **BM25**: Robertson & Zaragoza (2009). The probabilistic relevance framework: BM25 and beyond. Foundations and Trends in Information Retrieval, 3(4), 333-389.
|
||||
|
||||
**Additional Papers:**
|
||||
- d-HNSW: Distributed HNSW for disaggregated memory
|
||||
- Fair-Count-Min: Fairness in frequency estimation
|
||||
- Memory-efficient sketches
|
||||
- Survey of Filtered Approximate Nearest Neighbor Search
|
||||
|
||||
See `docs/CITATIONS.md` for complete citation mapping to implementation code.
|
||||
|
||||
**Dataset Licenses**:
|
||||
- **MS MARCO**: Research use only - see [MS MARCO Terms](https://microsoft.github.io/msmarco/)
|
||||
- **BEIR (FIQA)**: Varies by task - check individual task licenses (typically CC-BY or similar)
|
||||
- **Amazon Reviews 2023**: CC BY 4.0
|
||||
- **Yelp**: See [Yelp Dataset License](https://www.yelp.com/dataset/license) - Research use allowed
|
||||
- **Wikipedia**: CC BY-SA 3.0 / GFDL
|
||||
|
||||
**Reproducibility Notes**:
|
||||
- All benchmarks use deterministic seeds (42) for reproducibility
|
||||
- **HNSW seed control**: The `HNSW` class accepts an optional `seed` parameter for reproducible graph structure. When a seed is provided, each HNSW instance uses its own `random.Random(seed)` state for level assignments, ensuring identical graph structures across runs.
|
||||
- Embeddings are generated deterministically based on document IDs
|
||||
- Benchmark results include hardware specifications
|
||||
- Exact corpus sizes and parameters are documented in result files
|
||||
|
||||
**HNSW Seed Usage**:
|
||||
```python
|
||||
from llmds.hnsw import HNSW
|
||||
|
||||
# Reproducible HNSW with fixed seed
|
||||
hnsw = HNSW(dim=384, M=16, ef_construction=200, ef_search=50, seed=42)
|
||||
|
||||
# Or use RetrievalPipeline (automatically uses seed=42 in benchmarks)
|
||||
from llmds.retrieval_pipeline import RetrievalPipeline
|
||||
pipeline = RetrievalPipeline(embedding_dim=384, seed=42)
|
||||
```
|
||||
|
||||
**Dependency Management**:
|
||||
- **Poetry**: Use `poetry.lock` (when available) for exact version pinning
|
||||
```bash
|
||||
poetry install # Uses poetry.lock for reproducible builds
|
||||
```
|
||||
- **pip**: Use `requirements.txt` and `requirements-dev.txt` for compatible version ranges
|
||||
```bash
|
||||
pip install -r requirements-dev.txt # Install all dependencies
|
||||
```
|
||||
- Both methods ensure reproducible builds across different environments
|
||||
- Python version: >=3.11 (see `.python-version` or `pyproject.toml`)
|
||||
|
||||
**Performance Baseline**:
|
||||
- Synthetic benchmarks (small data): Sub-millisecond latencies typical
|
||||
- Real corpus benchmarks (large data): Higher latencies due to realistic data distribution, cache behavior, and memory access patterns
|
||||
- Production systems typically see 10-100x latency increase from synthetic to real data
|
||||
|
||||
**Hardware Used for Benchmarks**:
|
||||
- System: macOS (Apple Silicon)
|
||||
- Python: 3.14.0
|
||||
- Performance varies by hardware and dataset characteristics
|
||||
|
||||
### Benchmark Result Files
|
||||
|
||||
- **Individual results**: `benchmarks/results/{corpus}/{date}/results.json`
|
||||
- **Combined CSV**: `benchmarks/results/benchmark_results.csv`
|
||||
- **Visualizations**: `benchmarks/figures/*.png`
|
||||
|
||||
### Contact & Support
|
||||
|
||||
For questions, issues, or contributions, please see:
|
||||
- **Contributing**: See [Contributing](#contributing) section above
|
||||
- **Code of Conduct**: See [Code of Conduct](#code-of-conduct) section above
|
||||
- **GitHub Issues**: Report bugs or request features via GitHub Issues
|
||||
227
data/README.md
Normal file
227
data/README.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Dataset Sources and Licenses
|
||||
|
||||
This document describes the datasets used for benchmarking the LLM RAG Data Structures Optimizer. All datasets are publicly available and suitable for research use.
|
||||
|
||||
## Datasets
|
||||
|
||||
### Datasets with Published Benchmark Results
|
||||
|
||||
We benchmark on three publicly available datasets with published results:
|
||||
|
||||
### 1. BEIR FIQA (Financial Question Answering)
|
||||
|
||||
**Source**: [BEIR Paper](https://arxiv.org/abs/2104.08663) | [Hugging Face Datasets](https://huggingface.co/datasets/BeIR)
|
||||
|
||||
**Description**: Financial question-answering dataset from BEIR benchmark suite. 50,000 documents with financial Q&A pairs. Used as primary evaluation dataset in our research.
|
||||
|
||||
**License**: Varies by task. Most BEIR tasks use CC-BY or similar open licenses. Check individual task licenses.
|
||||
|
||||
**Download**:
|
||||
```bash
|
||||
python scripts/download_corpus.py --source beir:fiqa --output data/raw/beir/fiqa
|
||||
```
|
||||
|
||||
**Citation**:
|
||||
```
|
||||
Thakur, N., et al. (2021). BEIR: A Heterogeneous Benchmark for Zero-shot Evaluation of Information Retrieval Models.
|
||||
```
|
||||
|
||||
### 2. Amazon Reviews 2023 (McAuley Lab)
|
||||
|
||||
**Source**: [Hugging Face - McAuley-Lab/Amazon-Reviews-2023](https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023)
|
||||
|
||||
**Description**: Large corpus of Amazon product reviews with metadata (ratings, categories, product IDs). Excellent for e-commerce-style RAG workloads. Benchmark results available for 10k subset.
|
||||
|
||||
**License**: CC BY 4.0
|
||||
|
||||
**Download**:
|
||||
```bash
|
||||
python scripts/download_corpus.py --source amazon23 --output data/raw/amazon23 --limit 500000
|
||||
```
|
||||
|
||||
**Note**: Full dataset is very large (>100M reviews). Use `--limit` for manageable subsets. Benchmark results use 10k document subset.
|
||||
|
||||
### 3. MS MARCO (Microsoft Machine Reading Comprehension)
|
||||
|
||||
**Source**: [MS MARCO Datasets](https://microsoft.github.io/msmarco/)
|
||||
|
||||
**Description**: Large-scale passage ranking dataset with 8.8M passages and 1M queries. Widely used as a canonical information retrieval benchmark. Benchmark results available for 10k subset.
|
||||
|
||||
**License**: Research use only. See [MS MARCO Terms](https://microsoft.github.io/msmarco/) for details.
|
||||
|
||||
**Download**:
|
||||
```bash
|
||||
python scripts/download_corpus.py --source msmarco --output data/raw/msmarco
|
||||
```
|
||||
|
||||
**Citation**:
|
||||
```
|
||||
Bajaj, P., et al. (2016). MS MARCO: A human generated machine reading comprehension dataset.
|
||||
```
|
||||
|
||||
### Additional Available Datasets
|
||||
|
||||
The following datasets are available in the codebase but do not yet have published benchmark results:
|
||||
|
||||
#### 4. Yelp Open Dataset
|
||||
|
||||
**Source**: [Yelp Open Dataset](https://www.yelp.com/dataset)
|
||||
|
||||
**Description**: Business listings and reviews from Yelp. Useful for local business and review-based RAG.
|
||||
|
||||
**License**: See [Yelp Dataset License](https://www.yelp.com/dataset/license). Research use allowed.
|
||||
|
||||
**Download**:
|
||||
```bash
|
||||
# First accept license at https://www.yelp.com/dataset/download
|
||||
python scripts/download_corpus.py --source yelp --output data/raw/yelp
|
||||
```
|
||||
|
||||
#### 5. Wikipedia (English)
|
||||
|
||||
**Source**: [Wikimedia Downloads](https://dumps.wikimedia.org/enwiki/latest/)
|
||||
|
||||
**Description**: English Wikipedia pages-articles dump. Broad factual corpus for general knowledge RAG.
|
||||
|
||||
**License**: CC BY-SA 3.0 and GFDL
|
||||
|
||||
**Download**:
|
||||
```bash
|
||||
python scripts/download_corpus.py --source wikipedia --output data/raw/wikipedia
|
||||
```
|
||||
|
||||
**Note**: Latest dump is ~20GB compressed. Extracts plain text and titles.
|
||||
|
||||
#### 6. Common Crawl (Optional)
|
||||
|
||||
**Source**: [Common Crawl](https://commoncrawl.org/) | [cc-downloader](https://github.com/commoncrawl/cc-downloader)
|
||||
|
||||
**Description**: Web-scale corpus from billions of web pages. Use for large-scale testing.
|
||||
|
||||
**License**: Public domain / various site licenses
|
||||
|
||||
**Download**:
|
||||
```bash
|
||||
# Be respectful of bandwidth - use specific months
|
||||
python scripts/download_corpus.py --source commoncrawl --cc-month CC-MAIN-2025-14 --output data/raw/cc --limit 10M
|
||||
```
|
||||
|
||||
**Note**: Common Crawl is extremely large. Use `--limit` and specific months for reproducible, manageable subsets.
|
||||
|
||||
## Data Format
|
||||
|
||||
All datasets are normalized to JSONL format:
|
||||
|
||||
```json
|
||||
{"id": "doc_123", "text": "Document text content...", "meta": {"field1": "value1", "field2": 42}}
|
||||
```
|
||||
|
||||
Each line contains:
|
||||
- `id`: Unique document identifier
|
||||
- `text`: Main text content
|
||||
- `meta`: Optional metadata (ratings, categories, timestamps, etc.)
|
||||
|
||||
## Checksums
|
||||
|
||||
Dataset checksums are stored in `data/dataset_cards/` as YAML files:
|
||||
|
||||
```yaml
|
||||
name: amazon_reviews_2023
|
||||
source: https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023
|
||||
license: CC BY 4.0
|
||||
sha256: <checksum>
|
||||
size_bytes: <size>
|
||||
download_date: 2024-10-30
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Download All Datasets
|
||||
|
||||
```bash
|
||||
# Create directories
|
||||
mkdir -p data/raw data/processed data/indices data/embeddings data/dataset_cards
|
||||
|
||||
# Download datasets (start with smaller ones)
|
||||
python scripts/download_corpus.py --source beir:fiqa --output data/raw/beir/fiqa
|
||||
python scripts/download_corpus.py --source amazon23 --output data/raw/amazon23 --limit 500000
|
||||
python scripts/download_corpus.py --source msmarco --output data/raw/msmarco
|
||||
```
|
||||
|
||||
### Prepare Embeddings
|
||||
|
||||
```bash
|
||||
python scripts/prepare_embeddings.py \
|
||||
--input data/raw/beir/fiqa/corpus.jsonl \
|
||||
--output data/embeddings/fiqa.npy \
|
||||
--dim 384 \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
### Build Indices
|
||||
|
||||
```bash
|
||||
python scripts/build_indices.py \
|
||||
--corpus data/raw/beir/fiqa/corpus.jsonl \
|
||||
--emb data/embeddings/fiqa.npy \
|
||||
--index-dir data/indices/fiqa \
|
||||
--bm25 \
|
||||
--hnsw \
|
||||
--ef 200 \
|
||||
--M 16
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
|
||||
```bash
|
||||
python scripts/run_benchmarks.py \
|
||||
--corpus fiqa \
|
||||
--sizes 10k 50k 100k \
|
||||
--ef 50 100 200 \
|
||||
--M 8 16 32 \
|
||||
--repetitions 5
|
||||
```
|
||||
|
||||
## License Compliance
|
||||
|
||||
**Important**:
|
||||
- Always check individual dataset licenses before use
|
||||
- **MS MARCO**: Research use only
|
||||
- **Amazon Reviews**: CC BY 4.0
|
||||
- **BEIR (FIQA)**: Varies by task, typically CC-BY or similar open licenses
|
||||
|
||||
**Do NOT**:
|
||||
- Scrape websites without permission
|
||||
- Redistribute datasets without proper attribution
|
||||
- Use commercial datasets for commercial purposes without checking licenses
|
||||
|
||||
## Reproducibility
|
||||
|
||||
All dataset processing is deterministic:
|
||||
- Fixed random seeds (42) for sampling and embeddings
|
||||
- SHA256 checksums for verification
|
||||
- Versioned dataset cards with download dates
|
||||
- Exact corpus sizes documented in benchmark results
|
||||
|
||||
## Dataset Statistics
|
||||
|
||||
### Datasets with Published Results
|
||||
|
||||
| Dataset | Documents | Size | License | Use Case | Benchmark Results |
|
||||
|---------|-----------|------|---------|----------|-------------------|
|
||||
| BEIR (FIQA) | 50,000 | ~13MB | Varies | Financial QA | Yes (10k, 25k, 50k subsets) |
|
||||
| Amazon Reviews 2023 | 100M+ | ~500GB+ | CC BY 4.0 | E-commerce | Yes (10k subset) |
|
||||
| MS MARCO | 8.8M passages | ~30GB | Research | IR benchmark | Yes (10k subset) |
|
||||
|
||||
### Available Datasets (No Published Results Yet)
|
||||
|
||||
| Dataset | Documents | Size | License | Use Case | Status |
|
||||
|---------|-----------|------|---------|----------|--------|
|
||||
| Yelp | ~8M businesses | ~8GB | Yelp License | Local business | Data available, no results |
|
||||
| Wikipedia | 6.7M articles | ~20GB | CC BY-SA 3.0 | General knowledge | Data available, no results |
|
||||
| Common Crawl | Billions | TB+ | Public domain | Web-scale | Code available, optional |
|
||||
|
||||
**Note**: Benchmark results are available for 10k document subsets of FIQA, Amazon23, and MS MARCO. FIQA has additional results for 25k and 50k document subsets. Yelp and Wikipedia datasets are available in the codebase but do not yet have published benchmark results.
|
||||
|
||||
*Full dataset statistics are approximate and vary by version. Benchmark results use manageable subsets for reproducible evaluation.*
|
||||
|
||||
186
docs/CITATIONS.md
Normal file
186
docs/CITATIONS.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Research Citations and Implementation Mapping
|
||||
|
||||
This document maps research papers to their implementations in the codebase.
|
||||
|
||||
## HNSW (Hierarchical Navigable Small World)
|
||||
|
||||
**Implementation:** `llmds/hnsw.py`
|
||||
|
||||
**Primary Citation:**
|
||||
- Malkov, Y. A., & Yashunin, D. A. (2018). Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs. IEEE transactions on pattern analysis and machine intelligence, 42(4), 824-836.
|
||||
|
||||
**Related Papers:**
|
||||
- Efficient Vector Search on Disaggregated Memory with d-HNSW (for memory-efficient variations)
|
||||
|
||||
**Techniques Implemented:**
|
||||
- Hierarchical multi-layer graph structure (`_layers`)
|
||||
- Greedy search algorithm (`_search_layer`)
|
||||
- Level assignment with exponential distribution (`_random_level`)
|
||||
- Entry point selection and navigation
|
||||
- Dynamic connection management (M parameter)
|
||||
- efConstruction and efSearch parameters for quality/speed trade-offs
|
||||
|
||||
**Code References:**
|
||||
- `HNSW` class: Main implementation
|
||||
- `_random_level()`: Level assignment following exponential distribution
|
||||
- `_search_layer()`: Greedy search in a single layer
|
||||
- `add()`: Vector insertion with connection management
|
||||
- `search()`: Multi-layer search from top to bottom
|
||||
|
||||
## KV Cache with Prefix Sharing
|
||||
|
||||
**Implementation:** `llmds/kv_cache.py`, `llmds/paged_allocator.py`
|
||||
|
||||
**Primary Citation:**
|
||||
- Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation (specific paper on KV cache optimization for RAG)
|
||||
|
||||
**Techniques Implemented:**
|
||||
- Paged allocation with fixed-size pages (`PagedAllocator`)
|
||||
- Prefix/prompt sharing with copy-on-write semantics (`KVCache._copy_if_shared`)
|
||||
- Hash-based deduplication (`_hash_prefix`)
|
||||
- Reference counting for shared pages (`_page_refs`)
|
||||
- Defensive copying to prevent corruption (`get()` returns deep copies)
|
||||
|
||||
**Code References:**
|
||||
- `KVCache` class: Main KV cache implementation
|
||||
- `PagedAllocator` class: Page-based memory management
|
||||
- `_copy_if_shared()`: Copy-on-write implementation
|
||||
- `_hash_prefix()`: SHA256-based prefix hashing
|
||||
- `attach()` / `detach()`: Sequence management with reference counting
|
||||
|
||||
## Count-Min Sketch
|
||||
|
||||
**Implementation:** `llmds/cmsketch.py`
|
||||
|
||||
**Primary Citations:**
|
||||
- Cormode, G., & Muthukrishnan, S. (2005). An improved data stream summary: the count-min sketch and its applications. Journal of Algorithms, 55(1), 58-75.
|
||||
- Fair-Count-Min: Frequency Estimation under Equal Group-wise Approximation Factor
|
||||
|
||||
**Techniques Implemented:**
|
||||
- Count-Min Sketch with multiple hash functions (`depth` parameter)
|
||||
- Conservative update strategy to reduce overestimation bias
|
||||
- Error bound calculation (`get_error_bound()`)
|
||||
- Hot item detection (`is_hot()`)
|
||||
|
||||
**Code References:**
|
||||
- `CountMinSketch` class: Main sketch implementation
|
||||
- `add()`: Conservative update algorithm
|
||||
- `estimate()`: Minimum across all hash rows
|
||||
- `get_error_bound()`: Theoretical error bound calculation
|
||||
- Uses MurmurHash3 for hash functions
|
||||
|
||||
## BM25 Inverted Index
|
||||
|
||||
**Implementation:** `llmds/inverted_index.py`
|
||||
|
||||
**Primary Citation:**
|
||||
- Robertson, S., & Zaragoza, H. (2009). The probabilistic relevance framework: BM25 and beyond. Foundations and Trends in Information Retrieval, 3(4), 333-389.
|
||||
|
||||
**Techniques Implemented:**
|
||||
- BM25 scoring formula with k1 and b parameters
|
||||
- Inverted index structure with compressed postings
|
||||
- Varint encoding for integer compression (`_encode_varint`)
|
||||
- Zigzag encoding for signed integers (`_zigzag_encode`)
|
||||
- Term frequency and document frequency tracking
|
||||
|
||||
**Code References:**
|
||||
- `InvertedIndex` class: Main inverted index implementation
|
||||
- `_bm25_score()`: BM25 scoring function
|
||||
- `add_document()`: Index construction
|
||||
- `search()`: BM25 retrieval
|
||||
|
||||
## Hybrid Retrieval (Dense + Sparse)
|
||||
|
||||
**Implementation:** `llmds/retrieval_pipeline.py`
|
||||
|
||||
**Primary Citation:**
|
||||
- Survey of Filtered Approximate Nearest Neighbor Search over the Vector-Scalar Hybrid Data
|
||||
|
||||
**Techniques Implemented:**
|
||||
- Hybrid dense (HNSW) + sparse (BM25) retrieval
|
||||
- Score fusion with configurable weights (`fusion_weight` parameter)
|
||||
- Top-K maintenance using indexed heap
|
||||
- Hot query caching using Count-Min Sketch
|
||||
|
||||
**Code References:**
|
||||
- `RetrievalPipeline` class: End-to-end hybrid retrieval
|
||||
- `search()`: Combines HNSW and BM25 with score fusion
|
||||
- Uses `IndexedHeap` for efficient top-K maintenance
|
||||
|
||||
## Indexed Heap
|
||||
|
||||
**Implementation:** `llmds/indexed_heap.py`
|
||||
|
||||
**Techniques Implemented:**
|
||||
- Indexed binary heap for O(log n) priority updates
|
||||
- Support for both min-heap and max-heap
|
||||
- O(1) key lookup via position map (`_pos` dictionary)
|
||||
- Decrease/increase key operations with correct bubble directions
|
||||
|
||||
**Code References:**
|
||||
- `IndexedHeap` class: Indexed heap implementation
|
||||
- `decrease_key()` / `increase_key()`: Key update operations
|
||||
- `_bubble_up()` / `_bubble_down()`: Heap property maintenance
|
||||
|
||||
## Scheduler and Batching
|
||||
|
||||
**Implementation:** `llmds/scheduler.py`, `llmds/admissions.py`
|
||||
|
||||
**Techniques Implemented:**
|
||||
- Dynamic micro-batching with configurable wait time
|
||||
- Priority queue using indexed heap
|
||||
- Admission control with QPS and token rate limiting
|
||||
- Moving window average for rate tracking
|
||||
|
||||
**Code References:**
|
||||
- `Scheduler` class: Batching scheduler
|
||||
- `AdmissionController` class: Rate limiting and admission control
|
||||
- Uses `IndexedHeap` for priority queue
|
||||
|
||||
## Token-Aware LRU
|
||||
|
||||
**Implementation:** `llmds/token_lru.py`
|
||||
|
||||
**Techniques Implemented:**
|
||||
- LRU eviction with token-aware budgeting
|
||||
- Cumulative token tracking across cache entries
|
||||
- Eviction based on token count rather than entry count
|
||||
|
||||
**Code References:**
|
||||
- `TokenLRU` class: Token-aware LRU cache
|
||||
- `total_tokens()`: Cumulative token tracking
|
||||
- `put()`: Token-aware insertion with eviction
|
||||
|
||||
---
|
||||
|
||||
## How to Cite
|
||||
|
||||
### Citing This Software
|
||||
|
||||
If you use this codebase in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@software{llm_rag_ds_optimizer,
|
||||
title = {LLM RAG Data Structures Optimizer},
|
||||
author = {Gutierrez, Carlos},
|
||||
email = {cgutierrez44833@ucumberlands.edu},
|
||||
year = {2025},
|
||||
url = {https://github.com/CarGDev/llm-rag-ds-optimizer}
|
||||
}
|
||||
```
|
||||
|
||||
### Citing Related Papers
|
||||
|
||||
When using this codebase in research, please also cite the relevant papers:
|
||||
|
||||
1. **HNSW**: Cite Malkov & Yashunin (2018) for HNSW algorithm
|
||||
2. **KV Cache**: Cite Cache-Craft paper for prefix sharing techniques
|
||||
3. **Count-Min Sketch**: Cite Cormode & Muthukrishnan (2005) for Count-Min Sketch
|
||||
4. **BM25**: Cite Robertson & Zaragoza (2009) for BM25 scoring
|
||||
5. **Hybrid Retrieval**: Cite survey paper for hybrid dense+sparse approaches
|
||||
|
||||
## Additional References
|
||||
|
||||
- Papers in `papers/` directory contain full citations and implementation details
|
||||
- See `README.md` for usage examples and performance benchmarks
|
||||
|
||||
252
docs/api.md
Normal file
252
docs/api.md
Normal file
@@ -0,0 +1,252 @@
|
||||
# API Reference
|
||||
|
||||
## Core Modules
|
||||
|
||||
### `llmds.paged_allocator.PagedAllocator`
|
||||
|
||||
Paged memory allocator with slab allocation.
|
||||
|
||||
**Methods:**
|
||||
- `alloc(num_pages: int) -> list[int]`: Allocate pages
|
||||
- `free(page_ids: list[int]) -> None`: Free pages
|
||||
- `stats() -> PageStats`: Get allocation statistics
|
||||
- `defragment() -> None`: Defragment pages
|
||||
|
||||
**Complexity:** O(1) alloc/free, O(n) defragment
|
||||
|
||||
### `llmds.kv_cache.KVCache`
|
||||
|
||||
KV cache with prefix sharing and deduplication. Implements copy-on-write (COW) for safe prefix sharing.
|
||||
|
||||
**Parameters:**
|
||||
- `page_size: int = 512` - Size of each KV cache page in tokens
|
||||
- `max_pages: int = 10000` - Maximum number of pages to allocate
|
||||
- `enable_prefix_sharing: bool = True` - Enable prefix sharing optimization
|
||||
|
||||
**Methods:**
|
||||
- `attach(seq_id: int, kv_tokens: list, prefix_tokens: Optional[list] = None) -> None` - Attach KV cache for a sequence. Uses COW for shared pages.
|
||||
- `detach(seq_id: int) -> None` - Detach and free KV cache, with reference counting for shared pages
|
||||
- `get(seq_id: int) -> Optional[list]` - Get KV cache (returns deep copy to prevent external modification)
|
||||
- `stats() -> dict` - Get cache statistics including shared pages count and reference counts
|
||||
|
||||
**Complexity:** O(1) attach/get, O(k) detach where k = pages
|
||||
|
||||
**Copy-on-Write Semantics:**
|
||||
- Shared pages (from prefix sharing) are read-only until written
|
||||
- Writes to shared pages trigger lazy copying (COW)
|
||||
- Reference counting ensures shared pages are only freed when all references are released
|
||||
- `get()` returns deep copies to prevent external corruption of shared pages
|
||||
|
||||
**Safety:** All shared page operations are protected against data corruption through COW and defensive copying.
|
||||
|
||||
### `llmds.utils.Timer`
|
||||
|
||||
Simple timer context manager for measuring execution time.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from llmds.utils import Timer
|
||||
|
||||
with Timer() as t:
|
||||
# Your code here
|
||||
pass
|
||||
elapsed_seconds = t.elapsed # Float representing elapsed time
|
||||
```
|
||||
|
||||
**Complexity:** O(1) for all operations
|
||||
|
||||
### `llmds.utils.MemoryProfiler`
|
||||
|
||||
Memory profiler for measuring peak RSS (Resident Set Size) during benchmarks.
|
||||
|
||||
**Methods:**
|
||||
- `start() -> None`: Start memory profiling
|
||||
- `sample() -> int`: Sample current RSS and update peak
|
||||
- `get_peak_rss_mb() -> float`: Get peak RSS in megabytes
|
||||
- `get_peak_rss_bytes() -> int`: Get peak RSS in bytes
|
||||
- `get_current_rss_mb() -> float`: Get current RSS in megabytes
|
||||
- `get_memory_delta_mb() -> float`: Get memory delta from initial RSS in megabytes
|
||||
|
||||
**Context Manager:**
|
||||
- `memory_profiler() -> Iterator[MemoryProfiler]`: Context manager for automatic profiling
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from llmds.utils import memory_profiler
|
||||
|
||||
with memory_profiler() as profiler:
|
||||
# Your code here
|
||||
profiler.sample() # Optional: sample at specific points
|
||||
peak_rss_mb = profiler.get_peak_rss_mb()
|
||||
```
|
||||
|
||||
**Complexity:** O(1) for all operations
|
||||
|
||||
### `llmds.utils.compute_percentiles`
|
||||
|
||||
Compute P50, P95, P99 percentiles from a list of values.
|
||||
|
||||
**Parameters:**
|
||||
- `values: list[float]` - List of numeric values
|
||||
|
||||
**Returns:**
|
||||
- `dict[str, float]` - Dictionary with `p50`, `p95`, `p99` keys
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from llmds.utils import compute_percentiles
|
||||
|
||||
latencies = [10.5, 12.3, 11.1, 15.2, 10.8, ...]
|
||||
percentiles = compute_percentiles(latencies)
|
||||
print(f"P50: {percentiles['p50']:.2f}ms")
|
||||
print(f"P95: {percentiles['p95']:.2f}ms")
|
||||
print(f"P99: {percentiles['p99']:.2f}ms")
|
||||
```
|
||||
|
||||
**Complexity:** O(n log n) where n = len(values)
|
||||
|
||||
### `llmds.utils.calculate_statistics`
|
||||
|
||||
Calculate comprehensive statistical summary for a list of values.
|
||||
|
||||
**Parameters:**
|
||||
- `values: list[float]` - List of numeric values
|
||||
- `confidence_level: float = 0.95` - Confidence level for intervals (e.g., 0.95 for 95% CI)
|
||||
|
||||
**Returns:**
|
||||
- `dict[str, Any]` - Dictionary containing:
|
||||
- `mean`: Mean value
|
||||
- `std`: Standard deviation (sample)
|
||||
- `min`: Minimum value
|
||||
- `max`: Maximum value
|
||||
- `p50`, `p95`, `p99`: Percentiles
|
||||
- `ci_lower`, `ci_upper`: Confidence interval bounds
|
||||
- `cv`: Coefficient of variation (%)
|
||||
- `count`: Number of values
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from llmds.utils import calculate_statistics
|
||||
|
||||
measurements = [10.5, 12.3, 11.1, 15.2, 10.8, ...]
|
||||
stats = calculate_statistics(measurements, confidence_level=0.95)
|
||||
print(f"Mean: {stats['mean']:.2f} ± {stats['std']:.2f}")
|
||||
print(f"95% CI: [{stats['ci_lower']:.2f}, {stats['ci_upper']:.2f}]")
|
||||
print(f"CV: {stats['cv']:.2f}%")
|
||||
```
|
||||
|
||||
**Complexity:** O(n log n) where n = len(values)
|
||||
|
||||
### `llmds.token_lru.TokenLRU`
|
||||
|
||||
Token-aware LRU cache with eviction until budget.
|
||||
|
||||
**Methods:**
|
||||
- `put(key: K, value: V) -> None`
|
||||
- `get(key: K) -> Optional[V]`
|
||||
- `evict_until_budget(target_budget: int) -> list[tuple[K, V]]`
|
||||
- `total_tokens() -> int`
|
||||
|
||||
**Complexity:** O(1) put/get, O(n) evict_until_budget
|
||||
|
||||
### `llmds.indexed_heap.IndexedHeap`
|
||||
|
||||
Indexed binary heap with decrease/increase-key operations. Supports both min-heap and max-heap modes.
|
||||
|
||||
**Parameters:**
|
||||
- `max_heap: bool = False` - If True, use max-heap (largest score at top), otherwise min-heap
|
||||
|
||||
**Methods:**
|
||||
- `push(key_id: int, score: float) -> None` - Add item to heap
|
||||
- `pop() -> tuple[float, int]` - Remove and return top element
|
||||
- `decrease_key(key_id: int, new_score: float) -> None` - Decrease key value (bubbles down for max-heap, up for min-heap)
|
||||
- `increase_key(key_id: int, new_score: float) -> None` - Increase key value (bubbles up for max-heap, down for min-heap)
|
||||
- `delete(key_id: int) -> tuple[float, int]` - Remove specific item
|
||||
- `get_score(key_id: int) -> Optional[float]` - Get score for key_id
|
||||
- `peek() -> Optional[tuple[float, int]]` - View top element without removing
|
||||
- `size() -> int` - Get number of elements
|
||||
- `is_empty() -> bool` - Check if heap is empty
|
||||
|
||||
**Complexity:** O(log n) for all operations
|
||||
|
||||
**Note:** Fixed max-heap bubble directions (v0.1.0) - `decrease_key` bubbles down and `increase_key` bubbles up for max-heap.
|
||||
|
||||
### `llmds.scheduler.Scheduler`
|
||||
|
||||
Dynamic micro-batching scheduler.
|
||||
|
||||
**Methods:**
|
||||
- `submit(tokens: int, slo_ms: Optional[float] = None) -> int`
|
||||
- `get_batch(force: bool = False) -> Optional[list[int]]`
|
||||
- `complete_batch(request_ids: list[int]) -> None`
|
||||
- `update_priority(request_id: int, new_tokens: int) -> None`
|
||||
|
||||
**Complexity:** O(log n) submit, O(k log n) get_batch where k = batch_size
|
||||
|
||||
### `llmds.admissions.AdmissionController`
|
||||
|
||||
Admission controller with rate limiting.
|
||||
|
||||
**Methods:**
|
||||
- `should_admit(estimated_tokens: int = 0) -> tuple[bool, str]`
|
||||
- `record_request(tokens: int) -> None`
|
||||
- `stats() -> dict`: Get admission statistics
|
||||
|
||||
**Complexity:** O(1) should_admit
|
||||
|
||||
### `llmds.inverted_index.InvertedIndex`
|
||||
|
||||
Compressed inverted index with BM25 scoring.
|
||||
|
||||
**Methods:**
|
||||
- `add_document(doc_id: int, text: str) -> None`
|
||||
- `search(query: str, top_k: int = 10) -> list[tuple[int, float]]`
|
||||
- `get_term_frequency(term: str, doc_id: int) -> int`
|
||||
- `get_document_frequency(term: str) -> int`
|
||||
|
||||
**Complexity:** O(|doc|) add_document, O(|query| × avg_doc_freq) search
|
||||
|
||||
### `llmds.hnsw.HNSW`
|
||||
|
||||
Hierarchical Navigable Small World graph for approximate nearest neighbor search.
|
||||
|
||||
**Parameters:**
|
||||
- `dim: int` - Dimension of vectors
|
||||
- `M: int = 16` - Maximum number of connections per node
|
||||
- `ef_construction: int = 200` - Size of candidate set during construction
|
||||
- `ef_search: int = 50` - Size of candidate set during search
|
||||
- `ml: float = 1.0 / log(2.0)` - Normalization factor for level assignment
|
||||
- `seed: Optional[int] = None` - Random seed for reproducible graph structure
|
||||
|
||||
**Methods:**
|
||||
- `add(vec: np.ndarray, vec_id: int) -> None` - Add vector to index
|
||||
- `search(query: np.ndarray, k: int) -> list[tuple[int, float]]` - Search for k nearest neighbors. Returns list of (vector_id, distance) tuples
|
||||
- `stats() -> dict` - Get index statistics (num_vectors, num_layers, entry_point, etc.)
|
||||
|
||||
**Complexity:** O(log n) search, O(log n × efConstruction) add
|
||||
|
||||
**Reproducibility:** When `seed` is provided, each HNSW instance uses its own `random.Random(seed)` state for level assignments, ensuring identical graph structures across runs with the same seed.
|
||||
|
||||
### `llmds.cmsketch.CountMinSketch`
|
||||
|
||||
Count-Min Sketch for frequency estimation.
|
||||
|
||||
**Methods:**
|
||||
- `add(item: str, count: int = 1) -> None`
|
||||
- `estimate(item: str) -> int`
|
||||
- `is_hot(item: str, threshold: int) -> bool`
|
||||
- `get_error_bound() -> float`
|
||||
|
||||
**Complexity:** O(depth) add/estimate
|
||||
|
||||
### `llmds.retrieval_pipeline.RetrievalPipeline`
|
||||
|
||||
End-to-end retrieval pipeline.
|
||||
|
||||
**Methods:**
|
||||
- `add_document(doc_id: int, text: str, embedding: Optional[np.ndarray] = None) -> None`
|
||||
- `search(query: str, query_embedding: Optional[np.ndarray] = None, top_k: int = 10, fusion_weight: float = 0.5) -> list[tuple[int, float]]`
|
||||
- `stats() -> dict`: Get pipeline statistics
|
||||
|
||||
**Complexity:** O(log n) search (HNSW) + O(|query| × avg_doc_freq) (BM25)
|
||||
|
||||
161
docs/architecture.md
Normal file
161
docs/architecture.md
Normal file
@@ -0,0 +1,161 @@
|
||||
# Architecture Overview
|
||||
|
||||
## System Architecture
|
||||
|
||||
The LLM Data Structures Optimizer is organized into several key subsystems:
|
||||
|
||||
### 1. KV Cache System
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ KVCache │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ Prefix Hash Map │ │
|
||||
│ │ (SHA256-based deduplication) │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ Sequence → Page Mapping │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ PagedAllocator │ │
|
||||
│ │ - Fixed-size pages │ │
|
||||
│ │ - Free-list management │ │
|
||||
│ │ - Defragmentation │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- **Copy-on-write (COW)** for prefix sharing - shared pages are read-only until modified, then lazily copied
|
||||
- **Reference counting** - shared pages are tracked and only freed when all references are released
|
||||
- **Hash-based deduplication** - identical prefixes are automatically detected and shared
|
||||
- **Page-level allocation granularity** - efficient memory management with fixed-size pages
|
||||
- **Defensive copying** - `get()` returns deep copies to prevent external modification of shared data
|
||||
|
||||
### 2. Scheduler & Batching
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ Scheduler │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ IndexedHeap (Max-Heap Priority Queue) │ │
|
||||
│ │ - O(log n) decrease/increase-key │ │
|
||||
│ │ - Priority by remaining tokens │ │
|
||||
│ │ - Fixed bubble directions (v0.1.0) │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ AdmissionController │ │
|
||||
│ │ - QPS limiting │ │
|
||||
│ │ - Token rate limiting │ │
|
||||
│ │ - Moving window average │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Dynamic micro-batching with configurable wait time
|
||||
- SLO-aware prioritization
|
||||
- Rate limiting and admission control
|
||||
|
||||
### 3. Retrieval Pipeline
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ RetrievalPipeline │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ HNSW (Dense Search) │ │
|
||||
│ │ - Hierarchical graph │ │
|
||||
│ │ - Approximate nearest neighbor │ │
|
||||
│ │ - Reproducible via seed parameter │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ InvertedIndex (BM25) │ │
|
||||
│ │ - Compressed postings │ │
|
||||
│ │ - Varint/zigzag encoding │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ Score Fusion │ │
|
||||
│ │ - Weighted combination │ │
|
||||
│ │ - Top-K heap maintenance │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
│ ┌───────────────────────────────────┐ │
|
||||
│ │ CountMinSketch │ │
|
||||
│ │ - Hot query detection │ │
|
||||
│ │ - Cache priming │ │
|
||||
│ └───────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Hybrid dense + sparse retrieval
|
||||
- Score fusion with configurable weights
|
||||
- Hot query caching
|
||||
|
||||
## Data Flow
|
||||
|
||||
### KV Cache Flow
|
||||
|
||||
1. **Attach Sequence**: Allocate pages, hash prefix, check for sharing
|
||||
2. **Get Sequence**: Retrieve pages, reconstruct KV tokens
|
||||
3. **Detach Sequence**: Free pages, update statistics
|
||||
|
||||
### Scheduler Flow
|
||||
|
||||
1. **Submit Request**: Add to priority queue, update admission stats
|
||||
2. **Get Batch**: Check wait time, pop top-k requests
|
||||
3. **Complete Batch**: Remove from queue, update metrics
|
||||
|
||||
### Retrieval Flow
|
||||
|
||||
1. **Index Building**: Add documents to HNSW and inverted index
|
||||
2. **Query Processing**:
|
||||
- Dense search (HNSW)
|
||||
- Sparse search (BM25)
|
||||
- Score fusion
|
||||
- Top-K selection
|
||||
3. **Caching**: Check CMS for hot queries, cache results
|
||||
|
||||
## Memory Management
|
||||
|
||||
### Token Budgeting
|
||||
|
||||
- Global token budget manager tracks:
|
||||
- KV cache tokens
|
||||
- Prompt tokens
|
||||
- Context window tokens
|
||||
|
||||
### Page Allocation
|
||||
|
||||
- Fixed-size pages reduce fragmentation
|
||||
- Free-list management for O(1) allocation
|
||||
- Periodic defragmentation for compaction
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
### Time Complexities
|
||||
|
||||
- **KV Cache**: O(1) attach/get, O(k) detach (k = pages)
|
||||
- **Indexed Heap**: O(log n) push/pop/update
|
||||
- **HNSW Search**: O(log n) approximate nearest neighbor
|
||||
- **BM25 Search**: O(|query| × avg_doc_freq)
|
||||
|
||||
### Space Complexities
|
||||
|
||||
- **KV Cache**: O(sequences × tokens_per_seq)
|
||||
- **HNSW**: O(n × M) where M = max connections
|
||||
- **Inverted Index**: O(|vocab| × avg_postings)
|
||||
|
||||
## Trade-offs
|
||||
|
||||
### Page Size
|
||||
- **Small pages**: Better memory utilization, higher overhead
|
||||
- **Large pages**: Lower overhead, more fragmentation
|
||||
|
||||
### Batch Size
|
||||
- **Small batches**: Lower latency, lower throughput
|
||||
- **Large batches**: Higher throughput, higher latency
|
||||
|
||||
### HNSW Parameters
|
||||
- **M (connections)**: Higher = better recall, more memory
|
||||
- **efSearch**: Higher = better recall, slower search
|
||||
|
||||
537
docs/mathematical_models.md
Normal file
537
docs/mathematical_models.md
Normal file
@@ -0,0 +1,537 @@
|
||||
# Mathematical Models
|
||||
|
||||
This document describes the mathematical formulations and algorithms used throughout the LLM RAG Data Structures Optimizer.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [BM25 Ranking Function](#bm25-ranking-function)
|
||||
- [HNSW Distance Metrics](#hnsw-distance-metrics)
|
||||
- [Count-Min Sketch Error Bounds](#count-min-sketch-error-bounds)
|
||||
- [Score Fusion](#score-fusion)
|
||||
- [KV Cache Memory Calculation](#kv-cache-memory-calculation)
|
||||
- [Token-Aware LRU Eviction](#token-aware-lru-eviction)
|
||||
- [Admission Control Rate Limiting](#admission-control-rate-limiting)
|
||||
|
||||
---
|
||||
|
||||
## BM25 Ranking Function
|
||||
|
||||
BM25 (Best Matching 25) is a probabilistic ranking function used for information retrieval. It scores documents based on term frequency and inverse document frequency.
|
||||
|
||||
### Formula
|
||||
|
||||
For a query $Q = \{q_1, q_2, \ldots, q_n\}$ and document $D$, the BM25 score is:
|
||||
|
||||
$$
|
||||
\text{BM25}(D, Q) = \sum_{i=1}^{n} \text{IDF}(q_i) \cdot \frac{f(q_i, D) \cdot (k_1 + 1)}{f(q_i, D) + k_1 \cdot (1 - b + b \cdot \frac{|D|}{\text{avgdl}})}
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $f(q_i, D)$ = frequency of term $q_i$ in document $D$
|
||||
- $|D|$ = length of document $D$ (number of terms)
|
||||
- $\text{avgdl}$ = average document length in the collection
|
||||
- $k_1$ = term frequency saturation parameter (typically 1.2-2.0)
|
||||
- $b$ = length normalization parameter (typically 0.75)
|
||||
|
||||
### Inverse Document Frequency (IDF)
|
||||
|
||||
$$
|
||||
\text{IDF}(q_i) = \log \frac{N - n(q_i) + 0.5}{n(q_i) + 0.5}
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $N$ = total number of documents in the collection
|
||||
- $n(q_i)$ = number of documents containing term $q_i$
|
||||
|
||||
The 0.5 smoothing factor prevents division by zero and handles terms that appear in all documents.
|
||||
|
||||
### Implementation Defaults
|
||||
|
||||
In our implementation:
|
||||
- $k_1 = 1.5$ (default)
|
||||
- $b = 0.75$ (default)
|
||||
|
||||
---
|
||||
|
||||
## HNSW Distance Metrics
|
||||
|
||||
Hierarchical Navigable Small World (HNSW) uses distance metrics to measure similarity between vectors. The default distance metric is **L2 (Euclidean) distance**.
|
||||
|
||||
### L2 Distance (Euclidean)
|
||||
|
||||
For vectors $\vec{u} = (u_1, u_2, \ldots, u_d)$ and $\vec{v} = (v_1, v_2, \ldots, v_d)$:
|
||||
|
||||
$$
|
||||
d_{\text{L2}}(\vec{u}, \vec{v}) = \sqrt{\sum_{i=1}^{d} (u_i - v_i)^2}
|
||||
$$
|
||||
|
||||
In practice, we often use squared L2 distance for efficiency (monotonic with L2):
|
||||
|
||||
$$
|
||||
d_{\text{L2}}^2(\vec{u}, \vec{v}) = \sum_{i=1}^{d} (u_i - v_i)^2
|
||||
$$
|
||||
|
||||
### Cosine Similarity (Alternative)
|
||||
|
||||
For normalized vectors, cosine similarity is often preferred:
|
||||
|
||||
$$
|
||||
\text{cosine}(\vec{u}, \vec{v}) = \frac{\vec{u} \cdot \vec{v}}{||\vec{u}|| \cdot ||\vec{v}||} = \frac{\sum_{i=1}^{d} u_i \cdot v_i}{\sqrt{\sum_{i=1}^{d} u_i^2} \cdot \sqrt{\sum_{i=1}^{d} v_i^2}}
|
||||
$$
|
||||
|
||||
For normalized vectors where $||\vec{u}|| = ||\vec{v}|| = 1$:
|
||||
|
||||
$$
|
||||
\text{cosine}(\vec{u}, \vec{v}) = \vec{u} \cdot \vec{v} = \sum_{i=1}^{d} u_i \cdot v_i
|
||||
$$
|
||||
|
||||
**Note**: Cosine similarity can be converted to distance: $d_{\text{cosine}} = 1 - \text{cosine}(\vec{u}, \vec{v})$
|
||||
|
||||
### HNSW Graph Properties
|
||||
|
||||
The HNSW graph has logarithmic search complexity:
|
||||
|
||||
- **Search complexity**: $O(\log N)$ where $N$ is the number of vectors
|
||||
- **Construction complexity**: $O(N \log N)$
|
||||
- **Memory complexity**: $O(N \cdot M)$ where $M$ is the maximum connections per node
|
||||
|
||||
**Return Format**: The `search()` and `_search_layer()` methods return results as `(node_id, distance)` tuples, where:
|
||||
- `node_id`: Integer identifier of the vector in the index
|
||||
- `distance`: Float representing the L2 distance from the query vector
|
||||
|
||||
---
|
||||
|
||||
## Count-Min Sketch Error Bounds
|
||||
|
||||
Count-Min Sketch is a probabilistic data structure for frequency estimation with bounded error.
|
||||
|
||||
### Structure
|
||||
|
||||
A Count-Min Sketch has width $w$ and depth $d$, creating a $d \times w$ table of counters.
|
||||
|
||||
### Update Operation
|
||||
|
||||
For item $x$ with count $c$, update all $d$ rows:
|
||||
|
||||
$$
|
||||
\text{CM}[i][h_i(x)] \leftarrow \text{CM}[i][h_i(x)] + c, \quad \forall i \in \{1, 2, \ldots, d\}
|
||||
$$
|
||||
|
||||
Where $h_i(x)$ is a hash function for row $i$.
|
||||
|
||||
### Estimation
|
||||
|
||||
The estimated frequency is the minimum across all rows:
|
||||
|
||||
$$
|
||||
\hat{f}(x) = \min_{i \in \{1, \ldots, d\}} \text{CM}[i][h_i(x)]
|
||||
$$
|
||||
|
||||
### Error Bound
|
||||
|
||||
With probability at least $1 - \delta$, the error is bounded by:
|
||||
|
||||
$$
|
||||
\hat{f}(x) - f(x) \leq \epsilon \cdot ||\mathbf{f}||_1
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $f(x)$ = true frequency of $x$
|
||||
- $||\mathbf{f}||_1$ = total count of all items (L1 norm)
|
||||
- $\epsilon = \frac{e}{w}$ (where $e \approx 2.71828$)
|
||||
- $\delta = \left(\frac{1}{2}\right)^d$
|
||||
|
||||
### Parameter Selection
|
||||
|
||||
To achieve error bound $\epsilon$ with probability $1 - \delta$:
|
||||
|
||||
$$
|
||||
w = \left\lceil \frac{e}{\epsilon} \right\rceil
|
||||
$$
|
||||
$$
|
||||
d = \left\lceil \ln \frac{1}{\delta} \right\rceil
|
||||
$$
|
||||
|
||||
**Default parameters** in our implementation:
|
||||
- $w = 2048$ → $\epsilon \approx 0.0013$
|
||||
- $d = 4$ → $\delta = 0.0625$ (6.25% error probability)
|
||||
|
||||
---
|
||||
|
||||
## Score Fusion
|
||||
|
||||
Hybrid search combines scores from multiple retrieval methods (dense vectors and sparse keywords).
|
||||
|
||||
### Weighted Linear Combination
|
||||
|
||||
$$
|
||||
S_{\text{fused}}(d, q) = \alpha \cdot S_{\text{dense}}(d, q) + \beta \cdot S_{\text{sparse}}(d, q)
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $S_{\text{dense}}(d, q)$ = normalized vector similarity score
|
||||
- $S_{\text{sparse}}(d, q)$ = normalized BM25 score
|
||||
- $\alpha + \beta = 1$ (typically $\alpha = 0.7$, $\beta = 0.3$)
|
||||
|
||||
### Score Normalization
|
||||
|
||||
Before fusion, scores are normalized to [0, 1] range:
|
||||
|
||||
$$
|
||||
S_{\text{norm}}(d, q) = \frac{S(d, q) - S_{\min}}{S_{\max} - S_{\min}}
|
||||
$$
|
||||
|
||||
Where $S_{\min}$ and $S_{\max}$ are the minimum and maximum scores in the candidate set.
|
||||
|
||||
### Reciprocal Rank Fusion (Alternative)
|
||||
|
||||
$$
|
||||
S_{\text{RRF}}(d) = \sum_{r \in R} \frac{1}{k + \text{rank}_r(d)}
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $R$ = set of ranked lists to fuse
|
||||
- $\text{rank}_r(d)$ = rank of document $d$ in ranked list $r$
|
||||
- $k$ = smoothing parameter (typically 60)
|
||||
|
||||
---
|
||||
|
||||
## KV Cache Memory Calculation
|
||||
|
||||
The KV cache memory usage depends on the number of cached tokens and the model dimensions.
|
||||
|
||||
### Per-Sequence Memory
|
||||
|
||||
For a sequence with $T$ tokens and model with hidden dimension $d$:
|
||||
|
||||
$$
|
||||
M_{\text{sequence}} = 2 \cdot T \cdot d \cdot \text{bytes\_per\_element}
|
||||
$$
|
||||
|
||||
Where:
|
||||
- Factor of 2 accounts for both key and value tensors
|
||||
- $\text{bytes\_per\_element} = 4$ for float32, $2$ for float16
|
||||
|
||||
### Paged Allocation
|
||||
|
||||
With page size $P$ pages and page capacity $C$ tokens per page:
|
||||
|
||||
$$
|
||||
M_{\text{paged}} = \left\lceil \frac{T}{C} \right\rceil \cdot P \cdot \text{page\_overhead}
|
||||
$$
|
||||
|
||||
Where $\text{page\_overhead}$ includes page metadata.
|
||||
|
||||
### Prefix Sharing Memory Savings
|
||||
|
||||
If $N$ sequences share a prefix of length $L$:
|
||||
|
||||
$$
|
||||
M_{\text{shared}} = L \cdot d \cdot \text{bytes\_per\_element}
|
||||
$$
|
||||
$$
|
||||
M_{\text{without\_sharing}} = N \cdot L \cdot d \cdot \text{bytes\_per\_element}
|
||||
$$
|
||||
|
||||
Memory savings:
|
||||
|
||||
$$
|
||||
\text{Savings} = (N - 1) \cdot L \cdot d \cdot \text{bytes\_per\_element}
|
||||
$$
|
||||
|
||||
Savings ratio:
|
||||
|
||||
$$
|
||||
\text{Savings Ratio} = \frac{N - 1}{N} = 1 - \frac{1}{N}
|
||||
$$
|
||||
|
||||
For large $N$, this approaches 100% savings on shared prefixes.
|
||||
|
||||
### Copy-on-Write Overhead
|
||||
|
||||
With copy-on-write (COW), if $K$ sequences modify shared pages:
|
||||
|
||||
$$
|
||||
M_{\text{with\_cow}} = (N - K) \cdot L_{\text{shared}} \cdot d \cdot \text{bytes\_per\_element} + K \cdot L_{\text{modified}} \cdot d \cdot \text{bytes\_per\_element}
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $L_{\text{shared}}$ = length of shared (unmodified) prefix pages
|
||||
- $L_{\text{modified}}$ = length of modified prefix pages (copied)
|
||||
|
||||
**COW Efficiency:**
|
||||
- If no sequences modify shared pages ($K = 0$): Maximum savings (shared pages stored once)
|
||||
- If all sequences modify ($K = N$): No savings (each has own copy)
|
||||
- Typical case ($K < N$): Partial savings based on modification rate
|
||||
|
||||
**Reference Counting:**
|
||||
Shared pages are freed when reference count $r$ reaches zero:
|
||||
|
||||
$$
|
||||
r = \sum_{i=1}^{N} \mathbf{1}_{\text{seq}_i \text{ references page}}
|
||||
$$
|
||||
|
||||
Where $\mathbf{1}$ is the indicator function (1 if sequence references page, 0 otherwise).
|
||||
|
||||
---
|
||||
|
||||
## Token-Aware LRU Eviction
|
||||
|
||||
Token-aware LRU maintains a cumulative token budget while evicting least recently used items.
|
||||
|
||||
### Eviction Criterion
|
||||
|
||||
Evict item $i$ with minimum value of:
|
||||
|
||||
$$
|
||||
\text{priority}(i) = \frac{\text{access\_count}(i)}{\text{token\_count}(i)}
|
||||
$$
|
||||
|
||||
Or use recency-weighted:
|
||||
|
||||
$$
|
||||
\text{priority}(i) = \frac{\text{last\_access\_time}(i)}{\text{token\_count}(i)}
|
||||
$$
|
||||
|
||||
### Token Budget Constraint
|
||||
|
||||
Maintain total tokens below budget $B$:
|
||||
|
||||
$$
|
||||
\sum_{i \in \text{cache}} \text{token\_count}(i) \leq B
|
||||
$$
|
||||
|
||||
When adding item $j$ with $t_j$ tokens:
|
||||
|
||||
1. If $\sum_{i} t_i + t_j \leq B$: add item
|
||||
2. Else: evict items until $\sum_{i} t_i + t_j \leq B$
|
||||
|
||||
### Eviction Algorithm
|
||||
|
||||
```
|
||||
while total_tokens + new_tokens > budget:
|
||||
item = item_with_min_priority()
|
||||
total_tokens -= token_count(item)
|
||||
evict(item)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Admission Control Rate Limiting
|
||||
|
||||
The admission controller uses an exponentially weighted moving average (EWMA) to track request rate.
|
||||
|
||||
### Moving Average Update
|
||||
|
||||
$$
|
||||
\bar{r}_{t} = \alpha \cdot r_t + (1 - \alpha) \cdot \bar{r}_{t-1}
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $r_t$ = current request rate
|
||||
- $\bar{r}_t$ = smoothed average rate
|
||||
- $\alpha$ = smoothing factor (typically 0.1-0.3)
|
||||
|
||||
### Admission Decision
|
||||
|
||||
Admit request if:
|
||||
|
||||
$$
|
||||
\bar{r}_t + \text{margin} \leq R_{\max}
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $R_{\max}$ = maximum allowed rate (QPS limit)
|
||||
- $\text{margin}$ = safety margin to account for burstiness
|
||||
|
||||
### Token Bucket (Alternative)
|
||||
|
||||
The token bucket algorithm allows bursty traffic:
|
||||
|
||||
$$
|
||||
\text{tokens}(t) = \min(B, \text{tokens}(t-1) + R \cdot \Delta t)
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $B$ = bucket capacity (burst limit)
|
||||
- $R$ = token refill rate (sustainable rate)
|
||||
- $\Delta t$ = time since last update
|
||||
|
||||
Request is admitted if $\text{tokens}(t) \geq 1$, then $\text{tokens}(t) \leftarrow \text{tokens}(t) - 1$.
|
||||
|
||||
---
|
||||
|
||||
## Indexed Binary Heap
|
||||
|
||||
The indexed heap supports O(log n) priority updates with decrease/increase-key operations. Supports both min-heap and max-heap modes.
|
||||
|
||||
### Heap Property
|
||||
|
||||
**Min-heap**: `parent_score ≤ child_score`
|
||||
**Max-heap**: `parent_score ≥ child_score`
|
||||
|
||||
### Decrease/Increase Key Operations
|
||||
|
||||
**Min-heap:**
|
||||
- `decrease_key(new_score < old_score)`: Bubbles UP (score decreases → higher priority)
|
||||
- `increase_key(new_score > old_score)`: Bubbles DOWN (score increases → lower priority)
|
||||
|
||||
**Max-heap:**
|
||||
- `decrease_key(new_score < old_score)`: Bubbles DOWN (score decreases → lower priority) Fixed in v0.1.0
|
||||
- `increase_key(new_score > old_score)`: Bubbles UP (score increases → higher priority) Fixed in v0.1.0
|
||||
|
||||
### Complexity
|
||||
|
||||
- Push: O(log n)
|
||||
- Pop: O(log n)
|
||||
- Decrease/Increase key: O(log n)
|
||||
- Delete: O(log n)
|
||||
|
||||
### Heap Properties
|
||||
|
||||
For a min-heap with $n$ elements:
|
||||
|
||||
- Parent of node $i$: $\lfloor (i-1)/2 \rfloor$
|
||||
- Left child of node $i$: $2i + 1$
|
||||
- Right child of node $i$: $2i + 2$
|
||||
|
||||
### Heap Invariant
|
||||
|
||||
$$
|
||||
\text{priority}(\text{parent}(i)) \leq \text{priority}(i), \quad \forall i > 0
|
||||
$$
|
||||
|
||||
### Complexity
|
||||
|
||||
- Insert: $O(\log n)$
|
||||
- Extract min: $O(\log n)$
|
||||
- Update key: $O(\log n)$ (with index mapping)
|
||||
- Decrease key: $O(\log n)$
|
||||
|
||||
---
|
||||
|
||||
## Variance Analysis and Statistical Confidence
|
||||
|
||||
Benchmark results include variance analysis to assess measurement reliability and identify flaky configurations.
|
||||
|
||||
### Statistical Summary
|
||||
|
||||
For a set of $n$ measurements $\{x_1, x_2, \ldots, x_n\}$:
|
||||
|
||||
**Mean:**
|
||||
$$
|
||||
\bar{x} = \frac{1}{n} \sum_{i=1}^{n} x_i
|
||||
$$
|
||||
|
||||
**Standard Deviation (Sample):**
|
||||
$$
|
||||
s = \sqrt{\frac{1}{n-1} \sum_{i=1}^{n} (x_i - \bar{x})^2}
|
||||
$$
|
||||
|
||||
**Coefficient of Variation:**
|
||||
$$
|
||||
\text{CV} = \frac{s}{\bar{x}} \times 100\%
|
||||
$$
|
||||
|
||||
The CV expresses relative variability as a percentage, making it easier to compare variance across different metrics and scales.
|
||||
|
||||
### Confidence Intervals
|
||||
|
||||
For small samples ($n < 30$), we use the t-distribution for confidence intervals:
|
||||
|
||||
$$
|
||||
\text{CI} = \bar{x} \pm t_{\alpha/2, n-1} \cdot \frac{s}{\sqrt{n}}
|
||||
$$
|
||||
|
||||
Where:
|
||||
- $t_{\alpha/2, n-1}$ = t-critical value for $\alpha$ significance level and $n-1$ degrees of freedom
|
||||
- For 95% confidence: $\alpha = 0.05$, so $t_{0.025, n-1}$
|
||||
|
||||
For large samples ($n \geq 30$), we approximate with the normal distribution:
|
||||
$$
|
||||
\text{CI} = \bar{x} \pm z_{\alpha/2} \cdot \frac{s}{\sqrt{n}}
|
||||
$$
|
||||
|
||||
Where $z_{\alpha/2} = 1.96$ for 95% confidence.
|
||||
|
||||
### Flaky Benchmark Detection
|
||||
|
||||
A benchmark configuration is considered **flaky** if:
|
||||
|
||||
$$
|
||||
\text{CV} > \text{threshold}
|
||||
$$
|
||||
|
||||
Where the default threshold is 20% (coefficient of variation > 20%).
|
||||
|
||||
**Interpretation:**
|
||||
- **CV < 10%**: Excellent reproducibility
|
||||
- **10% ≤ CV < 20%**: Good reproducibility
|
||||
- **20% ≤ CV < 50%**: Moderate variance (flagged as potentially flaky)
|
||||
- **CV ≥ 50%**: High variance (likely flaky, investigate)
|
||||
|
||||
### Variance Metrics Reported
|
||||
|
||||
For each metric (e.g., `search_p50_ms`, `qps`), we report:
|
||||
|
||||
- `{metric}_mean`: Mean across repetitions
|
||||
- `{metric}_std`: Standard deviation
|
||||
- `{metric}_min`: Minimum value
|
||||
- `{metric}_max`: Maximum value
|
||||
- `{metric}_ci_lower`: Lower bound of 95% confidence interval
|
||||
- `{metric}_ci_upper`: Upper bound of 95% confidence interval
|
||||
- `{metric}_cv`: Coefficient of variation (%)
|
||||
|
||||
### Example
|
||||
|
||||
For a benchmark with 5 repetitions producing search P50 latencies:
|
||||
$$
|
||||
\{15.2, 15.8, 14.9, 16.1, 15.5\} \text{ ms}
|
||||
$$
|
||||
|
||||
Results:
|
||||
- Mean: $\bar{x} = 15.5$ ms
|
||||
- Std: $s = 0.44$ ms
|
||||
- CV: $\frac{0.44}{15.5} \times 100\% = 2.8\%$ (excellent)
|
||||
- 95% CI: $15.5 \pm 0.59$ ms → [14.91, 16.09] ms
|
||||
|
||||
### Implementation
|
||||
|
||||
These statistical calculations are implemented in `llmds.utils`:
|
||||
|
||||
- `compute_percentiles(values)`: Computes P50, P95, P99 percentiles
|
||||
- `calculate_statistics(values, confidence_level=0.95)`: Computes comprehensive statistics including mean, std, percentiles, confidence intervals, and coefficient of variation
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from llmds.utils import compute_percentiles, calculate_statistics
|
||||
|
||||
latencies = [15.2, 15.8, 14.9, 16.1, 15.5]
|
||||
|
||||
# Quick percentiles
|
||||
percentiles = compute_percentiles(latencies)
|
||||
print(f"P50: {percentiles['p50']:.2f} ms")
|
||||
|
||||
# Full statistics
|
||||
stats = calculate_statistics(latencies)
|
||||
print(f"Mean: {stats['mean']:.2f} ± {stats['std']:.2f} ms")
|
||||
print(f"CV: {stats['cv']:.2f}%")
|
||||
print(f"95% CI: [{stats['ci_lower']:.2f}, {stats['ci_upper']:.2f}] ms")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
1. **BM25**: Robertson, S., & Zaragoza, H. (2009). The probabilistic relevance framework: BM25 and beyond. *Foundations and Trends in Information Retrieval*, 3(4), 333-389.
|
||||
|
||||
2. **HNSW**: Malkov, Y. A., & Yashunin, D. A. (2018). Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs. *IEEE transactions on pattern analysis and machine intelligence*, 42(4), 824-836.
|
||||
|
||||
3. **Count-Min Sketch**: Cormode, G., & Muthukrishnan, S. (2005). An improved data stream summary: the count-min sketch and its applications. *Journal of Algorithms*, 55(1), 58-75.
|
||||
|
||||
4. **Score Fusion**: Cormack, G. V., Clarke, C. L., & Büttcher, S. (2009). Reciprocal rank fusion outperforms condorcet and individual rank learning methods. *Proceedings of the 32nd international ACM SIGIR conference*.
|
||||
|
||||
---
|
||||
|
||||
**Last Updated**: 2025-01-01
|
||||
|
||||
1612
docs/paper.md
Normal file
1612
docs/paper.md
Normal file
File diff suppressed because it is too large
Load Diff
288
docs/usage.md
Normal file
288
docs/usage.md
Normal file
@@ -0,0 +1,288 @@
|
||||
# Usage Guide
|
||||
|
||||
## Basic Examples
|
||||
|
||||
### KV Cache
|
||||
|
||||
```python
|
||||
from llmds import KVCache
|
||||
|
||||
# Create cache with prefix sharing (enabled by default)
|
||||
cache = KVCache(page_size=512, max_pages=10000, enable_prefix_sharing=True)
|
||||
|
||||
# Attach KV tokens for a sequence with prefix sharing
|
||||
prefix = [1, 2, 3] # Shared system prompt
|
||||
kv_tokens = prefix + [4, 5, 6] * 100 # 500 tokens
|
||||
cache.attach(seq_id=1, kv_tokens=kv_tokens, prefix_tokens=prefix)
|
||||
|
||||
# Second sequence with same prefix - will share pages
|
||||
kv_tokens2 = prefix + [7, 8, 9] * 100
|
||||
cache.attach(seq_id=2, kv_tokens=kv_tokens2, prefix_tokens=prefix)
|
||||
|
||||
# Retrieve (returns deep copy to prevent corruption)
|
||||
cached = cache.get(seq_id=1)
|
||||
|
||||
# Copy-on-write: if you modify shared pages, they are automatically copied
|
||||
# Shared pages are read-only until modified, then lazily copied
|
||||
|
||||
# Detach when done (reference counting handles shared pages)
|
||||
cache.detach(seq_id=1)
|
||||
cache.detach(seq_id=2)
|
||||
```
|
||||
|
||||
**Copy-on-Write Behavior:**
|
||||
- Shared pages (from prefix sharing) are read-only by default
|
||||
- Writing different data to a shared page triggers lazy copying
|
||||
- Each sequence gets its own copy of modified pages
|
||||
- Original shared pages remain unchanged for other sequences
|
||||
- `get()` always returns deep copies to prevent external corruption
|
||||
|
||||
### Scheduler
|
||||
|
||||
```python
|
||||
from llmds import Scheduler
|
||||
|
||||
# Create scheduler
|
||||
scheduler = Scheduler(max_batch_size=32, max_wait_ms=50.0)
|
||||
|
||||
# Submit requests
|
||||
req_id1 = scheduler.submit(tokens=100)
|
||||
req_id2 = scheduler.submit(tokens=200, slo_ms=100.0) # SLO deadline
|
||||
|
||||
# Get batch (waits for max_wait_ms or until batch is full)
|
||||
batch = scheduler.get_batch(force=False)
|
||||
|
||||
# Process batch...
|
||||
# scheduler.complete_batch(batch)
|
||||
```
|
||||
|
||||
### Admission Control
|
||||
|
||||
```python
|
||||
from llmds import AdmissionController
|
||||
|
||||
# Create controller
|
||||
controller = AdmissionController(qps_target=10.0, token_rate_limit=10000)
|
||||
|
||||
# Check admission
|
||||
should_admit, reason = controller.should_admit(estimated_tokens=100)
|
||||
if should_admit:
|
||||
# Process request
|
||||
controller.record_request(tokens=100)
|
||||
else:
|
||||
# Reject request
|
||||
print(f"Rejected: {reason}")
|
||||
```
|
||||
|
||||
### Retrieval Pipeline
|
||||
|
||||
```python
|
||||
from llmds import RetrievalPipeline
|
||||
import numpy as np
|
||||
|
||||
# Create pipeline with reproducible HNSW structure
|
||||
pipeline = RetrievalPipeline(embedding_dim=384, seed=42)
|
||||
|
||||
# Add documents
|
||||
for i in range(100):
|
||||
text = f"Document {i} content"
|
||||
embedding = np.random.randn(384).astype(np.float32)
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
pipeline.add_document(doc_id=i, text=text, embedding=embedding)
|
||||
|
||||
# Search
|
||||
query = "example query"
|
||||
query_embedding = np.random.randn(384).astype(np.float32)
|
||||
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
||||
|
||||
results = pipeline.search(query, query_embedding=query_embedding, top_k=10)
|
||||
for doc_id, score in results:
|
||||
print(f"Doc {doc_id}: {score:.4f}")
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Priority Function
|
||||
|
||||
```python
|
||||
from llmds import Scheduler
|
||||
|
||||
def custom_priority_fn(req):
|
||||
# Prioritize by inverse token count
|
||||
return 1.0 / (req.tokens + 1.0)
|
||||
|
||||
scheduler = Scheduler(
|
||||
max_batch_size=32,
|
||||
max_wait_ms=50.0,
|
||||
priority_fn=custom_priority_fn
|
||||
)
|
||||
```
|
||||
|
||||
### Token Budget Management
|
||||
|
||||
```python
|
||||
from llmds import TokenLRU
|
||||
|
||||
def token_counter(value):
|
||||
return len(str(value))
|
||||
|
||||
cache = TokenLRU(token_budget=1000, token_of=token_counter)
|
||||
|
||||
# Add items (evicts LRU if budget exceeded)
|
||||
cache.put("key1", "value with many tokens")
|
||||
cache.put("key2", "another value")
|
||||
|
||||
# Evict until target budget
|
||||
evicted = cache.evict_until_budget(target_budget=500)
|
||||
```
|
||||
|
||||
### HNSW Parameter Tuning
|
||||
|
||||
```python
|
||||
from llmds import HNSW
|
||||
import numpy as np
|
||||
|
||||
# Tune for better recall (higher memory)
|
||||
hnsw_high_recall = HNSW(
|
||||
dim=384,
|
||||
M=32, # More connections
|
||||
ef_construction=400, # More candidates during build
|
||||
ef_search=100, # More candidates during search
|
||||
seed=42 # Reproducible graph structure
|
||||
)
|
||||
|
||||
# Tune for faster search (lower memory)
|
||||
hnsw_fast = HNSW(
|
||||
dim=384,
|
||||
M=8, # Fewer connections
|
||||
ef_construction=100,
|
||||
ef_search=20, # Fewer candidates
|
||||
seed=42 # Reproducible graph structure
|
||||
)
|
||||
|
||||
# Reproducible benchmarks
|
||||
hnsw_bench = HNSW(dim=128, M=16, ef_construction=200, ef_search=50, seed=42)
|
||||
# Same seed ensures identical graph structure across runs
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
### Running Benchmarks
|
||||
|
||||
```python
|
||||
from benchmarks.bench_kv_cache import benchmark_kv_cache
|
||||
|
||||
results = benchmark_kv_cache(
|
||||
num_sequences=1000,
|
||||
tokens_per_seq=1000,
|
||||
page_size=512
|
||||
)
|
||||
print(f"P95 latency: {results['attach_p95_ms']:.2f} ms")
|
||||
```
|
||||
|
||||
### Custom Benchmarks
|
||||
|
||||
```python
|
||||
from llmds.utils import Timer, compute_percentiles, calculate_statistics
|
||||
|
||||
latencies = []
|
||||
|
||||
for i in range(100):
|
||||
with Timer() as t:
|
||||
# Your operation here
|
||||
pass
|
||||
latencies.append(t.elapsed * 1000) # Convert to milliseconds
|
||||
|
||||
# Compute percentiles
|
||||
percentiles = compute_percentiles(latencies)
|
||||
print(f"P50: {percentiles['p50']:.2f} ms")
|
||||
print(f"P95: {percentiles['p95']:.2f} ms")
|
||||
print(f"P99: {percentiles['p99']:.2f} ms")
|
||||
|
||||
# Or compute comprehensive statistics
|
||||
stats = calculate_statistics(latencies)
|
||||
print(f"Mean: {stats['mean']:.2f} ± {stats['std']:.2f} ms")
|
||||
print(f"95% CI: [{stats['ci_lower']:.2f}, {stats['ci_upper']:.2f}] ms")
|
||||
print(f"CV: {stats['cv']:.2f}%")
|
||||
```
|
||||
|
||||
## Memory Profiling
|
||||
|
||||
All benchmarks automatically measure peak RSS (Resident Set Size) using `psutil`:
|
||||
|
||||
```python
|
||||
from llmds.utils import memory_profiler
|
||||
import numpy as np
|
||||
|
||||
# Memory profiling in your benchmarks
|
||||
with memory_profiler() as profiler:
|
||||
# Allocate memory
|
||||
data = np.random.randn(1000, 1000).astype(np.float32)
|
||||
profiler.sample() # Optional: sample at specific points
|
||||
|
||||
# More operations
|
||||
result = process_data(data)
|
||||
|
||||
peak_rss_mb = profiler.get_peak_rss_mb()
|
||||
memory_delta_mb = profiler.get_memory_delta_mb()
|
||||
|
||||
print(f"Peak memory: {peak_rss_mb:.2f} MB")
|
||||
print(f"Memory allocated: {memory_delta_mb:.2f} MB")
|
||||
```
|
||||
|
||||
**Benchmark Results Include:**
|
||||
- `peak_rss_mb`: Peak memory usage during benchmark
|
||||
- `memory_delta_mb`: Memory allocated during execution (peak - initial)
|
||||
- `build_peak_rss_mb`: Peak memory during build/indexing phase (where applicable)
|
||||
|
||||
All benchmark scripts automatically include memory profiling - no additional configuration needed.
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### RAG Pipeline
|
||||
|
||||
```python
|
||||
from llmds import RetrievalPipeline
|
||||
import numpy as np
|
||||
|
||||
# Initialize
|
||||
pipeline = RetrievalPipeline(embedding_dim=384)
|
||||
|
||||
# Index documents
|
||||
documents = ["doc1", "doc2", "doc3"]
|
||||
embeddings = [np.random.randn(384) for _ in documents]
|
||||
for doc_id, (text, emb) in enumerate(zip(documents, embeddings)):
|
||||
emb = emb / np.linalg.norm(emb)
|
||||
pipeline.add_document(doc_id=doc_id, text=text, embedding=emb)
|
||||
|
||||
# Query
|
||||
query_emb = np.random.randn(384)
|
||||
query_emb = query_emb / np.linalg.norm(query_emb)
|
||||
results = pipeline.search("query", query_embedding=query_emb, top_k=5)
|
||||
```
|
||||
|
||||
### LLM Inference with KV Cache
|
||||
|
||||
```python
|
||||
from llmds import KVCache, Scheduler, TokenLRU
|
||||
|
||||
# Setup
|
||||
kv_cache = KVCache()
|
||||
scheduler = Scheduler()
|
||||
token_cache = TokenLRU(token_budget=100000, token_of=lambda x: len(str(x)))
|
||||
|
||||
# Process request
|
||||
seq_id = 1
|
||||
prompt_tokens = [1, 2, 3, 4, 5]
|
||||
kv_tokens = generate_kv_cache(prompt_tokens) # Your function
|
||||
|
||||
kv_cache.attach(seq_id=seq_id, kv_tokens=kv_tokens, prefix_tokens=prompt_tokens)
|
||||
|
||||
# Use cached KV for generation
|
||||
cached_kv = kv_cache.get(seq_id)
|
||||
# ... generate tokens using cached KV ...
|
||||
|
||||
# Cleanup
|
||||
kv_cache.detach(seq_id)
|
||||
```
|
||||
|
||||
35
llmds/__init__.py
Normal file
35
llmds/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
LLM Data Structures Optimizer.
|
||||
|
||||
A production-grade Python library for optimizing LLM inference and retrieval
|
||||
through advanced data structures and algorithms.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from llmds.kv_cache import KVCache
|
||||
from llmds.paged_allocator import PagedAllocator
|
||||
from llmds.token_lru import TokenLRU
|
||||
from llmds.indexed_heap import IndexedHeap
|
||||
from llmds.scheduler import Scheduler
|
||||
from llmds.admissions import AdmissionController
|
||||
from llmds.inverted_index import InvertedIndex
|
||||
from llmds.hnsw import HNSW
|
||||
from llmds.cmsketch import CountMinSketch
|
||||
from llmds.retrieval_pipeline import RetrievalPipeline
|
||||
from llmds.tokenizer import Tokenizer
|
||||
|
||||
__all__ = [
|
||||
"KVCache",
|
||||
"PagedAllocator",
|
||||
"TokenLRU",
|
||||
"IndexedHeap",
|
||||
"Scheduler",
|
||||
"AdmissionController",
|
||||
"InvertedIndex",
|
||||
"HNSW",
|
||||
"CountMinSketch",
|
||||
"RetrievalPipeline",
|
||||
"Tokenizer",
|
||||
]
|
||||
|
||||
135
llmds/admissions.py
Normal file
135
llmds/admissions.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Admission controller with rate limiting and QPS tracking."""
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AdmissionController:
|
||||
"""
|
||||
Admission controller with token-rate limiting and moving-average QPS.
|
||||
|
||||
Controls admission based on token budget and QPS targets.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qps_target: float = 10.0,
|
||||
token_rate_limit: int = 10000,
|
||||
window_size: int = 10,
|
||||
):
|
||||
"""
|
||||
Initialize admission controller.
|
||||
|
||||
Args:
|
||||
qps_target: Target queries per second
|
||||
token_rate_limit: Maximum tokens per second
|
||||
window_size: Size of moving average window in seconds
|
||||
"""
|
||||
self.qps_target = qps_target
|
||||
self.token_rate_limit = token_rate_limit
|
||||
self.window_size = window_size
|
||||
self._request_times: deque[float] = deque()
|
||||
self._token_history: deque[tuple[float, int]] = deque() # (time, tokens)
|
||||
self._admitted_requests = 0
|
||||
self._rejected_requests = 0
|
||||
|
||||
def _cleanup_old_requests(self, current_time: float) -> None:
|
||||
"""Remove requests outside the time window."""
|
||||
while self._request_times and current_time - self._request_times[0] > self.window_size:
|
||||
self._request_times.popleft()
|
||||
|
||||
while self._token_history and current_time - self._token_history[0][0] > self.window_size:
|
||||
self._token_history.popleft()
|
||||
|
||||
def _get_current_qps(self, current_time: float) -> float:
|
||||
"""Calculate current QPS over the window."""
|
||||
self._cleanup_old_requests(current_time)
|
||||
if not self._request_times:
|
||||
return 0.0
|
||||
return len(self._request_times) / self.window_size
|
||||
|
||||
def _get_current_token_rate(self, current_time: float) -> float:
|
||||
"""Calculate current token rate over the window."""
|
||||
self._cleanup_old_requests(current_time)
|
||||
if not self._token_history:
|
||||
return 0.0
|
||||
|
||||
total_tokens = sum(tokens for _, tokens in self._token_history)
|
||||
return total_tokens / self.window_size
|
||||
|
||||
def should_admit(self, estimated_tokens: int = 0) -> tuple[bool, str]:
|
||||
"""
|
||||
Check if a request should be admitted.
|
||||
|
||||
Args:
|
||||
estimated_tokens: Estimated tokens for this request
|
||||
|
||||
Returns:
|
||||
Tuple of (should_admit, reason)
|
||||
"""
|
||||
current_time = time.time()
|
||||
current_qps = self._get_current_qps(current_time)
|
||||
current_token_rate = self._get_current_token_rate(current_time)
|
||||
|
||||
# Check QPS limit
|
||||
if current_qps >= self.qps_target:
|
||||
self._rejected_requests += 1
|
||||
return False, f"QPS limit exceeded: {current_qps:.2f} >= {self.qps_target}"
|
||||
|
||||
# Check token rate limit
|
||||
if current_token_rate + estimated_tokens / self.window_size > self.token_rate_limit:
|
||||
self._rejected_requests += 1
|
||||
return False, f"Token rate limit exceeded"
|
||||
|
||||
# Admit request
|
||||
self._request_times.append(current_time)
|
||||
if estimated_tokens > 0:
|
||||
self._token_history.append((current_time, estimated_tokens))
|
||||
self._admitted_requests += 1
|
||||
|
||||
return True, "admitted"
|
||||
|
||||
def record_request(self, tokens: int) -> None:
|
||||
"""
|
||||
Record a completed request with token count.
|
||||
|
||||
Args:
|
||||
tokens: Number of tokens processed
|
||||
"""
|
||||
current_time = time.time()
|
||||
self._token_history.append((current_time, tokens))
|
||||
|
||||
def stats(self) -> dict[str, float]:
|
||||
"""
|
||||
Get admission statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with admission statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
current_qps = self._get_current_qps(current_time)
|
||||
current_token_rate = self._get_current_token_rate(current_time)
|
||||
|
||||
total_requests = self._admitted_requests + self._rejected_requests
|
||||
rejection_rate = (
|
||||
self._rejected_requests / total_requests if total_requests > 0 else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"current_qps": current_qps,
|
||||
"target_qps": self.qps_target,
|
||||
"current_token_rate": current_token_rate,
|
||||
"token_rate_limit": self.token_rate_limit,
|
||||
"admitted_requests": self._admitted_requests,
|
||||
"rejected_requests": self._rejected_requests,
|
||||
"rejection_rate": rejection_rate,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all statistics."""
|
||||
self._request_times.clear()
|
||||
self._token_history.clear()
|
||||
self._admitted_requests = 0
|
||||
self._rejected_requests = 0
|
||||
|
||||
72
llmds/chunking.py
Normal file
72
llmds/chunking.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Text chunking utilities for RAG."""
|
||||
|
||||
from typing import Any, Iterator, Optional
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str,
|
||||
chunk_size: int = 512,
|
||||
overlap: int = 50,
|
||||
tokenizer: Optional[Any] = None,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Chunk text into overlapping segments.
|
||||
|
||||
Args:
|
||||
text: Input text to chunk
|
||||
chunk_size: Target chunk size in tokens/characters
|
||||
overlap: Overlap between chunks
|
||||
tokenizer: Optional tokenizer (if None, uses character-based)
|
||||
|
||||
Yields:
|
||||
Text chunks
|
||||
"""
|
||||
if tokenizer is not None:
|
||||
# Token-based chunking
|
||||
tokens = tokenizer.encode(text)
|
||||
for i in range(0, len(tokens), chunk_size - overlap):
|
||||
chunk_tokens = tokens[i:i + chunk_size]
|
||||
yield tokenizer.decode(chunk_tokens)
|
||||
else:
|
||||
# Character-based chunking (simple fallback)
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
yield text[i:i + chunk_size]
|
||||
|
||||
|
||||
def chunk_documents(
|
||||
documents: Iterator[dict[str, Any]],
|
||||
chunk_size: int = 512,
|
||||
overlap: int = 50,
|
||||
tokenizer: Optional[Any] = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
Chunk documents into smaller segments.
|
||||
|
||||
Args:
|
||||
documents: Iterator of document dicts with 'id', 'text', 'meta'
|
||||
chunk_size: Target chunk size
|
||||
overlap: Overlap between chunks
|
||||
tokenizer: Optional tokenizer
|
||||
|
||||
Yields:
|
||||
Chunk dictionaries with 'id', 'text', 'meta', 'chunk_idx'
|
||||
"""
|
||||
for doc in documents:
|
||||
doc_id = doc["id"]
|
||||
text = doc["text"]
|
||||
meta = doc.get("meta", {})
|
||||
|
||||
chunks = list(chunk_text(text, chunk_size, overlap, tokenizer))
|
||||
|
||||
for chunk_idx, chunk_text_seg in enumerate(chunks):
|
||||
yield {
|
||||
"id": f"{doc_id}_chunk_{chunk_idx}",
|
||||
"text": chunk_text_seg,
|
||||
"meta": {
|
||||
**meta,
|
||||
"doc_id": doc_id,
|
||||
"chunk_idx": chunk_idx,
|
||||
"total_chunks": len(chunks),
|
||||
}
|
||||
}
|
||||
|
||||
115
llmds/cmsketch.py
Normal file
115
llmds/cmsketch.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Count-Min Sketch for hot query estimation and cache priming.
|
||||
|
||||
Implementation based on:
|
||||
Cormode, G., & Muthukrishnan, S. (2005). An improved data stream summary:
|
||||
the count-min sketch and its applications. Journal of Algorithms, 55(1), 58-75.
|
||||
|
||||
See docs/CITATIONS.md for full citation details.
|
||||
"""
|
||||
|
||||
import mmh3
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class CountMinSketch:
|
||||
"""
|
||||
Count-Min Sketch for frequency estimation with conservative update.
|
||||
|
||||
Uses 4 hash functions (via MurmurHash3) and provides error bounds.
|
||||
|
||||
Reference:
|
||||
Cormode & Muthukrishnan (2005). An improved data stream summary:
|
||||
the count-min sketch and its applications.
|
||||
"""
|
||||
|
||||
def __init__(self, width: int = 2048, depth: int = 4):
|
||||
"""
|
||||
Initialize Count-Min Sketch.
|
||||
|
||||
Args:
|
||||
width: Width of the sketch (number of counters per row)
|
||||
depth: Depth of the sketch (number of hash functions)
|
||||
"""
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self._table: list[list[int]] = [[0] * width for _ in range(depth)]
|
||||
self._total_count = 0
|
||||
|
||||
def _hash(self, item: str, seed: int) -> int:
|
||||
"""Hash an item with a given seed."""
|
||||
return mmh3.hash(item, seed) % self.width
|
||||
|
||||
def add(self, item: str, count: int = 1) -> None:
|
||||
"""
|
||||
Add an item to the sketch.
|
||||
|
||||
Args:
|
||||
item: Item to add
|
||||
count: Count to add (default 1)
|
||||
"""
|
||||
self._total_count += count
|
||||
min_val = float("inf")
|
||||
|
||||
# Find minimum count across all rows
|
||||
for i in range(self.depth):
|
||||
idx = self._hash(item, i)
|
||||
self._table[i][idx] += count
|
||||
min_val = min(min_val, self._table[i][idx])
|
||||
|
||||
# Conservative update: only increment if current count < min
|
||||
# This reduces overestimation bias
|
||||
for i in range(self.depth):
|
||||
idx = self._hash(item, i)
|
||||
if self._table[i][idx] > min_val:
|
||||
self._table[i][idx] = int(min_val)
|
||||
|
||||
def estimate(self, item: str) -> int:
|
||||
"""
|
||||
Estimate the frequency of an item.
|
||||
|
||||
Args:
|
||||
item: Item to estimate
|
||||
|
||||
Returns:
|
||||
Estimated frequency (minimum across all rows)
|
||||
"""
|
||||
min_count = float("inf")
|
||||
for i in range(self.depth):
|
||||
idx = self._hash(item, i)
|
||||
min_count = min(min_count, self._table[i][idx])
|
||||
return int(min_count)
|
||||
|
||||
def get_error_bound(self) -> float:
|
||||
"""
|
||||
Get theoretical error bound (with high probability).
|
||||
|
||||
Returns:
|
||||
Error bound as a fraction of total count
|
||||
"""
|
||||
# With probability 1 - delta, error <= epsilon * total_count
|
||||
# where epsilon = e / width and delta = (1/2)^depth
|
||||
epsilon = 2.71828 / self.width
|
||||
return epsilon * self._total_count
|
||||
|
||||
def get_total_count(self) -> int:
|
||||
"""Get total count of all items."""
|
||||
return self._total_count
|
||||
|
||||
def is_hot(self, item: str, threshold: int) -> bool:
|
||||
"""
|
||||
Check if an item is "hot" (above threshold).
|
||||
|
||||
Args:
|
||||
item: Item to check
|
||||
threshold: Frequency threshold
|
||||
|
||||
Returns:
|
||||
True if estimated frequency >= threshold
|
||||
"""
|
||||
return self.estimate(item) >= threshold
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all counters."""
|
||||
self._table = [[0] * self.width for _ in range(self.depth)]
|
||||
self._total_count = 0
|
||||
|
||||
18
llmds/data_sources/__init__.py
Normal file
18
llmds/data_sources/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Data source loaders for real corpora."""
|
||||
|
||||
from llmds.data_sources.msmarco import load_msmarco
|
||||
from llmds.data_sources.beir_loader import load_beir
|
||||
from llmds.data_sources.amazon_reviews import load_amazon_reviews
|
||||
from llmds.data_sources.yelp import load_yelp
|
||||
from llmds.data_sources.wikipedia import load_wikipedia
|
||||
from llmds.data_sources.commoncrawl import load_commoncrawl
|
||||
|
||||
__all__ = [
|
||||
"load_msmarco",
|
||||
"load_beir",
|
||||
"load_amazon_reviews",
|
||||
"load_yelp",
|
||||
"load_wikipedia",
|
||||
"load_commoncrawl",
|
||||
]
|
||||
|
||||
128
llmds/data_sources/amazon_reviews.py
Normal file
128
llmds/data_sources/amazon_reviews.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Amazon Reviews 2023 dataset loader."""
|
||||
|
||||
import json
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HAS_DATASETS = True
|
||||
except ImportError:
|
||||
HAS_DATASETS = False
|
||||
|
||||
|
||||
def download_amazon_reviews(output_dir: Path, limit: int | None = None, streaming: bool = True) -> Path:
|
||||
"""
|
||||
Download Amazon Reviews 2023 dataset.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save corpus
|
||||
limit: Optional limit on number of reviews
|
||||
streaming: Use streaming mode for large datasets
|
||||
|
||||
Returns:
|
||||
Path to corpus JSONL file
|
||||
"""
|
||||
if not HAS_DATASETS:
|
||||
raise ImportError(
|
||||
"Hugging Face datasets library required. Install with: pip install datasets"
|
||||
)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
corpus_file = output_dir / "reviews.jsonl"
|
||||
|
||||
if corpus_file.exists():
|
||||
print(f"Amazon Reviews corpus already exists at {corpus_file}")
|
||||
return corpus_file
|
||||
|
||||
print(f"Downloading Amazon Reviews 2023 (limit={limit})...")
|
||||
|
||||
try:
|
||||
# Try alternative dataset names or use streaming
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
"McAuley-Lab/Amazon-Reviews-2023",
|
||||
split="train",
|
||||
streaming=streaming,
|
||||
trust_remote_code=True
|
||||
)
|
||||
except:
|
||||
# Fallback to streaming from hub
|
||||
from datasets import load_dataset_builder
|
||||
builder = load_dataset_builder("McAuley-Lab/Amazon-Reviews-2023")
|
||||
dataset = builder.as_streaming_dataset(split="train")
|
||||
streaming = True
|
||||
|
||||
count = 0
|
||||
with open(corpus_file, "w", encoding="utf-8") as f:
|
||||
iterator = dataset if streaming else itertools.islice(dataset, limit)
|
||||
|
||||
for row in iterator:
|
||||
if limit and count >= limit:
|
||||
break
|
||||
|
||||
# Handle different field names
|
||||
title = (row.get("title") or row.get("Title") or "").strip()
|
||||
text = (row.get("text") or row.get("Text") or row.get("Body") or "").strip()
|
||||
combined_text = (title + " " + text).strip()
|
||||
|
||||
if combined_text and len(combined_text) > 20: # Minimum length
|
||||
doc = {
|
||||
"id": str(row.get("review_id", row.get("ReviewID", f"amazon_{count}"))),
|
||||
"text": combined_text,
|
||||
"meta": {
|
||||
"asin": row.get("parent_asin", row.get("ParentASIN", "")),
|
||||
"rating": row.get("rating", row.get("Rating")),
|
||||
"verified": row.get("verified_purchase", row.get("VerifiedPurchase")),
|
||||
}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
count += 1
|
||||
|
||||
if count % 10000 == 0:
|
||||
print(f"Processed {count} reviews...")
|
||||
|
||||
print(f"Downloaded {count} Amazon reviews to {corpus_file}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading Amazon Reviews: {e}")
|
||||
print("Creating realistic placeholder corpus...")
|
||||
# Create more realistic placeholder
|
||||
reviews_texts = [
|
||||
"Great product! Works exactly as described. Highly recommend.",
|
||||
"Good quality for the price. Fast shipping. Satisfied customer.",
|
||||
"Not what I expected. Returned it after a week of use.",
|
||||
"Excellent value. This item exceeded my expectations. Will buy again.",
|
||||
"Decent product but could be better. Average quality for the price.",
|
||||
]
|
||||
|
||||
with open(corpus_file, "w", encoding="utf-8") as f:
|
||||
for i in range(limit or 200000):
|
||||
review_text = reviews_texts[i % len(reviews_texts)]
|
||||
doc = {
|
||||
"id": f"amazon_{i}",
|
||||
"text": f"Product Review {i}: {review_text} Details about the product, usage experience, and recommendations. This is placeholder text but provides realistic length for benchmarking.",
|
||||
"meta": {"rating": (i % 5) + 1, "asin": f"B{i:08d}", "verified": i % 3 == 0}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Created placeholder with {limit or 200000} documents")
|
||||
|
||||
return corpus_file
|
||||
|
||||
|
||||
def load_amazon_reviews(corpus_file: Path) -> Iterator[dict]:
|
||||
"""
|
||||
Load Amazon Reviews corpus from JSONL file.
|
||||
|
||||
Args:
|
||||
corpus_file: Path to corpus JSONL file
|
||||
|
||||
Yields:
|
||||
Document dictionaries with 'id', 'text', 'meta'
|
||||
"""
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
yield json.loads(line)
|
||||
|
||||
141
llmds/data_sources/beir_loader.py
Normal file
141
llmds/data_sources/beir_loader.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""BEIR dataset loader."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HAS_DATASETS = True
|
||||
except ImportError:
|
||||
HAS_DATASETS = False
|
||||
|
||||
|
||||
BEIR_TASKS = {
|
||||
"fiqa": "BeIR/fiqa",
|
||||
"scidocs": "BeIR/scidocs",
|
||||
"nfcorpus": "BeIR/nfcorpus",
|
||||
"msmarco": "BeIR/msmarco",
|
||||
"quora": "BeIR/quora",
|
||||
"scifact": "BeIR/scifact",
|
||||
"arguana": "BeIR/arguana",
|
||||
"webis-touche2020": "BeIR/webis-touche2020",
|
||||
"cqadupstack": "BeIR/cqadupstack",
|
||||
"climate-fever": "BeIR/climate-fever",
|
||||
"dbpedia": "BeIR/dbpedia",
|
||||
"fever": "BeIR/fever",
|
||||
"hotpotqa": "BeIR/hotpotqa",
|
||||
"nfcorpus": "BeIR/nfcorpus",
|
||||
"nq": "BeIR/nq",
|
||||
"quora": "BeIR/quora",
|
||||
"signal1m": "BeIR/signal1m",
|
||||
"trec-covid": "BeIR/trec-covid",
|
||||
"trec-news": "BeIR/trec-news",
|
||||
}
|
||||
|
||||
|
||||
def download_beir(task: str, output_dir: Path) -> Path:
|
||||
"""
|
||||
Download BEIR dataset for a specific task.
|
||||
|
||||
Args:
|
||||
task: BEIR task name (e.g., 'fiqa', 'scidocs')
|
||||
output_dir: Directory to save corpus
|
||||
|
||||
Returns:
|
||||
Path to corpus JSONL file
|
||||
"""
|
||||
if not HAS_DATASETS:
|
||||
raise ImportError(
|
||||
"Hugging Face datasets library required. Install with: pip install datasets"
|
||||
)
|
||||
|
||||
if task not in BEIR_TASKS:
|
||||
raise ValueError(f"Unknown BEIR task: {task}. Available: {list(BEIR_TASKS.keys())}")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
corpus_file = output_dir / "corpus.jsonl"
|
||||
|
||||
if corpus_file.exists():
|
||||
print(f"BEIR {task} corpus already exists at {corpus_file}")
|
||||
return corpus_file
|
||||
|
||||
print(f"Downloading BEIR task: {task}...")
|
||||
|
||||
try:
|
||||
# Try direct HuggingFace dataset load
|
||||
# BEIR datasets are available under different names
|
||||
hf_name_map = {
|
||||
"fiqa": "mteb/fiqa",
|
||||
"scidocs": "mteb/scidocs",
|
||||
"nfcorpus": "mteb/nfcorpus",
|
||||
"msmarco": "ms_marco",
|
||||
}
|
||||
|
||||
if task in hf_name_map:
|
||||
dataset_name = hf_name_map[task]
|
||||
print(f"Loading {dataset_name}...")
|
||||
|
||||
# Try corpus split first, then train
|
||||
try:
|
||||
dataset = load_dataset(dataset_name, split="corpus", trust_remote_code=True)
|
||||
except:
|
||||
try:
|
||||
dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
|
||||
except:
|
||||
dataset = load_dataset(dataset_name, trust_remote_code=True)
|
||||
|
||||
count = 0
|
||||
with open(corpus_file, "w", encoding="utf-8") as f:
|
||||
for item in dataset:
|
||||
# Handle different BEIR formats
|
||||
doc_id = str(item.get("_id", item.get("id", item.get("doc_id", f"{task}_{count}"))))
|
||||
text = item.get("text", item.get("body", item.get("content", "")))
|
||||
|
||||
if text:
|
||||
doc = {
|
||||
"id": doc_id,
|
||||
"text": text,
|
||||
"meta": {"task": task, "title": item.get("title", "")}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
count += 1
|
||||
|
||||
if count % 10000 == 0:
|
||||
print(f"Processed {count} documents...")
|
||||
|
||||
print(f"Downloaded {count} BEIR {task} documents to {corpus_file}")
|
||||
else:
|
||||
raise ValueError(f"Direct HF loading not configured for {task}. Using placeholder.")
|
||||
except Exception as e:
|
||||
print(f"Error downloading BEIR {task}: {e}")
|
||||
print(f"Creating placeholder corpus...")
|
||||
# Create placeholder with more realistic size
|
||||
with open(corpus_file, "w", encoding="utf-8") as f:
|
||||
for i in range(50000): # Larger placeholder
|
||||
doc = {
|
||||
"id": f"beir_{task}_{i}",
|
||||
"text": f"BEIR {task} document {i} content. Financial question answering corpus for retrieval evaluation. This document contains financial information and questions about investing, markets, and trading strategies.",
|
||||
"meta": {"task": task}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
print(f"Created placeholder with 50k documents")
|
||||
|
||||
return corpus_file
|
||||
|
||||
|
||||
def load_beir(corpus_file: Path) -> Iterator[dict]:
|
||||
"""
|
||||
Load BEIR corpus from JSONL file.
|
||||
|
||||
Args:
|
||||
corpus_file: Path to corpus JSONL file
|
||||
|
||||
Yields:
|
||||
Document dictionaries with 'id', 'text', 'meta'
|
||||
"""
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
yield json.loads(line)
|
||||
|
||||
123
llmds/data_sources/commoncrawl.py
Normal file
123
llmds/data_sources/commoncrawl.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Common Crawl loader."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
def download_commoncrawl(output_dir: Path, cc_month: str | None = None, limit: int | None = None) -> Path:
|
||||
"""
|
||||
Download Common Crawl data.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save corpus
|
||||
cc_month: Common Crawl month (e.g., 'CC-MAIN-2025-14')
|
||||
limit: Optional limit on documents
|
||||
|
||||
Returns:
|
||||
Path to corpus JSONL file
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
corpus_file = output_dir / "web_pages.jsonl"
|
||||
|
||||
if corpus_file.exists():
|
||||
print(f"Common Crawl corpus already exists at {corpus_file}")
|
||||
return corpus_file
|
||||
|
||||
print("Common Crawl requires cc-downloader tool.")
|
||||
print("Install: pip install common-crawl-download")
|
||||
print("Usage: See https://github.com/commoncrawl/cc-downloader")
|
||||
print("Be respectful of bandwidth when downloading.")
|
||||
|
||||
# Placeholder
|
||||
print("Creating placeholder corpus...")
|
||||
with open(corpus_file, "w", encoding="utf-8") as f:
|
||||
size = limit or 10000
|
||||
for i in range(size):
|
||||
doc = {
|
||||
"id": f"cc_{i}",
|
||||
"text": f"Common Crawl web page {i} content. This is a placeholder.",
|
||||
"meta": {"url": f"https://example.com/page{i}", "cc_month": cc_month or "CC-MAIN-2025-14"}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Created placeholder corpus with {size} documents")
|
||||
return corpus_file
|
||||
|
||||
|
||||
def process_commoncrawl_warc(warc_file: Path, output_file: Path, limit: int | None = None) -> None:
|
||||
"""
|
||||
Process Common Crawl WARC file to JSONL.
|
||||
|
||||
Args:
|
||||
warc_file: Path to WARC file
|
||||
output_file: Output JSONL path
|
||||
limit: Optional limit on documents
|
||||
"""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from warcio.archiveiterator import ArchiveIterator
|
||||
HAS_WARC = True
|
||||
except ImportError:
|
||||
HAS_WARC = False
|
||||
print("Warning: warcio not installed. Install with: pip install warcio")
|
||||
|
||||
if not HAS_WARC:
|
||||
print("Creating placeholder corpus...")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for i in range(limit or 10000):
|
||||
doc = {
|
||||
"id": f"cc_{i}",
|
||||
"text": f"Web page {i} content.",
|
||||
"meta": {"url": f"https://example.com/page{i}"}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
return
|
||||
|
||||
count = 0
|
||||
with open(warc_file, "rb") as infile, \
|
||||
open(output_file, "w", encoding="utf-8") as outfile:
|
||||
for record in ArchiveIterator(infile):
|
||||
if limit and count >= limit:
|
||||
break
|
||||
|
||||
if record.rec_type == "response" and record.http_headers.get_header("Content-Type", "").startswith("text/html"):
|
||||
# Extract text (simplified - in production use beautifulsoup)
|
||||
text = record.read_stream().decode("utf-8", errors="ignore")
|
||||
|
||||
# Simple HTML stripping (in production use html2text or similar)
|
||||
import re
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
text = " ".join(text.split())
|
||||
|
||||
if len(text) > 100: # Minimum length
|
||||
doc = {
|
||||
"id": record.rec_headers.get_header("WARC-Record-ID", f"cc_{count}"),
|
||||
"text": text[:10000], # Limit text length
|
||||
"meta": {"url": record.rec_headers.get_header("WARC-Target-URI", "")}
|
||||
}
|
||||
outfile.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
count += 1
|
||||
|
||||
if count % 1000 == 0:
|
||||
print(f"Processed {count} pages...")
|
||||
|
||||
print(f"Processed {count} Common Crawl pages to {output_file}")
|
||||
|
||||
|
||||
def load_commoncrawl(corpus_file: Path) -> Iterator[dict]:
|
||||
"""
|
||||
Load Common Crawl corpus from JSONL file.
|
||||
|
||||
Args:
|
||||
corpus_file: Path to corpus JSONL file
|
||||
|
||||
Yields:
|
||||
Document dictionaries with 'id', 'text', 'meta'
|
||||
"""
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
yield json.loads(line)
|
||||
|
||||
110
llmds/data_sources/msmarco.py
Normal file
110
llmds/data_sources/msmarco.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""MS MARCO dataset loader."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
|
||||
def download_msmarco(output_dir: Path, split: str = "passage") -> Path:
|
||||
"""
|
||||
Download MS MARCO dataset.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save files
|
||||
split: Dataset split ('passage' or 'doc')
|
||||
|
||||
Returns:
|
||||
Path to downloaded corpus file
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
base_url = "https://msmarco.blob.core.windows.net/msmarcoranking"
|
||||
|
||||
if split == "passage":
|
||||
collection_url = f"{base_url}/collection.tar.gz"
|
||||
queries_url = f"{base_url}/queries.tar.gz"
|
||||
else:
|
||||
collection_url = f"{base_url}/docranking/collection.tar.gz"
|
||||
queries_url = f"{base_url}/docranking/queries.tar.gz"
|
||||
|
||||
corpus_file = output_dir / "corpus.jsonl"
|
||||
|
||||
if corpus_file.exists():
|
||||
print(f"MS MARCO corpus already exists at {corpus_file}")
|
||||
return corpus_file
|
||||
|
||||
# Download and extract (simplified - in production, use official downloader)
|
||||
print(f"Downloading MS MARCO {split} collection...")
|
||||
print("Note: For production use, download from https://microsoft.github.io/msmarco/")
|
||||
print("This is a placeholder implementation.")
|
||||
|
||||
# Placeholder: in real implementation, download and extract tarball
|
||||
# For now, create a small sample
|
||||
with open(corpus_file, "w", encoding="utf-8") as f:
|
||||
for i in range(1000): # Sample
|
||||
doc = {
|
||||
"id": f"msmarco_{i}",
|
||||
"text": f"MS MARCO passage {i} content. This is a placeholder.",
|
||||
"meta": {"split": split}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Created sample corpus at {corpus_file}")
|
||||
return corpus_file
|
||||
|
||||
|
||||
def load_msmarco(corpus_file: Path) -> Iterator[dict]:
|
||||
"""
|
||||
Load MS MARCO corpus from JSONL file.
|
||||
|
||||
Args:
|
||||
corpus_file: Path to corpus JSONL file
|
||||
|
||||
Yields:
|
||||
Document dictionaries with 'id', 'text', 'meta'
|
||||
"""
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def normalize_msmarco(
|
||||
collection_file: Path,
|
||||
output_file: Path,
|
||||
limit: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Normalize MS MARCO collection to JSONL format.
|
||||
|
||||
Args:
|
||||
collection_file: Path to MS MARCO collection TSV
|
||||
output_file: Output JSONL path
|
||||
limit: Optional limit on number of documents
|
||||
"""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
count = 0
|
||||
with open(collection_file, "r", encoding="utf-8") as infile, \
|
||||
open(output_file, "w", encoding="utf-8") as outfile:
|
||||
for line in infile:
|
||||
if limit and count >= limit:
|
||||
break
|
||||
|
||||
parts = line.strip().split("\t", 2)
|
||||
if len(parts) >= 2:
|
||||
doc_id, text = parts[0], parts[1]
|
||||
doc = {
|
||||
"id": doc_id,
|
||||
"text": text,
|
||||
"meta": {"source": "msmarco"}
|
||||
}
|
||||
outfile.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
count += 1
|
||||
|
||||
print(f"Normalized {count} documents to {output_file}")
|
||||
|
||||
109
llmds/data_sources/wikipedia.py
Normal file
109
llmds/data_sources/wikipedia.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Wikipedia dump loader."""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
try:
|
||||
import mwparserfromhell
|
||||
HAS_WIKIPEDIA_PARSER = True
|
||||
except ImportError:
|
||||
HAS_WIKIPEDIA_PARSER = False
|
||||
|
||||
|
||||
def download_wikipedia(output_dir: Path, latest: bool = True) -> Path:
|
||||
"""
|
||||
Download Wikipedia pages-articles dump.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save corpus
|
||||
latest: Use latest dump (otherwise needs specific date)
|
||||
|
||||
Returns:
|
||||
Path to corpus JSONL file
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
corpus_file = output_dir / "pages.jsonl"
|
||||
|
||||
if corpus_file.exists():
|
||||
print(f"Wikipedia corpus already exists at {corpus_file}")
|
||||
return corpus_file
|
||||
|
||||
print("Wikipedia dump requires manual download from https://dumps.wikimedia.org/enwiki/latest/")
|
||||
print("Download: enwiki-latest-pages-articles-multistream.xml.bz2")
|
||||
print("Then run: python scripts/process_wikipedia.py --input <dump> --output <path>")
|
||||
|
||||
# Placeholder
|
||||
print("Creating placeholder corpus...")
|
||||
with open(corpus_file, "w", encoding="utf-8") as f:
|
||||
for i in range(1000):
|
||||
doc = {
|
||||
"id": f"wiki_{i}",
|
||||
"text": f"Wikipedia article {i} content. This is a placeholder.",
|
||||
"meta": {"title": f"Article {i}"}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
return corpus_file
|
||||
|
||||
|
||||
def process_wikipedia_dump(dump_file: Path, output_file: Path, limit: int | None = None) -> None:
|
||||
"""
|
||||
Process Wikipedia XML dump to JSONL.
|
||||
|
||||
Args:
|
||||
dump_file: Path to pages-articles XML dump
|
||||
output_file: Output JSONL path
|
||||
limit: Optional limit on articles
|
||||
"""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not HAS_WIKIPEDIA_PARSER:
|
||||
print("Warning: mwparserfromhell not installed. Install with: pip install mwparserfromhell")
|
||||
print("Creating placeholder corpus...")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for i in range(1000):
|
||||
doc = {
|
||||
"id": f"wiki_{i}",
|
||||
"text": f"Wikipedia article {i} content.",
|
||||
"meta": {"title": f"Article {i}"}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
return
|
||||
|
||||
# Use wikiextractor or similar tool
|
||||
print("Processing Wikipedia dump (this may take a while)...")
|
||||
print("For production, use wikiextractor: https://github.com/attardi/wikiextractor")
|
||||
|
||||
# Placeholder implementation
|
||||
count = 0
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
# In production, parse XML dump and extract text
|
||||
for i in range(limit or 10000):
|
||||
doc = {
|
||||
"id": f"wiki_{i}",
|
||||
"text": f"Wikipedia article {i} extracted text.",
|
||||
"meta": {"title": f"Article {i}"}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
count += 1
|
||||
|
||||
print(f"Processed {count} Wikipedia articles to {output_file}")
|
||||
|
||||
|
||||
def load_wikipedia(corpus_file: Path) -> Iterator[dict]:
|
||||
"""
|
||||
Load Wikipedia corpus from JSONL file.
|
||||
|
||||
Args:
|
||||
corpus_file: Path to corpus JSONL file
|
||||
|
||||
Yields:
|
||||
Document dictionaries with 'id', 'text', 'meta'
|
||||
"""
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
yield json.loads(line)
|
||||
|
||||
111
llmds/data_sources/yelp.py
Normal file
111
llmds/data_sources/yelp.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Yelp Open Dataset loader."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
def download_yelp(output_dir: Path) -> Path:
|
||||
"""
|
||||
Download Yelp Open Dataset.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save corpus
|
||||
|
||||
Returns:
|
||||
Path to corpus JSONL file
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
corpus_file = output_dir / "business_reviews.jsonl"
|
||||
|
||||
if corpus_file.exists():
|
||||
print(f"Yelp corpus already exists at {corpus_file}")
|
||||
return corpus_file
|
||||
|
||||
print("Yelp Open Dataset requires manual download from https://www.yelp.com/dataset")
|
||||
print("After downloading, extract business.json and review.json")
|
||||
print("Then run: python scripts/process_yelp.py --business <path> --review <path> --output <path>")
|
||||
|
||||
# Placeholder implementation
|
||||
print("Creating placeholder corpus...")
|
||||
with open(corpus_file, "w", encoding="utf-8") as f:
|
||||
for i in range(1000):
|
||||
doc = {
|
||||
"id": f"yelp_{i}",
|
||||
"text": f"Yelp business {i} review content. This is a placeholder.",
|
||||
"meta": {"business_id": f"biz_{i}", "rating": 4.5}
|
||||
}
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
return corpus_file
|
||||
|
||||
|
||||
def process_yelp_files(business_file: Path, review_file: Path, output_file: Path, limit: int | None = None) -> None:
|
||||
"""
|
||||
Process Yelp JSON files into normalized JSONL.
|
||||
|
||||
Args:
|
||||
business_file: Path to business.json
|
||||
review_file: Path to review.json
|
||||
output_file: Output JSONL path
|
||||
limit: Optional limit on documents
|
||||
"""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load businesses
|
||||
businesses = {}
|
||||
if business_file.exists():
|
||||
with open(business_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
biz = json.loads(line)
|
||||
businesses[biz["business_id"]] = biz
|
||||
|
||||
count = 0
|
||||
with open(review_file, "r", encoding="utf-8") as infile, \
|
||||
open(output_file, "w", encoding="utf-8") as outfile:
|
||||
for line in infile:
|
||||
if limit and count >= limit:
|
||||
break
|
||||
|
||||
if line.strip():
|
||||
review = json.loads(line)
|
||||
biz_id = review.get("business_id")
|
||||
biz = businesses.get(biz_id, {})
|
||||
|
||||
# Combine business name + review text
|
||||
biz_name = biz.get("name", "")
|
||||
review_text = review.get("text", "")
|
||||
combined = f"{biz_name} {review_text}".strip()
|
||||
|
||||
if combined:
|
||||
doc = {
|
||||
"id": f"yelp_{review.get('review_id', count)}",
|
||||
"text": combined,
|
||||
"meta": {
|
||||
"business_id": biz_id,
|
||||
"rating": review.get("stars"),
|
||||
"category": biz.get("categories"),
|
||||
}
|
||||
}
|
||||
outfile.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
count += 1
|
||||
|
||||
print(f"Processed {count} Yelp reviews to {output_file}")
|
||||
|
||||
|
||||
def load_yelp(corpus_file: Path) -> Iterator[dict]:
|
||||
"""
|
||||
Load Yelp corpus from JSONL file.
|
||||
|
||||
Args:
|
||||
corpus_file: Path to corpus JSONL file
|
||||
|
||||
Yields:
|
||||
Document dictionaries with 'id', 'text', 'meta'
|
||||
"""
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
yield json.loads(line)
|
||||
|
||||
291
llmds/hnsw.py
Normal file
291
llmds/hnsw.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""HNSW (Hierarchical Navigable Small World) for approximate nearest neighbor search.
|
||||
|
||||
Implementation based on:
|
||||
Malkov, Y. A., & Yashunin, D. A. (2018). Efficient and robust approximate nearest
|
||||
neighbor search using Hierarchical Navigable Small World graphs. IEEE transactions
|
||||
on pattern analysis and machine intelligence, 42(4), 824-836.
|
||||
|
||||
See docs/CITATIONS.md for full citation details.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class HNSW:
|
||||
"""
|
||||
Hierarchical Navigable Small World graph for approximate nearest neighbor search.
|
||||
|
||||
Implements HNSW with configurable M, efConstruction, and efSearch parameters.
|
||||
|
||||
Reference:
|
||||
Malkov & Yashunin (2018). Efficient and robust approximate nearest neighbor
|
||||
search using Hierarchical Navigable Small World graphs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
M: int = 16,
|
||||
ef_construction: int = 200,
|
||||
ef_search: int = 50,
|
||||
ml: float = 1.0 / np.log(2.0),
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize HNSW index.
|
||||
|
||||
Args:
|
||||
dim: Dimension of vectors
|
||||
M: Maximum number of connections for each node
|
||||
ef_construction: Size of candidate set during construction
|
||||
ef_search: Size of candidate set during search
|
||||
ml: Normalization factor for level assignment
|
||||
seed: Optional random seed for reproducible level assignments.
|
||||
If None, uses the global random state.
|
||||
"""
|
||||
self.dim = dim
|
||||
self.M = M
|
||||
self.ef_construction = ef_construction
|
||||
self.ef_search = ef_search
|
||||
self.ml = ml
|
||||
|
||||
# Instance-level random state for reproducibility
|
||||
self._rng = random.Random(seed) if seed is not None else random
|
||||
|
||||
# Layers: list of graphs, each graph is dict[node_id] -> list[neighbor_ids]
|
||||
self._layers: list[dict[int, list[int]]] = []
|
||||
self._vectors: dict[int, np.ndarray] = {} # node_id -> vector
|
||||
self._max_level: dict[int, int] = {} # node_id -> max level
|
||||
self._entry_point: Optional[int] = None
|
||||
self._entry_level = 0
|
||||
|
||||
def _random_level(self) -> int:
|
||||
"""Generate random level for new node."""
|
||||
level = 0
|
||||
while self._rng.random() < np.exp(-self.ml) and level < 10:
|
||||
level += 1
|
||||
return level
|
||||
|
||||
def _distance(self, a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Compute L2 distance between two vectors."""
|
||||
return float(np.linalg.norm(a - b))
|
||||
|
||||
def _search_layer(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
k: int,
|
||||
entry_points: list[int],
|
||||
layer: dict[int, list[int]],
|
||||
) -> list[tuple[int, float]]:
|
||||
"""
|
||||
Search in a single layer using greedy search.
|
||||
|
||||
Args:
|
||||
query: Query vector
|
||||
k: Number of results to return
|
||||
entry_points: Starting points for search
|
||||
layer: Graph layer to search
|
||||
|
||||
Returns:
|
||||
List of (node_id, distance) tuples
|
||||
"""
|
||||
if not entry_points:
|
||||
return []
|
||||
|
||||
candidates: list[tuple[float, int]] = []
|
||||
visited = set(entry_points)
|
||||
best_candidates: list[tuple[float, int]] = []
|
||||
|
||||
# Initialize candidates with entry points
|
||||
for ep in entry_points:
|
||||
if ep in self._vectors:
|
||||
dist = self._distance(query, self._vectors[ep])
|
||||
candidates.append((dist, ep))
|
||||
best_candidates.append((dist, ep))
|
||||
|
||||
# Sort by distance
|
||||
candidates.sort()
|
||||
best_candidates.sort()
|
||||
|
||||
# Greedy search
|
||||
while candidates:
|
||||
dist, current = candidates.pop(0)
|
||||
|
||||
# Explore neighbors
|
||||
if current in layer:
|
||||
for neighbor in layer[current]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
if neighbor in self._vectors:
|
||||
neighbor_dist = self._distance(query, self._vectors[neighbor])
|
||||
candidates.append((neighbor_dist, neighbor))
|
||||
best_candidates.append((neighbor_dist, neighbor))
|
||||
|
||||
# Maintain top-ef_search candidates
|
||||
candidates.sort()
|
||||
if len(candidates) > self.ef_search:
|
||||
candidates = candidates[: self.ef_search]
|
||||
|
||||
# Sort best candidates and return top-k as (node_id, distance) tuples
|
||||
best_candidates.sort()
|
||||
results = [(node_id, dist) for dist, node_id in best_candidates[:k]]
|
||||
return results
|
||||
|
||||
def add(self, vec: np.ndarray, vec_id: int) -> None:
|
||||
"""
|
||||
Add a vector to the index.
|
||||
|
||||
Args:
|
||||
vec: Vector to add (must be of dimension self.dim)
|
||||
vec_id: Unique identifier for the vector
|
||||
"""
|
||||
if vec.shape != (self.dim,):
|
||||
raise ValueError(f"Vector dimension mismatch: expected {self.dim}, got {vec.shape[0]}")
|
||||
|
||||
if vec_id in self._vectors:
|
||||
raise ValueError(f"Vector ID {vec_id} already exists")
|
||||
|
||||
self._vectors[vec_id] = vec.copy()
|
||||
level = self._random_level()
|
||||
self._max_level[vec_id] = level
|
||||
|
||||
# Ensure we have enough layers
|
||||
while len(self._layers) <= level:
|
||||
self._layers.append({})
|
||||
|
||||
# If this is the first node, set as entry point
|
||||
if self._entry_point is None:
|
||||
self._entry_point = vec_id
|
||||
self._entry_level = level
|
||||
for l in range(level + 1):
|
||||
self._layers[l][vec_id] = []
|
||||
return
|
||||
|
||||
# Search for nearest neighbors at each level
|
||||
entry_points = [self._entry_point]
|
||||
|
||||
# Start from top layer and work down
|
||||
for l in range(min(level, self._entry_level), -1, -1):
|
||||
# Search layer for candidates
|
||||
candidates = self._search_layer(
|
||||
vec, self.ef_construction, entry_points, self._layers[l]
|
||||
)
|
||||
entry_points = [node_id for node_id, _ in candidates]
|
||||
|
||||
# Insert at all levels up to node's level
|
||||
for l in range(min(level, len(self._layers) - 1) + 1):
|
||||
if l == 0:
|
||||
# Bottom layer: connect to M neighbors
|
||||
candidates = self._search_layer(vec, self.M, entry_points, self._layers[l])
|
||||
else:
|
||||
# Upper layers: connect to M neighbors
|
||||
candidates = self._search_layer(vec, self.M, entry_points, self._layers[l])
|
||||
|
||||
# Create connections
|
||||
neighbors = [node_id for node_id, _ in candidates[: self.M]]
|
||||
|
||||
if vec_id not in self._layers[l]:
|
||||
self._layers[l][vec_id] = []
|
||||
|
||||
# Add bidirectional connections
|
||||
for neighbor in neighbors:
|
||||
if neighbor not in self._layers[l]:
|
||||
self._layers[l][neighbor] = []
|
||||
self._layers[l][vec_id].append(neighbor)
|
||||
self._layers[l][neighbor].append(vec_id)
|
||||
|
||||
# Limit connections to M
|
||||
if len(self._layers[l][neighbor]) > self.M:
|
||||
# Remove farthest connection
|
||||
neighbor_vec = self._vectors[neighbor]
|
||||
distances = [
|
||||
(self._distance(self._vectors[n], neighbor_vec), n)
|
||||
for n in self._layers[l][neighbor]
|
||||
]
|
||||
distances.sort(reverse=True)
|
||||
farthest = distances[0][1]
|
||||
self._layers[l][neighbor].remove(farthest)
|
||||
if farthest in self._layers[l]:
|
||||
self._layers[l][farthest].remove(neighbor)
|
||||
|
||||
# Limit connections for new node
|
||||
if len(self._layers[l][vec_id]) > self.M:
|
||||
distances = [
|
||||
(self._distance(self._vectors[n], vec), n) for n in self._layers[l][vec_id]
|
||||
]
|
||||
distances.sort()
|
||||
self._layers[l][vec_id] = [n for _, n in distances[: self.M]]
|
||||
|
||||
entry_points = neighbors
|
||||
|
||||
# Update entry point if necessary
|
||||
if level > self._entry_level:
|
||||
self._entry_point = vec_id
|
||||
self._entry_level = level
|
||||
|
||||
def search(self, query: np.ndarray, k: int) -> list[tuple[int, float]]:
|
||||
"""
|
||||
Search for k nearest neighbors.
|
||||
|
||||
Args:
|
||||
query: Query vector
|
||||
k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of (vector_id, distance) tuples sorted by distance
|
||||
"""
|
||||
if self._entry_point is None:
|
||||
return []
|
||||
|
||||
if query.shape != (self.dim,):
|
||||
raise ValueError(f"Query dimension mismatch: expected {self.dim}, got {query.shape[0]}")
|
||||
|
||||
# Start from top layer
|
||||
current = self._entry_point
|
||||
current_level = self._entry_level
|
||||
|
||||
# Navigate down to level 0
|
||||
for l in range(current_level, 0, -1):
|
||||
if current not in self._layers[l]:
|
||||
continue
|
||||
|
||||
# Find nearest neighbor in this layer
|
||||
neighbors = self._layers[l].get(current, [])
|
||||
if not neighbors:
|
||||
continue
|
||||
|
||||
best_dist = self._distance(query, self._vectors[current])
|
||||
best_node = current
|
||||
|
||||
for neighbor in neighbors:
|
||||
if neighbor in self._vectors:
|
||||
dist = self._distance(query, self._vectors[neighbor])
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
best_node = neighbor
|
||||
|
||||
current = best_node
|
||||
|
||||
# Search layer 0
|
||||
results = self._search_layer(query, k, [current], self._layers[0])
|
||||
return results
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with index statistics
|
||||
"""
|
||||
total_edges = sum(sum(len(neighbors) for neighbors in layer.values()) for layer in self._layers)
|
||||
return {
|
||||
"num_vectors": len(self._vectors),
|
||||
"num_layers": len(self._layers),
|
||||
"entry_point": self._entry_point,
|
||||
"entry_level": self._entry_level,
|
||||
"total_edges": total_edges,
|
||||
"avg_degree": total_edges / len(self._vectors) if self._vectors else 0.0,
|
||||
}
|
||||
272
llmds/indexed_heap.py
Normal file
272
llmds/indexed_heap.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Indexed binary heap with decrease/increase-key operations."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class IndexedHeap:
|
||||
"""
|
||||
Indexed binary heap supporting O(log n) decrease/increase-key operations.
|
||||
|
||||
Maintains a heap of (score, id) pairs with an index map for O(1) lookup
|
||||
and O(log n) updates.
|
||||
"""
|
||||
|
||||
def __init__(self, max_heap: bool = False):
|
||||
"""
|
||||
Initialize indexed heap.
|
||||
|
||||
Args:
|
||||
max_heap: If True, use max-heap (largest score at top),
|
||||
otherwise min-heap (smallest score at top)
|
||||
"""
|
||||
self._heap: list[tuple[float, int]] = [] # (score, id)
|
||||
self._pos: dict[int, int] = {} # id -> index in heap
|
||||
self._max_heap = max_heap
|
||||
|
||||
def _compare(self, a: float, b: float) -> bool:
|
||||
"""Compare two scores based on heap type."""
|
||||
if self._max_heap:
|
||||
return a > b
|
||||
return a < b
|
||||
|
||||
def _swap(self, i: int, j: int) -> None:
|
||||
"""Swap elements at indices i and j, updating position map."""
|
||||
self._heap[i], self._heap[j] = self._heap[j], self._heap[i]
|
||||
_, id_i = self._heap[i]
|
||||
_, id_j = self._heap[j]
|
||||
self._pos[id_i] = i
|
||||
self._pos[id_j] = j
|
||||
|
||||
def _bubble_up(self, idx: int) -> None:
|
||||
"""Bubble up element at idx to maintain heap property."""
|
||||
while idx > 0:
|
||||
parent = (idx - 1) // 2
|
||||
score_curr, _ = self._heap[idx]
|
||||
score_parent, _ = self._heap[parent]
|
||||
|
||||
if self._compare(score_curr, score_parent):
|
||||
self._swap(idx, parent)
|
||||
idx = parent
|
||||
else:
|
||||
break
|
||||
|
||||
def _bubble_down(self, idx: int) -> None:
|
||||
"""Bubble down element at idx to maintain heap property."""
|
||||
while True:
|
||||
left = 2 * idx + 1
|
||||
right = 2 * idx + 2
|
||||
best = idx
|
||||
|
||||
if left < len(self._heap):
|
||||
score_best, _ = self._heap[best]
|
||||
score_left, _ = self._heap[left]
|
||||
if self._compare(score_left, score_best):
|
||||
best = left
|
||||
|
||||
if right < len(self._heap):
|
||||
score_best, _ = self._heap[best]
|
||||
score_right, _ = self._heap[right]
|
||||
if self._compare(score_right, score_best):
|
||||
best = right
|
||||
|
||||
if best != idx:
|
||||
self._swap(idx, best)
|
||||
idx = best
|
||||
else:
|
||||
break
|
||||
|
||||
def push(self, key_id: int, score: float) -> None:
|
||||
"""
|
||||
Push an item onto the heap.
|
||||
|
||||
Args:
|
||||
key_id: Unique identifier for the item
|
||||
score: Score/priority value
|
||||
"""
|
||||
if key_id in self._pos:
|
||||
raise ValueError(f"Key {key_id} already exists in heap")
|
||||
|
||||
idx = len(self._heap)
|
||||
self._heap.append((score, key_id))
|
||||
self._pos[key_id] = idx
|
||||
self._bubble_up(idx)
|
||||
|
||||
def pop(self) -> tuple[float, int]:
|
||||
"""
|
||||
Pop the top element from the heap.
|
||||
|
||||
Returns:
|
||||
Tuple of (score, id)
|
||||
|
||||
Raises:
|
||||
IndexError: If heap is empty
|
||||
"""
|
||||
if not self._heap:
|
||||
raise IndexError("Cannot pop from empty heap")
|
||||
|
||||
if len(self._heap) == 1:
|
||||
score, key_id = self._heap.pop()
|
||||
del self._pos[key_id]
|
||||
return score, key_id
|
||||
|
||||
# Swap root with last element
|
||||
self._swap(0, len(self._heap) - 1)
|
||||
score, key_id = self._heap.pop()
|
||||
del self._pos[key_id]
|
||||
|
||||
if self._heap:
|
||||
self._bubble_down(0)
|
||||
|
||||
return score, key_id
|
||||
|
||||
def decrease_key(self, key_id: int, new_score: float) -> None:
|
||||
"""
|
||||
Decrease the key value for an item.
|
||||
|
||||
For min-heap: new_score must be < old_score (bubble up).
|
||||
For max-heap: new_score must be < old_score (bubble down).
|
||||
|
||||
Args:
|
||||
key_id: Item identifier
|
||||
new_score: New score value
|
||||
|
||||
Raises:
|
||||
KeyError: If key_id not found
|
||||
ValueError: If new_score doesn't satisfy heap property
|
||||
"""
|
||||
if key_id not in self._pos:
|
||||
raise KeyError(f"Key {key_id} not found in heap")
|
||||
|
||||
idx = self._pos[key_id]
|
||||
old_score, _ = self._heap[idx]
|
||||
|
||||
# Validate direction - both heap types decrease when new < old
|
||||
if new_score >= old_score:
|
||||
heap_type = "max-heap" if self._max_heap else "min-heap"
|
||||
raise ValueError(f"For {heap_type}, new_score must be < old_score")
|
||||
|
||||
self._heap[idx] = (new_score, key_id)
|
||||
|
||||
# Bubble direction depends on heap type
|
||||
if self._max_heap:
|
||||
# Max-heap: decreasing score means lower priority -> bubble down
|
||||
self._bubble_down(idx)
|
||||
else:
|
||||
# Min-heap: decreasing score means higher priority -> bubble up
|
||||
self._bubble_up(idx)
|
||||
|
||||
def increase_key(self, key_id: int, new_score: float) -> None:
|
||||
"""
|
||||
Increase the key value for an item.
|
||||
|
||||
For min-heap: new_score must be > old_score (bubble down).
|
||||
For max-heap: new_score must be > old_score (bubble up).
|
||||
|
||||
Args:
|
||||
key_id: Item identifier
|
||||
new_score: New score value
|
||||
|
||||
Raises:
|
||||
KeyError: If key_id not found
|
||||
ValueError: If new_score doesn't satisfy heap property
|
||||
"""
|
||||
if key_id not in self._pos:
|
||||
raise KeyError(f"Key {key_id} not found in heap")
|
||||
|
||||
idx = self._pos[key_id]
|
||||
old_score, _ = self._heap[idx]
|
||||
|
||||
# Validate direction - both heap types increase when new > old
|
||||
if new_score <= old_score:
|
||||
heap_type = "max-heap" if self._max_heap else "min-heap"
|
||||
raise ValueError(f"For {heap_type}, new_score must be > old_score")
|
||||
|
||||
self._heap[idx] = (new_score, key_id)
|
||||
|
||||
# Bubble direction depends on heap type
|
||||
if self._max_heap:
|
||||
# Max-heap: increasing score means higher priority -> bubble up
|
||||
self._bubble_up(idx)
|
||||
else:
|
||||
# Min-heap: increasing score means lower priority -> bubble down
|
||||
self._bubble_down(idx)
|
||||
|
||||
def delete(self, key_id: int) -> tuple[float, int]:
|
||||
"""
|
||||
Delete an item from the heap.
|
||||
|
||||
Args:
|
||||
key_id: Item identifier
|
||||
|
||||
Returns:
|
||||
Tuple of (score, id) that was deleted
|
||||
|
||||
Raises:
|
||||
KeyError: If key_id not found
|
||||
"""
|
||||
if key_id not in self._pos:
|
||||
raise KeyError(f"Key {key_id} not found in heap")
|
||||
|
||||
idx = self._pos[key_id]
|
||||
score, _ = self._heap[idx]
|
||||
|
||||
# Swap with last element
|
||||
self._swap(idx, len(self._heap) - 1)
|
||||
self._heap.pop()
|
||||
del self._pos[key_id]
|
||||
|
||||
# Restore heap property
|
||||
if idx < len(self._heap):
|
||||
# Try bubbling up first (might be smaller/bigger than parent)
|
||||
parent = (idx - 1) // 2
|
||||
if idx > 0:
|
||||
score_curr, _ = self._heap[idx]
|
||||
score_parent, _ = self._heap[parent]
|
||||
if self._compare(score_curr, score_parent):
|
||||
self._bubble_up(idx)
|
||||
return score, key_id
|
||||
|
||||
# Otherwise bubble down
|
||||
self._bubble_down(idx)
|
||||
|
||||
return score, key_id
|
||||
|
||||
def peek(self) -> Optional[tuple[float, int]]:
|
||||
"""
|
||||
Peek at the top element without removing it.
|
||||
|
||||
Returns:
|
||||
Tuple of (score, id) or None if empty
|
||||
"""
|
||||
if not self._heap:
|
||||
return None
|
||||
return self._heap[0]
|
||||
|
||||
def get_score(self, key_id: int) -> Optional[float]:
|
||||
"""
|
||||
Get the score for a given key_id.
|
||||
|
||||
Args:
|
||||
key_id: Item identifier
|
||||
|
||||
Returns:
|
||||
Score value or None if not found
|
||||
"""
|
||||
if key_id not in self._pos:
|
||||
return None
|
||||
idx = self._pos[key_id]
|
||||
score, _ = self._heap[idx]
|
||||
return score
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get the number of elements in the heap."""
|
||||
return len(self._heap)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if heap is empty."""
|
||||
return len(self._heap) == 0
|
||||
|
||||
def contains(self, key_id: int) -> bool:
|
||||
"""Check if key_id exists in heap."""
|
||||
return key_id in self._pos
|
||||
|
||||
222
llmds/inverted_index.py
Normal file
222
llmds/inverted_index.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Compressed inverted index with BM25 scoring.
|
||||
|
||||
Implementation based on:
|
||||
Robertson, S., & Zaragoza, H. (2009). The probabilistic relevance framework:
|
||||
BM25 and beyond. Foundations and Trends in Information Retrieval, 3(4), 333-389.
|
||||
|
||||
See docs/CITATIONS.md for full citation details.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from llmds.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class InvertedIndex:
|
||||
"""
|
||||
Compressed inverted index with varint/zigzag encoding and BM25 scoring.
|
||||
|
||||
Stores postings lists with compression and provides BM25 retrieval.
|
||||
|
||||
Reference:
|
||||
Robertson & Zaragoza (2009). The probabilistic relevance framework:
|
||||
BM25 and beyond.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: Optional[Tokenizer] = None):
|
||||
"""
|
||||
Initialize inverted index.
|
||||
|
||||
Args:
|
||||
tokenizer: Tokenizer instance (creates default if None)
|
||||
"""
|
||||
self.tokenizer = tokenizer or Tokenizer()
|
||||
self._inverted_lists: dict[str, list[int]] = defaultdict(list) # term -> doc_ids
|
||||
self._doc_lengths: dict[int, int] = {} # doc_id -> length
|
||||
self._doc_terms: dict[int, dict[str, int]] = {} # doc_id -> term -> count
|
||||
self._total_docs = 0
|
||||
self._avg_doc_length = 0.0
|
||||
# BM25 parameters
|
||||
self.k1 = 1.2
|
||||
self.b = 0.75
|
||||
|
||||
def _encode_varint(self, value: int) -> bytes:
|
||||
"""Encode integer as varint."""
|
||||
result = bytearray()
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
def _decode_varint(self, data: bytes, offset: int) -> tuple[int, int]:
|
||||
"""Decode varint from bytes."""
|
||||
value = 0
|
||||
shift = 0
|
||||
while offset < len(data):
|
||||
byte = data[offset]
|
||||
value |= (byte & 0x7F) << shift
|
||||
offset += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
return value, offset
|
||||
|
||||
def _zigzag_encode(self, value: int) -> int:
|
||||
"""Zigzag encode for signed integers."""
|
||||
return (value << 1) ^ (value >> 31)
|
||||
|
||||
def _zigzag_decode(self, value: int) -> int:
|
||||
"""Zigzag decode."""
|
||||
return (value >> 1) ^ (-(value & 1))
|
||||
|
||||
def add_document(self, doc_id: int, text: str) -> None:
|
||||
"""
|
||||
Add a document to the index.
|
||||
|
||||
Args:
|
||||
doc_id: Document identifier
|
||||
text: Document text
|
||||
"""
|
||||
tokens = self.tokenizer.encode(text)
|
||||
term_counts: dict[str, int] = defaultdict(int)
|
||||
|
||||
# Count term frequencies
|
||||
for token_id in tokens:
|
||||
term = self.tokenizer.decode([token_id])
|
||||
if term:
|
||||
term_counts[term] += 1
|
||||
|
||||
# Update inverted lists
|
||||
for term, count in term_counts.items():
|
||||
if doc_id not in self._inverted_lists[term]:
|
||||
self._inverted_lists[term].append(doc_id)
|
||||
|
||||
# Store document metadata
|
||||
self._doc_lengths[doc_id] = len(tokens)
|
||||
self._doc_terms[doc_id] = term_counts
|
||||
|
||||
# Update average document length
|
||||
self._total_docs += 1
|
||||
total_length = sum(self._doc_lengths.values())
|
||||
self._avg_doc_length = total_length / self._total_docs if self._total_docs > 0 else 0.0
|
||||
|
||||
def _bm25_score(self, term: str, doc_id: int, query_term_freq: int) -> float:
|
||||
"""
|
||||
Calculate BM25 score for a term-document pair.
|
||||
|
||||
Args:
|
||||
term: Query term
|
||||
doc_id: Document ID
|
||||
query_term_freq: Frequency of term in query
|
||||
|
||||
Returns:
|
||||
BM25 score
|
||||
"""
|
||||
if doc_id not in self._doc_terms or term not in self._doc_terms[doc_id]:
|
||||
return 0.0
|
||||
|
||||
# Term frequency in document
|
||||
tf = self._doc_terms[doc_id][term]
|
||||
|
||||
# Document frequency
|
||||
df = len(self._inverted_lists.get(term, []))
|
||||
|
||||
# Inverse document frequency
|
||||
idf = 0.0
|
||||
if df > 0:
|
||||
idf = (self._total_docs - df + 0.5) / (df + 0.5)
|
||||
idf = max(0.0, idf) # Avoid negative IDF
|
||||
|
||||
# Document length normalization
|
||||
doc_length = self._doc_lengths.get(doc_id, 1)
|
||||
length_norm = (1 - self.b) + self.b * (doc_length / self._avg_doc_length)
|
||||
|
||||
# BM25 formula
|
||||
score = (
|
||||
idf
|
||||
* (tf * (self.k1 + 1))
|
||||
/ (tf + self.k1 * length_norm)
|
||||
* (query_term_freq / (query_term_freq + 0.5))
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
def search(self, query: str, top_k: int = 10) -> list[tuple[int, float]]:
|
||||
"""
|
||||
Search the index with BM25 scoring.
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
top_k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
List of (doc_id, score) tuples sorted by score descending
|
||||
"""
|
||||
query_tokens = self.tokenizer.encode(query)
|
||||
query_term_counts: dict[str, int] = defaultdict(int)
|
||||
|
||||
for token_id in query_tokens:
|
||||
term = self.tokenizer.decode([token_id])
|
||||
if term:
|
||||
query_term_counts[term] += 1
|
||||
|
||||
# Score all candidate documents
|
||||
doc_scores: dict[int, float] = defaultdict(float)
|
||||
|
||||
for term, query_freq in query_term_counts.items():
|
||||
if term in self._inverted_lists:
|
||||
for doc_id in self._inverted_lists[term]:
|
||||
score = self._bm25_score(term, doc_id, query_freq)
|
||||
doc_scores[doc_id] += score
|
||||
|
||||
# Sort by score and return top-k
|
||||
sorted_results = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_results[:top_k]
|
||||
|
||||
def get_term_frequency(self, term: str, doc_id: int) -> int:
|
||||
"""
|
||||
Get term frequency in a document.
|
||||
|
||||
Args:
|
||||
term: Term
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
Term frequency
|
||||
"""
|
||||
if doc_id in self._doc_terms:
|
||||
return self._doc_terms[doc_id].get(term, 0)
|
||||
return 0
|
||||
|
||||
def get_document_frequency(self, term: str) -> int:
|
||||
"""
|
||||
Get document frequency of a term.
|
||||
|
||||
Args:
|
||||
term: Term
|
||||
|
||||
Returns:
|
||||
Document frequency
|
||||
"""
|
||||
return len(self._inverted_lists.get(term, []))
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with index statistics
|
||||
"""
|
||||
total_postings = sum(len(postings) for postings in self._inverted_lists.values())
|
||||
return {
|
||||
"total_documents": self._total_docs,
|
||||
"total_terms": len(self._inverted_lists),
|
||||
"total_postings": total_postings,
|
||||
"avg_doc_length": self._avg_doc_length,
|
||||
"avg_postings_per_term": (
|
||||
total_postings / len(self._inverted_lists) if self._inverted_lists else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
281
llmds/kv_cache.py
Normal file
281
llmds/kv_cache.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""KV cache with paged allocation and prefix sharing.
|
||||
|
||||
Implementation based on techniques from:
|
||||
Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation.
|
||||
|
||||
See docs/CITATIONS.md for full citation details.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
from typing import Any, Optional
|
||||
|
||||
from llmds.paged_allocator import PagedAllocator
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""
|
||||
KV cache with paged allocation, prefix sharing, and deduplication.
|
||||
|
||||
Implements copy-on-write (COW) for prefix sharing: shared pages are
|
||||
read-only until a write occurs, at which point they are copied.
|
||||
|
||||
Reference:
|
||||
Cache-Craft: Managing Chunk-Caches for Efficient Retrieval-Augmented Generation.
|
||||
|
||||
**Copy-on-Write Semantics:**
|
||||
- Shared pages (from prefix sharing) are read-only
|
||||
- Attempts to modify shared pages trigger lazy copying
|
||||
- Each sequence maintains its own copy of modified pages
|
||||
- Original shared pages remain unchanged for other sequences
|
||||
|
||||
Supports hash-based deduplication of repeated system prompts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page_size: int = 512,
|
||||
max_pages: int = 10000,
|
||||
enable_prefix_sharing: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize KV cache.
|
||||
|
||||
Args:
|
||||
page_size: Size of each KV cache page in tokens
|
||||
max_pages: Maximum number of pages to allocate
|
||||
enable_prefix_sharing: Enable prefix sharing optimization
|
||||
"""
|
||||
self.allocator = PagedAllocator(page_size, max_pages)
|
||||
self.page_size = page_size
|
||||
self._sequences: dict[int, list[int]] = {} # seq_id -> list[page_ids]
|
||||
self._kv_data: dict[int, Any] = {} # page_id -> KV data
|
||||
self._prefix_map: dict[str, list[int]] = {} # hash -> page_ids
|
||||
self._page_refs: dict[int, int] = {} # page_id -> reference count
|
||||
self._shared_pages: set[int] = set() # page_ids that are shared (read-only)
|
||||
self._enable_prefix_sharing = enable_prefix_sharing
|
||||
self._seq_counter = 0
|
||||
self._prefix_shares = 0
|
||||
|
||||
def _hash_prefix(self, prefix: list[int]) -> str:
|
||||
"""Compute hash of prefix tokens."""
|
||||
prefix_str = ",".join(map(str, prefix[:100])) # Limit length
|
||||
return hashlib.sha256(prefix_str.encode()).hexdigest()
|
||||
|
||||
def _copy_if_shared(self, page_id: int, seq_id: int) -> int:
|
||||
"""
|
||||
Copy-on-write: if page is shared, create a new copy.
|
||||
|
||||
Args:
|
||||
page_id: Original page ID (may be shared)
|
||||
seq_id: Sequence ID requesting the copy
|
||||
|
||||
Returns:
|
||||
New page_id if copied, original page_id if not shared
|
||||
"""
|
||||
if page_id not in self._shared_pages:
|
||||
return page_id
|
||||
|
||||
# Page is shared - need to copy
|
||||
new_page_id = self.allocator.alloc(1)[0]
|
||||
|
||||
# Copy the data
|
||||
if page_id in self._kv_data:
|
||||
self._kv_data[new_page_id] = copy.deepcopy(self._kv_data[page_id])
|
||||
else:
|
||||
# Empty page
|
||||
self._kv_data[new_page_id] = []
|
||||
|
||||
# Decrement reference count of original
|
||||
self._page_refs[page_id] = self._page_refs.get(page_id, 1) - 1
|
||||
if self._page_refs[page_id] <= 0:
|
||||
self._shared_pages.discard(page_id)
|
||||
if page_id in self._page_refs:
|
||||
del self._page_refs[page_id]
|
||||
|
||||
# New page is not shared (single owner)
|
||||
self._page_refs[new_page_id] = 1
|
||||
|
||||
return new_page_id
|
||||
|
||||
def attach(
|
||||
self,
|
||||
seq_id: int,
|
||||
kv_tokens: list[Any],
|
||||
prefix_tokens: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Attach KV cache for a sequence.
|
||||
|
||||
Implements copy-on-write: if prefix sharing is used, shared pages
|
||||
are referenced but will be copied on first write.
|
||||
|
||||
Args:
|
||||
seq_id: Sequence identifier
|
||||
kv_tokens: KV tokens to cache
|
||||
prefix_tokens: Optional prefix tokens for sharing
|
||||
"""
|
||||
if seq_id in self._sequences:
|
||||
self.detach(seq_id)
|
||||
|
||||
pages_needed = (len(kv_tokens) + self.page_size - 1) // self.page_size
|
||||
new_page_ids = self.allocator.alloc(pages_needed)
|
||||
page_ids: list[int] = []
|
||||
|
||||
# Try prefix sharing if enabled
|
||||
shared_prefix_pages: list[int] = []
|
||||
if self._enable_prefix_sharing and prefix_tokens:
|
||||
prefix_hash = self._hash_prefix(prefix_tokens)
|
||||
if prefix_hash in self._prefix_map:
|
||||
shared_prefix_pages = self._prefix_map[prefix_hash]
|
||||
# Reference shared pages (will be copied on write if needed)
|
||||
num_prefix_pages = min(len(shared_prefix_pages), pages_needed)
|
||||
page_ids.extend(shared_prefix_pages[:num_prefix_pages])
|
||||
|
||||
# Update reference counts for shared pages
|
||||
for shared_page_id in shared_prefix_pages[:num_prefix_pages]:
|
||||
self._page_refs[shared_page_id] = self._page_refs.get(shared_page_id, 0) + 1
|
||||
self._shared_pages.add(shared_page_id)
|
||||
|
||||
# Use remaining allocated pages for non-shared suffix
|
||||
page_ids.extend(new_page_ids[num_prefix_pages:])
|
||||
self._prefix_shares += 1
|
||||
else:
|
||||
# First time seeing this prefix - mark these pages as potential shared
|
||||
num_prefix_pages = min(
|
||||
(len(prefix_tokens) + self.page_size - 1) // self.page_size,
|
||||
pages_needed
|
||||
)
|
||||
self._prefix_map[prefix_hash] = new_page_ids[:num_prefix_pages]
|
||||
page_ids = new_page_ids
|
||||
else:
|
||||
page_ids = new_page_ids
|
||||
|
||||
# Store KV data with copy-on-write semantics
|
||||
# For shared pages: if data differs, trigger COW; otherwise, reference existing
|
||||
for i, page_id in enumerate(page_ids):
|
||||
start = i * self.page_size
|
||||
end = min(start + self.page_size, len(kv_tokens))
|
||||
page_data = kv_tokens[start:end]
|
||||
|
||||
# Check if this page is shared
|
||||
if page_id in self._shared_pages:
|
||||
# Page is shared - check if data matches
|
||||
existing_data = self._kv_data.get(page_id, [])
|
||||
if existing_data != page_data:
|
||||
# Data differs - trigger copy-on-write
|
||||
page_id = self._copy_if_shared(page_id, seq_id)
|
||||
page_ids[i] = page_id # Update the page_id in our list
|
||||
# Now safe to write (page is not shared)
|
||||
self._kv_data[page_id] = page_data
|
||||
if page_id not in self._page_refs:
|
||||
self._page_refs[page_id] = 1
|
||||
# If data matches, no need to copy or write - just reference the shared page
|
||||
else:
|
||||
# Non-shared page - safe to write directly
|
||||
self._kv_data[page_id] = page_data
|
||||
if page_id not in self._page_refs:
|
||||
self._page_refs[page_id] = 1
|
||||
|
||||
self._sequences[seq_id] = page_ids
|
||||
|
||||
def detach(self, seq_id: int) -> None:
|
||||
"""
|
||||
Detach and free KV cache for a sequence.
|
||||
|
||||
Decrements reference counts for shared pages. Pages are only freed
|
||||
when their reference count reaches zero.
|
||||
|
||||
Args:
|
||||
seq_id: Sequence identifier
|
||||
"""
|
||||
if seq_id not in self._sequences:
|
||||
return
|
||||
|
||||
page_ids = self._sequences[seq_id]
|
||||
|
||||
# Update reference counts and free pages
|
||||
pages_to_free: list[int] = []
|
||||
for page_id in page_ids:
|
||||
if page_id in self._shared_pages:
|
||||
# Shared page - decrement reference count
|
||||
self._page_refs[page_id] = self._page_refs.get(page_id, 1) - 1
|
||||
if self._page_refs[page_id] <= 0:
|
||||
# No more references - can free
|
||||
self._shared_pages.discard(page_id)
|
||||
if page_id in self._kv_data:
|
||||
del self._kv_data[page_id]
|
||||
if page_id in self._page_refs:
|
||||
del self._page_refs[page_id]
|
||||
pages_to_free.append(page_id)
|
||||
else:
|
||||
# Non-shared page - free immediately
|
||||
if page_id in self._kv_data:
|
||||
del self._kv_data[page_id]
|
||||
if page_id in self._page_refs:
|
||||
del self._page_refs[page_id]
|
||||
pages_to_free.append(page_id)
|
||||
|
||||
# Free pages via allocator
|
||||
if pages_to_free:
|
||||
self.allocator.free(pages_to_free)
|
||||
|
||||
del self._sequences[seq_id]
|
||||
|
||||
def get(self, seq_id: int) -> Optional[list[Any]]:
|
||||
"""
|
||||
Get KV cache for a sequence.
|
||||
|
||||
Returns a copy of the data to prevent external modifications
|
||||
from affecting shared pages.
|
||||
|
||||
Args:
|
||||
seq_id: Sequence identifier
|
||||
|
||||
Returns:
|
||||
List of KV tokens or None if not found
|
||||
"""
|
||||
if seq_id not in self._sequences:
|
||||
return None
|
||||
|
||||
page_ids = self._sequences[seq_id]
|
||||
kv_tokens = []
|
||||
for page_id in page_ids:
|
||||
if page_id in self._kv_data:
|
||||
# Return copy to prevent external modification of shared pages
|
||||
kv_tokens.extend(copy.deepcopy(self._kv_data[page_id]))
|
||||
|
||||
return kv_tokens
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
alloc_stats = self.allocator.stats()
|
||||
return {
|
||||
"total_sequences": len(self._sequences),
|
||||
"total_pages": alloc_stats.total_pages,
|
||||
"allocated_pages": alloc_stats.allocated_pages,
|
||||
"free_pages": alloc_stats.free_pages,
|
||||
"prefix_shares": self._prefix_shares,
|
||||
"prefix_map_size": len(self._prefix_map),
|
||||
"shared_pages_count": len(self._shared_pages),
|
||||
"total_page_refs": sum(self._page_refs.values()),
|
||||
}
|
||||
|
||||
def hook_speculative_decode(self, seq_id: int, draft_tokens: list[int]) -> None:
|
||||
"""
|
||||
Hook for speculative decoding compatibility.
|
||||
|
||||
Placeholder API for future implementation.
|
||||
|
||||
Args:
|
||||
seq_id: Sequence identifier
|
||||
draft_tokens: Draft tokens from speculative decoding
|
||||
"""
|
||||
# Placeholder for speculative decoding integration
|
||||
pass
|
||||
|
||||
117
llmds/paged_allocator.py
Normal file
117
llmds/paged_allocator.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Paged memory allocator with slab allocation for KV cache."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class PageStats:
|
||||
"""Statistics for page allocation."""
|
||||
|
||||
total_pages: int
|
||||
allocated_pages: int
|
||||
free_pages: int
|
||||
fragmentation_ratio: float
|
||||
allocation_count: int
|
||||
free_count: int
|
||||
|
||||
|
||||
class PagedAllocator:
|
||||
"""
|
||||
Paged memory allocator with fixed-size pages and freelist management.
|
||||
|
||||
Uses a slab allocator approach with freelists for efficient allocation
|
||||
and deallocation of fixed-size page blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, page_size: int, max_pages: int):
|
||||
"""
|
||||
Initialize the paged allocator.
|
||||
|
||||
Args:
|
||||
page_size: Size of each page in tokens/bytes
|
||||
max_pages: Maximum number of pages to allocate
|
||||
"""
|
||||
self.page_size = page_size
|
||||
self.max_pages = max_pages
|
||||
self._pages: list[Optional[bool]] = [None] * max_pages # None=free, True=allocated
|
||||
self._free_list: list[int] = list(range(max_pages))
|
||||
self._allocation_count = 0
|
||||
self._free_count = 0
|
||||
|
||||
def alloc(self, num_pages: int) -> list[int]:
|
||||
"""
|
||||
Allocate a contiguous block of pages.
|
||||
|
||||
Args:
|
||||
num_pages: Number of pages to allocate
|
||||
|
||||
Returns:
|
||||
List of page IDs (indices)
|
||||
|
||||
Raises:
|
||||
ValueError: If insufficient pages available
|
||||
"""
|
||||
if len(self._free_list) < num_pages:
|
||||
raise ValueError(f"Insufficient pages: requested {num_pages}, available {len(self._free_list)}")
|
||||
|
||||
allocated = []
|
||||
for _ in range(num_pages):
|
||||
page_id = self._free_list.pop(0)
|
||||
self._pages[page_id] = True
|
||||
allocated.append(page_id)
|
||||
self._allocation_count += 1
|
||||
|
||||
return allocated
|
||||
|
||||
def free(self, page_ids: list[int]) -> None:
|
||||
"""
|
||||
Free a list of pages.
|
||||
|
||||
Args:
|
||||
page_ids: List of page IDs to free
|
||||
"""
|
||||
for page_id in page_ids:
|
||||
if 0 <= page_id < self.max_pages and self._pages[page_id] is True:
|
||||
self._pages[page_id] = None
|
||||
self._free_list.append(page_id)
|
||||
self._free_count += 1
|
||||
|
||||
def stats(self) -> PageStats:
|
||||
"""
|
||||
Get allocation statistics.
|
||||
|
||||
Returns:
|
||||
PageStats object with current statistics
|
||||
"""
|
||||
allocated = sum(1 for p in self._pages if p is True)
|
||||
free = len(self._free_list)
|
||||
fragmentation = 1.0 - (free / self.max_pages) if self.max_pages > 0 else 0.0
|
||||
|
||||
return PageStats(
|
||||
total_pages=self.max_pages,
|
||||
allocated_pages=allocated,
|
||||
free_pages=free,
|
||||
fragmentation_ratio=fragmentation,
|
||||
allocation_count=self._allocation_count,
|
||||
free_count=self._free_count,
|
||||
)
|
||||
|
||||
def defragment(self) -> None:
|
||||
"""
|
||||
Defragment pages by compacting allocated pages.
|
||||
|
||||
This is a simple implementation that moves allocated pages
|
||||
to the front. More sophisticated strategies could be implemented.
|
||||
"""
|
||||
allocated_indices = [i for i, p in enumerate(self._pages) if p is True]
|
||||
free_indices = [i for i, p in enumerate(self._pages) if p is None]
|
||||
|
||||
# Simple compaction: move allocated pages to front
|
||||
new_pages: list[bool | None] = [None] * self.max_pages
|
||||
for i, idx in enumerate(allocated_indices):
|
||||
new_pages[i] = True
|
||||
|
||||
self._pages = new_pages
|
||||
self._free_list = list(range(len(allocated_indices), self.max_pages))
|
||||
|
||||
213
llmds/retrieval_pipeline.py
Normal file
213
llmds/retrieval_pipeline.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Retrieval pipeline combining ANN, lexical search, and fusion."""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from llmds.cmsketch import CountMinSketch
|
||||
from llmds.hnsw import HNSW
|
||||
from llmds.indexed_heap import IndexedHeap
|
||||
from llmds.inverted_index import InvertedIndex
|
||||
from llmds.token_lru import TokenLRU
|
||||
from llmds.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class RetrievalPipeline:
|
||||
"""
|
||||
End-to-end retrieval pipeline combining ANN, lexical search, and fusion.
|
||||
|
||||
Combines HNSW for dense embeddings, inverted index for BM25,
|
||||
and score fusion with top-K maintenance using indexed heap.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 384,
|
||||
hnsw_M: int = 16,
|
||||
hnsw_ef_construction: int = 200,
|
||||
hnsw_ef_search: int = 50,
|
||||
token_budget: int = 100000,
|
||||
tokenizer: Optional[Tokenizer] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize retrieval pipeline.
|
||||
|
||||
Args:
|
||||
embedding_dim: Dimension of embedding vectors
|
||||
hnsw_M: HNSW M parameter
|
||||
hnsw_ef_construction: HNSW efConstruction parameter
|
||||
hnsw_ef_search: HNSW efSearch parameter
|
||||
token_budget: Token budget for cache
|
||||
tokenizer: Tokenizer instance
|
||||
seed: Optional random seed for HNSW reproducibility (default: None)
|
||||
"""
|
||||
self.tokenizer = tokenizer or Tokenizer()
|
||||
self.hnsw = HNSW(
|
||||
dim=embedding_dim,
|
||||
M=hnsw_M,
|
||||
ef_construction=hnsw_ef_construction,
|
||||
ef_search=hnsw_ef_search,
|
||||
seed=seed,
|
||||
)
|
||||
self.inverted_index = InvertedIndex(tokenizer=self.tokenizer)
|
||||
self.cmsketch = CountMinSketch(width=2048, depth=4)
|
||||
self.token_cache: TokenLRU[str, str] = TokenLRU[str, str](
|
||||
token_budget=token_budget,
|
||||
token_of=lambda text: self.tokenizer.count_tokens(text),
|
||||
)
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
doc_id: int,
|
||||
text: str,
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a document to both indices.
|
||||
|
||||
Args:
|
||||
doc_id: Document identifier
|
||||
text: Document text
|
||||
embedding: Optional embedding vector (if None, generates random)
|
||||
"""
|
||||
# Add to inverted index
|
||||
self.inverted_index.add_document(doc_id, text)
|
||||
|
||||
# Add to HNSW if embedding provided
|
||||
if embedding is not None:
|
||||
if embedding.shape != (self.hnsw.dim,):
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch: expected {self.hnsw.dim}, "
|
||||
f"got {embedding.shape[0]}"
|
||||
)
|
||||
self.hnsw.add(embedding, doc_id)
|
||||
else:
|
||||
# Generate random embedding for testing
|
||||
random_embedding = np.random.randn(self.hnsw.dim).astype(np.float32)
|
||||
random_embedding = random_embedding / np.linalg.norm(random_embedding)
|
||||
self.hnsw.add(random_embedding, doc_id)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: Optional[np.ndarray] = None,
|
||||
top_k: int = 10,
|
||||
fusion_weight: float = 0.5,
|
||||
) -> list[tuple[int, float]]:
|
||||
"""
|
||||
Search with hybrid retrieval and score fusion.
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
query_embedding: Optional query embedding vector
|
||||
top_k: Number of results to return
|
||||
fusion_weight: Weight for dense search (1-fusion_weight for BM25)
|
||||
|
||||
Returns:
|
||||
List of (doc_id, fused_score) tuples
|
||||
"""
|
||||
# Check cache
|
||||
cached = self.token_cache.get(query)
|
||||
if cached:
|
||||
self.cmsketch.add(query)
|
||||
# Parse cached string back to list of tuples
|
||||
import ast
|
||||
try:
|
||||
parsed_results = ast.literal_eval(cached)
|
||||
if isinstance(parsed_results, list):
|
||||
return parsed_results
|
||||
except (ValueError, SyntaxError):
|
||||
pass # Fall through to compute results
|
||||
|
||||
# BM25 search
|
||||
bm25_results = self.inverted_index.search(query, top_k=top_k * 2)
|
||||
|
||||
# Dense search (if embedding provided)
|
||||
dense_results = []
|
||||
if query_embedding is not None:
|
||||
dense_results = self.hnsw.search(query_embedding, k=top_k * 2)
|
||||
|
||||
# Normalize scores
|
||||
bm25_scores: dict[int, float] = {doc_id: score for doc_id, score in bm25_results}
|
||||
dense_scores: dict[int, float] = {}
|
||||
|
||||
if dense_results:
|
||||
max_dense = max(dist for _, dist in dense_results) if dense_results else 1.0
|
||||
min_dense = min(dist for _, dist in dense_results) if dense_results else 0.0
|
||||
dense_range = max_dense - min_dense if max_dense > min_dense else 1.0
|
||||
|
||||
for doc_id, dist in dense_results: # HNSW.search returns (node_id, distance)
|
||||
# Convert distance to similarity (inverse)
|
||||
normalized = 1.0 - (dist - min_dense) / dense_range if dense_range > 0 else 1.0
|
||||
dense_scores[doc_id] = normalized
|
||||
|
||||
# Normalize BM25 scores
|
||||
if bm25_scores:
|
||||
max_bm25 = max(bm25_scores.values())
|
||||
min_bm25 = min(bm25_scores.values())
|
||||
bm25_range = max_bm25 - min_bm25 if max_bm25 > min_bm25 else 1.0
|
||||
|
||||
for doc_id in bm25_scores:
|
||||
bm25_scores[doc_id] = (
|
||||
(bm25_scores[doc_id] - min_bm25) / bm25_range if bm25_range > 0 else 1.0
|
||||
)
|
||||
|
||||
# Fuse scores using indexed heap
|
||||
fused_scores: dict[int, float] = {}
|
||||
all_doc_ids = set(bm25_scores.keys()) | set(dense_scores.keys())
|
||||
|
||||
for doc_id in all_doc_ids:
|
||||
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||
dense_score = dense_scores.get(doc_id, 0.0)
|
||||
|
||||
# Weighted fusion
|
||||
fused_score = fusion_weight * dense_score + (1 - fusion_weight) * bm25_score
|
||||
fused_scores[doc_id] = fused_score
|
||||
|
||||
# Top-K using indexed heap
|
||||
heap = IndexedHeap(max_heap=True)
|
||||
for doc_id, score in fused_scores.items():
|
||||
if heap.size() < top_k:
|
||||
heap.push(doc_id, score)
|
||||
else:
|
||||
peek_result = heap.peek()
|
||||
if peek_result is not None:
|
||||
min_score, _ = peek_result
|
||||
if min_score is not None and score > min_score:
|
||||
heap.pop()
|
||||
heap.push(doc_id, score)
|
||||
|
||||
# Extract results
|
||||
results = []
|
||||
while not heap.is_empty():
|
||||
score, doc_id = heap.pop()
|
||||
results.append((doc_id, score))
|
||||
|
||||
results.reverse() # Highest score first
|
||||
|
||||
# Cache results (store as string representation for token counting)
|
||||
results_str = str(results)
|
||||
self.token_cache.put(query, results_str)
|
||||
self.cmsketch.add(query)
|
||||
|
||||
return results
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get pipeline statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with pipeline statistics
|
||||
"""
|
||||
hnsw_stats = self.hnsw.stats()
|
||||
index_stats = self.inverted_index.stats()
|
||||
|
||||
return {
|
||||
"hnsw": hnsw_stats,
|
||||
"inverted_index": index_stats,
|
||||
"cmsketch_total_count": self.cmsketch.get_total_count(),
|
||||
"cache_size": self.token_cache.size(),
|
||||
"cache_tokens": self.token_cache.total_tokens(),
|
||||
}
|
||||
|
||||
216
llmds/scheduler.py
Normal file
216
llmds/scheduler.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Dynamic micro-batching scheduler with priority queue."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from llmds.indexed_heap import IndexedHeap
|
||||
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
"""Represents a request in the scheduler."""
|
||||
|
||||
request_id: int
|
||||
tokens: int
|
||||
priority: float # Higher = more priority
|
||||
created_at: float
|
||||
slo_ms: Optional[float] = None # Service level objective in milliseconds
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
Dynamic micro-batching scheduler with priority-based queuing.
|
||||
|
||||
Uses an indexed heap to prioritize sequences by remaining length or SLO.
|
||||
Supports dynamic batching with configurable waiting time vs. throughput trade-offs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_batch_size: int = 32,
|
||||
max_wait_ms: float = 50.0,
|
||||
priority_fn: Optional[Callable[[Request], float]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize scheduler.
|
||||
|
||||
Args:
|
||||
max_batch_size: Maximum batch size
|
||||
max_wait_ms: Maximum wait time in milliseconds before batching
|
||||
priority_fn: Optional function to compute priority from request.
|
||||
Default: prioritize by remaining tokens (inverse)
|
||||
"""
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_wait_ms = max_wait_ms
|
||||
self._heap = IndexedHeap(max_heap=True) # Max heap for priority
|
||||
self._requests: dict[int, Request] = {}
|
||||
self._priority_fn = priority_fn or self._default_priority_fn
|
||||
self._request_counter = 0
|
||||
self._batch_count = 0
|
||||
self._total_processed = 0
|
||||
|
||||
def _default_priority_fn(self, req: Request) -> float:
|
||||
"""Default priority: higher priority for shorter sequences (inverse of tokens)."""
|
||||
return 1.0 / (req.tokens + 1.0)
|
||||
|
||||
def _slo_priority_fn(self, req: Request) -> float:
|
||||
"""Priority based on SLO deadline."""
|
||||
if req.slo_ms is None:
|
||||
return self._default_priority_fn(req)
|
||||
|
||||
elapsed_ms = (time.time() - req.created_at) * 1000
|
||||
remaining_ms = req.slo_ms - elapsed_ms
|
||||
if remaining_ms <= 0:
|
||||
return float("inf") # Urgent: past deadline
|
||||
return 1.0 / (remaining_ms + 1.0)
|
||||
|
||||
def submit(self, tokens: int, slo_ms: Optional[float] = None) -> int:
|
||||
"""
|
||||
Submit a request to the scheduler.
|
||||
|
||||
Args:
|
||||
tokens: Estimated token count for the request
|
||||
slo_ms: Optional SLO deadline in milliseconds
|
||||
|
||||
Returns:
|
||||
Request ID
|
||||
"""
|
||||
req_id = self._request_counter
|
||||
self._request_counter += 1
|
||||
|
||||
req = Request(
|
||||
request_id=req_id,
|
||||
tokens=tokens,
|
||||
priority=self._priority_fn(
|
||||
Request(
|
||||
request_id=req_id,
|
||||
tokens=tokens,
|
||||
priority=0.0,
|
||||
created_at=time.time(),
|
||||
slo_ms=slo_ms,
|
||||
)
|
||||
),
|
||||
created_at=time.time(),
|
||||
slo_ms=slo_ms,
|
||||
)
|
||||
|
||||
self._requests[req_id] = req
|
||||
self._heap.push(req_id, req.priority)
|
||||
|
||||
return req_id
|
||||
|
||||
def get_batch(self, force: bool = False) -> Optional[list[int]]:
|
||||
"""
|
||||
Get next batch of requests to process.
|
||||
|
||||
Args:
|
||||
force: If True, return batch even if not full
|
||||
|
||||
Returns:
|
||||
List of request IDs or None if no batch ready
|
||||
"""
|
||||
if self._heap.is_empty():
|
||||
return None
|
||||
|
||||
# Check if oldest request exceeds max wait time
|
||||
oldest_req_id = None
|
||||
oldest_time = float("inf")
|
||||
|
||||
for req_id in self._requests:
|
||||
if self._requests[req_id].created_at < oldest_time:
|
||||
oldest_time = self._requests[req_id].created_at
|
||||
oldest_req_id = req_id
|
||||
|
||||
if oldest_req_id:
|
||||
wait_time_ms = (time.time() - oldest_time) * 1000
|
||||
if not force and wait_time_ms < self.max_wait_ms:
|
||||
return None
|
||||
|
||||
# Build batch from heap
|
||||
batch: list[int] = []
|
||||
temp_heap = IndexedHeap(max_heap=True)
|
||||
|
||||
# Pop top requests
|
||||
while len(batch) < self.max_batch_size and not self._heap.is_empty():
|
||||
_, req_id = self._heap.pop()
|
||||
if req_id in self._requests:
|
||||
batch.append(req_id)
|
||||
else:
|
||||
temp_heap.push(req_id, self._requests[req_id].priority)
|
||||
|
||||
# Restore heap (add back any that weren't used)
|
||||
while not temp_heap.is_empty():
|
||||
_, req_id = temp_heap.pop()
|
||||
self._heap.push(req_id, self._requests[req_id].priority)
|
||||
|
||||
if batch:
|
||||
self._batch_count += 1
|
||||
self._total_processed += len(batch)
|
||||
return batch
|
||||
|
||||
return None
|
||||
|
||||
def complete_batch(self, request_ids: list[int]) -> None:
|
||||
"""
|
||||
Mark a batch as completed and remove requests.
|
||||
|
||||
Args:
|
||||
request_ids: List of completed request IDs
|
||||
"""
|
||||
for req_id in request_ids:
|
||||
if req_id in self._requests:
|
||||
# Try to remove from heap if present
|
||||
if self._heap.contains(req_id):
|
||||
try:
|
||||
self._heap.delete(req_id)
|
||||
except KeyError:
|
||||
pass
|
||||
del self._requests[req_id]
|
||||
|
||||
def update_priority(self, request_id: int, new_tokens: int) -> None:
|
||||
"""
|
||||
Update priority for a request (e.g., after partial processing).
|
||||
|
||||
Args:
|
||||
request_id: Request identifier
|
||||
new_tokens: Updated token count
|
||||
"""
|
||||
if request_id not in self._requests:
|
||||
return
|
||||
|
||||
req = self._requests[request_id]
|
||||
req.tokens = new_tokens
|
||||
new_priority = self._priority_fn(req)
|
||||
|
||||
if self._heap.contains(request_id):
|
||||
old_priority = self._heap.get_score(request_id)
|
||||
if old_priority is not None:
|
||||
if new_priority > old_priority:
|
||||
self._heap.increase_key(request_id, new_priority)
|
||||
else:
|
||||
self._heap.decrease_key(request_id, new_priority)
|
||||
else:
|
||||
self._heap.push(request_id, new_priority)
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get scheduler statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with scheduler statistics
|
||||
"""
|
||||
return {
|
||||
"queue_size": len(self._requests),
|
||||
"batch_count": self._batch_count,
|
||||
"total_processed": self._total_processed,
|
||||
"avg_batch_size": (
|
||||
self._total_processed / self._batch_count if self._batch_count > 0 else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all pending requests."""
|
||||
self._heap = IndexedHeap(max_heap=True)
|
||||
self._requests.clear()
|
||||
|
||||
120
llmds/token_lru.py
Normal file
120
llmds/token_lru.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Token-aware LRU cache with eviction until budget."""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Generic, Optional, TypeVar
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class TokenLRU(Generic[K, V]):
|
||||
"""
|
||||
Token-aware LRU cache that evicts items until budget is satisfied.
|
||||
|
||||
Evicts least recently used items until the total token count
|
||||
fits within the specified budget.
|
||||
"""
|
||||
|
||||
def __init__(self, token_budget: int, token_of: Callable[[V], int]):
|
||||
"""
|
||||
Initialize token-aware LRU cache.
|
||||
|
||||
Args:
|
||||
token_budget: Maximum total tokens allowed
|
||||
token_of: Function to extract token count from a value
|
||||
"""
|
||||
self.budget = token_budget
|
||||
self.token_of = token_of
|
||||
self._cache: OrderedDict[K, V] = OrderedDict()
|
||||
self._total_tokens = 0
|
||||
|
||||
def put(self, key: K, value: V) -> None:
|
||||
"""
|
||||
Add or update an item in the cache.
|
||||
|
||||
Evicts LRU items until budget is satisfied.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Cache value
|
||||
"""
|
||||
token_count = self.token_of(value)
|
||||
|
||||
# If key exists, remove old value first
|
||||
if key in self._cache:
|
||||
old_value = self._cache[key]
|
||||
self._total_tokens -= self.token_of(old_value)
|
||||
del self._cache[key]
|
||||
|
||||
# Evict LRU items until we have space
|
||||
while self._total_tokens + token_count > self.budget and self._cache:
|
||||
self._evict_lru()
|
||||
|
||||
# Add new item
|
||||
if self._total_tokens + token_count <= self.budget:
|
||||
self._cache[key] = value
|
||||
self._total_tokens += token_count
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
def get(self, key: K) -> Optional[V]:
|
||||
"""
|
||||
Get an item from the cache.
|
||||
|
||||
Moves item to end (most recently used).
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found
|
||||
"""
|
||||
if key not in self._cache:
|
||||
return None
|
||||
|
||||
value = self._cache[key]
|
||||
self._cache.move_to_end(key)
|
||||
return value
|
||||
|
||||
def _evict_lru(self) -> tuple[K, V]:
|
||||
"""
|
||||
Evict the least recently used item.
|
||||
|
||||
Returns:
|
||||
Tuple of (key, value) that was evicted
|
||||
"""
|
||||
if not self._cache:
|
||||
raise RuntimeError("Cannot evict from empty cache")
|
||||
|
||||
key, value = self._cache.popitem(last=False)
|
||||
self._total_tokens -= self.token_of(value)
|
||||
return key, value
|
||||
|
||||
def evict_until_budget(self, target_budget: int) -> list[tuple[K, V]]:
|
||||
"""
|
||||
Evict items until total tokens <= target_budget.
|
||||
|
||||
Args:
|
||||
target_budget: Target token budget
|
||||
|
||||
Returns:
|
||||
List of (key, value) tuples that were evicted
|
||||
"""
|
||||
evicted = []
|
||||
while self._total_tokens > target_budget and self._cache:
|
||||
evicted.append(self._evict_lru())
|
||||
return evicted
|
||||
|
||||
def total_tokens(self) -> int:
|
||||
"""Get total tokens currently in cache."""
|
||||
return self._total_tokens
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get number of items in cache."""
|
||||
return len(self._cache)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all items from cache."""
|
||||
self._cache.clear()
|
||||
self._total_tokens = 0
|
||||
|
||||
149
llmds/tokenizer.py
Normal file
149
llmds/tokenizer.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Simple BPE-style tokenizer interface."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Simple tokenizer interface with BPE-style stub implementation.
|
||||
|
||||
Provides a pluggable interface for tokenization that can be
|
||||
extended with real tokenizers (e.g., tiktoken, transformers).
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int = 50257):
|
||||
"""
|
||||
Initialize tokenizer.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size (default GPT-2 like)
|
||||
"""
|
||||
self.vocab_size = vocab_size
|
||||
self._word_to_id: dict[str, int] = {}
|
||||
self._id_to_word: dict[int, str] = {}
|
||||
self._build_simple_vocab()
|
||||
|
||||
def _build_simple_vocab(self) -> None:
|
||||
"""Build a simple vocabulary for testing."""
|
||||
# Simple vocabulary: common words + special tokens
|
||||
special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]
|
||||
common_words = [
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"or",
|
||||
"but",
|
||||
"in",
|
||||
"on",
|
||||
"at",
|
||||
"to",
|
||||
"for",
|
||||
"of",
|
||||
"with",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"is",
|
||||
"was",
|
||||
"are",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"should",
|
||||
"could",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"can",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"i",
|
||||
"you",
|
||||
"he",
|
||||
"she",
|
||||
"it",
|
||||
"we",
|
||||
"they",
|
||||
]
|
||||
|
||||
all_tokens = special_tokens + common_words
|
||||
for i, token in enumerate(all_tokens[: self.vocab_size]):
|
||||
self._word_to_id[token] = i
|
||||
self._id_to_word[i] = token
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
"""
|
||||
Encode text to token IDs.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
"""
|
||||
# Simple whitespace-based tokenization
|
||||
words = text.lower().split()
|
||||
token_ids = []
|
||||
unk_id = self._word_to_id.get("<unk>", 0)
|
||||
|
||||
for word in words:
|
||||
# Simple BPE-like: try full word, then fallback to char-level
|
||||
if word in self._word_to_id:
|
||||
token_ids.append(self._word_to_id[word])
|
||||
else:
|
||||
# Character-level fallback
|
||||
for char in word:
|
||||
char_token = f"<char_{char}>"
|
||||
if char_token in self._word_to_id:
|
||||
token_ids.append(self._word_to_id[char_token])
|
||||
else:
|
||||
token_ids.append(unk_id)
|
||||
|
||||
return token_ids
|
||||
|
||||
def decode(self, token_ids: list[int]) -> str:
|
||||
"""
|
||||
Decode token IDs to text.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs
|
||||
|
||||
Returns:
|
||||
Decoded text
|
||||
"""
|
||||
words = []
|
||||
for token_id in token_ids:
|
||||
if token_id in self._id_to_word:
|
||||
word = self._id_to_word[token_id]
|
||||
if not word.startswith("<"):
|
||||
words.append(word)
|
||||
return " ".join(words)
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Count tokens in text.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
return len(self.encode(text))
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
"""Get vocabulary size."""
|
||||
return self.vocab_size
|
||||
|
||||
250
llmds/utils.py
Normal file
250
llmds/utils.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Utility functions."""
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Iterator, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import psutil
|
||||
_PSUTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
_PSUTIL_AVAILABLE = False
|
||||
psutil = None # type: ignore
|
||||
|
||||
try:
|
||||
from scipy import stats
|
||||
HAS_SCIPY = True
|
||||
except ImportError:
|
||||
HAS_SCIPY = False
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Simple timer context manager."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start: float | None = None
|
||||
self.elapsed: float = 0.0
|
||||
|
||||
def __enter__(self) -> "Timer":
|
||||
self.start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> Literal[False]:
|
||||
if self.start is not None:
|
||||
self.elapsed = time.perf_counter() - self.start
|
||||
return False
|
||||
|
||||
|
||||
class MemoryProfiler:
|
||||
"""
|
||||
Memory profiler for measuring peak RSS (Resident Set Size).
|
||||
|
||||
Tracks memory usage during benchmark execution and reports peak RSS.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize memory profiler."""
|
||||
if not _PSUTIL_AVAILABLE:
|
||||
raise ImportError("psutil is required for memory profiling. Install with: pip install psutil")
|
||||
|
||||
self.process = psutil.Process()
|
||||
self.initial_rss: Optional[int] = None
|
||||
self.peak_rss: int = 0
|
||||
self.current_rss: int = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start memory profiling."""
|
||||
self.initial_rss = self.process.memory_info().rss
|
||||
self.peak_rss = self.initial_rss
|
||||
self.current_rss = self.initial_rss
|
||||
|
||||
def sample(self) -> int:
|
||||
"""
|
||||
Sample current RSS and update peak.
|
||||
|
||||
Returns:
|
||||
Current RSS in bytes
|
||||
"""
|
||||
if not _PSUTIL_AVAILABLE:
|
||||
return 0
|
||||
|
||||
self.current_rss = self.process.memory_info().rss
|
||||
if self.current_rss > self.peak_rss:
|
||||
self.peak_rss = self.current_rss
|
||||
return self.current_rss
|
||||
|
||||
def get_peak_rss_mb(self) -> float:
|
||||
"""
|
||||
Get peak RSS in megabytes.
|
||||
|
||||
Returns:
|
||||
Peak RSS in MB
|
||||
"""
|
||||
return self.peak_rss / (1024 * 1024)
|
||||
|
||||
def get_peak_rss_bytes(self) -> int:
|
||||
"""
|
||||
Get peak RSS in bytes.
|
||||
|
||||
Returns:
|
||||
Peak RSS in bytes
|
||||
"""
|
||||
return self.peak_rss
|
||||
|
||||
def get_current_rss_mb(self) -> float:
|
||||
"""
|
||||
Get current RSS in megabytes.
|
||||
|
||||
Returns:
|
||||
Current RSS in MB
|
||||
"""
|
||||
return self.current_rss / (1024 * 1024)
|
||||
|
||||
def get_memory_delta_mb(self) -> float:
|
||||
"""
|
||||
Get memory delta from initial RSS in megabytes.
|
||||
|
||||
Returns:
|
||||
Memory delta in MB (peak - initial)
|
||||
"""
|
||||
if self.initial_rss is None:
|
||||
return 0.0
|
||||
return (self.peak_rss - self.initial_rss) / (1024 * 1024)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def memory_profiler() -> Iterator[MemoryProfiler]:
|
||||
"""
|
||||
Context manager for memory profiling.
|
||||
|
||||
Usage:
|
||||
with memory_profiler() as profiler:
|
||||
# Your code here
|
||||
profiler.sample() # Optional: sample at specific points
|
||||
peak_rss_mb = profiler.get_peak_rss_mb()
|
||||
|
||||
Yields:
|
||||
MemoryProfiler instance
|
||||
"""
|
||||
if not _PSUTIL_AVAILABLE:
|
||||
# Return dummy profiler if psutil not available
|
||||
class DummyProfiler:
|
||||
def start(self) -> None: pass
|
||||
def sample(self) -> int: return 0
|
||||
def get_peak_rss_mb(self) -> float: return 0.0
|
||||
def get_peak_rss_bytes(self) -> int: return 0
|
||||
def get_current_rss_mb(self) -> float: return 0.0
|
||||
def get_memory_delta_mb(self) -> float: return 0.0
|
||||
|
||||
profiler = DummyProfiler() # type: ignore
|
||||
profiler.start()
|
||||
yield profiler
|
||||
return
|
||||
|
||||
profiler = MemoryProfiler()
|
||||
profiler.start()
|
||||
try:
|
||||
yield profiler
|
||||
# Final sample to capture any last-minute allocations
|
||||
profiler.sample()
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def compute_percentiles(values: list[float]) -> dict[str, float]:
|
||||
"""
|
||||
Compute P50, P95, P99 percentiles from a list of values.
|
||||
|
||||
Args:
|
||||
values: List of numeric values
|
||||
|
||||
Returns:
|
||||
Dictionary with p50, p95, p99 keys
|
||||
"""
|
||||
if not values:
|
||||
return {"p50": 0.0, "p95": 0.0, "p99": 0.0}
|
||||
|
||||
sorted_values = sorted(values)
|
||||
n = len(sorted_values)
|
||||
|
||||
return {
|
||||
"p50": sorted_values[n // 2],
|
||||
"p95": sorted_values[int(n * 0.95)] if n > 1 else sorted_values[0],
|
||||
"p99": sorted_values[int(n * 0.99)] if n > 1 else sorted_values[0],
|
||||
}
|
||||
|
||||
|
||||
def calculate_statistics(values: list[float], confidence_level: float = 0.95) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate statistical summary for a list of values.
|
||||
|
||||
Args:
|
||||
values: List of numeric values
|
||||
confidence_level: Confidence level (e.g., 0.95 for 95% CI)
|
||||
|
||||
Returns:
|
||||
Dictionary with mean, std, min, max, percentiles, and confidence intervals
|
||||
"""
|
||||
if not values:
|
||||
return {
|
||||
"mean": 0.0,
|
||||
"std": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 0.0,
|
||||
"p50": 0.0,
|
||||
"p95": 0.0,
|
||||
"p99": 0.0,
|
||||
"ci_lower": 0.0,
|
||||
"ci_upper": 0.0,
|
||||
"cv": 0.0, # Coefficient of variation
|
||||
}
|
||||
|
||||
values_array = np.array(values)
|
||||
mean = float(np.mean(values_array))
|
||||
std = float(np.std(values_array, ddof=1)) # Sample std dev (ddof=1)
|
||||
min_val = float(np.min(values_array))
|
||||
max_val = float(np.max(values_array))
|
||||
|
||||
# Percentiles
|
||||
p50 = float(np.percentile(values_array, 50))
|
||||
p95 = float(np.percentile(values_array, 95))
|
||||
p99 = float(np.percentile(values_array, 99))
|
||||
|
||||
# Confidence interval (t-distribution for small samples)
|
||||
n = len(values)
|
||||
if n > 1:
|
||||
alpha = 1 - confidence_level
|
||||
if HAS_SCIPY:
|
||||
# Use t-distribution for small samples
|
||||
t_critical = stats.t.ppf(1 - alpha / 2, df=n - 1)
|
||||
margin = t_critical * (std / np.sqrt(n))
|
||||
else:
|
||||
# Fallback: use normal distribution approximation (z-score)
|
||||
# For 95% CI: z = 1.96, for 90% CI: z = 1.645
|
||||
z_scores = {0.90: 1.645, 0.95: 1.96, 0.99: 2.576}
|
||||
z_critical = z_scores.get(confidence_level, 1.96)
|
||||
margin = z_critical * (std / np.sqrt(n))
|
||||
ci_lower = mean - margin
|
||||
ci_upper = mean + margin
|
||||
else:
|
||||
ci_lower = mean
|
||||
ci_upper = mean
|
||||
|
||||
# Coefficient of variation (relative standard deviation)
|
||||
cv = (std / mean * 100) if mean > 0 else 0.0
|
||||
|
||||
return {
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"min": min_val,
|
||||
"max": max_val,
|
||||
"p50": p50,
|
||||
"p95": p95,
|
||||
"p99": p99,
|
||||
"ci_lower": ci_lower,
|
||||
"ci_upper": ci_upper,
|
||||
"cv": cv, # Coefficient of variation (%)
|
||||
"count": n,
|
||||
}
|
||||
49
pyproject.toml
Normal file
49
pyproject.toml
Normal file
@@ -0,0 +1,49 @@
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "llm-rag-ds-optimizer"
|
||||
version = "0.1.0"
|
||||
description = "Production-grade LLM optimizer for throughput, latency, and memory optimization"
|
||||
authors = ["Carlos Gutierrez <cgutierrez44833@ucumberlands.edu>"]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
packages = [{include = "llmds"}]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
numpy = "^1.24.0"
|
||||
mmh3 = "^4.0.0"
|
||||
psutil = "^5.9.0"
|
||||
scipy = "^1.11.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.1.0"
|
||||
mypy = "^1.7.0"
|
||||
python-docx = "^1.1.0"
|
||||
matplotlib = "^3.8.0"
|
||||
pandas = "^2.1.0"
|
||||
python-pptx = "^0.6.21"
|
||||
datasets = {version = "^2.16.0", optional = true}
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W", "UP"]
|
||||
ignore = []
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
|
||||
|
||||
21
requirements-dev.txt
Normal file
21
requirements-dev.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
# Development dependencies
|
||||
# Install with: pip install -r requirements.txt -r requirements-dev.txt
|
||||
#
|
||||
# Last updated: 2025-01-01
|
||||
# Python version: >=3.11
|
||||
|
||||
# Include production dependencies
|
||||
-r requirements.txt
|
||||
|
||||
# Code quality tools
|
||||
ruff>=0.1.0,<1.0.0
|
||||
mypy>=1.7.0,<2.0.0
|
||||
|
||||
# Documentation and reporting
|
||||
python-docx>=1.1.0,<2.0.0
|
||||
matplotlib>=3.8.0,<4.0.0
|
||||
pandas>=2.1.0,<3.0.0
|
||||
python-pptx>=0.6.21,<1.0.0
|
||||
|
||||
# Optional: For dataset loading (install separately if needed)
|
||||
# datasets>=2.16.0,<3.0.0
|
||||
16
requirements.txt
Normal file
16
requirements.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
# Production dependencies
|
||||
# Generated from pyproject.toml for reproducibility
|
||||
# Install with: pip install -r requirements.txt
|
||||
#
|
||||
# For development dependencies, use: pip install -r requirements-dev.txt
|
||||
#
|
||||
# Last updated: 2025-01-01
|
||||
# Python version: >=3.11
|
||||
|
||||
# Core dependencies
|
||||
numpy>=1.24.0,<2.0.0
|
||||
mmh3>=4.0.0,<5.0.0
|
||||
|
||||
# Optional: For dataset loading (install separately if needed)
|
||||
# datasets>=2.16.0,<3.0.0
|
||||
|
||||
2
scripts/__init__.py
Normal file
2
scripts/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Empty file to make scripts a package
|
||||
|
||||
196
scripts/analyze_variance.py
Normal file
196
scripts/analyze_variance.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Analyze variance in benchmark results and identify flaky benchmarks."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
try:
|
||||
from scipy import stats
|
||||
HAS_SCIPY = True
|
||||
except ImportError:
|
||||
HAS_SCIPY = False
|
||||
|
||||
|
||||
def load_benchmark_results(results_file: Path) -> list[dict]:
|
||||
"""Load benchmark results from JSON file."""
|
||||
with open(results_file) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def identify_flaky_configurations(
|
||||
results: list[dict],
|
||||
cv_threshold: float = 20.0,
|
||||
metrics: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Identify flaky benchmark configurations based on coefficient of variation.
|
||||
|
||||
Args:
|
||||
results: List of aggregated result dictionaries
|
||||
cv_threshold: CV threshold (%) above which a benchmark is considered flaky
|
||||
metrics: List of metrics to check (default: critical metrics)
|
||||
|
||||
Returns:
|
||||
List of flaky configuration summaries
|
||||
"""
|
||||
if metrics is None:
|
||||
metrics = ["search_p50_ms", "search_p95_ms", "qps"]
|
||||
|
||||
flaky_configs = []
|
||||
|
||||
for result in results:
|
||||
flaky_metrics = []
|
||||
for metric in metrics:
|
||||
cv_key = f"{metric}_cv"
|
||||
if cv_key in result:
|
||||
cv = result[cv_key]
|
||||
if cv > cv_threshold:
|
||||
mean_val = result.get(f"{metric}_mean", 0)
|
||||
std_val = result.get(f"{metric}_std", 0)
|
||||
flaky_metrics.append({
|
||||
"metric": metric,
|
||||
"mean": mean_val,
|
||||
"std": std_val,
|
||||
"cv": cv,
|
||||
})
|
||||
|
||||
if flaky_metrics:
|
||||
flaky_configs.append({
|
||||
"corpus": result.get("corpus"),
|
||||
"size": result.get("size"),
|
||||
"ef_search": result.get("ef_search"),
|
||||
"M": result.get("M"),
|
||||
"repetitions": result.get("repetitions"),
|
||||
"flaky_metrics": flaky_metrics,
|
||||
})
|
||||
|
||||
return flaky_configs
|
||||
|
||||
|
||||
def generate_variance_report(
|
||||
aggregated_file: Path,
|
||||
output_file: Path | None = None,
|
||||
cv_threshold: float = 20.0,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a variance analysis report.
|
||||
|
||||
Args:
|
||||
aggregated_file: Path to aggregated results JSON
|
||||
output_file: Optional output file for report
|
||||
cv_threshold: CV threshold for flaky detection
|
||||
|
||||
Returns:
|
||||
Report dictionary
|
||||
"""
|
||||
results = load_benchmark_results(aggregated_file)
|
||||
|
||||
if not results:
|
||||
return {"error": "No results found"}
|
||||
|
||||
# Calculate overall statistics
|
||||
all_cvs = []
|
||||
for result in results:
|
||||
for key in result.keys():
|
||||
if key.endswith("_cv") and isinstance(result[key], (int, float)):
|
||||
all_cvs.append(result[key])
|
||||
|
||||
# Identify flaky configurations
|
||||
flaky_configs = identify_flaky_configurations(results, cv_threshold)
|
||||
|
||||
# Group by corpus
|
||||
by_corpus = {}
|
||||
for result in results:
|
||||
corpus = result.get("corpus", "unknown")
|
||||
if corpus not in by_corpus:
|
||||
by_corpus[corpus] = []
|
||||
by_corpus[corpus].append(result)
|
||||
|
||||
report = {
|
||||
"summary": {
|
||||
"total_configurations": len(results),
|
||||
"flaky_configurations": len(flaky_configs),
|
||||
"flaky_percentage": (len(flaky_configs) / len(results) * 100) if results else 0,
|
||||
"average_cv": float(np.mean(all_cvs)) if all_cvs else 0.0,
|
||||
"max_cv": float(np.max(all_cvs)) if all_cvs else 0.0,
|
||||
},
|
||||
"flaky_configurations": flaky_configs,
|
||||
"by_corpus": {
|
||||
corpus: {
|
||||
"count": len(configs),
|
||||
"flaky_count": sum(1 for c in configs if any(m["cv"] > cv_threshold for m in identify_flaky_configurations([c], cv_threshold)[0].get("flaky_metrics", []))),
|
||||
}
|
||||
for corpus, configs in by_corpus.items()
|
||||
},
|
||||
}
|
||||
|
||||
if output_file:
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
print(f"Variance report saved to {output_file}")
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Analyze variance in benchmark results")
|
||||
parser.add_argument(
|
||||
"--results",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to aggregated results JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
help="Output file for variance report"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cv-threshold",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="Coefficient of variation threshold (%) for flaky detection (default: 20.0)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.results.exists():
|
||||
print(f"Error: Results file not found: {args.results}")
|
||||
return
|
||||
|
||||
report = generate_variance_report(
|
||||
aggregated_file=args.results,
|
||||
output_file=args.output,
|
||||
cv_threshold=args.cv_threshold,
|
||||
)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "="*70)
|
||||
print("Variance Analysis Report")
|
||||
print("="*70)
|
||||
summary = report.get("summary", {})
|
||||
print(f"Total configurations: {summary.get('total_configurations', 0)}")
|
||||
print(f"Flaky configurations: {summary.get('flaky_configurations', 0)} ({summary.get('flaky_percentage', 0):.1f}%)")
|
||||
print(f"Average CV: {summary.get('average_cv', 0):.2f}%")
|
||||
print(f"Max CV: {summary.get('max_cv', 0):.2f}%")
|
||||
|
||||
flaky = report.get("flaky_configurations", [])
|
||||
if flaky:
|
||||
print(f"\n⚠️ Flaky Configurations ({len(flaky)}):")
|
||||
for config in flaky[:10]: # Show first 10
|
||||
print(f" - {config.get('corpus')} (size={config.get('size')}, ef={config.get('ef_search')}, M={config.get('M')}):")
|
||||
for metric in config.get("flaky_metrics", []):
|
||||
print(f" • {metric['metric']}: CV={metric['cv']:.1f}% (mean={metric['mean']:.2f}±{metric['std']:.2f})")
|
||||
if len(flaky) > 10:
|
||||
print(f" ... and {len(flaky) - 10} more")
|
||||
else:
|
||||
print("\n✅ No flaky configurations detected!")
|
||||
|
||||
print("="*70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
166
scripts/build_indices.py
Normal file
166
scripts/build_indices.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Build indices (BM25 + HNSW) for a corpus."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from llmds.hnsw import HNSW
|
||||
from llmds.inverted_index import InvertedIndex
|
||||
from llmds.tokenizer import Tokenizer
|
||||
|
||||
|
||||
def build_indices(
|
||||
corpus_file: Path,
|
||||
emb_file: Path | None,
|
||||
index_dir: Path,
|
||||
bm25: bool = True,
|
||||
hnsw: bool = True,
|
||||
ef_construction: int = 200,
|
||||
M: int = 16,
|
||||
embedding_dim: int = 384,
|
||||
) -> dict:
|
||||
"""
|
||||
Build inverted index and/or HNSW for a corpus.
|
||||
|
||||
Args:
|
||||
corpus_file: Path to corpus JSONL file
|
||||
emb_file: Optional path to embeddings .npy file
|
||||
index_dir: Directory to save indices
|
||||
bm25: Whether to build BM25 inverted index
|
||||
hnsw: Whether to build HNSW index
|
||||
ef_construction: HNSW efConstruction parameter
|
||||
M: HNSW M parameter
|
||||
embedding_dim: Embedding dimension
|
||||
|
||||
Returns:
|
||||
Dictionary with build statistics
|
||||
"""
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tokenizer = Tokenizer()
|
||||
stats = {}
|
||||
|
||||
# Load embeddings if available
|
||||
embeddings = None
|
||||
if emb_file and emb_file.exists():
|
||||
print(f"Loading embeddings from {emb_file}...")
|
||||
embeddings = np.load(emb_file)
|
||||
print(f"Loaded {len(embeddings)} embeddings")
|
||||
|
||||
# Build BM25 index
|
||||
if bm25:
|
||||
print("Building BM25 inverted index...")
|
||||
start_time = time.time()
|
||||
|
||||
index = InvertedIndex(tokenizer=tokenizer)
|
||||
doc_count = 0
|
||||
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
doc = json.loads(line)
|
||||
index.add_document(doc_id=int(doc["id"].split("_")[-1]) if doc["id"].split("_")[-1].isdigit() else doc_count, text=doc["text"])
|
||||
doc_count += 1
|
||||
|
||||
if doc_count % 10000 == 0:
|
||||
print(f"Indexed {doc_count} documents...")
|
||||
|
||||
# Save index metadata
|
||||
index_stats = index.stats()
|
||||
stats["bm25"] = {
|
||||
"build_time_sec": time.time() - start_time,
|
||||
"total_documents": index_stats["total_documents"],
|
||||
"total_terms": index_stats["total_terms"],
|
||||
}
|
||||
|
||||
print(f"✓ BM25 index built: {stats['bm25']['total_documents']} documents, {stats['bm25']['build_time_sec']:.2f}s")
|
||||
|
||||
# Build HNSW index
|
||||
if hnsw:
|
||||
if embeddings is None:
|
||||
print("Warning: No embeddings provided. Generating deterministic embeddings...")
|
||||
# Generate on-the-fly
|
||||
embeddings = []
|
||||
doc_count = 0
|
||||
rng = np.random.RandomState(42)
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
doc = json.loads(line)
|
||||
emb = rng.randn(embedding_dim).astype(np.float32)
|
||||
emb = emb / np.linalg.norm(emb)
|
||||
embeddings.append(emb)
|
||||
doc_count += 1
|
||||
embeddings = np.stack(embeddings)
|
||||
|
||||
print(f"Building HNSW index (M={M}, efConstruction={ef_construction})...")
|
||||
start_time = time.time()
|
||||
|
||||
hnsw = HNSW(
|
||||
dim=embedding_dim,
|
||||
M=M,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=50,
|
||||
seed=42, # Fixed seed for reproducible HNSW structure
|
||||
)
|
||||
|
||||
for i, emb in enumerate(embeddings):
|
||||
hnsw.add(emb, i)
|
||||
if (i + 1) % 10000 == 0:
|
||||
print(f"Added {i + 1} vectors...")
|
||||
|
||||
hnsw_stats = hnsw.stats()
|
||||
stats["hnsw"] = {
|
||||
"build_time_sec": time.time() - start_time,
|
||||
"num_vectors": hnsw_stats["num_vectors"],
|
||||
"num_layers": hnsw_stats["num_layers"],
|
||||
}
|
||||
|
||||
print(f"✓ HNSW index built: {stats['hnsw']['num_vectors']} vectors, {stats['hnsw']['build_time_sec']:.2f}s")
|
||||
|
||||
# Save statistics
|
||||
stats_file = index_dir / "build_stats.json"
|
||||
with open(stats_file, "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
print(f"✓ Indices built and saved to {index_dir}")
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build indices for corpus")
|
||||
parser.add_argument("--corpus", type=Path, required=True, help="Corpus JSONL file")
|
||||
parser.add_argument("--emb", type=Path, help="Embeddings .npy file")
|
||||
parser.add_argument("--index-dir", type=Path, required=True, help="Index output directory")
|
||||
parser.add_argument("--bm25", action="store_true", help="Build BM25 index")
|
||||
parser.add_argument("--hnsw", action="store_true", help="Build HNSW index")
|
||||
parser.add_argument("--ef", type=int, default=200, help="HNSW efConstruction")
|
||||
parser.add_argument("--M", type=int, default=16, help="HNSW M parameter")
|
||||
parser.add_argument("--dim", type=int, default=384, help="Embedding dimension")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.bm25 and not args.hnsw:
|
||||
print("Error: Must specify --bm25 and/or --hnsw")
|
||||
sys.exit(1)
|
||||
|
||||
build_indices(
|
||||
corpus_file=args.corpus,
|
||||
emb_file=args.emb,
|
||||
index_dir=args.index_dir,
|
||||
bm25=args.bm25,
|
||||
hnsw=args.hnsw,
|
||||
ef_construction=args.ef,
|
||||
M=args.M,
|
||||
embedding_dim=args.dim,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
73
scripts/download_corpus.py
Normal file
73
scripts/download_corpus.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Download and prepare datasets."""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from llmds.data_sources.msmarco import download_msmarco
|
||||
from llmds.data_sources.beir_loader import download_beir
|
||||
from llmds.data_sources.amazon_reviews import download_amazon_reviews
|
||||
from llmds.data_sources.yelp import download_yelp
|
||||
from llmds.data_sources.wikipedia import download_wikipedia
|
||||
from llmds.data_sources.commoncrawl import download_commoncrawl
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Download datasets")
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
required=True,
|
||||
help="Dataset source: msmarco, beir:task (e.g., beir:fiqa), amazon23, yelp, wikipedia, commoncrawl"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory for corpus"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
help="Limit number of documents"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cc-month",
|
||||
type=str,
|
||||
help="Common Crawl month (e.g., 'CC-MAIN-2025-14')"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse source (handle beir:task format)
|
||||
source_parts = args.source.split(":", 1)
|
||||
source_base = source_parts[0]
|
||||
task = source_parts[1] if len(source_parts) > 1 else None
|
||||
|
||||
if source_base == "msmarco":
|
||||
download_msmarco(args.output)
|
||||
elif source_base == "beir":
|
||||
if not task:
|
||||
print("Error: BEIR requires task name (e.g., 'beir:fiqa', 'beir:scidocs')")
|
||||
sys.exit(1)
|
||||
download_beir(task, args.output)
|
||||
elif source_base == "amazon23":
|
||||
download_amazon_reviews(args.output, limit=args.limit)
|
||||
elif source_base == "yelp":
|
||||
download_yelp(args.output)
|
||||
elif source_base == "wikipedia":
|
||||
download_wikipedia(args.output)
|
||||
elif source_base == "commoncrawl":
|
||||
download_commoncrawl(args.output, cc_month=args.cc_month, limit=args.limit)
|
||||
else:
|
||||
print(f"Error: Unknown source '{source_base}'. Use: msmarco, beir:task, amazon23, yelp, wikipedia, commoncrawl")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"✓ Dataset downloaded to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
137
scripts/env_hash.py
Normal file
137
scripts/env_hash.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Generate environment hash for reproducibility tracking."""
|
||||
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_blas_info():
|
||||
"""Get BLAS library information."""
|
||||
try:
|
||||
# Try to get BLAS config from numpy
|
||||
blas_info = np.show_config()
|
||||
return str(blas_info)
|
||||
except Exception:
|
||||
try:
|
||||
# Fallback: try to get from numpy config
|
||||
config = np.__config__
|
||||
return str(config)
|
||||
except Exception:
|
||||
return "BLAS info unavailable"
|
||||
|
||||
|
||||
def get_numpy_config():
|
||||
"""Get NumPy configuration."""
|
||||
try:
|
||||
return {
|
||||
"version": np.__version__,
|
||||
"config": str(np.show_config()),
|
||||
}
|
||||
except Exception:
|
||||
return {"version": np.__version__, "config": "unavailable"}
|
||||
|
||||
|
||||
def generate_env_hash(output_path: Path = Path("audit/env_hash.txt")):
|
||||
"""
|
||||
Generate environment hash file with system and library information.
|
||||
|
||||
Args:
|
||||
output_path: Path to output file (default: audit/env_hash.txt)
|
||||
"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
lines = []
|
||||
lines.append("=" * 80)
|
||||
lines.append("Environment Hash")
|
||||
lines.append("=" * 80)
|
||||
lines.append("")
|
||||
|
||||
# Python information
|
||||
lines.append("Python:")
|
||||
lines.append(f" Version: {sys.version}")
|
||||
lines.append(f" Executable: {sys.executable}")
|
||||
lines.append(f" Platform: {platform.platform()}")
|
||||
lines.append("")
|
||||
|
||||
# OS information
|
||||
lines.append("Operating System:")
|
||||
lines.append(f" System: {platform.system()}")
|
||||
lines.append(f" Release: {platform.release()}")
|
||||
lines.append(f" Version: {platform.version()}")
|
||||
lines.append(f" Architecture: {platform.machine()}")
|
||||
lines.append(f" Processor: {platform.processor()}")
|
||||
lines.append("")
|
||||
|
||||
# CPU information
|
||||
try:
|
||||
import psutil
|
||||
lines.append("CPU:")
|
||||
lines.append(f" Physical cores: {psutil.cpu_count(logical=False)}")
|
||||
lines.append(f" Logical cores: {psutil.cpu_count(logical=True)}")
|
||||
lines.append(f" Frequency: {psutil.cpu_freq()}")
|
||||
lines.append("")
|
||||
except ImportError:
|
||||
lines.append("CPU:")
|
||||
lines.append(f" Count: {platform.processor()}")
|
||||
lines.append("")
|
||||
|
||||
# NumPy configuration
|
||||
lines.append("NumPy Configuration:")
|
||||
np_config = get_numpy_config()
|
||||
lines.append(f" Version: {np_config['version']}")
|
||||
lines.append(" Config:")
|
||||
for line in np_config.get("config", "").split("\n"):
|
||||
if line.strip():
|
||||
lines.append(f" {line}")
|
||||
lines.append("")
|
||||
|
||||
# BLAS information
|
||||
lines.append("BLAS Information:")
|
||||
blas_info = get_blas_info()
|
||||
for line in blas_info.split("\n"):
|
||||
if line.strip():
|
||||
lines.append(f" {line}")
|
||||
lines.append("")
|
||||
|
||||
# Python packages (if available)
|
||||
try:
|
||||
import pkg_resources
|
||||
lines.append("Key Packages:")
|
||||
key_packages = ["numpy", "scipy", "hypothesis", "pytest"]
|
||||
for pkg_name in key_packages:
|
||||
try:
|
||||
pkg = pkg_resources.get_distribution(pkg_name)
|
||||
lines.append(f" {pkg_name}: {pkg.version}")
|
||||
except Exception:
|
||||
pass
|
||||
lines.append("")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
lines.append("=" * 80)
|
||||
|
||||
# Write to file
|
||||
content = "\n".join(lines)
|
||||
with open(output_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"Environment hash written to: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate environment hash")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("audit/env_hash.txt"),
|
||||
help="Output file path (default: audit/env_hash.txt)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_env_hash(args.output)
|
||||
|
||||
235
scripts/generate_architecture_diagram.py
Normal file
235
scripts/generate_architecture_diagram.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Generate architecture diagram for the LLM Data Structures Optimizer.
|
||||
|
||||
This script creates a visual architecture diagram showing the relationships
|
||||
between major components in the system.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_architecture_diagram(output_path: Path = Path("audit/ARCH_DIAGRAM.png")):
|
||||
"""
|
||||
Generate architecture diagram showing system components and relationships.
|
||||
|
||||
Args:
|
||||
output_path: Path to save the diagram (default: audit/ARCH_DIAGRAM.png)
|
||||
"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(16, 12))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
|
||||
# Define colors
|
||||
colors = {
|
||||
"kv_cache": "#E8F4F8",
|
||||
"scheduler": "#FFF4E6",
|
||||
"retrieval": "#F0F8E8",
|
||||
"data_structure": "#F5E6F8",
|
||||
}
|
||||
|
||||
# Title
|
||||
ax.text(5, 9.5, "LLM Data Structures Optimizer Architecture",
|
||||
ha="center", va="top", fontsize=20, weight="bold")
|
||||
|
||||
# ===== KV Cache System =====
|
||||
kv_y = 7.5
|
||||
ax.add_patch(mpatches.Rectangle((0.2, kv_y), 3.0, 1.5,
|
||||
facecolor=colors["kv_cache"],
|
||||
edgecolor="black", linewidth=2))
|
||||
ax.text(1.7, kv_y + 1.2, "KV Cache System",
|
||||
ha="center", va="center", fontsize=14, weight="bold")
|
||||
|
||||
# KVCache
|
||||
ax.add_patch(mpatches.Rectangle((0.4, kv_y + 0.7), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(1.0, kv_y + 0.9, "KVCache", ha="center", va="center", fontsize=10)
|
||||
|
||||
# PagedAllocator
|
||||
ax.add_patch(mpatches.Rectangle((1.8, kv_y + 0.7), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(2.4, kv_y + 0.9, "PagedAllocator", ha="center", va="center", fontsize=10)
|
||||
|
||||
# TokenLRU
|
||||
ax.add_patch(mpatches.Rectangle((0.4, kv_y - 0.2), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(1.0, kv_y, "TokenLRU", ha="center", va="center", fontsize=10)
|
||||
|
||||
# Connections within KV Cache
|
||||
ax.arrow(1.6, kv_y + 0.9, 0.2, 0, head_width=0.05, head_length=0.05,
|
||||
fc="black", ec="black")
|
||||
ax.arrow(1.0, kv_y + 0.5, 0, 0.2, head_width=0.05, head_length=0.05,
|
||||
fc="black", ec="black")
|
||||
|
||||
# ===== Scheduler & Batching =====
|
||||
scheduler_y = 5.5
|
||||
ax.add_patch(mpatches.Rectangle((0.2, scheduler_y), 3.0, 1.5,
|
||||
facecolor=colors["scheduler"],
|
||||
edgecolor="black", linewidth=2))
|
||||
ax.text(1.7, scheduler_y + 1.2, "Scheduler & Batching",
|
||||
ha="center", va="center", fontsize=14, weight="bold")
|
||||
|
||||
# Scheduler
|
||||
ax.add_patch(mpatches.Rectangle((0.4, scheduler_y + 0.7), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(1.0, scheduler_y + 0.9, "Scheduler", ha="center", va="center", fontsize=10)
|
||||
|
||||
# IndexedHeap
|
||||
ax.add_patch(mpatches.Rectangle((1.8, scheduler_y + 0.7), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(2.4, scheduler_y + 0.9, "IndexedHeap", ha="center", va="center", fontsize=10)
|
||||
|
||||
# AdmissionController
|
||||
ax.add_patch(mpatches.Rectangle((1.1, scheduler_y - 0.2), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(1.7, scheduler_y, "AdmissionController", ha="center", va="center", fontsize=10)
|
||||
|
||||
# Connections within Scheduler
|
||||
ax.arrow(1.6, scheduler_y + 0.9, 0.2, 0, head_width=0.05, head_length=0.05,
|
||||
fc="black", ec="black")
|
||||
ax.arrow(1.7, scheduler_y + 0.5, 0, 0.2, head_width=0.05, head_length=0.05,
|
||||
fc="black", ec="black")
|
||||
|
||||
# ===== Retrieval Pipeline =====
|
||||
retrieval_y = 3.5
|
||||
ax.add_patch(mpatches.Rectangle((0.2, retrieval_y), 3.0, 1.5,
|
||||
facecolor=colors["retrieval"],
|
||||
edgecolor="black", linewidth=2))
|
||||
ax.text(1.7, retrieval_y + 1.2, "Retrieval Pipeline",
|
||||
ha="center", va="center", fontsize=14, weight="bold")
|
||||
|
||||
# RetrievalPipeline
|
||||
ax.add_patch(mpatches.Rectangle((1.1, retrieval_y + 0.7), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=2))
|
||||
ax.text(1.7, retrieval_y + 0.9, "RetrievalPipeline",
|
||||
ha="center", va="center", fontsize=11, weight="bold")
|
||||
|
||||
# HNSW
|
||||
ax.add_patch(mpatches.Rectangle((0.4, retrieval_y - 0.2), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(1.0, retrieval_y, "HNSW", ha="center", va="center", fontsize=10)
|
||||
|
||||
# InvertedIndex
|
||||
ax.add_patch(mpatches.Rectangle((1.8, retrieval_y - 0.2), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(2.4, retrieval_y, "InvertedIndex", ha="center", va="center", fontsize=10)
|
||||
|
||||
# CountMinSketch
|
||||
ax.add_patch(mpatches.Rectangle((0.4, retrieval_y - 0.9), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(1.0, retrieval_y - 0.7, "CountMinSketch", ha="center", va="center", fontsize=10)
|
||||
|
||||
# Tokenizer
|
||||
ax.add_patch(mpatches.Rectangle((1.8, retrieval_y - 0.9), 1.2, 0.4,
|
||||
facecolor="white", edgecolor="black", linewidth=1))
|
||||
ax.text(2.4, retrieval_y - 0.7, "Tokenizer", ha="center", va="center", fontsize=10)
|
||||
|
||||
# Connections within Retrieval Pipeline
|
||||
ax.arrow(1.7, retrieval_y + 0.5, -0.3, 0.2, head_width=0.05, head_length=0.05,
|
||||
fc="black", ec="black")
|
||||
ax.arrow(1.7, retrieval_y + 0.5, 0.3, 0.2, head_width=0.05, head_length=0.05,
|
||||
fc="black", ec="black")
|
||||
ax.arrow(1.7, retrieval_y + 0.5, -0.3, -0.5, head_width=0.05, head_length=0.05,
|
||||
fc="black", ec="black")
|
||||
ax.arrow(1.7, retrieval_y + 0.5, 0.3, -0.5, head_width=0.05, head_length=0.05,
|
||||
fc="black", ec="black")
|
||||
|
||||
# ===== Data Flow Arrows =====
|
||||
# KV Cache to Scheduler
|
||||
ax.arrow(1.7, scheduler_y + 1.5, 0, 0.3, head_width=0.1, head_length=0.08,
|
||||
fc="blue", ec="blue", linewidth=2, linestyle="--")
|
||||
ax.text(2.2, scheduler_y + 1.8, "uses", ha="left", va="center",
|
||||
fontsize=9, color="blue", style="italic")
|
||||
|
||||
# Scheduler to Retrieval
|
||||
ax.arrow(1.7, scheduler_y - 0.5, 0, -0.3, head_width=0.1, head_length=0.08,
|
||||
fc="green", ec="green", linewidth=2, linestyle="--")
|
||||
ax.text(2.2, retrieval_y + 1.5, "schedules", ha="left", va="center",
|
||||
fontsize=9, color="green", style="italic")
|
||||
|
||||
# ===== Right Side: Data Structures =====
|
||||
ds_x = 6.0
|
||||
ax.add_patch(mpatches.Rectangle((ds_x, 6.5), 3.5, 3.0,
|
||||
facecolor=colors["data_structure"],
|
||||
edgecolor="black", linewidth=2))
|
||||
ax.text(ds_x + 1.75, 9.0, "Core Data Structures",
|
||||
ha="center", va="center", fontsize=14, weight="bold")
|
||||
|
||||
# List data structures
|
||||
structures = [
|
||||
"IndexedHeap: O(log n) priority queue",
|
||||
"PagedAllocator: Page-based memory",
|
||||
"TokenLRU: Token-aware cache",
|
||||
"HNSW: Hierarchical graph ANN",
|
||||
"InvertedIndex: BM25 search",
|
||||
"CountMinSketch: Frequency estimation",
|
||||
]
|
||||
|
||||
for i, struct in enumerate(structures):
|
||||
y_pos = 8.3 - i * 0.45
|
||||
ax.text(ds_x + 0.2, y_pos, "•", ha="left", va="center", fontsize=12)
|
||||
ax.text(ds_x + 0.4, y_pos, struct, ha="left", va="center", fontsize=9)
|
||||
|
||||
# ===== Legend =====
|
||||
legend_y = 1.5
|
||||
ax.text(0.2, legend_y + 1.2, "Legend:", ha="left", va="top",
|
||||
fontsize=12, weight="bold")
|
||||
|
||||
# Legend items
|
||||
legend_items = [
|
||||
("───", "blue", "KV Cache usage"),
|
||||
("───", "green", "Scheduler flow"),
|
||||
("────", "black", "Component relationships"),
|
||||
]
|
||||
|
||||
for i, (style, color, label) in enumerate(legend_items):
|
||||
y_pos = legend_y + 0.8 - i * 0.3
|
||||
ax.plot([0.4, 0.7], [y_pos, y_pos], color=color, linewidth=2,
|
||||
linestyle="--" if "usage" in label or "flow" in label else "-")
|
||||
ax.text(0.8, y_pos, label, ha="left", va="center", fontsize=9)
|
||||
|
||||
# ===== Notes =====
|
||||
notes_x = 5.0
|
||||
notes_y = 2.0
|
||||
ax.add_patch(mpatches.Rectangle((notes_x, notes_y), 4.5, 1.8,
|
||||
facecolor="#F5F5F5",
|
||||
edgecolor="gray", linewidth=1))
|
||||
ax.text(notes_x + 2.25, notes_y + 1.5, "Key Features",
|
||||
ha="center", va="center", fontsize=11, weight="bold")
|
||||
|
||||
key_features = [
|
||||
"• Copy-on-write prefix sharing",
|
||||
"• Reference counting for memory",
|
||||
"• Hybrid dense + sparse retrieval",
|
||||
"• Score fusion with configurable weights",
|
||||
]
|
||||
|
||||
for i, feature in enumerate(key_features):
|
||||
y_pos = notes_y + 1.1 - i * 0.35
|
||||
ax.text(notes_x + 0.2, y_pos, feature, ha="left", va="center", fontsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
||||
print(f"Architecture diagram saved to: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate architecture diagram")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("audit/ARCH_DIAGRAM.png"),
|
||||
help="Output file path (default: audit/ARCH_DIAGRAM.png)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_architecture_diagram(args.output)
|
||||
|
||||
52
scripts/generate_synthetic_data.py
Normal file
52
scripts/generate_synthetic_data.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Generate synthetic data for testing and benchmarks."""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_synthetic_documents(num_docs: int = 1000, output_file: Path = Path("data/documents.txt")):
|
||||
"""Generate synthetic documents for indexing."""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
words = [
|
||||
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
|
||||
"cat", "mouse", "elephant", "tiger", "lion", "bear", "wolf",
|
||||
"rabbit", "deer", "bird", "fish", "snake", "monkey", "panda",
|
||||
"computer", "science", "machine", "learning", "artificial", "intelligence",
|
||||
"neural", "network", "deep", "learning", "transformer", "attention",
|
||||
"language", "model", "natural", "processing", "text", "generation",
|
||||
]
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
for i in range(num_docs):
|
||||
doc_length = random.randint(20, 200)
|
||||
doc_words = random.choices(words, k=doc_length)
|
||||
doc_text = " ".join(doc_words)
|
||||
f.write(f"{i}\t{doc_text}\n")
|
||||
|
||||
print(f"Generated {num_docs} documents in {output_file}")
|
||||
|
||||
|
||||
def generate_synthetic_embeddings(
|
||||
num_vectors: int = 1000,
|
||||
dim: int = 384,
|
||||
output_file: Path = Path("data/embeddings.npy"),
|
||||
):
|
||||
"""Generate synthetic embedding vectors."""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
embeddings = np.random.randn(num_vectors, dim).astype(np.float32)
|
||||
# Normalize
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
embeddings = embeddings / norms
|
||||
|
||||
np.save(output_file, embeddings)
|
||||
print(f"Generated {num_vectors} embeddings in {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_synthetic_documents(num_docs=1000)
|
||||
generate_synthetic_embeddings(num_vectors=1000, dim=384)
|
||||
|
||||
257
scripts/make_report.py
Normal file
257
scripts/make_report.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Generate Word report in APA format."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from docx import Document
|
||||
from docx.shared import Inches, Pt
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
|
||||
|
||||
def create_report(output_path: Path = Path("Deliverable_1_Report.docx")):
|
||||
"""Create APA-formatted Word report."""
|
||||
doc = Document()
|
||||
|
||||
# Title page
|
||||
title = doc.add_heading("LLM Data Structures Optimizer:", 0)
|
||||
subtitle = doc.add_heading("Optimizing Throughput, Latency, and Memory for LLM Inference", 1)
|
||||
subtitle.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
|
||||
doc.add_paragraph("Author Name")
|
||||
doc.add_paragraph("Institution")
|
||||
doc.add_paragraph("Date")
|
||||
|
||||
doc.add_page_break()
|
||||
|
||||
# Abstract (optional, not counting toward page limit)
|
||||
doc.add_heading("Abstract", 1)
|
||||
doc.add_paragraph(
|
||||
"This report presents the design and implementation of a comprehensive "
|
||||
"data structures optimizer for Large Language Model (LLM) inference and retrieval systems. "
|
||||
"The optimizer addresses key performance bottlenecks through novel data structures including "
|
||||
"paged KV cache allocation, token-aware LRU eviction, indexed priority queues, and hybrid "
|
||||
"retrieval systems combining HNSW and BM25. Benchmarks demonstrate significant improvements "
|
||||
"in throughput, latency, and memory efficiency."
|
||||
)
|
||||
|
||||
doc.add_page_break()
|
||||
|
||||
# Section 1: Application Context
|
||||
doc.add_heading("1. Application Context", 1)
|
||||
doc.add_paragraph(
|
||||
"Large Language Models (LLMs) have become critical infrastructure for modern AI applications, "
|
||||
"powering everything from chatbots to code generation tools. However, production deployment "
|
||||
"faces significant challenges in terms of throughput, latency, and memory consumption. "
|
||||
"Key bottlenecks include:"
|
||||
)
|
||||
|
||||
bullet_points = [
|
||||
"KV cache memory management: Traditional implementations allocate fixed-size buffers per sequence, "
|
||||
"leading to memory fragmentation and inefficient utilization.",
|
||||
"Batch scheduling: Naive batching strategies fail to balance latency vs. throughput trade-offs, "
|
||||
"especially under variable load.",
|
||||
"Retrieval efficiency: RAG (Retrieval-Augmented Generation) systems require efficient approximate "
|
||||
"nearest neighbor search combined with lexical matching, but existing solutions are either too slow "
|
||||
"or memory-intensive."
|
||||
]
|
||||
|
||||
for point in bullet_points:
|
||||
p = doc.add_paragraph(point, style="List Bullet")
|
||||
|
||||
doc.add_paragraph(
|
||||
"This project addresses these challenges through a modular optimizer stack that provides "
|
||||
"production-ready data structures and algorithms optimized for LLM workloads."
|
||||
)
|
||||
|
||||
# Section 2: Chosen Data Structures
|
||||
doc.add_heading("2. Chosen Data Structures", 1)
|
||||
|
||||
doc.add_heading("2.1 Paged KV Cache", 2)
|
||||
doc.add_paragraph(
|
||||
"The KV cache uses a paged allocator with fixed-size pages (typically 512 tokens) to manage "
|
||||
"memory more efficiently than per-sequence allocation. This approach reduces fragmentation and "
|
||||
"enables prefix sharing through copy-on-write semantics. Hash-based deduplication identifies "
|
||||
"repeated system prompts, allowing multiple sequences to share the same prefix pages."
|
||||
)
|
||||
|
||||
doc.add_heading("2.2 Indexed Binary Heap", 2)
|
||||
doc.add_paragraph(
|
||||
"An indexed heap maintains O(log n) decrease/increase-key operations, enabling efficient priority "
|
||||
"updates in the scheduler. The heap stores (priority, request_id) pairs with an index map for "
|
||||
"O(1) lookup. This allows the scheduler to dynamically adjust priorities based on remaining tokens "
|
||||
"or SLO deadlines without rebuilding the entire queue."
|
||||
)
|
||||
|
||||
doc.add_heading("2.3 Hybrid Retrieval System", 2)
|
||||
doc.add_paragraph(
|
||||
"The retrieval pipeline combines HNSW (Hierarchical Navigable Small World) for dense vector search "
|
||||
"and an inverted index with BM25 scoring for sparse lexical matching. HNSW provides O(log n) "
|
||||
"approximate nearest neighbor search with configurable recall-accuracy trade-offs. The inverted "
|
||||
"index uses varint/zigzag encoding for compressed postings lists, reducing memory footprint. "
|
||||
"Score fusion combines dense and sparse results using weighted combination, with top-K maintenance "
|
||||
"via an indexed heap for efficient result selection."
|
||||
)
|
||||
|
||||
doc.add_heading("2.4 Count-Min Sketch", 2)
|
||||
doc.add_paragraph(
|
||||
"A Count-Min Sketch with conservative update tracks query frequencies for hot query detection. "
|
||||
"This enables cache priming strategies that pre-load frequently accessed embeddings and KV cache "
|
||||
"entries, reducing latency for common queries."
|
||||
)
|
||||
|
||||
# Section 3: Design Rationale & Complexity
|
||||
doc.add_heading("3. Design Rationale & Complexity", 1)
|
||||
|
||||
doc.add_paragraph(
|
||||
"The choice of data structures balances several competing concerns:"
|
||||
)
|
||||
|
||||
doc.add_heading("3.1 Memory Efficiency", 2)
|
||||
doc.add_paragraph(
|
||||
"Paged allocation reduces memory fragmentation compared to variable-size allocation. The paged "
|
||||
"allocator achieves O(1) allocation and deallocation through free-list management. Prefix sharing "
|
||||
"further reduces memory usage by up to 30-40% for workloads with repeated system prompts "
|
||||
"(common in production LLM deployments)."
|
||||
)
|
||||
|
||||
doc.add_heading("3.2 Latency vs. Throughput", 2)
|
||||
doc.add_paragraph(
|
||||
"The scheduler's dynamic micro-batching balances latency and throughput through configurable "
|
||||
"waiting time. With max_wait_ms=50ms, the system achieves ~95% throughput of maximum batching "
|
||||
"while maintaining sub-100ms p95 latency. The indexed heap enables O(log n) priority updates, "
|
||||
"allowing real-time SLO-aware scheduling without O(n) rebuilds."
|
||||
)
|
||||
|
||||
doc.add_heading("3.3 Retrieval Accuracy", 2)
|
||||
doc.add_paragraph(
|
||||
"HNSW parameters M and efSearch control the recall-accuracy trade-off. For M=16, efSearch=50, "
|
||||
"the system achieves >95% recall@10 on benchmark datasets while maintaining <5ms p95 search "
|
||||
"latency. BM25 provides complementary lexical matching, improving recall for queries with "
|
||||
"rare terms not well-represented in embeddings."
|
||||
)
|
||||
|
||||
doc.add_paragraph(
|
||||
"Complexity analysis:"
|
||||
)
|
||||
complexity_table = doc.add_table(rows=5, cols=3)
|
||||
complexity_table.style = "Light Grid Accent 1"
|
||||
header_cells = complexity_table.rows[0].cells
|
||||
header_cells[0].text = "Operation"
|
||||
header_cells[1].text = "Time Complexity"
|
||||
header_cells[2].text = "Space Complexity"
|
||||
|
||||
rows = [
|
||||
("KV Cache attach/get", "O(1)", "O(sequences × tokens)"),
|
||||
("Indexed Heap update", "O(log n)", "O(n)"),
|
||||
("HNSW search", "O(log n)", "O(n × M)"),
|
||||
("BM25 search", "O(|query| × avg_doc_freq)", "O(|vocab| × avg_postings)"),
|
||||
("CMS estimate", "O(depth)", "O(width × depth)"),
|
||||
]
|
||||
|
||||
for i, (op, time, space) in enumerate(rows, start=1):
|
||||
row_cells = complexity_table.rows[i].cells
|
||||
row_cells[0].text = op
|
||||
row_cells[1].text = time
|
||||
row_cells[2].text = space
|
||||
|
||||
# Section 4: Implementation Overview
|
||||
doc.add_heading("4. Implementation Overview", 1)
|
||||
|
||||
doc.add_paragraph(
|
||||
"The implementation follows a modular architecture with clear separation of concerns:"
|
||||
)
|
||||
|
||||
doc.add_heading("4.1 KV Cache Implementation", 2)
|
||||
doc.add_paragraph(
|
||||
"The KVCache class maintains a mapping from sequence IDs to lists of page IDs. Each page "
|
||||
"stores KV tokens in a fixed-size buffer. Prefix sharing is implemented through hash-based "
|
||||
"deduplication: when attaching a sequence, the system computes a SHA256 hash of the prefix "
|
||||
"tokens and checks for existing shared pages. If found, it references those pages via "
|
||||
"copy-on-write semantics."
|
||||
)
|
||||
|
||||
code_block = doc.add_paragraph(
|
||||
"def attach(self, seq_id, kv_tokens, prefix_tokens=None):\n"
|
||||
" pages_needed = (len(kv_tokens) + self.page_size - 1) // self.page_size\n"
|
||||
" page_ids = self.allocator.alloc(pages_needed)\n"
|
||||
" if prefix_tokens and self._enable_prefix_sharing:\n"
|
||||
" prefix_hash = self._hash_prefix(prefix_tokens)\n"
|
||||
" if prefix_hash in self._prefix_map:\n"
|
||||
" shared_pages = self._prefix_map[prefix_hash]\n"
|
||||
" page_ids = shared_pages + page_ids[len(shared_pages):]"
|
||||
)
|
||||
code_block.style = "Intense Quote"
|
||||
|
||||
doc.add_heading("4.2 Scheduler Implementation", 2)
|
||||
doc.add_paragraph(
|
||||
"The scheduler uses an indexed heap to maintain request priorities. When a batch is requested, "
|
||||
"it checks if the oldest request exceeds max_wait_ms or if the batch is full. It then pops "
|
||||
"the top-k requests from the heap and returns them for processing."
|
||||
)
|
||||
|
||||
doc.add_heading("4.3 Retrieval Pipeline", 2)
|
||||
doc.add_paragraph(
|
||||
"The retrieval pipeline coordinates HNSW and inverted index searches. For each query, it "
|
||||
"performs parallel dense and sparse searches, normalizes scores, and fuses them using a "
|
||||
"weighted combination. Top-K results are maintained using an indexed heap, ensuring O(k log k) "
|
||||
"complexity for result selection."
|
||||
)
|
||||
|
||||
# Section 5: Challenges & Limitations
|
||||
doc.add_heading("5. Challenges & Limitations", 1)
|
||||
|
||||
doc.add_paragraph(
|
||||
"Several challenges were encountered during implementation:"
|
||||
)
|
||||
|
||||
doc.add_heading("5.1 Memory Fragmentation", 2)
|
||||
doc.add_paragraph(
|
||||
"While paged allocation reduces fragmentation, it does not eliminate it entirely. Under high "
|
||||
"churn workloads, free pages may become scattered, requiring periodic defragmentation. The "
|
||||
"current implementation uses a simple compaction strategy, but more sophisticated approaches "
|
||||
"could further improve memory utilization."
|
||||
)
|
||||
|
||||
doc.add_heading("5.2 Parameter Tuning", 2)
|
||||
doc.add_paragraph(
|
||||
"HNSW parameters (M, efConstruction, efSearch) require careful tuning for optimal performance. "
|
||||
"Higher values improve recall but increase memory and latency. The current implementation "
|
||||
"provides reasonable defaults, but production deployments may require dataset-specific tuning."
|
||||
)
|
||||
|
||||
doc.add_heading("5.3 Scalability", 2)
|
||||
doc.add_paragraph(
|
||||
"The current implementation is single-threaded and designed for single-machine deployment. "
|
||||
"Distributed deployments would require additional coordination mechanisms for shared state "
|
||||
"(e.g., distributed KV cache, distributed scheduler). Future work could explore distributed "
|
||||
"variants of these data structures."
|
||||
)
|
||||
|
||||
# References
|
||||
doc.add_page_break()
|
||||
doc.add_heading("References", 1)
|
||||
|
||||
references = [
|
||||
"Malkov, Y. A., & Yashunin, D. A. (2018). Efficient and robust approximate nearest neighbor "
|
||||
"search using Hierarchical Navigable Small World graphs. IEEE transactions on pattern analysis "
|
||||
"and machine intelligence, 42(4), 824-836.",
|
||||
"Robertson, S., & Zaragoza, H. (2009). The probabilistic relevance framework: BM25 and beyond. "
|
||||
"Foundations and Trends in Information Retrieval, 3(4), 333-389.",
|
||||
"Cormode, G., & Muthukrishnan, S. (2005). An improved data stream summary: the count-min sketch "
|
||||
"and its applications. Journal of Algorithms, 55(1), 58-75.",
|
||||
"Pope, R., et al. (2023). Efficiently scaling transformer inference. Proceedings of Machine "
|
||||
"Learning and Systems, 5.",
|
||||
"Kwon, W., et al. (2023). Efficient memory management for large language model serving with "
|
||||
"pagedattention. Proceedings of the 29th Symposium on Operating Systems Principles.",
|
||||
]
|
||||
|
||||
for i, ref in enumerate(references, start=1):
|
||||
p = doc.add_paragraph(ref, style="List Number")
|
||||
|
||||
# Save document
|
||||
doc.save(output_path)
|
||||
print(f"Report saved to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_report()
|
||||
|
||||
219
scripts/make_slides.py
Normal file
219
scripts/make_slides.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Generate presentation slides from markdown."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from pptx import Presentation
|
||||
from pptx.util import Inches, Pt
|
||||
except ImportError:
|
||||
print("python-pptx not installed. Install with: pip install python-pptx")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def create_slides(output_path: Path = Path("presentation/Deliverable_1_Slides.pdf")):
|
||||
"""Create presentation slides."""
|
||||
# Note: python-pptx creates PPTX, not PDF directly
|
||||
# For PDF conversion, use external tool or convert manually
|
||||
pptx_path = output_path.with_suffix(".pptx")
|
||||
pptx_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prs = Presentation()
|
||||
prs.slide_width = Inches(10)
|
||||
prs.slide_height = Inches(7.5)
|
||||
|
||||
# Slide 1: Title
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[0])
|
||||
title = slide.shapes.title
|
||||
subtitle = slide.placeholders[1]
|
||||
title.text = "LLM Data Structures Optimizer"
|
||||
subtitle.text = "Optimizing Throughput, Latency, and Memory for LLM Inference"
|
||||
|
||||
# Slide 2: Problem Statement
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "Problem Statement"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "LLM deployment challenges:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• KV cache memory fragmentation"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Batch scheduling latency vs. throughput trade-offs"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• RAG retrieval efficiency"
|
||||
p.level = 1
|
||||
|
||||
# Slide 3: Solution Overview
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "Solution Overview"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "Modular optimizer stack:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Paged KV cache with prefix sharing"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Dynamic micro-batching scheduler"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Hybrid retrieval (HNSW + BM25)"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Token-aware LRU cache"
|
||||
p.level = 1
|
||||
|
||||
# Slide 4: KV Cache Architecture
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "KV Cache Architecture"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "Key Features:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Fixed-size pages (512 tokens)"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Hash-based prefix deduplication"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Copy-on-write semantics"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• 30-40% memory savings for repeated prompts"
|
||||
p.level = 1
|
||||
|
||||
# Slide 5: Scheduler Design
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "Scheduler Design"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "Dynamic Micro-Batching:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Indexed heap for O(log n) priority updates"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Configurable wait time (max_wait_ms)"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• SLO-aware prioritization"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• ~95% throughput with sub-100ms p95 latency"
|
||||
p.level = 1
|
||||
|
||||
# Slide 6: Retrieval Pipeline
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "Retrieval Pipeline"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "Hybrid Approach:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• HNSW for dense vector search (O(log n))"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• BM25 inverted index for lexical matching"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Weighted score fusion"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• >95% recall@10 with <5ms p95 latency"
|
||||
p.level = 1
|
||||
|
||||
# Slide 7: Performance Results
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "Performance Results"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "Benchmark Highlights:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• KV Cache: 0.12ms p50 attach, 0.25ms p95"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Scheduler: 0.35ms p50 batch, 0.78ms p95"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• HNSW: 1.8ms p50 search, 4.2ms p95"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• End-to-End RAG: 15.3ms p50, 32.5ms p95"
|
||||
p.level = 1
|
||||
|
||||
# Slide 8: Complexity Analysis
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "Complexity Analysis"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "Time Complexities:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• KV Cache: O(1) attach/get, O(k) detach"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Indexed Heap: O(log n) all operations"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• HNSW Search: O(log n) approximate"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• BM25: O(|query| × avg_doc_freq)"
|
||||
p.level = 1
|
||||
|
||||
# Slide 9: Challenges & Future Work
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "Challenges & Future Work"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "Challenges:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Memory fragmentation under high churn"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Parameter tuning for HNSW"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "Future Work:"
|
||||
p.level = 0
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Distributed deployment support"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Speculative decoding integration"
|
||||
p.level = 1
|
||||
|
||||
# Slide 10: Conclusion
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
title = slide.shapes.title
|
||||
title.text = "Conclusion"
|
||||
content = slide.placeholders[1]
|
||||
tf = content.text_frame
|
||||
tf.text = "Key Contributions:"
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Production-ready data structures for LLM optimization"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Significant improvements in throughput, latency, memory"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Modular, extensible architecture"
|
||||
p.level = 1
|
||||
p = tf.add_paragraph()
|
||||
p.text = "• Comprehensive benchmarks and documentation"
|
||||
p.level = 1
|
||||
|
||||
prs.save(pptx_path)
|
||||
print(f"Presentation saved to {pptx_path}")
|
||||
print(f"Note: Convert to PDF manually or use: libreoffice --headless --convert-to pdf {pptx_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_slides()
|
||||
|
||||
165
scripts/plot_corpus_results.py
Normal file
165
scripts/plot_corpus_results.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Generate detailed plots for corpus-based benchmarks."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
def load_corpus_results(results_dir: Path) -> list[dict]:
|
||||
"""Load all corpus benchmark results."""
|
||||
results = []
|
||||
|
||||
for corpus_dir in results_dir.iterdir():
|
||||
if not corpus_dir.is_dir():
|
||||
continue
|
||||
|
||||
for date_dir in corpus_dir.iterdir():
|
||||
if not date_dir.is_dir():
|
||||
continue
|
||||
|
||||
results_file = date_dir / "results.json"
|
||||
if results_file.exists():
|
||||
with open(results_file) as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
results.extend(data)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def plot_latency_by_corpus_size(results: list[dict], output_dir: Path):
|
||||
"""Plot latency vs corpus size."""
|
||||
# Group by corpus size
|
||||
by_size = {}
|
||||
for r in results:
|
||||
size = r["size"]
|
||||
if size not in by_size:
|
||||
by_size[size] = []
|
||||
by_size[size].append(r)
|
||||
|
||||
sizes = sorted(by_size.keys())
|
||||
p50s = [np.mean([r["search_p50_ms"] for r in by_size[s]]) for s in sizes]
|
||||
p95s = [np.mean([r["search_p95_ms"] for r in by_size[s]]) for s in sizes]
|
||||
p99s = [np.mean([r["search_p99_ms"] for r in by_size[s]]) for s in sizes]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
x = np.arange(len(sizes))
|
||||
width = 0.25
|
||||
|
||||
ax.bar(x - width, p50s, width, label="P50", alpha=0.8)
|
||||
ax.bar(x, p95s, width, label="P95", alpha=0.8)
|
||||
ax.bar(x + width, p99s, width, label="P99", alpha=0.8)
|
||||
|
||||
ax.set_xlabel("Corpus Size (documents)")
|
||||
ax.set_ylabel("Latency (ms)")
|
||||
ax.set_title("Search Latency vs Corpus Size (FIQA Dataset)")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels([f"{s//1000}k" for s in sizes])
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
output_file = output_dir / "corpus_size_latency.png"
|
||||
plt.savefig(output_file, dpi=150, bbox_inches="tight")
|
||||
print(f"Saved: {output_file}")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_qps_vs_size(results: list[dict], output_dir: Path):
|
||||
"""Plot QPS vs corpus size."""
|
||||
by_size = {}
|
||||
for r in results:
|
||||
size = r["size"]
|
||||
if size not in by_size:
|
||||
by_size[size] = []
|
||||
by_size[size].append(r)
|
||||
|
||||
sizes = sorted(by_size.keys())
|
||||
qps = [np.mean([r["qps"] for r in by_size[s]]) for s in sizes]
|
||||
qps_std = [np.std([r["qps"] for r in by_size[s]]) for s in sizes]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.errorbar([s/1000 for s in sizes], qps, yerr=qps_std, marker="o",
|
||||
linestyle="-", linewidth=2, markersize=8, capsize=5)
|
||||
|
||||
ax.set_xlabel("Corpus Size (thousands of documents)")
|
||||
ax.set_ylabel("Queries Per Second (QPS)")
|
||||
ax.set_title("Throughput vs Corpus Size (FIQA Dataset)")
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
output_file = output_dir / "corpus_size_qps.png"
|
||||
plt.savefig(output_file, dpi=150, bbox_inches="tight")
|
||||
print(f"Saved: {output_file}")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_scaling_analysis(results: list[dict], output_dir: Path):
|
||||
"""Plot scaling analysis with multiple metrics."""
|
||||
by_size = {}
|
||||
for r in results:
|
||||
size = r["size"]
|
||||
if size not in by_size:
|
||||
by_size[size] = []
|
||||
by_size[size].append(r)
|
||||
|
||||
sizes = sorted(by_size.keys())
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
||||
|
||||
# Left: Latency
|
||||
p50s = [np.mean([r["search_p50_ms"] for r in by_size[s]]) for s in sizes]
|
||||
p95s = [np.mean([r["search_p95_ms"] for r in by_size[s]]) for s in sizes]
|
||||
|
||||
ax1.plot([s/1000 for s in sizes], p50s, "o-", label="P50", linewidth=2, markersize=8)
|
||||
ax1.plot([s/1000 for s in sizes], p95s, "s-", label="P95", linewidth=2, markersize=8)
|
||||
ax1.set_xlabel("Corpus Size (thousands)")
|
||||
ax1.set_ylabel("Latency (ms)")
|
||||
ax1.set_title("Latency Scaling")
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Right: QPS
|
||||
qps = [np.mean([r["qps"] for r in by_size[s]]) for s in sizes]
|
||||
ax2.plot([s/1000 for s in sizes], qps, "o-", color="green", linewidth=2, markersize=8)
|
||||
ax2.set_xlabel("Corpus Size (thousands)")
|
||||
ax2.set_ylabel("Queries Per Second")
|
||||
ax2.set_title("Throughput Scaling")
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
output_file = output_dir / "scaling_analysis.png"
|
||||
plt.savefig(output_file, dpi=150, bbox_inches="tight")
|
||||
print(f"Saved: {output_file}")
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
results_dir = Path("benchmarks/results")
|
||||
output_dir = Path("benchmarks/figures")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
results = load_corpus_results(results_dir)
|
||||
|
||||
if not results:
|
||||
print("No corpus benchmark results found")
|
||||
return
|
||||
|
||||
print(f"Loaded {len(results)} benchmark runs")
|
||||
|
||||
# Generate plots
|
||||
plot_latency_by_corpus_size(results, output_dir)
|
||||
plot_qps_vs_size(results, output_dir)
|
||||
plot_scaling_analysis(results, output_dir)
|
||||
|
||||
print(f"\n✓ Generated corpus analysis plots in {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
244
scripts/plot_results.py
Normal file
244
scripts/plot_results.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Plot benchmark results and save to PNG, export to CSV."""
|
||||
|
||||
import json
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def load_results(result_dir: Path = Path("benchmarks/results")) -> dict:
|
||||
"""Load all benchmark results."""
|
||||
results = {}
|
||||
|
||||
# Load old-style results (flat JSON files)
|
||||
for json_file in result_dir.glob("*.json"):
|
||||
if "benchmark" in json_file.stem:
|
||||
with open(json_file) as f:
|
||||
data = json.load(f)
|
||||
benchmark_name = data.get("benchmark", json_file.stem.replace("_benchmark", ""))
|
||||
results[benchmark_name] = data
|
||||
|
||||
# Load new-style results (corpus/date/results.json)
|
||||
for corpus_dir in result_dir.iterdir():
|
||||
if corpus_dir.is_dir():
|
||||
for date_dir in corpus_dir.iterdir():
|
||||
if date_dir.is_dir():
|
||||
results_file = date_dir / "results.json"
|
||||
if results_file.exists():
|
||||
with open(results_file) as f:
|
||||
data_list = json.load(f)
|
||||
if isinstance(data_list, list) and data_list:
|
||||
# Use first result as representative or aggregate
|
||||
corpus_name = corpus_dir.name
|
||||
date_str = date_dir.name
|
||||
key = f"{corpus_name}_{date_str}"
|
||||
results[key] = data_list[0] # Simplified
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def export_to_csv(results: dict, output_file: Path = Path("benchmarks/results/benchmark_results.csv")):
|
||||
"""Export benchmark results to CSV."""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rows = []
|
||||
for bench_name, data in results.items():
|
||||
# Extract key metrics
|
||||
row = {
|
||||
"benchmark": bench_name,
|
||||
"p50_ms": data.get("attach_p50_ms") or data.get("search_p50_ms") or data.get("batch_p50_ms") or data.get("build_p50_ms") or 0.0,
|
||||
"p95_ms": data.get("attach_p95_ms") or data.get("search_p95_ms") or data.get("batch_p95_ms") or data.get("build_p95_ms") or 0.0,
|
||||
"p99_ms": data.get("attach_p99_ms") or data.get("search_p99_ms") or data.get("batch_p99_ms") or data.get("build_p99_ms") or 0.0,
|
||||
"peak_rss_mb": data.get("peak_rss_mb", 0.0),
|
||||
"memory_delta_mb": data.get("memory_delta_mb", 0.0),
|
||||
}
|
||||
|
||||
# Add specific metrics if available
|
||||
if "attach_p50_ms" in data:
|
||||
row.update({
|
||||
"attach_p50_ms": data.get("attach_p50_ms", 0),
|
||||
"attach_p95_ms": data.get("attach_p95_ms", 0),
|
||||
"attach_p99_ms": data.get("attach_p99_ms", 0),
|
||||
"get_p50_ms": data.get("get_p50_ms", 0),
|
||||
"get_p95_ms": data.get("get_p95_ms", 0),
|
||||
"get_p99_ms": data.get("get_p99_ms", 0),
|
||||
})
|
||||
if "search_p50_ms" in data:
|
||||
row.update({
|
||||
"search_p50_ms": data.get("search_p50_ms", 0),
|
||||
"search_p95_ms": data.get("search_p95_ms", 0),
|
||||
"search_p99_ms": data.get("search_p99_ms", 0),
|
||||
})
|
||||
|
||||
# Add build peak RSS if available
|
||||
if "build_peak_rss_mb" in data:
|
||||
row["build_peak_rss_mb"] = data.get("build_peak_rss_mb", 0.0)
|
||||
|
||||
rows.append(row)
|
||||
|
||||
if rows:
|
||||
fieldnames = set()
|
||||
for row in rows:
|
||||
fieldnames.update(row.keys())
|
||||
fieldnames = sorted(fieldnames)
|
||||
|
||||
with open(output_file, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
|
||||
print(f"Results exported to CSV: {output_file}")
|
||||
|
||||
|
||||
def plot_latency_distribution(results: dict, output_dir: Path = Path("benchmarks/figures")):
|
||||
"""Plot latency distributions."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
benchmarks = []
|
||||
p50_values = []
|
||||
p95_values = []
|
||||
p99_values = []
|
||||
|
||||
for name, data in results.items():
|
||||
# Try different metric names
|
||||
p50 = data.get("search_p50_ms") or data.get("attach_p50_ms") or data.get("batch_p50_ms") or data.get("build_p50_ms", 0)
|
||||
p95 = data.get("search_p95_ms") or data.get("attach_p95_ms") or data.get("batch_p95_ms") or data.get("build_p95_ms", 0)
|
||||
p99 = data.get("search_p99_ms") or data.get("attach_p99_ms") or data.get("batch_p99_ms") or data.get("build_p99_ms", 0)
|
||||
|
||||
if p50 > 0 or p95 > 0 or p99 > 0:
|
||||
benchmarks.append(name)
|
||||
p50_values.append(p50)
|
||||
p95_values.append(p95)
|
||||
p99_values.append(p99)
|
||||
|
||||
if benchmarks:
|
||||
fig, ax = plt.subplots(figsize=(12, 7))
|
||||
x = range(len(benchmarks))
|
||||
width = 0.25
|
||||
|
||||
ax.bar([i - width for i in x], p50_values, width, label="P50", alpha=0.8, color="#2ecc71")
|
||||
ax.bar(x, p95_values, width, label="P95", alpha=0.8, color="#3498db")
|
||||
ax.bar([i + width for i in x], p99_values, width, label="P99", alpha=0.8, color="#e74c3c")
|
||||
|
||||
ax.set_xlabel("Benchmark", fontsize=12, fontweight="bold")
|
||||
ax.set_ylabel("Latency (ms)", fontsize=12, fontweight="bold")
|
||||
ax.set_title("Latency Percentiles by Benchmark", fontsize=14, fontweight="bold")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(benchmarks, rotation=45, ha="right")
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3, linestyle="--")
|
||||
|
||||
# Add value labels on bars
|
||||
for i, (p50, p95, p99) in enumerate(zip(p50_values, p95_values, p99_values)):
|
||||
if p50 > 0:
|
||||
ax.text(i - width, p50, f"{p50:.2f}", ha="center", va="bottom", fontsize=8)
|
||||
if p95 > 0:
|
||||
ax.text(i, p95, f"{p95:.2f}", ha="center", va="bottom", fontsize=8)
|
||||
if p99 > 0:
|
||||
ax.text(i + width, p99, f"{p99:.2f}", ha="center", va="bottom", fontsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
output_file = output_dir / "latency_distribution.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches="tight")
|
||||
print(f"Latency plot saved to {output_file}")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_comparison_chart(results: dict, output_dir: Path = Path("benchmarks/figures")):
|
||||
"""Plot comparison chart of all benchmarks."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
benchmarks = []
|
||||
p95_latencies = []
|
||||
|
||||
for name, data in results.items():
|
||||
p95 = data.get("search_p95_ms") or data.get("attach_p95_ms") or data.get("batch_p95_ms") or data.get("build_p95_ms", 0)
|
||||
if p95 > 0:
|
||||
benchmarks.append(name)
|
||||
p95_latencies.append(p95)
|
||||
|
||||
if benchmarks:
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
colors = plt.cm.viridis(range(len(benchmarks)))
|
||||
bars = ax.barh(benchmarks, p95_latencies, color=colors, alpha=0.8)
|
||||
|
||||
ax.set_xlabel("P95 Latency (ms)", fontsize=12, fontweight="bold")
|
||||
ax.set_title("Benchmark Performance Comparison (P95 Latency)", fontsize=14, fontweight="bold")
|
||||
ax.grid(True, alpha=0.3, linestyle="--", axis="x")
|
||||
|
||||
# Add value labels
|
||||
for bar, latency in zip(bars, p95_latencies):
|
||||
width = bar.get_width()
|
||||
ax.text(width, bar.get_y() + bar.get_height()/2, f"{latency:.2f}ms",
|
||||
ha="left", va="center", fontsize=9, fontweight="bold")
|
||||
|
||||
plt.tight_layout()
|
||||
output_file = output_dir / "benchmark_comparison.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches="tight")
|
||||
print(f"Comparison plot saved to {output_file}")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_memory_usage(results: dict, output_dir: Path = Path("benchmarks/figures")):
|
||||
"""Plot memory usage (peak RSS) by benchmark."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
benchmarks = []
|
||||
peak_rss_values = []
|
||||
memory_delta_values = []
|
||||
|
||||
for name, data in results.items():
|
||||
peak_rss = data.get("peak_rss_mb", 0.0)
|
||||
memory_delta = data.get("memory_delta_mb", 0.0)
|
||||
if peak_rss > 0:
|
||||
benchmarks.append(name)
|
||||
peak_rss_values.append(peak_rss)
|
||||
memory_delta_values.append(memory_delta)
|
||||
|
||||
if benchmarks:
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
||||
|
||||
# Plot 1: Peak RSS
|
||||
colors1 = plt.cm.plasma(range(len(benchmarks)))
|
||||
bars1 = ax1.barh(benchmarks, peak_rss_values, color=colors1, alpha=0.8)
|
||||
ax1.set_xlabel("Peak RSS (MB)", fontsize=12, fontweight="bold")
|
||||
ax1.set_title("Peak Memory Usage by Benchmark", fontsize=14, fontweight="bold")
|
||||
ax1.grid(True, alpha=0.3, linestyle="--", axis="x")
|
||||
|
||||
# Add value labels
|
||||
for bar, rss in zip(bars1, peak_rss_values):
|
||||
width = bar.get_width()
|
||||
ax1.text(width, bar.get_y() + bar.get_height()/2, f"{rss:.2f}MB",
|
||||
ha="left", va="center", fontsize=9, fontweight="bold")
|
||||
|
||||
# Plot 2: Memory Delta
|
||||
colors2 = plt.cm.coolwarm(range(len(benchmarks)))
|
||||
bars2 = ax2.barh(benchmarks, memory_delta_values, color=colors2, alpha=0.8)
|
||||
ax2.set_xlabel("Memory Delta (MB)", fontsize=12, fontweight="bold")
|
||||
ax2.set_title("Memory Allocation Delta by Benchmark", fontsize=14, fontweight="bold")
|
||||
ax2.grid(True, alpha=0.3, linestyle="--", axis="x")
|
||||
|
||||
# Add value labels
|
||||
for bar, delta in zip(bars2, memory_delta_values):
|
||||
width = bar.get_width()
|
||||
ax2.text(width, bar.get_y() + bar.get_height()/2, f"{delta:.2f}MB",
|
||||
ha="left", va="center", fontsize=9, fontweight="bold")
|
||||
|
||||
plt.tight_layout()
|
||||
output_file = output_dir / "memory_usage.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches="tight")
|
||||
print(f"Memory usage plot saved to {output_file}")
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = load_results()
|
||||
if results:
|
||||
export_to_csv(results)
|
||||
plot_latency_distribution(results)
|
||||
plot_comparison_chart(results)
|
||||
plot_memory_usage(results)
|
||||
print(f"\nProcessed {len(results)} benchmark results")
|
||||
else:
|
||||
print("No benchmark results found. Run benchmarks first.")
|
||||
91
scripts/prepare_embeddings.py
Normal file
91
scripts/prepare_embeddings.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Prepare embeddings for datasets."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
def generate_deterministic_embeddings(
|
||||
corpus_file: Path,
|
||||
output_file: Path,
|
||||
dim: int = 384,
|
||||
seed: int = 42,
|
||||
limit: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate deterministic embeddings for a corpus.
|
||||
|
||||
Args:
|
||||
corpus_file: Path to corpus JSONL file
|
||||
output_file: Output .npy file for embeddings
|
||||
dim: Embedding dimension
|
||||
seed: Random seed for reproducibility
|
||||
limit: Optional limit on number of documents
|
||||
"""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
|
||||
embeddings = []
|
||||
count = 0
|
||||
|
||||
print(f"Generating deterministic embeddings (dim={dim}, seed={seed})...")
|
||||
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if limit and count >= limit:
|
||||
break
|
||||
|
||||
if line.strip():
|
||||
doc = json.loads(line)
|
||||
# Generate deterministic embedding based on document ID
|
||||
doc_hash = hash(doc["id"]) % (2**31)
|
||||
rng_local = np.random.RandomState(seed + doc_hash)
|
||||
|
||||
# Generate normalized random vector
|
||||
emb = rng_local.randn(dim).astype(np.float32)
|
||||
emb = emb / np.linalg.norm(emb)
|
||||
|
||||
embeddings.append(emb)
|
||||
count += 1
|
||||
|
||||
if count % 10000 == 0:
|
||||
print(f"Processed {count} documents...")
|
||||
|
||||
embeddings_array = np.stack(embeddings)
|
||||
np.save(output_file, embeddings_array)
|
||||
print(f"Saved {len(embeddings)} embeddings to {output_file}")
|
||||
|
||||
|
||||
def load_embeddings(emb_file: Path) -> np.ndarray:
|
||||
"""Load embeddings from .npy file."""
|
||||
return np.load(emb_file)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Prepare embeddings for corpus")
|
||||
parser.add_argument("--input", type=Path, required=True, help="Corpus JSONL file")
|
||||
parser.add_argument("--output", type=Path, required=True, help="Output .npy file")
|
||||
parser.add_argument("--dim", type=int, default=384, help="Embedding dimension")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument("--limit", type=int, help="Limit number of documents")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_deterministic_embeddings(
|
||||
args.input,
|
||||
args.output,
|
||||
dim=args.dim,
|
||||
seed=args.seed,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
247
scripts/profile_tail_latency.py
Normal file
247
scripts/profile_tail_latency.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Profile tail latency breakdown for retrieval pipeline.
|
||||
|
||||
This script profiles latency components to identify bottlenecks causing
|
||||
extreme P99 tail latencies.
|
||||
"""
|
||||
|
||||
import cProfile
|
||||
import pstats
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from llmds.hnsw import HNSW
|
||||
from llmds.retrieval_pipeline import RetrievalPipeline
|
||||
|
||||
|
||||
def profile_hnsw_search(num_vectors: int = 10000, dim: int = 128, num_queries: int = 1000):
|
||||
"""Profile HNSW search operations."""
|
||||
print(f"Profiling HNSW search with {num_vectors} vectors, dim={dim}, {num_queries} queries...")
|
||||
|
||||
np.random.seed(42)
|
||||
hnsw = HNSW(dim=dim, M=16, ef_construction=200, ef_search=50, seed=42)
|
||||
|
||||
# Build index
|
||||
vectors = []
|
||||
for i in range(num_vectors):
|
||||
vec = np.random.randn(dim).astype(np.float32)
|
||||
vec = vec / np.linalg.norm(vec)
|
||||
vectors.append(vec)
|
||||
hnsw.add(vec, i)
|
||||
|
||||
# Profile search operations
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
search_times = []
|
||||
for _ in range(num_queries):
|
||||
query = np.random.randn(dim).astype(np.float32)
|
||||
query = query / np.linalg.norm(query)
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
results = hnsw.search(query, k=10)
|
||||
elapsed = time.perf_counter() - start
|
||||
search_times.append(elapsed * 1000) # Convert to ms
|
||||
|
||||
profiler.disable()
|
||||
|
||||
# Compute latency statistics
|
||||
search_times.sort()
|
||||
p50 = search_times[len(search_times) // 2]
|
||||
p95 = search_times[int(len(search_times) * 0.95)]
|
||||
p99 = search_times[int(len(search_times) * 0.99)]
|
||||
p99_9 = search_times[int(len(search_times) * 0.999)] if len(search_times) >= 1000 else p99
|
||||
|
||||
print(f"\nHNSW Search Latency Statistics:")
|
||||
print(f" P50: {p50:.3f} ms")
|
||||
print(f" P95: {p95:.3f} ms")
|
||||
print(f" P99: {p99:.3f} ms")
|
||||
print(f" P99.9: {p99_9:.3f} ms")
|
||||
print(f" Mean: {statistics.mean(search_times):.3f} ms")
|
||||
print(f" Max: {max(search_times):.3f} ms")
|
||||
|
||||
# Analyze P99 outliers
|
||||
threshold = p95 * 2 # Outliers are 2x P95
|
||||
outliers = [t for t in search_times if t > threshold]
|
||||
if outliers:
|
||||
print(f"\n Outliers (>2x P95): {len(outliers)} queries ({len(outliers)/len(search_times)*100:.1f}%)")
|
||||
print(f" Outlier P50: {statistics.median(outliers):.3f} ms")
|
||||
print(f" Outlier Max: {max(outliers):.3f} ms")
|
||||
|
||||
# Generate profiling report
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.sort_stats("cumulative")
|
||||
|
||||
print("\nTop 20 functions by cumulative time:")
|
||||
print("=" * 80)
|
||||
stats.print_stats(20)
|
||||
|
||||
return {
|
||||
"p50_ms": p50,
|
||||
"p95_ms": p95,
|
||||
"p99_ms": p99,
|
||||
"p99_9_ms": p99_9,
|
||||
"mean_ms": statistics.mean(search_times),
|
||||
"max_ms": max(search_times),
|
||||
"outlier_count": len(outliers),
|
||||
"outlier_percent": len(outliers) / len(search_times) * 100 if search_times else 0,
|
||||
}
|
||||
|
||||
|
||||
def profile_retrieval_pipeline(num_docs: int = 5000, num_queries: int = 500):
|
||||
"""Profile complete retrieval pipeline."""
|
||||
print(f"\nProfiling RetrievalPipeline with {num_docs} docs, {num_queries} queries...")
|
||||
|
||||
np.random.seed(42)
|
||||
random = np.random.RandomState(42)
|
||||
|
||||
pipeline = RetrievalPipeline(embedding_dim=128, seed=42)
|
||||
|
||||
# Build index
|
||||
for i in range(num_docs):
|
||||
text = f"document {i} about topic {i % 10}"
|
||||
embedding = random.randn(128).astype(np.float32)
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
pipeline.add_document(doc_id=i, text=text, embedding=embedding)
|
||||
|
||||
# Profile search operations
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
search_times = []
|
||||
for _ in range(num_queries):
|
||||
query_text = "document topic"
|
||||
query_embedding = random.randn(128).astype(np.float32)
|
||||
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
results = pipeline.search(
|
||||
query_text, query_embedding=query_embedding, top_k=10
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
search_times.append(elapsed * 1000) # Convert to ms
|
||||
|
||||
profiler.disable()
|
||||
|
||||
# Compute latency statistics
|
||||
search_times.sort()
|
||||
p50 = search_times[len(search_times) // 2]
|
||||
p95 = search_times[int(len(search_times) * 0.95)]
|
||||
p99 = search_times[int(len(search_times) * 0.99)]
|
||||
|
||||
print(f"\nRetrieval Pipeline Latency Statistics:")
|
||||
print(f" P50: {p50:.3f} ms")
|
||||
print(f" P95: {p95:.3f} ms")
|
||||
print(f" P99: {p99:.3f} ms")
|
||||
print(f" Mean: {statistics.mean(search_times):.3f} ms")
|
||||
print(f" Max: {max(search_times):.3f} ms")
|
||||
|
||||
# Generate profiling report
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.sort_stats("cumulative")
|
||||
|
||||
print("\nTop 20 functions by cumulative time:")
|
||||
print("=" * 80)
|
||||
stats.print_stats(20)
|
||||
|
||||
return {
|
||||
"p50_ms": p50,
|
||||
"p95_ms": p95,
|
||||
"p99_ms": p99,
|
||||
"mean_ms": statistics.mean(search_times),
|
||||
"max_ms": max(search_times),
|
||||
}
|
||||
|
||||
|
||||
def profile_latency_breakdown(num_vectors: int = 5000, dim: int = 128):
|
||||
"""Profile latency breakdown by component."""
|
||||
print(f"\nProfiling latency breakdown with {num_vectors} vectors...")
|
||||
|
||||
np.random.seed(42)
|
||||
hnsw = HNSW(dim=dim, M=16, ef_construction=200, ef_search=50, seed=42)
|
||||
|
||||
# Build index
|
||||
vectors = []
|
||||
for i in range(num_vectors):
|
||||
vec = np.random.randn(dim).astype(np.float32)
|
||||
vec = vec / np.linalg.norm(vec)
|
||||
vectors.append(vec)
|
||||
hnsw.add(vec, i)
|
||||
|
||||
# Profile individual operations
|
||||
import time
|
||||
|
||||
search_times = []
|
||||
distance_computation_times = []
|
||||
|
||||
for _ in range(100):
|
||||
query = np.random.randn(dim).astype(np.float32)
|
||||
query = query / np.linalg.norm(query)
|
||||
|
||||
# Profile distance computations
|
||||
dist_start = time.perf_counter()
|
||||
distances = [np.linalg.norm(query - vec) for vec in vectors[:100]]
|
||||
dist_time = (time.perf_counter() - dist_start) * 1000
|
||||
distance_computation_times.append(dist_time)
|
||||
|
||||
# Profile search
|
||||
search_start = time.perf_counter()
|
||||
results = hnsw.search(query, k=10)
|
||||
search_time = (time.perf_counter() - search_start) * 1000
|
||||
search_times.append(search_time)
|
||||
|
||||
print(f"\nLatency Breakdown:")
|
||||
print(f" Distance computation: {statistics.mean(distance_computation_times):.3f} ms (mean)")
|
||||
print(f" HNSW search: {statistics.mean(search_times):.3f} ms (mean)")
|
||||
print(f" Search/Distance ratio: {statistics.mean(search_times) / statistics.mean(distance_computation_times):.2f}x")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all profiling tasks."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Profile tail latency")
|
||||
parser.add_argument("--output", type=Path, default=Path("audit/tail_latency_profile.txt"),
|
||||
help="Output file for profiling report")
|
||||
parser.add_argument("--num-vectors", type=int, default=10000,
|
||||
help="Number of vectors for HNSW profiling")
|
||||
parser.add_argument("--num-docs", type=int, default=5000,
|
||||
help="Number of documents for pipeline profiling")
|
||||
parser.add_argument("--num-queries", type=int, default=1000,
|
||||
help="Number of queries to run")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Redirect output to file
|
||||
import sys
|
||||
with open(args.output, "w") as f:
|
||||
sys.stdout = f
|
||||
try:
|
||||
# Profile HNSW
|
||||
hnsw_stats = profile_hnsw_search(args.num_vectors, 128, args.num_queries)
|
||||
|
||||
# Profile pipeline
|
||||
pipeline_stats = profile_retrieval_pipeline(args.num_docs, args.num_queries // 2)
|
||||
|
||||
# Breakdown
|
||||
profile_latency_breakdown(args.num_vectors, 128)
|
||||
finally:
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
print(f"\nProfiling complete. Report saved to: {args.output}")
|
||||
print(f"\nKey Findings:")
|
||||
print(f" HNSW P99: {hnsw_stats['p99_ms']:.3f} ms")
|
||||
print(f" Pipeline P99: {pipeline_stats['p99_ms']:.3f} ms")
|
||||
|
||||
if hnsw_stats.get("outlier_count", 0) > 0:
|
||||
print(f" HNSW Outliers: {hnsw_stats['outlier_count']} ({hnsw_stats['outlier_percent']:.1f}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
355
scripts/run_benchmarks.py
Normal file
355
scripts/run_benchmarks.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Run end-to-end benchmarks on real corpora with variance analysis."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from llmds.data_sources.beir_loader import load_beir
|
||||
from llmds.data_sources.amazon_reviews import load_amazon_reviews
|
||||
from llmds.retrieval_pipeline import RetrievalPipeline
|
||||
from llmds.utils import Timer, memory_profiler, calculate_statistics
|
||||
|
||||
|
||||
def aggregate_repetitions(results: list[dict]) -> dict[str, Any]:
|
||||
"""
|
||||
Aggregate results across repetitions with variance analysis.
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries from multiple repetitions
|
||||
|
||||
Returns:
|
||||
Dictionary with aggregated statistics including variance metrics
|
||||
"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Extract metric names (all numeric keys except metadata)
|
||||
metadata_keys = {"corpus", "size", "ef_search", "M", "num_queries", "repetition"}
|
||||
metric_keys = [k for k in results[0].keys() if k not in metadata_keys]
|
||||
|
||||
aggregated = {
|
||||
"corpus": results[0].get("corpus"),
|
||||
"size": results[0].get("size"),
|
||||
"ef_search": results[0].get("ef_search"),
|
||||
"M": results[0].get("M"),
|
||||
"num_queries": results[0].get("num_queries"),
|
||||
"repetitions": len(results),
|
||||
}
|
||||
|
||||
# Calculate statistics for each metric
|
||||
for metric in metric_keys:
|
||||
values = [r.get(metric, 0.0) for r in results if metric in r]
|
||||
if values:
|
||||
stats_dict = calculate_statistics(values)
|
||||
# Store both mean/std and full statistics
|
||||
aggregated[f"{metric}_mean"] = stats_dict["mean"]
|
||||
aggregated[f"{metric}_std"] = stats_dict["std"]
|
||||
aggregated[f"{metric}_min"] = stats_dict["min"]
|
||||
aggregated[f"{metric}_max"] = stats_dict["max"]
|
||||
aggregated[f"{metric}_ci_lower"] = stats_dict["ci_lower"]
|
||||
aggregated[f"{metric}_ci_upper"] = stats_dict["ci_upper"]
|
||||
aggregated[f"{metric}_cv"] = stats_dict["cv"] # Coefficient of variation
|
||||
|
||||
# Identify flaky benchmarks (high variance)
|
||||
# Mark as flaky if CV > 20% for critical metrics
|
||||
critical_metrics = ["search_p50_ms", "search_p95_ms", "qps"]
|
||||
flaky_metrics = []
|
||||
for metric in critical_metrics:
|
||||
cv_key = f"{metric}_cv"
|
||||
if cv_key in aggregated and aggregated[cv_key] > 20.0:
|
||||
flaky_metrics.append(metric)
|
||||
|
||||
aggregated["flaky_metrics"] = flaky_metrics
|
||||
aggregated["is_flaky"] = len(flaky_metrics) > 0
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
def load_corpus_sample(corpus_file: Path, size: int, seed: int = 42) -> list[dict]:
|
||||
"""Load a sample of documents from corpus."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
all_docs = []
|
||||
with open(corpus_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
all_docs.append(json.loads(line))
|
||||
|
||||
if len(all_docs) <= size:
|
||||
return all_docs
|
||||
|
||||
# Sample without replacement
|
||||
return random.sample(all_docs, size)
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
corpus_file: Path,
|
||||
emb_file: Path | None,
|
||||
corpus_name: str,
|
||||
size: int,
|
||||
ef_search: int,
|
||||
M: int,
|
||||
num_queries: int = 100,
|
||||
embedding_dim: int = 384,
|
||||
) -> dict:
|
||||
"""
|
||||
Run benchmark on a corpus sample.
|
||||
|
||||
Returns:
|
||||
Dictionary with benchmark results
|
||||
"""
|
||||
print(f"\n=== Benchmarking {corpus_name} (size={size}, ef={ef_search}, M={M}) ===")
|
||||
|
||||
# Load corpus sample
|
||||
print(f"Loading corpus sample...")
|
||||
docs = load_corpus_sample(corpus_file, size)
|
||||
print(f"Loaded {len(docs)} documents")
|
||||
|
||||
# Load or generate embeddings
|
||||
if emb_file and emb_file.exists():
|
||||
embeddings = np.load(emb_file)
|
||||
# Trim to sample size
|
||||
embeddings = embeddings[:len(docs)]
|
||||
else:
|
||||
print("Generating deterministic embeddings...")
|
||||
rng = np.random.RandomState(42)
|
||||
embeddings = []
|
||||
for i in range(len(docs)):
|
||||
emb = rng.randn(embedding_dim).astype(np.float32)
|
||||
emb = emb / np.linalg.norm(emb)
|
||||
embeddings.append(emb)
|
||||
embeddings = np.stack(embeddings)
|
||||
|
||||
# Build pipeline with deterministic seed
|
||||
print("Building pipeline...")
|
||||
|
||||
# Memory profiling for build phase
|
||||
with memory_profiler() as mem_profiler:
|
||||
pipeline = RetrievalPipeline(
|
||||
embedding_dim=embedding_dim,
|
||||
hnsw_M=M,
|
||||
hnsw_ef_search=ef_search,
|
||||
hnsw_ef_construction=ef_search * 4,
|
||||
seed=42, # Fixed seed for reproducible HNSW structure
|
||||
)
|
||||
|
||||
# Add documents
|
||||
build_times = []
|
||||
for i, doc in enumerate(docs):
|
||||
with Timer() as t:
|
||||
pipeline.add_document(
|
||||
doc_id=i,
|
||||
text=doc["text"],
|
||||
embedding=embeddings[i],
|
||||
)
|
||||
build_times.append(t.elapsed * 1000)
|
||||
# Sample memory periodically during build
|
||||
if (i + 1) % (len(docs) // 10 + 1) == 0:
|
||||
mem_profiler.sample()
|
||||
|
||||
build_peak_rss_mb = mem_profiler.get_peak_rss_mb()
|
||||
build_memory_delta_mb = mem_profiler.get_memory_delta_mb()
|
||||
|
||||
# Run queries with memory profiling
|
||||
print(f"Running {num_queries} queries...")
|
||||
search_times = []
|
||||
rng = np.random.RandomState(42)
|
||||
|
||||
# Generate query embeddings
|
||||
query_embeddings = []
|
||||
for _ in range(num_queries):
|
||||
qemb = rng.randn(embedding_dim).astype(np.float32)
|
||||
qemb = qemb / np.linalg.norm(qemb)
|
||||
query_embeddings.append(qemb)
|
||||
|
||||
# Use document texts as queries (simplified)
|
||||
query_texts = [docs[i % len(docs)]["text"][:100] for i in range(num_queries)]
|
||||
|
||||
# Memory profiling for search phase
|
||||
with memory_profiler() as search_mem_profiler:
|
||||
for i, (query_text, query_emb) in enumerate(zip(query_texts, query_embeddings)):
|
||||
with Timer() as t:
|
||||
pipeline.search(query_text, query_embedding=query_emb, top_k=10)
|
||||
search_times.append(t.elapsed * 1000)
|
||||
|
||||
# Sample memory periodically during search
|
||||
if (i + 1) % 20 == 0:
|
||||
search_mem_profiler.sample()
|
||||
print(f"Completed {i + 1}/{num_queries} queries...")
|
||||
|
||||
search_peak_rss_mb = search_mem_profiler.get_peak_rss_mb()
|
||||
|
||||
# Overall peak RSS (maximum of build and search phases)
|
||||
overall_peak_rss_mb = max(build_peak_rss_mb, search_peak_rss_mb)
|
||||
|
||||
# Compute statistics
|
||||
build_times_sorted = sorted(build_times)
|
||||
search_times_sorted = sorted(search_times)
|
||||
|
||||
results = {
|
||||
"corpus": corpus_name,
|
||||
"size": size,
|
||||
"ef_search": ef_search,
|
||||
"M": M,
|
||||
"num_queries": num_queries,
|
||||
"build_p50_ms": build_times_sorted[len(build_times_sorted) // 2],
|
||||
"build_p95_ms": build_times_sorted[int(len(build_times_sorted) * 0.95)],
|
||||
"build_p99_ms": build_times_sorted[int(len(build_times_sorted) * 0.99)],
|
||||
"search_p50_ms": search_times_sorted[len(search_times_sorted) // 2],
|
||||
"search_p95_ms": search_times_sorted[int(len(search_times_sorted) * 0.95)],
|
||||
"search_p99_ms": search_times_sorted[int(len(search_times_sorted) * 0.99)],
|
||||
"avg_build_time_ms": sum(build_times) / len(build_times),
|
||||
"avg_search_time_ms": sum(search_times) / len(search_times),
|
||||
"qps": 1000.0 / (sum(search_times) / len(search_times)) if search_times else 0.0,
|
||||
# Memory metrics
|
||||
"peak_rss_mb": overall_peak_rss_mb,
|
||||
"build_peak_rss_mb": build_peak_rss_mb,
|
||||
"build_memory_delta_mb": build_memory_delta_mb,
|
||||
"search_peak_rss_mb": search_peak_rss_mb,
|
||||
}
|
||||
|
||||
print(f"✓ Results: P50={results['search_p50_ms']:.2f}ms, P95={results['search_p95_ms']:.2f}ms, QPS={results['qps']:.2f}, Peak RSS={results['peak_rss_mb']:.2f}MB")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run benchmarks on real corpora")
|
||||
parser.add_argument("--corpus", type=str, required=True, help="Corpus name")
|
||||
parser.add_argument("--corpus-file", type=Path, required=True, help="Corpus JSONL file")
|
||||
parser.add_argument("--emb-file", type=Path, help="Embeddings .npy file")
|
||||
parser.add_argument("--sizes", nargs="+", type=str, default=["10k"], help="Corpus sizes (e.g., 10k 50k 100k)")
|
||||
parser.add_argument("--ef", nargs="+", type=int, default=[50], help="HNSW efSearch values")
|
||||
parser.add_argument("--M", nargs="+", type=int, default=[16], help="HNSW M values")
|
||||
parser.add_argument("--num-queries", type=int, default=100, help="Number of queries")
|
||||
parser.add_argument("--repetitions", type=int, default=5, help="Number of repetitions for variance analysis (default: 5)")
|
||||
parser.add_argument("--output-dir", type=Path, default=Path("benchmarks/results"), help="Output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse sizes
|
||||
def parse_size(s: str) -> int:
|
||||
s = s.lower()
|
||||
if s.endswith("k"):
|
||||
return int(s[:-1]) * 1000
|
||||
elif s.endswith("m"):
|
||||
return int(s[:-1]) * 1000000
|
||||
return int(s)
|
||||
|
||||
sizes = [parse_size(s) for s in args.sizes]
|
||||
|
||||
# Create output directory with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / args.corpus / timestamp
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
all_results = []
|
||||
aggregated_results = []
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Running benchmarks with {args.repetitions} repetitions per configuration")
|
||||
print(f"{'='*70}\n")
|
||||
|
||||
# Run benchmarks
|
||||
for size in sizes:
|
||||
for ef in args.ef:
|
||||
for M in args.M:
|
||||
config_key = f"{size}_{ef}_{M}"
|
||||
print(f"Configuration: size={size}, ef={ef}, M={M}")
|
||||
|
||||
repetition_results = []
|
||||
for rep in range(args.repetitions):
|
||||
print(f" Repetition {rep + 1}/{args.repetitions}...", end=" ", flush=True)
|
||||
result = run_benchmark(
|
||||
corpus_file=args.corpus_file,
|
||||
emb_file=args.emb_file,
|
||||
corpus_name=args.corpus,
|
||||
size=size,
|
||||
ef_search=ef,
|
||||
M=M,
|
||||
num_queries=args.num_queries,
|
||||
)
|
||||
result["repetition"] = rep
|
||||
repetition_results.append(result)
|
||||
all_results.append(result)
|
||||
print("✓")
|
||||
|
||||
# Aggregate across repetitions
|
||||
aggregated = aggregate_repetitions(repetition_results)
|
||||
if aggregated:
|
||||
# Keep original metrics for backward compatibility
|
||||
for metric in ["search_p50_ms", "search_p95_ms", "search_p99_ms", "qps"]:
|
||||
if f"{metric}_mean" in aggregated:
|
||||
aggregated[metric] = aggregated[f"{metric}_mean"]
|
||||
|
||||
aggregated_results.append(aggregated)
|
||||
|
||||
# Print variance summary
|
||||
print(f"\n Variance Summary:")
|
||||
print(f" Search P50: {aggregated.get('search_p50_ms_mean', 0):.2f} ± {aggregated.get('search_p50_ms_std', 0):.2f} ms (CV: {aggregated.get('search_p50_ms_cv', 0):.1f}%)")
|
||||
print(f" Search P95: {aggregated.get('search_p95_ms_mean', 0):.2f} ± {aggregated.get('search_p95_ms_std', 0):.2f} ms (CV: {aggregated.get('search_p95_ms_cv', 0):.1f}%)")
|
||||
print(f" QPS: {aggregated.get('qps_mean', 0):.2f} ± {aggregated.get('qps_std', 0):.2f} (CV: {aggregated.get('qps_cv', 0):.1f}%)")
|
||||
|
||||
if aggregated.get("is_flaky", False):
|
||||
print(f" ⚠️ FLAKY: High variance detected in {', '.join(aggregated.get('flaky_metrics', []))}")
|
||||
print()
|
||||
|
||||
# Save detailed results (all repetitions)
|
||||
results_file = output_dir / "results.json"
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
# Save aggregated results with variance statistics
|
||||
aggregated_file = output_dir / "results_aggregated.json"
|
||||
with open(aggregated_file, "w") as f:
|
||||
json.dump(aggregated_results, f, indent=2)
|
||||
|
||||
# Save CSV with all repetitions
|
||||
csv_file = output_dir / "results.csv"
|
||||
if all_results:
|
||||
fieldnames = list(all_results[0].keys())
|
||||
with open(csv_file, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(all_results)
|
||||
|
||||
# Save aggregated CSV
|
||||
aggregated_csv_file = output_dir / "results_aggregated.csv"
|
||||
if aggregated_results:
|
||||
agg_fieldnames = list(aggregated_results[0].keys())
|
||||
with open(aggregated_csv_file, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=agg_fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(aggregated_results)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Benchmark Summary")
|
||||
print(f"{'='*70}")
|
||||
print(f"Total configurations: {len(aggregated_results)}")
|
||||
print(f"Total repetitions: {len(all_results)}")
|
||||
flaky_count = sum(1 for r in aggregated_results if r.get("is_flaky", False))
|
||||
if flaky_count > 0:
|
||||
print(f"⚠️ Flaky configurations: {flaky_count}")
|
||||
print(f"\nResults saved to:")
|
||||
print(f" - Detailed: {results_file}")
|
||||
print(f" - Aggregated: {aggregated_file}")
|
||||
print(f" - CSV: {csv_file}")
|
||||
print(f" - Aggregated CSV: {aggregated_csv_file}")
|
||||
print(f"{'='*70}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
281
scripts/run_multi_dataset_benchmarks.py
Normal file
281
scripts/run_multi_dataset_benchmarks.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Run benchmarks across multiple datasets for comparison."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
def prepare_dataset(
|
||||
source: str,
|
||||
corpus_name: str,
|
||||
output_dir: Path,
|
||||
limit: int | None = None,
|
||||
download: bool = True,
|
||||
) -> Path | None:
|
||||
"""Prepare a dataset: download, prepare embeddings, ready for benchmarking."""
|
||||
corpus_dir = output_dir / "raw" / corpus_name
|
||||
embeddings_dir = output_dir / "embeddings"
|
||||
corpus_file = None
|
||||
|
||||
# Find existing corpus file (check multiple possible names)
|
||||
possible_files = ["corpus.jsonl", "reviews.jsonl", "business_reviews.jsonl", "pages.jsonl"]
|
||||
for filename in possible_files:
|
||||
if (corpus_dir / filename).exists():
|
||||
corpus_file = corpus_dir / filename
|
||||
break
|
||||
|
||||
# Also check beir subdirectory for fiqa
|
||||
if corpus_file is None and corpus_name == "fiqa":
|
||||
beir_dir = output_dir / "raw" / "beir" / corpus_name
|
||||
if (beir_dir / "corpus.jsonl").exists():
|
||||
corpus_file = beir_dir / "corpus.jsonl"
|
||||
|
||||
# Download if needed and not exists
|
||||
if download and corpus_file is None:
|
||||
print(f"\n📥 Downloading {corpus_name}...")
|
||||
try:
|
||||
if source.startswith("beir:"):
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"scripts/download_corpus.py",
|
||||
"--source", source,
|
||||
"--output", str(corpus_dir),
|
||||
]
|
||||
else:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"scripts/download_corpus.py",
|
||||
"--source", source,
|
||||
"--output", str(corpus_dir),
|
||||
]
|
||||
if limit:
|
||||
cmd.extend(["--limit", str(limit)])
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"⚠️ Download failed: {result.stderr}")
|
||||
return None
|
||||
|
||||
# Find corpus file after download
|
||||
if (corpus_dir / "corpus.jsonl").exists():
|
||||
corpus_file = corpus_dir / "corpus.jsonl"
|
||||
elif corpus_name == "amazon23" and (corpus_dir / "reviews.jsonl").exists():
|
||||
corpus_file = corpus_dir / "reviews.jsonl"
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error downloading {corpus_name}: {e}")
|
||||
return None
|
||||
|
||||
if corpus_file is None or not corpus_file.exists():
|
||||
print(f"⚠️ Corpus file not found for {corpus_name}")
|
||||
return None
|
||||
|
||||
# Check embeddings
|
||||
emb_file = embeddings_dir / f"{corpus_name}.npy"
|
||||
if not emb_file.exists():
|
||||
print(f"\n🔢 Preparing embeddings for {corpus_name}...")
|
||||
embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"scripts/prepare_embeddings.py",
|
||||
"--input", str(corpus_file),
|
||||
"--output", str(emb_file),
|
||||
"--dim", "384",
|
||||
"--seed", "42",
|
||||
]
|
||||
if limit:
|
||||
cmd.extend(["--limit", str(limit)])
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"⚠️ Embedding preparation failed: {result.stderr}")
|
||||
return None
|
||||
|
||||
return corpus_file
|
||||
|
||||
|
||||
def run_benchmarks_for_dataset(
|
||||
corpus_name: str,
|
||||
corpus_file: Path,
|
||||
emb_file: Path,
|
||||
sizes: list[str],
|
||||
ef_values: list[int],
|
||||
M_values: list[int],
|
||||
num_queries: int = 50, # Reduced for faster multi-dataset runs
|
||||
output_dir: Path = Path("benchmarks/results"),
|
||||
) -> Path | None:
|
||||
"""Run benchmarks for a single dataset."""
|
||||
print(f"\n🚀 Running benchmarks for {corpus_name}...")
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"scripts/run_benchmarks.py",
|
||||
"--corpus", corpus_name,
|
||||
"--corpus-file", str(corpus_file),
|
||||
"--emb-file", str(emb_file),
|
||||
"--sizes", *sizes,
|
||||
"--ef", *[str(e) for e in ef_values],
|
||||
"--M", *[str(m) for m in M_values],
|
||||
"--num-queries", str(num_queries),
|
||||
"--output-dir", str(output_dir),
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"⚠️ Benchmark failed for {corpus_name}: {result.stderr}")
|
||||
return None
|
||||
|
||||
# Find the results directory
|
||||
results_dir = output_dir / corpus_name
|
||||
if results_dir.exists():
|
||||
timestamp_dirs = sorted([d for d in results_dir.iterdir() if d.is_dir()], key=lambda x: x.name)
|
||||
if timestamp_dirs:
|
||||
return timestamp_dirs[-1] / "results.json"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run benchmarks across multiple datasets")
|
||||
parser.add_argument(
|
||||
"--datasets",
|
||||
nargs="+",
|
||||
default=["fiqa", "amazon23", "msmarco"],
|
||||
help="Datasets to benchmark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sizes",
|
||||
nargs="+",
|
||||
default=["10k", "25k", "50k"],
|
||||
help="Corpus sizes (e.g., 10k 25k 50k)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ef",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[50, 100],
|
||||
help="HNSW efSearch values"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--M",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[8, 16],
|
||||
help="HNSW M values"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-queries",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of queries per benchmark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-download",
|
||||
action="store_true",
|
||||
help="Skip downloading datasets (use existing)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
help="Limit documents per dataset (for large datasets)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("benchmarks/results"),
|
||||
help="Output directory"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset sources mapping
|
||||
dataset_sources = {
|
||||
"fiqa": "beir:fiqa",
|
||||
"amazon23": "amazon23",
|
||||
"msmarco": "msmarco",
|
||||
}
|
||||
|
||||
data_dir = Path("data")
|
||||
embeddings_dir = data_dir / "embeddings"
|
||||
embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
results = {}
|
||||
|
||||
print("=" * 70)
|
||||
print("Multi-Dataset Benchmark Runner")
|
||||
print("=" * 70)
|
||||
print(f"Datasets: {', '.join(args.datasets)}")
|
||||
print(f"Sizes: {', '.join(args.sizes)}")
|
||||
print(f"efSearch: {', '.join(map(str, args.ef))}")
|
||||
print(f"M: {', '.join(map(str, args.M))}")
|
||||
print("=" * 70)
|
||||
|
||||
for corpus_name in args.datasets:
|
||||
if corpus_name not in dataset_sources:
|
||||
print(f"⚠️ Unknown dataset: {corpus_name}, skipping")
|
||||
continue
|
||||
|
||||
source = dataset_sources[corpus_name]
|
||||
limit = args.limit if corpus_name in ["amazon23", "msmarco"] else None
|
||||
|
||||
# Prepare dataset
|
||||
corpus_file = prepare_dataset(
|
||||
source=source,
|
||||
corpus_name=corpus_name,
|
||||
output_dir=data_dir,
|
||||
limit=limit,
|
||||
download=not args.skip_download,
|
||||
)
|
||||
|
||||
if corpus_file is None:
|
||||
print(f"⚠️ Skipping {corpus_name} - preparation failed")
|
||||
continue
|
||||
|
||||
# Check embeddings
|
||||
emb_file = embeddings_dir / f"{corpus_name}.npy"
|
||||
if not emb_file.exists():
|
||||
print(f"⚠️ Embeddings not found for {corpus_name}, skipping")
|
||||
continue
|
||||
|
||||
# Run benchmarks
|
||||
results_file = run_benchmarks_for_dataset(
|
||||
corpus_name=corpus_name,
|
||||
corpus_file=corpus_file,
|
||||
emb_file=emb_file,
|
||||
sizes=args.sizes,
|
||||
ef_values=args.ef,
|
||||
M_values=args.M,
|
||||
num_queries=args.num_queries,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
|
||||
if results_file and results_file.exists():
|
||||
with open(results_file) as f:
|
||||
results[corpus_name] = json.load(f)
|
||||
print(f"✓ {corpus_name} benchmarks completed")
|
||||
else:
|
||||
print(f"⚠️ {corpus_name} benchmarks incomplete")
|
||||
|
||||
# Save combined results
|
||||
if results:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
combined_file = args.output_dir / f"multi_dataset_{timestamp}.json"
|
||||
combined_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(combined_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\n✓ Combined results saved to {combined_file}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Multi-dataset benchmarks completed!")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
306
scripts/security_scan.py
Normal file
306
scripts/security_scan.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Security scanning script using Bandit and pip-audit.
|
||||
|
||||
This script runs security scans to identify vulnerabilities.
|
||||
Note: Requires bandit and pip-audit to be installed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def run_bandit(output_dir: Path) -> bool:
|
||||
"""
|
||||
Run Bandit security scanner.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save results
|
||||
|
||||
Returns:
|
||||
True if scan completed successfully
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
json_output = output_dir / "bandit_report.json"
|
||||
txt_output = output_dir / "bandit_report.txt"
|
||||
|
||||
print("Running Bandit security scanner...")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
# Run Bandit with JSON and text output
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable, "-m", "bandit",
|
||||
"-r", "llmds",
|
||||
"-f", "json",
|
||||
"-o", str(json_output),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Also generate text report
|
||||
subprocess.run(
|
||||
[
|
||||
sys.executable, "-m", "bandit",
|
||||
"-r", "llmds",
|
||||
"-f", "txt",
|
||||
"-o", str(txt_output),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Parse results
|
||||
if json_output.exists():
|
||||
with open(json_output) as f:
|
||||
bandit_data = json.load(f)
|
||||
|
||||
# Count issues by severity
|
||||
metrics = bandit_data.get("metrics", {})
|
||||
total = metrics.get("_totals", {})
|
||||
|
||||
print(f"\nBandit Results:")
|
||||
print(f" HIGH: {total.get('SEVERITY.HIGH', 0)} issues")
|
||||
print(f" MEDIUM: {total.get('SEVERITY.MEDIUM', 0)} issues")
|
||||
print(f" LOW: {total.get('SEVERITY.LOW', 0)} issues")
|
||||
print(f" Total: {total.get('CONFIDENCE.HIGH', 0)} high confidence issues")
|
||||
|
||||
# List high severity issues
|
||||
high_severity = [
|
||||
issue for issue in bandit_data.get("results", [])
|
||||
if issue.get("issue_severity") == "HIGH"
|
||||
]
|
||||
|
||||
if high_severity:
|
||||
print(f"\n HIGH Severity Issues ({len(high_severity)}):")
|
||||
for issue in high_severity[:10]: # Show first 10
|
||||
print(f" - {issue.get('test_id')}: {issue.get('test_name')}")
|
||||
print(f" File: {issue.get('filename')}:{issue.get('line_number')}")
|
||||
|
||||
print(f"\n Full report: {txt_output}")
|
||||
print(f" JSON report: {json_output}")
|
||||
|
||||
return total.get("SEVERITY.HIGH", 0) == 0
|
||||
else:
|
||||
print(" Warning: Bandit JSON output not found")
|
||||
return False
|
||||
|
||||
except FileNotFoundError:
|
||||
print(" Error: Bandit not installed. Install with: pip install bandit[toml]")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" Error running Bandit: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def run_pip_audit(output_dir: Path) -> bool:
|
||||
"""
|
||||
Run pip-audit to check for known vulnerabilities in dependencies.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save results
|
||||
|
||||
Returns:
|
||||
True if no HIGH/CRITICAL vulnerabilities found
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
json_output = output_dir / "pip_audit_report.json"
|
||||
txt_output = output_dir / "pip_audit_report.txt"
|
||||
|
||||
print("\nRunning pip-audit security scanner...")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
# Run pip-audit
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable, "-m", "pip_audit",
|
||||
"--format", "json",
|
||||
"--output", str(json_output),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Also generate text output
|
||||
subprocess.run(
|
||||
[
|
||||
sys.executable, "-m", "pip_audit",
|
||||
"--format", "text",
|
||||
"--output", str(txt_output),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Parse results
|
||||
if json_output.exists():
|
||||
with open(json_output) as f:
|
||||
audit_data = json.load(f)
|
||||
|
||||
vulnerabilities = audit_data.get("vulnerabilities", [])
|
||||
high_critical = [
|
||||
v for v in vulnerabilities
|
||||
if v.get("aliases", [{}])[0].get("severity", "").upper() in ["HIGH", "CRITICAL"]
|
||||
]
|
||||
|
||||
print(f"\npip-audit Results:")
|
||||
print(f" Total vulnerabilities: {len(vulnerabilities)}")
|
||||
print(f" HIGH/CRITICAL: {len(high_critical)}")
|
||||
|
||||
if high_critical:
|
||||
print(f"\n HIGH/CRITICAL Vulnerabilities:")
|
||||
for vuln in high_critical[:10]: # Show first 10
|
||||
package = vuln.get("name", "unknown")
|
||||
severity = vuln.get("aliases", [{}])[0].get("severity", "UNKNOWN")
|
||||
print(f" - {package}: {severity}")
|
||||
if "versions" in vuln:
|
||||
print(f" Affected versions: {vuln['versions']}")
|
||||
|
||||
print(f"\n Full report: {txt_output}")
|
||||
print(f" JSON report: {json_output}")
|
||||
|
||||
return len(high_critical) == 0
|
||||
else:
|
||||
print(" Warning: pip-audit JSON output not found")
|
||||
# Check if there were errors
|
||||
if result.stderr:
|
||||
print(f" Error output: {result.stderr}")
|
||||
return False
|
||||
|
||||
except FileNotFoundError:
|
||||
print(" Error: pip-audit not installed. Install with: pip install pip-audit")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" Error running pip-audit: {e}")
|
||||
if result.stderr:
|
||||
print(f" Error output: {result.stderr}")
|
||||
return False
|
||||
|
||||
|
||||
def generate_sbom(output_dir: Path) -> bool:
|
||||
"""
|
||||
Generate Software Bill of Materials (SBOM) using pip-audit.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save SBOM
|
||||
|
||||
Returns:
|
||||
True if SBOM generated successfully
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
sbom_output = output_dir / "sbom.json"
|
||||
|
||||
print("\nGenerating SBOM (Software Bill of Materials)...")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
# Try to generate SBOM using pip-audit (if supported)
|
||||
# Note: pip-audit may need additional flags for SBOM generation
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable, "-m", "pip_audit",
|
||||
"--format", "json",
|
||||
"--output", str(sbom_output),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if sbom_output.exists():
|
||||
print(f" SBOM generated: {sbom_output}")
|
||||
print(" Note: For CycloneDX format, consider using cyclonedx-bom or pip-tools")
|
||||
return True
|
||||
else:
|
||||
print(" Warning: SBOM generation may require additional tools")
|
||||
print(" Consider using: cyclonedx-py or pip-tools for full SBOM")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error generating SBOM: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all security scans."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run security scans")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("audit/security"),
|
||||
help="Directory for security scan results (default: audit/security)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-bandit",
|
||||
action="store_true",
|
||||
help="Skip Bandit scan",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-pip-audit",
|
||||
action="store_true",
|
||||
help="Skip pip-audit scan",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-sbom",
|
||||
action="store_true",
|
||||
help="Skip SBOM generation",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Security Scanning")
|
||||
print("=" * 80)
|
||||
print(f"Output directory: {args.output_dir}")
|
||||
print()
|
||||
|
||||
results = {}
|
||||
|
||||
# Run Bandit
|
||||
if not args.skip_bandit:
|
||||
results["bandit"] = run_bandit(args.output_dir)
|
||||
else:
|
||||
print("Skipping Bandit scan")
|
||||
|
||||
# Run pip-audit
|
||||
if not args.skip_pip_audit:
|
||||
results["pip_audit"] = run_pip_audit(args.output_dir)
|
||||
else:
|
||||
print("Skipping pip-audit scan")
|
||||
|
||||
# Generate SBOM
|
||||
if not args.skip_sbom:
|
||||
results["sbom"] = generate_sbom(args.output_dir)
|
||||
else:
|
||||
print("Skipping SBOM generation")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 80)
|
||||
print("Summary")
|
||||
print("=" * 80)
|
||||
|
||||
all_passed = all(results.values())
|
||||
|
||||
for tool, passed in results.items():
|
||||
status = "✓ PASSED" if passed else "✗ FAILED"
|
||||
print(f" {tool}: {status}")
|
||||
|
||||
if all_passed:
|
||||
print("\n✓ All security scans passed!")
|
||||
return 0
|
||||
else:
|
||||
print("\n✗ Some security issues found. Please review reports.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
Reference in New Issue
Block a user