eugeneyuan commited on
Commit
0ef798b
·
verified ·
1 Parent(s): c6c44b3

Add CryoFM2 models

Browse files

Upload CryoFM2 model checkpoints and configurations:
- cryofm2-pretrain: Unconditional pretrained model
- cryofm2-emhancer: Fine-tuned model for EMhancer-style enhancement
- cryofm2-emready: Fine-tuned model for EMReady-style enhancement

.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cryofm2_arch-finetune.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/cryofm2_arch-pretrain.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/cryofm2_overview.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,269 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - cryo-em
5
+ - flow-matching
6
+ - 3d-density-maps
7
+ - foundation-model
8
+ - conditional-sampling
9
+ ---
10
+
11
+ # CryoFM2: A Generative Foundation Model for Cryo-EM Densities
12
+
13
+ <div align="center">
14
+
15
+ [![GitHub](https://img.shields.io/badge/GitHub-cryofm-181717?logo=github&logoColor=white)](https://github.com/ByteDance-Seed/cryofm)
16
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
17
+
18
+ </div>
19
+
20
+ <div align="center">
21
+ <img src="./assets/cryofm2_overview.jpg" alt="CryoFM2 Overview" style="max-width: 100%; height: auto; width: 800px;"/>
22
+ </div>
23
+
24
+ ## Overview
25
+
26
+ **CryoFM2** is a flow-based generative foundation model for cryo-EM density maps.
27
+ It is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities and can be fine-tuned for downstream tasks.
28
+
29
+ The model learns a continuous mapping from a simple Gaussian distribution to the complex distribution of cryo-EM densities, enabling stable generation and flexible adaptation. CryoFM2 can also act as a **Bayesian prior**, integrating naturally with task-specific likelihoods to support applications such as anisotropy-aware refinement, non-uniform reconstruction, and controlled density modification.
30
+
31
+ ## Model Details
32
+
33
+ CryoFM2 is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities. The model can be fine-tuned for various downstream tasks such as density map enhancement and post-processing.
34
+
35
+ **Pre-training Architecture:**
36
+
37
+ <div align="center">
38
+ <img src="./assets/cryofm2_arch-pretrain.jpg" alt="CryoFM2 architecture for pre-training." style="max-width: 100%; height: auto; width: 800px;"/>
39
+ </div>
40
+
41
+ **Fine-tuning Architecture (for EMhancer/EMReady style post-processing):**
42
+
43
+ <div align="center">
44
+ <img src="./assets/cryofm2_arch-finetune.jpg" alt="CryoFM2 architecture for fine-tuning." style="max-width: 100%; height: auto; width: 800px;"/>
45
+ </div>
46
+
47
+ ### Architecture
48
+ - **Architecture Type**: 3D UNet
49
+ - **Input Size**: 64×64×64 voxels
50
+ - **Input Channels**: 2 for pre-trained model, 3 for fine-tuned model
51
+ - **Output Channels**: 1
52
+ - **Down Blocks**: DownBlock3D, DownBlock3D, AttnDownBlock3D, AttnDownBlock3D
53
+ - **Up Blocks**: AttnUpBlock3D, AttnUpBlock3D, UpBlock3D, UpBlock3D
54
+ - **Block Output Channels**: (64, 128, 256, 512)
55
+ - **Layers per Block**: 2
56
+ - **Attention Head Dimension**: 8
57
+ - **Normalization**: GroupNorm (32 groups)
58
+ - **Activation**: SiLU
59
+ - **Time Embedding**: Positional encoding
60
+
61
+ ### Model Variants
62
+
63
+ 1. **cryofm2-pretrain**: Unconditional pretrained model for general density map generation
64
+ 2. **cryofm2-emhancer**: Fine-tuned model for density map enhancement (EMhancer style)
65
+ 3. **cryofm2-emready**: Fine-tuned model for density map enhancement (EMReady style)
66
+
67
+ ## Play with CryoFM2
68
+
69
+ ### Unconditional Generation (Explore Training Data Distribution)
70
+
71
+ Generate samples from the pretrained model to explore the learned data distribution:
72
+
73
+ **Pretrained Model:**
74
+ ```python
75
+ import torch
76
+ from mmengine import Config
77
+
78
+ from cryofm.core.utils.mrc_io import save_mrc
79
+ from cryofm.core.utils.sampling_fm import sample_from_fm
80
+ from cryofm.projects.cryofm2.lit_modules import CryoFM2Uncond
81
+
82
+ # Update the path to your model directory
83
+ model_dir = "path/to/cryofm-v2/cryofm2-pretrain"
84
+ cfg = Config.fromfile(f"{model_dir}/config.yaml")
85
+ lit_model = CryoFM2Uncond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)
86
+
87
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+
89
+ lit_model = lit_model.to(device)
90
+ lit_model.eval()
91
+ def v_xt_t(_xt, _t):
92
+ return lit_model(_xt, _t)
93
+
94
+ # Enable bfloat16 for faster inference if your GPU supports it
95
+ with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
96
+ out = sample_from_fm(
97
+ v_xt_t,
98
+ lit_model.noise_scheduler,
99
+ method="euler",
100
+ num_steps=200,
101
+ num_samples=3,
102
+ device=lit_model.device,
103
+ side_shape=64
104
+ )
105
+ # Apply normalization if configured
106
+ if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
107
+ out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean
108
+
109
+ # Save generated samples
110
+ for i in range(3):
111
+ save_mrc(out[i].float().cpu().numpy(), f"sample-{i}.mrc", voxel_size=1.5)
112
+ ```
113
+
114
+ **Fine-tuned Models (EMhancer/EMReady):**
115
+ ```python
116
+ import torch
117
+ from mmengine import Config
118
+
119
+ from cryofm.core.utils.mrc_io import save_mrc
120
+ from cryofm.core.utils.sampling_fm import sample_from_fm
121
+ from cryofm.projects.cryofm2.lit_modules import CryoFM2Cond
122
+
123
+ # Choose style: "emhancer" or "emready"
124
+ style = "emhancer"
125
+ model_dir = f"path/to/cryofm-v2/cryofm2-{style}"
126
+ cfg = Config.fromfile(f"{model_dir}/config.yaml")
127
+ lit_model = CryoFM2Cond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)
128
+ output_tag = 1 if style == "emhancer" else 0
129
+
130
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+
132
+ lit_model = lit_model.to(device)
133
+ lit_model.eval()
134
+ def v_xt_t(_xt, _t):
135
+ bs = _xt.shape[0]
136
+ unconditional_generation_conds = {
137
+ "input_cond": None,
138
+ "output_cond": torch.tensor([output_tag] * bs).to(device),
139
+ "vol_cond": None, # dimension should be [bs, d, h, w]
140
+ }
141
+ return lit_model(_xt, _t, generation_conds=unconditional_generation_conds)
142
+
143
+ # Enable bfloat16 for faster inference if your GPU supports it
144
+ with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
145
+ out = sample_from_fm(
146
+ v_xt_t,
147
+ lit_model.noise_scheduler,
148
+ method="euler",
149
+ num_steps=200,
150
+ num_samples=3,
151
+ device=lit_model.device,
152
+ side_shape=64
153
+ )
154
+ # Apply normalization if configured
155
+ if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
156
+ out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean
157
+
158
+ # Save generated samples
159
+ for i in range(3):
160
+ save_mrc(out[i].float().cpu().numpy(), f"{style}-sample-{i}.mrc", voxel_size=1.5)
161
+ ```
162
+
163
+ ### Density Map Modification
164
+
165
+ CryoFM2 supports various density map modification operations using the pretrained model as a Bayesian prior. Supported operators include:
166
+
167
+ - **denoise**: Remove noise from density maps
168
+ - **inpaint**: Fill missing regions (e.g., missing wedge)
169
+ - **denoise inpaint**: Combined denoising and inpainting
170
+ - **non-uniform weight**: Apply non-uniform weighting during reconstruction
171
+
172
+ **Basic Usage:**
173
+
174
+ ```bash
175
+ python -m cryofm.projects.cryofm2.uncond_sampling \
176
+ -i1 half_map_1.mrc \
177
+ -i2 half_map_2.mrc \
178
+ -o ./output \
179
+ --model-dir path/to/cryofm-v2/cryofm2-pretrain \
180
+ --op denoise \
181
+ --norm-grad \
182
+ --use-lamb-w
183
+ ```
184
+
185
+ **For inpainting tasks**, you need to provide a RELION starfile path:
186
+
187
+ ```bash
188
+ python -m cryofm.projects.cryofm2.uncond_sampling \
189
+ -i1 half_map_1.mrc \
190
+ -i2 half_map_2.mrc \
191
+ -o ./output \
192
+ --model-dir path/to/cryofm-v2/cryofm2-pretrain \
193
+ --op inpaint \
194
+ --data-starfile-path path/to/relion_data.star \
195
+ --norm-grad \
196
+ --use-lamb-w
197
+ ```
198
+
199
+ ### Density Map Post-Processing
200
+
201
+ CryoFM2 provides fine-tuned models for density map enhancement in different styles, similar to EMhancer and EMReady.
202
+
203
+ #### EMhancer Style Enhancement
204
+
205
+ ```bash
206
+ python -m cryofm.projects.cryofm2.cond_sampling \
207
+ -i input_map.mrc \
208
+ -o ./output_emhancer \
209
+ --model-dir path/to/cryofm-v2/cryofm2-emhancer \
210
+ --output-tag 1
211
+ ```
212
+
213
+ #### EMReady Style Enhancement
214
+
215
+ ```bash
216
+ python -m cryofm.projects.cryofm2.cond_sampling \
217
+ -i input_map.mrc \
218
+ -o ./output_emready \
219
+ --model-dir path/to/cryofm-v2/cryofm2-emready \
220
+ --output-tag 0 \
221
+ --cfg-weight 0.5
222
+ ```
223
+
224
+ **Parameters:**
225
+ - `-i`: Input density map file (MRC format)
226
+ - `-o`: Output directory
227
+ - `--model-dir`: Path to the model directory containing `config.yaml` and `model.safetensors`
228
+ - `--output-tag`: Style tag (1 for EMhancer, 0 for EMReady)
229
+ - `--cfg-weight`: Classifier-free guidance weight (optional, default varies by model)
230
+
231
+
232
+ ## Performance Tips
233
+
234
+ - **Multi-GPU Inference**: Use `accelerate launch` for faster inference on multiple GPUs:
235
+ ```bash
236
+ NCCL_DEBUG=ERROR accelerate launch --num_processes=${NUM_GPUS} --main_process_port=8881 \
237
+ python -m cryofm.projects.cryofm2.cond_sampling ...
238
+ ```
239
+ - **Mixed Precision**: Use `--bf16` flag when available to reduce memory usage and speed up inference.
240
+ - **Batch Processing**: Adjust batch size based on your GPU memory capacity.
241
+
242
+ ## Limitations
243
+
244
+ - Input size is fixed at 64×64×64 voxels
245
+ - Model performance may vary depending on the input density map quality
246
+ - Fine-tuned models are optimized for specific enhancement styles
247
+
248
+ ## Ethical Considerations
249
+
250
+ This model is intended for scientific research and structural biology applications. Users should:
251
+ - Ensure proper attribution when using generated structures
252
+ - Validate generated structures through experimental verification
253
+ - Be aware of potential biases in the training data
254
+ - Use the model responsibly and in accordance with scientific best practices
255
+
256
+ ## Citation
257
+
258
+ TBA
259
+
260
+ ## License
261
+
262
+ This model is released under the Apache 2.0 License. See the [LICENSE](https://github.com/ByteDance-Seed/cryofm/blob/main/LICENSE) file for details.
263
+
264
+ ## Acknowledgments
265
+
266
+ This work is developed by the ByteDance Seed Team. For more information, visit:
267
+ - [Project Repository](https://github.com/ByteDance-Seed/cryofm)
268
+ - [ByteDance Seed Team](https://seed.bytedance.com/)
269
+
assets/cryofm2_arch-finetune.jpg ADDED

Git LFS Details

  • SHA256: 9e9a3e72b249ed4e8ed55b9a60d1bd9b905abf1477ee04cd3e42357d9ba2192b
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB
assets/cryofm2_arch-pretrain.jpg ADDED

Git LFS Details

  • SHA256: d911ffcb91e8b2eab5b58b897363aa7b3cdfa30b9a886f2948dd96973c1b4107
  • Pointer size: 131 Bytes
  • Size of remote file: 273 kB
assets/cryofm2_overview.jpg ADDED

Git LFS Details

  • SHA256: 025cea599802e4abacddead3bb96a30c231d9ce14bbe2f159f96d29afd039a21
  • Pointer size: 131 Bytes
  • Size of remote file: 676 kB
cryofm2-emhancer/config.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path: null
2
+ ddpm:
3
+ cond_drop_threshold: 0.1
4
+ prediction_type: v_prediction
5
+ exp_name: cond_model_emhancer
6
+ inference:
7
+ batch_size: 16
8
+ patch_overlap: 0
9
+ is_debug: false
10
+ keep_last_k: null
11
+ mode: train
12
+ model:
13
+ act_fn: silu
14
+ attention_head_dim: 8
15
+ attn_norm_num_groups: null
16
+ block_out_channels: !!python/tuple
17
+ - 64
18
+ - 128
19
+ - 256
20
+ - 512
21
+ class_embed_type: null
22
+ down_block_types: !!python/tuple
23
+ - DownBlock3D
24
+ - DownBlock3D
25
+ - AttnDownBlock3D
26
+ - AttnDownBlock3D
27
+ downsample_padding: 1
28
+ downsample_type: conv
29
+ dropout: 0.0
30
+ flip_sin_to_cos: true
31
+ freq_shift: 0
32
+ in_channels: 3
33
+ layers_per_block: 2
34
+ mid_block_scale_factor: 1
35
+ norm_eps: 1.0e-05
36
+ norm_num_groups: 32
37
+ num_class_embeds: 5
38
+ out_channels: 1
39
+ resnet_time_scale_shift: scale_shift
40
+ sample_size: 64
41
+ time_embedding_dim: null
42
+ time_embedding_type: positional
43
+ up_block_types: !!python/tuple
44
+ - AttnUpBlock3D
45
+ - AttnUpBlock3D
46
+ - UpBlock3D
47
+ - UpBlock3D
48
+ upsample_type: conv
49
+ model_type: unet
50
+ num_val_samples: 3
51
+ optimizer:
52
+ lr: 0.0001
53
+ warmup: 2000
54
+ patch_size: 64
55
+ process: fm
56
+ resume_path: null
57
+ seed: 42
58
+ selective_datasets: emhancer
59
+ timestep_sampling: uniform
60
+ work_dir: work_dirs/cond_model_emhancer
61
+ z_crop: null
62
+ z_scale:
63
+ mean: null
64
+ std: null
cryofm2-emhancer/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96576420fe03fc93088b40fdb1e7f785d100e6ed49833050c961670bdfaee163
3
+ size 672409268
cryofm2-emready/config.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path: null
2
+ ddpm:
3
+ cond_drop_threshold: 0.1
4
+ prediction_type: v_prediction
5
+ exp_name: cond_model_emready
6
+ inference:
7
+ batch_size: 16
8
+ patch_overlap: 0
9
+ is_debug: false
10
+ keep_last_k: null
11
+ mode: train
12
+ model:
13
+ act_fn: silu
14
+ attention_head_dim: 8
15
+ attn_norm_num_groups: null
16
+ block_out_channels: !!python/tuple
17
+ - 64
18
+ - 128
19
+ - 256
20
+ - 512
21
+ class_embed_type: null
22
+ down_block_types: !!python/tuple
23
+ - DownBlock3D
24
+ - DownBlock3D
25
+ - AttnDownBlock3D
26
+ - AttnDownBlock3D
27
+ downsample_padding: 1
28
+ downsample_type: conv
29
+ dropout: 0.0
30
+ flip_sin_to_cos: true
31
+ freq_shift: 0
32
+ in_channels: 3
33
+ layers_per_block: 2
34
+ mid_block_scale_factor: 1
35
+ norm_eps: 1.0e-05
36
+ norm_num_groups: 32
37
+ num_class_embeds: 5
38
+ out_channels: 1
39
+ resnet_time_scale_shift: scale_shift
40
+ sample_size: 64
41
+ time_embedding_dim: null
42
+ time_embedding_type: positional
43
+ up_block_types: !!python/tuple
44
+ - AttnUpBlock3D
45
+ - AttnUpBlock3D
46
+ - UpBlock3D
47
+ - UpBlock3D
48
+ upsample_type: conv
49
+ model_type: unet
50
+ num_val_samples: 3
51
+ optimizer:
52
+ lr: 0.0001
53
+ warmup: 2000
54
+ patch_size: 64
55
+ process: fm
56
+ resume_path: null
57
+ seed: 42
58
+ selective_datasets: emready
59
+ timestep_sampling: uniform
60
+ work_dir: work_dirs/cond_model_emready
61
+ z_crop: null
62
+ z_scale:
63
+ mean: null
64
+ std: null
cryofm2-emready/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77c1fa590eaee1906e8470d3659c981855a92fcf4e6a7817b48f0069cd6d2bca
3
+ size 672409268
cryofm2-pretrain/config.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path: null
2
+ ddpm:
3
+ cond_drop_threshold: 3
4
+ prediction_type: v_prediction
5
+ exp_name: uncond_model
6
+ inference:
7
+ batch_size: 16
8
+ patch_overlap: 0
9
+ is_debug: false
10
+ keep_last_k: null
11
+ mode: train
12
+ model:
13
+ act_fn: silu
14
+ attention_head_dim: 8
15
+ attn_norm_num_groups: null
16
+ block_out_channels: !!python/tuple
17
+ - 64
18
+ - 128
19
+ - 256
20
+ - 512
21
+ down_block_types: !!python/tuple
22
+ - DownBlock3D
23
+ - DownBlock3D
24
+ - AttnDownBlock3D
25
+ - AttnDownBlock3D
26
+ downsample_padding: 1
27
+ downsample_type: conv
28
+ dropout: 0.0
29
+ flip_sin_to_cos: true
30
+ freq_shift: 0
31
+ in_channels: 2
32
+ layers_per_block: 2
33
+ mid_block_scale_factor: 1
34
+ norm_eps: 1.0e-05
35
+ norm_num_groups: 32
36
+ out_channels: 1
37
+ resnet_time_scale_shift: scale_shift
38
+ sample_size: 64
39
+ time_embedding_dim: null
40
+ time_embedding_type: positional
41
+ up_block_types: !!python/tuple
42
+ - AttnUpBlock3D
43
+ - AttnUpBlock3D
44
+ - UpBlock3D
45
+ - UpBlock3D
46
+ upsample_type: conv
47
+ model_type: unet
48
+ num_val_samples: 3
49
+ optimizer:
50
+ lr: 0.0001
51
+ warmup: 2000
52
+ patch_size: 64
53
+ process: fm
54
+ resume_path: null
55
+ seed: 42
56
+ timestep_sampling: uniform
57
+ work_dir: work_dirs/uncond_model
58
+ z_crop: null
59
+ z_scale:
60
+ mean: null
61
+ std: null
cryofm2-pretrain/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f10dc552fceedae8a3c574e9b9d259de7d1f5047f4f2107d9309fae9512f413
3
+ size 672397148