CryoFM: Flow-based Foundation Model for Cryo-EM Density Maps

arXiv GitHub License

CryoFM Demo

Model Description

CryoFM1 is a flow-based foundation model for 3D cryo-electron microscopy (cryo-EM) density maps. The model employs a Hierarchical Diffusion Transformer (HDiT) architecture, specifically designed to learn deep priors of 3D cryo-EM densities. CryoFM1 supports various downstream tasks including density map denoising, anisotropy noise correction, missing wedge inpainting, and ab initio modeling.

Key Features

  • Flow Matching Framework: Uses flow matching for efficient and stable training
  • HDiT Architecture: Hierarchical Diffusion Transformer with local and global attention mechanisms
  • Two Model Variants: CryoFM-S (64³) and CryoFM-L (128³) for different resolution needs
  • Downstream Task Support: Denoising, anisotropy noise correction, missing wedge restoration, and more

Model Details

CryoFM1 employs a Hierarchical Diffusion Transformer (HDiT) architecture that combines local neighborhood attention with global attention mechanisms. This design enables the model to effectively capture both fine-grained local structures and long-range dependencies in 3D cryo-EM density maps. The architecture processes 3D volumes through a hierarchical patch-based approach, progressively building representations at multiple scales.

CryoFM Architecture

The model is available in two variants optimized for different resolution requirements. The following table summarizes the key architectural and training parameters for each variant:

Parameter CRYOFM-S CRYOFM-L
Parameters 335.18 M 308.54 M
GFLOP/forward 395.87 427.26
Training Steps 150k 300k
Batch Size 128 128
Precision bf16 bf16
Training Hardware 8×A100 8×A100
Patchifying 4 4
Levels (Local + Global Attention) 1 + 1 2 + 1
Depth [4, 8] [2, 2, 12]
Widths [768, 1536] [320, 640, 1280]
Attention Heads (Width / Head Dim) [12, 24] [5, 10, 20]
Attention Head Dim 64 64
Neighborhood Kernel Size 7 7

Quick Start

Unconditional Generation

CryoFM1 provides two model variants for different resolution needs:

  • CryoFM-S: Generates 64×64×64 voxel density maps at 1.5 Å/pixel resolution
  • CryoFM-L: Generates 128×128×128 voxel density maps at 3.0 Å/pixel resolution
import torch
from mmengine import Config
from cryofm.core.utils.mrc_io import save_mrc
from cryofm.projects.cryofm1.lit_modules import CryoFM1
from cryofm.core.utils.sampling_fm import sample_from_fm

# Choose model variant: "cryofm-s" or "cryofm-l"
model_variant = "cryofm-s"  # or "cryofm-l"
model_config = {
    "cryofm-s": {
        "config_path": "cryofm-v1/cryofm-s/config.yaml",
        "model_path": "cryofm-v1/cryofm-s/model.safetensors",
        "side_shape": 64,
        "apix": 1.5
    },
    "cryofm-l": {
        "config_path": "cryofm-v1/cryofm-l/config.yaml",
        "model_path": "cryofm-v1/cryofm-l/model.safetensors",
        "side_shape": 128,
        "apix": 3.0
    }
}

# Load configuration and model
cfg = Config.fromfile(model_config[model_variant]["config_path"])
lit_model = CryoFM1.load_from_safetensors(
    model_config[model_variant]["model_path"], 
    cfg=cfg
)

# Set up device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model = lit_model.to(device)
lit_model.eval()

# Define vector field function for flow matching
def v_xt_t(_xt, _t):
    return lit_model(_xt, _t)

# Generate samples
# Note: Enable bfloat16 if your GPU supports it for better performance
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    out = sample_from_fm(
        v_xt_t, 
        lit_model.noise_scheduler, 
        method="euler", 
        num_steps=200, 
        num_samples=3, 
        device=device, 
        side_shape=model_config[model_variant]["side_shape"]
    )
    # Apply z-scaling normalization if configured
    if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
        out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean

# Save generated density maps
for i in range(3):
    save_mrc(
        out[i].float().cpu().numpy(), 
        f"sample-{i}.mrc", 
        apix=model_config[model_variant]["apix"]  # Angstroms per pixel
    )

Ethical Considerations

This model is intended for scientific research and structural biology applications. Users should:

  • Ensure proper attribution when using generated structures
  • Validate generated structures through experimental verification
  • Be aware of potential biases in the training data

Citation

If you use CryoFM1 in your research, please cite:

@inproceedings{
  zhou2025cryofm,
  title={Cryo{FM}: A Flow-based Foundation Model for Cryo-{EM} Densities},
  author={Yi Zhou and Yilai Li and Jing Yuan and Quanquan Gu},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=T4sMzjy7fO}
}

License

This model is released under the Apache 2.0 License. See the LICENSE file for details.

Acknowledgments

This work is developed by the ByteDance Seed Team. For more information, visit:

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support