Fixing script
This commit is contained in:
@@ -13,6 +13,9 @@ import hashlib
|
||||
import pickle
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from functools import partial
|
||||
import multiprocessing
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -257,6 +260,32 @@ def create_dataloader(
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Standalone function for multiprocessing (must be at module level to be picklable)
|
||||
# ============================================================================
|
||||
|
||||
def _process_file_worker(file_path_str: str, use_ocr: bool, use_pdf_extraction: bool) -> tuple:
|
||||
"""
|
||||
Worker function for processing a single file in parallel.
|
||||
Must be at module level to be picklable for multiprocessing.
|
||||
|
||||
Args:
|
||||
file_path_str: File path as string
|
||||
use_ocr: Whether to use OCR
|
||||
use_pdf_extraction: Whether to extract PDFs
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, lines_list, error_string)
|
||||
"""
|
||||
try:
|
||||
file_path = Path(file_path_str)
|
||||
processor = DataProcessor(use_ocr=use_ocr, use_pdf_extraction=use_pdf_extraction)
|
||||
file_lines = list(processor.process_file(file_path))
|
||||
return (str(file_path), file_lines, None)
|
||||
except Exception as e:
|
||||
return (str(file_path_str), [], str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Processor for Multiple File Types
|
||||
# ============================================================================
|
||||
@@ -572,6 +601,10 @@ class DataProcessor:
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
min_length: int = 10,
|
||||
max_files: Optional[int] = None,
|
||||
num_workers: int = 0,
|
||||
skip_images: bool = False,
|
||||
skip_pdfs: bool = False,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Process all files in a directory.
|
||||
@@ -582,6 +615,10 @@ class DataProcessor:
|
||||
include_patterns: Optional list of glob patterns to include
|
||||
exclude_patterns: Optional list of glob patterns to exclude
|
||||
min_length: Minimum length for extracted text lines
|
||||
max_files: Maximum number of files to process (None = all)
|
||||
num_workers: Number of parallel workers (0 = sequential, -1 = auto)
|
||||
skip_images: Skip image files (faster processing)
|
||||
skip_pdfs: Skip PDF files (faster processing)
|
||||
|
||||
Yields:
|
||||
Text lines from all processed files
|
||||
@@ -625,12 +662,11 @@ class DataProcessor:
|
||||
logger.info("Scanning directory (cache miss or invalid)...")
|
||||
|
||||
# Collect all supported file extensions
|
||||
all_supported_extensions = (
|
||||
self.TEXT_EXTENSIONS |
|
||||
self.CODE_EXTENSIONS |
|
||||
self.PDF_EXTENSIONS |
|
||||
self.IMAGE_EXTENSIONS
|
||||
)
|
||||
all_supported_extensions = self.TEXT_EXTENSIONS | self.CODE_EXTENSIONS
|
||||
if not skip_pdfs:
|
||||
all_supported_extensions |= self.PDF_EXTENSIONS
|
||||
if not skip_images:
|
||||
all_supported_extensions |= self.IMAGE_EXTENSIONS
|
||||
|
||||
if recursive:
|
||||
pattern = '**/*'
|
||||
@@ -715,6 +751,11 @@ class DataProcessor:
|
||||
logger.error(f"Error during directory scanning: {e}")
|
||||
logger.info(f"Continuing with {len(files_to_process)} files found so far...")
|
||||
|
||||
# Limit number of files if max_files is specified
|
||||
if max_files and len(files_to_process) > max_files:
|
||||
logger.info(f"Limiting to first {max_files:,} files (found {len(files_to_process):,} total)")
|
||||
files_to_process = files_to_process[:max_files]
|
||||
|
||||
if skipped_count > 10:
|
||||
logger.info(f"Skipped {skipped_count} inaccessible paths")
|
||||
|
||||
@@ -742,7 +783,16 @@ class DataProcessor:
|
||||
logger.warning("No files found to process!")
|
||||
return
|
||||
|
||||
logger.info(f"Starting to process {total_files} files with progress bar...")
|
||||
# Determine number of workers
|
||||
if num_workers == -1:
|
||||
num_workers = max(1, multiprocessing.cpu_count() - 1)
|
||||
elif num_workers == 0:
|
||||
num_workers = 1 # Sequential processing
|
||||
|
||||
if num_workers > 1:
|
||||
logger.info(f"Starting to process {total_files:,} files with {num_workers} parallel workers...")
|
||||
else:
|
||||
logger.info(f"Starting to process {total_files:,} files sequentially...")
|
||||
|
||||
# Create progress bar
|
||||
pbar = tqdm(
|
||||
@@ -758,51 +808,95 @@ class DataProcessor:
|
||||
)
|
||||
|
||||
try:
|
||||
for idx, file_path in enumerate(files_to_process, 1):
|
||||
try:
|
||||
file_lines = list(self.process_file(file_path))
|
||||
if file_lines:
|
||||
processed_count += 1
|
||||
for line in file_lines:
|
||||
if len(line) >= min_length:
|
||||
yield line
|
||||
total_lines += 1
|
||||
else:
|
||||
skipped_count += 1
|
||||
if num_workers > 1:
|
||||
# Parallel processing
|
||||
# Use module-level function for pickling
|
||||
worker_func = partial(
|
||||
_process_file_worker,
|
||||
use_ocr=self.use_ocr and not skip_images,
|
||||
use_pdf_extraction=self.use_pdf_extraction and not skip_pdfs,
|
||||
)
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||
# Submit all tasks
|
||||
future_to_file = {
|
||||
executor.submit(worker_func, str(file_path)): file_path
|
||||
for file_path in files_to_process
|
||||
}
|
||||
|
||||
# Update progress bar with statistics
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1) # Advance progress bar
|
||||
pbar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pbar.close()
|
||||
logger.warning(
|
||||
f"Processing interrupted. "
|
||||
f"Files: {idx}/{total_files}, Processed: {processed_count}, "
|
||||
f"Skipped: {skipped_count}, Errors: {error_count}, "
|
||||
f"Lines extracted: {total_lines:,}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
# Update progress bar even on errors
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1) # Advance progress bar even on error
|
||||
pbar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
# Process completed tasks
|
||||
for future in as_completed(future_to_file):
|
||||
file_path_str, file_lines, error = future.result()
|
||||
file_path = Path(file_path_str)
|
||||
|
||||
if error:
|
||||
error_count += 1
|
||||
logger.error(f"Error processing {file_path}: {error}")
|
||||
elif file_lines:
|
||||
processed_count += 1
|
||||
for line in file_lines:
|
||||
if len(line) >= min_length:
|
||||
yield line
|
||||
total_lines += 1
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
# Update progress bar
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1)
|
||||
pbar.refresh()
|
||||
sys.stderr.flush()
|
||||
else:
|
||||
# Sequential processing (original code)
|
||||
for idx, file_path in enumerate(files_to_process, 1):
|
||||
try:
|
||||
file_lines = list(self.process_file(file_path))
|
||||
if file_lines:
|
||||
processed_count += 1
|
||||
for line in file_lines:
|
||||
if len(line) >= min_length:
|
||||
yield line
|
||||
total_lines += 1
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
# Update progress bar with statistics
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1) # Advance progress bar
|
||||
pbar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pbar.close()
|
||||
logger.warning(
|
||||
f"Processing interrupted. "
|
||||
f"Files: {idx}/{total_files}, Processed: {processed_count}, "
|
||||
f"Skipped: {skipped_count}, Errors: {error_count}, "
|
||||
f"Lines extracted: {total_lines:,}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
# Update progress bar even on errors
|
||||
pbar.set_postfix({
|
||||
'Processed': processed_count,
|
||||
'Skipped': skipped_count,
|
||||
'Errors': error_count,
|
||||
'Lines': f"{total_lines:,}"
|
||||
})
|
||||
pbar.update(1) # Advance progress bar even on error
|
||||
pbar.refresh() # Force immediate refresh
|
||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
@@ -819,6 +913,10 @@ class DataProcessor:
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
min_length: int = 10,
|
||||
max_samples: Optional[int] = None,
|
||||
max_files: Optional[int] = None,
|
||||
num_workers: int = 0,
|
||||
skip_images: bool = False,
|
||||
skip_pdfs: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Process directory and return list of text lines.
|
||||
@@ -846,6 +944,10 @@ class DataProcessor:
|
||||
include_patterns=include_patterns,
|
||||
exclude_patterns=exclude_patterns,
|
||||
min_length=min_length,
|
||||
max_files=max_files,
|
||||
num_workers=num_workers,
|
||||
skip_images=skip_images,
|
||||
skip_pdfs=skip_pdfs,
|
||||
):
|
||||
texts.append(text)
|
||||
if max_samples and len(texts) >= max_samples:
|
||||
@@ -870,6 +972,10 @@ def extract_text_from_directory(
|
||||
use_pdf_extraction: bool = True,
|
||||
min_length: int = 10,
|
||||
max_samples: Optional[int] = None,
|
||||
max_files: Optional[int] = None,
|
||||
num_workers: int = 0,
|
||||
skip_images: bool = False,
|
||||
skip_pdfs: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convenience function to extract text from a directory.
|
||||
@@ -881,6 +987,10 @@ def extract_text_from_directory(
|
||||
use_pdf_extraction: Whether to extract text from PDFs
|
||||
min_length: Minimum length for extracted text lines
|
||||
max_samples: Maximum number of samples to return (None = all)
|
||||
max_files: Maximum number of files to process (None = all)
|
||||
num_workers: Number of parallel workers (0 = sequential, -1 = auto)
|
||||
skip_images: Skip image files entirely (faster processing)
|
||||
skip_pdfs: Skip PDF files entirely (faster processing)
|
||||
|
||||
Returns:
|
||||
List of text lines
|
||||
@@ -892,6 +1002,10 @@ def extract_text_from_directory(
|
||||
recursive=recursive,
|
||||
min_length=min_length,
|
||||
max_samples=max_samples,
|
||||
max_files=max_files,
|
||||
num_workers=num_workers,
|
||||
skip_images=skip_images,
|
||||
skip_pdfs=skip_pdfs,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.error(
|
||||
|
||||
14
train.py
14
train.py
@@ -73,6 +73,12 @@ def main():
|
||||
parser.add_argument('--data', type=str, required=True, help='Path to training data')
|
||||
parser.add_argument('--output', type=str, default='./checkpoints', help='Output directory')
|
||||
parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from')
|
||||
parser.add_argument('--max-files', type=int, default=None, help='Maximum number of files to process (None = all)')
|
||||
parser.add_argument('--data-workers', type=int, default=0, help='Number of parallel workers for data processing (0 = sequential, -1 = auto)')
|
||||
parser.add_argument('--skip-images', action='store_true', help='Skip image files (faster processing)')
|
||||
parser.add_argument('--skip-pdfs', action='store_true', help='Skip PDF files (faster processing)')
|
||||
parser.add_argument('--no-ocr', action='store_true', help='Disable OCR for images')
|
||||
parser.add_argument('--no-pdf-extraction', action='store_true', help='Disable PDF text extraction')
|
||||
|
||||
# Auto-detect best device
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
@@ -143,9 +149,13 @@ def main():
|
||||
texts = extract_text_from_directory(
|
||||
directory=data_path,
|
||||
recursive=True,
|
||||
use_ocr=True, # Enable OCR for images
|
||||
use_pdf_extraction=True, # Enable PDF extraction
|
||||
use_ocr=not args.no_ocr, # Enable OCR for images unless disabled
|
||||
use_pdf_extraction=not args.no_pdf_extraction, # Enable PDF extraction unless disabled
|
||||
min_length=10, # Minimum length for text lines
|
||||
max_files=args.max_files, # Limit number of files if specified
|
||||
num_workers=args.data_workers, # Parallel processing workers
|
||||
skip_images=args.skip_images, # Skip images entirely
|
||||
skip_pdfs=args.skip_pdfs, # Skip PDFs entirely
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Data processing interrupted by user (Ctrl+C).")
|
||||
|
||||
Reference in New Issue
Block a user