Files
llm-rag-ds-optimizer/scripts/download_corpus.py

74 lines
2.2 KiB
Python

"""Download and prepare datasets."""
import argparse
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from llmds.data_sources.msmarco import download_msmarco
from llmds.data_sources.beir_loader import download_beir
from llmds.data_sources.amazon_reviews import download_amazon_reviews
from llmds.data_sources.yelp import download_yelp
from llmds.data_sources.wikipedia import download_wikipedia
from llmds.data_sources.commoncrawl import download_commoncrawl
def main():
parser = argparse.ArgumentParser(description="Download datasets")
parser.add_argument(
"--source",
required=True,
help="Dataset source: msmarco, beir:task (e.g., beir:fiqa), amazon23, yelp, wikipedia, commoncrawl"
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Output directory for corpus"
)
parser.add_argument(
"--limit",
type=int,
help="Limit number of documents"
)
parser.add_argument(
"--cc-month",
type=str,
help="Common Crawl month (e.g., 'CC-MAIN-2025-14')"
)
args = parser.parse_args()
# Parse source (handle beir:task format)
source_parts = args.source.split(":", 1)
source_base = source_parts[0]
task = source_parts[1] if len(source_parts) > 1 else None
if source_base == "msmarco":
download_msmarco(args.output)
elif source_base == "beir":
if not task:
print("Error: BEIR requires task name (e.g., 'beir:fiqa', 'beir:scidocs')")
sys.exit(1)
download_beir(task, args.output)
elif source_base == "amazon23":
download_amazon_reviews(args.output, limit=args.limit)
elif source_base == "yelp":
download_yelp(args.output)
elif source_base == "wikipedia":
download_wikipedia(args.output)
elif source_base == "commoncrawl":
download_commoncrawl(args.output, cc_month=args.cc_month, limit=args.limit)
else:
print(f"Error: Unknown source '{source_base}'. Use: msmarco, beir:task, amazon23, yelp, wikipedia, commoncrawl")
sys.exit(1)
print(f"✓ Dataset downloaded to {args.output}")
if __name__ == "__main__":
main()