File size: 3,767 Bytes
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import logging
import os
import urllib.request
from typing import List

import numpy as np
import torch
import yaml

from src.data.containers import BatchTimeSeriesContainer
from src.models.model import TimeSeriesModel
from src.plotting.plot_timeseries import plot_from_container

logger = logging.getLogger(__name__)


def load_model(
    config_path: str, model_path: str, device: torch.device
) -> TimeSeriesModel:
    """Load the TimeSeriesModel from config and checkpoint."""
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}")
    return model


def download_checkpoint_if_needed(url: str, target_dir: str = "models") -> str:
    """Download checkpoint from URL into target_dir if not present and return its path.

    Ensures direct download for Dropbox links by forcing dl=1.
    """
    os.makedirs(target_dir, exist_ok=True)
    target_path = os.path.join(target_dir, "checkpoint.pth")

    # Normalize Dropbox URL to force direct download
    if "dropbox.com" in url and "dl=0" in url:
        url = url.replace("dl=0", "dl=1")

    if not os.path.exists(target_path):
        logger.info(f"Downloading checkpoint from {url} to {target_path}...")
        urllib.request.urlretrieve(url, target_path)
        logger.info("Checkpoint downloaded successfully.")
    else:
        logger.info(f"Using existing checkpoint at {target_path}")

    return target_path


def plot_with_library(
    container: BatchTimeSeriesContainer,
    predictions_np: np.ndarray,  # [B, P, N, Q]
    model_quantiles: List[float] | None,
    output_dir: str = "outputs",
    show_plots: bool = True,
    save_plots: bool = True,
):
    os.makedirs(output_dir, exist_ok=True)
    batch_size = container.batch_size
    for i in range(batch_size):
        output_file = (
            os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png")
            if save_plots
            else None
        )
        plot_from_container(
            batch=container,
            sample_idx=i,
            predicted_values=predictions_np,
            model_quantiles=model_quantiles,
            title=f"Sine Wave Time Series Prediction - Sample {i + 1}",
            output_file=output_file,
            show=show_plots,
        )


def run_inference_and_plot(
    model: TimeSeriesModel,
    container: BatchTimeSeriesContainer,
    output_dir: str = "outputs",
    use_bfloat16: bool = True,
) -> None:
    """Run model inference with optional bfloat16 and plot using shared utilities."""
    device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu"
    autocast_enabled = use_bfloat16 and device_type == "cuda"
    with (
        torch.no_grad(),
        torch.autocast(
            device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled
        ),
    ):
        model_output = model(container)

    preds_full = model_output["result"].to(torch.float32)
    if hasattr(model, "scaler") and "scale_statistics" in model_output:
        preds_full = model.scaler.inverse_scale(
            preds_full, model_output["scale_statistics"]
        )

    preds_np = preds_full.detach().cpu().numpy()
    model_quantiles = (
        model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
    )
    plot_with_library(
        container=container,
        predictions_np=preds_np,
        model_quantiles=model_quantiles,
        output_dir=output_dir,
        show_plots=True,
        save_plots=True,
    )