import torch.nn as nn from src.models.gated_deltaproduct import GatedDeltaProductConfig from src.models.gated_deltaproduct.modeling_gated_deltaproduct import ( GatedDeltaProductBlock, ) class GatedDeltaProductEncoder(nn.Module): """ GatedDeltaNet encoder using GatedDeltaProductBlock for sequence modeling. """ def __init__( self, layer_idx: int, token_embed_dim: int, num_heads: int = 4, attn_mode: str = "chunk", expand_v: float = 1.0, use_gate: bool = False, use_short_conv: bool = True, conv_size: int = 4, hidden_ratio: int = 1.0, allow_neg_eigval: bool = True, use_forget_gate: bool = True, num_householder: int = 1, **kwargs, ): super().__init__() config = GatedDeltaProductConfig( attn_mode=attn_mode, hidden_size=token_embed_dim, expand_v=expand_v, use_gate=use_gate, use_short_conv=use_short_conv, conv_size=conv_size, head_dim=token_embed_dim // num_heads, hidden_ratio=hidden_ratio, num_heads=num_heads, allow_neg_eigval=allow_neg_eigval, use_forget_gate=use_forget_gate, num_householder=num_householder, ) self.encoder_layer = GatedDeltaProductBlock(layer_idx=layer_idx, config=config) def forward(self, x, initial_state=None): """ Forward pass through the GatedDeltaProductBlock. Args: x: Input tensor of shape [batch_size, seq_len, hidden_size] Returns: Output tensor of same shape as input """ x, last_hidden_state, _ = self.encoder_layer(x, output_attentions=True, initial_state=initial_state) return x, last_hidden_state