|
|
|
|
|
""" |
|
|
GIFT-Eval Runner Script |
|
|
|
|
|
This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline. |
|
|
|
|
|
- Uses `src/gift_eval/data.py` for dataset handling. |
|
|
- Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference. |
|
|
- Loads a model from a checkpoint. |
|
|
- Writes per-dataset CSV metrics to `output_dir` without creating plots. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import List, Optional |
|
|
|
|
|
from examples.utils import download_checkpoint_if_needed |
|
|
from src.gift_eval.constants import ALL_DATASETS |
|
|
from src.gift_eval.evaluate import evaluate_datasets |
|
|
from src.gift_eval.predictor import TimeSeriesPredictor |
|
|
from src.gift_eval.results import aggregate_results, write_results_to_disk |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
|
) |
|
|
logging.getLogger("matplotlib").setLevel(logging.WARNING) |
|
|
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) |
|
|
logger = logging.getLogger("gift_eval_runner") |
|
|
|
|
|
|
|
|
def _expand_datasets_arg(datasets_arg: List[str] | str) -> List[str]: |
|
|
"""Expand dataset argument to list of dataset names.""" |
|
|
if isinstance(datasets_arg, str): |
|
|
if datasets_arg == "all": |
|
|
return list(ALL_DATASETS) |
|
|
datasets_list = [datasets_arg] |
|
|
else: |
|
|
datasets_list = datasets_arg |
|
|
if datasets_list and datasets_list[0] == "all": |
|
|
return list(ALL_DATASETS) |
|
|
|
|
|
for ds in datasets_list: |
|
|
if ds not in ALL_DATASETS: |
|
|
raise ValueError(f"Invalid dataset: {ds}. Use one of {ALL_DATASETS}") |
|
|
return datasets_list |
|
|
|
|
|
|
|
|
def run_evaluation( |
|
|
predictor: TimeSeriesPredictor, |
|
|
datasets_arg: List[str] | str, |
|
|
terms_arg: List[str], |
|
|
dataset_storage_path: str, |
|
|
max_windows_arg: Optional[int], |
|
|
batch_size_arg: int, |
|
|
max_context_length_arg: Optional[int], |
|
|
output_dir_arg: str, |
|
|
model_name_arg: str, |
|
|
after_each_dataset_flush: bool = True, |
|
|
) -> None: |
|
|
"""Run evaluation on specified datasets.""" |
|
|
datasets_to_run = _expand_datasets_arg(datasets_arg) |
|
|
results_root = Path(output_dir_arg) |
|
|
|
|
|
for ds_name in datasets_to_run: |
|
|
items = evaluate_datasets( |
|
|
predictor=predictor, |
|
|
dataset=ds_name, |
|
|
dataset_storage_path=dataset_storage_path, |
|
|
terms=terms_arg, |
|
|
max_windows=max_windows_arg, |
|
|
batch_size=batch_size_arg, |
|
|
max_context_length=max_context_length_arg, |
|
|
create_plots=False, |
|
|
max_plots_per_dataset=0, |
|
|
) |
|
|
write_results_to_disk( |
|
|
items=items, |
|
|
dataset_name=ds_name, |
|
|
output_dir=results_root, |
|
|
model_name=model_name_arg, |
|
|
create_plots=False, |
|
|
) |
|
|
if after_each_dataset_flush: |
|
|
logger.info("Flushed results for %s", ds_name) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main execution function.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to model checkpoint. If not provided, will download from checkpoint_url.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--config_path", |
|
|
type=str, |
|
|
default="configs/example.yaml", |
|
|
help="Path to model config YAML (default: configs/example.yaml)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--checkpoint_url", |
|
|
type=str, |
|
|
default="https://www.dropbox.com/scl/fi/mqsni5lehooyaw93y3uzq/checkpoint_38M.pth?rlkey=3uyehvmtted02xkha24zgpzb6&st=seevsbkn&dl=0", |
|
|
help="URL to download checkpoint from if model_path is not provided", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--download_dir", |
|
|
type=str, |
|
|
default="models", |
|
|
help="Directory to download checkpoint to (default: models)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--datasets", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=["all"], |
|
|
help='List of dataset names or ["all"] (default: all)', |
|
|
) |
|
|
parser.add_argument( |
|
|
"--terms", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=["short", "medium", "long"], |
|
|
help="Prediction terms to evaluate (default: short medium long)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset_storage_path", |
|
|
type=str, |
|
|
default="/work/dlclarge2/moroshav-GiftEvalPretrain/gift_eval", |
|
|
|
|
|
help="Path to the root of the gift eval datasets storage directory", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_windows", |
|
|
type=int, |
|
|
default=20, |
|
|
help="Maximum number of windows to use for evaluation (default: 20)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--batch_size", |
|
|
type=int, |
|
|
default=64, |
|
|
help="Batch size for inference (default: 128)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_context_length", |
|
|
type=int, |
|
|
default=3072, |
|
|
help="Maximum context length (default: 3072)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="gift_eval_results", |
|
|
help="Output directory for results (default: gift_eval_results)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_name", |
|
|
type=str, |
|
|
default="TempoPFN", |
|
|
help="Model name identifier for results (default: TempoPFN)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--no_flush", |
|
|
action="store_true", |
|
|
help="Disable flushing results after each dataset", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
config_path = Path(args.config_path) |
|
|
download_dir = Path(args.download_dir) |
|
|
output_dir = Path(args.output_dir) |
|
|
|
|
|
|
|
|
resolved_model_path = None |
|
|
if args.model_path: |
|
|
resolved_model_path = args.model_path |
|
|
elif args.checkpoint_url: |
|
|
resolved_model_path = download_checkpoint_if_needed( |
|
|
args.checkpoint_url, target_dir=download_dir |
|
|
) |
|
|
|
|
|
if not resolved_model_path: |
|
|
raise FileNotFoundError( |
|
|
"No model checkpoint provided. Set --model_path or --checkpoint_url." |
|
|
) |
|
|
|
|
|
if not config_path.exists(): |
|
|
raise FileNotFoundError(f"Config not found: {config_path}") |
|
|
|
|
|
logger.info("Loading predictor from checkpoint: %s", resolved_model_path) |
|
|
predictor = TimeSeriesPredictor.from_paths( |
|
|
model_path=resolved_model_path, |
|
|
config_path=str(config_path), |
|
|
ds_prediction_length=1, |
|
|
ds_freq="D", |
|
|
batch_size=args.batch_size, |
|
|
max_context_length=args.max_context_length, |
|
|
) |
|
|
|
|
|
logger.info("Starting evaluation...") |
|
|
logger.info(" Datasets: %s", args.datasets) |
|
|
logger.info(" Terms: %s", args.terms) |
|
|
logger.info(" Output directory: %s", output_dir) |
|
|
|
|
|
|
|
|
run_evaluation( |
|
|
predictor=predictor, |
|
|
datasets_arg=args.datasets, |
|
|
terms_arg=args.terms, |
|
|
dataset_storage_path=args.dataset_storage_path, |
|
|
max_windows_arg=args.max_windows, |
|
|
batch_size_arg=args.batch_size, |
|
|
max_context_length_arg=args.max_context_length, |
|
|
output_dir_arg=str(output_dir), |
|
|
model_name_arg=args.model_name, |
|
|
after_each_dataset_flush=not args.no_flush, |
|
|
) |
|
|
|
|
|
logger.info("Evaluation complete. See results under: %s", output_dir) |
|
|
|
|
|
|
|
|
logger.info("Aggregating results from all datasets...") |
|
|
combined_df = aggregate_results(result_root_dir=output_dir) |
|
|
|
|
|
if combined_df is not None: |
|
|
logger.info("Successfully created aggregated results file: %s/all_results.csv", output_dir) |
|
|
else: |
|
|
logger.warning("No results to aggregate. Check that evaluation completed successfully.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|