TempoPFN / examples /utils.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
raw
history blame
2.78 kB
import logging
import os
import numpy as np
import torch
import yaml
from src.data.containers import BatchTimeSeriesContainer
from src.models.model import TimeSeriesModel
from src.plotting.plot_timeseries import plot_from_container
logger = logging.getLogger(__name__)
def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel:
"""Load the TimeSeriesModel from config and checkpoint."""
with open(config_path) as f:
config = yaml.safe_load(f)
model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}")
return model
def plot_with_library(
container: BatchTimeSeriesContainer,
predictions_np: np.ndarray, # [B, P, N, Q]
model_quantiles: list[float] | None,
output_dir: str = "outputs",
show_plots: bool = True,
save_plots: bool = True,
):
os.makedirs(output_dir, exist_ok=True)
batch_size = container.batch_size
for i in range(batch_size):
output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None
plot_from_container(
batch=container,
sample_idx=i,
predicted_values=predictions_np,
model_quantiles=model_quantiles,
title=f"Sine Wave Time Series Prediction - Sample {i + 1}",
output_file=output_file,
show=show_plots,
)
def run_inference_and_plot(
model: TimeSeriesModel,
container: BatchTimeSeriesContainer,
output_dir: str = "outputs",
use_bfloat16: bool = True,
) -> None:
"""Run model inference with optional bfloat16 and plot using shared utilities."""
device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu"
autocast_enabled = use_bfloat16 and device_type == "cuda"
with (
torch.no_grad(),
torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled),
):
model_output = model(container)
preds_full = model_output["result"].to(torch.float32)
if hasattr(model, "scaler") and "scale_statistics" in model_output:
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
preds_np = preds_full.detach().cpu().numpy()
model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
plot_with_library(
container=container,
predictions_np=preds_np,
model_quantiles=model_quantiles,
output_dir=output_dir,
show_plots=True,
save_plots=True,
)