Fixing script
This commit is contained in:
@@ -13,6 +13,9 @@ import hashlib
|
|||||||
import pickle
|
import pickle
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from functools import partial
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -257,6 +260,32 @@ def create_dataloader(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Standalone function for multiprocessing (must be at module level to be picklable)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def _process_file_worker(file_path_str: str, use_ocr: bool, use_pdf_extraction: bool) -> tuple:
|
||||||
|
"""
|
||||||
|
Worker function for processing a single file in parallel.
|
||||||
|
Must be at module level to be picklable for multiprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path_str: File path as string
|
||||||
|
use_ocr: Whether to use OCR
|
||||||
|
use_pdf_extraction: Whether to extract PDFs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (file_path, lines_list, error_string)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
file_path = Path(file_path_str)
|
||||||
|
processor = DataProcessor(use_ocr=use_ocr, use_pdf_extraction=use_pdf_extraction)
|
||||||
|
file_lines = list(processor.process_file(file_path))
|
||||||
|
return (str(file_path), file_lines, None)
|
||||||
|
except Exception as e:
|
||||||
|
return (str(file_path_str), [], str(e))
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Data Processor for Multiple File Types
|
# Data Processor for Multiple File Types
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -572,6 +601,10 @@ class DataProcessor:
|
|||||||
include_patterns: Optional[List[str]] = None,
|
include_patterns: Optional[List[str]] = None,
|
||||||
exclude_patterns: Optional[List[str]] = None,
|
exclude_patterns: Optional[List[str]] = None,
|
||||||
min_length: int = 10,
|
min_length: int = 10,
|
||||||
|
max_files: Optional[int] = None,
|
||||||
|
num_workers: int = 0,
|
||||||
|
skip_images: bool = False,
|
||||||
|
skip_pdfs: bool = False,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
"""
|
"""
|
||||||
Process all files in a directory.
|
Process all files in a directory.
|
||||||
@@ -582,6 +615,10 @@ class DataProcessor:
|
|||||||
include_patterns: Optional list of glob patterns to include
|
include_patterns: Optional list of glob patterns to include
|
||||||
exclude_patterns: Optional list of glob patterns to exclude
|
exclude_patterns: Optional list of glob patterns to exclude
|
||||||
min_length: Minimum length for extracted text lines
|
min_length: Minimum length for extracted text lines
|
||||||
|
max_files: Maximum number of files to process (None = all)
|
||||||
|
num_workers: Number of parallel workers (0 = sequential, -1 = auto)
|
||||||
|
skip_images: Skip image files (faster processing)
|
||||||
|
skip_pdfs: Skip PDF files (faster processing)
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Text lines from all processed files
|
Text lines from all processed files
|
||||||
@@ -625,12 +662,11 @@ class DataProcessor:
|
|||||||
logger.info("Scanning directory (cache miss or invalid)...")
|
logger.info("Scanning directory (cache miss or invalid)...")
|
||||||
|
|
||||||
# Collect all supported file extensions
|
# Collect all supported file extensions
|
||||||
all_supported_extensions = (
|
all_supported_extensions = self.TEXT_EXTENSIONS | self.CODE_EXTENSIONS
|
||||||
self.TEXT_EXTENSIONS |
|
if not skip_pdfs:
|
||||||
self.CODE_EXTENSIONS |
|
all_supported_extensions |= self.PDF_EXTENSIONS
|
||||||
self.PDF_EXTENSIONS |
|
if not skip_images:
|
||||||
self.IMAGE_EXTENSIONS
|
all_supported_extensions |= self.IMAGE_EXTENSIONS
|
||||||
)
|
|
||||||
|
|
||||||
if recursive:
|
if recursive:
|
||||||
pattern = '**/*'
|
pattern = '**/*'
|
||||||
@@ -715,6 +751,11 @@ class DataProcessor:
|
|||||||
logger.error(f"Error during directory scanning: {e}")
|
logger.error(f"Error during directory scanning: {e}")
|
||||||
logger.info(f"Continuing with {len(files_to_process)} files found so far...")
|
logger.info(f"Continuing with {len(files_to_process)} files found so far...")
|
||||||
|
|
||||||
|
# Limit number of files if max_files is specified
|
||||||
|
if max_files and len(files_to_process) > max_files:
|
||||||
|
logger.info(f"Limiting to first {max_files:,} files (found {len(files_to_process):,} total)")
|
||||||
|
files_to_process = files_to_process[:max_files]
|
||||||
|
|
||||||
if skipped_count > 10:
|
if skipped_count > 10:
|
||||||
logger.info(f"Skipped {skipped_count} inaccessible paths")
|
logger.info(f"Skipped {skipped_count} inaccessible paths")
|
||||||
|
|
||||||
@@ -742,7 +783,16 @@ class DataProcessor:
|
|||||||
logger.warning("No files found to process!")
|
logger.warning("No files found to process!")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Starting to process {total_files} files with progress bar...")
|
# Determine number of workers
|
||||||
|
if num_workers == -1:
|
||||||
|
num_workers = max(1, multiprocessing.cpu_count() - 1)
|
||||||
|
elif num_workers == 0:
|
||||||
|
num_workers = 1 # Sequential processing
|
||||||
|
|
||||||
|
if num_workers > 1:
|
||||||
|
logger.info(f"Starting to process {total_files:,} files with {num_workers} parallel workers...")
|
||||||
|
else:
|
||||||
|
logger.info(f"Starting to process {total_files:,} files sequentially...")
|
||||||
|
|
||||||
# Create progress bar
|
# Create progress bar
|
||||||
pbar = tqdm(
|
pbar = tqdm(
|
||||||
@@ -758,51 +808,95 @@ class DataProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for idx, file_path in enumerate(files_to_process, 1):
|
if num_workers > 1:
|
||||||
try:
|
# Parallel processing
|
||||||
file_lines = list(self.process_file(file_path))
|
# Use module-level function for pickling
|
||||||
if file_lines:
|
worker_func = partial(
|
||||||
processed_count += 1
|
_process_file_worker,
|
||||||
for line in file_lines:
|
use_ocr=self.use_ocr and not skip_images,
|
||||||
if len(line) >= min_length:
|
use_pdf_extraction=self.use_pdf_extraction and not skip_pdfs,
|
||||||
yield line
|
)
|
||||||
total_lines += 1
|
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||||
else:
|
# Submit all tasks
|
||||||
skipped_count += 1
|
future_to_file = {
|
||||||
|
executor.submit(worker_func, str(file_path)): file_path
|
||||||
|
for file_path in files_to_process
|
||||||
|
}
|
||||||
|
|
||||||
# Update progress bar with statistics
|
# Process completed tasks
|
||||||
pbar.set_postfix({
|
for future in as_completed(future_to_file):
|
||||||
'Processed': processed_count,
|
file_path_str, file_lines, error = future.result()
|
||||||
'Skipped': skipped_count,
|
file_path = Path(file_path_str)
|
||||||
'Errors': error_count,
|
|
||||||
'Lines': f"{total_lines:,}"
|
|
||||||
})
|
|
||||||
pbar.update(1) # Advance progress bar
|
|
||||||
pbar.refresh() # Force immediate refresh
|
|
||||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
if error:
|
||||||
pbar.close()
|
error_count += 1
|
||||||
logger.warning(
|
logger.error(f"Error processing {file_path}: {error}")
|
||||||
f"Processing interrupted. "
|
elif file_lines:
|
||||||
f"Files: {idx}/{total_files}, Processed: {processed_count}, "
|
processed_count += 1
|
||||||
f"Skipped: {skipped_count}, Errors: {error_count}, "
|
for line in file_lines:
|
||||||
f"Lines extracted: {total_lines:,}"
|
if len(line) >= min_length:
|
||||||
)
|
yield line
|
||||||
raise
|
total_lines += 1
|
||||||
except Exception as e:
|
else:
|
||||||
error_count += 1
|
skipped_count += 1
|
||||||
logger.error(f"Error processing {file_path}: {e}")
|
|
||||||
# Update progress bar even on errors
|
# Update progress bar
|
||||||
pbar.set_postfix({
|
pbar.set_postfix({
|
||||||
'Processed': processed_count,
|
'Processed': processed_count,
|
||||||
'Skipped': skipped_count,
|
'Skipped': skipped_count,
|
||||||
'Errors': error_count,
|
'Errors': error_count,
|
||||||
'Lines': f"{total_lines:,}"
|
'Lines': f"{total_lines:,}"
|
||||||
})
|
})
|
||||||
pbar.update(1) # Advance progress bar even on error
|
pbar.update(1)
|
||||||
pbar.refresh() # Force immediate refresh
|
pbar.refresh()
|
||||||
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
sys.stderr.flush()
|
||||||
|
else:
|
||||||
|
# Sequential processing (original code)
|
||||||
|
for idx, file_path in enumerate(files_to_process, 1):
|
||||||
|
try:
|
||||||
|
file_lines = list(self.process_file(file_path))
|
||||||
|
if file_lines:
|
||||||
|
processed_count += 1
|
||||||
|
for line in file_lines:
|
||||||
|
if len(line) >= min_length:
|
||||||
|
yield line
|
||||||
|
total_lines += 1
|
||||||
|
else:
|
||||||
|
skipped_count += 1
|
||||||
|
|
||||||
|
# Update progress bar with statistics
|
||||||
|
pbar.set_postfix({
|
||||||
|
'Processed': processed_count,
|
||||||
|
'Skipped': skipped_count,
|
||||||
|
'Errors': error_count,
|
||||||
|
'Lines': f"{total_lines:,}"
|
||||||
|
})
|
||||||
|
pbar.update(1) # Advance progress bar
|
||||||
|
pbar.refresh() # Force immediate refresh
|
||||||
|
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pbar.close()
|
||||||
|
logger.warning(
|
||||||
|
f"Processing interrupted. "
|
||||||
|
f"Files: {idx}/{total_files}, Processed: {processed_count}, "
|
||||||
|
f"Skipped: {skipped_count}, Errors: {error_count}, "
|
||||||
|
f"Lines extracted: {total_lines:,}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
|
# Update progress bar even on errors
|
||||||
|
pbar.set_postfix({
|
||||||
|
'Processed': processed_count,
|
||||||
|
'Skipped': skipped_count,
|
||||||
|
'Errors': error_count,
|
||||||
|
'Lines': f"{total_lines:,}"
|
||||||
|
})
|
||||||
|
pbar.update(1) # Advance progress bar even on error
|
||||||
|
pbar.refresh() # Force immediate refresh
|
||||||
|
sys.stderr.flush() # Force flush stderr to ensure progress bar displays
|
||||||
finally:
|
finally:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
@@ -819,6 +913,10 @@ class DataProcessor:
|
|||||||
exclude_patterns: Optional[List[str]] = None,
|
exclude_patterns: Optional[List[str]] = None,
|
||||||
min_length: int = 10,
|
min_length: int = 10,
|
||||||
max_samples: Optional[int] = None,
|
max_samples: Optional[int] = None,
|
||||||
|
max_files: Optional[int] = None,
|
||||||
|
num_workers: int = 0,
|
||||||
|
skip_images: bool = False,
|
||||||
|
skip_pdfs: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Process directory and return list of text lines.
|
Process directory and return list of text lines.
|
||||||
@@ -846,6 +944,10 @@ class DataProcessor:
|
|||||||
include_patterns=include_patterns,
|
include_patterns=include_patterns,
|
||||||
exclude_patterns=exclude_patterns,
|
exclude_patterns=exclude_patterns,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
|
max_files=max_files,
|
||||||
|
num_workers=num_workers,
|
||||||
|
skip_images=skip_images,
|
||||||
|
skip_pdfs=skip_pdfs,
|
||||||
):
|
):
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
if max_samples and len(texts) >= max_samples:
|
if max_samples and len(texts) >= max_samples:
|
||||||
@@ -870,6 +972,10 @@ def extract_text_from_directory(
|
|||||||
use_pdf_extraction: bool = True,
|
use_pdf_extraction: bool = True,
|
||||||
min_length: int = 10,
|
min_length: int = 10,
|
||||||
max_samples: Optional[int] = None,
|
max_samples: Optional[int] = None,
|
||||||
|
max_files: Optional[int] = None,
|
||||||
|
num_workers: int = 0,
|
||||||
|
skip_images: bool = False,
|
||||||
|
skip_pdfs: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Convenience function to extract text from a directory.
|
Convenience function to extract text from a directory.
|
||||||
@@ -881,6 +987,10 @@ def extract_text_from_directory(
|
|||||||
use_pdf_extraction: Whether to extract text from PDFs
|
use_pdf_extraction: Whether to extract text from PDFs
|
||||||
min_length: Minimum length for extracted text lines
|
min_length: Minimum length for extracted text lines
|
||||||
max_samples: Maximum number of samples to return (None = all)
|
max_samples: Maximum number of samples to return (None = all)
|
||||||
|
max_files: Maximum number of files to process (None = all)
|
||||||
|
num_workers: Number of parallel workers (0 = sequential, -1 = auto)
|
||||||
|
skip_images: Skip image files entirely (faster processing)
|
||||||
|
skip_pdfs: Skip PDF files entirely (faster processing)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of text lines
|
List of text lines
|
||||||
@@ -892,6 +1002,10 @@ def extract_text_from_directory(
|
|||||||
recursive=recursive,
|
recursive=recursive,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
|
max_files=max_files,
|
||||||
|
num_workers=num_workers,
|
||||||
|
skip_images=skip_images,
|
||||||
|
skip_pdfs=skip_pdfs,
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|||||||
14
train.py
14
train.py
@@ -73,6 +73,12 @@ def main():
|
|||||||
parser.add_argument('--data', type=str, required=True, help='Path to training data')
|
parser.add_argument('--data', type=str, required=True, help='Path to training data')
|
||||||
parser.add_argument('--output', type=str, default='./checkpoints', help='Output directory')
|
parser.add_argument('--output', type=str, default='./checkpoints', help='Output directory')
|
||||||
parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from')
|
parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from')
|
||||||
|
parser.add_argument('--max-files', type=int, default=None, help='Maximum number of files to process (None = all)')
|
||||||
|
parser.add_argument('--data-workers', type=int, default=0, help='Number of parallel workers for data processing (0 = sequential, -1 = auto)')
|
||||||
|
parser.add_argument('--skip-images', action='store_true', help='Skip image files (faster processing)')
|
||||||
|
parser.add_argument('--skip-pdfs', action='store_true', help='Skip PDF files (faster processing)')
|
||||||
|
parser.add_argument('--no-ocr', action='store_true', help='Disable OCR for images')
|
||||||
|
parser.add_argument('--no-pdf-extraction', action='store_true', help='Disable PDF text extraction')
|
||||||
|
|
||||||
# Auto-detect best device
|
# Auto-detect best device
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
@@ -143,9 +149,13 @@ def main():
|
|||||||
texts = extract_text_from_directory(
|
texts = extract_text_from_directory(
|
||||||
directory=data_path,
|
directory=data_path,
|
||||||
recursive=True,
|
recursive=True,
|
||||||
use_ocr=True, # Enable OCR for images
|
use_ocr=not args.no_ocr, # Enable OCR for images unless disabled
|
||||||
use_pdf_extraction=True, # Enable PDF extraction
|
use_pdf_extraction=not args.no_pdf_extraction, # Enable PDF extraction unless disabled
|
||||||
min_length=10, # Minimum length for text lines
|
min_length=10, # Minimum length for text lines
|
||||||
|
max_files=args.max_files, # Limit number of files if specified
|
||||||
|
num_workers=args.data_workers, # Parallel processing workers
|
||||||
|
skip_images=args.skip_images, # Skip images entirely
|
||||||
|
skip_pdfs=args.skip_pdfs, # Skip PDFs entirely
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n\n⚠️ Data processing interrupted by user (Ctrl+C).")
|
print("\n\n⚠️ Data processing interrupted by user (Ctrl+C).")
|
||||||
|
|||||||
Reference in New Issue
Block a user