Files
llm-rag-ds-optimizer/scripts/prepare_embeddings.py

92 lines
2.6 KiB
Python

"""Prepare embeddings for datasets."""
import argparse
import json
import sys
from pathlib import Path
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent))
def generate_deterministic_embeddings(
corpus_file: Path,
output_file: Path,
dim: int = 384,
seed: int = 42,
limit: int | None = None,
) -> None:
"""
Generate deterministic embeddings for a corpus.
Args:
corpus_file: Path to corpus JSONL file
output_file: Output .npy file for embeddings
dim: Embedding dimension
seed: Random seed for reproducibility
limit: Optional limit on number of documents
"""
output_file.parent.mkdir(parents=True, exist_ok=True)
rng = np.random.RandomState(seed)
embeddings = []
count = 0
print(f"Generating deterministic embeddings (dim={dim}, seed={seed})...")
with open(corpus_file, "r", encoding="utf-8") as f:
for line in f:
if limit and count >= limit:
break
if line.strip():
doc = json.loads(line)
# Generate deterministic embedding based on document ID
doc_hash = hash(doc["id"]) % (2**31)
rng_local = np.random.RandomState(seed + doc_hash)
# Generate normalized random vector
emb = rng_local.randn(dim).astype(np.float32)
emb = emb / np.linalg.norm(emb)
embeddings.append(emb)
count += 1
if count % 10000 == 0:
print(f"Processed {count} documents...")
embeddings_array = np.stack(embeddings)
np.save(output_file, embeddings_array)
print(f"Saved {len(embeddings)} embeddings to {output_file}")
def load_embeddings(emb_file: Path) -> np.ndarray:
"""Load embeddings from .npy file."""
return np.load(emb_file)
def main():
parser = argparse.ArgumentParser(description="Prepare embeddings for corpus")
parser.add_argument("--input", type=Path, required=True, help="Corpus JSONL file")
parser.add_argument("--output", type=Path, required=True, help="Output .npy file")
parser.add_argument("--dim", type=int, default=384, help="Embedding dimension")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--limit", type=int, help="Limit number of documents")
args = parser.parse_args()
generate_deterministic_embeddings(
args.input,
args.output,
dim=args.dim,
seed=args.seed,
limit=args.limit,
)
if __name__ == "__main__":
main()