TempoPFN / src /gift_eval /results.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
"""Utilities for persisting and aggregating GIFT-Eval results."""
import argparse
import csv
import glob
import logging
from pathlib import Path
import pandas as pd
from src.gift_eval.constants import (
ALL_DATASETS,
DATASET_PROPERTIES,
MED_LONG_DATASETS,
PRETTY_NAMES,
STANDARD_METRIC_NAMES,
)
from src.gift_eval.core import DatasetMetadata, EvaluationItem
logger = logging.getLogger(__name__)
def _ensure_results_csv(csv_file_path: Path) -> None:
if not csv_file_path.exists():
csv_file_path.parent.mkdir(parents=True, exist_ok=True)
with open(csv_file_path, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
header = (
["dataset", "model"]
+ [f"eval_metrics/{name}" for name in STANDARD_METRIC_NAMES]
+ ["domain", "num_variates"]
)
writer.writerow(header)
def write_results_to_disk(
items: list[EvaluationItem],
dataset_name: str,
output_dir: Path,
model_name: str,
create_plots: bool,
) -> None:
output_dir = output_dir / dataset_name
output_dir.mkdir(parents=True, exist_ok=True)
output_csv_path = output_dir / "results.csv"
_ensure_results_csv(output_csv_path)
try:
import matplotlib.pyplot as plt # Local import to avoid unnecessary dependency at module import time
except ImportError: # pragma: no cover - guard for optional dependency
plt = None
with open(output_csv_path, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
for item in items:
md: DatasetMetadata = item.dataset_metadata
metric_values: list[float | None] = []
for metric_name in STANDARD_METRIC_NAMES:
value = item.metrics.get(metric_name, None)
if value is None:
metric_values.append(None)
else:
if hasattr(value, "__len__") and not isinstance(value, (str, bytes)) and len(value) == 1:
value = value[0]
elif hasattr(value, "item"):
value = value.item()
metric_values.append(value)
ds_key = md.key.lower()
props = DATASET_PROPERTIES.get(ds_key, {})
domain = props.get("domain", "unknown")
num_variates = props.get("num_variates", 1 if md.to_univariate else md.target_dim)
row = [md.full_name, model_name] + metric_values + [domain, num_variates]
writer.writerow(row)
if create_plots and item.figures and plt is not None:
plots_dir = output_dir / "plots" / md.key / md.term
plots_dir.mkdir(parents=True, exist_ok=True)
for fig, filename in item.figures:
filepath = plots_dir / filename
fig.savefig(filepath, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.info(
"Evaluation complete for dataset '%s'. Results saved to %s",
dataset_name,
output_csv_path,
)
if create_plots:
logger.info("Plots saved under %s", output_dir / "plots")
def get_all_datasets_full_name() -> list[str]:
"""Get all possible dataset full names for validation."""
terms = ["short", "medium", "long"]
datasets_full_names: list[str] = []
for name in ALL_DATASETS:
for term in terms:
if term in ["medium", "long"] and name not in MED_LONG_DATASETS:
continue
if "/" in name:
ds_key, ds_freq = name.split("/")
ds_key = ds_key.lower()
ds_key = PRETTY_NAMES.get(ds_key, ds_key)
else:
ds_key = name.lower()
ds_key = PRETTY_NAMES.get(ds_key, ds_key)
ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")
datasets_full_names.append(f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}")
return datasets_full_names
def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
"""Aggregate results from multiple CSV files into a single dataframe."""
result_root = Path(result_root_dir)
logger.info("Aggregating results in: %s", result_root)
result_files = glob.glob(f"{result_root}/**/results.csv", recursive=True)
if not result_files:
logger.error("No result files found!")
return None
dataframes: list[pd.DataFrame] = []
for file in result_files:
try:
df = pd.read_csv(file)
if len(df) > 0:
dataframes.append(df)
else:
logger.warning("Empty file: %s", file)
except pd.errors.EmptyDataError:
logger.warning("Skipping empty file: %s", file)
except Exception as exc:
logger.error("Error reading %s: %s", file, exc)
if not dataframes:
logger.warning("No valid CSV files found to combine")
return None
combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")
if len(combined_df) != len(set(combined_df.dataset)):
duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()
logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets)
combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
logger.info("Removed duplicates, %s unique datasets remaining", len(combined_df))
logger.info("Combined results: %s datasets", len(combined_df))
all_datasets_full_name = get_all_datasets_full_name()
completed_experiments = combined_df.dataset.tolist()
completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]
missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]
logger.info("=== EXPERIMENT SUMMARY ===")
logger.info("Total expected datasets: %s", len(all_datasets_full_name))
logger.info("Completed experiments: %s", len(completed_experiments_clean))
logger.info("Missing/failed experiments: %s", len(missing_or_failed_experiments))
logger.info("Completed experiments:")
for idx, exp in enumerate(completed_experiments_clean, start=1):
logger.info(" %3d: %s", idx, exp)
if missing_or_failed_experiments:
logger.info("Missing or failed experiments:")
for idx, exp in enumerate(missing_or_failed_experiments, start=1):
logger.info(" %3d: %s", idx, exp)
completion_rate = (
len(completed_experiments_clean) / len(all_datasets_full_name) * 100 if all_datasets_full_name else 0.0
)
logger.info("Completion rate: %.1f%%", completion_rate)
output_file = result_root / "all_results.csv"
combined_df.to_csv(output_file, index=False)
logger.info("Combined results saved to: %s", output_file)
return combined_df
__all__ = [
"aggregate_results",
"get_all_datasets_full_name",
"write_results_to_disk",
]
def main() -> None:
"""CLI entry point for aggregating results from disk."""
parser = argparse.ArgumentParser(description="Aggregate GIFT-Eval results from multiple CSV files")
parser.add_argument(
"--result_root_dir",
type=str,
required=True,
help="Root directory containing result subdirectories",
)
args = parser.parse_args()
result_root_dir = Path(args.result_root_dir)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger.info("Searching in directory: %s", result_root_dir)
aggregate_results(result_root_dir=result_root_dir)
if __name__ == "__main__":
main()