Source code for scripts.processing.classifiers

#!/usr/bin/env python3

# Import third-party libraries
import argparse
import os
import sys
import warnings
import time
import logging
from pathlib import Path

# Add project root to Python path
current_dir = Path(__file__).parent.absolute()
project_root = current_dir.parent.parent
sys.path.insert(0, str(project_root))

# Import project configuration
from label_processing.config import get_model_path, config

# Import the necessary module from the 'label_processing' module package
import label_processing.tensorflow_classifier

# Suppress warnings
warnings.filterwarnings('ignore')

# Configure logging
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')


[docs] def parse_arguments() -> argparse.Namespace: """ Parse command-line arguments for the classification script. Returns: argparse.Namespace: Parsed command-line arguments. """ usage = 'python classifiers.py -m <select model> -j <path to jpgs> -o <path to output directory>' parser = argparse.ArgumentParser( description="Classify JPEG images using a pre-trained TensorFlow model.", usage=usage, add_help=False, formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument('-h', '--help', action='help', help='Show this help message and exit.') parser.add_argument( '-m', '--model', type=int, choices=range(1, 4), help=( 'Select a built-in classifier model:\n' ' 1: identifier / not_identifier\n' ' 2: handwritten / printed\n' ' 3: multi-label / single-label' ) ) parser.add_argument( '-j', '--jpg_dir', type=str, required=True, help='Directory containing input JPEG images.' ) parser.add_argument( '-o', '--out_dir', type=str, default=os.getcwd(), help='Directory to store outputs (default: current working directory).' ) return parser.parse_args()
[docs] def resolve_default_model_path(model_int: int) -> Path: """ Get the default model path based on model number using centralized configuration. Args: model_int (int): Model number (1-3) Returns: Path: Path to the default model """ model_type_mapping = { 1: "identifier", 2: "handwritten_printed", 3: "multi_single" } model_type = model_type_mapping.get(model_int) if not model_type: raise ValueError(f"No model type found for model number {model_int}") try: return get_model_path(model_type) except Exception as e: raise ValueError(f"Error getting model path for model {model_int}: {e}")
[docs] def get_class_names_by_model(model_int: int) -> list[str]: """ Return default class names for the selected model number using centralized configuration. Args: model_int (int): Model number (1-3) Returns: list[str]: Class labels """ model_type_mapping = { 1: "identifier", 2: "handwritten_printed", 3: "multi_single" } model_type = model_type_mapping.get(model_int) if not model_type: return None return config.get_class_names(model_type)
[docs] def load_class_names_from_file(path: str) -> list[str]: """ Load class names from a text file (one per line). Args: path (str): Path to the class names file. Returns: list[str]: List of class names. """ if not os.path.isfile(path): raise FileNotFoundError(f"Classes file not found at: {path}") with open(path, "r") as f: return f.read().splitlines()
[docs] def main() -> None: """ Main function to execute classification using a TensorFlow model. """ start_time = time.time() args = parse_arguments() if not os.path.isdir(args.jpg_dir): raise NotADirectoryError(f"JPEG input directory does not exist: {args.jpg_dir}") # Ensure output directory exists if not os.path.exists(args.out_dir): logging.info(f"Creating output directory: {args.out_dir}") os.makedirs(args.out_dir, exist_ok=True) if args.model: model_path = resolve_default_model_path(args.model) class_names = get_class_names_by_model(args.model) if not class_names: raise ValueError(f"No class names found for model {args.model}") else: raise ValueError("You must provide a model number (-m 1, 2, or 3).") logging.info("Loading model...") model = label_processing.tensorflow_classifier.get_model(model_path) logging.info("Classifying images...") df = label_processing.tensorflow_classifier.class_prediction( model=model, class_names=class_names, jpg_dir=args.jpg_dir, out_dir=args.out_dir ) logging.info("Saving classified images...") label_processing.tensorflow_classifier.filter_pictures( jpg_dir=args.jpg_dir, dataframe=df, out_dir=args.out_dir ) duration = time.time() - start_time logging.info(f"Classification completed in {duration:.2f} seconds.")
if __name__ == '__main__': try: main() except Exception as e: logging.error(f"Error: {e}") sys.exit(1)