File size: 5,351 Bytes
432c680 50c8b86 432c680 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 36d5852 50c8b86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
---
license: apache-2.0
datasets:
- japhba/pubmed_simple
language:
- en
tags:
- v2_pretrain_medassist
- gqa
- rope
- swiglu
- rmsnorm
- medical
---
# ๐ง MedAssist-GPT-401M
**Mid-sized medical-domain LLM pretraining project.**
โ ๏ธ *Strictly for research. Not for clinical or diagnostic use.*
---
## ๐งฉ TL;DR
* **Architecture:** Transformer with **RoPE**, **GQA**, **SwiGLU** MLP, and **RMSNorm**
* **Tokenizer:** `tiktoken` `p50k_base` (vocab โ **50,281**)
* **Context length:** 1,024 tokens
* **Parameters:** โ **401 M** (`d_model=1024`, `n_heads=32`, `blocks=24`, `d_ff=2048`)
* **GQA groups:** 8 โ 4 KV heads per 32 query heads
* **Dropout:** 0.0 (pretraining)
* **Precision:** **bf16** mixed precision
* **Training objective:** Next-token prediction
* **Effective batch:** 32 ร 4 = 128
---
## ๐ Data
| Field | Value |
| ----------------------- | --------------------------------- |
| **Dataset** | `japhba/pubmed_simple` |
| **Text column** | `abstract` |
| **Train/Val split** | 95 / 5 |
| **Samples used** | 100 k abstracts |
| **Seq length / stride** | 1,024 / 1,024 |
| **Cleaning** | `use_clean=False` (raw abstracts) |
---
## โ๏ธ Training
| Item | Value |
| -------------------------- | --------------------------------------------------------------------- |
| **Framework** | PyTorch |
| **Precision** | bf16 |
| **Objective** | Causal LM (next-token prediction) |
| **Optimizer** | AdamW (`ฮฒโ = 0.9`, `ฮฒโ = 0.95`, `eps = 1e-8`) |
| **Learning rate** | 3 ร 10โปโด (linear + 100-step warmup) |
| **Weight decay** | 0.1 |
| **Batch size** | 32 (ร 4 grad acc โ 128 effective) |
| **Grad clip** | 1.0 |
| **Total steps** | 100 k |
| **Eval** | every 500 steps ร 100 iters |
| **Checkpoint save** | every 1 k steps |
| **Seed** | 7 979 797 |
| **Gradient checkpointing** | โ
Enabled |
| **WandB** | `kunjcr2-dreamable/MedAssist-GPT-Pretraining` (`medassist-401M-test`) |
| **HF repo** | `kunjcr2/MedAssist-GPT-401M` |
---
## ๐งฎ Training Environment
| Item | Value |
| ------------------- | ---------------------- |
| **Hardware** | 1ร NVIDIA A100 (80 GB) |
| **Precision dtype** | bf16 |
| **Runtime** | ~15 hours |
| **Scheduler** | Linear LR decay |
| **Mixed precision** | Native AMP (bf16) |
---
## ๐ Loss Curves
*(Placeholder โ will update post-training)*


---
## ๐ Minimal Inference
```python
# pip install torch tiktoken huggingface_hub safetensors
import torch, tiktoken
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from MedAssistGPT import MedAssistGPT, MODEL_CONFIG
REPO_ID = "kunjcr2/MedAssist-GPT-401M"
weights = hf_hub_download(REPO_ID, "model.safetensors")
state = load_file(weights, device="cpu")
model = MedAssistGPT(MODEL_CONFIG)
model.load_state_dict(state, strict=True).eval()
enc = tiktoken.get_encoding("p50k_base")
ids = torch.tensor([enc.encode(
"A patient was admitted with severe headache. Initial assessment revealed"
)], dtype=torch.long)
for _ in range(100):
logits = model(ids)[:, -1, :]
next_id = torch.multinomial(torch.softmax(logits / 0.6, dim=-1), 1)
ids = torch.cat([ids, next_id], dim=1)
print(enc.decode(ids[0].tolist()))
```
---
## ๐พ Checkpoints
* Main run: `medassist-401M-test`
* Checkpoint: `/checkpoints/checkpoint_step_44500.pt`
---
## ๐งช Intended Use
For research and experimentation only โ e.g.,
* domain-adapted pretraining,
* architecture exploration,
* fine-tuning for medical text understanding.
๐ซ **Not intended for clinical or production medical use.**
---
## ๐ฎ Future Work
Next update includes:
* **Supervised fine-tuning (SFT)**
* **Reinforcement Learning (PPO) for alignment**
---
## ๐ Files
* 'checkpoints/'
* `config.json`, `tokenizer_config.json`
* Training script / notebook defining `MedAssistGPT`
---
## ๐ชช License
Apache 2.0 |