Fixing script

This commit is contained in:
Carlos Gutierrez
2025-11-28 11:47:25 -05:00
parent b0a8941344
commit 2a249af486
2 changed files with 177 additions and 53 deletions

View File

@@ -13,6 +13,9 @@ import hashlib
import pickle import pickle
import os import os
import sys import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
import multiprocessing
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -257,6 +260,32 @@ def create_dataloader(
) )
# ============================================================================
# Standalone function for multiprocessing (must be at module level to be picklable)
# ============================================================================
def _process_file_worker(file_path_str: str, use_ocr: bool, use_pdf_extraction: bool) -> tuple:
"""
Worker function for processing a single file in parallel.
Must be at module level to be picklable for multiprocessing.
Args:
file_path_str: File path as string
use_ocr: Whether to use OCR
use_pdf_extraction: Whether to extract PDFs
Returns:
Tuple of (file_path, lines_list, error_string)
"""
try:
file_path = Path(file_path_str)
processor = DataProcessor(use_ocr=use_ocr, use_pdf_extraction=use_pdf_extraction)
file_lines = list(processor.process_file(file_path))
return (str(file_path), file_lines, None)
except Exception as e:
return (str(file_path_str), [], str(e))
# ============================================================================ # ============================================================================
# Data Processor for Multiple File Types # Data Processor for Multiple File Types
# ============================================================================ # ============================================================================
@@ -572,6 +601,10 @@ class DataProcessor:
include_patterns: Optional[List[str]] = None, include_patterns: Optional[List[str]] = None,
exclude_patterns: Optional[List[str]] = None, exclude_patterns: Optional[List[str]] = None,
min_length: int = 10, min_length: int = 10,
max_files: Optional[int] = None,
num_workers: int = 0,
skip_images: bool = False,
skip_pdfs: bool = False,
) -> Iterator[str]: ) -> Iterator[str]:
""" """
Process all files in a directory. Process all files in a directory.
@@ -582,6 +615,10 @@ class DataProcessor:
include_patterns: Optional list of glob patterns to include include_patterns: Optional list of glob patterns to include
exclude_patterns: Optional list of glob patterns to exclude exclude_patterns: Optional list of glob patterns to exclude
min_length: Minimum length for extracted text lines min_length: Minimum length for extracted text lines
max_files: Maximum number of files to process (None = all)
num_workers: Number of parallel workers (0 = sequential, -1 = auto)
skip_images: Skip image files (faster processing)
skip_pdfs: Skip PDF files (faster processing)
Yields: Yields:
Text lines from all processed files Text lines from all processed files
@@ -625,12 +662,11 @@ class DataProcessor:
logger.info("Scanning directory (cache miss or invalid)...") logger.info("Scanning directory (cache miss or invalid)...")
# Collect all supported file extensions # Collect all supported file extensions
all_supported_extensions = ( all_supported_extensions = self.TEXT_EXTENSIONS | self.CODE_EXTENSIONS
self.TEXT_EXTENSIONS | if not skip_pdfs:
self.CODE_EXTENSIONS | all_supported_extensions |= self.PDF_EXTENSIONS
self.PDF_EXTENSIONS | if not skip_images:
self.IMAGE_EXTENSIONS all_supported_extensions |= self.IMAGE_EXTENSIONS
)
if recursive: if recursive:
pattern = '**/*' pattern = '**/*'
@@ -715,6 +751,11 @@ class DataProcessor:
logger.error(f"Error during directory scanning: {e}") logger.error(f"Error during directory scanning: {e}")
logger.info(f"Continuing with {len(files_to_process)} files found so far...") logger.info(f"Continuing with {len(files_to_process)} files found so far...")
# Limit number of files if max_files is specified
if max_files and len(files_to_process) > max_files:
logger.info(f"Limiting to first {max_files:,} files (found {len(files_to_process):,} total)")
files_to_process = files_to_process[:max_files]
if skipped_count > 10: if skipped_count > 10:
logger.info(f"Skipped {skipped_count} inaccessible paths") logger.info(f"Skipped {skipped_count} inaccessible paths")
@@ -742,7 +783,16 @@ class DataProcessor:
logger.warning("No files found to process!") logger.warning("No files found to process!")
return return
logger.info(f"Starting to process {total_files} files with progress bar...") # Determine number of workers
if num_workers == -1:
num_workers = max(1, multiprocessing.cpu_count() - 1)
elif num_workers == 0:
num_workers = 1 # Sequential processing
if num_workers > 1:
logger.info(f"Starting to process {total_files:,} files with {num_workers} parallel workers...")
else:
logger.info(f"Starting to process {total_files:,} files sequentially...")
# Create progress bar # Create progress bar
pbar = tqdm( pbar = tqdm(
@@ -758,51 +808,95 @@ class DataProcessor:
) )
try: try:
for idx, file_path in enumerate(files_to_process, 1): if num_workers > 1:
try: # Parallel processing
file_lines = list(self.process_file(file_path)) # Use module-level function for pickling
if file_lines: worker_func = partial(
processed_count += 1 _process_file_worker,
for line in file_lines: use_ocr=self.use_ocr and not skip_images,
if len(line) >= min_length: use_pdf_extraction=self.use_pdf_extraction and not skip_pdfs,
yield line )
total_lines += 1 with ProcessPoolExecutor(max_workers=num_workers) as executor:
else: # Submit all tasks
skipped_count += 1 future_to_file = {
executor.submit(worker_func, str(file_path)): file_path
for file_path in files_to_process
}
# Update progress bar with statistics # Process completed tasks
pbar.set_postfix({ for future in as_completed(future_to_file):
'Processed': processed_count, file_path_str, file_lines, error = future.result()
'Skipped': skipped_count, file_path = Path(file_path_str)
'Errors': error_count,
'Lines': f"{total_lines:,}" if error:
}) error_count += 1
pbar.update(1) # Advance progress bar logger.error(f"Error processing {file_path}: {error}")
pbar.refresh() # Force immediate refresh elif file_lines:
sys.stderr.flush() # Force flush stderr to ensure progress bar displays processed_count += 1
for line in file_lines:
except KeyboardInterrupt: if len(line) >= min_length:
pbar.close() yield line
logger.warning( total_lines += 1
f"Processing interrupted. " else:
f"Files: {idx}/{total_files}, Processed: {processed_count}, " skipped_count += 1
f"Skipped: {skipped_count}, Errors: {error_count}, "
f"Lines extracted: {total_lines:,}" # Update progress bar
) pbar.set_postfix({
raise 'Processed': processed_count,
except Exception as e: 'Skipped': skipped_count,
error_count += 1 'Errors': error_count,
logger.error(f"Error processing {file_path}: {e}") 'Lines': f"{total_lines:,}"
# Update progress bar even on errors })
pbar.set_postfix({ pbar.update(1)
'Processed': processed_count, pbar.refresh()
'Skipped': skipped_count, sys.stderr.flush()
'Errors': error_count, else:
'Lines': f"{total_lines:,}" # Sequential processing (original code)
}) for idx, file_path in enumerate(files_to_process, 1):
pbar.update(1) # Advance progress bar even on error try:
pbar.refresh() # Force immediate refresh file_lines = list(self.process_file(file_path))
sys.stderr.flush() # Force flush stderr to ensure progress bar displays if file_lines:
processed_count += 1
for line in file_lines:
if len(line) >= min_length:
yield line
total_lines += 1
else:
skipped_count += 1
# Update progress bar with statistics
pbar.set_postfix({
'Processed': processed_count,
'Skipped': skipped_count,
'Errors': error_count,
'Lines': f"{total_lines:,}"
})
pbar.update(1) # Advance progress bar
pbar.refresh() # Force immediate refresh
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
except KeyboardInterrupt:
pbar.close()
logger.warning(
f"Processing interrupted. "
f"Files: {idx}/{total_files}, Processed: {processed_count}, "
f"Skipped: {skipped_count}, Errors: {error_count}, "
f"Lines extracted: {total_lines:,}"
)
raise
except Exception as e:
error_count += 1
logger.error(f"Error processing {file_path}: {e}")
# Update progress bar even on errors
pbar.set_postfix({
'Processed': processed_count,
'Skipped': skipped_count,
'Errors': error_count,
'Lines': f"{total_lines:,}"
})
pbar.update(1) # Advance progress bar even on error
pbar.refresh() # Force immediate refresh
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
finally: finally:
pbar.close() pbar.close()
@@ -819,6 +913,10 @@ class DataProcessor:
exclude_patterns: Optional[List[str]] = None, exclude_patterns: Optional[List[str]] = None,
min_length: int = 10, min_length: int = 10,
max_samples: Optional[int] = None, max_samples: Optional[int] = None,
max_files: Optional[int] = None,
num_workers: int = 0,
skip_images: bool = False,
skip_pdfs: bool = False,
) -> List[str]: ) -> List[str]:
""" """
Process directory and return list of text lines. Process directory and return list of text lines.
@@ -846,6 +944,10 @@ class DataProcessor:
include_patterns=include_patterns, include_patterns=include_patterns,
exclude_patterns=exclude_patterns, exclude_patterns=exclude_patterns,
min_length=min_length, min_length=min_length,
max_files=max_files,
num_workers=num_workers,
skip_images=skip_images,
skip_pdfs=skip_pdfs,
): ):
texts.append(text) texts.append(text)
if max_samples and len(texts) >= max_samples: if max_samples and len(texts) >= max_samples:
@@ -870,6 +972,10 @@ def extract_text_from_directory(
use_pdf_extraction: bool = True, use_pdf_extraction: bool = True,
min_length: int = 10, min_length: int = 10,
max_samples: Optional[int] = None, max_samples: Optional[int] = None,
max_files: Optional[int] = None,
num_workers: int = 0,
skip_images: bool = False,
skip_pdfs: bool = False,
) -> List[str]: ) -> List[str]:
""" """
Convenience function to extract text from a directory. Convenience function to extract text from a directory.
@@ -881,6 +987,10 @@ def extract_text_from_directory(
use_pdf_extraction: Whether to extract text from PDFs use_pdf_extraction: Whether to extract text from PDFs
min_length: Minimum length for extracted text lines min_length: Minimum length for extracted text lines
max_samples: Maximum number of samples to return (None = all) max_samples: Maximum number of samples to return (None = all)
max_files: Maximum number of files to process (None = all)
num_workers: Number of parallel workers (0 = sequential, -1 = auto)
skip_images: Skip image files entirely (faster processing)
skip_pdfs: Skip PDF files entirely (faster processing)
Returns: Returns:
List of text lines List of text lines
@@ -892,6 +1002,10 @@ def extract_text_from_directory(
recursive=recursive, recursive=recursive,
min_length=min_length, min_length=min_length,
max_samples=max_samples, max_samples=max_samples,
max_files=max_files,
num_workers=num_workers,
skip_images=skip_images,
skip_pdfs=skip_pdfs,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
logger.error( logger.error(

View File

@@ -73,6 +73,12 @@ def main():
parser.add_argument('--data', type=str, required=True, help='Path to training data') parser.add_argument('--data', type=str, required=True, help='Path to training data')
parser.add_argument('--output', type=str, default='./checkpoints', help='Output directory') parser.add_argument('--output', type=str, default='./checkpoints', help='Output directory')
parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from') parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from')
parser.add_argument('--max-files', type=int, default=None, help='Maximum number of files to process (None = all)')
parser.add_argument('--data-workers', type=int, default=0, help='Number of parallel workers for data processing (0 = sequential, -1 = auto)')
parser.add_argument('--skip-images', action='store_true', help='Skip image files (faster processing)')
parser.add_argument('--skip-pdfs', action='store_true', help='Skip PDF files (faster processing)')
parser.add_argument('--no-ocr', action='store_true', help='Disable OCR for images')
parser.add_argument('--no-pdf-extraction', action='store_true', help='Disable PDF text extraction')
# Auto-detect best device # Auto-detect best device
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
@@ -143,9 +149,13 @@ def main():
texts = extract_text_from_directory( texts = extract_text_from_directory(
directory=data_path, directory=data_path,
recursive=True, recursive=True,
use_ocr=True, # Enable OCR for images use_ocr=not args.no_ocr, # Enable OCR for images unless disabled
use_pdf_extraction=True, # Enable PDF extraction use_pdf_extraction=not args.no_pdf_extraction, # Enable PDF extraction unless disabled
min_length=10, # Minimum length for text lines min_length=10, # Minimum length for text lines
max_files=args.max_files, # Limit number of files if specified
num_workers=args.data_workers, # Parallel processing workers
skip_images=args.skip_images, # Skip images entirely
skip_pdfs=args.skip_pdfs, # Skip PDFs entirely
) )
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n⚠️ Data processing interrupted by user (Ctrl+C).") print("\n\n⚠️ Data processing interrupted by user (Ctrl+C).")