from __future__ import annotations from typing import TYPE_CHECKING import torch import torch.nn as nn from fla.layers.attn import Attention from fla.models.utils import Cache from fla.modules import GatedMLP as GatedDeltaProductMLP from fla.modules import RMSNorm from src.models.gated_deltaproduct.configuration_gated_deltaproduct import ( GatedDeltaProductConfig, ) from src.models.gated_deltaproduct.gated_deltaproduct import GatedDeltaProduct if TYPE_CHECKING: from transformers.processing_utils import Unpack class GatedDeltaProductBlock(nn.Module): def __init__(self, config: GatedDeltaProductConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) if config.attn is not None and layer_idx in config.attn["layers"]: self.attn = Attention( hidden_size=config.hidden_size, num_heads=config.attn["num_heads"], num_kv_heads=config.attn["num_kv_heads"], qkv_bias=config.attn["qkv_bias"], window_size=config.attn["window_size"], rope_theta=config.attn["rope_theta"], max_position_embeddings=config.max_position_embeddings, layer_idx=layer_idx, ) else: self.attn = GatedDeltaProduct( mode=config.attn_mode, hidden_size=config.hidden_size, expand_v=config.expand_v, head_dim=config.head_dim, num_heads=config.num_heads, use_gate=config.use_gate, use_forget_gate=config.use_forget_gate, use_short_conv=config.use_short_conv, conv_size=config.conv_size, norm_eps=config.norm_eps, allow_neg_eigval=config.allow_neg_eigval, num_householder=config.num_householder, layer_idx=layer_idx, ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.mlp = GatedDeltaProductMLP( hidden_size=config.hidden_size, hidden_ratio=config.hidden_ratio, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, fuse_swiglu=config.fuse_swiglu, ) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, past_key_values: Cache | list[torch.FloatTensor] | None = None, use_cache: bool | None = False, output_attentions: bool | None = False, initial_state: torch.FloatTensor | None = None, **kwargs: Unpack[dict], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: residual = hidden_states hidden_states = self.attn_norm(hidden_states) hidden_states, attentions, past_key_values = self.attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, initial_state=initial_state, **kwargs, ) if self.config.fuse_norm: hidden_states, residual = self.mlp_norm(hidden_states, residual, True) else: hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.mlp_norm(hidden_states) hidden_states = self.mlp(hidden_states, **kwargs) hidden_states = residual + hidden_states outputs = (hidden_states, attentions, past_key_values) return outputs