Initial commit: SheepOp LLM - Transformer-based language model implementation
- Complete transformer implementation from scratch - Training pipeline with gradient accumulation and mixed precision - Optimized inference with KV caching - Multi-format data processing (PDFs, images, code, text) - Comprehensive documentation - Apache 2.0 license - Example training plots included in docs/images/
This commit is contained in:
304
extract_from_database.py
Normal file
304
extract_from_database.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Database extraction utility for training data
|
||||
Extracts text from various database types and formats for LLM training
|
||||
"""
|
||||
import sqlite3
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Iterator
|
||||
import json
|
||||
|
||||
|
||||
def extract_from_sqlite(
|
||||
db_path: str,
|
||||
table: str,
|
||||
text_column: str,
|
||||
limit: Optional[int] = None,
|
||||
where_clause: Optional[str] = None,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Extract text from SQLite database.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
table: Table name to extract from
|
||||
text_column: Column name containing text data
|
||||
limit: Maximum number of rows to extract (None = all)
|
||||
where_clause: Optional WHERE clause (e.g., "WHERE length(text) > 100")
|
||||
|
||||
Yields:
|
||||
Text strings from the database
|
||||
"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = f"SELECT {text_column} FROM {table}"
|
||||
if where_clause:
|
||||
query += f" {where_clause}"
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
cursor.execute(query)
|
||||
|
||||
for row in cursor:
|
||||
text = row[0]
|
||||
if text and isinstance(text, str) and len(text.strip()) > 0:
|
||||
# Clean and split text into sentences/lines
|
||||
cleaned_text = text.strip()
|
||||
yield cleaned_text
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
def extract_from_sql(
|
||||
connection_string: str,
|
||||
query: str,
|
||||
text_column: int = 0,
|
||||
batch_size: int = 1000,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Extract text using a raw SQL query.
|
||||
Works with any database that supports the connection string format.
|
||||
|
||||
Args:
|
||||
connection_string: Database connection string
|
||||
query: SQL query to execute
|
||||
text_column: Column index containing text (0-based)
|
||||
batch_size: Number of rows to fetch at once
|
||||
|
||||
Yields:
|
||||
Text strings from the database
|
||||
"""
|
||||
try:
|
||||
import psycopg2 # PostgreSQL
|
||||
conn = psycopg2.connect(connection_string)
|
||||
except ImportError:
|
||||
try:
|
||||
import pymysql # MySQL
|
||||
conn = pymysql.connect(connection_string)
|
||||
except ImportError:
|
||||
raise ImportError("Install psycopg2 for PostgreSQL or pymysql for MySQL")
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query)
|
||||
|
||||
while True:
|
||||
rows = cursor.fetchmany(batch_size)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
text = row[text_column]
|
||||
if text and isinstance(text, str) and len(text.strip()) > 0:
|
||||
yield text.strip()
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
def extract_from_json_file(
|
||||
json_path: str,
|
||||
text_field: str,
|
||||
limit: Optional[int] = None,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Extract text from JSON file (e.g., JSONL format).
|
||||
|
||||
Args:
|
||||
json_path: Path to JSON file
|
||||
text_field: Field name containing text (use dot notation for nested: "data.text")
|
||||
limit: Maximum number of records to extract
|
||||
|
||||
Yields:
|
||||
Text strings from the JSON file
|
||||
"""
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
count = 0
|
||||
for line in f:
|
||||
if limit and count >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
|
||||
# Handle nested fields with dot notation
|
||||
fields = text_field.split('.')
|
||||
value = data
|
||||
for field in fields:
|
||||
value = value.get(field)
|
||||
if value is None:
|
||||
break
|
||||
|
||||
if value and isinstance(value, str) and len(value.strip()) > 0:
|
||||
yield value.strip()
|
||||
count += 1
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
|
||||
def clean_and_split_text(text: str, min_length: int = 10) -> List[str]:
|
||||
"""
|
||||
Clean text and split into sentences/lines.
|
||||
|
||||
Args:
|
||||
text: Raw text string
|
||||
min_length: Minimum length for a text sample
|
||||
|
||||
Returns:
|
||||
List of cleaned text samples
|
||||
"""
|
||||
import re
|
||||
|
||||
# Remove extra whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
|
||||
# Split by sentences (periods, exclamation, question marks)
|
||||
sentences = re.split(r'[.!?]+\s+', text)
|
||||
|
||||
# Also split by newlines
|
||||
lines = []
|
||||
for sentence in sentences:
|
||||
lines.extend(sentence.split('\n'))
|
||||
|
||||
# Clean and filter
|
||||
cleaned = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if len(line) >= min_length:
|
||||
cleaned.append(line)
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def save_to_training_file(
|
||||
texts: Iterator[str],
|
||||
output_path: str,
|
||||
min_length: int = 10,
|
||||
max_samples: Optional[int] = None,
|
||||
clean_text: bool = True,
|
||||
):
|
||||
"""
|
||||
Save extracted texts to training file.
|
||||
|
||||
Args:
|
||||
texts: Iterator of text strings
|
||||
output_path: Path to save training data
|
||||
min_length: Minimum length for text samples
|
||||
max_samples: Maximum number of samples to save
|
||||
clean_text: Whether to clean and split text
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
count = 0
|
||||
total_texts = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for text in texts:
|
||||
if max_samples and count >= max_samples:
|
||||
break
|
||||
|
||||
if clean_text:
|
||||
# Clean and split into sentences
|
||||
cleaned_texts = clean_and_split_text(text, min_length)
|
||||
for cleaned in cleaned_texts:
|
||||
if max_samples and count >= max_samples:
|
||||
break
|
||||
f.write(cleaned + '\n')
|
||||
count += 1
|
||||
else:
|
||||
# Write as-is
|
||||
if len(text.strip()) >= min_length:
|
||||
f.write(text.strip() + '\n')
|
||||
count += 1
|
||||
|
||||
total_texts += 1
|
||||
|
||||
# Progress update every 1000 texts
|
||||
if total_texts % 1000 == 0:
|
||||
print(f"Processed {total_texts} texts, saved {count} samples...")
|
||||
|
||||
print(f"\n✅ Extraction complete!")
|
||||
print(f" Total texts processed: {total_texts}")
|
||||
print(f" Samples saved: {count}")
|
||||
print(f" Output file: {output_path}")
|
||||
print(f" File size: {output_path.stat().st_size / (1024*1024):.2f} MB")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Extract text from database for training')
|
||||
parser.add_argument('--type', type=str, choices=['sqlite', 'sql', 'json'],
|
||||
required=True, help='Database type')
|
||||
parser.add_argument('--output', type=str, default='data/database_extracted.txt',
|
||||
help='Output file path')
|
||||
parser.add_argument('--limit', type=int, help='Maximum number of samples to extract')
|
||||
parser.add_argument('--min-length', type=int, default=10,
|
||||
help='Minimum text length')
|
||||
|
||||
# SQLite options
|
||||
parser.add_argument('--db-path', type=str, help='SQLite database path')
|
||||
parser.add_argument('--table', type=str, help='Table name')
|
||||
parser.add_argument('--column', type=str, help='Text column name')
|
||||
parser.add_argument('--where', type=str, help='WHERE clause (e.g., "WHERE length(text) > 100")')
|
||||
|
||||
# SQL query options
|
||||
parser.add_argument('--connection', type=str, help='Database connection string')
|
||||
parser.add_argument('--query', type=str, help='SQL query')
|
||||
parser.add_argument('--text-column', type=int, default=0, help='Text column index (0-based)')
|
||||
|
||||
# JSON options
|
||||
parser.add_argument('--json-path', type=str, help='JSON/JSONL file path')
|
||||
parser.add_argument('--text-field', type=str, help='JSON field name containing text')
|
||||
|
||||
parser.add_argument('--no-clean', action='store_true', help='Do not clean/split text')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Extract based on type
|
||||
if args.type == 'sqlite':
|
||||
if not all([args.db_path, args.table, args.column]):
|
||||
print("Error: --db-path, --table, and --column required for SQLite")
|
||||
return
|
||||
|
||||
texts = extract_from_sqlite(
|
||||
db_path=args.db_path,
|
||||
table=args.table,
|
||||
text_column=args.column,
|
||||
limit=args.limit,
|
||||
where_clause=args.where,
|
||||
)
|
||||
|
||||
elif args.type == 'sql':
|
||||
if not all([args.connection, args.query]):
|
||||
print("Error: --connection and --query required for SQL")
|
||||
return
|
||||
|
||||
texts = extract_from_sql(
|
||||
connection_string=args.connection,
|
||||
query=args.query,
|
||||
text_column=args.text_column,
|
||||
)
|
||||
|
||||
elif args.type == 'json':
|
||||
if not all([args.json_path, args.text_field]):
|
||||
print("Error: --json-path and --text-field required for JSON")
|
||||
return
|
||||
|
||||
texts = extract_from_json_file(
|
||||
json_path=args.json_path,
|
||||
text_field=args.text_field,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
# Save to training file
|
||||
save_to_training_file(
|
||||
texts=texts,
|
||||
output_path=args.output,
|
||||
min_length=args.min_length,
|
||||
max_samples=args.limit,
|
||||
clean_text=not args.no_clean,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user