TempoPFN / examples /gift_eval /gift_eval_runner.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
#!/usr/bin/env python
"""
GIFT-Eval Runner Script (Hugging Face Repository Version)
This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline.
- Assumes it is running inside the cloned Hugging Face repository.
- Uses `src/gift_eval/data.py` for dataset handling.
- Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference.
- Loads the model from the local checkpoint (e.g., `models/checkpoint_38M.pth`).
- Writes per-dataset CSV metrics to `output_dir`.
"""
import argparse
import logging
from pathlib import Path
from src.gift_eval.constants import ALL_DATASETS
from src.gift_eval.evaluate import evaluate_datasets
from src.gift_eval.predictor import TimeSeriesPredictor
from src.gift_eval.results import aggregate_results, write_results_to_disk
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
logger = logging.getLogger("gift_eval_runner")
def _expand_datasets_arg(datasets_arg: list[str] | str) -> list[str]:
"""Expand dataset argument to list of dataset names."""
if isinstance(datasets_arg, str):
if datasets_arg == "all":
return list(ALL_DATASETS)
datasets_list = [datasets_arg]
else:
datasets_list = datasets_arg
if datasets_list and datasets_list[0] == "all":
return list(ALL_DATASETS)
for ds in datasets_list:
if ds not in ALL_DATASETS:
raise ValueError(f"Invalid dataset: {ds}. Use one of {ALL_DATASETS}")
return datasets_list
def run_evaluation(
predictor: TimeSeriesPredictor,
datasets_arg: list[str] | str,
terms_arg: list[str],
dataset_storage_path: str,
max_windows_arg: int | None,
batch_size_arg: int,
max_context_length_arg: int | None,
output_dir_arg: str,
model_name_arg: str,
after_each_dataset_flush: bool = True,
) -> None:
"""Run evaluation on specified datasets."""
datasets_to_run = _expand_datasets_arg(datasets_arg)
results_root = Path(output_dir_arg)
for ds_name in datasets_to_run:
items = evaluate_datasets(
predictor=predictor,
dataset=ds_name,
dataset_storage_path=dataset_storage_path,
terms=terms_arg,
max_windows=max_windows_arg,
batch_size=batch_size_arg,
max_context_length=max_context_length_arg,
create_plots=False,
max_plots_per_dataset=0,
)
write_results_to_disk(
items=items,
dataset_name=ds_name,
output_dir=results_root,
model_name=model_name_arg,
create_plots=False,
)
if after_each_dataset_flush:
logger.info("Flushed results for %s", ds_name)
def main():
"""Main execution function."""
parser = argparse.ArgumentParser(description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets")
parser.add_argument(
"--model_path",
type=str,
default="models/checkpoint_38M.pth",
help="Path to a local model checkpoint (default: models/checkpoint_38M.pth in this repo).",
)
parser.add_argument(
"--config_path",
type=str,
default="configs/example.yaml",
help="Path to model config YAML (default: configs/example.yaml)",
)
# Dataset configuration
parser.add_argument(
"--datasets",
type=str,
nargs="+",
default=["all"],
help='List of dataset names or ["all"] (default: all)',
)
parser.add_argument(
"--terms",
type=str,
nargs="+",
default=["short", "medium", "long"],
help="Prediction terms to evaluate (default: short medium long)",
)
parser.add_argument(
"--dataset_storage_path",
type=str,
default="/work/dlclarge2/moroshav-GiftEvalPretrain/gift_eval",
# required=True,
help="Path to the root of the gift eval datasets storage directory",
)
parser.add_argument(
"--max_windows",
type=int,
default=20,
help="Maximum number of windows to use for evaluation (default: 20)",
)
# Inference configuration
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Batch size for inference (default: 128)",
)
parser.add_argument(
"--max_context_length",
type=int,
default=3072,
help="Maximum context length (default: 3072)",
)
# Output configuration
parser.add_argument(
"--output_dir",
type=str,
default="gift_eval_results",
help="Output directory for results (default: gift_eval_results)",
)
parser.add_argument(
"--model_name",
type=str,
default="TempoPFN",
help="Model name identifier for results (default: TempoPFN)",
)
parser.add_argument(
"--no_flush",
action="store_true",
help="Disable flushing results after each dataset",
)
args = parser.parse_args()
# Resolve paths
config_path = Path(args.config_path)
output_dir = Path(args.output_dir)
resolved_model_path = Path(args.model_path)
if not resolved_model_path.exists():
logger.error(f"Model checkpoint not found at: {resolved_model_path}")
logger.error("Please ensure the file exists or you've cloned the repo using Git LFS.")
raise FileNotFoundError(f"No model checkpoint found at {resolved_model_path}")
if not config_path.exists():
raise FileNotFoundError(f"Config not found: {config_path}")
logger.info("Loading predictor from checkpoint: %s", resolved_model_path)
predictor = TimeSeriesPredictor.from_paths(
model_path=str(resolved_model_path),
config_path=str(config_path),
ds_prediction_length=1, # placeholder; set per dataset
ds_freq="D", # placeholder; set per dataset
batch_size=args.batch_size,
max_context_length=args.max_context_length,
)
logger.info("Starting evaluation...")
logger.info(" Datasets: %s", args.datasets)
logger.info(" Terms: %s", args.terms)
logger.info(" Output directory: %s", output_dir)
# Run evaluation
run_evaluation(
predictor=predictor,
datasets_arg=args.datasets,
terms_arg=args.terms,
dataset_storage_path=args.dataset_storage_path,
max_windows_arg=args.max_windows,
batch_size_arg=args.batch_size,
max_context_length_arg=args.max_context_length,
output_dir_arg=str(output_dir),
model_name_arg=args.model_name,
after_each_dataset_flush=not args.no_flush,
)
logger.info("Evaluation complete. See results under: %s", output_dir)
# Aggregate all results into a single CSV file
logger.info("Aggregating results from all datasets...")
combined_df = aggregate_results(result_root_dir=output_dir)
if combined_df is not None:
logger.info(
"Successfully created aggregated results file: %s/all_results.csv",
output_dir,
)
else:
logger.warning("No results to aggregate. Check that evaluation completed successfully.")
if __name__ == "__main__":
main()