Source code for scripts.processing.detection

#!/usr/bin/env python3

# Import third-party libraries
import argparse
import os
import sys
import time
import warnings
import pickle
import hashlib
from pathlib import Path
import pandas as pd
import torch

# Add project root to Python path
current_dir = Path(__file__).parent.absolute()
project_root = current_dir.parent.parent
sys.path.insert(0, str(project_root))

# Suppress warning messages during execution
warnings.filterwarnings('ignore')

# Import project configuration
from label_processing.config import get_model_path, get_project_root

# Import the necessary module from the 'label_processing' module package
import label_processing.label_detection as scrop
from label_processing.label_detection import create_crops
from detecto.core import Model

# Constants
THRESHOLD = 0.8
PROCESSES = 12

[docs] class OptimizedPredictLabel: """ Optimized version of PredictLabel with caching and streamlined loading. """ def __init__(self, path_to_model: str, classes: list, threshold: float = 0.8, use_cache: bool = True): """ Initialize with optimized model loading. Args: path_to_model (str): Path to the model file classes (list): List of class names threshold (float): Detection threshold use_cache (bool): Whether to use model caching """ self.path_to_model = Path(path_to_model) self.classes = classes self.threshold = threshold self.use_cache = use_cache # Setup caching self.cache_dir = Path.home() / '.entomological_cache' self.cache_dir.mkdir(exist_ok=True) # Load model with optimizations self.model = self.load_model_optimized() def _get_model_hash(self) -> str: """Generate hash of model file for cache validation.""" with open(self.path_to_model, 'rb') as f: # Read first 1MB and last 1MB for quick hash first_chunk = f.read(1024 * 1024) f.seek(-1024 * 1024, 2) # Seek to last MB last_chunk = f.read(1024 * 1024) combined = first_chunk + last_chunk + str(self.path_to_model.stat().st_mtime).encode() return hashlib.md5(combined).hexdigest() def _get_cache_path(self) -> Path: """Get path for cached model.""" model_hash = self._get_model_hash() return self.cache_dir / f"model_{model_hash}.pkl"
[docs] def load_model_optimized(self) -> Model: """ Load model with optimized strategy. """ print(f"Loading model from: {self.path_to_model}") start_time = time.perf_counter() # Try to load from cache first if self.use_cache: cached_model = self._try_load_from_cache() if cached_model is not None: load_time = time.perf_counter() - start_time print(f" Model loaded from cache in {load_time:.2f}s") return cached_model # If cache miss, load model efficiently model = self._load_model_direct() # Cache the model for future use if self.use_cache: self._cache_model(model) load_time = time.perf_counter() - start_time print(f" Model loaded in {load_time:.2f}s") return model
def _try_load_from_cache(self) -> Model: """Try to load model from cache.""" cache_path = self._get_cache_path() if not cache_path.exists(): return None try: print("Attempting to load model from cache...") with open(cache_path, 'rb') as f: cached_data = pickle.load(f) # Reconstruct model from cached state dict model = Model(self.classes) # Handle different model attribute names if hasattr(model, 'model'): model.model.load_state_dict(cached_data['state_dict'], strict=False) elif hasattr(model, '_model'): model._model.load_state_dict(cached_data['state_dict'], strict=False) else: model.load_state_dict(cached_data['state_dict'], strict=False) return model except Exception as e: print(f"Cache loading failed: {e}") # Remove corrupted cache file try: cache_path.unlink() except: pass return None def _load_model_direct(self) -> Model: """Load model directly with optimized settings.""" # Set up optimal environment self._setup_optimal_environment() try: # First try the most compatible approach for PyTorch 2.6+ print("Loading with PyTorch 2.6+ compatibility...") # Monkey-patch torch.load temporarily original_torch_load = torch.load def optimized_load(*args, **kwargs): kwargs['weights_only'] = False kwargs['map_location'] = 'cpu' # Always load to CPU first for compatibility return original_torch_load(*args, **kwargs) torch.load = optimized_load try: model = Model.load(str(self.path_to_model), self.classes) return model finally: torch.load = original_torch_load except Exception as e: print(f"Direct loading failed: {e}") # Fallback to manual loading return self._load_model_manual() def _load_model_manual(self) -> Model: """Manual model loading as fallback.""" print("Using manual loading fallback...") # Load state dict manually state_dict = torch.load(str(self.path_to_model), map_location='cpu', weights_only=False) # Create new model and load state model = Model(self.classes) if isinstance(state_dict, dict): model.model.load_state_dict(state_dict, strict=False) else: # Handle model object if hasattr(state_dict, 'state_dict'): model.model.load_state_dict(state_dict.state_dict(), strict=False) else: raise Exception(f"Unknown model format: {type(state_dict)}") return model def _cache_model(self, model: Model): """Cache the loaded model for future use.""" try: cache_path = self._get_cache_path() # Handle different model attribute names if hasattr(model, 'model'): state_dict = model.model.state_dict() elif hasattr(model, '_model'): state_dict = model._model.state_dict() else: state_dict = model.state_dict() cached_data = { 'state_dict': state_dict, 'classes': self.classes, 'timestamp': time.time() } with open(cache_path, 'wb') as f: pickle.dump(cached_data, f) print(f" Model cached for future use") except Exception as e: print(f"Warning: Could not cache model: {e}") def _setup_optimal_environment(self): """Setup optimal environment for loading.""" # Disable CUDA for loading (can enable later for inference) os.environ['CUDA_VISIBLE_DEVICES'] = '' # Optimize thread usage torch.set_num_threads(min(4, os.cpu_count() or 1)) # Set environment variables for stability os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_NUM_THREADS'] = '1'
[docs] def class_prediction(self, jpg_path: Path) -> pd.DataFrame: """ Predict labels for a given JPG file. Args: jpg_path (Path): Path to the JPG file Returns: pd.DataFrame: Prediction results """ import detecto.utils image = detecto.utils.read_image(str(jpg_path)) predictions = self.model.predict(image) labels, boxes, scores = predictions entries = [] for i, labelname in enumerate(labels): entry = { 'filename': jpg_path.name, 'class': labelname, 'score': scores[i].item(), 'xmin': boxes[i][0], 'ymin': boxes[i][1], 'xmax': boxes[i][2], 'ymax': boxes[i][3] } entries.append(entry) return pd.DataFrame(entries)
[docs] def parse_arguments() -> argparse.Namespace: """ Parse command-line arguments using argparse. Returns: argparse.Namespace: Parsed command-line arguments. """ parser = argparse.ArgumentParser( description="Execute label detection on entomological specimen images with performance optimizations." ) # Input options (mutually exclusive) input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument( '-j', '--input-dir', type=str, help='Directory containing specimen images' ) input_group.add_argument( '-i', '--input-image', type=str, help='Single image file to process' ) # Output directory (required) parser.add_argument( '-o', '--output-dir', type=str, required=True, help='Directory where results will be saved' ) # Optional parameters parser.add_argument( '--confidence', type=float, default=0.8, help='Detection confidence threshold (default: 0.5)' ) parser.add_argument( '--batch-size', type=int, default=16, help='Number of images processed simultaneously (default: 16)' ) parser.add_argument( '--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda', 'mps'], help='Device to use for processing (default: auto)' ) parser.add_argument( '--no-cache', action='store_true', help='Disable model caching' ) parser.add_argument( '--clear-cache', action='store_true', help='Clear model cache before running' ) return parser.parse_args()
[docs] def clear_model_cache(): """Clear all cached models.""" cache_dir = Path.home() / '.entomological_cache' if cache_dir.exists(): for cache_file in cache_dir.glob('model_*.pkl'): try: cache_file.unlink() print(f"Removed cache file: {cache_file.name}") except Exception as e: print(f"Could not remove {cache_file.name}: {e}") print("Model cache cleared.") else: print("No cache directory found.")
[docs] def setup_device(device_arg: str) -> str: """ Setup optimal device for inference. Args: device_arg: Device argument from command line Returns: str: Best available device """ if device_arg == 'auto': if torch.cuda.is_available(): device = 'cuda' print(f" Using CUDA GPU: {torch.cuda.get_device_name()}") elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # MPS has compatibility issues with some PyTorch operations (e.g., _share_filename_) # Fall back to CPU for reliable inference device = 'cpu' print(f" MPS available but using CPU for better compatibility") print(f" Using CPU with {torch.get_num_threads()} threads") else: device = 'cpu' print(f" Using CPU with {torch.get_num_threads()} threads") else: # If user explicitly requests MPS, warn them about potential issues if device_arg == 'mps': print(f"Warning: MPS has known compatibility issues. Using CPU instead.") device = 'cpu' else: device = device_arg print(f" Using device: {device}") return device
[docs] def main(): """ Main execution function with performance optimizations. """ start_time = time.perf_counter() args = parse_arguments() # Clear cache if requested if args.clear_cache: clear_model_cache() return # Use centralized configuration for model path try: MODEL_PATH = get_model_path("detection") except Exception as e: print(f"Error getting model path: {e}") print("Please ensure the model file exists or set the ENTOMOLOGICAL_DETECTION_MODEL_PATH environment variable.") return # Handle input (directory or single file) if args.input_dir: jpg_dir = Path(args.input_dir) input_type = "directory" else: # Single file input single_file = Path(args.input_image) if not single_file.exists(): print(f"Error: Input file '{single_file}' does not exist.") return jpg_dir = single_file.parent input_type = "single_file" print(f"Processing single file: {single_file.name}") out_dir = args.output_dir confidence_threshold = args.confidence batch_size = args.batch_size use_cache = not args.no_cache classes = ["label"] # Validate paths if not os.path.exists(out_dir): print(f"Creating output directory: {out_dir}") os.makedirs(out_dir) if not MODEL_PATH.exists(): print(f"Error: Model file '{MODEL_PATH}' not found.") return if input_type == "directory" and not jpg_dir.exists(): print(f"Error: Input directory '{jpg_dir}' does not exist.") return print(f"Using confidence threshold: {confidence_threshold}") print(f"Using batch size: {batch_size}") print(f"Model caching: {'enabled' if use_cache else 'disabled'}") # Setup device device = setup_device(args.device) try: # Initialize optimized predictor predictor = OptimizedPredictLabel(MODEL_PATH, classes, use_cache=use_cache) # Move model to selected device if not CPU if device != 'cpu': try: if hasattr(predictor.model, 'model'): predictor.model.model = predictor.model.model.to(device) elif hasattr(predictor.model, '_model'): predictor.model._model = predictor.model._model.to(device) else: # Try to move the predictor model directly predictor.model = predictor.model.to(device) print(f" Model moved to {device}") except Exception as e: print(f"Warning: Could not move model to {device}, using CPU: {e}") model_load_time = time.perf_counter() - start_time print(f" Total model setup time: {model_load_time:.2f}s") # Prediction phase prediction_start = time.perf_counter() if input_type == "single_file": print(f"Processing single file: {single_file}") df = predictor.class_prediction(single_file) if df.empty: df = pd.DataFrame(columns=['filename', 'class', 'score', 'xmin', 'ymin', 'xmax', 'ymax']) else: # For CPU inference, use sequential processing to avoid multiprocessing overhead # For GPU/MPS, parallel processing could be beneficial if device == 'cpu': print("Processing images sequentially (CPU mode)...") # Collect image files file_names = [p for p in sorted(jpg_dir.iterdir()) if scrop.is_image_file(p)] print(f"Found {len(file_names)} images to process") # Process sequentially results = [] for i, file_path in enumerate(file_names, 1): print(f"Processing {i}/{len(file_names)}: {file_path.name}", end='\r') result_df = predictor.class_prediction(file_path) if not result_df.empty: results.append(result_df) print() # New line after progress df = pd.concat(results, ignore_index=True) if results else pd.DataFrame() else: # For GPU/MPS, use parallel processing processes = min(PROCESSES, batch_size) if batch_size < PROCESSES else PROCESSES df = scrop.prediction_parallel(jpg_dir, predictor, processes) prediction_time = time.perf_counter() - prediction_start print(f" Prediction completed in {prediction_time:.2f}s") except Exception as e: print(f"Error during prediction: {e}") return if df.empty: print("No valid predictions were generated. Skipping further processing.") return try: df = scrop.clean_predictions(jpg_dir, df, confidence_threshold, out_dir=out_dir) except Exception as e: print(f"Error cleaning predictions: {e}") return detection_total_time = time.perf_counter() - start_time print(f"Detection finished in {detection_total_time:.2f}s") try: crop_start = time.perf_counter() create_crops(jpg_dir, df, out_dir=out_dir) crop_time = time.perf_counter() - crop_start print(f" Cropping completed in {crop_time:.2f}s") except Exception as e: print(f"Error during cropping: {e}") return total_time = time.perf_counter() - start_time print(f"\n" + "="*50) print(f" PROCESSING COMPLETED") print(f" Total time: {total_time:.2f}s") print(f" - Model loading: {model_load_time:.2f}s ({model_load_time/total_time*100:.1f}%)") print(f" - Prediction: {prediction_time:.2f}s ({prediction_time/total_time*100:.1f}%)") print(f" - Cropping: {crop_time:.2f}s ({crop_time/total_time*100:.1f}%)") print(f" Results saved to: {out_dir}") print(f" - CSV file: {out_dir}/input_predictions.csv") print(f" - Cropped images: {out_dir}/input_cropped/") print("="*50)
if __name__ == '__main__': main()