# Import third-party libraries
import argparse
import os
import json
import string
import time
import warnings
import pandas as pd
import numpy as np
import gensim
from nltk import word_tokenize
from sklearn.manifold import TSNE
import plotly.express as px
from typing import Union, List, Dict
# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')
[docs]
def parse_arguments() -> argparse.Namespace:
"""
Parse command-line arguments and return the parsed arguments.
Returns:
argparse.Namespace: Parsed command-line arguments.
"""
usage = 'cluster_eval.py [-h] -c <path to cluster CSV> -gt <path to ground truth JSON> -o <output directory> -s <cluster size>'
parser = argparse.ArgumentParser(
description="Scatter plot of clusters using t-SNE.",
add_help = False,
usage = usage)
parser.add_argument('-c', '--cluster_csv', required=True, help='Path to cluster CSV file')
parser.add_argument('-gt', '--ground_truth', required=True, help='Path to ground truth JSON file')
parser.add_argument('-o', '--out_dir', default='outputs', help='Directory to save output files')
parser.add_argument('-s', '--cluster_size', type=int, default=1, help='Minimum cluster size to be plotted')
parser.add_argument('--verbose', action='store_true', help='Enable verbose logging')
return parser.parse_args()
[docs]
def is_word(token: str) -> bool:
"""
Checks whether a token is a valid word (not punctuation or too short).
Args:
token (str): The token to check.
Returns:
bool: True if the token is a valid word, False otherwise.
"""
return token not in string.punctuation and not token.isspace() and len(token) >= 3
[docs]
def tokenize_text(labels: Union[List[Dict[str, str]], Dict[str, tuple[str, str]]], ground_truth: bool) -> List[Dict[str, Union[str, List[str]]]]:
"""
Tokenizes and lowercases text fields from labels.
Args:
labels (List[Dict[str, str]] or Dict[str, tuple[str, str]]): Labels to tokenize.
ground_truth (bool): Whether the labels are ground truth data.
Returns:
List[Dict[str, Union[str, List[str]]]]: Tokenized labels with IDs.
"""
tokenized = []
for label in labels:
text = label["text"] if ground_truth else label[1]
tokens = [token.lower() for token in word_tokenize(text) if is_word(token)]
if tokens:
tokenized.append({"ID": label["ID"] if ground_truth else label[0], "tokens": tokens})
return tokenized
[docs]
def build_word_vectors(labels, ground_truth) -> tuple[gensim.models.Word2Vec, List[Dict[str, Union[str, List[str]]]]]:
"""
Builds a Word2Vec model from the tokenized labels.
Args:
labels (List[Dict[str, str]] or Dict[str, tuple[str, str]]): Labels to build vectors from.
ground_truth (bool): Whether the labels are ground truth data.
Returns:
tuple: A tuple containing the trained Word2Vec model and the tokenized labels.
"""
tokenized = tokenize_text(labels, ground_truth)
model = gensim.models.Word2Vec(
[label["tokens"] for label in tokenized],
vector_size=100, window=2, min_count=1, sg=1
)
return model, tokenized
[docs]
def build_mean_label_vector(model, labels) -> tuple[Dict[str, np.ndarray], List[str]]:
"""
Computes the mean vector for each label using the Word2Vec model.
Also tracks labels that have no valid tokens (and thus no vector).
Args:
model (gensim.models.Word2Vec): The trained Word2Vec model.
labels (List[Dict[str, List[str]]]): Tokenized labels with IDs.
Returns:
tuple: A tuple containing a dictionary of mean vectors and a list of skipped IDs.
"""
vectors = {}
skipped_ids = []
for label in labels:
tokens = [t for t in label["tokens"] if t in model.wv]
if tokens:
vectors[label["ID"]] = np.mean([model.wv[t] for t in tokens], axis=0)
else:
skipped_ids.append(label["ID"])
return vectors, skipped_ids
[docs]
def load_json(path: str) -> List[Dict[str, str]]:
"""
Loads the ground truth JSON file.
Args:
path (str): Path to the JSON file.
Returns:
List[Dict[str, str]]: List of entries with "ID" and "text" fields.
"""
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
return [entry for entry in data if isinstance(entry, dict) and "ID" in entry and "text" in entry]
[docs]
def load_cluster_csv(path: str) -> Dict[str, List[str]]:
"""
Loads cluster assignments from a CSV file.
Args:
path (str): Path to the CSV file.
Returns:
Dict[str, List[str]]: Dictionary mapping label IDs to their cluster ID and transcript.
Skips entries with missing "Transcript" or "Cluster_ID".
"""
df = pd.read_csv(path, sep=';')
return {
str(row["ID"]).strip(): [str(row["Cluster_ID"]), str(row["Transcript"])]
for _, row in df.iterrows()
if not pd.isna(row["Transcript"]) and not pd.isna(row["Cluster_ID"])
}
[docs]
def plot_tsne(label_vectors: Dict[str, np.ndarray], clusters: Dict[str, List[str]], out_path: str, verbose: bool, skipped_ids: List[str]):
"""
Generates and saves a t-SNE scatter plot with cluster coloring and hover text.
Also includes skipped labels (no vectors) as a separate "No vector" cluster.
Args:
label_vectors (Dict[str, np.ndarray]): Dictionary of label IDs to their mean vectors.
clusters (Dict[str, List[str]]): Dictionary mapping label IDs to their cluster ID and transcript.
out_path (str): Path to save the t-SNE plot HTML file.
verbose (bool): Whether to print verbose output.
skipped_ids (List[str]): List of label IDs that had no valid tokens and thus no vector.
Returns:
plotly.graph_objects.Figure: The generated t-SNE plot.
"""
# Add zero-vectors for skipped labels
for sid in skipped_ids:
label_vectors[sid] = np.zeros(100) # Same vector size as Word2Vec
clusters[sid] = ["No vector", "(no valid text)"]
vectors = np.array(list(label_vectors.values()))
ids = list(label_vectors.keys())
cluster_ids = [clusters[i][0] if i in clusters else "Unassigned" for i in ids]
transcripts = [clusters[i][1] if i in clusters else "(no text found)" for i in ids]
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(vectors)
df = pd.DataFrame(tsne_results, columns=['x', 'y'])
df["label"] = ids
df["cluster"] = cluster_ids
df["text"] = transcripts
fig = px.scatter(
df,
x="x",
y="y",
color="cluster",
hover_data=["label", "text"],
title="t-SNE Cluster Visualization for MfN_LEP_SEASIA"
)
fig.write_html(out_path)
if verbose:
print(f"t-SNE plot saved at {out_path}")
print(f"Skipped {len(skipped_ids)} labels with no vector:")
for i in skipped_ids[:10]: # Show up to 10 skipped
print(f" - {i}")
if len(skipped_ids) > 10:
print(f" ... and {len(skipped_ids) - 10} more.")
return fig
[docs]
def main(args):
"""
Main entry point for clustering visualization.
Loads data, trains embeddings, computes vectors, runs t-SNE, and saves plot.
Args:
args (argparse.Namespace): Parsed command-line arguments.
Returns:
None
"""
os.makedirs(args.out_dir, exist_ok=True)
start_time = time.time()
try:
if args.verbose:
print(f"Loading data...")
gt_data = load_json(args.ground_truth)
cluster_data = load_cluster_csv(args.cluster_csv)
# Normalize keys
gt_dict = {entry["ID"].strip(): (None, entry["text"]) for entry in gt_data}
cluster_data = {k.strip(): v for k, v in cluster_data.items()}
# Merge cluster IDs into GT labels
for label_id in gt_dict:
if label_id in cluster_data:
gt_dict[label_id] = (cluster_data[label_id][0], gt_dict[label_id][1])
label_list = [{"ID": k, "text": v[1]} for k, v in gt_dict.items()]
if args.verbose:
print(f"Building Word2Vec model on {len(label_list)} labels...")
model, tokenized = build_word_vectors(label_list, ground_truth=True)
mean_vectors, skipped_ids = build_mean_label_vector(model, tokenized)
out_path = os.path.join(args.out_dir, "cluster_visualization.html")
plot_tsne(mean_vectors, cluster_data, out_path, args.verbose, skipped_ids)
if args.verbose:
print(f"Finished in {time.time() - start_time:.2f}s")
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
args = parse_arguments()
main(args)