#!/usr/bin/env python3
"""
Label Detection Module (Detectron2 / Detecto)
Detects and crops individual labels from full specimen photographs using a
trained Faster R-CNN object-detection model. Used by the traditional MLI
pipeline; the Gemini pipeline uses gemini_processor.detect_and_classify instead.
"""
import cv2
import torch
import os
import glob
import detecto.utils
import multiprocessing as mp
import pandas as pd
import numpy as np
from typing import Union
from pathlib import Path
import sys
from detecto.core import Model
import pickle
import platform
# ---------------------Image Segmentation---------------------#
# --- START: added image-file helpers and small robustness fixes ---
# helper: only try real image files
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".gif", ".webp"}
[docs]
def is_image_file(path) -> bool:
p = Path(path)
if not p.is_file():
return False
name = p.name
if name.startswith("._") or name.startswith("."):
return False
return p.suffix.lower() in IMAGE_EXTS
# --- END: added helpers ---
[docs]
class PredictLabel:
"""
Class for predicting labels using a trained object detection model.
Attributes:
path_to_model (str): Path to the trained model file.
classes (list): List of classes used in the model.
jpg_path (str|Path|None): Path to a specific JPG file for prediction.
threshold (float): Threshold value for scores. Defaults to 0.8.
model (detecto.core.Model): Trained object detection model.
"""
def __init__(
self,
path_to_model: str,
classes: list,
jpg_path: Union[str, Path, None] = None,
threshold: float = 0.8,
) -> None:
"""
Init Method for the PredictLabel Class.
Args:
path_to_model (str): Path to the model.
classes (list): List of classes.
jpg_path (str|Path|None): Path to JPG file for prediction.
threshold (float, optional): Threshold value for scores.
"""
self.path_to_model = path_to_model
self.classes = classes
self.jpg_path = jpg_path
self.threshold = threshold
self.model = self.retrieve_model()
@property
def jpg_path(self):
"""str|Path|None: Property for JPG path."""
return self._jpg_path
@jpg_path.setter
def jpg_path(self, jpg_path: Union[str, Path]):
"""Setter for JPG path."""
if jpg_path == None:
self._jpg_path = None
elif isinstance(jpg_path, str):
self._jpg_path = Path(jpg_path)
elif isinstance(jpg_path, Path):
self._jpg_path = jpg_path
[docs]
def retrieve_model(self) -> detecto.core.Model:
"""
Retrieve the trained object detection model using Detecto's Model.load.
Includes cross-platform compatibility fixes and integrity verification.
"""
if not os.path.exists(self.path_to_model):
raise FileNotFoundError(f"Model file '{self.path_to_model}' not found.")
if os.path.getsize(self.path_to_model) == 0:
raise IOError(f"Model file '{self.path_to_model}' is empty.")
# Verify model integrity if checksums file exists
model_dir = os.path.dirname(self.path_to_model)
checksums_file = os.path.join(model_dir, "checksums.sha256")
if os.path.exists(checksums_file):
try:
from label_processing.utils import verify_model_integrity
if not verify_model_integrity(self.path_to_model, checksums_file):
print(
f"WARNING: Model integrity check failed for {self.path_to_model}"
)
else:
print(f"Model integrity verified for {self.path_to_model}")
except Exception as e:
print(f"Could not verify model integrity: {e}")
print("Loading model from:", self.path_to_model)
# Set environment for cross-platform compatibility
self._setup_cross_platform_environment()
# SECURITY: Only use safe loading strategies with weights_only=True
loading_strategies = [
# Strategy 1: SAFE PyTorch loading with mandatory weights_only=True
lambda: self._load_pytorch_safe(),
# Strategy 2: SAFE detecto loading with verification
lambda: self._load_detecto_safe(),
]
last_error = None
for i, strategy in enumerate(loading_strategies, 1):
try:
print(f"Trying loading strategy {i}...")
model = strategy()
print("Model loaded successfully")
return model
except Exception as e:
print(f"Strategy {i} failed: {e}")
last_error = e
continue
# If all strategies fail, raise the last error
print(f"All loading strategies failed. Last error: {last_error}")
raise last_error
def _setup_cross_platform_environment(self):
"""Setup environment variables for cross-platform compatibility."""
import platform
# Force CPU-only execution to avoid CUDA issues on Linux servers
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Set multiprocessing start method for Linux compatibility
if platform.system() == "Linux":
try:
mp.set_start_method("spawn", force=True)
except RuntimeError:
# Method already set, ignore
pass
# Set PyTorch thread limits for stable performance
torch.set_num_threads(1)
# Disable MKL optimizations that can cause issues on some Linux distributions
if platform.system() == "Linux":
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
def _load_pytorch_safe(self):
"""SECURITY: Safe PyTorch loading with mandatory weights_only=True."""
try:
print("SECURITY: Attempting SAFE PyTorch loading with weights_only=True")
# SECURITY: Always use weights_only=True to prevent code injection
state_dict = torch.load(
self.path_to_model, map_location="cpu", weights_only=True
)
# Create a new model instance and load state dict safely
model = Model(self.classes)
if hasattr(model, "model"):
model.model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(state_dict, strict=False)
return model
except Exception as e:
print(f"SECURITY: Safe PyTorch loading failed: {e}")
raise Exception(
f"SECURITY ERROR: Could not load model safely. "
f"Model may be corrupted or use unsafe pickle objects. "
f"Error: {e}"
)
def _load_detecto_safe(self):
"""SECURITY: Safe detecto loading with integrity verification."""
try:
print("SECURITY: Attempting SAFE detecto loading with verification")
# SECURITY: Verify model integrity before loading
model_dir = os.path.dirname(self.path_to_model)
checksums_file = os.path.join(model_dir, "checksums.sha256")
if not os.path.exists(checksums_file):
raise Exception(
f"SECURITY ERROR: No checksums file found at {checksums_file}. "
f"Model integrity cannot be verified."
)
from label_processing.utils import verify_model_integrity
if not verify_model_integrity(self.path_to_model, checksums_file):
raise Exception(
f"SECURITY ERROR: Model integrity verification failed for {self.path_to_model}. "
f"Model may be corrupted or tampered with."
)
# Only load if integrity is verified
print("SECURITY: Model integrity verified, proceeding with safe loading")
# Monkey-patch torch.load to enforce weights_only=True
original_torch_load = torch.load
def safe_patched_load(*args, **kwargs):
kwargs["weights_only"] = True # SECURITY: Force safe loading
return original_torch_load(*args, **kwargs)
torch.load = safe_patched_load
try:
model = Model.load(self.path_to_model, self.classes)
return model
finally:
torch.load = original_torch_load
except Exception as e:
print(f"SECURITY: Safe detecto loading failed: {e}")
raise Exception(
f"SECURITY ERROR: Could not load model safely via detecto. "
f"Model integrity verification failed or model uses unsafe objects. "
f"Error: {e}"
)
[docs]
def class_prediction(self, jpg_path: Path = None) -> pd.DataFrame:
"""
Predict labels for a given JPG file.
Args:
jpg_path (Path): Path to the JPG file.
Returns:
pd.DataFrame: Pandas DataFrame with prediction results.
"""
if jpg_path is None:
jpg_path = self.jpg_path
# Validate the requested path
if jpg_path is None:
return pd.DataFrame()
jpg_path = Path(jpg_path)
if not is_image_file(jpg_path):
print(f"Skipping non-image or hidden file: {jpg_path}")
return pd.DataFrame()
try:
image = detecto.utils.read_image(str(jpg_path))
except Exception as e:
print(f"Skipping unreadable image {jpg_path}: {e}")
return pd.DataFrame()
try:
predictions = self.model.predict(image)
except Exception as e:
print(f"Prediction failed for {jpg_path}: {e}")
return pd.DataFrame()
labels, boxes, scores = predictions
entries = []
for i, labelname in enumerate(labels):
entry = {}
entry["filename"] = jpg_path.name
entry["class"] = labelname
entry["score"] = scores[i].item()
entry["xmin"] = boxes[i][0]
entry["ymin"] = boxes[i][1]
entry["xmax"] = boxes[i][2]
entry["ymax"] = boxes[i][3]
entries.append(entry)
return pd.DataFrame(entries)
[docs]
def prediction_parallel(
jpg_dir: Union[str, Path], predictor: PredictLabel, n_processes: int
) -> pd.DataFrame:
"""
Perform predictions for all JPG files in a directory with parallel processing.
Args:
jpg_dir (Path|str): Path to JPG files for prediction.
predictor (PredictLabel): Prediction instance.
n_processes (int): Number of processes for parallel execution.
Returns:
pd.DataFrame: Pandas DataFrame containing the predictions.
"""
if not isinstance(jpg_dir, Path):
jpg_dir = Path(jpg_dir)
# Collect image files while skipping hidden and macOS '._*' files
file_names: list[Path] = [p for p in sorted(jpg_dir.iterdir()) if is_image_file(p)]
# Validate readability with cv2 (some files can exist but be corrupted)
valid_files = []
for file in file_names:
img = cv2.imread(str(file))
if img is None:
print(f"Skipping corrupted or unreadable image: {file}")
else:
valid_files.append(file)
mp.set_start_method("spawn", force=True)
with mp.Pool(n_processes) as executor:
results = list(executor.map(predictor.class_prediction, valid_files))
# filter empty DataFrames and concatenate if any results exist
results = [r for r in results if isinstance(r, pd.DataFrame) and not r.empty]
if not results:
return pd.DataFrame()
return pd.concat(results, ignore_index=True)
[docs]
def clean_predictions(
jpg_dir: Path, dataframe: pd.DataFrame, threshold: float, out_dir=None
) -> pd.DataFrame:
"""
Filter predictions based on a threshold and save the results to a CSV file.
Args:
jpg_dir (Path): Path to the directory with JPG files.
dataframe (pd.DataFrame): Pandas DataFrame with predictions.
threshold (float): Threshold value for scores.
out_dir (str): Output directory for saving the CSV file.
Returns:
pd.DataFrame: Pandas DataFrame with filtered results.
"""
# Ensure jpg_dir is a Path object
jpg_dir = Path(jpg_dir)
print("\nFilter coordinates")
colnames = ["score", "xmin", "ymin", "xmax", "ymax"]
for header in colnames:
dataframe[header] = (
dataframe[header]
.astype("str")
.str.extractall(r"(\d+\.\d+)")
.unstack()
.fillna("")
.sum(axis=1)
.astype(float)
)
dataframe = dataframe.loc[dataframe["score"] >= threshold]
dataframe[["xmin", "ymin", "xmax", "ymax"]] = dataframe[
["xmin", "ymin", "xmax", "ymax"]
].fillna("0")
if out_dir is None:
parent_dir = jpg_dir.resolve().parent
else:
parent_dir = out_dir
filename = f"{jpg_dir.stem}_predictions.csv"
csv_path = f"{parent_dir}/{filename}"
dataframe.to_csv(csv_path)
print(f"\nThe csv_file {filename} has been successfully saved in {out_dir}")
return dataframe
# ---------------------Image Cropping---------------------#
[docs]
def crop_picture(img_raw: np.ndarray, path: str, filename: str, **coordinates) -> None:
"""
Crop the picture using the given coordinates.
Args:
img_raw (numpy.ndarray): Input JPG converted to a numpy matrix by cv2.
path (str): Path where the picture should be saved.
filename (str): Name of the picture.
coordinates: Coordinates for cropping.
"""
xmin = coordinates["xmin"]
ymin = coordinates["ymin"]
xmax = coordinates["xmax"]
ymax = coordinates["ymax"]
filepath = f"{path}/{filename}"
crop = img_raw[ymin:ymax, xmin:xmax]
cv2.imwrite(filepath, crop)
[docs]
def create_crops(
jpg_dir: Path, dataframe: pd.DataFrame, out_dir: Path = Path(os.getcwd())
) -> None:
"""
Creates crops by using the csv from applying the model and the original
pictures inside a directory.
Args:
jpg_dir (): path to directory with jpgs.
dataframe (str): path to csv file.
out_dir (Path): path to the target directory to save the cropped jpgs.
"""
dir_path = jpg_dir
out_dir = Path(out_dir)
new_dir_name = Path(dir_path).name + "_cropped"
path = out_dir.joinpath(new_dir_name)
path.mkdir(parents=True, exist_ok=True)
total_crops = 0
# iterate Path objects and skip hidden / '._*' files
# Get all image files (case-insensitive)
image_files = []
for pattern in ["*.jpg", "*.JPG", "*.jpeg", "*.JPEG", "*.png", "*.PNG"]:
image_files.extend(Path(dir_path).glob(pattern))
for p in sorted(image_files):
filepath = str(p)
if not p.exists():
print(f"File cannot be found: {filepath}")
continue
if not is_image_file(p):
print(f"Skipping hidden or non-image file: {filepath}")
continue
filename = os.path.basename(filepath)
match = dataframe[dataframe.filename == filename]
if match.empty:
print(f"No predictions for image: {filename}. Skipping...")
continue
image_raw = cv2.imread(filepath)
if image_raw is None:
print(f"Error: Impossible to read the image {filepath}. Corrupted file?")
continue
label_id = Path(filename).stem
label_occ = []
for _, row in match.iterrows():
occ = label_occ.count(label_id) + 1
new_filename = f"{label_id}_{occ}.jpg"
coordinates = {
"xmin": int(row.xmin),
"ymin": int(row.ymin),
"xmax": int(row.xmax),
"ymax": int(row.ymax),
}
crop_picture(image_raw, path, new_filename, **coordinates)
label_occ.append(label_id)
crops_for_this_image = len(glob.glob(os.path.join(path, f"{label_id}_*.jpg")))
total_crops += crops_for_this_image
print(f"{filename} generated {crops_for_this_image} crops")
print(f"\nTotal crops generated: {total_crops}")
print(f"\nThe images have been successfully saved in {path}")