๐Ÿง  MedAssist-GPT-401M

Mid-sized medical-domain LLM pretraining project. โš ๏ธ Strictly for research. Not for clinical or diagnostic use.


๐Ÿงฉ TL;DR

  • Architecture: Transformer with RoPE, GQA, SwiGLU MLP, and RMSNorm
  • Tokenizer: tiktoken p50k_base (vocab โ‰ˆ 50,281)
  • Context length: 1,024 tokens
  • Parameters: โ‰ˆ 401 M (d_model=1024, n_heads=32, blocks=24, d_ff=2048)
  • GQA groups: 8 โ†’ 4 KV heads per 32 query heads
  • Dropout: 0.0 (pretraining)
  • Precision: bf16 mixed precision
  • Training objective: Next-token prediction
  • Effective batch: 32 ร— 4 = 128

๐Ÿ“š Data

Field Value
Dataset japhba/pubmed_simple
Text column abstract
Train/Val split 95 / 5
Samples used 100 k abstracts
Seq length / stride 1,024 / 1,024
Cleaning use_clean=False (raw abstracts)

โš™๏ธ Training

Item Value
Framework PyTorch
Precision bf16
Objective Causal LM (next-token prediction)
Optimizer AdamW (ฮฒโ‚ = 0.9, ฮฒโ‚‚ = 0.95, eps = 1e-8)
Learning rate 3 ร— 10โปโด (linear + 100-step warmup)
Weight decay 0.1
Batch size 32 (ร— 4 grad acc โ†’ 128 effective)
Grad clip 1.0
Total steps 100 k
Eval every 500 steps ร— 100 iters
Checkpoint save every 1 k steps
Seed 7 979 797
Gradient checkpointing โœ… Enabled
WandB kunjcr2-dreamable/MedAssist-GPT-Pretraining (medassist-401M-test)
HF repo kunjcr2/MedAssist-GPT-401M

๐Ÿงฎ Training Environment

Item Value
Hardware 1ร— NVIDIA A100 (80 GB)
Precision dtype bf16
Runtime ~15 hours
Scheduler Linear LR decay
Mixed precision Native AMP (bf16)

๐Ÿ“ˆ Loss Curves

(Placeholder โ€” will update post-training) train\_loss val\_loss


๐Ÿš€ Minimal Inference

# pip install torch tiktoken huggingface_hub safetensors
import torch, tiktoken
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from MedAssistGPT import MedAssistGPT, MODEL_CONFIG

REPO_ID = "kunjcr2/MedAssist-GPT-401M"
weights = hf_hub_download(REPO_ID, "model.safetensors")
state = load_file(weights, device="cpu")

model = MedAssistGPT(MODEL_CONFIG)
model.load_state_dict(state, strict=True).eval()

enc = tiktoken.get_encoding("p50k_base")
ids = torch.tensor([enc.encode(
    "A patient was admitted with severe headache. Initial assessment revealed"
)], dtype=torch.long)

for _ in range(100):
    logits = model(ids)[:, -1, :]
    next_id = torch.multinomial(torch.softmax(logits / 0.6, dim=-1), 1)
    ids = torch.cat([ids, next_id], dim=1)
print(enc.decode(ids[0].tolist()))

๐Ÿ’พ Checkpoints

  • Main run: medassist-401M-test
  • Checkpoint: /checkpoints/checkpoint_step_44500.pt

๐Ÿงช Intended Use

For research and experimentation only โ€” e.g.,

  • domain-adapted pretraining,
  • architecture exploration,
  • fine-tuning for medical text understanding.

๐Ÿšซ Not intended for clinical or production medical use.


๐Ÿ”ฎ Future Work

Next update includes:

  • Supervised fine-tuning (SFT)
  • Reinforcement Learning (PPO) for alignment

๐Ÿ“ Files

  • 'checkpoints/'
  • config.json, tokenizer_config.json
  • Training script / notebook defining MedAssistGPT

๐Ÿชช License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Dataset used to train kunjcr2/MedAssistGPT