File size: 3,767 Bytes
c4b87d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import logging
import os
import urllib.request
from typing import List
import numpy as np
import torch
import yaml
from src.data.containers import BatchTimeSeriesContainer
from src.models.model import TimeSeriesModel
from src.plotting.plot_timeseries import plot_from_container
logger = logging.getLogger(__name__)
def load_model(
config_path: str, model_path: str, device: torch.device
) -> TimeSeriesModel:
"""Load the TimeSeriesModel from config and checkpoint."""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}")
return model
def download_checkpoint_if_needed(url: str, target_dir: str = "models") -> str:
"""Download checkpoint from URL into target_dir if not present and return its path.
Ensures direct download for Dropbox links by forcing dl=1.
"""
os.makedirs(target_dir, exist_ok=True)
target_path = os.path.join(target_dir, "checkpoint.pth")
# Normalize Dropbox URL to force direct download
if "dropbox.com" in url and "dl=0" in url:
url = url.replace("dl=0", "dl=1")
if not os.path.exists(target_path):
logger.info(f"Downloading checkpoint from {url} to {target_path}...")
urllib.request.urlretrieve(url, target_path)
logger.info("Checkpoint downloaded successfully.")
else:
logger.info(f"Using existing checkpoint at {target_path}")
return target_path
def plot_with_library(
container: BatchTimeSeriesContainer,
predictions_np: np.ndarray, # [B, P, N, Q]
model_quantiles: List[float] | None,
output_dir: str = "outputs",
show_plots: bool = True,
save_plots: bool = True,
):
os.makedirs(output_dir, exist_ok=True)
batch_size = container.batch_size
for i in range(batch_size):
output_file = (
os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png")
if save_plots
else None
)
plot_from_container(
batch=container,
sample_idx=i,
predicted_values=predictions_np,
model_quantiles=model_quantiles,
title=f"Sine Wave Time Series Prediction - Sample {i + 1}",
output_file=output_file,
show=show_plots,
)
def run_inference_and_plot(
model: TimeSeriesModel,
container: BatchTimeSeriesContainer,
output_dir: str = "outputs",
use_bfloat16: bool = True,
) -> None:
"""Run model inference with optional bfloat16 and plot using shared utilities."""
device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu"
autocast_enabled = use_bfloat16 and device_type == "cuda"
with (
torch.no_grad(),
torch.autocast(
device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled
),
):
model_output = model(container)
preds_full = model_output["result"].to(torch.float32)
if hasattr(model, "scaler") and "scale_statistics" in model_output:
preds_full = model.scaler.inverse_scale(
preds_full, model_output["scale_statistics"]
)
preds_np = preds_full.detach().cpu().numpy()
model_quantiles = (
model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
)
plot_with_library(
container=container,
predictions_np=preds_np,
model_quantiles=model_quantiles,
output_dir=output_dir,
show_plots=True,
save_plots=True,
)
|