# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from __future__ import annotations import math import warnings from typing import TYPE_CHECKING import torch import torch.nn as nn from einops import rearrange, repeat from fla.layers.utils import get_unpad_data, index_first_axis, pad_input from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution from fla.ops.delta_rule import fused_recurrent_delta_rule from fla.ops.gated_delta_product import chunk_gated_delta_product from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule from torch.nn import functional as F if TYPE_CHECKING: from fla.models.utils import Cache from transformers.processing_utils import Unpack class GatedDeltaProduct(nn.Module): """ Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations. """ def __init__( self, hidden_size: int = 2048, expand_v: float = 2, head_dim: int = 256, num_heads: int = 6, num_v_heads: int = None, mode: str = "chunk", use_gate: bool = True, use_short_conv: bool = True, conv_size: int = 4, conv_bias: bool = False, layer_idx: int = None, norm_eps: float = 1e-5, use_forget_gate: bool = True, allow_neg_eigval: bool = True, num_householder: int = 2, **kwargs, ) -> GatedDeltaProduct: super().__init__() self.mode = mode self.hidden_size = hidden_size self.expand_v = expand_v self.use_forget_gate = use_forget_gate self.allow_neg_eigval = allow_neg_eigval self.num_householder = num_householder self.use_gate = use_gate self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias self.head_dim = head_dim self.num_heads = num_heads self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads self.head_k_dim = head_dim self.head_v_dim = int(self.head_dim * self.expand_v) self.key_dim = int(self.num_heads * self.head_k_dim) self.value_dim = int(self.num_v_heads * self.head_v_dim) self.layer_idx = layer_idx self.init_hidden_state = nn.Parameter(torch.randn(self.num_heads, self.head_dim, self.head_dim)) # Consistency check: Ensure expand_v produces integer values if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5): raise ValueError( f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "( f"Resulting value_dim would be " f"{self.num_v_heads * self.head_dim * expand_v}, " "which is invalid for nn.Linear." ) ) if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0: raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.") if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5): raise ValueError( f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. " f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated." ) assert mode in ["chunk", "fused_recurrent"], f"Not suppoerted mode `{mode}`." self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) self.k_proj = nn.Linear(hidden_size, self.key_dim * num_householder, bias=False) self.v_proj = nn.Linear(hidden_size, self.value_dim * num_householder, bias=False) self.b_proj = nn.Linear(hidden_size, self.num_v_heads * num_householder, bias=False) if self.use_forget_gate: self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False) A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True # hard coded for now dt_min = 0.001 dt_max = 0.1 dt_init_floor = 1e-4 dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) dt = torch.clamp(dt, min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) self.dt_bias = nn.Parameter(inv_dt) # Just to be explicit. Without this we already don't put wd on dt_bias because of the check # name.endswith("bias") in param_grouping.py self.dt_bias._no_weight_decay = True if use_short_conv: self.conv_size = conv_size self.q_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, bias=conv_bias, activation="silu", ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim * num_householder, kernel_size=conv_size, bias=conv_bias, activation="silu", ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim * num_householder, kernel_size=conv_size, bias=conv_bias, activation="silu", ) else: warnings.warn( "ShortConvolution is crucial to the performance. " "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." ) if use_gate: self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) else: self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) def _initialize_weights(self, module: nn.Module): if getattr(module, "_is_hf_initialized", False): return if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight, gain=2**-2.5) if module.bias is not None: nn.init.zeros_(module.bias) module._is_hf_initialized = True def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, initial_state: torch.Tensor | None = None, use_cache: bool | None = False, output_attentions: bool | None = False, **kwargs: Unpack[dict], ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: if attention_mask is not None: assert len(attention_mask.shape) == 2, ( "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " "for padding purposes (0 indicating padding). " "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." ) batch_size, q_len, _ = hidden_states.shape # change to inference mode. mode = self.mode if self.training: assert mode == "chunk", "Only chunk mode is supported in training." last_state = None if past_key_values is not None and len(past_key_values) > self.layer_idx: last_state = past_key_values[self.layer_idx] cu_seqlens = kwargs.get("cu_seqlens", None) if attention_mask is not None: indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) if self.use_short_conv: conv_state_q, conv_state_k, conv_state_v = None, None, None if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"] q, conv_state_q = self.q_conv1d( x=self.q_proj(hidden_states), cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens, ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens, ) v, conv_state_v = self.v_conv1d( x=self.v_proj(hidden_states), cache=conv_state_v, output_final_state=use_cache, cu_seqlens=cu_seqlens, ) else: q = F.silu(self.q_proj(hidden_states)) k = F.silu(self.k_proj(hidden_states)) v = F.silu(self.v_proj(hidden_states)) q = rearrange(q, "... (h d) -> ... h d", d=self.head_k_dim) k = rearrange( k, "... l (n h d) -> ... (l n) h d", n=self.num_householder, d=self.head_k_dim, ) v = rearrange( v, "... l (n h d) -> ... (l n) h d", n=self.num_householder, d=self.head_v_dim, ) if self.num_v_heads > self.num_heads: q, k = map( lambda x: repeat(x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads), (q, k), ) beta = self.b_proj(hidden_states).sigmoid() if self.allow_neg_eigval: beta = beta * 2.0 beta = rearrange(beta, "... l (n h) -> ... (l n) h", n=self.num_householder) if self.use_forget_gate: g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) else: g = None recurrent_state = last_state["recurrent_state"] if last_state is not None else None if mode == "chunk": o, recurrent_state = chunk_gated_delta_product( q=q, k=k, v=v, g=g, beta=beta, initial_state=initial_state, output_final_state=output_attentions, cu_seqlens=cu_seqlens, num_householder=self.num_householder, use_qk_l2norm_in_kernel=True, ) elif mode == "fused_recurrent": if self.use_forget_gate: g_new = torch.zeros( g.shape[0], g.shape[1], self.num_householder, g.shape[2], device=g.device, dtype=torch.float32, ) g_new[:, :, 0] = g g = rearrange(g_new, "... l n h -> ... (l n) h") q_new = q.new_zeros(q.shape[0], q.shape[1], self.num_householder, q.shape[2], q.shape[3]) q_new[:, :, -1] = q q = rearrange(q_new, "... l n h d-> ... (l n) h d") if self.use_forget_gate: o, recurrent_state = fused_recurrent_gated_delta_rule( q=q, k=k, v=v, g=g, beta=beta, initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens * self.num_householder if cu_seqlens is not None else None, use_qk_l2norm_in_kernel=True, ) else: o, recurrent_state = fused_recurrent_delta_rule( q=q, k=k, v=v, beta=beta, initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens * self.num_householder if cu_seqlens is not None else None, use_qk_l2norm_in_kernel=True, ) o = rearrange(o, "... (l n) h d -> ... l n h d", n=self.num_householder)[..., -1, :, :].contiguous() if past_key_values is not None: past_key_values.update( recurrent_state=recurrent_state, conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, layer_idx=self.layer_idx, offset=q_len, ) if self.use_gate: g = rearrange(self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim) o = self.o_norm(o, g) else: o = self.o_norm(o) o = rearrange(o, "b t h d -> b t (h d)") o = self.o_proj(o) if attention_mask is not None: o = pad_input(o.squeeze(0), indices, batch_size, q_len) return o, recurrent_state, past_key_values