kunjcr2 commited on
Commit
50c8b86
ยท
verified ยท
1 Parent(s): 098a80a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +122 -42
README.md CHANGED
@@ -1,80 +1,160 @@
1
  ---
2
  license: apache-2.0
3
  datasets:
4
- - Hack90/europe_pmc_articles_part_2
5
  language:
6
  - en
7
  tags:
8
- - v0_pretrain_medassist
 
 
 
 
 
9
  ---
10
- # MedAssist-GPT
11
 
12
- Tiny medical-domain LLM pretraining project.
13
- **NOT for clinical use.**
14
 
15
- ## TL;DR
 
16
 
17
- * **Arch:** Transformer with **RoPE** + **GQA**, **SwiGLU** MLP, **RMSNorm**, causal LM head (tied embeddings).
18
- * **Tokenizer:** `tiktoken` **p50k_base** (vocab โ‰ˆ 50,281).
19
- * **Context:** 1,024 tokens (default).
20
- * **Size (default config):** ~125M params (d_model=512, n_heads=16, layers=16, d_ff=2048).
21
- * **Trained on** about 2.2B tokens of pure medical data.
22
 
23
- ## Data (example)
24
 
25
- * Source: `Hack90/europe_pmc_articles_part_2` (`full_text`).
26
- * XML โ†’ plain text via `clean()`; sliding windows (`max_length=1024`, `stride=1024`).
 
 
 
 
 
 
 
27
 
28
- ## Training (script)
29
 
30
- * AdamW + OneCycleLR, bf16 AMP, grad accumulation, checkpoints, optional HF upload, wandb logging.
31
 
32
- ## Loss
 
 
 
 
 
 
 
33
 
 
34
 
35
- ![train_loss](https://cdn-uploads.huggingface.co/production/uploads/67c358189919777813863c48/bQGVqgx4GoqXZTcMh8KhM.png)
36
- ![val_loss](https://cdn-uploads.huggingface.co/production/uploads/67c358189919777813863c48/jhNnS_Wvhj4-fzNoO2dRN.png)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- ## Try it (minimal)
 
 
 
 
 
 
 
 
39
 
40
  ```python
41
  # pip install torch tiktoken huggingface_hub safetensors
42
  import torch, tiktoken
43
  from safetensors.torch import load_file
44
  from huggingface_hub import hf_hub_download
 
45
 
46
- REPO_ID = "kunjcr2/MedAssistGPT" # change if needed
47
- WEIGHTS = hf_hub_download(REPO_ID, "model.safetensors")
48
- state = load_file(WEIGHTS, device="cpu")
49
-
50
- # Import your MedAssistGPT class from the script/notebook
51
- from MedAssistGPT import MedAssistGPT, MODEL_CONFIG # ensure paths match your repo
52
 
53
  model = MedAssistGPT(MODEL_CONFIG)
54
  model.load_state_dict(state, strict=True).eval()
55
 
56
  enc = tiktoken.get_encoding("p50k_base")
57
- ids = torch.tensor([enc.encode("To live a good life")], dtype=torch.long)
58
- with torch.no_grad():
59
- for _ in range(100):
60
- logits = model(ids)[:, -1, :]
61
- next_id = torch.multinomial(torch.softmax(logits/0.7, dim=-1), 1)
62
- ids = torch.cat([ids, next_id], dim=1)
63
- if next_id.item() == enc.eot_token: break
64
-
65
  print(enc.decode(ids[0].tolist()))
66
  ```
67
 
68
- ## Intended use & limitations
69
 
70
- Research/experimentation + downstream finetuning after pretraining.
71
- Do **NOT** use for medical decisions.
72
 
73
- ## Files
 
74
 
75
- * `model.safetensors` (weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  * `config.json`, `tokenizer_config.json`
77
- * Script/notebook defining `MedAssistGPT` class
 
 
 
 
78
 
79
- ## License
80
- Apache-2.0
 
1
  ---
2
  license: apache-2.0
3
  datasets:
4
+ - japhba/pubmed_simple
5
  language:
6
  - en
7
  tags:
8
+ - v2_pretrain_medassist
9
+ - gqa
10
+ - rope
11
+ - swiglu
12
+ - rmsnorm
13
+ - medical
14
  ---
 
15
 
16
+ # ๐Ÿง  MedAssist-GPT-401M
 
17
 
18
+ **Mid-sized medical-domain LLM pretraining project.**
19
+ โš ๏ธ *Strictly for research. Not for clinical or diagnostic use.*
20
 
21
+ ---
 
 
 
 
22
 
23
+ ## ๐Ÿงฉ TL;DR
24
 
25
+ * **Architecture:** Transformer with **RoPE**, **GQA**, **SwiGLU** MLP, and **RMSNorm**
26
+ * **Tokenizer:** `tiktoken` `p50k_base` (vocab โ‰ˆ **50,281**)
27
+ * **Context length:** 1,024 tokens
28
+ * **Parameters:** โ‰ˆ **401 M** (`d_model=1024`, `n_heads=32`, `blocks=24`, `d_ff=2048`)
29
+ * **GQA groups:** 8 โ†’ 4 KV heads per 32 query heads
30
+ * **Dropout:** 0.0 (pretraining)
31
+ * **Precision:** **bf16** mixed precision
32
+ * **Training objective:** Next-token prediction
33
+ * **Effective batch:** 32 ร— 4 = 128
34
 
35
+ ---
36
 
37
+ ## ๐Ÿ“š Data
38
 
39
+ | Field | Value |
40
+ | ----------------------- | --------------------------------- |
41
+ | **Dataset** | `japhba/pubmed_simple` |
42
+ | **Text column** | `abstract` |
43
+ | **Train/Val split** | 95 / 5 |
44
+ | **Samples used** | 100 k abstracts |
45
+ | **Seq length / stride** | 1,024 / 1,024 |
46
+ | **Cleaning** | `use_clean=False` (raw abstracts) |
47
 
48
+ ---
49
 
50
+ ## โš™๏ธ Training
51
+
52
+ | Item | Value |
53
+ | -------------------------- | --------------------------------------------------------------------- |
54
+ | **Framework** | PyTorch |
55
+ | **Precision** | bf16 |
56
+ | **Objective** | Causal LM (next-token prediction) |
57
+ | **Optimizer** | AdamW (`ฮฒโ‚ = 0.9`, `ฮฒโ‚‚ = 0.95`, `eps = 1e-8`) |
58
+ | **Learning rate** | 3 ร— 10โปโด (linear + 100-step warmup) |
59
+ | **Weight decay** | 0.1 |
60
+ | **Batch size** | 32 (ร— 4 grad acc โ†’ 128 effective) |
61
+ | **Grad clip** | 1.0 |
62
+ | **Total steps** | 100 k |
63
+ | **Eval** | every 500 steps ร— 100 iters |
64
+ | **Checkpoint save** | every 1 k steps |
65
+ | **Seed** | 7 979 797 |
66
+ | **Gradient checkpointing** | โœ… Enabled |
67
+ | **WandB** | `kunjcr2-dreamable/MedAssist-GPT-Pretraining` (`medassist-401M-test`) |
68
+ | **HF repo** | `kunjcr2/MedAssist-GPT-401M` |
69
+
70
+ ---
71
+
72
+ ## ๐Ÿงฎ Training Environment
73
+
74
+ | Item | Value |
75
+ | ------------------- | ---------------------- |
76
+ | **Hardware** | 1ร— NVIDIA A100 (80 GB) |
77
+ | **Precision dtype** | bf16 |
78
+ | **Runtime** | ~15 hours |
79
+ | **Scheduler** | Linear LR decay |
80
+ | **Mixed precision** | Native AMP (bf16) |
81
+
82
+ ---
83
 
84
+ ## ๐Ÿ“ˆ Loss Curves
85
+
86
+ *(Placeholder โ€” will update post-training)*
87
+ ![train\_loss](https://cdn-uploads.huggingface.co/production/uploads/67c358189919777813863c48/bQGVqgx4GoqXZTcMh8KhM.png)
88
+ ![val\_loss](https://cdn-uploads.huggingface.co/production/uploads/67c358189919777813863c48/jhNnS_Wvhj4-fzNoO2dRN.png)
89
+
90
+ ---
91
+
92
+ ## ๐Ÿš€ Minimal Inference
93
 
94
  ```python
95
  # pip install torch tiktoken huggingface_hub safetensors
96
  import torch, tiktoken
97
  from safetensors.torch import load_file
98
  from huggingface_hub import hf_hub_download
99
+ from MedAssistGPT import MedAssistGPT, MODEL_CONFIG
100
 
101
+ REPO_ID = "kunjcr2/MedAssist-GPT-401M"
102
+ weights = hf_hub_download(REPO_ID, "model.safetensors")
103
+ state = load_file(weights, device="cpu")
 
 
 
104
 
105
  model = MedAssistGPT(MODEL_CONFIG)
106
  model.load_state_dict(state, strict=True).eval()
107
 
108
  enc = tiktoken.get_encoding("p50k_base")
109
+ ids = torch.tensor([enc.encode(
110
+ "A patient was admitted with severe headache. Initial assessment revealed"
111
+ )], dtype=torch.long)
112
+
113
+ for _ in range(100):
114
+ logits = model(ids)[:, -1, :]
115
+ next_id = torch.multinomial(torch.softmax(logits / 0.6, dim=-1), 1)
116
+ ids = torch.cat([ids, next_id], dim=1)
117
  print(enc.decode(ids[0].tolist()))
118
  ```
119
 
120
+ ---
121
 
122
+ ## ๐Ÿ’พ Checkpoints
 
123
 
124
+ * Main run: `medassist-401M-test`
125
+ * Checkpoint: `/checkpoints/checkpoint_step_44500.pt`
126
 
127
+ ---
128
+
129
+ ## ๐Ÿงช Intended Use
130
+
131
+ For research and experimentation only โ€” e.g.,
132
+
133
+ * domain-adapted pretraining,
134
+ * architecture exploration,
135
+ * fine-tuning for medical text understanding.
136
+
137
+ ๐Ÿšซ **Not intended for clinical or production medical use.**
138
+
139
+ ---
140
+
141
+ ## ๐Ÿ”ฎ Future Work
142
+
143
+ Next update includes:
144
+
145
+ * **Supervised fine-tuning (SFT)**
146
+ * **Reinforcement Learning (PPO) for alignment**
147
+
148
+ ---
149
+
150
+ ## ๐Ÿ“ Files
151
+
152
+ * 'checkpoints/'
153
  * `config.json`, `tokenizer_config.json`
154
+ * Training script / notebook defining `MedAssistGPT`
155
+
156
+ ---
157
+
158
+ ## ๐Ÿชช License
159
 
160
+ Apache 2.0