Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
924
data.example/__init__.py
Normal file
924
data.example/__init__.py
Normal file
@@ -0,0 +1,924 @@
|
||||
"""
|
||||
Data loading and preprocessing utilities
|
||||
Includes comprehensive data processor for multiple file types (PDFs, images, code, text, etc.)
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from typing import List, Dict, Optional, Iterator
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import hashlib
|
||||
import pickle
|
||||
import os
|
||||
import sys
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ensure logging output is unbuffered
|
||||
for handler in logger.handlers:
|
||||
handler.flush()
|
||||
# Also ensure root logger handlers are unbuffered
|
||||
for handler in logging.root.handlers:
|
||||
handler.flush()
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""
|
||||
Dataset for text data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: List[str],
|
||||
tokenizer,
|
||||
max_length: int = 512,
|
||||
stride: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
texts: List of text strings
|
||||
tokenizer: Tokenizer instance
|
||||
max_length: Maximum sequence length
|
||||
stride: Stride for sliding window (if None, no overlap)
|
||||
"""
|
||||
self.texts = texts
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.stride = stride if stride is not None else max_length
|
||||
|
||||
# Tokenize all texts
|
||||
self.sequences = self._prepare_sequences()
|
||||
|
||||
def _prepare_sequences(self) -> List[torch.Tensor]:
|
||||
"""Tokenize and chunk sequences."""
|
||||
sequences = []
|
||||
|
||||
for text in self.texts:
|
||||
# Tokenize text
|
||||
tokens = self.tokenizer.encode(text)
|
||||
|
||||
# Chunk into sequences of max_length
|
||||
for i in range(0, len(tokens), self.stride):
|
||||
chunk = tokens[i:i + self.max_length]
|
||||
|
||||
# Pad if necessary
|
||||
if len(chunk) < self.max_length:
|
||||
chunk = chunk + [self.tokenizer.pad_token_id] * (self.max_length - len(chunk))
|
||||
|
||||
sequences.append(torch.tensor(chunk, dtype=torch.long))
|
||||
|
||||
# Stop if we've covered the entire sequence
|
||||
if i + self.max_length >= len(tokens):
|
||||
break
|
||||
|
||||
return sequences
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sequence = self.sequences[idx]
|
||||
|
||||
# Input is all tokens except the last one
|
||||
input_ids = sequence[:-1]
|
||||
# Labels are all tokens except the first one (shifted by 1)
|
||||
labels = sequence[1:]
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'labels': labels,
|
||||
}
|
||||
|
||||
|
||||
class SimpleTokenizer:
|
||||
"""
|
||||
Simple character-level tokenizer (for backward compatibility).
|
||||
Uses BPE tokenizer by default if available, falls back to character-level.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file: Optional[str] = None,
|
||||
use_bpe: bool = True,
|
||||
vocab_size: int = 50257,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_file: Optional path to vocabulary file
|
||||
use_bpe: Whether to use BPE tokenizer (default: True)
|
||||
vocab_size: Vocabulary size for BPE tokenizer (default: 50257)
|
||||
"""
|
||||
self.use_bpe = use_bpe
|
||||
|
||||
# Try to use BPE tokenizer if available
|
||||
if use_bpe:
|
||||
try:
|
||||
from .bpe_tokenizer import BPETokenizer
|
||||
self.bpe_tokenizer = BPETokenizer(vocab_size=vocab_size)
|
||||
self._use_bpe = True
|
||||
|
||||
# Map BPE tokenizer attributes
|
||||
self.pad_token_id = self.bpe_tokenizer.pad_token_id
|
||||
self.unk_token_id = self.bpe_tokenizer.unk_token_id
|
||||
self.bos_token_id = self.bpe_tokenizer.bos_token_id
|
||||
self.eos_token_id = self.bpe_tokenizer.eos_token_id
|
||||
self.vocab_size = self.bpe_tokenizer.vocab_size
|
||||
self.vocab = {i: self.bpe_tokenizer.vocab.get(i, bytes([i])).decode('utf-8', errors='replace')
|
||||
for i in range(256)} # Limited vocab view
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
return
|
||||
except ImportError:
|
||||
logger.warning("BPE tokenizer not available, falling back to character-level")
|
||||
self._use_bpe = False
|
||||
|
||||
# Fallback to character-level tokenizer
|
||||
self._use_bpe = False
|
||||
|
||||
if vocab_file and Path(vocab_file).exists():
|
||||
with open(vocab_file, 'r') as f:
|
||||
vocab = json.load(f)
|
||||
self.vocab = vocab
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
else:
|
||||
# Default: character-level vocabulary
|
||||
self.vocab = {
|
||||
'<pad>': 0,
|
||||
'<unk>': 1,
|
||||
'<bos>': 2,
|
||||
'<eos>': 3,
|
||||
}
|
||||
# Add printable ASCII characters
|
||||
for i in range(32, 127):
|
||||
self.vocab[chr(i)] = len(self.vocab)
|
||||
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
|
||||
self.pad_token_id = self.vocab.get('<pad>', 0)
|
||||
self.unk_token_id = self.vocab.get('<unk>', 1)
|
||||
self.bos_token_id = self.vocab.get('<bos>', 2)
|
||||
self.eos_token_id = self.vocab.get('<eos>', 3)
|
||||
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
"""Encode text to token IDs."""
|
||||
if self._use_bpe:
|
||||
return self.bpe_tokenizer.encode(text)
|
||||
|
||||
# Character-level encoding
|
||||
tokens = []
|
||||
for char in text:
|
||||
tokens.append(self.vocab.get(char, self.unk_token_id))
|
||||
return tokens
|
||||
|
||||
def decode(self, token_ids: List[int]) -> str:
|
||||
"""Decode token IDs to text."""
|
||||
if self._use_bpe:
|
||||
return self.bpe_tokenizer.decode(token_ids)
|
||||
|
||||
# Character-level decoding
|
||||
chars = []
|
||||
for tid in token_ids:
|
||||
if tid in self.inv_vocab:
|
||||
char = self.inv_vocab[tid]
|
||||
if char not in ['<pad>', '<bos>', '<eos>']:
|
||||
chars.append(char)
|
||||
return ''.join(chars)
|
||||
|
||||
def save_vocab(self, vocab_file: str):
|
||||
"""Save vocabulary to file."""
|
||||
if self._use_bpe:
|
||||
# Save BPE tokenizer
|
||||
merges_file = str(vocab_file).replace('.json', '_merges.json')
|
||||
self.bpe_tokenizer.save(merges_file, vocab_file)
|
||||
else:
|
||||
# Save character-level vocab
|
||||
with open(vocab_file, 'w') as f:
|
||||
json.dump(self.vocab, f, indent=2)
|
||||
|
||||
def train(self, texts: List[str], num_merges: Optional[int] = None, verbose: bool = False):
|
||||
"""Train the tokenizer on texts (BPE only)."""
|
||||
if self._use_bpe:
|
||||
self.bpe_tokenizer.train(texts, num_merges=num_merges, verbose=verbose)
|
||||
# Update vocab size
|
||||
self.vocab_size = self.bpe_tokenizer.vocab_size
|
||||
else:
|
||||
logger.warning("Training not supported for character-level tokenizer")
|
||||
|
||||
|
||||
def create_dataloader(
|
||||
texts: List[str],
|
||||
tokenizer,
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512,
|
||||
shuffle: bool = True,
|
||||
num_workers: int = 0,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Create a DataLoader for text data.
|
||||
|
||||
Args:
|
||||
texts: List of text strings
|
||||
tokenizer: Tokenizer instance
|
||||
batch_size: Batch size
|
||||
max_length: Maximum sequence length
|
||||
shuffle: Whether to shuffle data
|
||||
num_workers: Number of data loading workers
|
||||
|
||||
Returns:
|
||||
DataLoader instance
|
||||
"""
|
||||
dataset = TextDataset(
|
||||
texts=texts,
|
||||
tokenizer=tokenizer,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
def collate_fn(batch):
|
||||
"""Collate function for batching."""
|
||||
input_ids = torch.stack([item['input_ids'] for item in batch])
|
||||
labels = torch.stack([item['labels'] for item in batch])
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'labels': labels,
|
||||
}
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Processor for Multiple File Types
|
||||
# ============================================================================
|
||||
|
||||
class DataProcessor:
|
||||
"""
|
||||
Process various file types and extract text for training.
|
||||
Supports: PDFs, images (OCR), code files, text files, and more.
|
||||
"""
|
||||
|
||||
# Supported file extensions
|
||||
TEXT_EXTENSIONS = {'.txt', '.md', '.rst', '.log', '.csv', '.json', '.jsonl', '.xml', '.html', '.htm'}
|
||||
CODE_EXTENSIONS = {
|
||||
'.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp',
|
||||
'.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala', '.r',
|
||||
'.sql', '.sh', '.bash', '.zsh', '.fish', '.yaml', '.yml', '.toml',
|
||||
'.ini', '.cfg', '.conf', '.vue', '.svelte', '.dart', '.lua', '.pl',
|
||||
'.hs', '.ml', '.mli', '.elm', '.ex', '.exs', '.jl', '.clj', '.cljs'
|
||||
}
|
||||
PDF_EXTENSIONS = {'.pdf'}
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.tif', '.webp'}
|
||||
|
||||
def __init__(self, use_ocr: bool = True, use_pdf_extraction: bool = True, cache_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize data processor.
|
||||
|
||||
Args:
|
||||
use_ocr: Whether to use OCR for images (requires pytesseract)
|
||||
use_pdf_extraction: Whether to extract text from PDFs (requires PyPDF2 or pdfplumber)
|
||||
cache_dir: Directory to store cache files (default: .cache in data directory)
|
||||
"""
|
||||
self.use_ocr = use_ocr
|
||||
self.use_pdf_extraction = use_pdf_extraction
|
||||
self.cache_dir = Path(cache_dir) if cache_dir else None
|
||||
self._check_dependencies()
|
||||
|
||||
def _get_cache_dir(self, directory: Path) -> Path:
|
||||
"""Get cache directory for a given data directory."""
|
||||
if self.cache_dir:
|
||||
return self.cache_dir
|
||||
# Default: .cache in the data directory
|
||||
cache_dir = directory / '.cache'
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
def _compute_directory_hash(self, directory: Path, recursive: bool = True) -> str:
|
||||
"""
|
||||
Compute a hash of directory contents to detect changes.
|
||||
Uses file paths and modification times.
|
||||
"""
|
||||
directory = Path(directory)
|
||||
file_info = []
|
||||
|
||||
pattern = '**/*' if recursive else '*'
|
||||
scanned_count = 0
|
||||
|
||||
try:
|
||||
for file_path in directory.glob(pattern):
|
||||
scanned_count += 1
|
||||
|
||||
# Progress feedback every 5000 files (hash computation can be slow)
|
||||
if scanned_count % 5000 == 0:
|
||||
logger.info(f"Computing directory hash: scanned {scanned_count:,} paths...")
|
||||
sys.stderr.flush()
|
||||
|
||||
try:
|
||||
if file_path.is_file():
|
||||
stat = file_path.stat()
|
||||
file_info.append(f"{file_path.relative_to(directory)}:{stat.st_mtime}:{stat.st_size}")
|
||||
except (OSError, PermissionError):
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Directory hash computation interrupted after scanning {scanned_count:,} paths")
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
# Re-raise to allow graceful handling upstream
|
||||
logger.warning("Directory hash computation interrupted. Will skip cache and do fresh scan.")
|
||||
raise
|
||||
|
||||
# Sort for consistent hashing
|
||||
file_info.sort()
|
||||
content = '\n'.join(file_info)
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def _get_cache_path(self, directory: Path, cache_type: str = 'files') -> Path:
|
||||
"""Get cache file path for a directory."""
|
||||
cache_dir = self._get_cache_dir(directory)
|
||||
# Create a safe filename from directory path
|
||||
dir_hash = hashlib.md5(str(directory.absolute()).encode()).hexdigest()[:8]
|
||||
return cache_dir / f"{cache_type}_{dir_hash}.pkl"
|
||||
|
||||
def _load_cache(self, cache_path: Path) -> Optional[Dict]:
|
||||
"""Load cache from file."""
|
||||
if not cache_path.exists():
|
||||
return None
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load cache from {cache_path}: {e}")
|
||||
return None
|
||||
|
||||
def _save_cache(self, cache_path: Path, data: Dict):
|
||||
"""Save cache to file."""
|
||||
try:
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cache_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save cache to {cache_path}: {e}")
|
||||
|
||||
def clear_cache(self, directory: Path):
|
||||
"""
|
||||
Clear cache for a directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path to clear cache for
|
||||
"""
|
||||
cache_path = self._get_cache_path(directory, 'files')
|
||||
if cache_path.exists():
|
||||
try:
|
||||
cache_path.unlink()
|
||||
logger.info(f"✅ Cleared cache for {directory}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear cache: {e}")
|
||||
else:
|
||||
logger.info(f"No cache found for {directory}")
|
||||
|
||||
def _check_dependencies(self):
|
||||
"""Check if required dependencies are available."""
|
||||
if self.use_ocr:
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
self._ocr_available = True
|
||||
except ImportError:
|
||||
logger.warning("pytesseract or PIL not available. OCR disabled.")
|
||||
self._ocr_available = False
|
||||
self.use_ocr = False
|
||||
|
||||
if self.use_pdf_extraction:
|
||||
try:
|
||||
import PyPDF2
|
||||
self._pypdf2_available = True
|
||||
except ImportError:
|
||||
try:
|
||||
import pdfplumber
|
||||
self._pdfplumber_available = True
|
||||
self._pypdf2_available = False
|
||||
except ImportError:
|
||||
logger.warning("PyPDF2 or pdfplumber not available. PDF extraction disabled.")
|
||||
self._pdfplumber_available = False
|
||||
self.use_pdf_extraction = False
|
||||
|
||||
def process_file(self, file_path: Path) -> Iterator[str]:
|
||||
"""
|
||||
Process a single file and yield text lines.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Yields:
|
||||
Text lines extracted from the file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"File not found: {file_path}")
|
||||
return
|
||||
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
try:
|
||||
if suffix in self.TEXT_EXTENSIONS:
|
||||
yield from self._process_text_file(file_path)
|
||||
elif suffix in self.CODE_EXTENSIONS:
|
||||
yield from self._process_code_file(file_path)
|
||||
elif suffix in self.PDF_EXTENSIONS:
|
||||
yield from self._process_pdf(file_path)
|
||||
elif suffix in self.IMAGE_EXTENSIONS:
|
||||
yield from self._process_image(file_path)
|
||||
else:
|
||||
# Try to process as text file as fallback (many file types can be read as text)
|
||||
# Only log at debug level to avoid spam
|
||||
logger.debug(f"Unsupported file type: {file_path} (extension: {suffix}), attempting as text...")
|
||||
try:
|
||||
yield from self._process_text_file(file_path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to process {file_path} as text: {e}")
|
||||
except KeyboardInterrupt:
|
||||
# Re-raise KeyboardInterrupt to allow graceful shutdown
|
||||
logger.warning(f"Interrupted while processing {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
def _process_text_file(self, file_path: Path) -> Iterator[str]:
|
||||
"""Process a text file."""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
yield line
|
||||
except KeyboardInterrupt:
|
||||
# Re-raise KeyboardInterrupt to allow graceful shutdown
|
||||
logger.warning(f"Interrupted while reading {file_path}")
|
||||
raise
|
||||
except UnicodeDecodeError:
|
||||
# Try with different encoding
|
||||
try:
|
||||
with open(file_path, 'r', encoding='latin-1', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
yield line
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while reading {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read {file_path}: {e}")
|
||||
|
||||
def _process_code_file(self, file_path: Path) -> Iterator[str]:
|
||||
"""Process a code file."""
|
||||
# Code files are processed as text, but we can add syntax-aware processing
|
||||
# For now, just extract text lines
|
||||
yield from self._process_text_file(file_path)
|
||||
|
||||
def _process_pdf(self, file_path: Path) -> Iterator[str]:
|
||||
"""Extract text from PDF file."""
|
||||
if not self.use_pdf_extraction:
|
||||
logger.warning(f"PDF extraction disabled. Skipping {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
if self._pypdf2_available:
|
||||
import PyPDF2
|
||||
with open(file_path, 'rb') as f:
|
||||
pdf_reader = PyPDF2.PdfReader(f)
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
try:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
# Split into sentences/lines
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if line and len(line) > 5: # Filter very short lines
|
||||
yield line
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while processing PDF page {page_num} from {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting page {page_num} from {file_path}: {e}")
|
||||
|
||||
elif self._pdfplumber_available:
|
||||
import pdfplumber
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages):
|
||||
try:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if line and len(line) > 5:
|
||||
yield line
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while processing PDF page {page_num} from {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting page {page_num} from {file_path}: {e}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while processing PDF {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract text from PDF {file_path}: {e}")
|
||||
|
||||
def _process_image(self, file_path: Path) -> Iterator[str]:
|
||||
"""Extract text from image using OCR."""
|
||||
if not self.use_ocr or not self._ocr_available:
|
||||
logger.warning(f"OCR disabled or unavailable. Skipping {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
|
||||
# Open and process image
|
||||
img = Image.open(file_path)
|
||||
|
||||
# Perform OCR
|
||||
text = pytesseract.image_to_string(img)
|
||||
|
||||
if text:
|
||||
# Split into lines
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if line and len(line) > 3: # Filter very short lines
|
||||
yield line
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while processing image {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract text from image {file_path}: {e}")
|
||||
|
||||
def process_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
recursive: bool = True,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
min_length: int = 10,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Process all files in a directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
recursive: Whether to process subdirectories
|
||||
include_patterns: Optional list of glob patterns to include
|
||||
exclude_patterns: Optional list of glob patterns to exclude
|
||||
min_length: Minimum length for extracted text lines
|
||||
|
||||
Yields:
|
||||
Text lines from all processed files
|
||||
"""
|
||||
directory = Path(directory)
|
||||
|
||||
if not directory.exists():
|
||||
logger.error(f"Directory not found: {directory}")
|
||||
return
|
||||
|
||||
# Try to load cached file list
|
||||
cache_path = self._get_cache_path(directory, 'files')
|
||||
|
||||
# Compute directory hash (may be interrupted)
|
||||
try:
|
||||
logger.info("Computing directory hash for cache validation...")
|
||||
logger.info("(This may take a while for large directories. Press Ctrl+C to skip cache and do fresh scan)")
|
||||
sys.stderr.flush()
|
||||
current_hash = self._compute_directory_hash(directory, recursive)
|
||||
cached_data = self._load_cache(cache_path)
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("\n⚠️ Directory hash computation interrupted.")
|
||||
logger.warning(" Skipping cache validation and doing fresh directory scan...")
|
||||
logger.warning(" (Press Ctrl+C again to stop completely)")
|
||||
sys.stderr.flush()
|
||||
current_hash = None # Force cache miss
|
||||
cached_data = None
|
||||
# Don't re-raise - allow user to continue with fresh scan
|
||||
# If they want to stop completely, they can press Ctrl+C again during scanning
|
||||
|
||||
files_to_process = []
|
||||
scanned_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
# Check if cache is valid
|
||||
if current_hash and cached_data and cached_data.get('hash') == current_hash:
|
||||
files_to_process = [Path(f) for f in cached_data.get('files', [])]
|
||||
logger.info(f"✅ Loaded {len(files_to_process):,} files from cache (skipping directory scan)")
|
||||
else:
|
||||
# Cache miss or invalid - scan directory
|
||||
logger.info("Scanning directory (cache miss or invalid)...")
|
||||
|
||||
# Collect all supported file extensions
|
||||
all_supported_extensions = (
|
||||
self.TEXT_EXTENSIONS |
|
||||
self.CODE_EXTENSIONS |
|
||||
self.PDF_EXTENSIONS |
|
||||
self.IMAGE_EXTENSIONS
|
||||
)
|
||||
|
||||
if recursive:
|
||||
pattern = '**/*'
|
||||
else:
|
||||
pattern = '*'
|
||||
|
||||
# Default exclude patterns for common directories that don't contain training data
|
||||
default_exclude_patterns = [
|
||||
'**/.git/**',
|
||||
'**/__pycache__/**',
|
||||
'**/node_modules/**',
|
||||
'**/.venv/**',
|
||||
'**/venv/**',
|
||||
'**/.env/**',
|
||||
'**/.pytest_cache/**',
|
||||
'**/.mypy_cache/**',
|
||||
'**/.tox/**',
|
||||
'**/.coverage/**',
|
||||
'**/dist/**',
|
||||
'**/build/**',
|
||||
'**/*.pyc',
|
||||
'**/.DS_Store',
|
||||
]
|
||||
|
||||
# Merge user exclude patterns with defaults
|
||||
all_exclude_patterns = default_exclude_patterns.copy()
|
||||
if exclude_patterns:
|
||||
# Convert any Path objects to strings
|
||||
all_exclude_patterns.extend(str(p) if isinstance(p, Path) else p for p in exclude_patterns)
|
||||
|
||||
# Ensure all patterns are strings (not Path objects)
|
||||
all_exclude_patterns = [str(p) for p in all_exclude_patterns]
|
||||
|
||||
# Convert include_patterns to strings as well
|
||||
if include_patterns:
|
||||
include_patterns = [str(p) if isinstance(p, Path) else p for p in include_patterns]
|
||||
|
||||
logger.info(f"Scanning directory: {directory} (recursive={recursive})...")
|
||||
logger.info("This may take several minutes for large directories. Please wait...")
|
||||
sys.stderr.flush() # Force flush to show message immediately
|
||||
|
||||
try:
|
||||
for file_path in directory.glob(pattern):
|
||||
scanned_count += 1
|
||||
|
||||
# Progress reporting every 1000 files scanned
|
||||
if scanned_count % 1000 == 0:
|
||||
logger.info(f"Scanned {scanned_count:,} paths, found {len(files_to_process):,} files to process...")
|
||||
sys.stderr.flush() # Force flush to show progress immediately
|
||||
|
||||
# Skip if not a file (handles symlinks, directories, etc. gracefully)
|
||||
try:
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
except (OSError, PermissionError) as e:
|
||||
# Skip inaccessible files (broken symlinks, permission denied, etc.)
|
||||
skipped_count += 1
|
||||
if skipped_count <= 10: # Only log first 10 to avoid spam
|
||||
logger.debug(f"Skipping inaccessible path: {file_path} ({e})")
|
||||
continue
|
||||
|
||||
# Early filtering by extension to avoid checking unsupported files
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix not in all_supported_extensions:
|
||||
continue
|
||||
|
||||
# Check include/exclude patterns
|
||||
if include_patterns:
|
||||
if not any(file_path.match(pattern) for pattern in include_patterns):
|
||||
continue
|
||||
|
||||
if all_exclude_patterns:
|
||||
if any(file_path.match(pattern) for pattern in all_exclude_patterns):
|
||||
continue
|
||||
|
||||
files_to_process.append(file_path)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Directory scanning interrupted. Found {len(files_to_process)} files so far.")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during directory scanning: {e}")
|
||||
logger.info(f"Continuing with {len(files_to_process)} files found so far...")
|
||||
|
||||
if skipped_count > 10:
|
||||
logger.info(f"Skipped {skipped_count} inaccessible paths")
|
||||
|
||||
logger.info(f"Found {len(files_to_process):,} files to process (scanned {scanned_count:,} paths)")
|
||||
sys.stderr.flush() # Force flush
|
||||
|
||||
# Save file list to cache
|
||||
cache_data = {
|
||||
'hash': current_hash,
|
||||
'files': [str(f.absolute()) for f in files_to_process],
|
||||
'recursive': recursive,
|
||||
}
|
||||
self._save_cache(cache_path, cache_data)
|
||||
logger.info(f"💾 Cached file list ({len(files_to_process):,} files) for future use")
|
||||
sys.stderr.flush() # Force flush
|
||||
|
||||
# Process each file
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
total_lines = 0
|
||||
total_files = len(files_to_process)
|
||||
|
||||
if total_files == 0:
|
||||
logger.warning("No files found to process!")
|
||||
return
|
||||
|
||||
logger.info(f"Starting to process {total_files} files with progress bar...")
|
||||
|
||||
# Create progress bar
|
||||
pbar = tqdm(
|
||||
total=total_files,
|
||||
desc="Processing files",
|
||||
unit="file",
|
||||
ncols=120,
|
||||
mininterval=0.1, # Update at least every 0.1 seconds
|
||||
maxinterval=1.0, # Force update at least once per second
|
||||
file=sys.stderr, # Write to stderr to avoid buffering issues
|
||||
dynamic_ncols=True, # Auto-adjust to terminal width
|
||||
disable=False, # Explicitly enable
|
||||
)
|
||||
|
||||
try:
|
||||
for idx, file_path in enumerate(files_to_process, 1):
|
||||
try:
|
||||
file_lines = list(self.process_file(file_path))
|
||||
if file_lines:
|
||||
processed_count += 1
|
||||
for line in file_lines:
|
||||
if len(line) >= min_length:
|
||||
yield line
|
||||
total_lines += 1
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
# Update progress bar with statistics
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1) # Advance progress bar
|
||||
pbar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pbar.close()
|
||||
logger.warning(
|
||||
f"Processing interrupted. "
|
||||
f"Files: {idx}/{total_files}, Processed: {processed_count}, "
|
||||
f"Skipped: {skipped_count}, Errors: {error_count}, "
|
||||
f"Lines extracted: {total_lines:,}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
# Update progress bar even on errors
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1) # Advance progress bar even on error
|
||||
pbar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
logger.info(
|
||||
f"Processing complete: {processed_count}/{total_files} files processed successfully, "
|
||||
f"{skipped_count} skipped, {error_count} errors, {total_lines:,} lines extracted"
|
||||
)
|
||||
|
||||
def process_to_list(
|
||||
self,
|
||||
directory: Path,
|
||||
recursive: bool = True,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
min_length: int = 10,
|
||||
max_samples: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Process directory and return list of text lines.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
recursive: Whether to process subdirectories
|
||||
include_patterns: Optional list of glob patterns to include
|
||||
exclude_patterns: Optional list of glob patterns to exclude
|
||||
min_length: Minimum length for extracted text lines
|
||||
max_samples: Maximum number of samples to return (None = all)
|
||||
|
||||
Returns:
|
||||
List of text lines
|
||||
"""
|
||||
logger.info(f"Starting data extraction from {directory}...")
|
||||
logger.info("This may take a while for large directories. Progress will be shown below.")
|
||||
sys.stderr.flush() # Force flush to show message immediately
|
||||
|
||||
texts = []
|
||||
try:
|
||||
for text in self.process_directory(
|
||||
directory=directory,
|
||||
recursive=recursive,
|
||||
include_patterns=include_patterns,
|
||||
exclude_patterns=exclude_patterns,
|
||||
min_length=min_length,
|
||||
):
|
||||
texts.append(text)
|
||||
if max_samples and len(texts) >= max_samples:
|
||||
logger.info(f"Reached max_samples limit ({max_samples}). Stopping extraction.")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
# Return partial results if interrupted
|
||||
logger.warning(
|
||||
f"Data processing interrupted. Returning {len(texts):,} text samples collected so far."
|
||||
)
|
||||
# Re-raise to allow caller to handle if needed
|
||||
raise
|
||||
|
||||
logger.info(f"✅ Extracted {len(texts):,} text samples from {directory}")
|
||||
return texts
|
||||
|
||||
|
||||
def extract_text_from_directory(
|
||||
directory: Path,
|
||||
recursive: bool = True,
|
||||
use_ocr: bool = True,
|
||||
use_pdf_extraction: bool = True,
|
||||
min_length: int = 10,
|
||||
max_samples: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convenience function to extract text from a directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
recursive: Whether to process subdirectories
|
||||
use_ocr: Whether to use OCR for images
|
||||
use_pdf_extraction: Whether to extract text from PDFs
|
||||
min_length: Minimum length for extracted text lines
|
||||
max_samples: Maximum number of samples to return (None = all)
|
||||
|
||||
Returns:
|
||||
List of text lines
|
||||
"""
|
||||
processor = DataProcessor(use_ocr=use_ocr, use_pdf_extraction=use_pdf_extraction)
|
||||
try:
|
||||
return processor.process_to_list(
|
||||
directory=directory,
|
||||
recursive=recursive,
|
||||
min_length=min_length,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.error(
|
||||
"\n⚠️ Data processing interrupted by user (Ctrl+C).\n"
|
||||
" No data was loaded. Please run the training command again to retry."
|
||||
)
|
||||
# Re-raise to stop training
|
||||
raise
|
||||
|
||||
|
||||
# Try to import BPE tokenizer for direct access
|
||||
try:
|
||||
from .bpe_tokenizer import BPETokenizer
|
||||
__all__ = [
|
||||
'TextDataset',
|
||||
'SimpleTokenizer',
|
||||
'BPETokenizer',
|
||||
'create_dataloader',
|
||||
'DataProcessor',
|
||||
'extract_text_from_directory',
|
||||
]
|
||||
except ImportError:
|
||||
__all__ = [
|
||||
'TextDataset',
|
||||
'SimpleTokenizer',
|
||||
'create_dataloader',
|
||||
'DataProcessor',
|
||||
'extract_text_from_directory',
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user