|
|
"""Predictor implementation wrapping the TimeSeriesModel for GIFT-Eval.""" |
|
|
|
|
|
import logging |
|
|
from collections.abc import Iterator |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
from gluonts.model.forecast import QuantileForecast |
|
|
from gluonts.model.predictor import Predictor |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
from src.data.containers import BatchTimeSeriesContainer |
|
|
from src.data.frequency import parse_frequency |
|
|
from src.data.scalers import RobustScaler |
|
|
from src.models.model import TimeSeriesModel |
|
|
from src.utils.utils import device |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TimeSeriesPredictor(Predictor): |
|
|
"""Unified predictor for TimeSeriesModel supporting flexible construction.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: TimeSeriesModel, |
|
|
config: dict, |
|
|
ds_prediction_length: int, |
|
|
ds_freq: str, |
|
|
batch_size: int = 32, |
|
|
max_context_length: int | None = None, |
|
|
debug: bool = False, |
|
|
) -> None: |
|
|
|
|
|
self.ds_prediction_length = ds_prediction_length |
|
|
self.ds_freq = ds_freq |
|
|
self.batch_size = batch_size |
|
|
self.max_context_length = max_context_length |
|
|
self.debug = debug |
|
|
|
|
|
|
|
|
self.model = model.module if isinstance(model, DDP) else model |
|
|
self.model.eval() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
scaler_type = self.config.get("TimeSeriesModel", {}).get("scaler", "custom_robust") |
|
|
epsilon = self.config.get("TimeSeriesModel", {}).get("epsilon", 1e-3) |
|
|
if scaler_type == "custom_robust": |
|
|
self.scaler = RobustScaler(epsilon=epsilon) |
|
|
else: |
|
|
raise ValueError(f"Unsupported scaler type: {scaler_type}") |
|
|
|
|
|
def set_dataset_context( |
|
|
self, |
|
|
prediction_length: int | None = None, |
|
|
freq: str | None = None, |
|
|
batch_size: int | None = None, |
|
|
max_context_length: int | None = None, |
|
|
) -> None: |
|
|
"""Update lightweight dataset-specific attributes without reloading the model.""" |
|
|
|
|
|
if prediction_length is not None: |
|
|
self.ds_prediction_length = prediction_length |
|
|
if freq is not None: |
|
|
self.ds_freq = freq |
|
|
if batch_size is not None: |
|
|
self.batch_size = batch_size |
|
|
if max_context_length is not None: |
|
|
self.max_context_length = max_context_length |
|
|
|
|
|
@classmethod |
|
|
def from_model( |
|
|
cls, |
|
|
model: TimeSeriesModel, |
|
|
config: dict, |
|
|
ds_prediction_length: int, |
|
|
ds_freq: str, |
|
|
batch_size: int = 32, |
|
|
max_context_length: int | None = None, |
|
|
debug: bool = False, |
|
|
) -> "TimeSeriesPredictor": |
|
|
return cls( |
|
|
model=model, |
|
|
config=config, |
|
|
ds_prediction_length=ds_prediction_length, |
|
|
ds_freq=ds_freq, |
|
|
batch_size=batch_size, |
|
|
max_context_length=max_context_length, |
|
|
debug=debug, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_paths( |
|
|
cls, |
|
|
model_path: str, |
|
|
config_path: str, |
|
|
ds_prediction_length: int, |
|
|
ds_freq: str, |
|
|
batch_size: int = 32, |
|
|
max_context_length: int | None = None, |
|
|
debug: bool = False, |
|
|
) -> "TimeSeriesPredictor": |
|
|
with open(config_path) as f: |
|
|
config = yaml.safe_load(f) |
|
|
model = cls._load_model_from_path(config=config, model_path=model_path) |
|
|
return cls( |
|
|
model=model, |
|
|
config=config, |
|
|
ds_prediction_length=ds_prediction_length, |
|
|
ds_freq=ds_freq, |
|
|
batch_size=batch_size, |
|
|
max_context_length=max_context_length, |
|
|
debug=debug, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _load_model_from_path(config: dict, model_path: str) -> TimeSeriesModel: |
|
|
try: |
|
|
model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device) |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.eval() |
|
|
logger.info(f"Successfully loaded model from {model_path}") |
|
|
return model |
|
|
except Exception as exc: |
|
|
logger.error(f"Failed to load model from {model_path}: {exc}") |
|
|
raise |
|
|
|
|
|
def predict(self, test_data_input) -> Iterator[QuantileForecast]: |
|
|
"""Generate forecasts for the test data.""" |
|
|
|
|
|
if hasattr(test_data_input, "__iter__") and not isinstance(test_data_input, list): |
|
|
test_data_input = list(test_data_input) |
|
|
logger.debug(f"Processing {len(test_data_input)} time series") |
|
|
|
|
|
|
|
|
|
|
|
def _effective_length(entry) -> int: |
|
|
target = entry["target"] |
|
|
if target.ndim == 1: |
|
|
seq_len = len(target) |
|
|
else: |
|
|
|
|
|
seq_len = target.shape[1] |
|
|
if self.max_context_length is not None: |
|
|
seq_len = min(seq_len, self.max_context_length) |
|
|
return seq_len |
|
|
|
|
|
length_to_items: dict[int, list[tuple[int, object]]] = {} |
|
|
for idx, entry in enumerate(test_data_input): |
|
|
seq_len = _effective_length(entry) |
|
|
length_to_items.setdefault(seq_len, []).append((idx, entry)) |
|
|
|
|
|
total = len(test_data_input) |
|
|
ordered_results: list[QuantileForecast | None] = [None] * total |
|
|
|
|
|
for _, items in length_to_items.items(): |
|
|
for i in range(0, len(items), self.batch_size): |
|
|
chunk = items[i : i + self.batch_size] |
|
|
entries = [entry for (_orig_idx, entry) in chunk] |
|
|
batch_forecasts = self._predict_batch(entries) |
|
|
for forecast_idx, (orig_idx, _entry) in enumerate(chunk): |
|
|
ordered_results[orig_idx] = batch_forecasts[forecast_idx] |
|
|
|
|
|
return ordered_results |
|
|
|
|
|
def _predict_batch(self, test_data_batch: list) -> list[QuantileForecast]: |
|
|
"""Generate predictions for a batch of time series.""" |
|
|
|
|
|
logger.debug(f"Processing batch of size: {len(test_data_batch)}") |
|
|
|
|
|
try: |
|
|
batch_container = self._convert_to_batch_container(test_data_batch) |
|
|
|
|
|
if isinstance(device, torch.device): |
|
|
device_type = device.type |
|
|
else: |
|
|
device_type = "cuda" if "cuda" in str(device).lower() else "cpu" |
|
|
enable_autocast = device_type == "cuda" |
|
|
|
|
|
with torch.autocast( |
|
|
device_type=device_type, |
|
|
dtype=torch.bfloat16, |
|
|
enabled=enable_autocast, |
|
|
): |
|
|
with torch.no_grad(): |
|
|
model_output = self.model(batch_container, drop_enc_allow=False) |
|
|
|
|
|
forecasts = self._convert_to_forecasts(model_output, test_data_batch, batch_container) |
|
|
|
|
|
logger.debug(f"Generated {len(forecasts)} forecasts") |
|
|
return forecasts |
|
|
except Exception as exc: |
|
|
logger.error(f"Error in batch prediction: {exc}") |
|
|
raise |
|
|
|
|
|
def _convert_to_batch_container(self, test_data_batch: list) -> BatchTimeSeriesContainer: |
|
|
"""Convert gluonts test data to BatchTimeSeriesContainer.""" |
|
|
|
|
|
batch_size = len(test_data_batch) |
|
|
history_values_list = [] |
|
|
start_dates = [] |
|
|
frequencies = [] |
|
|
|
|
|
for entry in test_data_batch: |
|
|
target = entry["target"] |
|
|
|
|
|
if target.ndim == 1: |
|
|
target = target.reshape(-1, 1) |
|
|
else: |
|
|
target = target.T |
|
|
|
|
|
if self.max_context_length is not None and len(target) > self.max_context_length: |
|
|
target = target[-self.max_context_length :] |
|
|
|
|
|
history_values_list.append(target) |
|
|
start_dates.append(entry["start"].to_timestamp().to_datetime64()) |
|
|
frequencies.append(parse_frequency(entry["freq"])) |
|
|
|
|
|
history_values_np = np.stack(history_values_list, axis=0) |
|
|
num_channels = history_values_np.shape[2] |
|
|
|
|
|
history_values = torch.tensor(history_values_np, dtype=torch.float32, device=device) |
|
|
|
|
|
future_values = torch.zeros( |
|
|
(batch_size, self.ds_prediction_length, num_channels), |
|
|
dtype=torch.float32, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
return BatchTimeSeriesContainer( |
|
|
history_values=history_values, |
|
|
future_values=future_values, |
|
|
start=start_dates, |
|
|
frequency=frequencies, |
|
|
) |
|
|
|
|
|
def _convert_to_forecasts( |
|
|
self, |
|
|
model_output: dict, |
|
|
test_data_batch: list, |
|
|
batch_container: BatchTimeSeriesContainer, |
|
|
) -> list[QuantileForecast]: |
|
|
"""Convert model predictions to QuantileForecast objects.""" |
|
|
|
|
|
predictions = model_output["result"] |
|
|
scale_statistics = model_output["scale_statistics"] |
|
|
|
|
|
if predictions.ndim == 4: |
|
|
predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics) |
|
|
is_quantile = True |
|
|
quantile_levels = self.model.quantiles |
|
|
else: |
|
|
predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics) |
|
|
is_quantile = False |
|
|
quantile_levels = [0.5] |
|
|
|
|
|
forecasts: list[QuantileForecast] = [] |
|
|
for idx, entry in enumerate(test_data_batch): |
|
|
history_length = int(batch_container.history_values.shape[1]) |
|
|
start_date = entry["start"] |
|
|
forecast_start = start_date + history_length |
|
|
|
|
|
if is_quantile: |
|
|
pred_array = predictions_unscaled[idx].cpu().numpy() |
|
|
|
|
|
if pred_array.shape[1] == 1: |
|
|
pred_array = pred_array.squeeze(1) |
|
|
forecast_arrays = pred_array.T |
|
|
else: |
|
|
forecast_arrays = pred_array.transpose(2, 0, 1) |
|
|
|
|
|
forecast = QuantileForecast( |
|
|
forecast_arrays=forecast_arrays, |
|
|
forecast_keys=[str(q) for q in quantile_levels], |
|
|
start_date=forecast_start, |
|
|
) |
|
|
else: |
|
|
pred_array = predictions_unscaled[idx].cpu().numpy() |
|
|
|
|
|
if pred_array.shape[1] == 1: |
|
|
pred_array = pred_array.squeeze(1) |
|
|
forecast_arrays = pred_array.reshape(1, -1) |
|
|
else: |
|
|
forecast_arrays = pred_array.reshape(1, *pred_array.shape) |
|
|
|
|
|
forecast = QuantileForecast( |
|
|
forecast_arrays=forecast_arrays, |
|
|
forecast_keys=["0.5"], |
|
|
start_date=forecast_start, |
|
|
) |
|
|
|
|
|
forecasts.append(forecast) |
|
|
|
|
|
return forecasts |
|
|
|
|
|
|
|
|
__all__ = ["TimeSeriesPredictor"] |
|
|
|