Source code for scripts.postprocessing.consolidate_results

#!/usr/bin/env python3
"""
Consolidate Pipeline Results Script

Creates a single JSON file that links all per-label results across the pipeline
(detection → classification → rotation → OCR → post‑processing).

Supports both the traditional (TensorFlow-based) pipeline and the Gemini pipeline.
Output is a flat list of per-label entries, each containing: ``source_image``,
``label_filename``, ``label_index``, ``category``, bounding-box coordinates,
``rotation_angle``, and ``ocr`` (method, text, confidence).
"""

import json
import csv
import os
import argparse
import sys
from pathlib import Path
from typing import Dict, List, Any
import glob

# Add project root to Python path
current_dir = Path(__file__).parent.absolute()
project_root = current_dir.parent.parent
sys.path.insert(0, str(project_root))


[docs] def parse_arguments() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Consolidate all pipeline results into a single JSON file." ) parser.add_argument( '-o', '--output-dir', type=str, required=True, help='Output directory containing pipeline results.' ) parser.add_argument( '-f', '--filename', type=str, default='consolidated_results.json', help='Output filename (default: consolidated_results.json).' ) return parser.parse_args()
# ---- Gemini pipeline loaders ------------------------------------------------ def _load_gemini_classification(output_dir: str) -> List[Dict[str, Any]]: """Load per-label classification from gemini_classification.json. Returns a list of label dicts with keys: source_image, label_filename, label_index, category, rotation_angle, confidence, xmin/ymin/xmax/ymax. """ json_path = os.path.join(output_dir, 'gemini_classification.json') if not os.path.exists(json_path): return [] try: with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, list): return data except Exception as e: print(f"Warning: could not load {json_path}: {e}") return [] def _load_gemini_ocr(output_dir: str) -> Dict[str, Dict[str, Any]]: """Load Gemini OCR results. Returns {label_filename: {text, confidence}}.""" path = os.path.join(output_dir, 'ocr_gemini.json') if not os.path.exists(path): return {} try: with open(path, 'r', encoding='utf-8') as f: data = json.load(f) results = {} if isinstance(data, list): for item in data: fid = item.get('ID', '') if fid: results[fid] = { 'method': 'gemini', 'text': item.get('text', ''), 'confidence': item.get('confidence', None) } return results except Exception as e: print(f"Warning: could not load {path}: {e}") return {} def _consolidate_gemini(output_dir: str) -> List[Dict[str, Any]]: """Build consolidated entries from Gemini pipeline outputs.""" labels = _load_gemini_classification(output_dir) ocr_map = _load_gemini_ocr(output_dir) # Also try corrected transcripts (post-processing output) corrected = _load_corrected_transcripts(output_dir) consolidated = [] for lbl in labels: label_fn = lbl.get('label_filename', '') ocr_entry = ocr_map.get(label_fn, {}) corrected_entry = corrected.get(label_fn, {}) entry = { 'source_image': lbl.get('source_image', ''), 'label_filename': label_fn, 'label_index': lbl.get('label_index', ''), 'category': lbl.get('category', 'unknown'), 'bbox': { 'xmin': lbl.get('xmin', 0), 'ymin': lbl.get('ymin', 0), 'xmax': lbl.get('xmax', 0), 'ymax': lbl.get('ymax', 0), }, 'rotation_angle': lbl.get('rotation_angle', 0), 'detection_confidence': lbl.get('confidence', None), 'ocr': { 'method': ocr_entry.get('method', 'gemini'), 'text': corrected_entry.get('text', ocr_entry.get('text', '')), 'raw_text': ocr_entry.get('text', ''), 'confidence': ocr_entry.get('confidence', None), }, } consolidated.append(entry) return consolidated # ---- Traditional pipeline loaders ------------------------------------------- def _load_detection_results(output_dir: str) -> Dict[str, Dict[str, Any]]: """Load detection results from input_predictions.csv.""" results: Dict[str, Dict[str, Any]] = {} path = os.path.join(output_dir, 'input_predictions.csv') if not os.path.exists(path): return results try: with open(path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: filename = row.get('filename', '') if not filename: continue bbox = [] confidence = 0.0 try: if all(k in row for k in ('xmin', 'ymin', 'xmax', 'ymax')): bbox = [float(row['xmin']), float(row['ymin']), float(row['xmax']), float(row['ymax'])] if 'score' in row: confidence = float(row['score']) except (ValueError, KeyError): pass results.setdefault(filename, []).append({ 'coordinates': bbox, 'confidence': confidence, }) except Exception as e: print(f"Warning: could not load {path}: {e}") return results def _determine_classification(filename: str, output_dir: str) -> str: """Return the category string for a label file (directory-based check).""" for category in ('empty', 'identifier', 'handwritten', 'printed'): dir_path = os.path.join(output_dir, category) if os.path.exists(dir_path): if os.path.exists(os.path.join(dir_path, filename)) or \ glob.glob(os.path.join(dir_path, filename) + '*'): return category return 'unknown' def _load_rotation_results(output_dir: str) -> Dict[str, Dict[str, Any]]: """Load rotation correction metadata.""" results: Dict[str, Dict[str, Any]] = {} meta_files = [ os.path.join(output_dir, 'rotation_metadata.csv'), os.path.join(output_dir, 'printed_preprocessed', 'rotation_metadata.csv'), os.path.join(output_dir, 'printed_rotated', 'rotation_metadata.csv'), ] for meta_file in meta_files: if os.path.exists(meta_file): try: with open(meta_file, 'r', encoding='utf-8') as f: for row in csv.DictReader(f): fn = row.get('filename', '') if fn: results[fn] = { 'angle': float(row.get('angle', 0)), 'corrected': str(row.get('corrected', 'False')).lower() == 'true', } return results except Exception as e: print(f"Warning: could not load {meta_file}: {e}") return results def _load_traditional_ocr(output_dir: str) -> Dict[str, Dict[str, Any]]: """Load OCR results from tesseract / google-vision JSON files.""" ocr_results: Dict[str, Dict[str, Any]] = {} for ocr_file, method in [('ocr_preprocessed.json', 'tesseract'), ('ocr_google_vision.json', 'google_vision'), ('ocr_results.json', 'unknown')]: path = os.path.join(output_dir, ocr_file) if not os.path.exists(path): continue try: with open(path, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, list): for item in data: fid = item.get('ID', '') if fid: ocr_results[fid] = { 'method': method, 'text': item.get('text', ''), 'confidence': item.get('confidence', None), } break # use first available file except Exception as e: print(f"Warning: could not load {path}: {e}") return ocr_results def _load_corrected_transcripts(output_dir: str) -> Dict[str, Dict[str, Any]]: """Load corrected / plausible transcripts from post-processing.""" results: Dict[str, Dict[str, Any]] = {} for name in ('corrected_transcripts.json', 'plausible_transcripts.json'): path = os.path.join(output_dir, name) if os.path.exists(path): try: with open(path, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, list): for item in data: fid = item.get('ID', '') if fid and fid not in results: results[fid] = {'text': item.get('text', '')} except Exception: pass return results def _consolidate_traditional(output_dir: str) -> List[Dict[str, Any]]: """Build consolidated entries from traditional pipeline outputs.""" detection = _load_detection_results(output_dir) rotation = _load_rotation_results(output_dir) ocr_map = _load_traditional_ocr(output_dir) corrected = _load_corrected_transcripts(output_dir) consolidated = [] # Iterate over labels that have OCR results for label_fn in sorted(ocr_map.keys()): ocr_entry = ocr_map[label_fn] corrected_entry = corrected.get(label_fn, {}) rot = rotation.get(label_fn, {}) category = _determine_classification(label_fn, output_dir) # Try to find the source image from detection results source_image = '' det_bbox = {} det_conf = None for src_img, det_list in detection.items(): # Detection CSV groups labels under the source image filename if label_fn.startswith(Path(src_img).stem): source_image = src_img break entry = { 'source_image': source_image, 'label_filename': label_fn, 'label_index': '', 'category': category, 'bbox': {}, 'rotation_angle': rot.get('angle', 0), 'detection_confidence': det_conf, 'ocr': { 'method': ocr_entry.get('method', ''), 'text': corrected_entry.get('text', ocr_entry.get('text', '')), 'raw_text': ocr_entry.get('text', ''), 'confidence': ocr_entry.get('confidence', None), }, } consolidated.append(entry) return consolidated # ---- Main consolidation ------------------------------------------------------
[docs] def consolidate_results(output_dir: str) -> List[Dict[str, Any]]: """Auto-detect pipeline type and consolidate all results.""" print("Loading pipeline results...") # Prefer Gemini outputs when available gemini_json = os.path.join(output_dir, 'gemini_classification.json') if os.path.exists(gemini_json): print("Detected Gemini pipeline outputs") results = _consolidate_gemini(output_dir) else: print("Detected traditional pipeline outputs") results = _consolidate_traditional(output_dir) print(f"Consolidated {len(results)} label entries") return results
[docs] def main(): """Main entry point.""" args = parse_arguments() if not os.path.exists(args.output_dir): print(f"Error: Output directory {args.output_dir} does not exist.") return print(f"Consolidating results from: {args.output_dir}") consolidated = consolidate_results(args.output_dir) output_file = os.path.join(args.output_dir, args.filename) try: with open(output_file, 'w', encoding='utf-8') as f: json.dump(consolidated, f, indent=2, ensure_ascii=False) print(f"Saved to: {output_file}") # Summary categories = {} for r in consolidated: cat = r.get('category', 'unknown') categories[cat] = categories.get(cat, 0) + 1 ocr_count = sum(1 for r in consolidated if r.get('ocr', {}).get('text')) print(f"\n=== Summary ===") print(f"Total labels: {len(consolidated)}") for cat, count in sorted(categories.items()): print(f" {cat}: {count}") print(f"Labels with OCR text: {ocr_count}") except Exception as e: print(f"Error saving consolidated results: {e}")
if __name__ == "__main__": main()