from torch import nn from GroupedQueryAttention import GroupedQueryAttention from FeedForward import FeedForward class TransformerBlock(nn.Module): def __init__(self, cfg): super().__init__() self.att = GroupedQueryAttention( d_in=cfg["emb_dim"], d_out=cfg["emb_dim"], num_heads=cfg["n_heads"], num_kv_groups=cfg["n_kv_groups"], dtype=cfg["dtype"] ) self.ff = FeedForward(cfg) self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) def forward(self, x, mask, cos, sin): # Shortcut connection for attention block shortcut = x x = self.norm1(x) x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size] x = x + shortcut # Add the original input back # Shortcut connection for feed-forward block shortcut = x x = self.norm2(x) x = self.ff(x) x = x + shortcut # Add the original input back return x