|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Optional, Tuple
|
|
|
import json
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
"""Root Mean Square Layer Normalization (same as MAP-NEO)"""
|
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.eps = eps
|
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
return x * norm * self.weight
|
|
|
|
|
|
|
|
|
class RotaryPositionalEmbedding(nn.Module):
|
|
|
"""Rotary Position Embedding (RoPE) - same as MAP-NEO"""
|
|
|
def __init__(self, dim: int, max_len: int = 8192, theta: float = 10000.0):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.max_len = max_len
|
|
|
self.theta = theta
|
|
|
|
|
|
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
self.register_buffer("freqs", freqs, persistent=False)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
device = x.device
|
|
|
positions = torch.arange(seq_len, device=device).float()
|
|
|
|
|
|
|
|
|
angles = positions.unsqueeze(1) * self.freqs.unsqueeze(0)
|
|
|
|
|
|
cos = torch.cos(angles)
|
|
|
sin = torch.sin(angles)
|
|
|
|
|
|
return cos, sin
|
|
|
|
|
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
|
|
"""Apply rotary embedding to query/key tensors"""
|
|
|
|
|
|
|
|
|
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
|
|
|
|
|
|
|
rotated = torch.cat([
|
|
|
x1 * cos.unsqueeze(0).unsqueeze(-2) - x2 * sin.unsqueeze(0).unsqueeze(-2),
|
|
|
x1 * sin.unsqueeze(0).unsqueeze(-2) + x2 * cos.unsqueeze(0).unsqueeze(-2)
|
|
|
], dim=-1)
|
|
|
|
|
|
return rotated
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
|
"""Multi-head attention with RoPE and optional Flash Attention"""
|
|
|
def __init__(self, dim: int, n_heads: int, dropout: float = 0.0):
|
|
|
super().__init__()
|
|
|
assert dim % n_heads == 0
|
|
|
|
|
|
self.dim = dim
|
|
|
self.n_heads = n_heads
|
|
|
self.head_dim = dim // n_heads
|
|
|
self.scale = self.head_dim ** -0.5
|
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(dim, dim, bias=False)
|
|
|
self.k_proj = nn.Linear(dim, dim, bias=False)
|
|
|
self.v_proj = nn.Linear(dim, dim, bias=False)
|
|
|
self.o_proj = nn.Linear(dim, dim, bias=False)
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
self.rotary_emb = RotaryPositionalEmbedding(self.head_dim)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
batch_size, seq_len, dim = x.shape
|
|
|
|
|
|
|
|
|
q = self.q_proj(x)
|
|
|
k = self.k_proj(x)
|
|
|
v = self.v_proj(x)
|
|
|
|
|
|
|
|
|
q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
k = k.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
v = v.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(q, seq_len)
|
|
|
q = apply_rotary_emb(q.transpose(1, 2), cos, sin).transpose(1, 2)
|
|
|
k = apply_rotary_emb(k.transpose(1, 2), cos, sin).transpose(1, 2)
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
q_flash = q.transpose(1, 2)
|
|
|
k_flash = k.transpose(1, 2)
|
|
|
v_flash = v.transpose(1, 2)
|
|
|
|
|
|
|
|
|
out = F.scaled_dot_product_attention(
|
|
|
q_flash.transpose(1, 2), k_flash.transpose(1, 2), v_flash.transpose(1, 2),
|
|
|
attn_mask=None,
|
|
|
dropout_p=self.dropout.p if self.training else 0.0,
|
|
|
is_causal=True
|
|
|
)
|
|
|
out = out.transpose(1, 2)
|
|
|
|
|
|
except:
|
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
|
|
|
|
|
|
|
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
|
|
|
scores = scores.masked_fill(causal_mask, float('-inf'))
|
|
|
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
scores = scores.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(1), float('-inf'))
|
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1)
|
|
|
attn_weights = self.dropout(attn_weights)
|
|
|
|
|
|
out = torch.matmul(attn_weights, v)
|
|
|
out = out.transpose(1, 2)
|
|
|
|
|
|
|
|
|
out = out.contiguous().view(batch_size, seq_len, dim)
|
|
|
out = self.o_proj(out)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
"""SwiGLU Feed-Forward Network (same as MAP-NEO)"""
|
|
|
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
|
|
|
super().__init__()
|
|
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
gate = F.silu(self.gate_proj(x))
|
|
|
up = self.up_proj(x)
|
|
|
hidden = gate * up
|
|
|
hidden = self.dropout(hidden)
|
|
|
return self.down_proj(hidden)
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
"""Transformer block with pre-norm (RMSNorm)"""
|
|
|
def __init__(self, dim: int, n_heads: int, hidden_dim: int, dropout: float = 0.0):
|
|
|
super().__init__()
|
|
|
self.attention_norm = RMSNorm(dim)
|
|
|
self.attention = MultiHeadAttention(dim, n_heads, dropout)
|
|
|
self.ffn_norm = RMSNorm(dim)
|
|
|
self.ffn = FeedForward(dim, hidden_dim, dropout)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
|
|
h = x + self.attention(self.attention_norm(x), attention_mask)
|
|
|
|
|
|
|
|
|
h = h + self.ffn(self.ffn_norm(h))
|
|
|
|
|
|
return h
|
|
|
|
|
|
|
|
|
class NeoMiniConfig:
|
|
|
"""Configuration for MAP-NEO Mini (300M parameters)"""
|
|
|
def __init__(self):
|
|
|
|
|
|
self.vocab_size = 50257
|
|
|
self.max_seq_len = 2048
|
|
|
self.dim = 1024
|
|
|
self.n_layers = 16
|
|
|
self.n_heads = 16
|
|
|
self.hidden_dim = 2736
|
|
|
|
|
|
|
|
|
self.dropout = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
def to_dict(self):
|
|
|
return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
|
|
|
|
|
|
@classmethod
|
|
|
def from_dict(cls, config_dict):
|
|
|
config = cls()
|
|
|
for k, v in config_dict.items():
|
|
|
setattr(config, k, v)
|
|
|
return config
|
|
|
|
|
|
|
|
|
class NeoMini(nn.Module):
|
|
|
"""MAP-NEO Mini Language Model (300M parameters)"""
|
|
|
def __init__(self, config: NeoMiniConfig):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, config.dim)
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([
|
|
|
TransformerBlock(
|
|
|
dim=config.dim,
|
|
|
n_heads=config.n_heads,
|
|
|
hidden_dim=config.hidden_dim,
|
|
|
dropout=config.dropout
|
|
|
)
|
|
|
for _ in range(config.n_layers)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.ln_f = RMSNorm(config.dim)
|
|
|
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
self.lm_head.weight = self.token_embedding.weight
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
if isinstance(module, nn.Linear):
|
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
if module.bias is not None:
|
|
|
nn.init.zeros_(module.bias)
|
|
|
elif isinstance(module, nn.Embedding):
|
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
|
|
x = self.token_embedding(input_ids)
|
|
|
|
|
|
|
|
|
for block in self.blocks:
|
|
|
x = block(x, attention_mask)
|
|
|
|
|
|
|
|
|
x = self.ln_f(x)
|
|
|
logits = self.lm_head(x)
|
|
|
|
|
|
return logits
|
|
|
|
|
|
def get_num_params(self):
|
|
|
"""Count model parameters"""
|
|
|
return sum(p.numel() for p in self.parameters())
|
|
|
|
|
|
def save_config(self, path: str):
|
|
|
"""Save model configuration"""
|
|
|
with open(path, 'w') as f:
|
|
|
json.dump(self.config.to_dict(), f, indent=2)
|
|
|
|
|
|
@classmethod
|
|
|
def from_config(cls, config_path: str):
|
|
|
"""Load model from configuration"""
|
|
|
with open(config_path, 'r') as f:
|
|
|
config_dict = json.load(f)
|
|
|
config = NeoMiniConfig.from_dict(config_dict)
|
|
|
return cls(config)
|
|
|
|
|
|
|
|
|
def create_model():
|
|
|
"""Create a MAP-NEO Mini model"""
|
|
|
config = NeoMiniConfig()
|
|
|
model = NeoMini(config)
|
|
|
|
|
|
print(f"Created MAP-NEO Mini with {model.get_num_params():,} parameters")
|
|
|
print(f"Config: {config.n_layers} layers, {config.dim} dim, {config.n_heads} heads")
|
|
|
|
|
|
return model, config
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
model, config = create_model()
|
|
|
|
|
|
|
|
|
batch_size, seq_len = 2, 512
|
|
|
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
|
|
|
|
|
with torch.no_grad():
|
|
|
logits = model(input_ids)
|
|
|
print(f"Input shape: {input_ids.shape}")
|
|
|
print(f"Output shape: {logits.shape}")
|
|
|
print("Model test passed!") |