from tools import * from torch import nn import torch class GroupedQueryAttention(nn.Module): def __init__( self, d_in, d_out, num_heads, num_kv_groups, dtype=None ): super().__init__() assert d_out % num_heads == 0, "d_out must be divisible by num_heads" assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" self.d_out = d_out self.num_heads = num_heads self.head_dim = d_out // num_heads self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) self.num_kv_groups = num_kv_groups self.group_size = num_heads // num_kv_groups self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) def forward(self, x, mask, cos, sin): b, num_tokens, d_in = x.shape queries = self.W_query(x) # Shape: (b, num_tokens, d_out) keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim) values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim) # Reshape queries, keys, and values queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) # Transpose keys, values, and queries keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim) values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim) queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) # Apply RoPE keys = apply_rope(keys, cos, sin) queries = apply_rope(queries, cos, sin) # Expand keys and values to match the number of heads # Shape: (b, num_heads, num_tokens, head_dim) keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim) values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim) # For example, before repeat_interleave along dim=1 (query groups): # [K1, K2] # After repeat_interleave (each query group is repeated group_size times): # [K1, K1, K2, K2] # If we used regular repeat instead of repeat_interleave, we'd get: # [K1, K2, K1, K2] # Compute scaled dot-product attention (aka self-attention) with a causal mask # Shape: (b, num_heads, num_tokens, num_tokens) attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head # Compute attention scores attn_scores = attn_scores.masked_fill(mask, -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) assert keys.shape[-1] == self.head_dim # Shape: (b, num_tokens, num_heads, head_dim) context_vec = (attn_weights @ values).transpose(1, 2) # Combine heads, where self.d_out = self.num_heads * self.head_dim context_vec = context_vec.reshape(b, num_tokens, self.d_out) context_vec = self.out_proj(context_vec) # optional projection return context_vec