from dataclasses import dataclass from typing import List, Optional import numpy as np import torch from src.data.frequency import Frequency @dataclass class BatchTimeSeriesContainer: """ Container for a batch of multivariate time series data and their associated features. Attributes: history_values: Tensor of historical observations. Shape: [batch_size, seq_len, num_channels] future_values: Tensor of future observations to predict. Shape: [batch_size, pred_len, num_channels] start: Timestamp of the first history value. Type: List[np.datetime64] frequency: Frequency of the time series. Type: List[Frequency] history_mask: Optional boolean/float tensor indicating missing entries in history_values across channels. Shape: [batch_size, seq_len] future_mask: Optional boolean/float tensor indicating missing entries in future_values across channels. Shape: [batch_size, pred_len] """ history_values: torch.Tensor future_values: torch.Tensor start: List[np.datetime64] frequency: List[Frequency] history_mask: Optional[torch.Tensor] = None future_mask: Optional[torch.Tensor] = None def __post_init__(self): """Validate all tensor shapes and consistency.""" # --- Tensor Type Checks --- if not isinstance(self.history_values, torch.Tensor): raise TypeError("history_values must be a torch.Tensor") if not isinstance(self.future_values, torch.Tensor): raise TypeError("future_values must be a torch.Tensor") if not isinstance(self.start, list) or not all( isinstance(x, np.datetime64) for x in self.start ): raise TypeError("start must be a List[np.datetime64]") if not isinstance(self.frequency, list) or not all( isinstance(x, Frequency) for x in self.frequency ): raise TypeError("frequency must be a List[Frequency]") batch_size, seq_len, num_channels = self.history_values.shape pred_len = self.future_values.shape[1] # --- Core Shape Checks --- if self.future_values.shape[0] != batch_size: raise ValueError("Batch size mismatch between history and future_values") if self.future_values.shape[2] != num_channels: raise ValueError("Channel size mismatch between history and future_values") # --- Optional Mask Checks --- if self.history_mask is not None: if not isinstance(self.history_mask, torch.Tensor): raise TypeError("history_mask must be a Tensor or None") if self.history_mask.shape[:2] != (batch_size, seq_len): raise ValueError( f"Shape mismatch in history_mask: {self.history_mask.shape[:2]} vs {(batch_size, seq_len)}" ) if self.future_mask is not None: if not isinstance(self.future_mask, torch.Tensor): raise TypeError("future_mask must be a Tensor or None") if not ( self.future_mask.shape == (batch_size, pred_len) or self.future_mask.shape == self.future_values.shape ): raise ValueError( f"Shape mismatch in future_mask: expected {(batch_size, pred_len)} or {self.future_values.shape}, got {self.future_mask.shape}" ) def to_device( self, device: torch.device, attributes: Optional[List[str]] = None ) -> None: """ Move specified tensors to the target device in place. Args: device: Target device (e.g., 'cpu', 'cuda'). attributes: Optional list of attribute names to move. If None, move all tensors. Raises: ValueError: If an invalid attribute is specified or device transfer fails. """ all_tensors = { "history_values": self.history_values, "future_values": self.future_values, "history_mask": self.history_mask, "future_mask": self.future_mask, } if attributes is None: attributes = [k for k, v in all_tensors.items() if v is not None] for attr in attributes: if attr not in all_tensors: raise ValueError(f"Invalid attribute: {attr}") if all_tensors[attr] is not None: setattr(self, attr, all_tensors[attr].to(device)) def to(self, device: torch.device, attributes: Optional[List[str]] = None): """ Alias for to_device method for consistency with PyTorch conventions. Args: device: Target device (e.g., 'cpu', 'cuda'). attributes: Optional list of attribute names to move. If None, move all tensors. """ self.to_device(device, attributes) return self @property def batch_size(self) -> int: return self.history_values.shape[0] @property def history_length(self) -> int: return self.history_values.shape[1] @property def future_length(self) -> int: return self.future_values.shape[1] @property def num_channels(self) -> int: return self.history_values.shape[2] @dataclass class TimeSeriesContainer: """ Container for batch of time series data without explicit history/future split. This container is used for storing generated synthetic time series data where the entire series is treated as a single entity, typically for further processing or splitting into history/future components later. Attributes: values: np.ndarray of time series values. Shape: [batch_size, seq_len, num_channels] for multivariate series [batch_size, seq_len] for univariate series start: List of start timestamps for each series in the batch. Type: List[np.datetime64], length should match batch_size frequency: List of frequency for each series in the batch. Type: List[Frequency], length should match batch_size """ values: np.ndarray start: List[np.datetime64] frequency: List[Frequency] def __post_init__(self): """Validate all shapes and consistency.""" # --- Numpy Type Checks --- if not isinstance(self.values, np.ndarray): raise TypeError("values must be a np.ndarray") if not isinstance(self.start, list) or not all( isinstance(x, np.datetime64) for x in self.start ): raise TypeError("start must be a List[np.datetime64]") if not isinstance(self.frequency, list) or not all( isinstance(x, Frequency) for x in self.frequency ): raise TypeError("frequency must be a List[Frequency]") # --- Shape and Length Consistency Checks --- if len(self.values.shape) < 2 or len(self.values.shape) > 3: raise ValueError( f"values must have 2 or 3 dimensions [batch_size, seq_len] or [batch_size, seq_len, num_channels], got shape {self.values.shape}" ) batch_size = self.values.shape[0] if len(self.start) != batch_size: raise ValueError( f"Length of start ({len(self.start)}) must match batch_size ({batch_size})" ) if len(self.frequency) != batch_size: raise ValueError( f"Length of frequency ({len(self.frequency)}) must match batch_size ({batch_size})" ) @property def batch_size(self) -> int: return self.values.shape[0] @property def seq_length(self) -> int: return self.values.shape[1] @property def num_channels(self) -> int: return self.values.shape[2] if len(self.values.shape) == 3 else 1