|
|
import logging |
|
|
import os |
|
|
import urllib.request |
|
|
from typing import List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
|
|
|
from src.data.containers import BatchTimeSeriesContainer |
|
|
from src.models.model import TimeSeriesModel |
|
|
from src.plotting.plot_timeseries import plot_from_container |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def load_model( |
|
|
config_path: str, model_path: str, device: torch.device |
|
|
) -> TimeSeriesModel: |
|
|
"""Load the TimeSeriesModel from config and checkpoint.""" |
|
|
with open(config_path, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device) |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.eval() |
|
|
logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}") |
|
|
return model |
|
|
|
|
|
|
|
|
def download_checkpoint_if_needed(url: str, target_dir: str = "models") -> str: |
|
|
"""Download checkpoint from URL into target_dir if not present and return its path. |
|
|
|
|
|
Ensures direct download for Dropbox links by forcing dl=1. |
|
|
""" |
|
|
os.makedirs(target_dir, exist_ok=True) |
|
|
target_path = os.path.join(target_dir, "checkpoint.pth") |
|
|
|
|
|
|
|
|
if "dropbox.com" in url and "dl=0" in url: |
|
|
url = url.replace("dl=0", "dl=1") |
|
|
|
|
|
if not os.path.exists(target_path): |
|
|
logger.info(f"Downloading checkpoint from {url} to {target_path}...") |
|
|
urllib.request.urlretrieve(url, target_path) |
|
|
logger.info("Checkpoint downloaded successfully.") |
|
|
else: |
|
|
logger.info(f"Using existing checkpoint at {target_path}") |
|
|
|
|
|
return target_path |
|
|
|
|
|
|
|
|
def plot_with_library( |
|
|
container: BatchTimeSeriesContainer, |
|
|
predictions_np: np.ndarray, |
|
|
model_quantiles: List[float] | None, |
|
|
output_dir: str = "outputs", |
|
|
show_plots: bool = True, |
|
|
save_plots: bool = True, |
|
|
): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
batch_size = container.batch_size |
|
|
for i in range(batch_size): |
|
|
output_file = ( |
|
|
os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") |
|
|
if save_plots |
|
|
else None |
|
|
) |
|
|
plot_from_container( |
|
|
batch=container, |
|
|
sample_idx=i, |
|
|
predicted_values=predictions_np, |
|
|
model_quantiles=model_quantiles, |
|
|
title=f"Sine Wave Time Series Prediction - Sample {i + 1}", |
|
|
output_file=output_file, |
|
|
show=show_plots, |
|
|
) |
|
|
|
|
|
|
|
|
def run_inference_and_plot( |
|
|
model: TimeSeriesModel, |
|
|
container: BatchTimeSeriesContainer, |
|
|
output_dir: str = "outputs", |
|
|
use_bfloat16: bool = True, |
|
|
) -> None: |
|
|
"""Run model inference with optional bfloat16 and plot using shared utilities.""" |
|
|
device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu" |
|
|
autocast_enabled = use_bfloat16 and device_type == "cuda" |
|
|
with ( |
|
|
torch.no_grad(), |
|
|
torch.autocast( |
|
|
device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled |
|
|
), |
|
|
): |
|
|
model_output = model(container) |
|
|
|
|
|
preds_full = model_output["result"].to(torch.float32) |
|
|
if hasattr(model, "scaler") and "scale_statistics" in model_output: |
|
|
preds_full = model.scaler.inverse_scale( |
|
|
preds_full, model_output["scale_statistics"] |
|
|
) |
|
|
|
|
|
preds_np = preds_full.detach().cpu().numpy() |
|
|
model_quantiles = ( |
|
|
model.quantiles if getattr(model, "loss_type", None) == "quantile" else None |
|
|
) |
|
|
plot_with_library( |
|
|
container=container, |
|
|
predictions_np=preds_np, |
|
|
model_quantiles=model_quantiles, |
|
|
output_dir=output_dir, |
|
|
show_plots=True, |
|
|
save_plots=True, |
|
|
) |
|
|
|