import torch import torch.nn as nn from fla.modules import GatedMLP from src.data.containers import BatchTimeSeriesContainer from src.data.scalers import MinMaxScaler, RobustScaler from src.data.time_features import compute_batch_time_features from src.models.blocks import GatedDeltaProductEncoder from src.utils.utils import device def create_scaler(scaler_type: str, epsilon: float = 1e-3): """Create scaler instance based on type.""" if scaler_type == "custom_robust": return RobustScaler(epsilon=epsilon) elif scaler_type == "min_max": return MinMaxScaler(epsilon=epsilon) else: raise ValueError(f"Unknown scaler: {scaler_type}") def apply_channel_noise(values: torch.Tensor, noise_scale: float = 0.1): """Add noise to constant channels to prevent model instability.""" is_constant = torch.all(values == values[:, 0:1, :], dim=1) noise = torch.randn_like(values) * noise_scale * is_constant.unsqueeze(1) return values + noise class TimeSeriesModel(nn.Module): """Time series forecasting model combining embedding, encoding, and prediction.""" def __init__( self, # Core architecture embed_size: int = 128, num_encoder_layers: int = 2, # Scaling and preprocessing scaler: str = "custom_robust", epsilon: float = 1e-3, scaler_clamp_value: float = None, handle_constants: bool = False, # Time features K_max: int = 6, time_feature_config: dict = None, encoding_dropout: float = 0.0, # Encoder configuration encoder_config: dict = None, # Loss configuration loss_type: str = "huber", # "huber", "quantile" quantiles: list[float] = None, **kwargs, ): super().__init__() # Core parameters self.embed_size = embed_size self.num_encoder_layers = num_encoder_layers self.epsilon = epsilon self.scaler_clamp_value = scaler_clamp_value self.handle_constants = handle_constants self.encoding_dropout = encoding_dropout self.K_max = K_max self.time_feature_config = time_feature_config or {} self.encoder_config = encoder_config or {} # Store loss parameters self.loss_type = loss_type self.quantiles = quantiles if self.loss_type == "quantile" and self.quantiles is None: raise ValueError("Quantiles must be provided for quantile loss.") if self.quantiles: self.register_buffer("qt", torch.tensor(self.quantiles, device=device).view(1, 1, 1, -1)) # Validate configuration before initialization self._validate_configuration() # Initialize components self.scaler = create_scaler(scaler, epsilon) self._init_embedding_layers() self._init_encoder_layers(self.encoder_config, num_encoder_layers) self._init_projection_layers() def _validate_configuration(self): """Validate essential model configuration parameters.""" if "num_heads" not in self.encoder_config: raise ValueError("encoder_config must contain 'num_heads' parameter") if self.embed_size % self.encoder_config["num_heads"] != 0: raise ValueError( f"embed_size ({self.embed_size}) must be divisible by num_heads ({self.encoder_config['num_heads']})" ) def _init_embedding_layers(self): """Initialize value and time feature embedding layers.""" self.expand_values = nn.Linear(1, self.embed_size, bias=True) self.nan_embedding = nn.Parameter( torch.randn(1, 1, 1, self.embed_size) / self.embed_size, requires_grad=True, ) self.time_feature_projection = nn.Linear(self.K_max, self.embed_size) def _init_encoder_layers(self, encoder_config: dict, num_encoder_layers: int): """Initialize encoder layers.""" self.num_encoder_layers = num_encoder_layers # Ensure encoder_config has token_embed_dim encoder_config = encoder_config.copy() encoder_config["token_embed_dim"] = self.embed_size self.encoder_layers = nn.ModuleList( [ GatedDeltaProductEncoder(layer_idx=layer_idx, **encoder_config) for layer_idx in range(self.num_encoder_layers) ] ) def _init_projection_layers(self): if self.loss_type == "quantile": output_dim = len(self.quantiles) else: output_dim = 1 self.final_output_layer = nn.Linear(self.embed_size, output_dim) self.mlp = GatedMLP( hidden_size=self.embed_size, hidden_ratio=4, hidden_act="swish", fuse_swiglu=True, ) # Initialize learnable initial hidden state for the first encoder layer # This will be expanded to match batch size during forward pass head_k_dim = self.embed_size // self.encoder_config["num_heads"] # Get expand_v from encoder_config, default to 1.0 if not present expand_v = self.encoder_config.get("expand_v", 1.0) head_v_dim = int(head_k_dim * expand_v) num_initial_hidden_states = self.num_encoder_layers self.initial_hidden_state = nn.ParameterList( [ nn.Parameter( torch.randn(1, self.encoder_config["num_heads"], head_k_dim, head_v_dim) / head_k_dim, requires_grad=True, ) for _ in range(num_initial_hidden_states) ] ) def _preprocess_data(self, data_container: BatchTimeSeriesContainer): """Extract data shapes and handle constants without padding.""" history_values = data_container.history_values future_values = data_container.future_values history_mask = data_container.history_mask batch_size, history_length, num_channels = history_values.shape future_length = future_values.shape[1] if future_values is not None else 0 # Handle constants if self.handle_constants: history_values = apply_channel_noise(history_values) return { "history_values": history_values, "future_values": future_values, "history_mask": history_mask, "num_channels": num_channels, "history_length": history_length, "future_length": future_length, "batch_size": batch_size, } def _compute_scaling(self, history_values: torch.Tensor, history_mask: torch.Tensor = None): """Compute scaling statistics and apply scaling.""" scale_statistics = self.scaler.compute_statistics(history_values, history_mask) return scale_statistics def _apply_scaling_and_masking(self, values: torch.Tensor, scale_statistics: dict, mask: torch.Tensor = None): """Apply scaling and optional masking to values.""" scaled_values = self.scaler.scale(values, scale_statistics) if mask is not None: scaled_values = scaled_values * mask.unsqueeze(-1).float() if self.scaler_clamp_value is not None: scaled_values = torch.clamp(scaled_values, -self.scaler_clamp_value, self.scaler_clamp_value) return scaled_values def _get_positional_embeddings( self, time_features: torch.Tensor, num_channels: int, batch_size: int, drop_enc_allow: bool = False, ): """Generate positional embeddings from time features.""" seq_len = time_features.shape[1] if (torch.rand(1).item() < self.encoding_dropout) and drop_enc_allow: return torch.zeros(batch_size, seq_len, num_channels, self.embed_size, device=device).to(torch.float32) pos_embed = self.time_feature_projection(time_features) return pos_embed.unsqueeze(2).expand(-1, -1, num_channels, -1) def _compute_embeddings( self, scaled_history: torch.Tensor, history_pos_embed: torch.Tensor, history_mask: torch.Tensor | None = None, ): """Compute value embeddings and combine with positional embeddings.""" nan_mask = torch.isnan(scaled_history) history_for_embedding = torch.nan_to_num(scaled_history, nan=0.0) channel_embeddings = self.expand_values(history_for_embedding.unsqueeze(-1)) channel_embeddings[nan_mask] = self.nan_embedding.to(channel_embeddings.dtype) channel_embeddings = channel_embeddings + history_pos_embed # Suppress padded time steps completely so padding is a pure batching artifact # history_mask: [B, S] -> broadcast to [B, S, 1, 1] if history_mask is not None: mask_broadcast = history_mask.unsqueeze(-1).unsqueeze(-1).to(channel_embeddings.dtype) channel_embeddings = channel_embeddings * mask_broadcast batch_size, seq_len = scaled_history.shape[:2] all_channels_embedded = channel_embeddings.view(batch_size, seq_len, -1) return all_channels_embedded def _generate_predictions( self, embedded: torch.Tensor, target_pos_embed: torch.Tensor, prediction_length: int, num_channels: int, history_mask: torch.Tensor = None, ): """ Generate predictions for all channels using vectorized operations. """ batch_size, seq_len, _ = embedded.shape # embedded shape: [B, S, N*E] -> Reshape to [B, S, N, E] embedded = embedded.view(batch_size, seq_len, num_channels, self.embed_size) # Vectorize across channels by merging the batch and channel dimensions. # [B, S, N, E] -> [B*N, S, E] channel_embedded = ( embedded.permute(0, 2, 1, 3).contiguous().view(batch_size * num_channels, seq_len, self.embed_size) ) # Reshape target positional embeddings similarly: [B, P, N, E] -> [B*N, P, E] target_pos_embed = ( target_pos_embed.permute(0, 2, 1, 3) .contiguous() .view(batch_size * num_channels, prediction_length, self.embed_size) ) x = channel_embedded target_repr = target_pos_embed x = torch.concatenate([x, target_repr], dim=1) if self.encoder_config.get("weaving", True): # initial hidden state is learnable hidden_state = torch.zeros_like(self.initial_hidden_state[0].repeat(batch_size * num_channels, 1, 1, 1)) for layer_idx, encoder_layer in enumerate(self.encoder_layers): x, hidden_state = encoder_layer( x, hidden_state + self.initial_hidden_state[layer_idx].repeat(batch_size * num_channels, 1, 1, 1), ) else: # initial hidden state is separately learnable for each layer for layer_idx, encoder_layer in enumerate(self.encoder_layers): initial_hidden_state = self.initial_hidden_state[layer_idx].repeat(batch_size * num_channels, 1, 1, 1) x, _ = encoder_layer(x, initial_hidden_state) # Use the last prediction_length positions prediction_embeddings = x[:, -prediction_length:, :] predictions = self.final_output_layer(self.mlp(prediction_embeddings)) # Reshape output to handle quantiles # Original shape: [B*N, P, Q] where Q is num_quantiles or 1 # Reshape the output back to [B, P, N, Q] output_dim = len(self.quantiles) if self.loss_type == "quantile" else 1 predictions = predictions.view(batch_size, num_channels, prediction_length, output_dim) predictions = predictions.permute(0, 2, 1, 3) # [B, P, N, Q] # Squeeze the last dimension if not in quantile mode for backward compatibility if self.loss_type != "quantile": predictions = predictions.squeeze(-1) # [B, P, N] return predictions def forward(self, data_container: BatchTimeSeriesContainer, drop_enc_allow: bool = False): """Main forward pass.""" # Preprocess data preprocessed = self._preprocess_data(data_container) # Compute time features dynamically based on actual lengths history_time_features, target_time_features = compute_batch_time_features( start=data_container.start, history_length=preprocessed["history_length"], future_length=preprocessed["future_length"], batch_size=preprocessed["batch_size"], frequency=data_container.frequency, K_max=self.K_max, time_feature_config=self.time_feature_config, ) # Compute scaling scale_statistics = self._compute_scaling(preprocessed["history_values"], preprocessed["history_mask"]) # Apply scaling history_scaled = self._apply_scaling_and_masking( preprocessed["history_values"], scale_statistics, preprocessed["history_mask"], ) # Scale future values if present future_scaled = None if preprocessed["future_values"] is not None: future_scaled = self.scaler.scale(preprocessed["future_values"], scale_statistics) # Get positional embeddings history_pos_embed = self._get_positional_embeddings( history_time_features, preprocessed["num_channels"], preprocessed["batch_size"], drop_enc_allow, ) target_pos_embed = self._get_positional_embeddings( target_time_features, preprocessed["num_channels"], preprocessed["batch_size"], drop_enc_allow, ) # Compute embeddings history_embed = self._compute_embeddings(history_scaled, history_pos_embed, preprocessed["history_mask"]) # Generate predictions predictions = self._generate_predictions( history_embed, target_pos_embed, preprocessed["future_length"], preprocessed["num_channels"], preprocessed["history_mask"], ) return { "result": predictions, "scale_statistics": scale_statistics, "future_scaled": future_scaled, "history_length": preprocessed["history_length"], "future_length": preprocessed["future_length"], } def _quantile_loss(self, y_true: torch.Tensor, y_pred: torch.Tensor): """ Compute the quantile loss. y_true: [B, P, N] y_pred: [B, P, N, Q] """ # Add a dimension to y_true to match y_pred: [B, P, N] -> [B, P, N, 1] y_true = y_true.unsqueeze(-1) # Calculate errors errors = y_true - y_pred # Calculate quantile loss # The max operator implements the two cases of the quantile loss formula loss = torch.max((self.qt - 1) * errors, self.qt * errors) # Average the loss across all dimensions return loss.mean() def compute_loss(self, y_true: torch.Tensor, y_pred: dict): """Compute loss between predictions and scaled ground truth.""" predictions = y_pred["result"] scale_statistics = y_pred["scale_statistics"] if y_true is None: return torch.tensor(0.0, device=predictions.device) future_scaled = self.scaler.scale(y_true, scale_statistics) if self.loss_type == "huber": if predictions.shape != future_scaled.shape: raise ValueError( f"Shape mismatch for Huber loss: predictions {predictions.shape} " f"vs future_scaled {future_scaled.shape}" ) return nn.functional.huber_loss(predictions, future_scaled) elif self.loss_type == "quantile": return self._quantile_loss(future_scaled, predictions) else: raise ValueError(f"Unknown loss type: {self.loss_type}")