File size: 5,351 Bytes
432c680
 
 
50c8b86
432c680
 
36d5852
50c8b86
 
 
 
 
 
36d5852
 
50c8b86
36d5852
50c8b86
 
36d5852
50c8b86
36d5852
50c8b86
36d5852
50c8b86
 
 
 
 
 
 
 
 
36d5852
50c8b86
36d5852
50c8b86
36d5852
50c8b86
 
 
 
 
 
 
 
36d5852
50c8b86
36d5852
50c8b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36d5852
50c8b86
 
 
 
 
 
 
 
 
36d5852
 
 
 
 
 
50c8b86
36d5852
50c8b86
 
 
36d5852
 
 
 
 
50c8b86
 
 
 
 
 
 
 
36d5852
 
 
50c8b86
36d5852
50c8b86
36d5852
50c8b86
 
36d5852
50c8b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36d5852
50c8b86
 
 
 
 
36d5852
50c8b86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
---
license: apache-2.0
datasets:
- japhba/pubmed_simple
language:
- en
tags:
- v2_pretrain_medassist
- gqa
- rope
- swiglu
- rmsnorm
- medical
---

# ๐Ÿง  MedAssist-GPT-401M

**Mid-sized medical-domain LLM pretraining project.**
โš ๏ธ *Strictly for research. Not for clinical or diagnostic use.*

---

## ๐Ÿงฉ TL;DR

* **Architecture:** Transformer with **RoPE**, **GQA**, **SwiGLU** MLP, and **RMSNorm**
* **Tokenizer:** `tiktoken` `p50k_base` (vocab โ‰ˆ **50,281**)
* **Context length:** 1,024 tokens
* **Parameters:** โ‰ˆ **401 M** (`d_model=1024`, `n_heads=32`, `blocks=24`, `d_ff=2048`)
* **GQA groups:** 8 โ†’ 4 KV heads per 32 query heads
* **Dropout:** 0.0 (pretraining)
* **Precision:** **bf16** mixed precision
* **Training objective:** Next-token prediction
* **Effective batch:** 32 ร— 4 = 128

---

## ๐Ÿ“š Data

| Field                   | Value                             |
| ----------------------- | --------------------------------- |
| **Dataset**             | `japhba/pubmed_simple`            |
| **Text column**         | `abstract`                        |
| **Train/Val split**     | 95 / 5                            |
| **Samples used**        | 100 k abstracts                   |
| **Seq length / stride** | 1,024 / 1,024                     |
| **Cleaning**            | `use_clean=False` (raw abstracts) |

---

## โš™๏ธ Training

| Item                       | Value                                                                 |
| -------------------------- | --------------------------------------------------------------------- |
| **Framework**              | PyTorch                                                               |
| **Precision**              | bf16                                                                  |
| **Objective**              | Causal LM (next-token prediction)                                     |
| **Optimizer**              | AdamW (`ฮฒโ‚ = 0.9`, `ฮฒโ‚‚ = 0.95`, `eps = 1e-8`)                         |
| **Learning rate**          | 3 ร— 10โปโด (linear + 100-step warmup)                                   |
| **Weight decay**           | 0.1                                                                   |
| **Batch size**             | 32 (ร— 4 grad acc โ†’ 128 effective)                                     |
| **Grad clip**              | 1.0                                                                   |
| **Total steps**            | 100 k                                                                 |
| **Eval**                   | every 500 steps ร— 100 iters                                           |
| **Checkpoint save**        | every 1 k steps                                                       |
| **Seed**                   | 7 979 797                                                             |
| **Gradient checkpointing** | โœ… Enabled                                                             |
| **WandB**                  | `kunjcr2-dreamable/MedAssist-GPT-Pretraining` (`medassist-401M-test`) |
| **HF repo**                | `kunjcr2/MedAssist-GPT-401M`                                          |

---

## ๐Ÿงฎ Training Environment

| Item                | Value                  |
| ------------------- | ---------------------- |
| **Hardware**        | 1ร— NVIDIA A100 (80 GB) |
| **Precision dtype** | bf16                   |
| **Runtime**         | ~15 hours              |
| **Scheduler**       | Linear LR decay        |
| **Mixed precision** | Native AMP (bf16)      |

---

## ๐Ÿ“ˆ Loss Curves

*(Placeholder โ€” will update post-training)*
![train\_loss](/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F67c358189919777813863c48%2FbQGVqgx4GoqXZTcMh8KhM.png%3C%2Fspan%3E)
![val\_loss](/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F67c358189919777813863c48%2FjhNnS_Wvhj4-fzNoO2dRN.png%3C%2Fspan%3E)

---

## ๐Ÿš€ Minimal Inference

```python
# pip install torch tiktoken huggingface_hub safetensors
import torch, tiktoken
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from MedAssistGPT import MedAssistGPT, MODEL_CONFIG

REPO_ID = "kunjcr2/MedAssist-GPT-401M"
weights = hf_hub_download(REPO_ID, "model.safetensors")
state = load_file(weights, device="cpu")

model = MedAssistGPT(MODEL_CONFIG)
model.load_state_dict(state, strict=True).eval()

enc = tiktoken.get_encoding("p50k_base")
ids = torch.tensor([enc.encode(
    "A patient was admitted with severe headache. Initial assessment revealed"
)], dtype=torch.long)

for _ in range(100):
    logits = model(ids)[:, -1, :]
    next_id = torch.multinomial(torch.softmax(logits / 0.6, dim=-1), 1)
    ids = torch.cat([ids, next_id], dim=1)
print(enc.decode(ids[0].tolist()))
```

---

## ๐Ÿ’พ Checkpoints

* Main run: `medassist-401M-test`
* Checkpoint: `/checkpoints/checkpoint_step_44500.pt`

---

## ๐Ÿงช Intended Use

For research and experimentation only โ€” e.g.,

* domain-adapted pretraining,
* architecture exploration,
* fine-tuning for medical text understanding.

๐Ÿšซ **Not intended for clinical or production medical use.**

---

## ๐Ÿ”ฎ Future Work

Next update includes:

* **Supervised fine-tuning (SFT)**
* **Reinforcement Learning (PPO) for alignment**

---

## ๐Ÿ“ Files

* 'checkpoints/'
* `config.json`, `tokenizer_config.json`
* Training script / notebook defining `MedAssistGPT`

---

## ๐Ÿชช License

Apache 2.0