Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
924
data.example/__init__.py
Normal file
924
data.example/__init__.py
Normal file
@@ -0,0 +1,924 @@
|
||||
"""
|
||||
Data loading and preprocessing utilities
|
||||
Includes comprehensive data processor for multiple file types (PDFs, images, code, text, etc.)
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from typing import List, Dict, Optional, Iterator
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import hashlib
|
||||
import pickle
|
||||
import os
|
||||
import sys
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ensure logging output is unbuffered
|
||||
for handler in logger.handlers:
|
||||
handler.flush()
|
||||
# Also ensure root logger handlers are unbuffered
|
||||
for handler in logging.root.handlers:
|
||||
handler.flush()
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""
|
||||
Dataset for text data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: List[str],
|
||||
tokenizer,
|
||||
max_length: int = 512,
|
||||
stride: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
texts: List of text strings
|
||||
tokenizer: Tokenizer instance
|
||||
max_length: Maximum sequence length
|
||||
stride: Stride for sliding window (if None, no overlap)
|
||||
"""
|
||||
self.texts = texts
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.stride = stride if stride is not None else max_length
|
||||
|
||||
# Tokenize all texts
|
||||
self.sequences = self._prepare_sequences()
|
||||
|
||||
def _prepare_sequences(self) -> List[torch.Tensor]:
|
||||
"""Tokenize and chunk sequences."""
|
||||
sequences = []
|
||||
|
||||
for text in self.texts:
|
||||
# Tokenize text
|
||||
tokens = self.tokenizer.encode(text)
|
||||
|
||||
# Chunk into sequences of max_length
|
||||
for i in range(0, len(tokens), self.stride):
|
||||
chunk = tokens[i:i + self.max_length]
|
||||
|
||||
# Pad if necessary
|
||||
if len(chunk) < self.max_length:
|
||||
chunk = chunk + [self.tokenizer.pad_token_id] * (self.max_length - len(chunk))
|
||||
|
||||
sequences.append(torch.tensor(chunk, dtype=torch.long))
|
||||
|
||||
# Stop if we've covered the entire sequence
|
||||
if i + self.max_length >= len(tokens):
|
||||
break
|
||||
|
||||
return sequences
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
sequence = self.sequences[idx]
|
||||
|
||||
# Input is all tokens except the last one
|
||||
input_ids = sequence[:-1]
|
||||
# Labels are all tokens except the first one (shifted by 1)
|
||||
labels = sequence[1:]
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'labels': labels,
|
||||
}
|
||||
|
||||
|
||||
class SimpleTokenizer:
|
||||
"""
|
||||
Simple character-level tokenizer (for backward compatibility).
|
||||
Uses BPE tokenizer by default if available, falls back to character-level.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file: Optional[str] = None,
|
||||
use_bpe: bool = True,
|
||||
vocab_size: int = 50257,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_file: Optional path to vocabulary file
|
||||
use_bpe: Whether to use BPE tokenizer (default: True)
|
||||
vocab_size: Vocabulary size for BPE tokenizer (default: 50257)
|
||||
"""
|
||||
self.use_bpe = use_bpe
|
||||
|
||||
# Try to use BPE tokenizer if available
|
||||
if use_bpe:
|
||||
try:
|
||||
from .bpe_tokenizer import BPETokenizer
|
||||
self.bpe_tokenizer = BPETokenizer(vocab_size=vocab_size)
|
||||
self._use_bpe = True
|
||||
|
||||
# Map BPE tokenizer attributes
|
||||
self.pad_token_id = self.bpe_tokenizer.pad_token_id
|
||||
self.unk_token_id = self.bpe_tokenizer.unk_token_id
|
||||
self.bos_token_id = self.bpe_tokenizer.bos_token_id
|
||||
self.eos_token_id = self.bpe_tokenizer.eos_token_id
|
||||
self.vocab_size = self.bpe_tokenizer.vocab_size
|
||||
self.vocab = {i: self.bpe_tokenizer.vocab.get(i, bytes([i])).decode('utf-8', errors='replace')
|
||||
for i in range(256)} # Limited vocab view
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
return
|
||||
except ImportError:
|
||||
logger.warning("BPE tokenizer not available, falling back to character-level")
|
||||
self._use_bpe = False
|
||||
|
||||
# Fallback to character-level tokenizer
|
||||
self._use_bpe = False
|
||||
|
||||
if vocab_file and Path(vocab_file).exists():
|
||||
with open(vocab_file, 'r') as f:
|
||||
vocab = json.load(f)
|
||||
self.vocab = vocab
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
else:
|
||||
# Default: character-level vocabulary
|
||||
self.vocab = {
|
||||
'<pad>': 0,
|
||||
'<unk>': 1,
|
||||
'<bos>': 2,
|
||||
'<eos>': 3,
|
||||
}
|
||||
# Add printable ASCII characters
|
||||
for i in range(32, 127):
|
||||
self.vocab[chr(i)] = len(self.vocab)
|
||||
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
|
||||
self.pad_token_id = self.vocab.get('<pad>', 0)
|
||||
self.unk_token_id = self.vocab.get('<unk>', 1)
|
||||
self.bos_token_id = self.vocab.get('<bos>', 2)
|
||||
self.eos_token_id = self.vocab.get('<eos>', 3)
|
||||
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
"""Encode text to token IDs."""
|
||||
if self._use_bpe:
|
||||
return self.bpe_tokenizer.encode(text)
|
||||
|
||||
# Character-level encoding
|
||||
tokens = []
|
||||
for char in text:
|
||||
tokens.append(self.vocab.get(char, self.unk_token_id))
|
||||
return tokens
|
||||
|
||||
def decode(self, token_ids: List[int]) -> str:
|
||||
"""Decode token IDs to text."""
|
||||
if self._use_bpe:
|
||||
return self.bpe_tokenizer.decode(token_ids)
|
||||
|
||||
# Character-level decoding
|
||||
chars = []
|
||||
for tid in token_ids:
|
||||
if tid in self.inv_vocab:
|
||||
char = self.inv_vocab[tid]
|
||||
if char not in ['<pad>', '<bos>', '<eos>']:
|
||||
chars.append(char)
|
||||
return ''.join(chars)
|
||||
|
||||
def save_vocab(self, vocab_file: str):
|
||||
"""Save vocabulary to file."""
|
||||
if self._use_bpe:
|
||||
# Save BPE tokenizer
|
||||
merges_file = str(vocab_file).replace('.json', '_merges.json')
|
||||
self.bpe_tokenizer.save(merges_file, vocab_file)
|
||||
else:
|
||||
# Save character-level vocab
|
||||
with open(vocab_file, 'w') as f:
|
||||
json.dump(self.vocab, f, indent=2)
|
||||
|
||||
def train(self, texts: List[str], num_merges: Optional[int] = None, verbose: bool = False):
|
||||
"""Train the tokenizer on texts (BPE only)."""
|
||||
if self._use_bpe:
|
||||
self.bpe_tokenizer.train(texts, num_merges=num_merges, verbose=verbose)
|
||||
# Update vocab size
|
||||
self.vocab_size = self.bpe_tokenizer.vocab_size
|
||||
else:
|
||||
logger.warning("Training not supported for character-level tokenizer")
|
||||
|
||||
|
||||
def create_dataloader(
|
||||
texts: List[str],
|
||||
tokenizer,
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512,
|
||||
shuffle: bool = True,
|
||||
num_workers: int = 0,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Create a DataLoader for text data.
|
||||
|
||||
Args:
|
||||
texts: List of text strings
|
||||
tokenizer: Tokenizer instance
|
||||
batch_size: Batch size
|
||||
max_length: Maximum sequence length
|
||||
shuffle: Whether to shuffle data
|
||||
num_workers: Number of data loading workers
|
||||
|
||||
Returns:
|
||||
DataLoader instance
|
||||
"""
|
||||
dataset = TextDataset(
|
||||
texts=texts,
|
||||
tokenizer=tokenizer,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
def collate_fn(batch):
|
||||
"""Collate function for batching."""
|
||||
input_ids = torch.stack([item['input_ids'] for item in batch])
|
||||
labels = torch.stack([item['labels'] for item in batch])
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'labels': labels,
|
||||
}
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Processor for Multiple File Types
|
||||
# ============================================================================
|
||||
|
||||
class DataProcessor:
|
||||
"""
|
||||
Process various file types and extract text for training.
|
||||
Supports: PDFs, images (OCR), code files, text files, and more.
|
||||
"""
|
||||
|
||||
# Supported file extensions
|
||||
TEXT_EXTENSIONS = {'.txt', '.md', '.rst', '.log', '.csv', '.json', '.jsonl', '.xml', '.html', '.htm'}
|
||||
CODE_EXTENSIONS = {
|
||||
'.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp',
|
||||
'.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala', '.r',
|
||||
'.sql', '.sh', '.bash', '.zsh', '.fish', '.yaml', '.yml', '.toml',
|
||||
'.ini', '.cfg', '.conf', '.vue', '.svelte', '.dart', '.lua', '.pl',
|
||||
'.hs', '.ml', '.mli', '.elm', '.ex', '.exs', '.jl', '.clj', '.cljs'
|
||||
}
|
||||
PDF_EXTENSIONS = {'.pdf'}
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.tif', '.webp'}
|
||||
|
||||
def __init__(self, use_ocr: bool = True, use_pdf_extraction: bool = True, cache_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize data processor.
|
||||
|
||||
Args:
|
||||
use_ocr: Whether to use OCR for images (requires pytesseract)
|
||||
use_pdf_extraction: Whether to extract text from PDFs (requires PyPDF2 or pdfplumber)
|
||||
cache_dir: Directory to store cache files (default: .cache in data directory)
|
||||
"""
|
||||
self.use_ocr = use_ocr
|
||||
self.use_pdf_extraction = use_pdf_extraction
|
||||
self.cache_dir = Path(cache_dir) if cache_dir else None
|
||||
self._check_dependencies()
|
||||
|
||||
def _get_cache_dir(self, directory: Path) -> Path:
|
||||
"""Get cache directory for a given data directory."""
|
||||
if self.cache_dir:
|
||||
return self.cache_dir
|
||||
# Default: .cache in the data directory
|
||||
cache_dir = directory / '.cache'
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
def _compute_directory_hash(self, directory: Path, recursive: bool = True) -> str:
|
||||
"""
|
||||
Compute a hash of directory contents to detect changes.
|
||||
Uses file paths and modification times.
|
||||
"""
|
||||
directory = Path(directory)
|
||||
file_info = []
|
||||
|
||||
pattern = '**/*' if recursive else '*'
|
||||
scanned_count = 0
|
||||
|
||||
try:
|
||||
for file_path in directory.glob(pattern):
|
||||
scanned_count += 1
|
||||
|
||||
# Progress feedback every 5000 files (hash computation can be slow)
|
||||
if scanned_count % 5000 == 0:
|
||||
logger.info(f"Computing directory hash: scanned {scanned_count:,} paths...")
|
||||
sys.stderr.flush()
|
||||
|
||||
try:
|
||||
if file_path.is_file():
|
||||
stat = file_path.stat()
|
||||
file_info.append(f"{file_path.relative_to(directory)}:{stat.st_mtime}:{stat.st_size}")
|
||||
except (OSError, PermissionError):
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Directory hash computation interrupted after scanning {scanned_count:,} paths")
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
# Re-raise to allow graceful handling upstream
|
||||
logger.warning("Directory hash computation interrupted. Will skip cache and do fresh scan.")
|
||||
raise
|
||||
|
||||
# Sort for consistent hashing
|
||||
file_info.sort()
|
||||
content = '\n'.join(file_info)
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def _get_cache_path(self, directory: Path, cache_type: str = 'files') -> Path:
|
||||
"""Get cache file path for a directory."""
|
||||
cache_dir = self._get_cache_dir(directory)
|
||||
# Create a safe filename from directory path
|
||||
dir_hash = hashlib.md5(str(directory.absolute()).encode()).hexdigest()[:8]
|
||||
return cache_dir / f"{cache_type}_{dir_hash}.pkl"
|
||||
|
||||
def _load_cache(self, cache_path: Path) -> Optional[Dict]:
|
||||
"""Load cache from file."""
|
||||
if not cache_path.exists():
|
||||
return None
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load cache from {cache_path}: {e}")
|
||||
return None
|
||||
|
||||
def _save_cache(self, cache_path: Path, data: Dict):
|
||||
"""Save cache to file."""
|
||||
try:
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cache_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save cache to {cache_path}: {e}")
|
||||
|
||||
def clear_cache(self, directory: Path):
|
||||
"""
|
||||
Clear cache for a directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path to clear cache for
|
||||
"""
|
||||
cache_path = self._get_cache_path(directory, 'files')
|
||||
if cache_path.exists():
|
||||
try:
|
||||
cache_path.unlink()
|
||||
logger.info(f"✅ Cleared cache for {directory}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear cache: {e}")
|
||||
else:
|
||||
logger.info(f"No cache found for {directory}")
|
||||
|
||||
def _check_dependencies(self):
|
||||
"""Check if required dependencies are available."""
|
||||
if self.use_ocr:
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
self._ocr_available = True
|
||||
except ImportError:
|
||||
logger.warning("pytesseract or PIL not available. OCR disabled.")
|
||||
self._ocr_available = False
|
||||
self.use_ocr = False
|
||||
|
||||
if self.use_pdf_extraction:
|
||||
try:
|
||||
import PyPDF2
|
||||
self._pypdf2_available = True
|
||||
except ImportError:
|
||||
try:
|
||||
import pdfplumber
|
||||
self._pdfplumber_available = True
|
||||
self._pypdf2_available = False
|
||||
except ImportError:
|
||||
logger.warning("PyPDF2 or pdfplumber not available. PDF extraction disabled.")
|
||||
self._pdfplumber_available = False
|
||||
self.use_pdf_extraction = False
|
||||
|
||||
def process_file(self, file_path: Path) -> Iterator[str]:
|
||||
"""
|
||||
Process a single file and yield text lines.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Yields:
|
||||
Text lines extracted from the file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"File not found: {file_path}")
|
||||
return
|
||||
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
try:
|
||||
if suffix in self.TEXT_EXTENSIONS:
|
||||
yield from self._process_text_file(file_path)
|
||||
elif suffix in self.CODE_EXTENSIONS:
|
||||
yield from self._process_code_file(file_path)
|
||||
elif suffix in self.PDF_EXTENSIONS:
|
||||
yield from self._process_pdf(file_path)
|
||||
elif suffix in self.IMAGE_EXTENSIONS:
|
||||
yield from self._process_image(file_path)
|
||||
else:
|
||||
# Try to process as text file as fallback (many file types can be read as text)
|
||||
# Only log at debug level to avoid spam
|
||||
logger.debug(f"Unsupported file type: {file_path} (extension: {suffix}), attempting as text...")
|
||||
try:
|
||||
yield from self._process_text_file(file_path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to process {file_path} as text: {e}")
|
||||
except KeyboardInterrupt:
|
||||
# Re-raise KeyboardInterrupt to allow graceful shutdown
|
||||
logger.warning(f"Interrupted while processing {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
def _process_text_file(self, file_path: Path) -> Iterator[str]:
|
||||
"""Process a text file."""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
yield line
|
||||
except KeyboardInterrupt:
|
||||
# Re-raise KeyboardInterrupt to allow graceful shutdown
|
||||
logger.warning(f"Interrupted while reading {file_path}")
|
||||
raise
|
||||
except UnicodeDecodeError:
|
||||
# Try with different encoding
|
||||
try:
|
||||
with open(file_path, 'r', encoding='latin-1', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
yield line
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while reading {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read {file_path}: {e}")
|
||||
|
||||
def _process_code_file(self, file_path: Path) -> Iterator[str]:
|
||||
"""Process a code file."""
|
||||
# Code files are processed as text, but we can add syntax-aware processing
|
||||
# For now, just extract text lines
|
||||
yield from self._process_text_file(file_path)
|
||||
|
||||
def _process_pdf(self, file_path: Path) -> Iterator[str]:
|
||||
"""Extract text from PDF file."""
|
||||
if not self.use_pdf_extraction:
|
||||
logger.warning(f"PDF extraction disabled. Skipping {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
if self._pypdf2_available:
|
||||
import PyPDF2
|
||||
with open(file_path, 'rb') as f:
|
||||
pdf_reader = PyPDF2.PdfReader(f)
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
try:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
# Split into sentences/lines
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if line and len(line) > 5: # Filter very short lines
|
||||
yield line
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while processing PDF page {page_num} from {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting page {page_num} from {file_path}: {e}")
|
||||
|
||||
elif self._pdfplumber_available:
|
||||
import pdfplumber
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages):
|
||||
try:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if line and len(line) > 5:
|
||||
yield line
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while processing PDF page {page_num} from {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting page {page_num} from {file_path}: {e}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while processing PDF {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract text from PDF {file_path}: {e}")
|
||||
|
||||
def _process_image(self, file_path: Path) -> Iterator[str]:
|
||||
"""Extract text from image using OCR."""
|
||||
if not self.use_ocr or not self._ocr_available:
|
||||
logger.warning(f"OCR disabled or unavailable. Skipping {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
|
||||
# Open and process image
|
||||
img = Image.open(file_path)
|
||||
|
||||
# Perform OCR
|
||||
text = pytesseract.image_to_string(img)
|
||||
|
||||
if text:
|
||||
# Split into lines
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if line and len(line) > 3: # Filter very short lines
|
||||
yield line
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Interrupted while processing image {file_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract text from image {file_path}: {e}")
|
||||
|
||||
def process_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
recursive: bool = True,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
min_length: int = 10,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Process all files in a directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
recursive: Whether to process subdirectories
|
||||
include_patterns: Optional list of glob patterns to include
|
||||
exclude_patterns: Optional list of glob patterns to exclude
|
||||
min_length: Minimum length for extracted text lines
|
||||
|
||||
Yields:
|
||||
Text lines from all processed files
|
||||
"""
|
||||
directory = Path(directory)
|
||||
|
||||
if not directory.exists():
|
||||
logger.error(f"Directory not found: {directory}")
|
||||
return
|
||||
|
||||
# Try to load cached file list
|
||||
cache_path = self._get_cache_path(directory, 'files')
|
||||
|
||||
# Compute directory hash (may be interrupted)
|
||||
try:
|
||||
logger.info("Computing directory hash for cache validation...")
|
||||
logger.info("(This may take a while for large directories. Press Ctrl+C to skip cache and do fresh scan)")
|
||||
sys.stderr.flush()
|
||||
current_hash = self._compute_directory_hash(directory, recursive)
|
||||
cached_data = self._load_cache(cache_path)
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("\n⚠️ Directory hash computation interrupted.")
|
||||
logger.warning(" Skipping cache validation and doing fresh directory scan...")
|
||||
logger.warning(" (Press Ctrl+C again to stop completely)")
|
||||
sys.stderr.flush()
|
||||
current_hash = None # Force cache miss
|
||||
cached_data = None
|
||||
# Don't re-raise - allow user to continue with fresh scan
|
||||
# If they want to stop completely, they can press Ctrl+C again during scanning
|
||||
|
||||
files_to_process = []
|
||||
scanned_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
# Check if cache is valid
|
||||
if current_hash and cached_data and cached_data.get('hash') == current_hash:
|
||||
files_to_process = [Path(f) for f in cached_data.get('files', [])]
|
||||
logger.info(f"✅ Loaded {len(files_to_process):,} files from cache (skipping directory scan)")
|
||||
else:
|
||||
# Cache miss or invalid - scan directory
|
||||
logger.info("Scanning directory (cache miss or invalid)...")
|
||||
|
||||
# Collect all supported file extensions
|
||||
all_supported_extensions = (
|
||||
self.TEXT_EXTENSIONS |
|
||||
self.CODE_EXTENSIONS |
|
||||
self.PDF_EXTENSIONS |
|
||||
self.IMAGE_EXTENSIONS
|
||||
)
|
||||
|
||||
if recursive:
|
||||
pattern = '**/*'
|
||||
else:
|
||||
pattern = '*'
|
||||
|
||||
# Default exclude patterns for common directories that don't contain training data
|
||||
default_exclude_patterns = [
|
||||
'**/.git/**',
|
||||
'**/__pycache__/**',
|
||||
'**/node_modules/**',
|
||||
'**/.venv/**',
|
||||
'**/venv/**',
|
||||
'**/.env/**',
|
||||
'**/.pytest_cache/**',
|
||||
'**/.mypy_cache/**',
|
||||
'**/.tox/**',
|
||||
'**/.coverage/**',
|
||||
'**/dist/**',
|
||||
'**/build/**',
|
||||
'**/*.pyc',
|
||||
'**/.DS_Store',
|
||||
]
|
||||
|
||||
# Merge user exclude patterns with defaults
|
||||
all_exclude_patterns = default_exclude_patterns.copy()
|
||||
if exclude_patterns:
|
||||
# Convert any Path objects to strings
|
||||
all_exclude_patterns.extend(str(p) if isinstance(p, Path) else p for p in exclude_patterns)
|
||||
|
||||
# Ensure all patterns are strings (not Path objects)
|
||||
all_exclude_patterns = [str(p) for p in all_exclude_patterns]
|
||||
|
||||
# Convert include_patterns to strings as well
|
||||
if include_patterns:
|
||||
include_patterns = [str(p) if isinstance(p, Path) else p for p in include_patterns]
|
||||
|
||||
logger.info(f"Scanning directory: {directory} (recursive={recursive})...")
|
||||
logger.info("This may take several minutes for large directories. Please wait...")
|
||||
sys.stderr.flush() # Force flush to show message immediately
|
||||
|
||||
try:
|
||||
for file_path in directory.glob(pattern):
|
||||
scanned_count += 1
|
||||
|
||||
# Progress reporting every 1000 files scanned
|
||||
if scanned_count % 1000 == 0:
|
||||
logger.info(f"Scanned {scanned_count:,} paths, found {len(files_to_process):,} files to process...")
|
||||
sys.stderr.flush() # Force flush to show progress immediately
|
||||
|
||||
# Skip if not a file (handles symlinks, directories, etc. gracefully)
|
||||
try:
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
except (OSError, PermissionError) as e:
|
||||
# Skip inaccessible files (broken symlinks, permission denied, etc.)
|
||||
skipped_count += 1
|
||||
if skipped_count <= 10: # Only log first 10 to avoid spam
|
||||
logger.debug(f"Skipping inaccessible path: {file_path} ({e})")
|
||||
continue
|
||||
|
||||
# Early filtering by extension to avoid checking unsupported files
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix not in all_supported_extensions:
|
||||
continue
|
||||
|
||||
# Check include/exclude patterns
|
||||
if include_patterns:
|
||||
if not any(file_path.match(pattern) for pattern in include_patterns):
|
||||
continue
|
||||
|
||||
if all_exclude_patterns:
|
||||
if any(file_path.match(pattern) for pattern in all_exclude_patterns):
|
||||
continue
|
||||
|
||||
files_to_process.append(file_path)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning(f"Directory scanning interrupted. Found {len(files_to_process)} files so far.")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during directory scanning: {e}")
|
||||
logger.info(f"Continuing with {len(files_to_process)} files found so far...")
|
||||
|
||||
if skipped_count > 10:
|
||||
logger.info(f"Skipped {skipped_count} inaccessible paths")
|
||||
|
||||
logger.info(f"Found {len(files_to_process):,} files to process (scanned {scanned_count:,} paths)")
|
||||
sys.stderr.flush() # Force flush
|
||||
|
||||
# Save file list to cache
|
||||
cache_data = {
|
||||
'hash': current_hash,
|
||||
'files': [str(f.absolute()) for f in files_to_process],
|
||||
'recursive': recursive,
|
||||
}
|
||||
self._save_cache(cache_path, cache_data)
|
||||
logger.info(f"💾 Cached file list ({len(files_to_process):,} files) for future use")
|
||||
sys.stderr.flush() # Force flush
|
||||
|
||||
# Process each file
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
total_lines = 0
|
||||
total_files = len(files_to_process)
|
||||
|
||||
if total_files == 0:
|
||||
logger.warning("No files found to process!")
|
||||
return
|
||||
|
||||
logger.info(f"Starting to process {total_files} files with progress bar...")
|
||||
|
||||
# Create progress bar
|
||||
pbar = tqdm(
|
||||
total=total_files,
|
||||
desc="Processing files",
|
||||
unit="file",
|
||||
ncols=120,
|
||||
mininterval=0.1, # Update at least every 0.1 seconds
|
||||
maxinterval=1.0, # Force update at least once per second
|
||||
file=sys.stderr, # Write to stderr to avoid buffering issues
|
||||
dynamic_ncols=True, # Auto-adjust to terminal width
|
||||
disable=False, # Explicitly enable
|
||||
)
|
||||
|
||||
try:
|
||||
for idx, file_path in enumerate(files_to_process, 1):
|
||||
try:
|
||||
file_lines = list(self.process_file(file_path))
|
||||
if file_lines:
|
||||
processed_count += 1
|
||||
for line in file_lines:
|
||||
if len(line) >= min_length:
|
||||
yield line
|
||||
total_lines += 1
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
# Update progress bar with statistics
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1) # Advance progress bar
|
||||
pbar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pbar.close()
|
||||
logger.warning(
|
||||
f"Processing interrupted. "
|
||||
f"Files: {idx}/{total_files}, Processed: {processed_count}, "
|
||||
f"Skipped: {skipped_count}, Errors: {error_count}, "
|
||||
f"Lines extracted: {total_lines:,}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
# Update progress bar even on errors
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1) # Advance progress bar even on error
|
||||
pbar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
logger.info(
|
||||
f"Processing complete: {processed_count}/{total_files} files processed successfully, "
|
||||
f"{skipped_count} skipped, {error_count} errors, {total_lines:,} lines extracted"
|
||||
)
|
||||
|
||||
def process_to_list(
|
||||
self,
|
||||
directory: Path,
|
||||
recursive: bool = True,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
min_length: int = 10,
|
||||
max_samples: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Process directory and return list of text lines.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
recursive: Whether to process subdirectories
|
||||
include_patterns: Optional list of glob patterns to include
|
||||
exclude_patterns: Optional list of glob patterns to exclude
|
||||
min_length: Minimum length for extracted text lines
|
||||
max_samples: Maximum number of samples to return (None = all)
|
||||
|
||||
Returns:
|
||||
List of text lines
|
||||
"""
|
||||
logger.info(f"Starting data extraction from {directory}...")
|
||||
logger.info("This may take a while for large directories. Progress will be shown below.")
|
||||
sys.stderr.flush() # Force flush to show message immediately
|
||||
|
||||
texts = []
|
||||
try:
|
||||
for text in self.process_directory(
|
||||
directory=directory,
|
||||
recursive=recursive,
|
||||
include_patterns=include_patterns,
|
||||
exclude_patterns=exclude_patterns,
|
||||
min_length=min_length,
|
||||
):
|
||||
texts.append(text)
|
||||
if max_samples and len(texts) >= max_samples:
|
||||
logger.info(f"Reached max_samples limit ({max_samples}). Stopping extraction.")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
# Return partial results if interrupted
|
||||
logger.warning(
|
||||
f"Data processing interrupted. Returning {len(texts):,} text samples collected so far."
|
||||
)
|
||||
# Re-raise to allow caller to handle if needed
|
||||
raise
|
||||
|
||||
logger.info(f"✅ Extracted {len(texts):,} text samples from {directory}")
|
||||
return texts
|
||||
|
||||
|
||||
def extract_text_from_directory(
|
||||
directory: Path,
|
||||
recursive: bool = True,
|
||||
use_ocr: bool = True,
|
||||
use_pdf_extraction: bool = True,
|
||||
min_length: int = 10,
|
||||
max_samples: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convenience function to extract text from a directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
recursive: Whether to process subdirectories
|
||||
use_ocr: Whether to use OCR for images
|
||||
use_pdf_extraction: Whether to extract text from PDFs
|
||||
min_length: Minimum length for extracted text lines
|
||||
max_samples: Maximum number of samples to return (None = all)
|
||||
|
||||
Returns:
|
||||
List of text lines
|
||||
"""
|
||||
processor = DataProcessor(use_ocr=use_ocr, use_pdf_extraction=use_pdf_extraction)
|
||||
try:
|
||||
return processor.process_to_list(
|
||||
directory=directory,
|
||||
recursive=recursive,
|
||||
min_length=min_length,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.error(
|
||||
"\n⚠️ Data processing interrupted by user (Ctrl+C).\n"
|
||||
" No data was loaded. Please run the training command again to retry."
|
||||
)
|
||||
# Re-raise to stop training
|
||||
raise
|
||||
|
||||
|
||||
# Try to import BPE tokenizer for direct access
|
||||
try:
|
||||
from .bpe_tokenizer import BPETokenizer
|
||||
__all__ = [
|
||||
'TextDataset',
|
||||
'SimpleTokenizer',
|
||||
'BPETokenizer',
|
||||
'create_dataloader',
|
||||
'DataProcessor',
|
||||
'extract_text_from_directory',
|
||||
]
|
||||
except ImportError:
|
||||
__all__ = [
|
||||
'TextDataset',
|
||||
'SimpleTokenizer',
|
||||
'create_dataloader',
|
||||
'DataProcessor',
|
||||
'extract_text_from_directory',
|
||||
]
|
||||
|
||||
423
data.example/bpe_tokenizer.py
Normal file
423
data.example/bpe_tokenizer.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Improved BPE Tokenizer based on GPT-4 tokenization approach
|
||||
Addresses common tokenization challenges:
|
||||
- UTF-8 byte-level encoding
|
||||
- Better Python code handling
|
||||
- Case-insensitive contraction matching
|
||||
- Limited number merging (1-3 digits)
|
||||
- Proper special token handling
|
||||
- Trailing whitespace warnings
|
||||
"""
|
||||
import re
|
||||
import json
|
||||
from typing import List, Dict, Tuple, Optional, Set
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class BPETokenizer:
|
||||
"""
|
||||
Byte Pair Encoding Tokenizer with GPT-4-inspired improvements.
|
||||
|
||||
Key features:
|
||||
- UTF-8 byte-level encoding
|
||||
- BPE merging algorithm
|
||||
- GPT-4 style regex pattern for text splitting
|
||||
- Better whitespace handling for Python code
|
||||
- Case-insensitive matching for contractions
|
||||
- Limited number merging (1-3 digits)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 50257,
|
||||
special_tokens: Optional[Dict[str, int]] = None,
|
||||
merges_file: Optional[str] = None,
|
||||
vocab_file: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize BPE tokenizer.
|
||||
|
||||
Args:
|
||||
vocab_size: Target vocabulary size (default 50257 for GPT-2 style)
|
||||
special_tokens: Dictionary of special token names to IDs
|
||||
merges_file: Path to saved merges file
|
||||
vocab_file: Path to saved vocab file
|
||||
"""
|
||||
# Special tokens
|
||||
self.special_tokens = special_tokens or {
|
||||
'<pad>': 0,
|
||||
'<unk>': 1,
|
||||
'<bos>': 2,
|
||||
'<eos>': 3,
|
||||
}
|
||||
|
||||
# Initialize byte vocabulary (0-255)
|
||||
self.byte_to_token = {i: i for i in range(256)}
|
||||
self.token_to_byte = {i: bytes([i]) for i in range(256)}
|
||||
self.next_token_id = 256
|
||||
|
||||
# BPE merges: (left, right) -> merged_token_id
|
||||
self.merges: Dict[Tuple[int, int], int] = {}
|
||||
|
||||
# Vocabulary: token_id -> bytes
|
||||
self.vocab: Dict[int, bytes] = {}
|
||||
self.inv_vocab: Dict[bytes, int] = {}
|
||||
|
||||
# Initialize vocab with bytes
|
||||
for i in range(256):
|
||||
self.vocab[i] = bytes([i])
|
||||
self.inv_vocab[bytes([i])] = i
|
||||
|
||||
# GPT-4 style regex pattern for splitting text
|
||||
# Improvements over GPT-2:
|
||||
# - Case-insensitive matching (flag)
|
||||
# - Better whitespace handling
|
||||
# - Limit number merging to 1-3 digits
|
||||
self.pattern = self._create_gpt4_pattern()
|
||||
|
||||
# Load pre-trained tokenizer if files provided
|
||||
if merges_file and vocab_file:
|
||||
self.load(merges_file, vocab_file)
|
||||
else:
|
||||
self.target_vocab_size = vocab_size
|
||||
|
||||
# Token IDs for special tokens
|
||||
self.pad_token_id = self.special_tokens.get('<pad>', 0)
|
||||
self.unk_token_id = self.special_tokens.get('<unk>', 1)
|
||||
self.bos_token_id = self.special_tokens.get('<bos>', 2)
|
||||
self.eos_token_id = self.special_tokens.get('<eos>', 3)
|
||||
|
||||
def _create_gpt4_pattern(self) -> re.Pattern:
|
||||
"""
|
||||
Create GPT-4 style regex pattern for splitting text.
|
||||
|
||||
Improvements over GPT-2:
|
||||
- Case-insensitive matching for contractions
|
||||
- Better whitespace handling (groups multiple spaces)
|
||||
- Limit number merging (1-3 digits)
|
||||
"""
|
||||
# GPT-4 style pattern with improvements
|
||||
# Pattern breakdown:
|
||||
# 1. Contractions: '(?i:[sdmt]|ll|ve|re) - case-insensitive
|
||||
# 2. Letters: [^\r\n\p{L}\p{N}]?+\p{L}+ - optional space + letters
|
||||
# 3. Numbers: \p{N}{1,3} - 1-3 digits only
|
||||
# 4. Punctuation: ?[^\s\p{L}\p{N}]++ - optional space + non-letter/number
|
||||
# 5. Whitespace: \r?\n - newlines
|
||||
# 6. Trailing whitespace: \s+ - multiple spaces
|
||||
pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++|\r?\n|\s+"""
|
||||
|
||||
# Compile with case-insensitive flag for contractions
|
||||
return re.compile(pattern, re.IGNORECASE | re.UNICODE)
|
||||
|
||||
def _get_stats(self, tokens: List[int]) -> Dict[Tuple[int, int], int]:
|
||||
"""
|
||||
Get statistics of consecutive token pairs.
|
||||
|
||||
Args:
|
||||
tokens: List of token IDs
|
||||
|
||||
Returns:
|
||||
Dictionary mapping pair tuples to counts
|
||||
"""
|
||||
stats = defaultdict(int)
|
||||
for i in range(len(tokens) - 1):
|
||||
pair = (tokens[i], tokens[i + 1])
|
||||
stats[pair] += 1
|
||||
return dict(stats)
|
||||
|
||||
def _merge(self, tokens: List[int], pair: Tuple[int, int], new_id: int) -> List[int]:
|
||||
"""
|
||||
Merge consecutive occurrences of a pair into a new token.
|
||||
|
||||
Args:
|
||||
tokens: List of token IDs
|
||||
pair: Tuple of (left, right) tokens to merge
|
||||
new_id: New token ID to replace the pair
|
||||
|
||||
Returns:
|
||||
New list with merged tokens
|
||||
"""
|
||||
if len(tokens) < 2:
|
||||
return tokens
|
||||
|
||||
new_tokens = []
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
# Check if we can merge at position i
|
||||
if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
|
||||
new_tokens.append(new_id)
|
||||
i += 2
|
||||
else:
|
||||
new_tokens.append(tokens[i])
|
||||
i += 1
|
||||
|
||||
return new_tokens
|
||||
|
||||
def train(
|
||||
self,
|
||||
texts: List[str],
|
||||
num_merges: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Train BPE tokenizer on a corpus of texts.
|
||||
|
||||
Args:
|
||||
texts: List of training texts
|
||||
num_merges: Number of merges to perform (default: vocab_size - 256)
|
||||
verbose: Whether to print progress
|
||||
"""
|
||||
if num_merges is None:
|
||||
num_merges = self.target_vocab_size - 256
|
||||
|
||||
# Convert all texts to byte sequences
|
||||
all_tokens = []
|
||||
for text in texts:
|
||||
# Split text using regex pattern
|
||||
chunks = self.pattern.findall(text)
|
||||
|
||||
# Convert each chunk to bytes and tokenize
|
||||
for chunk in chunks:
|
||||
bytes_seq = chunk.encode('utf-8')
|
||||
tokens = list(bytes_seq)
|
||||
all_tokens.extend(tokens)
|
||||
# Add separator between chunks (optional)
|
||||
# all_tokens.append(256) # separator token
|
||||
|
||||
# Perform BPE merges
|
||||
for merge_num in range(num_merges):
|
||||
# Get statistics
|
||||
stats = self._get_stats(all_tokens)
|
||||
|
||||
if not stats:
|
||||
break
|
||||
|
||||
# Find most frequent pair
|
||||
pair = max(stats, key=stats.get)
|
||||
|
||||
# Create new token
|
||||
new_id = self.next_token_id
|
||||
self.next_token_id += 1
|
||||
|
||||
# Merge
|
||||
all_tokens = self._merge(all_tokens, pair, new_id)
|
||||
|
||||
# Store merge
|
||||
self.merges[pair] = new_id
|
||||
|
||||
# Update vocabulary
|
||||
left_bytes = self.vocab.get(pair[0], bytes([pair[0]]))
|
||||
right_bytes = self.vocab.get(pair[1], bytes([pair[1]]))
|
||||
merged_bytes = left_bytes + right_bytes
|
||||
self.vocab[new_id] = merged_bytes
|
||||
self.inv_vocab[merged_bytes] = new_id
|
||||
|
||||
if verbose and (merge_num + 1) % 1000 == 0:
|
||||
print(f"Merged {merge_num + 1}/{num_merges} pairs")
|
||||
|
||||
if verbose:
|
||||
print(f"Training complete. Vocabulary size: {len(self.vocab)}")
|
||||
|
||||
def _encode_chunk(self, text: str) -> List[int]:
|
||||
"""
|
||||
Encode a single text chunk using BPE.
|
||||
|
||||
Args:
|
||||
text: Text chunk to encode
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
"""
|
||||
# Convert to bytes
|
||||
bytes_seq = text.encode('utf-8')
|
||||
tokens = list(bytes_seq)
|
||||
|
||||
# If no merges trained yet, return byte tokens
|
||||
if not self.merges:
|
||||
return tokens
|
||||
|
||||
# Apply merges in order
|
||||
# Sort merges by their token ID (merge order)
|
||||
sorted_merges = sorted(self.merges.items(), key=lambda x: x[1])
|
||||
|
||||
# Keep merging until no more merges are possible
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
best_pair = None
|
||||
best_idx = float('inf')
|
||||
|
||||
# Find the earliest merge we can apply
|
||||
for i in range(len(tokens) - 1):
|
||||
pair = (tokens[i], tokens[i + 1])
|
||||
if pair in self.merges:
|
||||
merge_idx = self.merges[pair]
|
||||
if merge_idx < best_idx:
|
||||
best_idx = merge_idx
|
||||
best_pair = pair
|
||||
|
||||
# Apply the best merge
|
||||
if best_pair is not None:
|
||||
merged_id = self.merges[best_pair]
|
||||
tokens = self._merge(tokens, best_pair, merged_id)
|
||||
changed = True
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, text: str, allowed_special: Optional[Set[str]] = None) -> List[int]:
|
||||
"""
|
||||
Encode text into token IDs.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
allowed_special: Set of special tokens to allow in text
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
"""
|
||||
# Check for trailing whitespace (warn if present)
|
||||
if text and text[-1] == ' ':
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Text ends with trailing whitespace. This may cause worse performance "
|
||||
"due to how the tokenizer splits text into tokens.",
|
||||
UserWarning
|
||||
)
|
||||
|
||||
# Handle special tokens
|
||||
if allowed_special:
|
||||
for special_name, special_id in self.special_tokens.items():
|
||||
if special_name in allowed_special and special_name in text:
|
||||
# Simple special token replacement (can be improved)
|
||||
if text == special_name:
|
||||
return [special_id]
|
||||
|
||||
# Split text using regex pattern
|
||||
chunks = self.pattern.findall(text)
|
||||
|
||||
# Encode each chunk
|
||||
tokens = []
|
||||
for chunk in chunks:
|
||||
chunk_tokens = self._encode_chunk(chunk)
|
||||
tokens.extend(chunk_tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def decode(self, token_ids: List[int], errors: str = 'replace') -> str:
|
||||
"""
|
||||
Decode token IDs back to text.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs
|
||||
errors: Error handling for invalid UTF-8 ('strict', 'replace', 'ignore')
|
||||
|
||||
Returns:
|
||||
Decoded text string
|
||||
"""
|
||||
# Handle special tokens
|
||||
if self.eos_token_id in token_ids:
|
||||
# Stop at EOS token
|
||||
eos_idx = token_ids.index(self.eos_token_id)
|
||||
token_ids = token_ids[:eos_idx]
|
||||
|
||||
# Convert tokens to bytes
|
||||
bytes_parts = []
|
||||
for token_id in token_ids:
|
||||
if token_id in self.special_tokens.values():
|
||||
# Skip special tokens (except maybe keep them for debugging)
|
||||
continue
|
||||
|
||||
if token_id in self.vocab:
|
||||
bytes_parts.append(self.vocab[token_id])
|
||||
else:
|
||||
# Unknown token - try to use byte representation
|
||||
if token_id < 256:
|
||||
bytes_parts.append(bytes([token_id]))
|
||||
else:
|
||||
# Unknown token - use replacement character
|
||||
bytes_parts.append(b'\ufffd')
|
||||
|
||||
# Concatenate bytes
|
||||
if not bytes_parts:
|
||||
return ''
|
||||
|
||||
try:
|
||||
combined_bytes = b''.join(bytes_parts)
|
||||
return combined_bytes.decode('utf-8', errors=errors)
|
||||
except UnicodeDecodeError:
|
||||
# Fallback with replacement
|
||||
return combined_bytes.decode('utf-8', errors='replace')
|
||||
|
||||
def save(self, merges_file: str, vocab_file: str):
|
||||
"""
|
||||
Save tokenizer to files.
|
||||
|
||||
Args:
|
||||
merges_file: Path to save merges
|
||||
vocab_file: Path to save vocabulary
|
||||
"""
|
||||
# Save merges
|
||||
merges_list = [
|
||||
(left, right, merged_id)
|
||||
for (left, right), merged_id in sorted(self.merges.items(), key=lambda x: x[1])
|
||||
]
|
||||
|
||||
with open(merges_file, 'w') as f:
|
||||
json.dump(merges_list, f, indent=2)
|
||||
|
||||
# Save vocabulary (convert bytes to base64 or hex)
|
||||
vocab_dict = {
|
||||
str(token_id): token_bytes.hex()
|
||||
for token_id, token_bytes in self.vocab.items()
|
||||
}
|
||||
|
||||
with open(vocab_file, 'w') as f:
|
||||
json.dump({
|
||||
'vocab': vocab_dict,
|
||||
'special_tokens': self.special_tokens,
|
||||
'next_token_id': self.next_token_id,
|
||||
}, f, indent=2)
|
||||
|
||||
def load(self, merges_file: str, vocab_file: str):
|
||||
"""
|
||||
Load tokenizer from files.
|
||||
|
||||
Args:
|
||||
merges_file: Path to merges file
|
||||
vocab_file: Path to vocabulary file
|
||||
"""
|
||||
# Load merges
|
||||
with open(merges_file, 'r') as f:
|
||||
merges_list = json.load(f)
|
||||
|
||||
for left, right, merged_id in merges_list:
|
||||
self.merges[(left, right)] = merged_id
|
||||
self.next_token_id = max(self.next_token_id, merged_id + 1)
|
||||
|
||||
# Load vocabulary
|
||||
with open(vocab_file, 'r') as f:
|
||||
vocab_data = json.load(f)
|
||||
|
||||
vocab_dict = vocab_data['vocab']
|
||||
for token_id_str, token_bytes_hex in vocab_dict.items():
|
||||
token_id = int(token_id_str)
|
||||
token_bytes = bytes.fromhex(token_bytes_hex)
|
||||
self.vocab[token_id] = token_bytes
|
||||
self.inv_vocab[token_bytes] = token_id
|
||||
|
||||
if 'special_tokens' in vocab_data:
|
||||
self.special_tokens.update(vocab_data['special_tokens'])
|
||||
|
||||
if 'next_token_id' in vocab_data:
|
||||
self.next_token_id = vocab_data['next_token_id']
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Get vocabulary size."""
|
||||
return len(self.vocab) + len(self.special_tokens)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
SimpleTokenizer = BPETokenizer
|
||||
|
||||
Reference in New Issue
Block a user