|
|
"""Utilities for persisting and aggregating GIFT-Eval results.""" |
|
|
|
|
|
import argparse |
|
|
import csv |
|
|
import glob |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
from src.gift_eval.constants import ( |
|
|
ALL_DATASETS, |
|
|
DATASET_PROPERTIES, |
|
|
MED_LONG_DATASETS, |
|
|
PRETTY_NAMES, |
|
|
STANDARD_METRIC_NAMES, |
|
|
) |
|
|
from src.gift_eval.core import DatasetMetadata, EvaluationItem |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _ensure_results_csv(csv_file_path: Path) -> None: |
|
|
if not csv_file_path.exists(): |
|
|
csv_file_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(csv_file_path, "w", newline="") as csvfile: |
|
|
writer = csv.writer(csvfile) |
|
|
header = ( |
|
|
["dataset", "model"] |
|
|
+ [f"eval_metrics/{name}" for name in STANDARD_METRIC_NAMES] |
|
|
+ ["domain", "num_variates"] |
|
|
) |
|
|
writer.writerow(header) |
|
|
|
|
|
|
|
|
def write_results_to_disk( |
|
|
items: list[EvaluationItem], |
|
|
dataset_name: str, |
|
|
output_dir: Path, |
|
|
model_name: str, |
|
|
create_plots: bool, |
|
|
) -> None: |
|
|
output_dir = output_dir / dataset_name |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
output_csv_path = output_dir / "results.csv" |
|
|
_ensure_results_csv(output_csv_path) |
|
|
|
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
except ImportError: |
|
|
plt = None |
|
|
|
|
|
with open(output_csv_path, "a", newline="") as csvfile: |
|
|
writer = csv.writer(csvfile) |
|
|
for item in items: |
|
|
md: DatasetMetadata = item.dataset_metadata |
|
|
metric_values: list[float | None] = [] |
|
|
for metric_name in STANDARD_METRIC_NAMES: |
|
|
value = item.metrics.get(metric_name, None) |
|
|
if value is None: |
|
|
metric_values.append(None) |
|
|
else: |
|
|
if hasattr(value, "__len__") and not isinstance(value, (str, bytes)) and len(value) == 1: |
|
|
value = value[0] |
|
|
elif hasattr(value, "item"): |
|
|
value = value.item() |
|
|
metric_values.append(value) |
|
|
|
|
|
ds_key = md.key.lower() |
|
|
props = DATASET_PROPERTIES.get(ds_key, {}) |
|
|
domain = props.get("domain", "unknown") |
|
|
num_variates = props.get("num_variates", 1 if md.to_univariate else md.target_dim) |
|
|
|
|
|
row = [md.full_name, model_name] + metric_values + [domain, num_variates] |
|
|
writer.writerow(row) |
|
|
|
|
|
if create_plots and item.figures and plt is not None: |
|
|
plots_dir = output_dir / "plots" / md.key / md.term |
|
|
plots_dir.mkdir(parents=True, exist_ok=True) |
|
|
for fig, filename in item.figures: |
|
|
filepath = plots_dir / filename |
|
|
fig.savefig(filepath, dpi=300, bbox_inches="tight") |
|
|
plt.close(fig) |
|
|
|
|
|
logger.info( |
|
|
"Evaluation complete for dataset '%s'. Results saved to %s", |
|
|
dataset_name, |
|
|
output_csv_path, |
|
|
) |
|
|
if create_plots: |
|
|
logger.info("Plots saved under %s", output_dir / "plots") |
|
|
|
|
|
|
|
|
def get_all_datasets_full_name() -> list[str]: |
|
|
"""Get all possible dataset full names for validation.""" |
|
|
|
|
|
terms = ["short", "medium", "long"] |
|
|
datasets_full_names: list[str] = [] |
|
|
|
|
|
for name in ALL_DATASETS: |
|
|
for term in terms: |
|
|
if term in ["medium", "long"] and name not in MED_LONG_DATASETS: |
|
|
continue |
|
|
|
|
|
if "/" in name: |
|
|
ds_key, ds_freq = name.split("/") |
|
|
ds_key = ds_key.lower() |
|
|
ds_key = PRETTY_NAMES.get(ds_key, ds_key) |
|
|
else: |
|
|
ds_key = name.lower() |
|
|
ds_key = PRETTY_NAMES.get(ds_key, ds_key) |
|
|
ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency") |
|
|
|
|
|
datasets_full_names.append(f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}") |
|
|
|
|
|
return datasets_full_names |
|
|
|
|
|
|
|
|
def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None: |
|
|
"""Aggregate results from multiple CSV files into a single dataframe.""" |
|
|
|
|
|
result_root = Path(result_root_dir) |
|
|
|
|
|
logger.info("Aggregating results in: %s", result_root) |
|
|
|
|
|
result_files = glob.glob(f"{result_root}/**/results.csv", recursive=True) |
|
|
|
|
|
if not result_files: |
|
|
logger.error("No result files found!") |
|
|
return None |
|
|
|
|
|
dataframes: list[pd.DataFrame] = [] |
|
|
for file in result_files: |
|
|
try: |
|
|
df = pd.read_csv(file) |
|
|
if len(df) > 0: |
|
|
dataframes.append(df) |
|
|
else: |
|
|
logger.warning("Empty file: %s", file) |
|
|
except pd.errors.EmptyDataError: |
|
|
logger.warning("Skipping empty file: %s", file) |
|
|
except Exception as exc: |
|
|
logger.error("Error reading %s: %s", file, exc) |
|
|
|
|
|
if not dataframes: |
|
|
logger.warning("No valid CSV files found to combine") |
|
|
return None |
|
|
|
|
|
combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset") |
|
|
|
|
|
if len(combined_df) != len(set(combined_df.dataset)): |
|
|
duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist() |
|
|
logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets) |
|
|
combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first") |
|
|
logger.info("Removed duplicates, %s unique datasets remaining", len(combined_df)) |
|
|
|
|
|
logger.info("Combined results: %s datasets", len(combined_df)) |
|
|
|
|
|
all_datasets_full_name = get_all_datasets_full_name() |
|
|
completed_experiments = combined_df.dataset.tolist() |
|
|
|
|
|
completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name] |
|
|
missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean] |
|
|
|
|
|
logger.info("=== EXPERIMENT SUMMARY ===") |
|
|
logger.info("Total expected datasets: %s", len(all_datasets_full_name)) |
|
|
logger.info("Completed experiments: %s", len(completed_experiments_clean)) |
|
|
logger.info("Missing/failed experiments: %s", len(missing_or_failed_experiments)) |
|
|
|
|
|
logger.info("Completed experiments:") |
|
|
for idx, exp in enumerate(completed_experiments_clean, start=1): |
|
|
logger.info(" %3d: %s", idx, exp) |
|
|
|
|
|
if missing_or_failed_experiments: |
|
|
logger.info("Missing or failed experiments:") |
|
|
for idx, exp in enumerate(missing_or_failed_experiments, start=1): |
|
|
logger.info(" %3d: %s", idx, exp) |
|
|
|
|
|
completion_rate = ( |
|
|
len(completed_experiments_clean) / len(all_datasets_full_name) * 100 if all_datasets_full_name else 0.0 |
|
|
) |
|
|
logger.info("Completion rate: %.1f%%", completion_rate) |
|
|
|
|
|
output_file = result_root / "all_results.csv" |
|
|
combined_df.to_csv(output_file, index=False) |
|
|
logger.info("Combined results saved to: %s", output_file) |
|
|
|
|
|
return combined_df |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"aggregate_results", |
|
|
"get_all_datasets_full_name", |
|
|
"write_results_to_disk", |
|
|
] |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
"""CLI entry point for aggregating results from disk.""" |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Aggregate GIFT-Eval results from multiple CSV files") |
|
|
parser.add_argument( |
|
|
"--result_root_dir", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Root directory containing result subdirectories", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
result_root_dir = Path(args.result_root_dir) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
logger.info("Searching in directory: %s", result_root_dir) |
|
|
|
|
|
aggregate_results(result_root_dir=result_root_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|