Source code for label_evaluation.evaluate_text

# Import third-party libraries
import jiwer
import json
import csv
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import warnings
import difflib

# Suppress warning messages during execution
warnings.filterwarnings("ignore")


[docs] class EmptyReferenceError(Exception): """ Custom exception for handling cases where the reference string is empty. """ def __init__(self, message=None): self.message = message or "The reference string is empty." super().__init__(self.message)
[docs] def calculate_cer(reference: list, hypothesis: list) -> float: """ Calculate the Character Error Rate (CER) between reference and hypothesis. Args: reference (list): List of reference (ground truth) strings. hypothesis (list): List of hypothesis (predicted) strings. Returns: float: The computed CER value. """ if not reference or len(reference[0]) == 0: return 0.0 # Calculate edit distance using difflib ref_chars = list(reference[0]) hyp_chars = list(hypothesis[0]) matcher = difflib.SequenceMatcher(None, ref_chars, hyp_chars) edit_distance = len(ref_chars) + len(hyp_chars) - 2 * sum(triple.size for triple in matcher.get_matching_blocks()) reference_length = len(reference[0]) return edit_distance / reference_length
[docs] def get_gold_transcriptions(filename: str, sep: str = ",") -> dict: """ Load ground truth transcriptions from a CSV file into a dictionary. Args: filename (str): Path to the CSV file. sep (str, optional): Delimiter used in the CSV file. Defaults to ','. Returns: dict: Dictionary with keys as unique identifiers and values as transcription text. """ gold_transcriptions = {} try: with open(filename, encoding="utf-8-sig") as file_in: csv_reader = csv.reader(file_in, delimiter=sep) next(csv_reader) # Skip header for line_number, line in enumerate(csv_reader, start=2): if len(line) != 2: print(f"Skipping malformed line {line_number}: {line}") continue line = [field.strip() for field in line] gold_transcriptions[line[0]] = line[1] return gold_transcriptions except Exception as e: print(f"Error loading ground truth CSV: {e}") return {}
[docs] def load_json_predictions(filename: str) -> list: """ Load predictions from a JSON file. Args: filename (str): Path to the JSON file. Returns: list: List of predictions from the JSON file. """ try: with open(filename, "r", encoding="utf-8-sig") as f: return json.load(f) except Exception as e: print(f"Error loading JSON predictions: {e}") return []
[docs] def calculate_scores(gold_text: str, predicted_text: str) -> tuple: """ Calculate Word Error Rate (WER) and Character Error Rate (CER) between ground truth and prediction. Args: gold_text (str): Ground truth transcription. predicted_text (str): Predicted transcription. Returns: tuple: (WER, CER) both rounded to two decimal places. """ gold_text, predicted_text = gold_text.lower(), predicted_text.lower() if not gold_text or gold_text.isspace(): raise EmptyReferenceError() # Use the new jiwer API - process_words returns more detailed output output = jiwer.process_words(gold_text, predicted_text) wer = round(output.wer, 2) cer = round(calculate_cer([gold_text], [predicted_text]), 2) return wer, cer
[docs] def create_plot(data: list, score_name: str, file_name: str) -> None: """ Create and save a violin plot for the given error scores. Args: data (list): List of numerical scores to visualize. score_name (str): Name of the score (e.g., "CER" or "WER"). file_name (str): Path to save the plot image. """ plt.figure(figsize=(10, 6)) df = pd.DataFrame(data, columns=[score_name]) sns.violinplot(data=df, inner="box", cut=1, palette="Set2") plt.axhline( df[score_name].mean(), color="r", linestyle="--", label=f"Mean: {df[score_name].mean():.2f}", ) plt.axhline( df[score_name].median(), color="g", linestyle="-", label=f"Median: {df[score_name].median():.2f}", ) plt.title(f"Distribution of {score_name}", fontsize=24) plt.xlabel(score_name, fontsize=20) plt.ylabel("Density", fontsize=20) plt.tick_params(axis='both', which='major', labelsize=18) plt.legend(fontsize=16) plt.savefig(file_name, dpi=300) plt.close() print(f"Plot saved as {file_name}")
[docs] def evaluate_text_predictions( ground_truth_file: str, predictions_file: str, out_dir: str ) -> tuple: """ Evaluate OCR predictions against a ground truth dataset. Args: ground_truth_file (str): Path to the ground truth CSV file. predictions_file (str): Path to the predictions JSON file. out_dir (str): Output directory for results. Returns: tuple: (List of WER scores, List of CER scores) """ try: ground_truth = get_gold_transcriptions(ground_truth_file) generated_transcriptions = load_json_predictions(predictions_file) wers, cers = [], [] output_csv = f"{out_dir}/ocr_evaluation.csv" with open(output_csv, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["ID", "Gold", "Predicted", "WER", "CER"]) for entry in generated_transcriptions: transcript_id = entry["ID"].strip().lower() if transcript_id in ground_truth: gold, predicted = ground_truth[transcript_id], entry["text"].strip() try: wer, cer = calculate_scores(gold, predicted) wers.append(wer) cers.append(cer) writer.writerow([entry["ID"], gold, predicted, wer, cer]) except EmptyReferenceError as e: print( f"Skipping ID '{entry['ID']}' due to empty reference: {e}" ) create_plot(cers, "CERs", f"{out_dir}/cers.png") create_plot(wers, "WERs", f"{out_dir}/wers.png") return wers, cers except Exception as e: print(f"Error during evaluation: {e}") return [], []