Files
sheepOp/data.example/__init__.py
Carlos Gutierrez 3d2da94ce2 Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch
- Training pipeline with gradient accumulation and mixed precision
- Optimized inference with KV caching
- Multi-format data processing (PDFs, images, code, text)
- Comprehensive documentation
- Apache 2.0 license
- Example training plots included in docs/images/
2025-11-06 22:07:41 -05:00

925 lines
35 KiB
Python

"""
Data loading and preprocessing utilities
Includes comprehensive data processor for multiple file types (PDFs, images, code, text, etc.)
"""
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Optional, Iterator
import json
from pathlib import Path
import logging
from tqdm import tqdm
import hashlib
import pickle
import os
import sys
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Ensure logging output is unbuffered
for handler in logger.handlers:
handler.flush()
# Also ensure root logger handlers are unbuffered
for handler in logging.root.handlers:
handler.flush()
class TextDataset(Dataset):
"""
Dataset for text data.
"""
def __init__(
self,
texts: List[str],
tokenizer,
max_length: int = 512,
stride: Optional[int] = None,
):
"""
Args:
texts: List of text strings
tokenizer: Tokenizer instance
max_length: Maximum sequence length
stride: Stride for sliding window (if None, no overlap)
"""
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
self.stride = stride if stride is not None else max_length
# Tokenize all texts
self.sequences = self._prepare_sequences()
def _prepare_sequences(self) -> List[torch.Tensor]:
"""Tokenize and chunk sequences."""
sequences = []
for text in self.texts:
# Tokenize text
tokens = self.tokenizer.encode(text)
# Chunk into sequences of max_length
for i in range(0, len(tokens), self.stride):
chunk = tokens[i:i + self.max_length]
# Pad if necessary
if len(chunk) < self.max_length:
chunk = chunk + [self.tokenizer.pad_token_id] * (self.max_length - len(chunk))
sequences.append(torch.tensor(chunk, dtype=torch.long))
# Stop if we've covered the entire sequence
if i + self.max_length >= len(tokens):
break
return sequences
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
sequence = self.sequences[idx]
# Input is all tokens except the last one
input_ids = sequence[:-1]
# Labels are all tokens except the first one (shifted by 1)
labels = sequence[1:]
return {
'input_ids': input_ids,
'labels': labels,
}
class SimpleTokenizer:
"""
Simple character-level tokenizer (for backward compatibility).
Uses BPE tokenizer by default if available, falls back to character-level.
"""
def __init__(
self,
vocab_file: Optional[str] = None,
use_bpe: bool = True,
vocab_size: int = 50257,
):
"""
Args:
vocab_file: Optional path to vocabulary file
use_bpe: Whether to use BPE tokenizer (default: True)
vocab_size: Vocabulary size for BPE tokenizer (default: 50257)
"""
self.use_bpe = use_bpe
# Try to use BPE tokenizer if available
if use_bpe:
try:
from .bpe_tokenizer import BPETokenizer
self.bpe_tokenizer = BPETokenizer(vocab_size=vocab_size)
self._use_bpe = True
# Map BPE tokenizer attributes
self.pad_token_id = self.bpe_tokenizer.pad_token_id
self.unk_token_id = self.bpe_tokenizer.unk_token_id
self.bos_token_id = self.bpe_tokenizer.bos_token_id
self.eos_token_id = self.bpe_tokenizer.eos_token_id
self.vocab_size = self.bpe_tokenizer.vocab_size
self.vocab = {i: self.bpe_tokenizer.vocab.get(i, bytes([i])).decode('utf-8', errors='replace')
for i in range(256)} # Limited vocab view
self.inv_vocab = {v: k for k, v in self.vocab.items()}
return
except ImportError:
logger.warning("BPE tokenizer not available, falling back to character-level")
self._use_bpe = False
# Fallback to character-level tokenizer
self._use_bpe = False
if vocab_file and Path(vocab_file).exists():
with open(vocab_file, 'r') as f:
vocab = json.load(f)
self.vocab = vocab
self.inv_vocab = {v: k for k, v in vocab.items()}
else:
# Default: character-level vocabulary
self.vocab = {
'<pad>': 0,
'<unk>': 1,
'<bos>': 2,
'<eos>': 3,
}
# Add printable ASCII characters
for i in range(32, 127):
self.vocab[chr(i)] = len(self.vocab)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.pad_token_id = self.vocab.get('<pad>', 0)
self.unk_token_id = self.vocab.get('<unk>', 1)
self.bos_token_id = self.vocab.get('<bos>', 2)
self.eos_token_id = self.vocab.get('<eos>', 3)
self.vocab_size = len(self.vocab)
def encode(self, text: str) -> List[int]:
"""Encode text to token IDs."""
if self._use_bpe:
return self.bpe_tokenizer.encode(text)
# Character-level encoding
tokens = []
for char in text:
tokens.append(self.vocab.get(char, self.unk_token_id))
return tokens
def decode(self, token_ids: List[int]) -> str:
"""Decode token IDs to text."""
if self._use_bpe:
return self.bpe_tokenizer.decode(token_ids)
# Character-level decoding
chars = []
for tid in token_ids:
if tid in self.inv_vocab:
char = self.inv_vocab[tid]
if char not in ['<pad>', '<bos>', '<eos>']:
chars.append(char)
return ''.join(chars)
def save_vocab(self, vocab_file: str):
"""Save vocabulary to file."""
if self._use_bpe:
# Save BPE tokenizer
merges_file = str(vocab_file).replace('.json', '_merges.json')
self.bpe_tokenizer.save(merges_file, vocab_file)
else:
# Save character-level vocab
with open(vocab_file, 'w') as f:
json.dump(self.vocab, f, indent=2)
def train(self, texts: List[str], num_merges: Optional[int] = None, verbose: bool = False):
"""Train the tokenizer on texts (BPE only)."""
if self._use_bpe:
self.bpe_tokenizer.train(texts, num_merges=num_merges, verbose=verbose)
# Update vocab size
self.vocab_size = self.bpe_tokenizer.vocab_size
else:
logger.warning("Training not supported for character-level tokenizer")
def create_dataloader(
texts: List[str],
tokenizer,
batch_size: int = 32,
max_length: int = 512,
shuffle: bool = True,
num_workers: int = 0,
) -> DataLoader:
"""
Create a DataLoader for text data.
Args:
texts: List of text strings
tokenizer: Tokenizer instance
batch_size: Batch size
max_length: Maximum sequence length
shuffle: Whether to shuffle data
num_workers: Number of data loading workers
Returns:
DataLoader instance
"""
dataset = TextDataset(
texts=texts,
tokenizer=tokenizer,
max_length=max_length,
)
def collate_fn(batch):
"""Collate function for batching."""
input_ids = torch.stack([item['input_ids'] for item in batch])
labels = torch.stack([item['labels'] for item in batch])
return {
'input_ids': input_ids,
'labels': labels,
}
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available(),
)
# ============================================================================
# Data Processor for Multiple File Types
# ============================================================================
class DataProcessor:
"""
Process various file types and extract text for training.
Supports: PDFs, images (OCR), code files, text files, and more.
"""
# Supported file extensions
TEXT_EXTENSIONS = {'.txt', '.md', '.rst', '.log', '.csv', '.json', '.jsonl', '.xml', '.html', '.htm'}
CODE_EXTENSIONS = {
'.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp',
'.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala', '.r',
'.sql', '.sh', '.bash', '.zsh', '.fish', '.yaml', '.yml', '.toml',
'.ini', '.cfg', '.conf', '.vue', '.svelte', '.dart', '.lua', '.pl',
'.hs', '.ml', '.mli', '.elm', '.ex', '.exs', '.jl', '.clj', '.cljs'
}
PDF_EXTENSIONS = {'.pdf'}
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.tif', '.webp'}
def __init__(self, use_ocr: bool = True, use_pdf_extraction: bool = True, cache_dir: Optional[Path] = None):
"""
Initialize data processor.
Args:
use_ocr: Whether to use OCR for images (requires pytesseract)
use_pdf_extraction: Whether to extract text from PDFs (requires PyPDF2 or pdfplumber)
cache_dir: Directory to store cache files (default: .cache in data directory)
"""
self.use_ocr = use_ocr
self.use_pdf_extraction = use_pdf_extraction
self.cache_dir = Path(cache_dir) if cache_dir else None
self._check_dependencies()
def _get_cache_dir(self, directory: Path) -> Path:
"""Get cache directory for a given data directory."""
if self.cache_dir:
return self.cache_dir
# Default: .cache in the data directory
cache_dir = directory / '.cache'
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def _compute_directory_hash(self, directory: Path, recursive: bool = True) -> str:
"""
Compute a hash of directory contents to detect changes.
Uses file paths and modification times.
"""
directory = Path(directory)
file_info = []
pattern = '**/*' if recursive else '*'
scanned_count = 0
try:
for file_path in directory.glob(pattern):
scanned_count += 1
# Progress feedback every 5000 files (hash computation can be slow)
if scanned_count % 5000 == 0:
logger.info(f"Computing directory hash: scanned {scanned_count:,} paths...")
sys.stderr.flush()
try:
if file_path.is_file():
stat = file_path.stat()
file_info.append(f"{file_path.relative_to(directory)}:{stat.st_mtime}:{stat.st_size}")
except (OSError, PermissionError):
continue
except KeyboardInterrupt:
logger.warning(f"Directory hash computation interrupted after scanning {scanned_count:,} paths")
raise
except KeyboardInterrupt:
# Re-raise to allow graceful handling upstream
logger.warning("Directory hash computation interrupted. Will skip cache and do fresh scan.")
raise
# Sort for consistent hashing
file_info.sort()
content = '\n'.join(file_info)
return hashlib.md5(content.encode()).hexdigest()
def _get_cache_path(self, directory: Path, cache_type: str = 'files') -> Path:
"""Get cache file path for a directory."""
cache_dir = self._get_cache_dir(directory)
# Create a safe filename from directory path
dir_hash = hashlib.md5(str(directory.absolute()).encode()).hexdigest()[:8]
return cache_dir / f"{cache_type}_{dir_hash}.pkl"
def _load_cache(self, cache_path: Path) -> Optional[Dict]:
"""Load cache from file."""
if not cache_path.exists():
return None
try:
with open(cache_path, 'rb') as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"Failed to load cache from {cache_path}: {e}")
return None
def _save_cache(self, cache_path: Path, data: Dict):
"""Save cache to file."""
try:
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, 'wb') as f:
pickle.dump(data, f)
except Exception as e:
logger.warning(f"Failed to save cache to {cache_path}: {e}")
def clear_cache(self, directory: Path):
"""
Clear cache for a directory.
Args:
directory: Directory path to clear cache for
"""
cache_path = self._get_cache_path(directory, 'files')
if cache_path.exists():
try:
cache_path.unlink()
logger.info(f"✅ Cleared cache for {directory}")
except Exception as e:
logger.warning(f"Failed to clear cache: {e}")
else:
logger.info(f"No cache found for {directory}")
def _check_dependencies(self):
"""Check if required dependencies are available."""
if self.use_ocr:
try:
import pytesseract
from PIL import Image
self._ocr_available = True
except ImportError:
logger.warning("pytesseract or PIL not available. OCR disabled.")
self._ocr_available = False
self.use_ocr = False
if self.use_pdf_extraction:
try:
import PyPDF2
self._pypdf2_available = True
except ImportError:
try:
import pdfplumber
self._pdfplumber_available = True
self._pypdf2_available = False
except ImportError:
logger.warning("PyPDF2 or pdfplumber not available. PDF extraction disabled.")
self._pdfplumber_available = False
self.use_pdf_extraction = False
def process_file(self, file_path: Path) -> Iterator[str]:
"""
Process a single file and yield text lines.
Args:
file_path: Path to the file
Yields:
Text lines extracted from the file
"""
file_path = Path(file_path)
if not file_path.exists():
logger.warning(f"File not found: {file_path}")
return
suffix = file_path.suffix.lower()
try:
if suffix in self.TEXT_EXTENSIONS:
yield from self._process_text_file(file_path)
elif suffix in self.CODE_EXTENSIONS:
yield from self._process_code_file(file_path)
elif suffix in self.PDF_EXTENSIONS:
yield from self._process_pdf(file_path)
elif suffix in self.IMAGE_EXTENSIONS:
yield from self._process_image(file_path)
else:
# Try to process as text file as fallback (many file types can be read as text)
# Only log at debug level to avoid spam
logger.debug(f"Unsupported file type: {file_path} (extension: {suffix}), attempting as text...")
try:
yield from self._process_text_file(file_path)
except KeyboardInterrupt:
raise
except Exception as e:
logger.debug(f"Failed to process {file_path} as text: {e}")
except KeyboardInterrupt:
# Re-raise KeyboardInterrupt to allow graceful shutdown
logger.warning(f"Interrupted while processing {file_path}")
raise
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
def _process_text_file(self, file_path: Path) -> Iterator[str]:
"""Process a text file."""
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if line:
yield line
except KeyboardInterrupt:
# Re-raise KeyboardInterrupt to allow graceful shutdown
logger.warning(f"Interrupted while reading {file_path}")
raise
except UnicodeDecodeError:
# Try with different encoding
try:
with open(file_path, 'r', encoding='latin-1', errors='ignore') as f:
for line in f:
line = line.strip()
if line:
yield line
except KeyboardInterrupt:
logger.warning(f"Interrupted while reading {file_path}")
raise
except Exception as e:
logger.error(f"Failed to read {file_path}: {e}")
def _process_code_file(self, file_path: Path) -> Iterator[str]:
"""Process a code file."""
# Code files are processed as text, but we can add syntax-aware processing
# For now, just extract text lines
yield from self._process_text_file(file_path)
def _process_pdf(self, file_path: Path) -> Iterator[str]:
"""Extract text from PDF file."""
if not self.use_pdf_extraction:
logger.warning(f"PDF extraction disabled. Skipping {file_path}")
return
try:
if self._pypdf2_available:
import PyPDF2
with open(file_path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
for page_num, page in enumerate(pdf_reader.pages):
try:
text = page.extract_text()
if text:
# Split into sentences/lines
for line in text.split('\n'):
line = line.strip()
if line and len(line) > 5: # Filter very short lines
yield line
except KeyboardInterrupt:
logger.warning(f"Interrupted while processing PDF page {page_num} from {file_path}")
raise
except Exception as e:
logger.warning(f"Error extracting page {page_num} from {file_path}: {e}")
elif self._pdfplumber_available:
import pdfplumber
with pdfplumber.open(file_path) as pdf:
for page_num, page in enumerate(pdf.pages):
try:
text = page.extract_text()
if text:
for line in text.split('\n'):
line = line.strip()
if line and len(line) > 5:
yield line
except KeyboardInterrupt:
logger.warning(f"Interrupted while processing PDF page {page_num} from {file_path}")
raise
except Exception as e:
logger.warning(f"Error extracting page {page_num} from {file_path}: {e}")
except KeyboardInterrupt:
logger.warning(f"Interrupted while processing PDF {file_path}")
raise
except Exception as e:
logger.error(f"Failed to extract text from PDF {file_path}: {e}")
def _process_image(self, file_path: Path) -> Iterator[str]:
"""Extract text from image using OCR."""
if not self.use_ocr or not self._ocr_available:
logger.warning(f"OCR disabled or unavailable. Skipping {file_path}")
return
try:
import pytesseract
from PIL import Image
# Open and process image
img = Image.open(file_path)
# Perform OCR
text = pytesseract.image_to_string(img)
if text:
# Split into lines
for line in text.split('\n'):
line = line.strip()
if line and len(line) > 3: # Filter very short lines
yield line
except KeyboardInterrupt:
logger.warning(f"Interrupted while processing image {file_path}")
raise
except Exception as e:
logger.error(f"Failed to extract text from image {file_path}: {e}")
def process_directory(
self,
directory: Path,
recursive: bool = True,
include_patterns: Optional[List[str]] = None,
exclude_patterns: Optional[List[str]] = None,
min_length: int = 10,
) -> Iterator[str]:
"""
Process all files in a directory.
Args:
directory: Directory path
recursive: Whether to process subdirectories
include_patterns: Optional list of glob patterns to include
exclude_patterns: Optional list of glob patterns to exclude
min_length: Minimum length for extracted text lines
Yields:
Text lines from all processed files
"""
directory = Path(directory)
if not directory.exists():
logger.error(f"Directory not found: {directory}")
return
# Try to load cached file list
cache_path = self._get_cache_path(directory, 'files')
# Compute directory hash (may be interrupted)
try:
logger.info("Computing directory hash for cache validation...")
logger.info("(This may take a while for large directories. Press Ctrl+C to skip cache and do fresh scan)")
sys.stderr.flush()
current_hash = self._compute_directory_hash(directory, recursive)
cached_data = self._load_cache(cache_path)
except KeyboardInterrupt:
logger.warning("\n⚠️ Directory hash computation interrupted.")
logger.warning(" Skipping cache validation and doing fresh directory scan...")
logger.warning(" (Press Ctrl+C again to stop completely)")
sys.stderr.flush()
current_hash = None # Force cache miss
cached_data = None
# Don't re-raise - allow user to continue with fresh scan
# If they want to stop completely, they can press Ctrl+C again during scanning
files_to_process = []
scanned_count = 0
skipped_count = 0
# Check if cache is valid
if current_hash and cached_data and cached_data.get('hash') == current_hash:
files_to_process = [Path(f) for f in cached_data.get('files', [])]
logger.info(f"✅ Loaded {len(files_to_process):,} files from cache (skipping directory scan)")
else:
# Cache miss or invalid - scan directory
logger.info("Scanning directory (cache miss or invalid)...")
# Collect all supported file extensions
all_supported_extensions = (
self.TEXT_EXTENSIONS |
self.CODE_EXTENSIONS |
self.PDF_EXTENSIONS |
self.IMAGE_EXTENSIONS
)
if recursive:
pattern = '**/*'
else:
pattern = '*'
# Default exclude patterns for common directories that don't contain training data
default_exclude_patterns = [
'**/.git/**',
'**/__pycache__/**',
'**/node_modules/**',
'**/.venv/**',
'**/venv/**',
'**/.env/**',
'**/.pytest_cache/**',
'**/.mypy_cache/**',
'**/.tox/**',
'**/.coverage/**',
'**/dist/**',
'**/build/**',
'**/*.pyc',
'**/.DS_Store',
]
# Merge user exclude patterns with defaults
all_exclude_patterns = default_exclude_patterns.copy()
if exclude_patterns:
# Convert any Path objects to strings
all_exclude_patterns.extend(str(p) if isinstance(p, Path) else p for p in exclude_patterns)
# Ensure all patterns are strings (not Path objects)
all_exclude_patterns = [str(p) for p in all_exclude_patterns]
# Convert include_patterns to strings as well
if include_patterns:
include_patterns = [str(p) if isinstance(p, Path) else p for p in include_patterns]
logger.info(f"Scanning directory: {directory} (recursive={recursive})...")
logger.info("This may take several minutes for large directories. Please wait...")
sys.stderr.flush() # Force flush to show message immediately
try:
for file_path in directory.glob(pattern):
scanned_count += 1
# Progress reporting every 1000 files scanned
if scanned_count % 1000 == 0:
logger.info(f"Scanned {scanned_count:,} paths, found {len(files_to_process):,} files to process...")
sys.stderr.flush() # Force flush to show progress immediately
# Skip if not a file (handles symlinks, directories, etc. gracefully)
try:
if not file_path.is_file():
continue
except (OSError, PermissionError) as e:
# Skip inaccessible files (broken symlinks, permission denied, etc.)
skipped_count += 1
if skipped_count <= 10: # Only log first 10 to avoid spam
logger.debug(f"Skipping inaccessible path: {file_path} ({e})")
continue
# Early filtering by extension to avoid checking unsupported files
suffix = file_path.suffix.lower()
if suffix not in all_supported_extensions:
continue
# Check include/exclude patterns
if include_patterns:
if not any(file_path.match(pattern) for pattern in include_patterns):
continue
if all_exclude_patterns:
if any(file_path.match(pattern) for pattern in all_exclude_patterns):
continue
files_to_process.append(file_path)
except KeyboardInterrupt:
logger.warning(f"Directory scanning interrupted. Found {len(files_to_process)} files so far.")
raise
except Exception as e:
logger.error(f"Error during directory scanning: {e}")
logger.info(f"Continuing with {len(files_to_process)} files found so far...")
if skipped_count > 10:
logger.info(f"Skipped {skipped_count} inaccessible paths")
logger.info(f"Found {len(files_to_process):,} files to process (scanned {scanned_count:,} paths)")
sys.stderr.flush() # Force flush
# Save file list to cache
cache_data = {
'hash': current_hash,
'files': [str(f.absolute()) for f in files_to_process],
'recursive': recursive,
}
self._save_cache(cache_path, cache_data)
logger.info(f"💾 Cached file list ({len(files_to_process):,} files) for future use")
sys.stderr.flush() # Force flush
# Process each file
processed_count = 0
skipped_count = 0
error_count = 0
total_lines = 0
total_files = len(files_to_process)
if total_files == 0:
logger.warning("No files found to process!")
return
logger.info(f"Starting to process {total_files} files with progress bar...")
# Create progress bar
pbar = tqdm(
total=total_files,
desc="Processing files",
unit="file",
ncols=120,
mininterval=0.1, # Update at least every 0.1 seconds
maxinterval=1.0, # Force update at least once per second
file=sys.stderr, # Write to stderr to avoid buffering issues
dynamic_ncols=True, # Auto-adjust to terminal width
disable=False, # Explicitly enable
)
try:
for idx, file_path in enumerate(files_to_process, 1):
try:
file_lines = list(self.process_file(file_path))
if file_lines:
processed_count += 1
for line in file_lines:
if len(line) >= min_length:
yield line
total_lines += 1
else:
skipped_count += 1
# Update progress bar with statistics
pbar.set_postfix({
'Processed': processed_count,
'Skipped': skipped_count,
'Errors': error_count,
'Lines': f"{total_lines:,}"
})
pbar.update(1) # Advance progress bar
pbar.refresh() # Force immediate refresh
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
except KeyboardInterrupt:
pbar.close()
logger.warning(
f"Processing interrupted. "
f"Files: {idx}/{total_files}, Processed: {processed_count}, "
f"Skipped: {skipped_count}, Errors: {error_count}, "
f"Lines extracted: {total_lines:,}"
)
raise
except Exception as e:
error_count += 1
logger.error(f"Error processing {file_path}: {e}")
# Update progress bar even on errors
pbar.set_postfix({
'Processed': processed_count,
'Skipped': skipped_count,
'Errors': error_count,
'Lines': f"{total_lines:,}"
})
pbar.update(1) # Advance progress bar even on error
pbar.refresh() # Force immediate refresh
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
finally:
pbar.close()
logger.info(
f"Processing complete: {processed_count}/{total_files} files processed successfully, "
f"{skipped_count} skipped, {error_count} errors, {total_lines:,} lines extracted"
)
def process_to_list(
self,
directory: Path,
recursive: bool = True,
include_patterns: Optional[List[str]] = None,
exclude_patterns: Optional[List[str]] = None,
min_length: int = 10,
max_samples: Optional[int] = None,
) -> List[str]:
"""
Process directory and return list of text lines.
Args:
directory: Directory path
recursive: Whether to process subdirectories
include_patterns: Optional list of glob patterns to include
exclude_patterns: Optional list of glob patterns to exclude
min_length: Minimum length for extracted text lines
max_samples: Maximum number of samples to return (None = all)
Returns:
List of text lines
"""
logger.info(f"Starting data extraction from {directory}...")
logger.info("This may take a while for large directories. Progress will be shown below.")
sys.stderr.flush() # Force flush to show message immediately
texts = []
try:
for text in self.process_directory(
directory=directory,
recursive=recursive,
include_patterns=include_patterns,
exclude_patterns=exclude_patterns,
min_length=min_length,
):
texts.append(text)
if max_samples and len(texts) >= max_samples:
logger.info(f"Reached max_samples limit ({max_samples}). Stopping extraction.")
break
except KeyboardInterrupt:
# Return partial results if interrupted
logger.warning(
f"Data processing interrupted. Returning {len(texts):,} text samples collected so far."
)
# Re-raise to allow caller to handle if needed
raise
logger.info(f"✅ Extracted {len(texts):,} text samples from {directory}")
return texts
def extract_text_from_directory(
directory: Path,
recursive: bool = True,
use_ocr: bool = True,
use_pdf_extraction: bool = True,
min_length: int = 10,
max_samples: Optional[int] = None,
) -> List[str]:
"""
Convenience function to extract text from a directory.
Args:
directory: Directory path
recursive: Whether to process subdirectories
use_ocr: Whether to use OCR for images
use_pdf_extraction: Whether to extract text from PDFs
min_length: Minimum length for extracted text lines
max_samples: Maximum number of samples to return (None = all)
Returns:
List of text lines
"""
processor = DataProcessor(use_ocr=use_ocr, use_pdf_extraction=use_pdf_extraction)
try:
return processor.process_to_list(
directory=directory,
recursive=recursive,
min_length=min_length,
max_samples=max_samples,
)
except KeyboardInterrupt:
logger.error(
"\n⚠️ Data processing interrupted by user (Ctrl+C).\n"
" No data was loaded. Please run the training command again to retry."
)
# Re-raise to stop training
raise
# Try to import BPE tokenizer for direct access
try:
from .bpe_tokenizer import BPETokenizer
__all__ = [
'TextDataset',
'SimpleTokenizer',
'BPETokenizer',
'create_dataloader',
'DataProcessor',
'extract_text_from_directory',
]
except ImportError:
__all__ = [
'TextDataset',
'SimpleTokenizer',
'create_dataloader',
'DataProcessor',
'extract_text_from_directory',
]