๐ง MedAssist-GPT-401M
Mid-sized medical-domain LLM pretraining project. โ ๏ธ Strictly for research. Not for clinical or diagnostic use.
๐งฉ TL;DR
- Architecture: Transformer with RoPE, GQA, SwiGLU MLP, and RMSNorm
- Tokenizer:
tiktokenp50k_base(vocab โ 50,281) - Context length: 1,024 tokens
- Parameters: โ 401 M (
d_model=1024,n_heads=32,blocks=24,d_ff=2048) - GQA groups: 8 โ 4 KV heads per 32 query heads
- Dropout: 0.0 (pretraining)
- Precision: bf16 mixed precision
- Training objective: Next-token prediction
- Effective batch: 32 ร 4 = 128
๐ Data
| Field | Value |
|---|---|
| Dataset | japhba/pubmed_simple |
| Text column | abstract |
| Train/Val split | 95 / 5 |
| Samples used | 100 k abstracts |
| Seq length / stride | 1,024 / 1,024 |
| Cleaning | use_clean=False (raw abstracts) |
โ๏ธ Training
| Item | Value |
|---|---|
| Framework | PyTorch |
| Precision | bf16 |
| Objective | Causal LM (next-token prediction) |
| Optimizer | AdamW (ฮฒโ = 0.9, ฮฒโ = 0.95, eps = 1e-8) |
| Learning rate | 3 ร 10โปโด (linear + 100-step warmup) |
| Weight decay | 0.1 |
| Batch size | 32 (ร 4 grad acc โ 128 effective) |
| Grad clip | 1.0 |
| Total steps | 100 k |
| Eval | every 500 steps ร 100 iters |
| Checkpoint save | every 1 k steps |
| Seed | 7 979 797 |
| Gradient checkpointing | โ Enabled |
| WandB | kunjcr2-dreamable/MedAssist-GPT-Pretraining (medassist-401M-test) |
| HF repo | kunjcr2/MedAssist-GPT-401M |
๐งฎ Training Environment
| Item | Value |
|---|---|
| Hardware | 1ร NVIDIA A100 (80 GB) |
| Precision dtype | bf16 |
| Runtime | ~15 hours |
| Scheduler | Linear LR decay |
| Mixed precision | Native AMP (bf16) |
๐ Loss Curves
(Placeholder โ will update post-training)

๐ Minimal Inference
# pip install torch tiktoken huggingface_hub safetensors
import torch, tiktoken
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from MedAssistGPT import MedAssistGPT, MODEL_CONFIG
REPO_ID = "kunjcr2/MedAssist-GPT-401M"
weights = hf_hub_download(REPO_ID, "model.safetensors")
state = load_file(weights, device="cpu")
model = MedAssistGPT(MODEL_CONFIG)
model.load_state_dict(state, strict=True).eval()
enc = tiktoken.get_encoding("p50k_base")
ids = torch.tensor([enc.encode(
"A patient was admitted with severe headache. Initial assessment revealed"
)], dtype=torch.long)
for _ in range(100):
logits = model(ids)[:, -1, :]
next_id = torch.multinomial(torch.softmax(logits / 0.6, dim=-1), 1)
ids = torch.cat([ids, next_id], dim=1)
print(enc.decode(ids[0].tolist()))
๐พ Checkpoints
- Main run:
medassist-401M-test - Checkpoint:
/checkpoints/checkpoint_step_44500.pt
๐งช Intended Use
For research and experimentation only โ e.g.,
- domain-adapted pretraining,
- architecture exploration,
- fine-tuning for medical text understanding.
๐ซ Not intended for clinical or production medical use.
๐ฎ Future Work
Next update includes:
- Supervised fine-tuning (SFT)
- Reinforcement Learning (PPO) for alignment
๐ Files
- 'checkpoints/'
config.json,tokenizer_config.json- Training script / notebook defining
MedAssistGPT
๐ชช License
Apache 2.0
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support