import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_medassistgpt import MedAssistGPTConfig # --- Rotary Position Embeddings --- class RoPE(nn.Module): def __init__(self, d_model: int, max_len: int = 5000): super().__init__() assert d_model % 2 == 0 self.register_buffer('position_ids', torch.arange(max_len).unsqueeze(1)) self.register_buffer( 'div_term', torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) ) def forward(self, x: torch.Tensor): B, T, D = x.shape pos = self.position_ids[:T] angles = pos * self.div_term cos, sin = torch.cos(angles), torch.sin(angles) x_pairs = x.view(B, T, D // 2, 2) x_even, x_odd = x_pairs[..., 0], x_pairs[..., 1] rotated = torch.stack([x_even * cos - x_odd * sin, x_even * sin + x_odd * cos], dim=-1) return rotated.view(B, T, D) # --- Attention --- class GroupedQueryAttention(nn.Module): def __init__(self, d_model, n_heads, gqa_groups, max_len): super().__init__() assert d_model % n_heads == 0 self.d_model, self.n_heads, self.gqa_groups = d_model, n_heads, gqa_groups self.head_dim = d_model // n_heads self.n_kv_heads = n_heads // gqa_groups self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False) self.rope_q = RoPE(n_heads * self.head_dim, max_len) self.rope_k = RoPE(self.n_kv_heads * self.head_dim, max_len) def forward(self, x): B, T, C = x.shape q = self.rope_q(self.q_proj(x)) k = self.rope_k(self.k_proj(x)) v = self.v_proj(x) q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) expand = self.n_heads // self.n_kv_heads k = k.repeat_interleave(expand, dim=1) v = v.repeat_interleave(expand, dim=1) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) out = out.transpose(1, 2).contiguous().view(B, T, C) return self.o_proj(out) # --- SwiGLU Feedforward --- class SwiGLU_MLP(nn.Module): def __init__(self, d_model, d_ff): super().__init__() self.w1 = nn.Linear(d_model, 2 * d_ff, bias=False) self.w2 = nn.Linear(d_ff, d_model, bias=False) def forward(self, x): up, gate = self.w1(x).chunk(2, dim=-1) return self.w2(up * F.silu(gate)) # --- Transformer Block --- class TransformerBlock(nn.Module): def __init__(self, cfg): super().__init__() self.rms1 = nn.RMSNorm(cfg.d_model, eps=cfg.eps) self.rms2 = nn.RMSNorm(cfg.d_model, eps=cfg.eps) self.attn = GroupedQueryAttention(cfg.d_model, cfg.n_heads, cfg.gqa_groups, cfg.max_len) self.mlp = SwiGLU_MLP(cfg.d_model, cfg.d_ff) self.dropout = nn.Dropout(cfg.dropout_p) def forward(self, x): x = x + self.dropout(self.attn(self.rms1(x))) x = x + self.dropout(self.mlp(self.rms2(x))) return x # --- Main Model --- class MedAssistGPTModel(PreTrainedModel): config_class = MedAssistGPTConfig def __init__(self, config: MedAssistGPTConfig): super().__init__(config) self.embed = nn.Embedding(config.vocab_size, config.d_model) self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.blocks)]) self.final_rms = nn.RMSNorm(config.d_model, eps=config.eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.embed.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs ): # If past_key_values is provided, only use the last token if past_key_values is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } def forward(self, input_ids, labels=None): h = self.embed(input_ids) for blk in self.blocks: h = blk(h) h = self.final_rms(h) logits = self.lm_head(h) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits) # --- Alias for CausalLM --- class MedAssistGPTForCausalLM(MedAssistGPTModel): pass