# Import third-party libraries
import jiwer
import json
import csv
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import warnings
import difflib
# Suppress warning messages during execution
warnings.filterwarnings("ignore")
[docs]
class EmptyReferenceError(Exception):
"""
Custom exception for handling cases where the reference string is empty.
"""
def __init__(self, message=None):
self.message = message or "The reference string is empty."
super().__init__(self.message)
[docs]
def calculate_cer(reference: list, hypothesis: list) -> float:
"""
Calculate the Character Error Rate (CER) between reference and hypothesis.
Args:
reference (list): List of reference (ground truth) strings.
hypothesis (list): List of hypothesis (predicted) strings.
Returns:
float: The computed CER value.
"""
if not reference or len(reference[0]) == 0:
return 0.0
# Calculate edit distance using difflib
ref_chars = list(reference[0])
hyp_chars = list(hypothesis[0])
matcher = difflib.SequenceMatcher(None, ref_chars, hyp_chars)
edit_distance = len(ref_chars) + len(hyp_chars) - 2 * sum(triple.size for triple in matcher.get_matching_blocks())
reference_length = len(reference[0])
return edit_distance / reference_length
[docs]
def get_gold_transcriptions(filename: str, sep: str = ",") -> dict:
"""
Load ground truth transcriptions from a CSV file into a dictionary.
Args:
filename (str): Path to the CSV file.
sep (str, optional): Delimiter used in the CSV file. Defaults to ','.
Returns:
dict: Dictionary with keys as unique identifiers and values as transcription text.
"""
gold_transcriptions = {}
try:
with open(filename, encoding="utf-8-sig") as file_in:
csv_reader = csv.reader(file_in, delimiter=sep)
next(csv_reader) # Skip header
for line_number, line in enumerate(csv_reader, start=2):
if len(line) != 2:
print(f"Skipping malformed line {line_number}: {line}")
continue
line = [field.strip() for field in line]
gold_transcriptions[line[0]] = line[1]
return gold_transcriptions
except Exception as e:
print(f"Error loading ground truth CSV: {e}")
return {}
[docs]
def load_json_predictions(filename: str) -> list:
"""
Load predictions from a JSON file.
Args:
filename (str): Path to the JSON file.
Returns:
list: List of predictions from the JSON file.
"""
try:
with open(filename, "r", encoding="utf-8-sig") as f:
return json.load(f)
except Exception as e:
print(f"Error loading JSON predictions: {e}")
return []
[docs]
def calculate_scores(gold_text: str, predicted_text: str) -> tuple:
"""
Calculate Word Error Rate (WER) and Character Error Rate (CER) between ground truth and prediction.
Args:
gold_text (str): Ground truth transcription.
predicted_text (str): Predicted transcription.
Returns:
tuple: (WER, CER) both rounded to two decimal places.
"""
gold_text, predicted_text = gold_text.lower(), predicted_text.lower()
if not gold_text or gold_text.isspace():
raise EmptyReferenceError()
# Use the new jiwer API - process_words returns more detailed output
output = jiwer.process_words(gold_text, predicted_text)
wer = round(output.wer, 2)
cer = round(calculate_cer([gold_text], [predicted_text]), 2)
return wer, cer
[docs]
def create_plot(data: list, score_name: str, file_name: str) -> None:
"""
Create and save a violin plot for the given error scores.
Args:
data (list): List of numerical scores to visualize.
score_name (str): Name of the score (e.g., "CER" or "WER").
file_name (str): Path to save the plot image.
"""
plt.figure(figsize=(10, 6))
df = pd.DataFrame(data, columns=[score_name])
sns.violinplot(data=df, inner="box", cut=1, palette="Set2")
plt.axhline(
df[score_name].mean(),
color="r",
linestyle="--",
label=f"Mean: {df[score_name].mean():.2f}",
)
plt.axhline(
df[score_name].median(),
color="g",
linestyle="-",
label=f"Median: {df[score_name].median():.2f}",
)
plt.title(f"Distribution of {score_name}", fontsize=24)
plt.xlabel(score_name, fontsize=20)
plt.ylabel("Density", fontsize=20)
plt.tick_params(axis='both', which='major', labelsize=18)
plt.legend(fontsize=16)
plt.savefig(file_name, dpi=300)
plt.close()
print(f"Plot saved as {file_name}")
[docs]
def evaluate_text_predictions(
ground_truth_file: str, predictions_file: str, out_dir: str
) -> tuple:
"""
Evaluate OCR predictions against a ground truth dataset.
Args:
ground_truth_file (str): Path to the ground truth CSV file.
predictions_file (str): Path to the predictions JSON file.
out_dir (str): Output directory for results.
Returns:
tuple: (List of WER scores, List of CER scores)
"""
try:
ground_truth = get_gold_transcriptions(ground_truth_file)
generated_transcriptions = load_json_predictions(predictions_file)
wers, cers = [], []
output_csv = f"{out_dir}/ocr_evaluation.csv"
with open(output_csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["ID", "Gold", "Predicted", "WER", "CER"])
for entry in generated_transcriptions:
transcript_id = entry["ID"].strip().lower()
if transcript_id in ground_truth:
gold, predicted = ground_truth[transcript_id], entry["text"].strip()
try:
wer, cer = calculate_scores(gold, predicted)
wers.append(wer)
cers.append(cer)
writer.writerow([entry["ID"], gold, predicted, wer, cer])
except EmptyReferenceError as e:
print(
f"Skipping ID '{entry['ID']}' due to empty reference: {e}"
)
create_plot(cers, "CERs", f"{out_dir}/cers.png")
create_plot(wers, "WERs", f"{out_dir}/wers.png")
return wers, cers
except Exception as e:
print(f"Error during evaluation: {e}")
return [], []