Source code for scripts.evaluation.classifiers_eval

# Import the necessary module from the 'label_evaluation' module package
import label_evaluation.accuracy_classifier

# Import third-party libraries
import argparse
import os
import warnings
import pandas as pd
import time
from pathlib import Path

# Suppress warning messages during execution
warnings.filterwarnings('ignore')


[docs] def parse_arguments() -> argparse.Namespace: """ Parse command-line arguments and return the parsed arguments. Returns: argparse.Namespace: Parsed command-line arguments. """ usage = 'classifiers_eval.py [-h] -o <path to outputs> -d <path to gt_dataframe>' parser = argparse.ArgumentParser( description="Execute the accuracy_classifier.py module.", add_help = False, usage = usage) parser.add_argument('-h','--help', action='help', help='Description of the command-line arguments.') parser.add_argument('-o', '--out_dir', metavar='', type=str, default=os.getcwd(), help=('Directory in which the accuracy scores and plots will be stored. ' 'Default is the current working directory.')) parser.add_argument('-d', '--df', metavar='', type=str, required=True, help=('Path to the input ground truth CSV file.')) parser.add_argument('-t', '--title', metavar='', type=str, default='Classifier', help=('Title for the confusion matrix plot.')) return parser.parse_args()
[docs] def main(): """ Main function to evaluate classifier accuracy and generate reports. """ start_time = time.time() args = parse_arguments() out_dir = Path(args.out_dir) try: df = pd.read_csv(args.df, sep=';') if "pred" not in df.columns or "gt" not in df.columns: raise ValueError("CSV file must contain 'pred' and 'gt' columns.") except FileNotFoundError: print(f"Error: File '{args.df}' not found.") return except pd.errors.ParserError: print(f"Error: Failed to parse CSV file '{args.df}'.") return except ValueError as e: print(f"Error: {e}") return pred = df["pred"] gt = df["gt"] # Clean up labels for better readability def clean_label(x): x = str(x).replace('MfN_', '') x = x.replace('not_nuri', 'Not Specimen Identifier') x = x.replace('nuri', 'Specimen Identifier') return x pred_clean = pred.apply(clean_label) gt_clean = gt.apply(clean_label) target = [clean_label(t) for t in gt.unique().tolist()] try: label_evaluation.accuracy_classifier.metrics(target, pred_clean, gt_clean, out_dir=out_dir) label_evaluation.accuracy_classifier.cm(target, pred_clean, gt_clean, out_dir=out_dir, title=args.title) except Exception as e: print(f"Error during classification evaluation: {e}") return print(f"Finished in {round(time.time() - start_time, 2)} seconds")
if __name__ == '__main__': main()