diff --git a/data.example/__init__.py b/data.example/__init__.py index 9d2df87..b4629f0 100644 --- a/data.example/__init__.py +++ b/data.example/__init__.py @@ -13,6 +13,9 @@ import hashlib import pickle import os import sys +from concurrent.futures import ProcessPoolExecutor, as_completed +from functools import partial +import multiprocessing logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -257,6 +260,32 @@ def create_dataloader( ) +# ============================================================================ +# Standalone function for multiprocessing (must be at module level to be picklable) +# ============================================================================ + +def _process_file_worker(file_path_str: str, use_ocr: bool, use_pdf_extraction: bool) -> tuple: + """ + Worker function for processing a single file in parallel. + Must be at module level to be picklable for multiprocessing. + + Args: + file_path_str: File path as string + use_ocr: Whether to use OCR + use_pdf_extraction: Whether to extract PDFs + + Returns: + Tuple of (file_path, lines_list, error_string) + """ + try: + file_path = Path(file_path_str) + processor = DataProcessor(use_ocr=use_ocr, use_pdf_extraction=use_pdf_extraction) + file_lines = list(processor.process_file(file_path)) + return (str(file_path), file_lines, None) + except Exception as e: + return (str(file_path_str), [], str(e)) + + # ============================================================================ # Data Processor for Multiple File Types # ============================================================================ @@ -572,6 +601,10 @@ class DataProcessor: include_patterns: Optional[List[str]] = None, exclude_patterns: Optional[List[str]] = None, min_length: int = 10, + max_files: Optional[int] = None, + num_workers: int = 0, + skip_images: bool = False, + skip_pdfs: bool = False, ) -> Iterator[str]: """ Process all files in a directory. @@ -582,6 +615,10 @@ class DataProcessor: include_patterns: Optional list of glob patterns to include exclude_patterns: Optional list of glob patterns to exclude min_length: Minimum length for extracted text lines + max_files: Maximum number of files to process (None = all) + num_workers: Number of parallel workers (0 = sequential, -1 = auto) + skip_images: Skip image files (faster processing) + skip_pdfs: Skip PDF files (faster processing) Yields: Text lines from all processed files @@ -625,12 +662,11 @@ class DataProcessor: logger.info("Scanning directory (cache miss or invalid)...") # Collect all supported file extensions - all_supported_extensions = ( - self.TEXT_EXTENSIONS | - self.CODE_EXTENSIONS | - self.PDF_EXTENSIONS | - self.IMAGE_EXTENSIONS - ) + all_supported_extensions = self.TEXT_EXTENSIONS | self.CODE_EXTENSIONS + if not skip_pdfs: + all_supported_extensions |= self.PDF_EXTENSIONS + if not skip_images: + all_supported_extensions |= self.IMAGE_EXTENSIONS if recursive: pattern = '**/*' @@ -715,6 +751,11 @@ class DataProcessor: logger.error(f"Error during directory scanning: {e}") logger.info(f"Continuing with {len(files_to_process)} files found so far...") + # Limit number of files if max_files is specified + if max_files and len(files_to_process) > max_files: + logger.info(f"Limiting to first {max_files:,} files (found {len(files_to_process):,} total)") + files_to_process = files_to_process[:max_files] + if skipped_count > 10: logger.info(f"Skipped {skipped_count} inaccessible paths") @@ -742,7 +783,16 @@ class DataProcessor: logger.warning("No files found to process!") return - logger.info(f"Starting to process {total_files} files with progress bar...") + # Determine number of workers + if num_workers == -1: + num_workers = max(1, multiprocessing.cpu_count() - 1) + elif num_workers == 0: + num_workers = 1 # Sequential processing + + if num_workers > 1: + logger.info(f"Starting to process {total_files:,} files with {num_workers} parallel workers...") + else: + logger.info(f"Starting to process {total_files:,} files sequentially...") # Create progress bar pbar = tqdm( @@ -758,51 +808,95 @@ class DataProcessor: ) try: - for idx, file_path in enumerate(files_to_process, 1): - try: - file_lines = list(self.process_file(file_path)) - if file_lines: - processed_count += 1 - for line in file_lines: - if len(line) >= min_length: - yield line - total_lines += 1 - else: - skipped_count += 1 + if num_workers > 1: + # Parallel processing + # Use module-level function for pickling + worker_func = partial( + _process_file_worker, + use_ocr=self.use_ocr and not skip_images, + use_pdf_extraction=self.use_pdf_extraction and not skip_pdfs, + ) + with ProcessPoolExecutor(max_workers=num_workers) as executor: + # Submit all tasks + future_to_file = { + executor.submit(worker_func, str(file_path)): file_path + for file_path in files_to_process + } - # Update progress bar with statistics - pbar.set_postfix({ - 'Processed': processed_count, - 'Skipped': skipped_count, - 'Errors': error_count, - 'Lines': f"{total_lines:,}" - }) - pbar.update(1) # Advance progress bar - pbar.refresh() # Force immediate refresh - sys.stderr.flush() # Force flush stderr to ensure progress bar displays - - except KeyboardInterrupt: - pbar.close() - logger.warning( - f"Processing interrupted. " - f"Files: {idx}/{total_files}, Processed: {processed_count}, " - f"Skipped: {skipped_count}, Errors: {error_count}, " - f"Lines extracted: {total_lines:,}" - ) - raise - except Exception as e: - error_count += 1 - logger.error(f"Error processing {file_path}: {e}") - # Update progress bar even on errors - pbar.set_postfix({ - 'Processed': processed_count, - 'Skipped': skipped_count, - 'Errors': error_count, - 'Lines': f"{total_lines:,}" - }) - pbar.update(1) # Advance progress bar even on error - pbar.refresh() # Force immediate refresh - sys.stderr.flush() # Force flush stderr to ensure progress bar displays + # Process completed tasks + for future in as_completed(future_to_file): + file_path_str, file_lines, error = future.result() + file_path = Path(file_path_str) + + if error: + error_count += 1 + logger.error(f"Error processing {file_path}: {error}") + elif file_lines: + processed_count += 1 + for line in file_lines: + if len(line) >= min_length: + yield line + total_lines += 1 + else: + skipped_count += 1 + + # Update progress bar + pbar.set_postfix({ + 'Processed': processed_count, + 'Skipped': skipped_count, + 'Errors': error_count, + 'Lines': f"{total_lines:,}" + }) + pbar.update(1) + pbar.refresh() + sys.stderr.flush() + else: + # Sequential processing (original code) + for idx, file_path in enumerate(files_to_process, 1): + try: + file_lines = list(self.process_file(file_path)) + if file_lines: + processed_count += 1 + for line in file_lines: + if len(line) >= min_length: + yield line + total_lines += 1 + else: + skipped_count += 1 + + # Update progress bar with statistics + pbar.set_postfix({ + 'Processed': processed_count, + 'Skipped': skipped_count, + 'Errors': error_count, + 'Lines': f"{total_lines:,}" + }) + pbar.update(1) # Advance progress bar + pbar.refresh() # Force immediate refresh + sys.stderr.flush() # Force flush stderr to ensure progress bar displays + + except KeyboardInterrupt: + pbar.close() + logger.warning( + f"Processing interrupted. " + f"Files: {idx}/{total_files}, Processed: {processed_count}, " + f"Skipped: {skipped_count}, Errors: {error_count}, " + f"Lines extracted: {total_lines:,}" + ) + raise + except Exception as e: + error_count += 1 + logger.error(f"Error processing {file_path}: {e}") + # Update progress bar even on errors + pbar.set_postfix({ + 'Processed': processed_count, + 'Skipped': skipped_count, + 'Errors': error_count, + 'Lines': f"{total_lines:,}" + }) + pbar.update(1) # Advance progress bar even on error + pbar.refresh() # Force immediate refresh + sys.stderr.flush() # Force flush stderr to ensure progress bar displays finally: pbar.close() @@ -819,6 +913,10 @@ class DataProcessor: exclude_patterns: Optional[List[str]] = None, min_length: int = 10, max_samples: Optional[int] = None, + max_files: Optional[int] = None, + num_workers: int = 0, + skip_images: bool = False, + skip_pdfs: bool = False, ) -> List[str]: """ Process directory and return list of text lines. @@ -846,6 +944,10 @@ class DataProcessor: include_patterns=include_patterns, exclude_patterns=exclude_patterns, min_length=min_length, + max_files=max_files, + num_workers=num_workers, + skip_images=skip_images, + skip_pdfs=skip_pdfs, ): texts.append(text) if max_samples and len(texts) >= max_samples: @@ -870,6 +972,10 @@ def extract_text_from_directory( use_pdf_extraction: bool = True, min_length: int = 10, max_samples: Optional[int] = None, + max_files: Optional[int] = None, + num_workers: int = 0, + skip_images: bool = False, + skip_pdfs: bool = False, ) -> List[str]: """ Convenience function to extract text from a directory. @@ -881,6 +987,10 @@ def extract_text_from_directory( use_pdf_extraction: Whether to extract text from PDFs min_length: Minimum length for extracted text lines max_samples: Maximum number of samples to return (None = all) + max_files: Maximum number of files to process (None = all) + num_workers: Number of parallel workers (0 = sequential, -1 = auto) + skip_images: Skip image files entirely (faster processing) + skip_pdfs: Skip PDF files entirely (faster processing) Returns: List of text lines @@ -892,6 +1002,10 @@ def extract_text_from_directory( recursive=recursive, min_length=min_length, max_samples=max_samples, + max_files=max_files, + num_workers=num_workers, + skip_images=skip_images, + skip_pdfs=skip_pdfs, ) except KeyboardInterrupt: logger.error( diff --git a/train.py b/train.py index b3c6fac..eb104ea 100644 --- a/train.py +++ b/train.py @@ -73,6 +73,12 @@ def main(): parser.add_argument('--data', type=str, required=True, help='Path to training data') parser.add_argument('--output', type=str, default='./checkpoints', help='Output directory') parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from') + parser.add_argument('--max-files', type=int, default=None, help='Maximum number of files to process (None = all)') + parser.add_argument('--data-workers', type=int, default=0, help='Number of parallel workers for data processing (0 = sequential, -1 = auto)') + parser.add_argument('--skip-images', action='store_true', help='Skip image files (faster processing)') + parser.add_argument('--skip-pdfs', action='store_true', help='Skip PDF files (faster processing)') + parser.add_argument('--no-ocr', action='store_true', help='Disable OCR for images') + parser.add_argument('--no-pdf-extraction', action='store_true', help='Disable PDF text extraction') # Auto-detect best device if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): @@ -143,9 +149,13 @@ def main(): texts = extract_text_from_directory( directory=data_path, recursive=True, - use_ocr=True, # Enable OCR for images - use_pdf_extraction=True, # Enable PDF extraction + use_ocr=not args.no_ocr, # Enable OCR for images unless disabled + use_pdf_extraction=not args.no_pdf_extraction, # Enable PDF extraction unless disabled min_length=10, # Minimum length for text lines + max_files=args.max_files, # Limit number of files if specified + num_workers=args.data_workers, # Parallel processing workers + skip_images=args.skip_images, # Skip images entirely + skip_pdfs=args.skip_pdfs, # Skip PDFs entirely ) except KeyboardInterrupt: print("\n\n⚠️ Data processing interrupted by user (Ctrl+C).")