Files
sheepOp/data.example/bpe_tokenizer.py
Carlos Gutierrez 3d2da94ce2 Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch
- Training pipeline with gradient accumulation and mixed precision
- Optimized inference with KV caching
- Multi-format data processing (PDFs, images, code, text)
- Comprehensive documentation
- Apache 2.0 license
- Example training plots included in docs/images/
2025-11-06 22:07:41 -05:00

424 lines
14 KiB
Python

"""
Improved BPE Tokenizer based on GPT-4 tokenization approach
Addresses common tokenization challenges:
- UTF-8 byte-level encoding
- Better Python code handling
- Case-insensitive contraction matching
- Limited number merging (1-3 digits)
- Proper special token handling
- Trailing whitespace warnings
"""
import re
import json
from typing import List, Dict, Tuple, Optional, Set
from collections import defaultdict
from pathlib import Path
class BPETokenizer:
"""
Byte Pair Encoding Tokenizer with GPT-4-inspired improvements.
Key features:
- UTF-8 byte-level encoding
- BPE merging algorithm
- GPT-4 style regex pattern for text splitting
- Better whitespace handling for Python code
- Case-insensitive matching for contractions
- Limited number merging (1-3 digits)
"""
def __init__(
self,
vocab_size: int = 50257,
special_tokens: Optional[Dict[str, int]] = None,
merges_file: Optional[str] = None,
vocab_file: Optional[str] = None,
):
"""
Initialize BPE tokenizer.
Args:
vocab_size: Target vocabulary size (default 50257 for GPT-2 style)
special_tokens: Dictionary of special token names to IDs
merges_file: Path to saved merges file
vocab_file: Path to saved vocab file
"""
# Special tokens
self.special_tokens = special_tokens or {
'<pad>': 0,
'<unk>': 1,
'<bos>': 2,
'<eos>': 3,
}
# Initialize byte vocabulary (0-255)
self.byte_to_token = {i: i for i in range(256)}
self.token_to_byte = {i: bytes([i]) for i in range(256)}
self.next_token_id = 256
# BPE merges: (left, right) -> merged_token_id
self.merges: Dict[Tuple[int, int], int] = {}
# Vocabulary: token_id -> bytes
self.vocab: Dict[int, bytes] = {}
self.inv_vocab: Dict[bytes, int] = {}
# Initialize vocab with bytes
for i in range(256):
self.vocab[i] = bytes([i])
self.inv_vocab[bytes([i])] = i
# GPT-4 style regex pattern for splitting text
# Improvements over GPT-2:
# - Case-insensitive matching (flag)
# - Better whitespace handling
# - Limit number merging to 1-3 digits
self.pattern = self._create_gpt4_pattern()
# Load pre-trained tokenizer if files provided
if merges_file and vocab_file:
self.load(merges_file, vocab_file)
else:
self.target_vocab_size = vocab_size
# Token IDs for special tokens
self.pad_token_id = self.special_tokens.get('<pad>', 0)
self.unk_token_id = self.special_tokens.get('<unk>', 1)
self.bos_token_id = self.special_tokens.get('<bos>', 2)
self.eos_token_id = self.special_tokens.get('<eos>', 3)
def _create_gpt4_pattern(self) -> re.Pattern:
"""
Create GPT-4 style regex pattern for splitting text.
Improvements over GPT-2:
- Case-insensitive matching for contractions
- Better whitespace handling (groups multiple spaces)
- Limit number merging (1-3 digits)
"""
# GPT-4 style pattern with improvements
# Pattern breakdown:
# 1. Contractions: '(?i:[sdmt]|ll|ve|re) - case-insensitive
# 2. Letters: [^\r\n\p{L}\p{N}]?+\p{L}+ - optional space + letters
# 3. Numbers: \p{N}{1,3} - 1-3 digits only
# 4. Punctuation: ?[^\s\p{L}\p{N}]++ - optional space + non-letter/number
# 5. Whitespace: \r?\n - newlines
# 6. Trailing whitespace: \s+ - multiple spaces
pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++|\r?\n|\s+"""
# Compile with case-insensitive flag for contractions
return re.compile(pattern, re.IGNORECASE | re.UNICODE)
def _get_stats(self, tokens: List[int]) -> Dict[Tuple[int, int], int]:
"""
Get statistics of consecutive token pairs.
Args:
tokens: List of token IDs
Returns:
Dictionary mapping pair tuples to counts
"""
stats = defaultdict(int)
for i in range(len(tokens) - 1):
pair = (tokens[i], tokens[i + 1])
stats[pair] += 1
return dict(stats)
def _merge(self, tokens: List[int], pair: Tuple[int, int], new_id: int) -> List[int]:
"""
Merge consecutive occurrences of a pair into a new token.
Args:
tokens: List of token IDs
pair: Tuple of (left, right) tokens to merge
new_id: New token ID to replace the pair
Returns:
New list with merged tokens
"""
if len(tokens) < 2:
return tokens
new_tokens = []
i = 0
while i < len(tokens):
# Check if we can merge at position i
if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
new_tokens.append(new_id)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
return new_tokens
def train(
self,
texts: List[str],
num_merges: Optional[int] = None,
verbose: bool = False,
):
"""
Train BPE tokenizer on a corpus of texts.
Args:
texts: List of training texts
num_merges: Number of merges to perform (default: vocab_size - 256)
verbose: Whether to print progress
"""
if num_merges is None:
num_merges = self.target_vocab_size - 256
# Convert all texts to byte sequences
all_tokens = []
for text in texts:
# Split text using regex pattern
chunks = self.pattern.findall(text)
# Convert each chunk to bytes and tokenize
for chunk in chunks:
bytes_seq = chunk.encode('utf-8')
tokens = list(bytes_seq)
all_tokens.extend(tokens)
# Add separator between chunks (optional)
# all_tokens.append(256) # separator token
# Perform BPE merges
for merge_num in range(num_merges):
# Get statistics
stats = self._get_stats(all_tokens)
if not stats:
break
# Find most frequent pair
pair = max(stats, key=stats.get)
# Create new token
new_id = self.next_token_id
self.next_token_id += 1
# Merge
all_tokens = self._merge(all_tokens, pair, new_id)
# Store merge
self.merges[pair] = new_id
# Update vocabulary
left_bytes = self.vocab.get(pair[0], bytes([pair[0]]))
right_bytes = self.vocab.get(pair[1], bytes([pair[1]]))
merged_bytes = left_bytes + right_bytes
self.vocab[new_id] = merged_bytes
self.inv_vocab[merged_bytes] = new_id
if verbose and (merge_num + 1) % 1000 == 0:
print(f"Merged {merge_num + 1}/{num_merges} pairs")
if verbose:
print(f"Training complete. Vocabulary size: {len(self.vocab)}")
def _encode_chunk(self, text: str) -> List[int]:
"""
Encode a single text chunk using BPE.
Args:
text: Text chunk to encode
Returns:
List of token IDs
"""
# Convert to bytes
bytes_seq = text.encode('utf-8')
tokens = list(bytes_seq)
# If no merges trained yet, return byte tokens
if not self.merges:
return tokens
# Apply merges in order
# Sort merges by their token ID (merge order)
sorted_merges = sorted(self.merges.items(), key=lambda x: x[1])
# Keep merging until no more merges are possible
changed = True
while changed:
changed = False
best_pair = None
best_idx = float('inf')
# Find the earliest merge we can apply
for i in range(len(tokens) - 1):
pair = (tokens[i], tokens[i + 1])
if pair in self.merges:
merge_idx = self.merges[pair]
if merge_idx < best_idx:
best_idx = merge_idx
best_pair = pair
# Apply the best merge
if best_pair is not None:
merged_id = self.merges[best_pair]
tokens = self._merge(tokens, best_pair, merged_id)
changed = True
return tokens
def encode(self, text: str, allowed_special: Optional[Set[str]] = None) -> List[int]:
"""
Encode text into token IDs.
Args:
text: Input text
allowed_special: Set of special tokens to allow in text
Returns:
List of token IDs
"""
# Check for trailing whitespace (warn if present)
if text and text[-1] == ' ':
import warnings
warnings.warn(
"Text ends with trailing whitespace. This may cause worse performance "
"due to how the tokenizer splits text into tokens.",
UserWarning
)
# Handle special tokens
if allowed_special:
for special_name, special_id in self.special_tokens.items():
if special_name in allowed_special and special_name in text:
# Simple special token replacement (can be improved)
if text == special_name:
return [special_id]
# Split text using regex pattern
chunks = self.pattern.findall(text)
# Encode each chunk
tokens = []
for chunk in chunks:
chunk_tokens = self._encode_chunk(chunk)
tokens.extend(chunk_tokens)
return tokens
def decode(self, token_ids: List[int], errors: str = 'replace') -> str:
"""
Decode token IDs back to text.
Args:
token_ids: List of token IDs
errors: Error handling for invalid UTF-8 ('strict', 'replace', 'ignore')
Returns:
Decoded text string
"""
# Handle special tokens
if self.eos_token_id in token_ids:
# Stop at EOS token
eos_idx = token_ids.index(self.eos_token_id)
token_ids = token_ids[:eos_idx]
# Convert tokens to bytes
bytes_parts = []
for token_id in token_ids:
if token_id in self.special_tokens.values():
# Skip special tokens (except maybe keep them for debugging)
continue
if token_id in self.vocab:
bytes_parts.append(self.vocab[token_id])
else:
# Unknown token - try to use byte representation
if token_id < 256:
bytes_parts.append(bytes([token_id]))
else:
# Unknown token - use replacement character
bytes_parts.append(b'\ufffd')
# Concatenate bytes
if not bytes_parts:
return ''
try:
combined_bytes = b''.join(bytes_parts)
return combined_bytes.decode('utf-8', errors=errors)
except UnicodeDecodeError:
# Fallback with replacement
return combined_bytes.decode('utf-8', errors='replace')
def save(self, merges_file: str, vocab_file: str):
"""
Save tokenizer to files.
Args:
merges_file: Path to save merges
vocab_file: Path to save vocabulary
"""
# Save merges
merges_list = [
(left, right, merged_id)
for (left, right), merged_id in sorted(self.merges.items(), key=lambda x: x[1])
]
with open(merges_file, 'w') as f:
json.dump(merges_list, f, indent=2)
# Save vocabulary (convert bytes to base64 or hex)
vocab_dict = {
str(token_id): token_bytes.hex()
for token_id, token_bytes in self.vocab.items()
}
with open(vocab_file, 'w') as f:
json.dump({
'vocab': vocab_dict,
'special_tokens': self.special_tokens,
'next_token_id': self.next_token_id,
}, f, indent=2)
def load(self, merges_file: str, vocab_file: str):
"""
Load tokenizer from files.
Args:
merges_file: Path to merges file
vocab_file: Path to vocabulary file
"""
# Load merges
with open(merges_file, 'r') as f:
merges_list = json.load(f)
for left, right, merged_id in merges_list:
self.merges[(left, right)] = merged_id
self.next_token_id = max(self.next_token_id, merged_id + 1)
# Load vocabulary
with open(vocab_file, 'r') as f:
vocab_data = json.load(f)
vocab_dict = vocab_data['vocab']
for token_id_str, token_bytes_hex in vocab_dict.items():
token_id = int(token_id_str)
token_bytes = bytes.fromhex(token_bytes_hex)
self.vocab[token_id] = token_bytes
self.inv_vocab[token_bytes] = token_id
if 'special_tokens' in vocab_data:
self.special_tokens.update(vocab_data['special_tokens'])
if 'next_token_id' in vocab_data:
self.next_token_id = vocab_data['next_token_id']
@property
def vocab_size(self) -> int:
"""Get vocabulary size."""
return len(self.vocab) + len(self.special_tokens)
# Backward compatibility alias
SimpleTokenizer = BPETokenizer