File size: 7,426 Bytes
c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
#!/usr/bin/env python
"""
GIFT-Eval Runner Script (Hugging Face Repository Version)
This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline.
- Assumes it is running inside the cloned Hugging Face repository.
- Uses `src/gift_eval/data.py` for dataset handling.
- Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference.
- Loads the model from the local checkpoint (e.g., `models/checkpoint_38M.pth`).
- Writes per-dataset CSV metrics to `output_dir`.
"""
import argparse
import logging
from pathlib import Path
from src.gift_eval.constants import ALL_DATASETS
from src.gift_eval.evaluate import evaluate_datasets
from src.gift_eval.predictor import TimeSeriesPredictor
from src.gift_eval.results import aggregate_results, write_results_to_disk
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
logger = logging.getLogger("gift_eval_runner")
def _expand_datasets_arg(datasets_arg: list[str] | str) -> list[str]:
"""Expand dataset argument to list of dataset names."""
if isinstance(datasets_arg, str):
if datasets_arg == "all":
return list(ALL_DATASETS)
datasets_list = [datasets_arg]
else:
datasets_list = datasets_arg
if datasets_list and datasets_list[0] == "all":
return list(ALL_DATASETS)
for ds in datasets_list:
if ds not in ALL_DATASETS:
raise ValueError(f"Invalid dataset: {ds}. Use one of {ALL_DATASETS}")
return datasets_list
def run_evaluation(
predictor: TimeSeriesPredictor,
datasets_arg: list[str] | str,
terms_arg: list[str],
dataset_storage_path: str,
max_windows_arg: int | None,
batch_size_arg: int,
max_context_length_arg: int | None,
output_dir_arg: str,
model_name_arg: str,
after_each_dataset_flush: bool = True,
) -> None:
"""Run evaluation on specified datasets."""
datasets_to_run = _expand_datasets_arg(datasets_arg)
results_root = Path(output_dir_arg)
for ds_name in datasets_to_run:
items = evaluate_datasets(
predictor=predictor,
dataset=ds_name,
dataset_storage_path=dataset_storage_path,
terms=terms_arg,
max_windows=max_windows_arg,
batch_size=batch_size_arg,
max_context_length=max_context_length_arg,
create_plots=False,
max_plots_per_dataset=0,
)
write_results_to_disk(
items=items,
dataset_name=ds_name,
output_dir=results_root,
model_name=model_name_arg,
create_plots=False,
)
if after_each_dataset_flush:
logger.info("Flushed results for %s", ds_name)
def main():
"""Main execution function."""
parser = argparse.ArgumentParser(description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets")
parser.add_argument(
"--model_path",
type=str,
default="models/checkpoint_38M.pth",
help="Path to a local model checkpoint (default: models/checkpoint_38M.pth in this repo).",
)
parser.add_argument(
"--config_path",
type=str,
default="configs/example.yaml",
help="Path to model config YAML (default: configs/example.yaml)",
)
# Dataset configuration
parser.add_argument(
"--datasets",
type=str,
nargs="+",
default=["all"],
help='List of dataset names or ["all"] (default: all)',
)
parser.add_argument(
"--terms",
type=str,
nargs="+",
default=["short", "medium", "long"],
help="Prediction terms to evaluate (default: short medium long)",
)
parser.add_argument(
"--dataset_storage_path",
type=str,
default="/work/dlclarge2/moroshav-GiftEvalPretrain/gift_eval",
# required=True,
help="Path to the root of the gift eval datasets storage directory",
)
parser.add_argument(
"--max_windows",
type=int,
default=20,
help="Maximum number of windows to use for evaluation (default: 20)",
)
# Inference configuration
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Batch size for inference (default: 128)",
)
parser.add_argument(
"--max_context_length",
type=int,
default=3072,
help="Maximum context length (default: 3072)",
)
# Output configuration
parser.add_argument(
"--output_dir",
type=str,
default="gift_eval_results",
help="Output directory for results (default: gift_eval_results)",
)
parser.add_argument(
"--model_name",
type=str,
default="TempoPFN",
help="Model name identifier for results (default: TempoPFN)",
)
parser.add_argument(
"--no_flush",
action="store_true",
help="Disable flushing results after each dataset",
)
args = parser.parse_args()
# Resolve paths
config_path = Path(args.config_path)
output_dir = Path(args.output_dir)
resolved_model_path = Path(args.model_path)
if not resolved_model_path.exists():
logger.error(f"Model checkpoint not found at: {resolved_model_path}")
logger.error("Please ensure the file exists or you've cloned the repo using Git LFS.")
raise FileNotFoundError(f"No model checkpoint found at {resolved_model_path}")
if not config_path.exists():
raise FileNotFoundError(f"Config not found: {config_path}")
logger.info("Loading predictor from checkpoint: %s", resolved_model_path)
predictor = TimeSeriesPredictor.from_paths(
model_path=str(resolved_model_path),
config_path=str(config_path),
ds_prediction_length=1, # placeholder; set per dataset
ds_freq="D", # placeholder; set per dataset
batch_size=args.batch_size,
max_context_length=args.max_context_length,
)
logger.info("Starting evaluation...")
logger.info(" Datasets: %s", args.datasets)
logger.info(" Terms: %s", args.terms)
logger.info(" Output directory: %s", output_dir)
# Run evaluation
run_evaluation(
predictor=predictor,
datasets_arg=args.datasets,
terms_arg=args.terms,
dataset_storage_path=args.dataset_storage_path,
max_windows_arg=args.max_windows,
batch_size_arg=args.batch_size,
max_context_length_arg=args.max_context_length,
output_dir_arg=str(output_dir),
model_name_arg=args.model_name,
after_each_dataset_flush=not args.no_flush,
)
logger.info("Evaluation complete. See results under: %s", output_dir)
# Aggregate all results into a single CSV file
logger.info("Aggregating results from all datasets...")
combined_df = aggregate_results(result_root_dir=output_dir)
if combined_df is not None:
logger.info(
"Successfully created aggregated results file: %s/all_results.csv",
output_dir,
)
else:
logger.warning("No results to aggregate. Check that evaluation completed successfully.")
if __name__ == "__main__":
main()
|