Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- README.md +92 -3
- added_tokens.json +34 -0
- assets/comparision.png +3 -0
- config.json +66 -0
- model.safetensors.index.json +698 -0
- models/__pycache__/config.cpython-310.pyc +0 -0
- models/__pycache__/gen_pipeline.cpython-310.pyc +0 -0
- models/__pycache__/heads.cpython-310.pyc +0 -0
- models/__pycache__/llama_model.cpython-310.pyc +0 -0
- models/__pycache__/nextstep_model.cpython-310.pyc +0 -0
- models/config.py +45 -0
- models/gen_pipeline.py +398 -0
- models/heads.py +283 -0
- models/llama_model.py +568 -0
- models/nextstep_model.py +553 -0
- pytorch-model-00001-of-00004.safetensors +3 -0
- pytorch-model-00002-of-00004.safetensors +3 -0
- pytorch-model-00003-of-00004.safetensors +3 -0
- pytorch-model-00004-of-00004.safetensors +3 -0
- requirements.txt +14 -0
- special_tokens_map.json +27 -0
- tokenizer.json +3 -0
- tokenizer_config.json +285 -0
- utils/__pycache__/compile_utils.cpython-310.pyc +0 -0
- utils/__pycache__/image_utils.cpython-310.pyc +0 -0
- utils/__pycache__/misc.cpython-310.pyc +0 -0
- utils/__pycache__/model_utils.cpython-310.pyc +0 -0
- utils/aspect_ratio.py +107 -0
- utils/compile_utils.py +122 -0
- utils/image_utils.py +314 -0
- utils/misc.py +51 -0
- utils/model_utils.py +128 -0
- vae/__pycache__/nextstep_ae.cpython-310.pyc +0 -0
- vae/checkpoint.pt +3 -0
- vae/config.json +14 -0
- vae/nextstep_ae.py +494 -0
- vocab.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/comparision.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,92 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
pipeline_tag: text-to-image
|
| 4 |
+
library_name: transformers
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## NextStep-1.1
|
| 8 |
+
|
| 9 |
+
[Homepage](https://stepfun.ai/research/en/nextstep-1)
|
| 10 |
+
| [GitHub](https://github.com/stepfun-ai/NextStep-1)
|
| 11 |
+
| [Paper](https://arxiv.org/abs/2508.10711)
|
| 12 |
+
|
| 13 |
+
We introduce **NextStep-1.1**, a new model represents a significant leap forward in the NextStep series. This version effectively resolves the visualization failures seen in **NextStep-1** and substantially elevates image quality through extended training and a Flow-based Reinforcement Learning (RL) post-training paradigm.
|
| 14 |
+
|
| 15 |
+
<div align='center'>
|
| 16 |
+
<img src="assets/comparision.png" class="interpolation-image" alt="arch." width="100%" />
|
| 17 |
+
</div>
|
| 18 |
+
|
| 19 |
+
## What's New in 1.1?
|
| 20 |
+
|
| 21 |
+
NextStep-1.1 is not just a fine-tune; it is a re-engineered version focused on stability and high-fidelity output. Key improvements include:
|
| 22 |
+
|
| 23 |
+
- RL Enhanced Visual Fidelity: Significant improvement in image texture and a substantial reduction in visual artifacts via RL, ensuring much cleaner and more professional outputs.
|
| 24 |
+
|
| 25 |
+
- Technical Stability: Solves numerical instability inherent in the RL of autoregressive flow-based models.
|
| 26 |
+
|
| 27 |
+
## Environment Setup
|
| 28 |
+
|
| 29 |
+
To avoid potential errors when loading and running your models, we recommend using the following settings:
|
| 30 |
+
|
| 31 |
+
```shell
|
| 32 |
+
conda create -n nextstep python=3.11 -y
|
| 33 |
+
conda activate nextstep
|
| 34 |
+
|
| 35 |
+
pip install uv # optional
|
| 36 |
+
|
| 37 |
+
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/stepfun-ai/NextStep-1.1-Pretrain && cd NextStep-1.1-Pretrain
|
| 38 |
+
uv pip install -r requirements.txt
|
| 39 |
+
|
| 40 |
+
hf download stepfun-ai/NextStep-1.1-Pretrain "vae/checkpoint.pt" --local-dir ./
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Usage
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
import torch
|
| 47 |
+
from transformers import AutoTokenizer, AutoModel
|
| 48 |
+
from models.gen_pipeline import NextStepPipeline
|
| 49 |
+
|
| 50 |
+
HF_HUB = "stepfun-ai/NextStep-1.1-Pretrain"
|
| 51 |
+
|
| 52 |
+
# load model and tokenizer
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=True, trust_remote_code=True)
|
| 54 |
+
model = AutoModel.from_pretrained(HF_HUB, local_files_only=True, trust_remote_code=True)
|
| 55 |
+
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device="cuda", dtype=torch.bfloat16)
|
| 56 |
+
|
| 57 |
+
# set prompts
|
| 58 |
+
positive_prompt = ""
|
| 59 |
+
negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
|
| 60 |
+
example_prompt = "A REALISTIC PHOTOGRAPH OF A WALL WITH \"TOWARD AUTOREGRESSIVE IMAGE GENERATION WITH CONTINUOUS TOKENS AT SCALE\" PROMINENTLY DISPLAYED"
|
| 61 |
+
|
| 62 |
+
# generate image from text
|
| 63 |
+
IMG_SIZE = 512
|
| 64 |
+
image = pipeline.generate_image(
|
| 65 |
+
example_prompt,
|
| 66 |
+
hw=(IMG_SIZE, IMG_SIZE),
|
| 67 |
+
num_images_per_caption=1,
|
| 68 |
+
positive_prompt=positive_prompt,
|
| 69 |
+
negative_prompt=negative_prompt,
|
| 70 |
+
cfg=7.5,
|
| 71 |
+
cfg_img=1.0,
|
| 72 |
+
cfg_schedule="constant",
|
| 73 |
+
use_norm=False,
|
| 74 |
+
num_sampling_steps=28,
|
| 75 |
+
timesteps_shift=1.0,
|
| 76 |
+
seed=3407,
|
| 77 |
+
)[0]
|
| 78 |
+
image.save("./assets/output.jpg")
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Citation
|
| 82 |
+
|
| 83 |
+
If you find NextStep useful for your research and applications, please consider starring this repository and citing:
|
| 84 |
+
|
| 85 |
+
```bibtex
|
| 86 |
+
@article{nextstepteam2025nextstep1,
|
| 87 |
+
title={NextStep-1: Toward Autoregressive Image Generation with Continuous Tokens at Scale},
|
| 88 |
+
author={NextStep Team and Chunrui Han and Guopeng Li and Jingwei Wu and Quan Sun and Yan Cai and Yuang Peng and Zheng Ge and Deyu Zhou and Haomiao Tang and Hongyu Zhou and Kenkun Liu and Ailin Huang and Bin Wang and Changxin Miao and Deshan Sun and En Yu and Fukun Yin and Gang Yu and Hao Nie and Haoran Lv and Hanpeng Hu and Jia Wang and Jian Zhou and Jianjian Sun and Kaijun Tan and Kang An and Kangheng Lin and Liang Zhao and Mei Chen and Peng Xing and Rui Wang and Shiyu Liu and Shutao Xia and Tianhao You and Wei Ji and Xianfang Zeng and Xin Han and Xuelin Zhang and Yana Wei and Yanming Xu and Yimin Jiang and Yingming Wang and Yu Zhou and Yucheng Han and Ziyang Meng and Binxing Jiao and Daxin Jiang and Xiangyu Zhang and Yibo Zhu},
|
| 89 |
+
journal={arXiv preprint arXiv:2508.10711},
|
| 90 |
+
year={2025}
|
| 91 |
+
}
|
| 92 |
+
```
|
added_tokens.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"</tool_call>": 151658,
|
| 3 |
+
"<tool_call>": 151657,
|
| 4 |
+
"<|begin_of_image|>": 151667,
|
| 5 |
+
"<|begin_of_prompt_refinement|>": 151670,
|
| 6 |
+
"<|begin_of_thinking|>": 151672,
|
| 7 |
+
"<|box_end|>": 151649,
|
| 8 |
+
"<|box_start|>": 151648,
|
| 9 |
+
"<|end_of_image|>": 151668,
|
| 10 |
+
"<|end_of_prompt_refinement|>": 151671,
|
| 11 |
+
"<|end_of_thinking|>": 151673,
|
| 12 |
+
"<|beginoftext|>": 151674,
|
| 13 |
+
"<|endoftext|>": 151643,
|
| 14 |
+
"<|file_sep|>": 151664,
|
| 15 |
+
"<|fim_middle|>": 151660,
|
| 16 |
+
"<|fim_pad|>": 151662,
|
| 17 |
+
"<|fim_prefix|>": 151659,
|
| 18 |
+
"<|fim_suffix|>": 151661,
|
| 19 |
+
"<|im_end|>": 151645,
|
| 20 |
+
"<|im_start|>": 151644,
|
| 21 |
+
"<|image_area|>": 151666,
|
| 22 |
+
"<|image_pad|>": 151655,
|
| 23 |
+
"<|image_placeholder|>": 151669,
|
| 24 |
+
"<|object_ref_end|>": 151647,
|
| 25 |
+
"<|object_ref_start|>": 151646,
|
| 26 |
+
"<|quad_end|>": 151651,
|
| 27 |
+
"<|quad_start|>": 151650,
|
| 28 |
+
"<|repo_name|>": 151663,
|
| 29 |
+
"<|video_pad|>": 151656,
|
| 30 |
+
"<|vision_end|>": 151653,
|
| 31 |
+
"<|vision_pad|>": 151654,
|
| 32 |
+
"<|vision_start|>": 151652,
|
| 33 |
+
"[PAD]": 151665
|
| 34 |
+
}
|
assets/comparision.png
ADDED
|
Git LFS Details
|
config.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_attn_implementation_autoset": true,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"LlamaForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"auto_map":{
|
| 7 |
+
"AutoConfig": "models/config.NextStepConfig",
|
| 8 |
+
"AutoModel": "models/nextstep_model.NextStep"
|
| 9 |
+
},
|
| 10 |
+
"attention_bias": true,
|
| 11 |
+
"attention_dropout": 0.0,
|
| 12 |
+
"base_image_grid_size": 64,
|
| 13 |
+
"boi": 151667,
|
| 14 |
+
"bos_token_id": 151643,
|
| 15 |
+
"create_kwargs": {
|
| 16 |
+
"snr_type": "lognorm"
|
| 17 |
+
},
|
| 18 |
+
"eoi": 151668,
|
| 19 |
+
"eos_token_id": 151643,
|
| 20 |
+
"genloss_batch_mul": 4,
|
| 21 |
+
"genloss_depth": 12,
|
| 22 |
+
"genloss_net_arch": "mlp",
|
| 23 |
+
"genloss_num_sampling_steps": "100",
|
| 24 |
+
"genloss_type": "transport",
|
| 25 |
+
"genloss_width": 1536,
|
| 26 |
+
"head_dim": 128,
|
| 27 |
+
"hidden_act": "silu",
|
| 28 |
+
"hidden_size": 5120,
|
| 29 |
+
"image_decoder_arch": "Trans_E",
|
| 30 |
+
"image_encoder_name": null,
|
| 31 |
+
"image_feature_layer": -2,
|
| 32 |
+
"image_loss_weight": 1.0,
|
| 33 |
+
"image_placeholder_id": 151669,
|
| 34 |
+
"image_size": 64,
|
| 35 |
+
"initializer_range": 0.02,
|
| 36 |
+
"intermediate_size": 13824,
|
| 37 |
+
"lm_loss_weight": 0.01,
|
| 38 |
+
"max_position_embeddings": 131072,
|
| 39 |
+
"max_window_layers": 48,
|
| 40 |
+
"mlp_bias": false,
|
| 41 |
+
"model_type": "nextstep",
|
| 42 |
+
"noise_strength": 0.0,
|
| 43 |
+
"num_attention_heads": 40,
|
| 44 |
+
"num_channels": 16,
|
| 45 |
+
"num_hidden_layers": 48,
|
| 46 |
+
"num_key_value_heads": 8,
|
| 47 |
+
"o_attention_bias": false,
|
| 48 |
+
"pad_token_id_added": 151665,
|
| 49 |
+
"patch_size": 2,
|
| 50 |
+
"pretraining_tp": 1,
|
| 51 |
+
"rms_norm_eps": 1e-05,
|
| 52 |
+
"rope_scaling": null,
|
| 53 |
+
"rope_theta": 1000000.0,
|
| 54 |
+
"sliding_window": 131072,
|
| 55 |
+
"tie_word_embeddings": false,
|
| 56 |
+
"torch_dtype": "bfloat16",
|
| 57 |
+
"transformers_version": "4.55.0",
|
| 58 |
+
"use_2d_rope": false,
|
| 59 |
+
"use_cache": true,
|
| 60 |
+
"use_gen_pos_embed": false,
|
| 61 |
+
"use_mlp_before_lm_head": false,
|
| 62 |
+
"use_sliding_window": false,
|
| 63 |
+
"use_token_length_weight": false,
|
| 64 |
+
"vae_name_or_path": "vae/",
|
| 65 |
+
"vocab_size": 152064
|
| 66 |
+
}
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 29907628160
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"embed_tokens.weight": "pytorch-model-00004-of-00004.safetensors",
|
| 7 |
+
"image_head.net.cond_embed.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 8 |
+
"image_head.net.cond_embed.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 9 |
+
"image_head.net.final_layer.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 10 |
+
"image_head.net.final_layer.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 11 |
+
"image_head.net.final_layer.linear.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 12 |
+
"image_head.net.final_layer.linear.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 13 |
+
"image_head.net.input_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 14 |
+
"image_head.net.input_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 15 |
+
"image_head.net.res_blocks.0.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 16 |
+
"image_head.net.res_blocks.0.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 17 |
+
"image_head.net.res_blocks.0.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 18 |
+
"image_head.net.res_blocks.0.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 19 |
+
"image_head.net.res_blocks.0.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 20 |
+
"image_head.net.res_blocks.0.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 21 |
+
"image_head.net.res_blocks.0.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 22 |
+
"image_head.net.res_blocks.0.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 23 |
+
"image_head.net.res_blocks.1.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 24 |
+
"image_head.net.res_blocks.1.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 25 |
+
"image_head.net.res_blocks.1.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 26 |
+
"image_head.net.res_blocks.1.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 27 |
+
"image_head.net.res_blocks.1.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 28 |
+
"image_head.net.res_blocks.1.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 29 |
+
"image_head.net.res_blocks.1.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 30 |
+
"image_head.net.res_blocks.1.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 31 |
+
"image_head.net.res_blocks.10.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 32 |
+
"image_head.net.res_blocks.10.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 33 |
+
"image_head.net.res_blocks.10.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 34 |
+
"image_head.net.res_blocks.10.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 35 |
+
"image_head.net.res_blocks.10.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 36 |
+
"image_head.net.res_blocks.10.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 37 |
+
"image_head.net.res_blocks.10.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 38 |
+
"image_head.net.res_blocks.10.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 39 |
+
"image_head.net.res_blocks.11.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 40 |
+
"image_head.net.res_blocks.11.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 41 |
+
"image_head.net.res_blocks.11.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 42 |
+
"image_head.net.res_blocks.11.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 43 |
+
"image_head.net.res_blocks.11.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 44 |
+
"image_head.net.res_blocks.11.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 45 |
+
"image_head.net.res_blocks.11.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 46 |
+
"image_head.net.res_blocks.11.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 47 |
+
"image_head.net.res_blocks.2.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 48 |
+
"image_head.net.res_blocks.2.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 49 |
+
"image_head.net.res_blocks.2.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 50 |
+
"image_head.net.res_blocks.2.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 51 |
+
"image_head.net.res_blocks.2.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 52 |
+
"image_head.net.res_blocks.2.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 53 |
+
"image_head.net.res_blocks.2.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 54 |
+
"image_head.net.res_blocks.2.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 55 |
+
"image_head.net.res_blocks.3.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 56 |
+
"image_head.net.res_blocks.3.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 57 |
+
"image_head.net.res_blocks.3.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 58 |
+
"image_head.net.res_blocks.3.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 59 |
+
"image_head.net.res_blocks.3.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 60 |
+
"image_head.net.res_blocks.3.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 61 |
+
"image_head.net.res_blocks.3.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 62 |
+
"image_head.net.res_blocks.3.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 63 |
+
"image_head.net.res_blocks.4.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 64 |
+
"image_head.net.res_blocks.4.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 65 |
+
"image_head.net.res_blocks.4.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 66 |
+
"image_head.net.res_blocks.4.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 67 |
+
"image_head.net.res_blocks.4.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 68 |
+
"image_head.net.res_blocks.4.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 69 |
+
"image_head.net.res_blocks.4.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 70 |
+
"image_head.net.res_blocks.4.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 71 |
+
"image_head.net.res_blocks.5.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 72 |
+
"image_head.net.res_blocks.5.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 73 |
+
"image_head.net.res_blocks.5.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 74 |
+
"image_head.net.res_blocks.5.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 75 |
+
"image_head.net.res_blocks.5.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 76 |
+
"image_head.net.res_blocks.5.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 77 |
+
"image_head.net.res_blocks.5.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 78 |
+
"image_head.net.res_blocks.5.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 79 |
+
"image_head.net.res_blocks.6.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 80 |
+
"image_head.net.res_blocks.6.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 81 |
+
"image_head.net.res_blocks.6.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 82 |
+
"image_head.net.res_blocks.6.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 83 |
+
"image_head.net.res_blocks.6.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 84 |
+
"image_head.net.res_blocks.6.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 85 |
+
"image_head.net.res_blocks.6.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 86 |
+
"image_head.net.res_blocks.6.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 87 |
+
"image_head.net.res_blocks.7.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 88 |
+
"image_head.net.res_blocks.7.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 89 |
+
"image_head.net.res_blocks.7.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 90 |
+
"image_head.net.res_blocks.7.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 91 |
+
"image_head.net.res_blocks.7.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 92 |
+
"image_head.net.res_blocks.7.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 93 |
+
"image_head.net.res_blocks.7.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 94 |
+
"image_head.net.res_blocks.7.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 95 |
+
"image_head.net.res_blocks.8.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 96 |
+
"image_head.net.res_blocks.8.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 97 |
+
"image_head.net.res_blocks.8.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 98 |
+
"image_head.net.res_blocks.8.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 99 |
+
"image_head.net.res_blocks.8.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 100 |
+
"image_head.net.res_blocks.8.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 101 |
+
"image_head.net.res_blocks.8.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 102 |
+
"image_head.net.res_blocks.8.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 103 |
+
"image_head.net.res_blocks.9.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 104 |
+
"image_head.net.res_blocks.9.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 105 |
+
"image_head.net.res_blocks.9.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 106 |
+
"image_head.net.res_blocks.9.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 107 |
+
"image_head.net.res_blocks.9.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 108 |
+
"image_head.net.res_blocks.9.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 109 |
+
"image_head.net.res_blocks.9.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 110 |
+
"image_head.net.res_blocks.9.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 111 |
+
"image_head.net.time_embed.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 112 |
+
"image_head.net.time_embed.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 113 |
+
"image_head.net.time_embed.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 114 |
+
"image_head.net.time_embed.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 115 |
+
"image_in_projector.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 116 |
+
"image_in_projector.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 117 |
+
"image_out_projector.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 118 |
+
"image_out_projector.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 119 |
+
"layers.0.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 120 |
+
"layers.0.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 121 |
+
"layers.0.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 122 |
+
"layers.0.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 123 |
+
"layers.0.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 124 |
+
"layers.0.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 125 |
+
"layers.0.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 126 |
+
"layers.0.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 127 |
+
"layers.0.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 128 |
+
"layers.0.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 129 |
+
"layers.0.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 130 |
+
"layers.0.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 131 |
+
"layers.1.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 132 |
+
"layers.1.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 133 |
+
"layers.1.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 134 |
+
"layers.1.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 135 |
+
"layers.1.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 136 |
+
"layers.1.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 137 |
+
"layers.1.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 138 |
+
"layers.1.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 139 |
+
"layers.1.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 140 |
+
"layers.1.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 141 |
+
"layers.1.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 142 |
+
"layers.1.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 143 |
+
"layers.10.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 144 |
+
"layers.10.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 145 |
+
"layers.10.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 146 |
+
"layers.10.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 147 |
+
"layers.10.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 148 |
+
"layers.10.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 149 |
+
"layers.10.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 150 |
+
"layers.10.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 151 |
+
"layers.10.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 152 |
+
"layers.10.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 153 |
+
"layers.10.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 154 |
+
"layers.10.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 155 |
+
"layers.11.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 156 |
+
"layers.11.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 157 |
+
"layers.11.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 158 |
+
"layers.11.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 159 |
+
"layers.11.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 160 |
+
"layers.11.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 161 |
+
"layers.11.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 162 |
+
"layers.11.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 163 |
+
"layers.11.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 164 |
+
"layers.11.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 165 |
+
"layers.11.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 166 |
+
"layers.11.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 167 |
+
"layers.12.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 168 |
+
"layers.12.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 169 |
+
"layers.12.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 170 |
+
"layers.12.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 171 |
+
"layers.12.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 172 |
+
"layers.12.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 173 |
+
"layers.12.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 174 |
+
"layers.12.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 175 |
+
"layers.12.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 176 |
+
"layers.12.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 177 |
+
"layers.12.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 178 |
+
"layers.12.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 179 |
+
"layers.13.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 180 |
+
"layers.13.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 181 |
+
"layers.13.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 182 |
+
"layers.13.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 183 |
+
"layers.13.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 184 |
+
"layers.13.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 185 |
+
"layers.13.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 186 |
+
"layers.13.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 187 |
+
"layers.13.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 188 |
+
"layers.13.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 189 |
+
"layers.13.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 190 |
+
"layers.13.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 191 |
+
"layers.14.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 192 |
+
"layers.14.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 193 |
+
"layers.14.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 194 |
+
"layers.14.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 195 |
+
"layers.14.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 196 |
+
"layers.14.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 197 |
+
"layers.14.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 198 |
+
"layers.14.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 199 |
+
"layers.14.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 200 |
+
"layers.14.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 201 |
+
"layers.14.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 202 |
+
"layers.14.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 203 |
+
"layers.15.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 204 |
+
"layers.15.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 205 |
+
"layers.15.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 206 |
+
"layers.15.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 207 |
+
"layers.15.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 208 |
+
"layers.15.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 209 |
+
"layers.15.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 210 |
+
"layers.15.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 211 |
+
"layers.15.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 212 |
+
"layers.15.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 213 |
+
"layers.15.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 214 |
+
"layers.15.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 215 |
+
"layers.16.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 216 |
+
"layers.16.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 217 |
+
"layers.16.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 218 |
+
"layers.16.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 219 |
+
"layers.16.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 220 |
+
"layers.16.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 221 |
+
"layers.16.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 222 |
+
"layers.16.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 223 |
+
"layers.16.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 224 |
+
"layers.16.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 225 |
+
"layers.16.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 226 |
+
"layers.16.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 227 |
+
"layers.17.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 228 |
+
"layers.17.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 229 |
+
"layers.17.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 230 |
+
"layers.17.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 231 |
+
"layers.17.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 232 |
+
"layers.17.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 233 |
+
"layers.17.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 234 |
+
"layers.17.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 235 |
+
"layers.17.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 236 |
+
"layers.17.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 237 |
+
"layers.17.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 238 |
+
"layers.17.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 239 |
+
"layers.18.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 240 |
+
"layers.18.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 241 |
+
"layers.18.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 242 |
+
"layers.18.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 243 |
+
"layers.18.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 244 |
+
"layers.18.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 245 |
+
"layers.18.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 246 |
+
"layers.18.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 247 |
+
"layers.18.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 248 |
+
"layers.18.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 249 |
+
"layers.18.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 250 |
+
"layers.18.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 251 |
+
"layers.19.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 252 |
+
"layers.19.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 253 |
+
"layers.19.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 254 |
+
"layers.19.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 255 |
+
"layers.19.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 256 |
+
"layers.19.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 257 |
+
"layers.19.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 258 |
+
"layers.19.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 259 |
+
"layers.19.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 260 |
+
"layers.19.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 261 |
+
"layers.19.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 262 |
+
"layers.19.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 263 |
+
"layers.2.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 264 |
+
"layers.2.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 265 |
+
"layers.2.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 266 |
+
"layers.2.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 267 |
+
"layers.2.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 268 |
+
"layers.2.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 269 |
+
"layers.2.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 270 |
+
"layers.2.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 271 |
+
"layers.2.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 272 |
+
"layers.2.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 273 |
+
"layers.2.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 274 |
+
"layers.2.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 275 |
+
"layers.20.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 276 |
+
"layers.20.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 277 |
+
"layers.20.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 278 |
+
"layers.20.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 279 |
+
"layers.20.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 280 |
+
"layers.20.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 281 |
+
"layers.20.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 282 |
+
"layers.20.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 283 |
+
"layers.20.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 284 |
+
"layers.20.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 285 |
+
"layers.20.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 286 |
+
"layers.20.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 287 |
+
"layers.21.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 288 |
+
"layers.21.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 289 |
+
"layers.21.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 290 |
+
"layers.21.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 291 |
+
"layers.21.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 292 |
+
"layers.21.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 293 |
+
"layers.21.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 294 |
+
"layers.21.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 295 |
+
"layers.21.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 296 |
+
"layers.21.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 297 |
+
"layers.21.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 298 |
+
"layers.21.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 299 |
+
"layers.22.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 300 |
+
"layers.22.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 301 |
+
"layers.22.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 302 |
+
"layers.22.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 303 |
+
"layers.22.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 304 |
+
"layers.22.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 305 |
+
"layers.22.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 306 |
+
"layers.22.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 307 |
+
"layers.22.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 308 |
+
"layers.22.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 309 |
+
"layers.22.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 310 |
+
"layers.22.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 311 |
+
"layers.23.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 312 |
+
"layers.23.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 313 |
+
"layers.23.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 314 |
+
"layers.23.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 315 |
+
"layers.23.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 316 |
+
"layers.23.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 317 |
+
"layers.23.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 318 |
+
"layers.23.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 319 |
+
"layers.23.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 320 |
+
"layers.23.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 321 |
+
"layers.23.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 322 |
+
"layers.23.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 323 |
+
"layers.24.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 324 |
+
"layers.24.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 325 |
+
"layers.24.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 326 |
+
"layers.24.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 327 |
+
"layers.24.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 328 |
+
"layers.24.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 329 |
+
"layers.24.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 330 |
+
"layers.24.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 331 |
+
"layers.24.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 332 |
+
"layers.24.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 333 |
+
"layers.24.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
| 334 |
+
"layers.24.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 335 |
+
"layers.25.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 336 |
+
"layers.25.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 337 |
+
"layers.25.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 338 |
+
"layers.25.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 339 |
+
"layers.25.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 340 |
+
"layers.25.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 341 |
+
"layers.25.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 342 |
+
"layers.25.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
| 343 |
+
"layers.25.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 344 |
+
"layers.25.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 345 |
+
"layers.25.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 346 |
+
"layers.25.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 347 |
+
"layers.26.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 348 |
+
"layers.26.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 349 |
+
"layers.26.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 350 |
+
"layers.26.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 351 |
+
"layers.26.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 352 |
+
"layers.26.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 353 |
+
"layers.26.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 354 |
+
"layers.26.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 355 |
+
"layers.26.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 356 |
+
"layers.26.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 357 |
+
"layers.26.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 358 |
+
"layers.26.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 359 |
+
"layers.27.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 360 |
+
"layers.27.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 361 |
+
"layers.27.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 362 |
+
"layers.27.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 363 |
+
"layers.27.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 364 |
+
"layers.27.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 365 |
+
"layers.27.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 366 |
+
"layers.27.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 367 |
+
"layers.27.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 368 |
+
"layers.27.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 369 |
+
"layers.27.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 370 |
+
"layers.27.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 371 |
+
"layers.28.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 372 |
+
"layers.28.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 373 |
+
"layers.28.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 374 |
+
"layers.28.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 375 |
+
"layers.28.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 376 |
+
"layers.28.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 377 |
+
"layers.28.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 378 |
+
"layers.28.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 379 |
+
"layers.28.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 380 |
+
"layers.28.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 381 |
+
"layers.28.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 382 |
+
"layers.28.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 383 |
+
"layers.29.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 384 |
+
"layers.29.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 385 |
+
"layers.29.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 386 |
+
"layers.29.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 387 |
+
"layers.29.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 388 |
+
"layers.29.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 389 |
+
"layers.29.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 390 |
+
"layers.29.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 391 |
+
"layers.29.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 392 |
+
"layers.29.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 393 |
+
"layers.29.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 394 |
+
"layers.29.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 395 |
+
"layers.3.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 396 |
+
"layers.3.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 397 |
+
"layers.3.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 398 |
+
"layers.3.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 399 |
+
"layers.3.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 400 |
+
"layers.3.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 401 |
+
"layers.3.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 402 |
+
"layers.3.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 403 |
+
"layers.3.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 404 |
+
"layers.3.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 405 |
+
"layers.3.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 406 |
+
"layers.3.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 407 |
+
"layers.30.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 408 |
+
"layers.30.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 409 |
+
"layers.30.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 410 |
+
"layers.30.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 411 |
+
"layers.30.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 412 |
+
"layers.30.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 413 |
+
"layers.30.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 414 |
+
"layers.30.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 415 |
+
"layers.30.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 416 |
+
"layers.30.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 417 |
+
"layers.30.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 418 |
+
"layers.30.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 419 |
+
"layers.31.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 420 |
+
"layers.31.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 421 |
+
"layers.31.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 422 |
+
"layers.31.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 423 |
+
"layers.31.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 424 |
+
"layers.31.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 425 |
+
"layers.31.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 426 |
+
"layers.31.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 427 |
+
"layers.31.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 428 |
+
"layers.31.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 429 |
+
"layers.31.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 430 |
+
"layers.31.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 431 |
+
"layers.32.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 432 |
+
"layers.32.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 433 |
+
"layers.32.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 434 |
+
"layers.32.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 435 |
+
"layers.32.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 436 |
+
"layers.32.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 437 |
+
"layers.32.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 438 |
+
"layers.32.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 439 |
+
"layers.32.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 440 |
+
"layers.32.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 441 |
+
"layers.32.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 442 |
+
"layers.32.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 443 |
+
"layers.33.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 444 |
+
"layers.33.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 445 |
+
"layers.33.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 446 |
+
"layers.33.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 447 |
+
"layers.33.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 448 |
+
"layers.33.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 449 |
+
"layers.33.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 450 |
+
"layers.33.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 451 |
+
"layers.33.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 452 |
+
"layers.33.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 453 |
+
"layers.33.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 454 |
+
"layers.33.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 455 |
+
"layers.34.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 456 |
+
"layers.34.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 457 |
+
"layers.34.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 458 |
+
"layers.34.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 459 |
+
"layers.34.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 460 |
+
"layers.34.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 461 |
+
"layers.34.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 462 |
+
"layers.34.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 463 |
+
"layers.34.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 464 |
+
"layers.34.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 465 |
+
"layers.34.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 466 |
+
"layers.34.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 467 |
+
"layers.35.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 468 |
+
"layers.35.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 469 |
+
"layers.35.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 470 |
+
"layers.35.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 471 |
+
"layers.35.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 472 |
+
"layers.35.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 473 |
+
"layers.35.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 474 |
+
"layers.35.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 475 |
+
"layers.35.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 476 |
+
"layers.35.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 477 |
+
"layers.35.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 478 |
+
"layers.35.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 479 |
+
"layers.36.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 480 |
+
"layers.36.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 481 |
+
"layers.36.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 482 |
+
"layers.36.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 483 |
+
"layers.36.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 484 |
+
"layers.36.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 485 |
+
"layers.36.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 486 |
+
"layers.36.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 487 |
+
"layers.36.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 488 |
+
"layers.36.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 489 |
+
"layers.36.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 490 |
+
"layers.36.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 491 |
+
"layers.37.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 492 |
+
"layers.37.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 493 |
+
"layers.37.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 494 |
+
"layers.37.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 495 |
+
"layers.37.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 496 |
+
"layers.37.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 497 |
+
"layers.37.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 498 |
+
"layers.37.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 499 |
+
"layers.37.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 500 |
+
"layers.37.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 501 |
+
"layers.37.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 502 |
+
"layers.37.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 503 |
+
"layers.38.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 504 |
+
"layers.38.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 505 |
+
"layers.38.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 506 |
+
"layers.38.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 507 |
+
"layers.38.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 508 |
+
"layers.38.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 509 |
+
"layers.38.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 510 |
+
"layers.38.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 511 |
+
"layers.38.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 512 |
+
"layers.38.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 513 |
+
"layers.38.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 514 |
+
"layers.38.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 515 |
+
"layers.39.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 516 |
+
"layers.39.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 517 |
+
"layers.39.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 518 |
+
"layers.39.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 519 |
+
"layers.39.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 520 |
+
"layers.39.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 521 |
+
"layers.39.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 522 |
+
"layers.39.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 523 |
+
"layers.39.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 524 |
+
"layers.39.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 525 |
+
"layers.39.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 526 |
+
"layers.39.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 527 |
+
"layers.4.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 528 |
+
"layers.4.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 529 |
+
"layers.4.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 530 |
+
"layers.4.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 531 |
+
"layers.4.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 532 |
+
"layers.4.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 533 |
+
"layers.4.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 534 |
+
"layers.4.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 535 |
+
"layers.4.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 536 |
+
"layers.4.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 537 |
+
"layers.4.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 538 |
+
"layers.4.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 539 |
+
"layers.40.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 540 |
+
"layers.40.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 541 |
+
"layers.40.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 542 |
+
"layers.40.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 543 |
+
"layers.40.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 544 |
+
"layers.40.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 545 |
+
"layers.40.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 546 |
+
"layers.40.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 547 |
+
"layers.40.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 548 |
+
"layers.40.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 549 |
+
"layers.40.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
| 550 |
+
"layers.40.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 551 |
+
"layers.41.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 552 |
+
"layers.41.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 553 |
+
"layers.41.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 554 |
+
"layers.41.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 555 |
+
"layers.41.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 556 |
+
"layers.41.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 557 |
+
"layers.41.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 558 |
+
"layers.41.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
| 559 |
+
"layers.41.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 560 |
+
"layers.41.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 561 |
+
"layers.41.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 562 |
+
"layers.41.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 563 |
+
"layers.42.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 564 |
+
"layers.42.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 565 |
+
"layers.42.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 566 |
+
"layers.42.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 567 |
+
"layers.42.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 568 |
+
"layers.42.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 569 |
+
"layers.42.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 570 |
+
"layers.42.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 571 |
+
"layers.42.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 572 |
+
"layers.42.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 573 |
+
"layers.42.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 574 |
+
"layers.42.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 575 |
+
"layers.43.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 576 |
+
"layers.43.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 577 |
+
"layers.43.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 578 |
+
"layers.43.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 579 |
+
"layers.43.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 580 |
+
"layers.43.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 581 |
+
"layers.43.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 582 |
+
"layers.43.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 583 |
+
"layers.43.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 584 |
+
"layers.43.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 585 |
+
"layers.43.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 586 |
+
"layers.43.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 587 |
+
"layers.44.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 588 |
+
"layers.44.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 589 |
+
"layers.44.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 590 |
+
"layers.44.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 591 |
+
"layers.44.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 592 |
+
"layers.44.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 593 |
+
"layers.44.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 594 |
+
"layers.44.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 595 |
+
"layers.44.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 596 |
+
"layers.44.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 597 |
+
"layers.44.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 598 |
+
"layers.44.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 599 |
+
"layers.45.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 600 |
+
"layers.45.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 601 |
+
"layers.45.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 602 |
+
"layers.45.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 603 |
+
"layers.45.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 604 |
+
"layers.45.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 605 |
+
"layers.45.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 606 |
+
"layers.45.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 607 |
+
"layers.45.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 608 |
+
"layers.45.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 609 |
+
"layers.45.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 610 |
+
"layers.45.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 611 |
+
"layers.46.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 612 |
+
"layers.46.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 613 |
+
"layers.46.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 614 |
+
"layers.46.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 615 |
+
"layers.46.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 616 |
+
"layers.46.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 617 |
+
"layers.46.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 618 |
+
"layers.46.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 619 |
+
"layers.46.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 620 |
+
"layers.46.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 621 |
+
"layers.46.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 622 |
+
"layers.46.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 623 |
+
"layers.47.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 624 |
+
"layers.47.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 625 |
+
"layers.47.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 626 |
+
"layers.47.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 627 |
+
"layers.47.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 628 |
+
"layers.47.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 629 |
+
"layers.47.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 630 |
+
"layers.47.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 631 |
+
"layers.47.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 632 |
+
"layers.47.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 633 |
+
"layers.47.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 634 |
+
"layers.47.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 635 |
+
"layers.5.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 636 |
+
"layers.5.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 637 |
+
"layers.5.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 638 |
+
"layers.5.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 639 |
+
"layers.5.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 640 |
+
"layers.5.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 641 |
+
"layers.5.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 642 |
+
"layers.5.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 643 |
+
"layers.5.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 644 |
+
"layers.5.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 645 |
+
"layers.5.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 646 |
+
"layers.5.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 647 |
+
"layers.6.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 648 |
+
"layers.6.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 649 |
+
"layers.6.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 650 |
+
"layers.6.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 651 |
+
"layers.6.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 652 |
+
"layers.6.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 653 |
+
"layers.6.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 654 |
+
"layers.6.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 655 |
+
"layers.6.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 656 |
+
"layers.6.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 657 |
+
"layers.6.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 658 |
+
"layers.6.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 659 |
+
"layers.7.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 660 |
+
"layers.7.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 661 |
+
"layers.7.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 662 |
+
"layers.7.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 663 |
+
"layers.7.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 664 |
+
"layers.7.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 665 |
+
"layers.7.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 666 |
+
"layers.7.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 667 |
+
"layers.7.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 668 |
+
"layers.7.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 669 |
+
"layers.7.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 670 |
+
"layers.7.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 671 |
+
"layers.8.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 672 |
+
"layers.8.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 673 |
+
"layers.8.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 674 |
+
"layers.8.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 675 |
+
"layers.8.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 676 |
+
"layers.8.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 677 |
+
"layers.8.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 678 |
+
"layers.8.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 679 |
+
"layers.8.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 680 |
+
"layers.8.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 681 |
+
"layers.8.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 682 |
+
"layers.8.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 683 |
+
"layers.9.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 684 |
+
"layers.9.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 685 |
+
"layers.9.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 686 |
+
"layers.9.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 687 |
+
"layers.9.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 688 |
+
"layers.9.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 689 |
+
"layers.9.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 690 |
+
"layers.9.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 691 |
+
"layers.9.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 692 |
+
"layers.9.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 693 |
+
"layers.9.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
| 694 |
+
"layers.9.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 695 |
+
"lm_head.weight": "pytorch-model-00003-of-00004.safetensors",
|
| 696 |
+
"norm.weight": "pytorch-model-00003-of-00004.safetensors"
|
| 697 |
+
}
|
| 698 |
+
}
|
models/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (1.47 kB). View file
|
|
|
models/__pycache__/gen_pipeline.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
models/__pycache__/heads.cpython-310.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
models/__pycache__/llama_model.cpython-310.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
models/__pycache__/nextstep_model.cpython-310.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
models/config.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 2 |
+
|
| 3 |
+
class NextStepConfig(LlamaConfig):
|
| 4 |
+
|
| 5 |
+
model_type = "nextstep"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
vae_name_or_path: str | None = None,
|
| 10 |
+
latent_size: int = 32,
|
| 11 |
+
latent_patch_size: int = 2,
|
| 12 |
+
latent_channels: int = 16,
|
| 13 |
+
boi: int | None = None,
|
| 14 |
+
eoi: int | None = None,
|
| 15 |
+
image_placeholder_id: int | None = None,
|
| 16 |
+
pad_token_id_added: int | None = None,
|
| 17 |
+
lm_loss_weight: float = 0.01,
|
| 18 |
+
im_loss_weight: float = 1.0,
|
| 19 |
+
fm_head_dim: int = 1536,
|
| 20 |
+
fm_head_layers: int = 12,
|
| 21 |
+
fm_head_batch_mul: int = 4,
|
| 22 |
+
o_attention_bias: bool | None = None,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
|
| 27 |
+
self.vae_name_or_path = vae_name_or_path
|
| 28 |
+
|
| 29 |
+
self.latent_size = latent_size
|
| 30 |
+
self.latent_patch_size = latent_patch_size
|
| 31 |
+
self.latent_channels = latent_channels
|
| 32 |
+
|
| 33 |
+
self.boi = boi
|
| 34 |
+
self.eoi = eoi
|
| 35 |
+
self.image_placeholder_id = image_placeholder_id
|
| 36 |
+
self.pad_token_id_added = pad_token_id_added
|
| 37 |
+
|
| 38 |
+
self.lm_loss_weight = lm_loss_weight
|
| 39 |
+
self.im_loss_weight = im_loss_weight
|
| 40 |
+
|
| 41 |
+
self.fm_head_dim = fm_head_dim
|
| 42 |
+
self.fm_head_layers = fm_head_layers
|
| 43 |
+
self.fm_head_batch_mul = fm_head_batch_mul
|
| 44 |
+
|
| 45 |
+
self.o_attention_bias = self.attention_bias if o_attention_bias is None else o_attention_bias
|
models/gen_pipeline.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import copy
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
from transformers.cache_utils import Cache, StaticCache
|
| 14 |
+
|
| 15 |
+
from models.nextstep_model import NextStep
|
| 16 |
+
from vae.nextstep_ae import AutoencoderKL
|
| 17 |
+
from utils.image_utils import to_pil
|
| 18 |
+
from utils.model_utils import layer_norm
|
| 19 |
+
from utils.compile_utils import compile_manager
|
| 20 |
+
from utils.misc import set_seed
|
| 21 |
+
|
| 22 |
+
DEFAULT_IMAGE_AREA_TOKEN = "<|image_area|>"
|
| 23 |
+
|
| 24 |
+
def hw2str(h: int, w: int) -> str:
|
| 25 |
+
return f"{h}*{w}"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NextStepPipeline:
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
model_name_or_path: str | None = None,
|
| 32 |
+
vae_name_or_path: str | None = None,
|
| 33 |
+
tokenizer: AutoTokenizer | None = None,
|
| 34 |
+
model: nn.Module | None = None,
|
| 35 |
+
vae: AutoencoderKL | None = None,
|
| 36 |
+
):
|
| 37 |
+
if model is not None:
|
| 38 |
+
self.tokenizer = copy.deepcopy(tokenizer)
|
| 39 |
+
self.tokenizer.padding_side = "left"
|
| 40 |
+
self.model = model
|
| 41 |
+
elif model_name_or_path is not None:
|
| 42 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 43 |
+
model_name_or_path,
|
| 44 |
+
local_files_only=True,
|
| 45 |
+
padding_side="left",
|
| 46 |
+
use_fast=True,
|
| 47 |
+
)
|
| 48 |
+
self.model: NextStep = NextStep.from_pretrained(model_name_or_path, local_files_only=True)
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError("model or model_name_or_path is required")
|
| 51 |
+
|
| 52 |
+
self.tokenizer.add_eos_token = False
|
| 53 |
+
if vae_name_or_path is None:
|
| 54 |
+
vae_name_or_path = getattr(self.model.config, "vae_name_or_path", None)
|
| 55 |
+
if vae is not None:
|
| 56 |
+
self.vae = vae
|
| 57 |
+
elif vae_name_or_path is not None:
|
| 58 |
+
self.vae = AutoencoderKL.from_pretrained(vae_name_or_path)
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError("vae or vae_name_or_path is required")
|
| 61 |
+
|
| 62 |
+
self.model.eval()
|
| 63 |
+
self.vae.eval()
|
| 64 |
+
|
| 65 |
+
vae_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 66 |
+
self.down_factor = vae_factor * self.model.config.latent_patch_size
|
| 67 |
+
self.shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
|
| 68 |
+
self.scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
|
| 69 |
+
|
| 70 |
+
self.boi = self.model.config.boi
|
| 71 |
+
self.eoi = self.model.config.eoi
|
| 72 |
+
|
| 73 |
+
self.image_placeholder_id = self.model.config.image_placeholder_id
|
| 74 |
+
self.pil2tensor = transforms.Compose(
|
| 75 |
+
[
|
| 76 |
+
transforms.ToTensor(),
|
| 77 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
self.__device = self.model.device
|
| 81 |
+
self.__dtype = self.model.dtype
|
| 82 |
+
self.to(self.device, self.dtype)
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def device(self):
|
| 86 |
+
return self.__device
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def device_type(self):
|
| 90 |
+
if isinstance(self.__device, str):
|
| 91 |
+
return self.__device
|
| 92 |
+
return self.__device.type
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def dtype(self):
|
| 96 |
+
return self.__dtype
|
| 97 |
+
|
| 98 |
+
def to(self, device: str | None = None, dtype: torch.dtype | None = None):
|
| 99 |
+
if device is not None:
|
| 100 |
+
self.__device = device
|
| 101 |
+
if dtype is not None:
|
| 102 |
+
self.__dtype = dtype
|
| 103 |
+
self.model.to(self.__device, dtype=self.__dtype)
|
| 104 |
+
self.vae.to(self.__device, dtype=self.__dtype)
|
| 105 |
+
return self
|
| 106 |
+
|
| 107 |
+
def _image_str(self, hw: tuple[int, int] = (256, 256)):
|
| 108 |
+
latent_hw = (hw[0] // self.down_factor, hw[1] // self.down_factor)
|
| 109 |
+
image_ids = [self.boi] + [self.image_placeholder_id] * (latent_hw[0] * latent_hw[1]) + [self.eoi]
|
| 110 |
+
image_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(*latent_hw) + self.tokenizer.decode(image_ids)
|
| 111 |
+
return image_str
|
| 112 |
+
|
| 113 |
+
def _check_input(
|
| 114 |
+
self, captions: str | list[str], images: Image.Image | list[Image.Image] | None
|
| 115 |
+
) -> tuple[list[str], list[Image.Image] | None]:
|
| 116 |
+
if not isinstance(captions, list):
|
| 117 |
+
captions = [captions]
|
| 118 |
+
if images is not None:
|
| 119 |
+
if not isinstance(images, list):
|
| 120 |
+
images = [images]
|
| 121 |
+
# Validate image count matches <image> tokens in captions
|
| 122 |
+
image_token_count = 0
|
| 123 |
+
for caption in captions:
|
| 124 |
+
num_image_token = len(re.findall(r"<image>", caption))
|
| 125 |
+
assert num_image_token == 1, f"Caption `{caption}` has {num_image_token} image tokens, but only 1 is allowed."
|
| 126 |
+
image_token_count += num_image_token
|
| 127 |
+
if image_token_count != len(images):
|
| 128 |
+
raise ValueError(
|
| 129 |
+
f"Number of images ({len(images)}) does not match number of image tokens ({image_token_count}).\n"
|
| 130 |
+
f"Captions: {captions}"
|
| 131 |
+
)
|
| 132 |
+
hws = [(image.size[1], image.size[0]) for image in images]
|
| 133 |
+
# Replace <image> tokens sequentially with corresponding image_str based on hw
|
| 134 |
+
processed_captions = []
|
| 135 |
+
image_idx = 0
|
| 136 |
+
for caption in captions:
|
| 137 |
+
# Process each caption
|
| 138 |
+
processed_caption = caption
|
| 139 |
+
num_image_tokens = processed_caption.count("<image>")
|
| 140 |
+
# Replace each <image> token in order
|
| 141 |
+
for _ in range(num_image_tokens):
|
| 142 |
+
processed_caption = processed_caption.replace("<image>", self._image_str(hws[image_idx]), 1)
|
| 143 |
+
image_idx += 1
|
| 144 |
+
processed_captions.append(processed_caption)
|
| 145 |
+
captions = processed_captions
|
| 146 |
+
return captions, images
|
| 147 |
+
|
| 148 |
+
def _build_captions(
|
| 149 |
+
self,
|
| 150 |
+
captions: str | list[str],
|
| 151 |
+
images: list[Image.Image] | None = None,
|
| 152 |
+
num_images_per_caption: int = 1,
|
| 153 |
+
positive_prompt: str | None = None,
|
| 154 |
+
negative_prompt: str | None = None,
|
| 155 |
+
cfg: float = 1.0,
|
| 156 |
+
cfg_img: float = 1.0,
|
| 157 |
+
):
|
| 158 |
+
# 1. repeat captions and images
|
| 159 |
+
if not isinstance(captions, list):
|
| 160 |
+
captions = [captions]
|
| 161 |
+
|
| 162 |
+
captions = [caption for caption in captions for _ in range(num_images_per_caption)]
|
| 163 |
+
if images is not None:
|
| 164 |
+
images = [image for image in images for _ in range(num_images_per_caption)]
|
| 165 |
+
|
| 166 |
+
# 2. add positive prompt
|
| 167 |
+
if positive_prompt is not None and positive_prompt != "":
|
| 168 |
+
captions = [f"{caption} {positive_prompt}" for caption in captions]
|
| 169 |
+
|
| 170 |
+
# 3. add negative prompt
|
| 171 |
+
if negative_prompt is None:
|
| 172 |
+
negative_prompt = ""
|
| 173 |
+
|
| 174 |
+
num_samples = len(captions)
|
| 175 |
+
if cfg != 1.0 and cfg_img != 1.0: # use both image and text CFG
|
| 176 |
+
w, h = images[0].size
|
| 177 |
+
captions = (
|
| 178 |
+
captions + [self._image_str((h, w)) + negative_prompt] * num_samples
|
| 179 |
+
)
|
| 180 |
+
images = images + images
|
| 181 |
+
captions = captions + [negative_prompt] * num_samples
|
| 182 |
+
elif cfg != 1.0 and cfg_img == 1.0: # use text CFG
|
| 183 |
+
captions = captions + [negative_prompt] * num_samples
|
| 184 |
+
elif cfg == 1.0 and cfg_img == 1.0:
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
return captions, images
|
| 188 |
+
|
| 189 |
+
def _add_prefix_ids(self, hw: tuple[int, int], input_ids: torch.Tensor, attention_mask: torch.Tensor):
|
| 190 |
+
prefix_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(hw[0] // self.down_factor, hw[1] // self.down_factor)
|
| 191 |
+
prefix_output = self.tokenizer(
|
| 192 |
+
prefix_str,
|
| 193 |
+
truncation=False,
|
| 194 |
+
add_special_tokens=True,
|
| 195 |
+
return_tensors="pt"
|
| 196 |
+
)
|
| 197 |
+
prefix_input_ids = prefix_output.input_ids.to(input_ids.device, dtype=input_ids.dtype)
|
| 198 |
+
prefix_attention_mask = prefix_output.attention_mask.to(attention_mask.device, dtype=attention_mask.dtype)
|
| 199 |
+
# remove bos token
|
| 200 |
+
if self.tokenizer.bos_token is not None:
|
| 201 |
+
prefix_input_ids = prefix_input_ids[:, 1:]
|
| 202 |
+
prefix_attention_mask = prefix_attention_mask[:, 1:]
|
| 203 |
+
# add boi token
|
| 204 |
+
prefix_input_ids = torch.cat(
|
| 205 |
+
[
|
| 206 |
+
prefix_input_ids,
|
| 207 |
+
prefix_input_ids.new_tensor([self.model.config.boi]).unsqueeze(0),
|
| 208 |
+
],
|
| 209 |
+
dim=1,
|
| 210 |
+
)
|
| 211 |
+
prefix_attention_mask = torch.cat(
|
| 212 |
+
[
|
| 213 |
+
prefix_attention_mask,
|
| 214 |
+
prefix_attention_mask.new_ones((prefix_attention_mask.shape[0], 1)),
|
| 215 |
+
],
|
| 216 |
+
dim=1,
|
| 217 |
+
)
|
| 218 |
+
bsz = input_ids.shape[0]
|
| 219 |
+
input_ids = torch.cat([input_ids, prefix_input_ids.expand(bsz, -1)], dim=1)
|
| 220 |
+
attention_mask = torch.cat([attention_mask, prefix_attention_mask.expand(bsz, -1)], dim=1)
|
| 221 |
+
|
| 222 |
+
return input_ids, attention_mask
|
| 223 |
+
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def decoding(
|
| 226 |
+
self,
|
| 227 |
+
c: torch.Tensor,
|
| 228 |
+
attention_mask: torch.Tensor,
|
| 229 |
+
past_key_values: Cache,
|
| 230 |
+
max_new_len: int,
|
| 231 |
+
num_images_per_caption: int,
|
| 232 |
+
use_norm: bool = False,
|
| 233 |
+
cfg: float = 1.0,
|
| 234 |
+
cfg_img: float = 1.0,
|
| 235 |
+
cfg_schedule: Literal["linear", "constant"] = "constant",
|
| 236 |
+
timesteps_shift: float = 1.0,
|
| 237 |
+
num_sampling_steps: int = 20,
|
| 238 |
+
progress: bool = True,
|
| 239 |
+
hw: tuple[int, int] = (256, 256),
|
| 240 |
+
step: int = 0,
|
| 241 |
+
):
|
| 242 |
+
indices = list(range(max_new_len))
|
| 243 |
+
indices = tqdm(indices, unit="tokens") if progress else indices
|
| 244 |
+
tokens = None
|
| 245 |
+
for step in indices:
|
| 246 |
+
# cfg schedule follow Muse
|
| 247 |
+
if cfg_schedule == "linear":
|
| 248 |
+
tokens_len = 0 if tokens is None else tokens.shape[1]
|
| 249 |
+
cfg_iter = max(cfg / 2, 1 + (cfg - 1) * tokens_len / max_new_len)
|
| 250 |
+
cfg_img_iter = max(cfg_img / 2, 1 + (cfg_img - 1) * tokens_len / max_new_len)
|
| 251 |
+
elif cfg_schedule == "constant":
|
| 252 |
+
cfg_iter = cfg
|
| 253 |
+
cfg_img_iter = cfg_img
|
| 254 |
+
else:
|
| 255 |
+
raise NotImplementedError
|
| 256 |
+
|
| 257 |
+
c = self.model.image_out_projector(c)
|
| 258 |
+
token_sampled = self.model.image_head.sample(
|
| 259 |
+
c=c.squeeze(1),
|
| 260 |
+
cfg=cfg_iter,
|
| 261 |
+
cfg_img=cfg_img_iter,
|
| 262 |
+
timesteps_shift=timesteps_shift,
|
| 263 |
+
num_sampling_steps=num_sampling_steps,
|
| 264 |
+
noise_repeat=num_images_per_caption,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if use_norm:
|
| 268 |
+
token_sampled = layer_norm(token_sampled, normalized_shape=token_sampled.size()[1:])
|
| 269 |
+
if tokens is not None:
|
| 270 |
+
tokens = torch.cat([tokens, token_sampled.unsqueeze(1)], dim=1)
|
| 271 |
+
else:
|
| 272 |
+
tokens = token_sampled.unsqueeze(1)
|
| 273 |
+
|
| 274 |
+
cur_inputs_embeds = self.model.image_in_projector(tokens[:, -1:])
|
| 275 |
+
if cfg != 1.0 and cfg_img == 1.0:
|
| 276 |
+
cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds], dim=0)
|
| 277 |
+
elif cfg != 1.0 and cfg_img != 1.0:
|
| 278 |
+
cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds, cur_inputs_embeds], dim=0)
|
| 279 |
+
|
| 280 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
| 281 |
+
outputs = self.model.forward_model(
|
| 282 |
+
inputs_embeds=cur_inputs_embeds,
|
| 283 |
+
attention_mask=attention_mask,
|
| 284 |
+
past_key_values=past_key_values,
|
| 285 |
+
use_cache=True,
|
| 286 |
+
)
|
| 287 |
+
past_key_values = outputs.past_key_values
|
| 288 |
+
c = outputs.last_hidden_state[:, -1:]
|
| 289 |
+
if self.model.config.use_gen_pos_embed:
|
| 290 |
+
c = c + self.model.gen_pos_embed_with_ar(hw[0], hw[1])[:, step + 1 : step + 2, :]
|
| 291 |
+
|
| 292 |
+
return tokens
|
| 293 |
+
|
| 294 |
+
@torch.no_grad()
|
| 295 |
+
def generate_image(
|
| 296 |
+
self,
|
| 297 |
+
captions: str | list[str],
|
| 298 |
+
images: list[Image.Image] | None = None,
|
| 299 |
+
num_images_per_caption: int = 1,
|
| 300 |
+
positive_prompt: str | None = None,
|
| 301 |
+
negative_prompt: str | None = None,
|
| 302 |
+
hw: tuple[int, int] = (256, 256),
|
| 303 |
+
use_norm: bool = False,
|
| 304 |
+
cfg: float = 1.0,
|
| 305 |
+
cfg_img: float = 1.0,
|
| 306 |
+
cfg_schedule: Literal["linear", "constant"] = "constant",
|
| 307 |
+
num_sampling_steps: int = 20,
|
| 308 |
+
timesteps_shift: float = 1.0,
|
| 309 |
+
seed: int = 42,
|
| 310 |
+
progress: bool = True,
|
| 311 |
+
) -> list[Image.Image]:
|
| 312 |
+
# 0. set seed
|
| 313 |
+
if seed is not None:
|
| 314 |
+
set_seed(seed)
|
| 315 |
+
|
| 316 |
+
# 1. check input
|
| 317 |
+
captions, images = self._check_input(captions, images)
|
| 318 |
+
|
| 319 |
+
# 2. build captions
|
| 320 |
+
captions, images = self._build_captions(
|
| 321 |
+
captions, images, num_images_per_caption, positive_prompt, negative_prompt, cfg, cfg_img
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# 3. encode images
|
| 325 |
+
# `images` must be processed by `process_images` before calling this function
|
| 326 |
+
latents = None
|
| 327 |
+
if images is not None:
|
| 328 |
+
pixel_values = [self.pil2tensor(image) for image in images]
|
| 329 |
+
pixel_values = torch.stack(pixel_values).to(self.device)
|
| 330 |
+
with compile_manager.compile_disabled():
|
| 331 |
+
posterior = self.vae.encode(pixel_values.to(self.vae.dtype)).latent_dist
|
| 332 |
+
latents = (posterior.sample() - self.shift_factor) * self.scaling_factor
|
| 333 |
+
captions = [self.tokenizer.bos_token + caption if self.tokenizer.bos_token is not None else caption for caption in captions]
|
| 334 |
+
|
| 335 |
+
# 4. tokenize caption & add prefix ids
|
| 336 |
+
output = self.tokenizer(
|
| 337 |
+
captions,
|
| 338 |
+
padding="longest",
|
| 339 |
+
truncation=False,
|
| 340 |
+
add_special_tokens=True,
|
| 341 |
+
return_tensors="pt",
|
| 342 |
+
padding_side="left"
|
| 343 |
+
)
|
| 344 |
+
input_ids = output.input_ids.to(self.device)
|
| 345 |
+
attention_mask = output.attention_mask.to(self.device)
|
| 346 |
+
input_ids, attention_mask = self._add_prefix_ids(hw, input_ids, attention_mask)
|
| 347 |
+
|
| 348 |
+
# 5. LLM prefill
|
| 349 |
+
max_new_len = (hw[0] // self.down_factor) * (hw[1] // self.down_factor)
|
| 350 |
+
max_cache_len = input_ids.shape[1] + max_new_len
|
| 351 |
+
past_key_values = StaticCache(
|
| 352 |
+
config=self.model.config,
|
| 353 |
+
max_batch_size=input_ids.shape[0],
|
| 354 |
+
max_cache_len=max_cache_len,
|
| 355 |
+
device=self.device,
|
| 356 |
+
dtype=self.dtype,
|
| 357 |
+
)
|
| 358 |
+
inputs_embeds = self.model.prepare_inputs_embeds(input_ids, latents)
|
| 359 |
+
with compile_manager.compile_disabled():
|
| 360 |
+
outputs = self.model.forward_model(
|
| 361 |
+
inputs_embeds=inputs_embeds,
|
| 362 |
+
attention_mask=attention_mask,
|
| 363 |
+
past_key_values=past_key_values,
|
| 364 |
+
use_cache=True,
|
| 365 |
+
)
|
| 366 |
+
past_key_values = outputs.past_key_values
|
| 367 |
+
c = outputs.last_hidden_state[:, -1:]
|
| 368 |
+
if self.model.config.use_gen_pos_embed:
|
| 369 |
+
c = c + self.model.gen_pos_embed_with_ar(hw[0], hw[1])[:, 0:1, :]
|
| 370 |
+
|
| 371 |
+
# 6. decoding
|
| 372 |
+
tokens = self.decoding(
|
| 373 |
+
c=c,
|
| 374 |
+
attention_mask=attention_mask,
|
| 375 |
+
past_key_values=past_key_values,
|
| 376 |
+
max_new_len=max_new_len,
|
| 377 |
+
num_images_per_caption=num_images_per_caption,
|
| 378 |
+
use_norm=use_norm,
|
| 379 |
+
cfg=cfg,
|
| 380 |
+
cfg_img=cfg_img,
|
| 381 |
+
cfg_schedule=cfg_schedule,
|
| 382 |
+
timesteps_shift=timesteps_shift,
|
| 383 |
+
num_sampling_steps=num_sampling_steps,
|
| 384 |
+
progress=progress,
|
| 385 |
+
hw=hw,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# 7. unpatchify
|
| 389 |
+
latents = self.model.unpatchify(tokens)
|
| 390 |
+
latents = (latents / self.scaling_factor) + self.shift_factor
|
| 391 |
+
|
| 392 |
+
# 8. decode latents
|
| 393 |
+
with compile_manager.compile_disabled():
|
| 394 |
+
sampled_images = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 395 |
+
sampled_images = sampled_images.detach().cpu().to(torch.float32)
|
| 396 |
+
pil_images = [to_pil(img) for img in sampled_images]
|
| 397 |
+
|
| 398 |
+
return pil_images
|
models/heads.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.checkpoint import checkpoint
|
| 6 |
+
|
| 7 |
+
from transformers.activations import ACT2FN
|
| 8 |
+
|
| 9 |
+
from models.config import LlamaConfig
|
| 10 |
+
from utils.misc import LargeInt
|
| 11 |
+
from utils.model_utils import expand_t, randn_tensor
|
| 12 |
+
from utils.compile_utils import smart_compile
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LlamaMLP(nn.Module):
|
| 16 |
+
def __init__(self, config: LlamaConfig):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.config = config
|
| 19 |
+
self.hidden_size = config.hidden_size
|
| 20 |
+
self.intermediate_size = config.intermediate_size
|
| 21 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 22 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 23 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
| 24 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 28 |
+
return down_proj
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def modulate(x, shift, scale=None):
|
| 34 |
+
if shift is None:
|
| 35 |
+
return x * (1 + scale)
|
| 36 |
+
return x * (1 + scale) + shift
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ResBlock(nn.Module):
|
| 40 |
+
def __init__(self, channels, mlp_ratio=1.0):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.channels = channels
|
| 43 |
+
self.intermediate_size = int(channels * mlp_ratio)
|
| 44 |
+
|
| 45 |
+
self.in_ln = nn.LayerNorm(self.channels, eps=1e-6)
|
| 46 |
+
self.mlp = nn.Sequential(
|
| 47 |
+
nn.Linear(self.channels, self.intermediate_size),
|
| 48 |
+
nn.SiLU(),
|
| 49 |
+
nn.Linear(self.intermediate_size, self.channels),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
|
| 53 |
+
|
| 54 |
+
def forward(self, x, y):
|
| 55 |
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
|
| 56 |
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
| 57 |
+
h = self.mlp(h)
|
| 58 |
+
return x + gate_mlp * h
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class FinalLayer(nn.Module):
|
| 62 |
+
def __init__(self, model_channels, out_channels):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
| 65 |
+
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
| 66 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True))
|
| 67 |
+
|
| 68 |
+
def forward(self, x, c):
|
| 69 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 70 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 71 |
+
x = self.linear(x)
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TimestepEmbedder(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
Embeds scalar timesteps into vector representations.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.mlp = nn.Sequential(
|
| 83 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 84 |
+
nn.SiLU(),
|
| 85 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 86 |
+
)
|
| 87 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0):
|
| 91 |
+
"""
|
| 92 |
+
Create sinusoidal timestep embeddings.
|
| 93 |
+
:param t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 94 |
+
:param dim: the dimension of the output.
|
| 95 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 96 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 97 |
+
"""
|
| 98 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 99 |
+
half = dim // 2
|
| 100 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 101 |
+
device=t.device
|
| 102 |
+
)
|
| 103 |
+
args = t[:, None].float() * freqs[None]
|
| 104 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 105 |
+
if dim % 2:
|
| 106 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 107 |
+
return embedding
|
| 108 |
+
|
| 109 |
+
def forward(self, t):
|
| 110 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 111 |
+
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
| 112 |
+
return t_emb
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class SimpleMLPAdaLN(nn.Module):
|
| 116 |
+
def __init__(self, input_dim, cond_dim, dim=1536, layers=12, mlp_ratio=1.0):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.input_dim = input_dim
|
| 119 |
+
self.cond_dim = cond_dim
|
| 120 |
+
self.dim = dim
|
| 121 |
+
self.layers = layers
|
| 122 |
+
self.mlp_ratio = mlp_ratio
|
| 123 |
+
|
| 124 |
+
self.time_embed = TimestepEmbedder(dim)
|
| 125 |
+
self.cond_embed = nn.Linear(cond_dim, dim)
|
| 126 |
+
self.input_proj = nn.Linear(input_dim, dim)
|
| 127 |
+
|
| 128 |
+
res_blocks = []
|
| 129 |
+
for _ in range(layers):
|
| 130 |
+
res_blocks.append(ResBlock(dim, mlp_ratio))
|
| 131 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 132 |
+
|
| 133 |
+
self.final_layer = FinalLayer(dim, input_dim)
|
| 134 |
+
|
| 135 |
+
self.grad_checkpointing = False
|
| 136 |
+
|
| 137 |
+
self.initialize_weights()
|
| 138 |
+
|
| 139 |
+
def initialize_weights(self):
|
| 140 |
+
def _basic_init(module):
|
| 141 |
+
if isinstance(module, nn.Linear):
|
| 142 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 143 |
+
if module.bias is not None:
|
| 144 |
+
nn.init.constant_(module.bias, 0)
|
| 145 |
+
|
| 146 |
+
self.apply(_basic_init)
|
| 147 |
+
|
| 148 |
+
# Initialize timestep embedding MLP
|
| 149 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
| 150 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
| 151 |
+
|
| 152 |
+
# Zero-out adaLN modulation layers
|
| 153 |
+
for block in self.res_blocks:
|
| 154 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 155 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 156 |
+
|
| 157 |
+
# Zero-out output layers
|
| 158 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 159 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 160 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 161 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 162 |
+
|
| 163 |
+
@smart_compile()
|
| 164 |
+
def forward(self, x, t, c):
|
| 165 |
+
"""
|
| 166 |
+
x.shape = (bsz, input_dim)
|
| 167 |
+
t.shape = (bsz,)
|
| 168 |
+
c.shape = (bsz, cond_dim)
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
x = self.input_proj(x)
|
| 172 |
+
t = self.time_embed(t)
|
| 173 |
+
c = self.cond_embed(c)
|
| 174 |
+
|
| 175 |
+
y = t + c
|
| 176 |
+
|
| 177 |
+
for block in self.res_blocks:
|
| 178 |
+
if self.grad_checkpointing and self.training:
|
| 179 |
+
x = checkpoint(block, x, y, use_reentrant=True)
|
| 180 |
+
else:
|
| 181 |
+
x = block(x, y)
|
| 182 |
+
|
| 183 |
+
return self.final_layer(x, y)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class FlowMatchingHead(nn.Module):
|
| 187 |
+
|
| 188 |
+
def __init__(self, input_dim, cond_dim, dim=1536, layers=12, mlp_ratio=1.0):
|
| 189 |
+
super(FlowMatchingHead, self).__init__()
|
| 190 |
+
self.input_dim = input_dim
|
| 191 |
+
self.net = SimpleMLPAdaLN(input_dim=input_dim, cond_dim=cond_dim, dim=dim, layers=layers, mlp_ratio=mlp_ratio)
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def dtype(self):
|
| 195 |
+
return self.net.input_proj.weight.dtype
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
def device(self):
|
| 199 |
+
return self.net.input_proj.weight.device
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def trainable_params(self) -> float:
|
| 203 |
+
n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 204 |
+
return LargeInt(n_params)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_score_from_velocity(self, velocity, x, t):
|
| 208 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
| 209 |
+
Args:
|
| 210 |
+
velocity: [bsz, ...] shaped tensor; velocity model output
|
| 211 |
+
x: [bsz, ...] shaped tensor; x_t data point
|
| 212 |
+
t: [bsz,] time tensor
|
| 213 |
+
"""
|
| 214 |
+
t = expand_t(t, x)
|
| 215 |
+
alpha_t, d_alpha_t = t, 1
|
| 216 |
+
sigma_t, d_sigma_t = 1 - t, -1
|
| 217 |
+
mean = x
|
| 218 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 219 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
| 220 |
+
score = (reverse_alpha_ratio * velocity - mean) / var
|
| 221 |
+
return score
|
| 222 |
+
|
| 223 |
+
def get_velocity_from_cfg(self, velocity, cfg, cfg_img, cfg_mult):
|
| 224 |
+
if cfg_mult == 2:
|
| 225 |
+
cond_v, uncond_v = torch.chunk(velocity, 2, dim=0)
|
| 226 |
+
velocity = uncond_v + cfg * (cond_v - uncond_v)
|
| 227 |
+
elif cfg_mult == 3:
|
| 228 |
+
cond_v, uncond_v1, uncond_v2 = torch.chunk(velocity, 3, dim=0)
|
| 229 |
+
velocity = uncond_v2 + cfg_img * (uncond_v1 - uncond_v2) + cfg * (cond_v - uncond_v1)
|
| 230 |
+
return velocity
|
| 231 |
+
|
| 232 |
+
@smart_compile(options={"triton.cudagraphs": True}, fullgraph=True)
|
| 233 |
+
@torch.no_grad()
|
| 234 |
+
def sample(
|
| 235 |
+
self,
|
| 236 |
+
c: torch.Tensor,
|
| 237 |
+
cfg: float = 1.0,
|
| 238 |
+
cfg_img: float = 1.0,
|
| 239 |
+
timesteps_shift: float = 1.0,
|
| 240 |
+
num_sampling_steps: int = 20,
|
| 241 |
+
last_step_size: float = 0.0,
|
| 242 |
+
noise_repeat: int = 1,
|
| 243 |
+
):
|
| 244 |
+
# """c.shape = (bsz, cond_dim)"""
|
| 245 |
+
cfg_mult = 1
|
| 246 |
+
if cfg > 1.0:
|
| 247 |
+
cfg_mult += 1
|
| 248 |
+
if cfg_img > 1.0:
|
| 249 |
+
cfg_mult += 1
|
| 250 |
+
|
| 251 |
+
noise = randn_tensor((c.shape[0] // cfg_mult, self.input_dim), noise_repeat, self.device)
|
| 252 |
+
|
| 253 |
+
mean_x = noise
|
| 254 |
+
x = noise
|
| 255 |
+
xs = []
|
| 256 |
+
|
| 257 |
+
t0, t1 = 0, 1
|
| 258 |
+
timesteps = torch.linspace(t0, t1, num_sampling_steps + 1, device=c.device)[:-1]
|
| 259 |
+
timesteps = timesteps / (timesteps_shift - (timesteps_shift - 1) * timesteps)
|
| 260 |
+
timesteps = torch.cat([timesteps, torch.ones(1, device=c.device)])
|
| 261 |
+
for ti, tj in zip(timesteps[:-1], timesteps[1:]):
|
| 262 |
+
dt = tj - ti
|
| 263 |
+
|
| 264 |
+
combined = torch.cat([x] * cfg_mult, dim=0)
|
| 265 |
+
velocity = self.net(combined.to(c.dtype), ti.expand(c.shape[0]).to(c), c)
|
| 266 |
+
velocity = velocity.to(torch.float32)
|
| 267 |
+
|
| 268 |
+
velocity = self.get_velocity_from_cfg(velocity, cfg, cfg_img, cfg_mult)
|
| 269 |
+
score = self.get_score_from_velocity(velocity, x, ti.expand(x.shape[0]).to(x))
|
| 270 |
+
drift = velocity + (1 - expand_t(ti.expand(x.shape[0]).to(x), x)) * score
|
| 271 |
+
|
| 272 |
+
w_cur = randn_tensor((c.shape[0] // cfg_mult, self.input_dim), noise_repeat, self.device)
|
| 273 |
+
dw = w_cur * torch.sqrt(dt)
|
| 274 |
+
|
| 275 |
+
mean_x = x + drift * dt
|
| 276 |
+
x = mean_x + torch.sqrt(2 * (1 - expand_t(ti.expand(x.shape[0]).to(x), x))) * dw
|
| 277 |
+
xs.append(x)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if len(xs) != num_sampling_steps:
|
| 281 |
+
raise ValueError(f"Samples ({len(xs)}) does not match the number of steps ({num_sampling_steps})")
|
| 282 |
+
|
| 283 |
+
return xs[-1].to(c.dtype)
|
models/llama_model.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
from loguru import logger
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from transformers.cache_utils import Cache, StaticCache
|
| 9 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 10 |
+
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
| 11 |
+
from transformers import ROPE_INIT_FUNCTIONS
|
| 12 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 13 |
+
|
| 14 |
+
from models.heads import LlamaMLP
|
| 15 |
+
from utils.model_utils import apply_rotary_pos_emb, repeat_kv
|
| 16 |
+
from models.config import NextStepConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LlamaRMSNorm(nn.Module):
|
| 20 |
+
"""LlamaRMSNorm is equivalent to T5LayerNorm"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 25 |
+
self.variance_epsilon = eps
|
| 26 |
+
|
| 27 |
+
def forward(self, hidden_states):
|
| 28 |
+
input_dtype = hidden_states.dtype
|
| 29 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 30 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 31 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 32 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 33 |
+
|
| 34 |
+
def extra_repr(self):
|
| 35 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LlamaRotaryEmbedding(nn.Module):
|
| 39 |
+
def __init__(self, device=None, config: Optional[LlamaConfig] = None):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.rope_type = "default"
|
| 42 |
+
self.config = config
|
| 43 |
+
|
| 44 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 45 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 46 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 47 |
+
|
| 48 |
+
@torch.no_grad()
|
| 49 |
+
def forward(self, x, position_ids):
|
| 50 |
+
# Core RoPE block
|
| 51 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 52 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 53 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 54 |
+
device_type = x.device.type
|
| 55 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 56 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 57 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 58 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 59 |
+
cos = emb.cos()
|
| 60 |
+
sin = emb.sin()
|
| 61 |
+
|
| 62 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 63 |
+
cos = cos * self.attention_scaling
|
| 64 |
+
sin = sin * self.attention_scaling
|
| 65 |
+
|
| 66 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class LlamaAttention(nn.Module):
|
| 70 |
+
def __init__(self, config: NextStepConfig, layer_idx: Optional[int]):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.config = config
|
| 73 |
+
self.layer_idx = layer_idx
|
| 74 |
+
|
| 75 |
+
self.attention_dropout = config.attention_dropout
|
| 76 |
+
self.hidden_size = config.hidden_size
|
| 77 |
+
self.num_heads = config.num_attention_heads
|
| 78 |
+
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
| 79 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 80 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 81 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 82 |
+
self.rope_theta = config.rope_theta
|
| 83 |
+
self.is_causal = True
|
| 84 |
+
|
| 85 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 86 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 87 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 88 |
+
self.o_proj = nn.Linear(
|
| 89 |
+
self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, "o_attention_bias", config.attention_bias)
|
| 90 |
+
)
|
| 91 |
+
self._flash_attn_uses_top_left_mask = False
|
| 92 |
+
|
| 93 |
+
def forward_sdpa(
|
| 94 |
+
self,
|
| 95 |
+
hidden_states: torch.Tensor,
|
| 96 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 97 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 98 |
+
past_key_value: Optional[Cache] = None,
|
| 99 |
+
output_attentions: bool = False,
|
| 100 |
+
use_cache: bool = False,
|
| 101 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 102 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 103 |
+
**kwargs,
|
| 104 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 105 |
+
bsz, q_len, _ = hidden_states.size()
|
| 106 |
+
|
| 107 |
+
query_states = self.q_proj(hidden_states)
|
| 108 |
+
key_states = self.k_proj(hidden_states)
|
| 109 |
+
value_states = self.v_proj(hidden_states)
|
| 110 |
+
|
| 111 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 112 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 113 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 114 |
+
|
| 115 |
+
if position_embeddings is None:
|
| 116 |
+
logger.warning_once(
|
| 117 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 118 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 119 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 120 |
+
"removed and `position_embeddings` will be mandatory."
|
| 121 |
+
)
|
| 122 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 123 |
+
else:
|
| 124 |
+
cos, sin = position_embeddings
|
| 125 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 126 |
+
|
| 127 |
+
if past_key_value is not None:
|
| 128 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 129 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 130 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 131 |
+
|
| 132 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 133 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 134 |
+
|
| 135 |
+
causal_mask = attention_mask
|
| 136 |
+
if attention_mask is not None:
|
| 137 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
| 138 |
+
|
| 139 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 140 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 141 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 142 |
+
query_states = query_states.contiguous()
|
| 143 |
+
key_states = key_states.contiguous()
|
| 144 |
+
value_states = value_states.contiguous()
|
| 145 |
+
|
| 146 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 147 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 148 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 149 |
+
|
| 150 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 151 |
+
query_states,
|
| 152 |
+
key_states,
|
| 153 |
+
value_states,
|
| 154 |
+
attn_mask=causal_mask,
|
| 155 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 156 |
+
is_causal=is_causal,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 160 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
| 161 |
+
|
| 162 |
+
attn_output = self.o_proj(attn_output)
|
| 163 |
+
|
| 164 |
+
return attn_output, None, past_key_value
|
| 165 |
+
|
| 166 |
+
def forward_flash(
|
| 167 |
+
self,
|
| 168 |
+
hidden_states: torch.Tensor,
|
| 169 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 170 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 171 |
+
past_key_value: Optional[Cache] = None,
|
| 172 |
+
output_attentions: bool = False,
|
| 173 |
+
use_cache: bool = False,
|
| 174 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 175 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 176 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 177 |
+
if isinstance(past_key_value, StaticCache):
|
| 178 |
+
raise ValueError(
|
| 179 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
| 180 |
+
"make sure to use `sdpa` in the mean time, and open an issue at GitHub - huggingface/transformers: 🤗 Transformers: the model-definition framework for state-of-the-a"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
output_attentions = False
|
| 184 |
+
|
| 185 |
+
bsz, q_len, _ = hidden_states.size()
|
| 186 |
+
|
| 187 |
+
query_states = self.q_proj(hidden_states)
|
| 188 |
+
key_states = self.k_proj(hidden_states)
|
| 189 |
+
value_states = self.v_proj(hidden_states)
|
| 190 |
+
|
| 191 |
+
# Flash attention requires the input to have the shape
|
| 192 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 193 |
+
# therefore we just need to keep the original shape
|
| 194 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 195 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 196 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 197 |
+
|
| 198 |
+
if position_embeddings is None:
|
| 199 |
+
logger.warning_once(
|
| 200 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 201 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 202 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 203 |
+
"removed and `position_embeddings` will be mandatory."
|
| 204 |
+
)
|
| 205 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 206 |
+
else:
|
| 207 |
+
cos, sin = position_embeddings
|
| 208 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 209 |
+
|
| 210 |
+
if past_key_value is not None:
|
| 211 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 212 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 213 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 214 |
+
|
| 215 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 216 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 217 |
+
query_states = query_states.transpose(1, 2)
|
| 218 |
+
key_states = key_states.transpose(1, 2)
|
| 219 |
+
value_states = value_states.transpose(1, 2)
|
| 220 |
+
|
| 221 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
| 222 |
+
|
| 223 |
+
input_dtype = query_states.dtype
|
| 224 |
+
if input_dtype == torch.float32:
|
| 225 |
+
if torch.is_autocast_enabled():
|
| 226 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 227 |
+
# Handle the case where the model is quantized
|
| 228 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 229 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 230 |
+
else:
|
| 231 |
+
target_dtype = self.q_proj.weight.dtype
|
| 232 |
+
|
| 233 |
+
logger.warning_once(
|
| 234 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 235 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 236 |
+
f" {target_dtype}."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
query_states = query_states.to(target_dtype)
|
| 240 |
+
key_states = key_states.to(target_dtype)
|
| 241 |
+
value_states = value_states.to(target_dtype)
|
| 242 |
+
|
| 243 |
+
attn_output = _flash_attention_forward(
|
| 244 |
+
query_states,
|
| 245 |
+
key_states,
|
| 246 |
+
value_states,
|
| 247 |
+
attention_mask,
|
| 248 |
+
q_len,
|
| 249 |
+
position_ids=position_ids,
|
| 250 |
+
dropout=dropout_rate,
|
| 251 |
+
sliding_window=getattr(self, "sliding_window", None),
|
| 252 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 253 |
+
is_causal=self.is_causal,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
| 257 |
+
attn_output = self.o_proj(attn_output)
|
| 258 |
+
|
| 259 |
+
if not output_attentions:
|
| 260 |
+
attn_weights = None
|
| 261 |
+
|
| 262 |
+
return attn_output, attn_weights, past_key_value
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
hidden_states: torch.Tensor,
|
| 267 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 268 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 269 |
+
past_key_value: Optional[Cache] = None,
|
| 270 |
+
output_attentions: bool = False,
|
| 271 |
+
use_cache: bool = False,
|
| 272 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 273 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 274 |
+
**kwargs,
|
| 275 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 276 |
+
bsz, q_len, _ = hidden_states.size()
|
| 277 |
+
|
| 278 |
+
query_states = self.q_proj(hidden_states)
|
| 279 |
+
key_states = self.k_proj(hidden_states)
|
| 280 |
+
value_states = self.v_proj(hidden_states)
|
| 281 |
+
|
| 282 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 283 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 284 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 285 |
+
|
| 286 |
+
if position_embeddings is None:
|
| 287 |
+
logger.warning_once(
|
| 288 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 289 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 290 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 291 |
+
"removed and `position_embeddings` will be mandatory."
|
| 292 |
+
)
|
| 293 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 294 |
+
else:
|
| 295 |
+
cos, sin = position_embeddings
|
| 296 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 297 |
+
|
| 298 |
+
if past_key_value is not None:
|
| 299 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 300 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 301 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 302 |
+
|
| 303 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 304 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 305 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 306 |
+
|
| 307 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 308 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 309 |
+
attn_weights = attn_weights + causal_mask
|
| 310 |
+
|
| 311 |
+
# upcast attention to fp32
|
| 312 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 313 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 314 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 315 |
+
|
| 316 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 317 |
+
raise ValueError(
|
| 318 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 319 |
+
f" {attn_output.size()}"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 323 |
+
|
| 324 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
| 325 |
+
|
| 326 |
+
attn_output = self.o_proj(attn_output)
|
| 327 |
+
|
| 328 |
+
if not output_attentions:
|
| 329 |
+
attn_weights = None
|
| 330 |
+
|
| 331 |
+
return attn_output, attn_weights, past_key_value
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class LlamaFlashAttention2(LlamaAttention):
|
| 335 |
+
"""
|
| 336 |
+
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
| 337 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 338 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
def __init__(self, *args, **kwargs):
|
| 342 |
+
super().__init__(*args, **kwargs)
|
| 343 |
+
|
| 344 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 345 |
+
|
| 346 |
+
def forward(
|
| 347 |
+
self,
|
| 348 |
+
hidden_states: torch.Tensor,
|
| 349 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 350 |
+
past_key_value: Optional[Cache] = None,
|
| 351 |
+
output_attentions: bool = False,
|
| 352 |
+
use_cache: bool = False,
|
| 353 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 354 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 355 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 356 |
+
if isinstance(past_key_value, StaticCache):
|
| 357 |
+
raise ValueError(
|
| 358 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
| 359 |
+
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
output_attentions = False
|
| 363 |
+
|
| 364 |
+
bsz, q_len, _ = hidden_states.size()
|
| 365 |
+
|
| 366 |
+
query_states = self.q_proj(hidden_states)
|
| 367 |
+
key_states = self.k_proj(hidden_states)
|
| 368 |
+
value_states = self.v_proj(hidden_states)
|
| 369 |
+
|
| 370 |
+
# Flash attention requires the input to have the shape
|
| 371 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 372 |
+
# therefore we just need to keep the original shape
|
| 373 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 374 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 375 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 376 |
+
|
| 377 |
+
cos, sin = position_embeddings
|
| 378 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 379 |
+
|
| 380 |
+
if past_key_value is not None:
|
| 381 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 382 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 383 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 384 |
+
|
| 385 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 386 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 387 |
+
query_states = query_states.transpose(1, 2)
|
| 388 |
+
key_states = key_states.transpose(1, 2)
|
| 389 |
+
value_states = value_states.transpose(1, 2)
|
| 390 |
+
|
| 391 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
| 392 |
+
|
| 393 |
+
input_dtype = query_states.dtype
|
| 394 |
+
if input_dtype == torch.float32:
|
| 395 |
+
if torch.is_autocast_enabled():
|
| 396 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 397 |
+
# Handle the case where the model is quantized
|
| 398 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 399 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 400 |
+
else:
|
| 401 |
+
target_dtype = self.q_proj.weight.dtype
|
| 402 |
+
|
| 403 |
+
logger.warning_once(
|
| 404 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 405 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 406 |
+
f" {target_dtype}."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
query_states = query_states.to(target_dtype)
|
| 410 |
+
key_states = key_states.to(target_dtype)
|
| 411 |
+
value_states = value_states.to(target_dtype)
|
| 412 |
+
|
| 413 |
+
attn_output = _flash_attention_forward(
|
| 414 |
+
query_states,
|
| 415 |
+
key_states,
|
| 416 |
+
value_states,
|
| 417 |
+
attention_mask,
|
| 418 |
+
q_len,
|
| 419 |
+
position_ids=None,
|
| 420 |
+
dropout=dropout_rate,
|
| 421 |
+
sliding_window=getattr(self, "sliding_window", None),
|
| 422 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 423 |
+
is_causal=self.is_causal,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
| 427 |
+
attn_output = self.o_proj(attn_output)
|
| 428 |
+
|
| 429 |
+
if not output_attentions:
|
| 430 |
+
attn_weights = None
|
| 431 |
+
|
| 432 |
+
return attn_output, attn_weights, past_key_value
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class LlamaSdpaAttention(LlamaAttention):
|
| 436 |
+
"""
|
| 437 |
+
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 438 |
+
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 439 |
+
SDPA API.
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
# Adapted from LlamaAttention.forward
|
| 443 |
+
def forward(
|
| 444 |
+
self,
|
| 445 |
+
hidden_states: torch.Tensor,
|
| 446 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 447 |
+
past_key_value: Optional[Cache] = None,
|
| 448 |
+
output_attentions: bool = False,
|
| 449 |
+
use_cache: bool = False,
|
| 450 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 451 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 452 |
+
**kwargs,
|
| 453 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 454 |
+
|
| 455 |
+
bsz, q_len, _ = hidden_states.size()
|
| 456 |
+
|
| 457 |
+
query_states = self.q_proj(hidden_states)
|
| 458 |
+
key_states = self.k_proj(hidden_states)
|
| 459 |
+
value_states = self.v_proj(hidden_states)
|
| 460 |
+
|
| 461 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 462 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 463 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 464 |
+
|
| 465 |
+
cos, sin = position_embeddings
|
| 466 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 467 |
+
|
| 468 |
+
if past_key_value is not None:
|
| 469 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 470 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 471 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 472 |
+
|
| 473 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 474 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 475 |
+
|
| 476 |
+
causal_mask = attention_mask
|
| 477 |
+
if attention_mask is not None:
|
| 478 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
| 479 |
+
|
| 480 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 481 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 482 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 483 |
+
query_states = query_states.contiguous()
|
| 484 |
+
key_states = key_states.contiguous()
|
| 485 |
+
value_states = value_states.contiguous()
|
| 486 |
+
|
| 487 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 488 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 489 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 490 |
+
|
| 491 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 492 |
+
query_states,
|
| 493 |
+
key_states,
|
| 494 |
+
value_states,
|
| 495 |
+
attn_mask=causal_mask,
|
| 496 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 497 |
+
is_causal=is_causal,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 501 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
| 502 |
+
|
| 503 |
+
attn_output = self.o_proj(attn_output)
|
| 504 |
+
|
| 505 |
+
return attn_output, None, past_key_value
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
LLAMA_ATTENTION_CLASSES = {
|
| 509 |
+
"eager": LlamaAttention,
|
| 510 |
+
"flash_attention_2": LlamaFlashAttention2,
|
| 511 |
+
"sdpa": LlamaSdpaAttention,
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class LlamaDecoderLayer(nn.Module):
|
| 516 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
| 517 |
+
super().__init__()
|
| 518 |
+
self.hidden_size = config.hidden_size
|
| 519 |
+
|
| 520 |
+
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 521 |
+
|
| 522 |
+
self.mlp = LlamaMLP(config)
|
| 523 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 524 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 525 |
+
|
| 526 |
+
def forward(
|
| 527 |
+
self,
|
| 528 |
+
hidden_states: torch.Tensor,
|
| 529 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 530 |
+
past_key_value: Optional[Cache] = None,
|
| 531 |
+
output_attentions: Optional[bool] = False,
|
| 532 |
+
use_cache: Optional[bool] = False,
|
| 533 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 534 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 535 |
+
**kwargs,
|
| 536 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 537 |
+
residual = hidden_states
|
| 538 |
+
|
| 539 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 540 |
+
|
| 541 |
+
# Self Attention
|
| 542 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 543 |
+
hidden_states=hidden_states,
|
| 544 |
+
attention_mask=attention_mask,
|
| 545 |
+
past_key_value=past_key_value,
|
| 546 |
+
output_attentions=output_attentions,
|
| 547 |
+
use_cache=use_cache,
|
| 548 |
+
cache_position=cache_position,
|
| 549 |
+
position_embeddings=position_embeddings,
|
| 550 |
+
**kwargs,
|
| 551 |
+
)
|
| 552 |
+
hidden_states = residual + hidden_states
|
| 553 |
+
|
| 554 |
+
# Fully Connected
|
| 555 |
+
residual = hidden_states
|
| 556 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 557 |
+
hidden_states = self.mlp(hidden_states)
|
| 558 |
+
hidden_states = residual + hidden_states
|
| 559 |
+
|
| 560 |
+
outputs = (hidden_states,)
|
| 561 |
+
|
| 562 |
+
if output_attentions:
|
| 563 |
+
outputs += (self_attn_weights,)
|
| 564 |
+
|
| 565 |
+
if use_cache:
|
| 566 |
+
outputs += (present_key_value,)
|
| 567 |
+
|
| 568 |
+
return outputs
|
models/nextstep_model.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import inspect
|
| 4 |
+
from loguru import logger
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import CrossEntropyLoss
|
| 10 |
+
|
| 11 |
+
from safetensors.torch import safe_open
|
| 12 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 13 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 14 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 15 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 16 |
+
|
| 17 |
+
from models.config import NextStepConfig
|
| 18 |
+
from models.llama_model import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding
|
| 19 |
+
from models.heads import FlowMatchingHead
|
| 20 |
+
from utils.misc import LargeInt
|
| 21 |
+
from utils.compile_utils import smart_compile
|
| 22 |
+
from utils.model_utils import get_2d_sincos_pos_embed
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class NextStepOutputWithPast(CausalLMOutputWithPast):
|
| 27 |
+
lm_loss: torch.FloatTensor | None = None
|
| 28 |
+
im_loss: torch.FloatTensor | None = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class NextStepPreTrainedModel(PreTrainedModel):
|
| 32 |
+
config_class = NextStepConfig
|
| 33 |
+
supports_gradient_checkpointing = True
|
| 34 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
| 35 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 36 |
+
_supports_flash_attn_2 = True
|
| 37 |
+
_supports_sdpa = True
|
| 38 |
+
_supports_cache_class = True
|
| 39 |
+
_supports_quantized_cache = True
|
| 40 |
+
_supports_static_cache = True
|
| 41 |
+
|
| 42 |
+
def _init_weights(self, module):
|
| 43 |
+
std = self.config.initializer_range
|
| 44 |
+
if isinstance(module, nn.Linear):
|
| 45 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 46 |
+
if module.bias is not None:
|
| 47 |
+
module.bias.data.zero_()
|
| 48 |
+
elif isinstance(module, nn.Embedding):
|
| 49 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 50 |
+
if module.padding_idx is not None:
|
| 51 |
+
module.weight.data[module.padding_idx].zero_()
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def trainable_params(self) -> float:
|
| 55 |
+
n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 56 |
+
return LargeInt(n_params)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class NextStep(NextStepPreTrainedModel):
|
| 60 |
+
|
| 61 |
+
def __init__(self, config: NextStepConfig):
|
| 62 |
+
super().__init__(config)
|
| 63 |
+
self.padding_idx = config.pad_token_id
|
| 64 |
+
self.vocab_size = config.vocab_size
|
| 65 |
+
|
| 66 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 67 |
+
|
| 68 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
| 69 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 70 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
| 71 |
+
|
| 72 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 73 |
+
|
| 74 |
+
self.gradient_checkpointing = False
|
| 75 |
+
|
| 76 |
+
# Initialize weights and apply final processing
|
| 77 |
+
self.post_init()
|
| 78 |
+
|
| 79 |
+
token_dim = self.config.latent_channels * self.config.latent_patch_size**2
|
| 80 |
+
|
| 81 |
+
self.image_in_projector = nn.Linear(token_dim, config.hidden_size)
|
| 82 |
+
self.image_in_projector.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
| 83 |
+
self.image_in_projector.bias.data.zero_()
|
| 84 |
+
|
| 85 |
+
self.image_out_projector = nn.Linear(config.hidden_size, config.hidden_size)
|
| 86 |
+
self.image_out_projector.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
| 87 |
+
self.image_out_projector.bias.data.zero_()
|
| 88 |
+
|
| 89 |
+
self.image_head = FlowMatchingHead(
|
| 90 |
+
input_dim=token_dim,
|
| 91 |
+
cond_dim=config.hidden_size,
|
| 92 |
+
dim=config.fm_head_dim,
|
| 93 |
+
layers=config.fm_head_layers,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if config.use_gen_pos_embed:
|
| 97 |
+
self.init_gen_pos_embed()
|
| 98 |
+
|
| 99 |
+
def init_gen_pos_embed(self):
|
| 100 |
+
self.register_buffer(
|
| 101 |
+
"gen_pos_embed",
|
| 102 |
+
torch.from_numpy(
|
| 103 |
+
get_2d_sincos_pos_embed(
|
| 104 |
+
self.config.hidden_size, self.config.base_image_grid_size
|
| 105 |
+
)
|
| 106 |
+
).float().unsqueeze(0),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def gen_pos_embed_with_ar(self, h, w):
|
| 110 |
+
bsz, hw, dim = self.gen_pos_embed.shape
|
| 111 |
+
gen_pos_embed = self.gen_pos_embed.reshape(bsz, int(hw**0.5), int(hw**0.5), dim)
|
| 112 |
+
gen_pos_embed = gen_pos_embed[:, :h, :w, :]
|
| 113 |
+
gen_pos_embed = gen_pos_embed.reshape(bsz, -1, dim)
|
| 114 |
+
return gen_pos_embed
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def image_size(self):
|
| 118 |
+
return self.config.image_size
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def image_patch_size(self):
|
| 122 |
+
return self.config.patch_size
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def image_grid_size(self):
|
| 126 |
+
return round(self.image_size / self.image_patch_size)
|
| 127 |
+
|
| 128 |
+
def get_input_embeddings(self):
|
| 129 |
+
return self.embed_tokens
|
| 130 |
+
|
| 131 |
+
def set_input_embeddings(self, value):
|
| 132 |
+
self.embed_tokens = value
|
| 133 |
+
|
| 134 |
+
def get_output_embeddings(self):
|
| 135 |
+
return self.lm_head
|
| 136 |
+
|
| 137 |
+
def set_output_embeddings(self, new_embeddings):
|
| 138 |
+
self.lm_head = new_embeddings
|
| 139 |
+
|
| 140 |
+
def load_lm_head(self, lm_head_dir: str | None = None):
|
| 141 |
+
index_json_file = os.path.join(lm_head_dir, "model.safetensors.index.json")
|
| 142 |
+
head_weight_name = "lm_head.weight" if not self.config.tie_word_embeddings else "model.embed_tokens.weight"
|
| 143 |
+
if os.path.exists(index_json_file):
|
| 144 |
+
with open(index_json_file, "r") as f:
|
| 145 |
+
index = json.load(f)
|
| 146 |
+
model_name = index["weight_map"][head_weight_name]
|
| 147 |
+
else:
|
| 148 |
+
model_name = "model.safetensors"
|
| 149 |
+
with safe_open(os.path.join(lm_head_dir, model_name), framework="pt") as f:
|
| 150 |
+
loaded_weight = f.get_tensor(head_weight_name)
|
| 151 |
+
loaded_weight = loaded_weight.to(dtype=self.lm_head.weight.dtype, device=self.lm_head.weight.device)
|
| 152 |
+
self.lm_head.weight.data.copy_(loaded_weight)
|
| 153 |
+
|
| 154 |
+
def patchify(self, img: torch.Tensor):
|
| 155 |
+
"""
|
| 156 |
+
img: (bsz, C, H, W)
|
| 157 |
+
x: (bsz, H * W / patch_size**2, patch_size**2 * C)
|
| 158 |
+
"""
|
| 159 |
+
bsz, c, h, w = img.shape
|
| 160 |
+
p = self.config.latent_patch_size
|
| 161 |
+
h_, w_ = h // p, w // p
|
| 162 |
+
|
| 163 |
+
img = img.reshape(bsz, c, h_, p, w_, p)
|
| 164 |
+
img = torch.einsum("nchpwq->nhwcpq", img)
|
| 165 |
+
x = img.reshape(bsz, h_ * w_, c * p**2)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
def unpatchify(self, x: torch.Tensor, h: int = None, w: int = None):
|
| 169 |
+
"""
|
| 170 |
+
x: (bsz, H * W / patch_size**2, patch_size**2 * C)
|
| 171 |
+
img: (bsz, C, H, W)
|
| 172 |
+
"""
|
| 173 |
+
bsz = x.shape[0]
|
| 174 |
+
p = self.config.latent_patch_size
|
| 175 |
+
c = self.config.latent_channels
|
| 176 |
+
if h is None and w is None:
|
| 177 |
+
h_ = w_ = int(x.shape[1] ** 0.5)
|
| 178 |
+
else:
|
| 179 |
+
h_, w_ = h, w
|
| 180 |
+
assert h_ * w_ == x.shape[1], f"Invalid sequence length {x.shape[1]}."
|
| 181 |
+
|
| 182 |
+
x = x.reshape(bsz, h_, w_, c, p, p)
|
| 183 |
+
x = torch.einsum("nhwcpq->nchpwq", x)
|
| 184 |
+
img = x.reshape(bsz, c, h_ * p, w_ * p)
|
| 185 |
+
return img
|
| 186 |
+
|
| 187 |
+
def prepare_inputs_embeds(self, input_ids: torch.LongTensor | None = None, latents: torch.FloatTensor | None = None):
|
| 188 |
+
if latents is None:
|
| 189 |
+
if not self.training:
|
| 190 |
+
return self.embed_tokens(input_ids)
|
| 191 |
+
else: # dummy forward for image pass, for the consistent shape of gradient.
|
| 192 |
+
raise NotImplementedError("Dummy forward for image pass is not implemented.")
|
| 193 |
+
else:
|
| 194 |
+
bs, seq_length = input_ids.shape
|
| 195 |
+
inputs_embeds = torch.zeros(
|
| 196 |
+
(bs, seq_length, self.config.hidden_size),
|
| 197 |
+
device=self.embed_tokens.weight.device,
|
| 198 |
+
dtype=self.embed_tokens.weight.dtype,
|
| 199 |
+
)
|
| 200 |
+
im_indices = input_ids == self.config.image_placeholder_id
|
| 201 |
+
lm_indices = ~im_indices
|
| 202 |
+
|
| 203 |
+
if isinstance(latents, list):
|
| 204 |
+
tokens = torch.cat([self.patchify(latent) for latent in latents], dim=1)
|
| 205 |
+
else:
|
| 206 |
+
tokens = self.patchify(latents)
|
| 207 |
+
# tokens = tokens.reshape(1, -1, tokens.shape[-1])
|
| 208 |
+
|
| 209 |
+
image_embeds = self.image_in_projector(tokens)
|
| 210 |
+
image_embeds = image_embeds.view(-1, self.config.hidden_size)
|
| 211 |
+
|
| 212 |
+
token_embeds = self.embed_tokens(input_ids[lm_indices])
|
| 213 |
+
|
| 214 |
+
inputs_embeds[im_indices] = image_embeds.to(inputs_embeds.dtype)
|
| 215 |
+
inputs_embeds[lm_indices] = token_embeds
|
| 216 |
+
|
| 217 |
+
return inputs_embeds
|
| 218 |
+
|
| 219 |
+
def _update_causal_mask(
|
| 220 |
+
self,
|
| 221 |
+
attention_mask: torch.Tensor,
|
| 222 |
+
input_tensor: torch.Tensor,
|
| 223 |
+
cache_position: torch.Tensor,
|
| 224 |
+
past_key_values: Cache,
|
| 225 |
+
output_attentions: bool,
|
| 226 |
+
):
|
| 227 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 228 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 229 |
+
return attention_mask
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 233 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 234 |
+
# to infer the attention mask.
|
| 235 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 236 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 237 |
+
|
| 238 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 239 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 240 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 241 |
+
attention_mask,
|
| 242 |
+
inputs_embeds=input_tensor,
|
| 243 |
+
past_key_values_length=past_seen_tokens,
|
| 244 |
+
is_training=self.training,
|
| 245 |
+
):
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 249 |
+
sequence_length = input_tensor.shape[1]
|
| 250 |
+
if using_static_cache:
|
| 251 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 252 |
+
else:
|
| 253 |
+
target_length = (
|
| 254 |
+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 258 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 259 |
+
attention_mask,
|
| 260 |
+
sequence_length=sequence_length,
|
| 261 |
+
target_length=target_length,
|
| 262 |
+
dtype=dtype,
|
| 263 |
+
device=device,
|
| 264 |
+
cache_position=cache_position,
|
| 265 |
+
batch_size=input_tensor.shape[0],
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if (
|
| 269 |
+
self.config._attn_implementation == "sdpa"
|
| 270 |
+
and attention_mask is not None
|
| 271 |
+
and attention_mask.device.type == "cuda"
|
| 272 |
+
and not output_attentions
|
| 273 |
+
):
|
| 274 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 275 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 276 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 277 |
+
min_dtype = torch.finfo(dtype).min
|
| 278 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 279 |
+
|
| 280 |
+
return causal_mask
|
| 281 |
+
|
| 282 |
+
@staticmethod
|
| 283 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 284 |
+
attention_mask: torch.Tensor,
|
| 285 |
+
sequence_length: int,
|
| 286 |
+
target_length: int,
|
| 287 |
+
dtype: torch.dtype,
|
| 288 |
+
device: torch.device,
|
| 289 |
+
cache_position: torch.Tensor,
|
| 290 |
+
batch_size: int,
|
| 291 |
+
**kwargs,
|
| 292 |
+
):
|
| 293 |
+
"""
|
| 294 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 295 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
attention_mask (`torch.Tensor`):
|
| 299 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 300 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 301 |
+
sequence_length (`int`):
|
| 302 |
+
The sequence length being processed.
|
| 303 |
+
target_length (`int`):
|
| 304 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 305 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 306 |
+
dtype (`torch.dtype`):
|
| 307 |
+
The dtype to use for the 4D attention mask.
|
| 308 |
+
device (`torch.device`):
|
| 309 |
+
The device to plcae the 4D attention mask on.
|
| 310 |
+
cache_position (`torch.Tensor`):
|
| 311 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 312 |
+
batch_size (`torch.Tensor`):
|
| 313 |
+
Batch size.
|
| 314 |
+
"""
|
| 315 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 316 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 317 |
+
causal_mask = attention_mask
|
| 318 |
+
else:
|
| 319 |
+
min_dtype = torch.finfo(dtype).min
|
| 320 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 321 |
+
if sequence_length != 1:
|
| 322 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 323 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 324 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 325 |
+
if attention_mask is not None:
|
| 326 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 327 |
+
mask_length = attention_mask.shape[-1]
|
| 328 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
| 329 |
+
padding_mask = padding_mask == 0
|
| 330 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
|
| 331 |
+
|
| 332 |
+
return causal_mask
|
| 333 |
+
|
| 334 |
+
@smart_compile()
|
| 335 |
+
def forward_model(
|
| 336 |
+
self,
|
| 337 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 338 |
+
attention_mask: torch.Tensor | None = None,
|
| 339 |
+
past_key_values: Cache | list[torch.FloatTensor] | None = None,
|
| 340 |
+
use_cache: bool | None = None,
|
| 341 |
+
output_attentions: bool | None = None,
|
| 342 |
+
output_hidden_states: bool | None = None,
|
| 343 |
+
cache_position: torch.LongTensor | None = None,
|
| 344 |
+
) -> tuple | BaseModelOutputWithPast:
|
| 345 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 346 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 347 |
+
|
| 348 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 349 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 350 |
+
use_cache = False
|
| 351 |
+
|
| 352 |
+
if use_cache and past_key_values is None:
|
| 353 |
+
past_key_values = DynamicCache()
|
| 354 |
+
|
| 355 |
+
if cache_position is None:
|
| 356 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 357 |
+
cache_position = torch.arange(
|
| 358 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 359 |
+
)
|
| 360 |
+
position_ids = cache_position.unsqueeze(0)
|
| 361 |
+
|
| 362 |
+
causal_mask = self._update_causal_mask(
|
| 363 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 364 |
+
)
|
| 365 |
+
hidden_states = inputs_embeds
|
| 366 |
+
|
| 367 |
+
# create position embeddings to be shared across the decoder layers
|
| 368 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 369 |
+
|
| 370 |
+
# decoder layers
|
| 371 |
+
all_hidden_states = () if output_hidden_states else None
|
| 372 |
+
all_self_attns = () if output_attentions else None
|
| 373 |
+
|
| 374 |
+
for decoder_layer in self.layers:
|
| 375 |
+
if output_hidden_states:
|
| 376 |
+
all_hidden_states += (hidden_states,)
|
| 377 |
+
|
| 378 |
+
if self.gradient_checkpointing and self.training:
|
| 379 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 380 |
+
decoder_layer.__call__,
|
| 381 |
+
hidden_states,
|
| 382 |
+
causal_mask,
|
| 383 |
+
past_key_values,
|
| 384 |
+
output_attentions,
|
| 385 |
+
use_cache,
|
| 386 |
+
cache_position,
|
| 387 |
+
position_embeddings,
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
layer_outputs = decoder_layer(
|
| 391 |
+
hidden_states,
|
| 392 |
+
attention_mask=causal_mask,
|
| 393 |
+
past_key_value=past_key_values,
|
| 394 |
+
output_attentions=output_attentions,
|
| 395 |
+
use_cache=use_cache,
|
| 396 |
+
cache_position=cache_position,
|
| 397 |
+
position_embeddings=position_embeddings,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
hidden_states = layer_outputs[0]
|
| 401 |
+
|
| 402 |
+
if output_attentions:
|
| 403 |
+
all_self_attns += (layer_outputs[1],)
|
| 404 |
+
|
| 405 |
+
hidden_states = self.norm(hidden_states)
|
| 406 |
+
|
| 407 |
+
# add hidden states from the last decoder layer
|
| 408 |
+
if output_hidden_states:
|
| 409 |
+
all_hidden_states += (hidden_states,)
|
| 410 |
+
|
| 411 |
+
return BaseModelOutputWithPast(
|
| 412 |
+
last_hidden_state=hidden_states,
|
| 413 |
+
past_key_values=past_key_values if use_cache else None,
|
| 414 |
+
hidden_states=all_hidden_states,
|
| 415 |
+
attentions=all_self_attns,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def prepare_inputs_for_generation(
|
| 421 |
+
self,
|
| 422 |
+
input_ids: torch.LongTensor,
|
| 423 |
+
past_key_values: Cache | None = None,
|
| 424 |
+
attention_mask: torch.LongTensor | None = None,
|
| 425 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 426 |
+
cache_position: torch.LongTensor | None = None,
|
| 427 |
+
**kwargs,
|
| 428 |
+
):
|
| 429 |
+
"""
|
| 430 |
+
Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
|
| 431 |
+
slicing inputs given the existing cache.
|
| 432 |
+
|
| 433 |
+
See the forward pass in the model documentation for expected arguments (different models might have different
|
| 434 |
+
requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
# 1. Handle BC:
|
| 438 |
+
model_inputs = {}
|
| 439 |
+
# - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
|
| 440 |
+
if self._supports_cache_class:
|
| 441 |
+
model_inputs["cache_position"] = cache_position
|
| 442 |
+
# - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
|
| 443 |
+
# function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
|
| 444 |
+
# (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
|
| 445 |
+
elif cache_position is None:
|
| 446 |
+
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 447 |
+
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
| 448 |
+
|
| 449 |
+
# 2. Generic cache-dependent input preparation
|
| 450 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
| 451 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
| 452 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
| 453 |
+
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
|
| 454 |
+
if past_key_values is not None:
|
| 455 |
+
model_inputs["past_key_values"] = past_key_values
|
| 456 |
+
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
|
| 457 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
| 458 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
| 459 |
+
input_ids = input_ids[:, cache_position]
|
| 460 |
+
|
| 461 |
+
# 3. Prepare base model inputs
|
| 462 |
+
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
| 463 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 464 |
+
if not self.config.is_encoder_decoder:
|
| 465 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
| 466 |
+
model_inputs[input_ids_key] = None
|
| 467 |
+
model_inputs["inputs_embeds"] = inputs_embeds
|
| 468 |
+
else:
|
| 469 |
+
# `clone` calls in this function ensure a consistent stride. See #32227
|
| 470 |
+
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
| 471 |
+
model_inputs["inputs_embeds"] = None
|
| 472 |
+
else:
|
| 473 |
+
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
| 474 |
+
|
| 475 |
+
# 4. Create missing `position_ids` on the fly
|
| 476 |
+
if (
|
| 477 |
+
attention_mask is not None
|
| 478 |
+
and kwargs.get("position_ids") is None
|
| 479 |
+
and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
|
| 480 |
+
):
|
| 481 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 482 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 483 |
+
kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
|
| 484 |
+
|
| 485 |
+
# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
| 486 |
+
for model_input_name in ["position_ids", "token_type_ids"]:
|
| 487 |
+
model_input = kwargs.get(model_input_name)
|
| 488 |
+
if model_input is not None:
|
| 489 |
+
if past_key_values:
|
| 490 |
+
model_input = model_input[:, -input_ids.shape[1] :]
|
| 491 |
+
model_input = model_input.clone(memory_format=torch.contiguous_format)
|
| 492 |
+
model_inputs[model_input_name] = model_input
|
| 493 |
+
|
| 494 |
+
# 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
|
| 495 |
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
| 496 |
+
if model_inputs["inputs_embeds"] is not None:
|
| 497 |
+
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
| 498 |
+
device = model_inputs["inputs_embeds"].device
|
| 499 |
+
else:
|
| 500 |
+
batch_size, sequence_length = model_inputs[input_ids_key].shape
|
| 501 |
+
device = model_inputs[input_ids_key].device
|
| 502 |
+
|
| 503 |
+
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
| 504 |
+
# the 4D causal mask exists, it should be present in the base model (XXXModel class).
|
| 505 |
+
base_model = getattr(self, self.base_model_prefix, None)
|
| 506 |
+
if base_model is None:
|
| 507 |
+
causal_mask_creation_function = getattr(self, "_prepare_4d_causal_attention_mask_with_cache_position", None)
|
| 508 |
+
else:
|
| 509 |
+
causal_mask_creation_function = getattr(
|
| 510 |
+
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
| 511 |
+
)
|
| 512 |
+
if causal_mask_creation_function is None:
|
| 513 |
+
logger.warning_once(
|
| 514 |
+
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
| 515 |
+
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
| 516 |
+
"writing code, see Llama for an example implementation. If you're a user, please report this "
|
| 517 |
+
"issue on GitHub."
|
| 518 |
+
)
|
| 519 |
+
else:
|
| 520 |
+
attention_mask = causal_mask_creation_function(
|
| 521 |
+
attention_mask,
|
| 522 |
+
sequence_length=sequence_length,
|
| 523 |
+
target_length=past_key_values.get_max_cache_shape(),
|
| 524 |
+
dtype=self.dtype,
|
| 525 |
+
device=device,
|
| 526 |
+
cache_position=cache_position,
|
| 527 |
+
batch_size=batch_size,
|
| 528 |
+
config=self.config,
|
| 529 |
+
past_key_values=past_key_values,
|
| 530 |
+
)
|
| 531 |
+
if attention_mask is not None:
|
| 532 |
+
model_inputs["attention_mask"] = attention_mask
|
| 533 |
+
|
| 534 |
+
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
| 535 |
+
for key, value in kwargs.items():
|
| 536 |
+
if key not in model_inputs:
|
| 537 |
+
model_inputs[key] = value
|
| 538 |
+
|
| 539 |
+
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
| 540 |
+
model_inputs.pop("labels", None)
|
| 541 |
+
return model_inputs
|
| 542 |
+
|
| 543 |
+
@torch.no_grad()
|
| 544 |
+
def generate(self, inputs: torch.LongTensor = None, **kwargs):
|
| 545 |
+
input_ids = kwargs.pop("input_ids")
|
| 546 |
+
latents = kwargs.pop("latents", None)
|
| 547 |
+
inputs_embeds = self.prepare_inputs_embeds(input_ids, latents)
|
| 548 |
+
return super().generate(inputs=inputs, input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
|
| 549 |
+
|
| 550 |
+
def gradient_checkpointing_enable(self, **kwargs):
|
| 551 |
+
super().gradient_checkpointing_enable(**kwargs)
|
| 552 |
+
|
| 553 |
+
self.image_head.net.grad_checkpointing = True
|
pytorch-model-00001-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:909083ea5deb26b37d66d4966869b0c3310c06aa72a88974c293d0fe62a489b9
|
| 3 |
+
size 9962132680
|
pytorch-model-00002-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:630af8946d2326406253ede6e3cb143d935ba19595059ce56af86fd50442e2d3
|
| 3 |
+
size 9909693448
|
pytorch-model-00003-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47a71e20a18992b9008bae6fa523b69824c1e67b5656bbd8fa6b442bf7405c72
|
| 3 |
+
size 8478742432
|
pytorch-model-00004-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b5f5fe250f5cbce219aadb61c9a44903739e510a66184cc960bfce87175bc34
|
| 3 |
+
size 1557135464
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers==0.34.0
|
| 2 |
+
einops==0.8.1
|
| 3 |
+
gradio==5.42.0
|
| 4 |
+
loguru==0.7.3
|
| 5 |
+
numpy==1.26.4
|
| 6 |
+
omegaconf==2.3.0
|
| 7 |
+
Pillow==11.0.0
|
| 8 |
+
Requests==2.32.4
|
| 9 |
+
safetensors==0.5.3
|
| 10 |
+
tabulate==0.9.0
|
| 11 |
+
torch==2.5.1
|
| 12 |
+
torchvision==0.20.1
|
| 13 |
+
tqdm==4.67.1
|
| 14 |
+
transformers==4.55.0
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|image_area|>",
|
| 4 |
+
"<|begin_of_image|>",
|
| 5 |
+
"<|end_of_image|>",
|
| 6 |
+
"<|image_placeholder|>",
|
| 7 |
+
"<|begin_of_prompt_refinement|>",
|
| 8 |
+
"<|end_of_prompt_refinement|>",
|
| 9 |
+
"<|begin_of_thinking|>",
|
| 10 |
+
"<|end_of_thinking|>",
|
| 11 |
+
"<|beginoftext|>"
|
| 12 |
+
],
|
| 13 |
+
"eos_token": {
|
| 14 |
+
"content": "<|endoftext|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false
|
| 19 |
+
},
|
| 20 |
+
"pad_token": {
|
| 21 |
+
"content": "[PAD]",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false
|
| 26 |
+
}
|
| 27 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:310b48c809fba04c32e7f7cdac4d0fb1c00140d8914e0b0163307f64e5330a92
|
| 3 |
+
size 11423853
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "[PAD]",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": true
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "<|image_area|>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": true
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<|begin_of_image|>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": true
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "<|end_of_image|>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": true
|
| 212 |
+
},
|
| 213 |
+
"151669": {
|
| 214 |
+
"content": "<|image_placeholder|>",
|
| 215 |
+
"lstrip": false,
|
| 216 |
+
"normalized": false,
|
| 217 |
+
"rstrip": false,
|
| 218 |
+
"single_word": false,
|
| 219 |
+
"special": true
|
| 220 |
+
},
|
| 221 |
+
"151670": {
|
| 222 |
+
"content": "<|begin_of_prompt_refinement|>",
|
| 223 |
+
"lstrip": false,
|
| 224 |
+
"normalized": false,
|
| 225 |
+
"rstrip": false,
|
| 226 |
+
"single_word": false,
|
| 227 |
+
"special": true
|
| 228 |
+
},
|
| 229 |
+
"151671": {
|
| 230 |
+
"content": "<|end_of_prompt_refinement|>",
|
| 231 |
+
"lstrip": false,
|
| 232 |
+
"normalized": false,
|
| 233 |
+
"rstrip": false,
|
| 234 |
+
"single_word": false,
|
| 235 |
+
"special": true
|
| 236 |
+
},
|
| 237 |
+
"151672": {
|
| 238 |
+
"content": "<|begin_of_thinking|>",
|
| 239 |
+
"lstrip": false,
|
| 240 |
+
"normalized": false,
|
| 241 |
+
"rstrip": false,
|
| 242 |
+
"single_word": false,
|
| 243 |
+
"special": true
|
| 244 |
+
},
|
| 245 |
+
"151673": {
|
| 246 |
+
"content": "<|end_of_thinking|>",
|
| 247 |
+
"lstrip": false,
|
| 248 |
+
"normalized": false,
|
| 249 |
+
"rstrip": false,
|
| 250 |
+
"single_word": false,
|
| 251 |
+
"special": true
|
| 252 |
+
},
|
| 253 |
+
"151674": {
|
| 254 |
+
"content": "<|beginoftext|>",
|
| 255 |
+
"lstrip": false,
|
| 256 |
+
"normalized": false,
|
| 257 |
+
"rstrip": false,
|
| 258 |
+
"single_word": false,
|
| 259 |
+
"special": true
|
| 260 |
+
}
|
| 261 |
+
},
|
| 262 |
+
"additional_special_tokens": [
|
| 263 |
+
"<|image_area|>",
|
| 264 |
+
"<|begin_of_image|>",
|
| 265 |
+
"<|end_of_image|>",
|
| 266 |
+
"<|image_placeholder|>",
|
| 267 |
+
"<|begin_of_prompt_refinement|>",
|
| 268 |
+
"<|end_of_prompt_refinement|>",
|
| 269 |
+
"<|begin_of_thinking|>",
|
| 270 |
+
"<|end_of_thinking|>",
|
| 271 |
+
"<|beginoftext|>"
|
| 272 |
+
],
|
| 273 |
+
"bos_token": null,
|
| 274 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
| 275 |
+
"clean_up_tokenization_spaces": false,
|
| 276 |
+
"eos_token": "<|endoftext|>",
|
| 277 |
+
"errors": "replace",
|
| 278 |
+
"extra_special_tokens": {},
|
| 279 |
+
"model_max_length": 8192,
|
| 280 |
+
"pad_token": "[PAD]",
|
| 281 |
+
"padding_side": "right",
|
| 282 |
+
"split_special_tokens": false,
|
| 283 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 284 |
+
"unk_token": null
|
| 285 |
+
}
|
utils/__pycache__/compile_utils.cpython-310.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
utils/__pycache__/image_utils.cpython-310.pyc
ADDED
|
Binary file (8.46 kB). View file
|
|
|
utils/__pycache__/misc.cpython-310.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
utils/__pycache__/model_utils.cpython-310.pyc
ADDED
|
Binary file (4.38 kB). View file
|
|
|
utils/aspect_ratio.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import PIL.Image
|
| 3 |
+
|
| 4 |
+
ANY_ASPECT_RATIO = (0, 0)
|
| 5 |
+
|
| 6 |
+
HW_ASPECT_RATIOS = [
|
| 7 |
+
(8, 32), # 256
|
| 8 |
+
(9, 28), # 252
|
| 9 |
+
(10, 25), # 250
|
| 10 |
+
(11, 23), # 253
|
| 11 |
+
(12, 21), # 252
|
| 12 |
+
(13, 19), # 247
|
| 13 |
+
(14, 18), # 252
|
| 14 |
+
(15, 17), # 255
|
| 15 |
+
(16, 16), # 256
|
| 16 |
+
(17, 15), # 255
|
| 17 |
+
(18, 14), # 252
|
| 18 |
+
(19, 13), # 247
|
| 19 |
+
(21, 12), # 252
|
| 20 |
+
(23, 11), # 253
|
| 21 |
+
(25, 10), # 250
|
| 22 |
+
(28, 9), # 252
|
| 23 |
+
(32, 8), # 256
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_ar_base(ars: list[tuple[int, int]] = HW_ASPECT_RATIOS):
|
| 28 |
+
sqrt_products = [round(np.sqrt(h * w)) for h, w in ars]
|
| 29 |
+
return round(np.mean(sqrt_products))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def ar2str(h: int, w: int) -> str:
|
| 33 |
+
return f"{h}*{w}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def str2ar(s: str) -> tuple[int, int]:
|
| 37 |
+
return tuple(map(int, s.split("*")))
|
| 38 |
+
|
| 39 |
+
def center_crop_arr_with_buckets(pil_image, ars: list[tuple[int, int]] = HW_ASPECT_RATIOS, crop=True, buckets: list[int] = [256, 512, 768, 1024]):
|
| 40 |
+
"""
|
| 41 |
+
Center crop the image to match the closest aspect ratio from the provided list.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
pil_image: PIL Image to be cropped
|
| 45 |
+
image_size: Target size for the smaller dimension
|
| 46 |
+
ars: List of aspect ratios as (height, width) tuples
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
PIL Image cropped to the closest aspect ratio
|
| 50 |
+
"""
|
| 51 |
+
# ar_base = get_ar_base(ars)
|
| 52 |
+
# Get current image dimensions
|
| 53 |
+
width, height = pil_image.size
|
| 54 |
+
|
| 55 |
+
buckets = sorted(buckets, reverse=True)
|
| 56 |
+
image_size = buckets[-1]
|
| 57 |
+
|
| 58 |
+
for bucket in buckets:
|
| 59 |
+
if width * height >= bucket * bucket:
|
| 60 |
+
image_size = bucket
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
return center_crop_arr_with_ar(pil_image, image_size, ars, crop)
|
| 64 |
+
|
| 65 |
+
def center_crop_arr_with_ar(pil_image, image_size: int, ars: list[tuple[int, int]] = HW_ASPECT_RATIOS, crop=True):
|
| 66 |
+
"""
|
| 67 |
+
Center crop the image to match the closest aspect ratio from the provided list.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
pil_image: PIL Image to be cropped
|
| 71 |
+
image_sizes: Target size for the smaller dimension
|
| 72 |
+
ars: List of aspect ratios as (height, width) tuples
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
PIL Image cropped to the closest aspect ratio
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
ar_base = get_ar_base(ars)
|
| 79 |
+
assert image_size % ar_base == 0, f"image_size must be divisible by {ar_base}"
|
| 80 |
+
|
| 81 |
+
# Get current image dimensions
|
| 82 |
+
width, height = pil_image.size
|
| 83 |
+
|
| 84 |
+
current_ar = height / width
|
| 85 |
+
|
| 86 |
+
# Find the closest aspect ratio
|
| 87 |
+
closest_ar_idx = np.argmin([abs(current_ar - (h / w)) for h, w in ars])
|
| 88 |
+
target_h, target_w = ars[closest_ar_idx]
|
| 89 |
+
|
| 90 |
+
if crop:
|
| 91 |
+
target_h, target_w = round(image_size / ar_base * target_h), round(image_size / ar_base * target_w)
|
| 92 |
+
|
| 93 |
+
# First, resize the image while maintaining aspect ratio to ensure the smaller dimension is at least the target size
|
| 94 |
+
scale = max(target_h / height, target_w / width)
|
| 95 |
+
new_height = round(height * scale)
|
| 96 |
+
new_width = round(width * scale)
|
| 97 |
+
pil_image = pil_image.resize((new_width, new_height), resample=PIL.Image.LANCZOS)
|
| 98 |
+
|
| 99 |
+
arr = np.array(pil_image)
|
| 100 |
+
# Then perform center crop to the target dimensions
|
| 101 |
+
crop_y = (new_height - target_h) // 2
|
| 102 |
+
crop_x = (new_width - target_w) // 2
|
| 103 |
+
|
| 104 |
+
return PIL.Image.fromarray(arr[crop_y : crop_y + target_h, crop_x : crop_x + target_w])
|
| 105 |
+
else:
|
| 106 |
+
scale = image_size // ar_base
|
| 107 |
+
return pil_image.resize((round(target_w * scale), round(target_h * scale)), resample=PIL.Image.LANCZOS)
|
utils/compile_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
from typing import Callable, Dict, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Usage:
|
| 12 |
+
|
| 13 |
+
1. Control through environment variable (at startup):
|
| 14 |
+
export TORCH_COMPILE_ENABLE=true
|
| 15 |
+
python your_script.py
|
| 16 |
+
|
| 17 |
+
2. Control through environment variable (disable):
|
| 18 |
+
export TORCH_COMPILE_ENABLE=false # or not set
|
| 19 |
+
python your_script.py
|
| 20 |
+
|
| 21 |
+
3. Dynamically control in code:
|
| 22 |
+
compile_manager.set_compile_enabled(True) # enable
|
| 23 |
+
compile_manager.set_compile_enabled(False) # disable
|
| 24 |
+
|
| 25 |
+
4. Select version at runtime:
|
| 26 |
+
# use the version configured
|
| 27 |
+
result = my_function(args)
|
| 28 |
+
|
| 29 |
+
# force use the original version
|
| 30 |
+
result = my_function.original(args)
|
| 31 |
+
|
| 32 |
+
# force use the compiled version
|
| 33 |
+
result = my_function.compiled(args)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# Global configuration: control whether to enable compile through environment variables
|
| 37 |
+
# Default set this env to true
|
| 38 |
+
ENABLE_TORCH_COMPILE = os.getenv("ENABLE_TORCH_COMPILE", "false").lower() == "true"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CompileManager:
|
| 42 |
+
"""Global controller for torch.compile"""
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
self.compile_enabled = ENABLE_TORCH_COMPILE
|
| 46 |
+
self.compiled_functions: Dict[str, Callable] = {}
|
| 47 |
+
self.original_functions: Dict[str, Callable] = {}
|
| 48 |
+
|
| 49 |
+
def set_compile_enabled(self, enabled: bool):
|
| 50 |
+
"""Dynamic setting of whether to enable compile"""
|
| 51 |
+
self.compile_enabled = enabled
|
| 52 |
+
|
| 53 |
+
def get_compile_status(self):
|
| 54 |
+
"""Get the current compile status"""
|
| 55 |
+
return self.compile_enabled
|
| 56 |
+
|
| 57 |
+
@contextlib.contextmanager
|
| 58 |
+
def compile_disabled(self):
|
| 59 |
+
"""Temporarily disable compile within the context"""
|
| 60 |
+
original_status = self.compile_enabled
|
| 61 |
+
try:
|
| 62 |
+
self.compile_enabled = False
|
| 63 |
+
yield
|
| 64 |
+
finally:
|
| 65 |
+
self.compile_enabled = original_status
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# global instance
|
| 69 |
+
compile_manager = CompileManager()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def smart_compile(func: Optional[Callable] = None, **compile_kwargs):
|
| 73 |
+
"""
|
| 74 |
+
Smart compile decorator
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
func: The function to decorate
|
| 78 |
+
**compile_kwargs: Other compile parameters, see https://pytorch.org/docs/stable/generated/torch.compile.html
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def decorator(fn: Callable) -> Callable:
|
| 82 |
+
# save the original function
|
| 83 |
+
original_func = fn
|
| 84 |
+
# Use qualified name to handle functions with same name in different classes
|
| 85 |
+
# Include module name to handle functions with same name in different files
|
| 86 |
+
func_name = f"{fn.__module__}.{fn.__qualname__}"
|
| 87 |
+
compile_manager.original_functions[func_name] = original_func
|
| 88 |
+
|
| 89 |
+
# if compile is disabled, return the original function
|
| 90 |
+
if not compile_manager.compile_enabled:
|
| 91 |
+
# add attributes to the original function for later access
|
| 92 |
+
original_func.original = original_func
|
| 93 |
+
original_func.compiled = original_func # point to itself
|
| 94 |
+
return original_func
|
| 95 |
+
|
| 96 |
+
# create the compiled function
|
| 97 |
+
try:
|
| 98 |
+
compiled_func = torch.compile(original_func, **compile_kwargs)
|
| 99 |
+
compile_manager.compiled_functions[func_name] = compiled_func
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.warning(f"[WARNING] Failed to compile function {func_name}: {e}")
|
| 102 |
+
# if compile fails, revert to the original function
|
| 103 |
+
compiled_func = original_func
|
| 104 |
+
|
| 105 |
+
@functools.wraps(original_func)
|
| 106 |
+
def wrapper(*args, **kwargs):
|
| 107 |
+
# check whether to use the compiled version at runtime
|
| 108 |
+
if compile_manager.compile_enabled:
|
| 109 |
+
return compiled_func(*args, **kwargs)
|
| 110 |
+
else:
|
| 111 |
+
return original_func(*args, **kwargs)
|
| 112 |
+
|
| 113 |
+
# add attributes to the wrapper for later access
|
| 114 |
+
wrapper.original = original_func
|
| 115 |
+
wrapper.compiled = compiled_func
|
| 116 |
+
|
| 117 |
+
return wrapper
|
| 118 |
+
|
| 119 |
+
# support direct use of @smart_compile or @smart_compile(...)
|
| 120 |
+
if func is not None:
|
| 121 |
+
return decorator(func)
|
| 122 |
+
return decorator
|
utils/image_utils.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
from typing import Literal, TypeAlias
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import PIL.Image
|
| 7 |
+
import PIL.ImageOps
|
| 8 |
+
import requests
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
- pil: `PIL.Image.Image`, size (w, h), seamless conversion between `uint8`
|
| 13 |
+
- np: `np.ndarray`, shape (h, w, c), default `np.uint8`
|
| 14 |
+
- pt: `torch.Tensor`, shape (c, h, w), default `torch.uint8`
|
| 15 |
+
"""
|
| 16 |
+
ImageType: TypeAlias = PIL.Image.Image | np.ndarray | torch.Tensor
|
| 17 |
+
ImageTypeStr: TypeAlias = Literal["pil", "np", "pt"]
|
| 18 |
+
ImageFormat: TypeAlias = Literal["JPEG", "PNG"]
|
| 19 |
+
DataFormat: TypeAlias = Literal["255", "01", "11"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
IMG_SUPPORT_MODE = ["L", "LA", "RGB", "RGBA", "CMYK", "P", "1"]
|
| 23 |
+
IMAGE_EXT_LOWER = ["png", "jpeg", "jpg", "webp"]
|
| 24 |
+
IMAGE_EXT = IMAGE_EXT_LOWER + [_ext.upper() for _ext in IMAGE_EXT_LOWER]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def check_image_type(image: ImageType):
|
| 28 |
+
if not (isinstance(image, PIL.Image.Image) or isinstance(image, np.ndarray) or isinstance(image, torch.Tensor)):
|
| 29 |
+
raise TypeError(f"`image` should be PIL Image, ndarray or Tensor. Got `{type(image)}`.")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
| 33 |
+
# Automatically adjust the orientation of the image to match the direction it was taken.
|
| 34 |
+
image = PIL.ImageOps.exif_transpose(image)
|
| 35 |
+
|
| 36 |
+
if image.mode not in IMG_SUPPORT_MODE:
|
| 37 |
+
raise ValueError(f"Only support mode in `{IMG_SUPPORT_MODE}`, got `{image.mode}`")
|
| 38 |
+
|
| 39 |
+
if image.mode == "LA":
|
| 40 |
+
image = image.convert("RGBA")
|
| 41 |
+
|
| 42 |
+
# add white background for RGBA images, and convert to RGB
|
| 43 |
+
if image.mode == "RGBA":
|
| 44 |
+
background = PIL.Image.new("RGBA", image.size, "white")
|
| 45 |
+
image = PIL.Image.alpha_composite(background, image).convert("RGB")
|
| 46 |
+
|
| 47 |
+
# then convert to RGB
|
| 48 |
+
image = image.convert("RGB")
|
| 49 |
+
|
| 50 |
+
return image
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_image(
|
| 54 |
+
image: str | os.PathLike | PIL.Image.Image | bytes,
|
| 55 |
+
*,
|
| 56 |
+
output_type: ImageTypeStr = "pil",
|
| 57 |
+
) -> ImageType:
|
| 58 |
+
"""
|
| 59 |
+
Loads `image` to a PIL Image, NumPy array or PyTorch tensor.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
image (str | PIL.Image.Image): The path to image or PIL Image.
|
| 63 |
+
mode (ImageMode, optional): The mode to convert to. Defaults to None (no conversion).
|
| 64 |
+
The current version supports all possible conversions between "L", "RGB", "RGBA".
|
| 65 |
+
output_type (ImageTypeStr, optional): The type of the output image. Defaults to "pil".
|
| 66 |
+
The current version supports "pil", "np", "pt".
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
ImageType: The loaded image in the given type.
|
| 70 |
+
"""
|
| 71 |
+
timeout = 10
|
| 72 |
+
# Load the `image` into a PIL Image.
|
| 73 |
+
if isinstance(image, str) or isinstance(image, os.PathLike):
|
| 74 |
+
if image.startswith("http://") or image.startswith("https://"):
|
| 75 |
+
try:
|
| 76 |
+
image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw)
|
| 77 |
+
except requests.exceptions.Timeout:
|
| 78 |
+
raise ValueError(f"HTTP request timed out after {timeout} seconds")
|
| 79 |
+
elif os.path.isfile(image):
|
| 80 |
+
image = PIL.Image.open(image)
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"Incorrect path or url, URLs must start with `http://`, `https://` or `s3+[profile]://`, and `{image}` is not a valid path."
|
| 84 |
+
)
|
| 85 |
+
elif isinstance(image, PIL.Image.Image):
|
| 86 |
+
image = image
|
| 87 |
+
elif isinstance(image, bytes):
|
| 88 |
+
image = PIL.Image.open(io.BytesIO(image))
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"`image` must be a path or PIL Image, got `{type(image)}`")
|
| 91 |
+
|
| 92 |
+
image = to_rgb(image)
|
| 93 |
+
|
| 94 |
+
if output_type == "pil":
|
| 95 |
+
image = image
|
| 96 |
+
elif output_type == "np":
|
| 97 |
+
image = to_np(image)
|
| 98 |
+
elif output_type == "pt":
|
| 99 |
+
image = to_pt(image)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"`output_type` must be one of `{ImageTypeStr}`, got `{output_type}`")
|
| 102 |
+
|
| 103 |
+
return image
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def to_pil(image: ImageType, image_mode: DataFormat | None = None) -> PIL.Image.Image:
|
| 107 |
+
"""
|
| 108 |
+
Convert a NumPy array or a PyTorch tensor to a PIL image.
|
| 109 |
+
"""
|
| 110 |
+
check_image_type(image)
|
| 111 |
+
|
| 112 |
+
if isinstance(image, PIL.Image.Image):
|
| 113 |
+
return image
|
| 114 |
+
|
| 115 |
+
elif isinstance(image, np.ndarray):
|
| 116 |
+
image = normalize_np(image, image_mode)
|
| 117 |
+
|
| 118 |
+
elif isinstance(image, torch.Tensor):
|
| 119 |
+
image = normalize_pt(image, image_mode)
|
| 120 |
+
|
| 121 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
| 122 |
+
assert image.dtype == np.uint8, f"Supposed to convert `torch.uint8` to `np.uint8`, but got `{image.dtype}`"
|
| 123 |
+
|
| 124 |
+
mode_map = {1: "L", 3: "RGB"}
|
| 125 |
+
mode = mode_map[image.shape[-1]]
|
| 126 |
+
|
| 127 |
+
if image.shape[-1] == 1:
|
| 128 |
+
image = image[:, :, 0]
|
| 129 |
+
|
| 130 |
+
return PIL.Image.fromarray(image, mode=mode)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def to_np(image: ImageType, image_mode: DataFormat | None = None) -> np.ndarray:
|
| 134 |
+
"""
|
| 135 |
+
Convert a PIL image or a PyTorch tensor to a NumPy array.
|
| 136 |
+
"""
|
| 137 |
+
check_image_type(image)
|
| 138 |
+
|
| 139 |
+
if isinstance(image, PIL.Image.Image):
|
| 140 |
+
image = np.array(image, np.uint8, copy=True)
|
| 141 |
+
|
| 142 |
+
if isinstance(image, np.ndarray):
|
| 143 |
+
image = normalize_np(image, image_mode)
|
| 144 |
+
|
| 145 |
+
elif isinstance(image, torch.Tensor):
|
| 146 |
+
image = normalize_pt(image, image_mode)
|
| 147 |
+
|
| 148 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
| 149 |
+
assert image.dtype == np.uint8, f"Supposed to convert `torch.uint8` to `np.uint8`, but got `{image.dtype}`"
|
| 150 |
+
|
| 151 |
+
return image
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def to_pt(image: ImageType, image_mode: DataFormat | None = None) -> torch.Tensor:
|
| 155 |
+
"""
|
| 156 |
+
Convert a PIL image or a NumPy array to a PyTorch tensor.
|
| 157 |
+
"""
|
| 158 |
+
check_image_type(image)
|
| 159 |
+
|
| 160 |
+
if isinstance(image, torch.Tensor):
|
| 161 |
+
image = normalize_pt(image, image_mode)
|
| 162 |
+
return image
|
| 163 |
+
|
| 164 |
+
# convert PIL Image to NumPy array
|
| 165 |
+
if isinstance(image, PIL.Image.Image):
|
| 166 |
+
image = np.array(image, np.uint8, copy=True)
|
| 167 |
+
|
| 168 |
+
image = normalize_np(image, image_mode)
|
| 169 |
+
|
| 170 |
+
image = torch.from_numpy(image.transpose((2, 0, 1))).contiguous()
|
| 171 |
+
assert image.dtype == torch.uint8, f"Supposed to convert `np.uint8` to `torch.uint8`, but got `{image.dtype}`"
|
| 172 |
+
return image
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def normalize_np(image: np.ndarray, image_mode: DataFormat | None = None) -> np.ndarray:
|
| 176 |
+
"""
|
| 177 |
+
Normalize a NumPy array to the standard format of shape (h, w, c) and uint8.
|
| 178 |
+
"""
|
| 179 |
+
if image.ndim not in {2, 3}:
|
| 180 |
+
raise ValueError(f"`image` should be 2 or 3 dimensions. Got {image.ndim} dimensions.")
|
| 181 |
+
|
| 182 |
+
elif image.ndim == 2:
|
| 183 |
+
# if 2D image, add channel dimension (HWC)
|
| 184 |
+
image = np.expand_dims(image, 2)
|
| 185 |
+
|
| 186 |
+
if image.shape[-1] not in {1, 3}:
|
| 187 |
+
raise ValueError(f"`image` should have 1 (`L`) or 3 (`RGB`) channels. Got {image.shape[-1]} channels.")
|
| 188 |
+
|
| 189 |
+
image = to_dataformat(image, image_mode=image_mode, mode="255")
|
| 190 |
+
|
| 191 |
+
return image
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def normalize_pt(image: torch.Tensor, image_mode: DataFormat | None = None) -> torch.Tensor:
|
| 195 |
+
"""
|
| 196 |
+
Normalize a PyTorch tensor to the standard format of shape (c, h, w) and uint8.
|
| 197 |
+
"""
|
| 198 |
+
if image.ndimension() not in {2, 3}:
|
| 199 |
+
raise ValueError(f"`image` should be 2 or 3 dimensions. Got {image.ndimension()} dimensions.")
|
| 200 |
+
|
| 201 |
+
elif image.ndimension() == 2:
|
| 202 |
+
# if 2D image, add channel dimension (CHW)
|
| 203 |
+
image = image.unsqueeze(0)
|
| 204 |
+
|
| 205 |
+
# check number of channels
|
| 206 |
+
if image.shape[-3] not in {1, 3}:
|
| 207 |
+
raise ValueError(f"`image` should have 1 (`L`) or 3 (`RGB`) channels. Got {image.shape[-3]} channels.")
|
| 208 |
+
|
| 209 |
+
image = to_dataformat(image, image_mode=image_mode, mode="255")
|
| 210 |
+
|
| 211 |
+
return image
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def to_dataformat(
|
| 215 |
+
image: ImageType,
|
| 216 |
+
*,
|
| 217 |
+
image_mode: DataFormat | None = None,
|
| 218 |
+
mode: DataFormat = "255",
|
| 219 |
+
) -> np.ndarray | torch.Tensor:
|
| 220 |
+
check_image_type(image)
|
| 221 |
+
|
| 222 |
+
# convert PIL Image to NumPy array
|
| 223 |
+
if isinstance(image, PIL.Image.Image):
|
| 224 |
+
image = np.array(image, np.uint8, copy=True)
|
| 225 |
+
image_mode = "255"
|
| 226 |
+
|
| 227 |
+
# guess image mode
|
| 228 |
+
if image.dtype == np.uint8 or image.dtype == torch.uint8:
|
| 229 |
+
guess_image_mode = "255"
|
| 230 |
+
elif image.dtype == np.float32 or image.dtype == np.float16 or image.dtype == torch.float32 or image.dtype == torch.float16:
|
| 231 |
+
if image.min() < 0.0:
|
| 232 |
+
guess_image_mode = "11"
|
| 233 |
+
else:
|
| 234 |
+
guess_image_mode = "01"
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError(f"Unsupported dtype `{image.dtype}`")
|
| 237 |
+
|
| 238 |
+
if image_mode is None:
|
| 239 |
+
image_mode = guess_image_mode
|
| 240 |
+
else:
|
| 241 |
+
if guess_image_mode != image_mode:
|
| 242 |
+
print(f"Guess image mode is `{guess_image_mode}`, but image mode is `{image_mode}`")
|
| 243 |
+
|
| 244 |
+
if isinstance(image, np.ndarray):
|
| 245 |
+
if image_mode == "255" and mode != "255":
|
| 246 |
+
np.clip((image.astype(np.float32) / 255), 0, 1, out=image)
|
| 247 |
+
if mode == "11":
|
| 248 |
+
np.clip((image * 2 - 1), -1, 1, out=image)
|
| 249 |
+
|
| 250 |
+
elif image_mode == "01" and mode != "01":
|
| 251 |
+
if mode == "255":
|
| 252 |
+
np.clip(image, 0, 1, out=image)
|
| 253 |
+
image = (image * 255).round().astype(np.uint8)
|
| 254 |
+
elif mode == "11":
|
| 255 |
+
np.clip((image * 2 - 1), -1, 1, out=image)
|
| 256 |
+
|
| 257 |
+
elif image_mode == "11" and mode != "11":
|
| 258 |
+
np.clip((image / 2 + 0.5), 0, 1, out=image)
|
| 259 |
+
if mode == "255":
|
| 260 |
+
image = (image * 255).round().astype(np.uint8)
|
| 261 |
+
|
| 262 |
+
elif isinstance(image, torch.Tensor):
|
| 263 |
+
if image_mode == "255" and mode != "255":
|
| 264 |
+
image = image.to(dtype=torch.float32).div(255).clamp(0, 1)
|
| 265 |
+
if mode == "11":
|
| 266 |
+
image = (image * 2 - 1).clamp(-1, 1)
|
| 267 |
+
|
| 268 |
+
elif image_mode == "01" and mode != "01":
|
| 269 |
+
if mode == "255":
|
| 270 |
+
image = image.clamp(0, 1)
|
| 271 |
+
image = (image * 255).round().to(dtype=torch.uint8)
|
| 272 |
+
elif mode == "11":
|
| 273 |
+
image = (image * 2 - 1).clamp(-1, 1)
|
| 274 |
+
|
| 275 |
+
elif image_mode == "11" and mode != "11":
|
| 276 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 277 |
+
if mode == "255":
|
| 278 |
+
image = image.mul(255).round().to(dtype=torch.uint8)
|
| 279 |
+
|
| 280 |
+
return image
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def resize_image(pil_image, image_size):
|
| 284 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 285 |
+
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=PIL.Image.BOX)
|
| 286 |
+
|
| 287 |
+
scale = image_size / min(*pil_image.size)
|
| 288 |
+
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=PIL.Image.BICUBIC)
|
| 289 |
+
return pil_image
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def center_crop_arr(pil_image, image_size, crop=True):
|
| 293 |
+
"""
|
| 294 |
+
Center cropping implementation from ADM.
|
| 295 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| 296 |
+
"""
|
| 297 |
+
if crop:
|
| 298 |
+
pil_image = resize_image(pil_image, image_size)
|
| 299 |
+
arr = np.array(pil_image)
|
| 300 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 301 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 302 |
+
return PIL.Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
| 303 |
+
else:
|
| 304 |
+
# 将图像填充为正方形
|
| 305 |
+
width, height = pil_image.size
|
| 306 |
+
if width != height:
|
| 307 |
+
# 创建一个正方形画布,尺寸为较大的边长
|
| 308 |
+
max_dim = max(width, height)
|
| 309 |
+
padded_img = PIL.Image.new(pil_image.mode, (max_dim, max_dim), (0, 0, 0))
|
| 310 |
+
# 将原图居中粘贴到正方形画布上
|
| 311 |
+
padded_img.paste(pil_image, ((max_dim - width) // 2, (max_dim - height) // 2))
|
| 312 |
+
pil_image = padded_img
|
| 313 |
+
pil_image = resize_image(pil_image, image_size)
|
| 314 |
+
return pil_image
|
utils/misc.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_seed(seed: int, rank: int = 0):
|
| 9 |
+
random.seed(seed + rank)
|
| 10 |
+
np.random.seed(seed + rank)
|
| 11 |
+
torch.manual_seed(seed + rank)
|
| 12 |
+
torch.cuda.manual_seed_all(seed + rank)
|
| 13 |
+
torch.backends.cudnn.deterministic = True
|
| 14 |
+
os.environ["PYTHONHASHSEED"] = str(seed + rank)
|
| 15 |
+
|
| 16 |
+
class LargeInt(int):
|
| 17 |
+
def __new__(cls, value):
|
| 18 |
+
if isinstance(value, str):
|
| 19 |
+
units = {"K": 1e3, "M": 1e6, "B": 1e9, "T": 1e12}
|
| 20 |
+
last_char = value[-1].upper()
|
| 21 |
+
if last_char in units:
|
| 22 |
+
num = float(value[:-1]) * units[last_char]
|
| 23 |
+
return super(LargeInt, cls).__new__(cls, int(num))
|
| 24 |
+
else:
|
| 25 |
+
return super(LargeInt, cls).__new__(cls, int(value))
|
| 26 |
+
else:
|
| 27 |
+
return super(LargeInt, cls).__new__(cls, value)
|
| 28 |
+
|
| 29 |
+
def __str__(self):
|
| 30 |
+
value = int(self)
|
| 31 |
+
if abs(value) < 1000:
|
| 32 |
+
return f"{value}"
|
| 33 |
+
for unit in ["", "K", "M", "B", "T"]:
|
| 34 |
+
if abs(value) < 1000:
|
| 35 |
+
return f"{value:.1f}{unit}"
|
| 36 |
+
value /= 1000
|
| 37 |
+
return f"{value:.1f}P" # P stands for Peta, or 10^15
|
| 38 |
+
|
| 39 |
+
def __repr__(self):
|
| 40 |
+
return f'"{self.__str__()}"' # Ensure repr also returns the string with quotes
|
| 41 |
+
|
| 42 |
+
def __json__(self):
|
| 43 |
+
return f'"{self.__str__()}"'
|
| 44 |
+
|
| 45 |
+
def __add__(self, other):
|
| 46 |
+
if isinstance(other, int):
|
| 47 |
+
return LargeInt(super().__add__(other))
|
| 48 |
+
return NotImplemented
|
| 49 |
+
|
| 50 |
+
def __radd__(self, other):
|
| 51 |
+
return self.__add__(other) # This ensures commutativity
|
utils/model_utils.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0):
|
| 6 |
+
"""
|
| 7 |
+
grid_size: int of the grid height and width
|
| 8 |
+
return:
|
| 9 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 10 |
+
"""
|
| 11 |
+
grid_h = np.arange(grid_size, dtype=np.float32) / pe_interpolation
|
| 12 |
+
grid_w = np.arange(grid_size, dtype=np.float32) / pe_interpolation
|
| 13 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 14 |
+
grid = np.stack(grid, axis=0)
|
| 15 |
+
|
| 16 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 17 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 18 |
+
if cls_token and extra_tokens > 0:
|
| 19 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 20 |
+
return pos_embed
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 24 |
+
assert embed_dim % 2 == 0
|
| 25 |
+
|
| 26 |
+
# use half of dimensions to encode grid_h
|
| 27 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 28 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 29 |
+
|
| 30 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 31 |
+
return emb
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 35 |
+
"""
|
| 36 |
+
embed_dim: output dimension for each position
|
| 37 |
+
pos: a list of positions to be encoded: size (M,)
|
| 38 |
+
out: (M, D)
|
| 39 |
+
"""
|
| 40 |
+
assert embed_dim % 2 == 0
|
| 41 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 42 |
+
omega /= embed_dim / 2.0
|
| 43 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 44 |
+
|
| 45 |
+
pos = pos.reshape(-1) # (M,)
|
| 46 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 47 |
+
|
| 48 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 49 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 50 |
+
|
| 51 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 52 |
+
return emb
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def expand_t(t, x):
|
| 56 |
+
"""Function to reshape time t to broadcastable dimension of x
|
| 57 |
+
Args:
|
| 58 |
+
t: [bsz,], time vector
|
| 59 |
+
x: [bsz,...], data point
|
| 60 |
+
"""
|
| 61 |
+
dims = [1] * (len(x.size()) - 1)
|
| 62 |
+
t = t.view(t.size(0), *dims)
|
| 63 |
+
return t
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def randn_tensor(shape, noise_repeat, device, dtype=torch.float32):
|
| 67 |
+
bsz = shape[0]
|
| 68 |
+
if bsz % noise_repeat != 0:
|
| 69 |
+
raise ValueError(f"Batch size ({bsz}) must be divisible by noise repeat ({noise_repeat})")
|
| 70 |
+
_shape = (noise_repeat,) + shape[1:]
|
| 71 |
+
_tensor = torch.randn(_shape, device=device, dtype=dtype).repeat(bsz // noise_repeat, 1)
|
| 72 |
+
return _tensor
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def rotate_half(x):
|
| 76 |
+
"""Rotates half the hidden dims of the input."""
|
| 77 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 78 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 79 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 83 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 84 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 85 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 86 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 87 |
+
return q_embed, k_embed
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 91 |
+
"""
|
| 92 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 93 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 94 |
+
"""
|
| 95 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 96 |
+
if n_rep == 1:
|
| 97 |
+
return hidden_states
|
| 98 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 99 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def identity(input: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 103 |
+
return input
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def rms_norm(
|
| 107 |
+
input: torch.Tensor,
|
| 108 |
+
normalized_shape: torch.Size,
|
| 109 |
+
eps: float = 1e-6,
|
| 110 |
+
) -> torch.Tensor:
|
| 111 |
+
dtype = input.dtype
|
| 112 |
+
input = input.to(torch.float32)
|
| 113 |
+
variance = input.flatten(-len(normalized_shape)).pow(2).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
|
| 114 |
+
input = input * torch.rsqrt(variance + eps)
|
| 115 |
+
return input.to(dtype)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def layer_norm(
|
| 119 |
+
input: torch.Tensor,
|
| 120 |
+
normalized_shape: torch.Size,
|
| 121 |
+
eps: float = 1e-6,
|
| 122 |
+
) -> torch.Tensor:
|
| 123 |
+
dtype = input.dtype
|
| 124 |
+
input = input.to(torch.float32)
|
| 125 |
+
mean = input.flatten(-len(normalized_shape)).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
|
| 126 |
+
variance = (input - mean).flatten(-len(normalized_shape)).pow(2).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
|
| 127 |
+
input = (input - mean) * torch.rsqrt(variance + eps)
|
| 128 |
+
return input.to(dtype)
|
vae/__pycache__/nextstep_ae.cpython-310.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
vae/checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99293255229a29297e2851858db3794497d1b0b09b20c308c1062636ea4bcdd9
|
| 3 |
+
size 335365010
|
vae/config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"resolution": 256,
|
| 3 |
+
"in_channels": 3,
|
| 4 |
+
"ch": 128,
|
| 5 |
+
"out_ch": 3,
|
| 6 |
+
"ch_mult": [1, 2, 4, 4],
|
| 7 |
+
"num_res_blocks": 2,
|
| 8 |
+
"z_channels": 16,
|
| 9 |
+
"shift_factor": 0,
|
| 10 |
+
"scaling_factor": 1,
|
| 11 |
+
"deterministic": true,
|
| 12 |
+
"encoder_norm": true,
|
| 13 |
+
"psz": 1
|
| 14 |
+
}
|
vae/nextstep_ae.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import inspect
|
| 4 |
+
from dataclasses import dataclass, field, asdict
|
| 5 |
+
from loguru import logger
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
+
from tabulate import tabulate
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
|
| 16 |
+
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
| 17 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 18 |
+
|
| 19 |
+
from utils.misc import LargeInt
|
| 20 |
+
from utils.model_utils import randn_tensor
|
| 21 |
+
from utils.compile_utils import smart_compile
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class AutoEncoderParams:
|
| 26 |
+
resolution: int = 256
|
| 27 |
+
in_channels: int = 3
|
| 28 |
+
ch: int = 128
|
| 29 |
+
out_ch: int = 3
|
| 30 |
+
ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
|
| 31 |
+
num_res_blocks: int = 2
|
| 32 |
+
z_channels: int = 16
|
| 33 |
+
scaling_factor: float = 0.3611
|
| 34 |
+
shift_factor: float = 0.1159
|
| 35 |
+
deterministic: bool = False
|
| 36 |
+
encoder_norm: bool = False
|
| 37 |
+
psz: int | None = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def swish(x: Tensor) -> Tensor:
|
| 41 |
+
return x * torch.sigmoid(x)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class AttnBlock(nn.Module):
|
| 45 |
+
def __init__(self, in_channels: int):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.in_channels = in_channels
|
| 48 |
+
|
| 49 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 50 |
+
|
| 51 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 52 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 53 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 54 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 55 |
+
|
| 56 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 57 |
+
h_ = self.norm(h_)
|
| 58 |
+
q = self.q(h_)
|
| 59 |
+
k = self.k(h_)
|
| 60 |
+
v = self.v(h_)
|
| 61 |
+
|
| 62 |
+
b, c, h, w = q.shape
|
| 63 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 64 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 65 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 66 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 67 |
+
|
| 68 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 69 |
+
|
| 70 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 71 |
+
return x + self.proj_out(self.attention(x))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ResnetBlock(nn.Module):
|
| 75 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.in_channels = in_channels
|
| 78 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 79 |
+
self.out_channels = out_channels
|
| 80 |
+
|
| 81 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 82 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 83 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 84 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 85 |
+
if self.in_channels != self.out_channels:
|
| 86 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
h = x
|
| 90 |
+
h = self.norm1(h)
|
| 91 |
+
h = swish(h)
|
| 92 |
+
h = self.conv1(h)
|
| 93 |
+
|
| 94 |
+
h = self.norm2(h)
|
| 95 |
+
h = swish(h)
|
| 96 |
+
h = self.conv2(h)
|
| 97 |
+
|
| 98 |
+
if self.in_channels != self.out_channels:
|
| 99 |
+
x = self.nin_shortcut(x)
|
| 100 |
+
|
| 101 |
+
return x + h
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Downsample(nn.Module):
|
| 105 |
+
def __init__(self, in_channels: int):
|
| 106 |
+
super().__init__()
|
| 107 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 108 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 109 |
+
|
| 110 |
+
def forward(self, x: Tensor):
|
| 111 |
+
pad = (0, 1, 0, 1)
|
| 112 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 113 |
+
x = self.conv(x)
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Upsample(nn.Module):
|
| 118 |
+
def __init__(self, in_channels: int):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 121 |
+
|
| 122 |
+
def forward(self, x: Tensor):
|
| 123 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 124 |
+
x = self.conv(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class Encoder(nn.Module):
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
resolution: int,
|
| 132 |
+
in_channels: int,
|
| 133 |
+
ch: int,
|
| 134 |
+
ch_mult: list[int],
|
| 135 |
+
num_res_blocks: int,
|
| 136 |
+
z_channels: int,
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.ch = ch
|
| 140 |
+
self.num_resolutions = len(ch_mult)
|
| 141 |
+
self.num_res_blocks = num_res_blocks
|
| 142 |
+
self.resolution = resolution
|
| 143 |
+
self.in_channels = in_channels
|
| 144 |
+
# downsampling
|
| 145 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 146 |
+
|
| 147 |
+
curr_res = resolution
|
| 148 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 149 |
+
self.in_ch_mult = in_ch_mult
|
| 150 |
+
self.down = nn.ModuleList()
|
| 151 |
+
block_in = self.ch
|
| 152 |
+
for i_level in range(self.num_resolutions):
|
| 153 |
+
block = nn.ModuleList()
|
| 154 |
+
attn = nn.ModuleList()
|
| 155 |
+
block_in = ch * in_ch_mult[i_level]
|
| 156 |
+
block_out = ch * ch_mult[i_level]
|
| 157 |
+
for _ in range(self.num_res_blocks):
|
| 158 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 159 |
+
block_in = block_out
|
| 160 |
+
down = nn.Module()
|
| 161 |
+
down.block = block
|
| 162 |
+
down.attn = attn
|
| 163 |
+
if i_level != self.num_resolutions - 1:
|
| 164 |
+
down.downsample = Downsample(block_in)
|
| 165 |
+
curr_res = curr_res // 2
|
| 166 |
+
self.down.append(down)
|
| 167 |
+
|
| 168 |
+
# middle
|
| 169 |
+
self.mid = nn.Module()
|
| 170 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 171 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 172 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 173 |
+
|
| 174 |
+
# end
|
| 175 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 176 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
| 177 |
+
|
| 178 |
+
self.grad_checkpointing = False
|
| 179 |
+
|
| 180 |
+
@smart_compile()
|
| 181 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 182 |
+
# downsampling
|
| 183 |
+
hs = [self.conv_in(x)]
|
| 184 |
+
for i_level in range(self.num_resolutions):
|
| 185 |
+
for i_block in range(self.num_res_blocks):
|
| 186 |
+
block_fn = self.down[i_level].block[i_block]
|
| 187 |
+
if self.grad_checkpointing:
|
| 188 |
+
h = checkpoint(block_fn, hs[-1])
|
| 189 |
+
else:
|
| 190 |
+
h = block_fn(hs[-1])
|
| 191 |
+
if len(self.down[i_level].attn) > 0:
|
| 192 |
+
attn_fn = self.down[i_level].attn[i_block]
|
| 193 |
+
if self.grad_checkpointing:
|
| 194 |
+
h = checkpoint(attn_fn, h)
|
| 195 |
+
else:
|
| 196 |
+
h = attn_fn(h)
|
| 197 |
+
hs.append(h)
|
| 198 |
+
if i_level != self.num_resolutions - 1:
|
| 199 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 200 |
+
|
| 201 |
+
# middle
|
| 202 |
+
h = hs[-1]
|
| 203 |
+
h = self.mid.block_1(h)
|
| 204 |
+
h = self.mid.attn_1(h)
|
| 205 |
+
h = self.mid.block_2(h)
|
| 206 |
+
# end
|
| 207 |
+
h = self.norm_out(h)
|
| 208 |
+
h = swish(h)
|
| 209 |
+
h = self.conv_out(h)
|
| 210 |
+
return h
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Decoder(nn.Module):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
ch: int,
|
| 217 |
+
out_ch: int,
|
| 218 |
+
ch_mult: list[int],
|
| 219 |
+
num_res_blocks: int,
|
| 220 |
+
in_channels: int,
|
| 221 |
+
resolution: int,
|
| 222 |
+
z_channels: int,
|
| 223 |
+
):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.ch = ch
|
| 226 |
+
self.num_resolutions = len(ch_mult)
|
| 227 |
+
self.num_res_blocks = num_res_blocks
|
| 228 |
+
self.resolution = resolution
|
| 229 |
+
self.in_channels = in_channels
|
| 230 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 231 |
+
|
| 232 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 233 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 234 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 235 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 236 |
+
|
| 237 |
+
# z to block_in
|
| 238 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 239 |
+
|
| 240 |
+
# middle
|
| 241 |
+
self.mid = nn.Module()
|
| 242 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 243 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 244 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 245 |
+
|
| 246 |
+
# upsampling
|
| 247 |
+
self.up = nn.ModuleList()
|
| 248 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 249 |
+
block = nn.ModuleList()
|
| 250 |
+
attn = nn.ModuleList()
|
| 251 |
+
block_out = ch * ch_mult[i_level]
|
| 252 |
+
for _ in range(self.num_res_blocks + 1):
|
| 253 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 254 |
+
block_in = block_out
|
| 255 |
+
up = nn.Module()
|
| 256 |
+
up.block = block
|
| 257 |
+
up.attn = attn
|
| 258 |
+
if i_level != 0:
|
| 259 |
+
up.upsample = Upsample(block_in)
|
| 260 |
+
curr_res = curr_res * 2
|
| 261 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 262 |
+
|
| 263 |
+
# end
|
| 264 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 265 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 266 |
+
|
| 267 |
+
self.grad_checkpointing = False
|
| 268 |
+
|
| 269 |
+
@smart_compile()
|
| 270 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 271 |
+
# get dtype for proper tracing
|
| 272 |
+
upscale_dtype = next(self.up.parameters()).dtype
|
| 273 |
+
|
| 274 |
+
# z to block_in
|
| 275 |
+
h = self.conv_in(z)
|
| 276 |
+
|
| 277 |
+
# middle
|
| 278 |
+
h = self.mid.block_1(h)
|
| 279 |
+
h = self.mid.attn_1(h)
|
| 280 |
+
h = self.mid.block_2(h)
|
| 281 |
+
|
| 282 |
+
# cast to proper dtype
|
| 283 |
+
h = h.to(upscale_dtype)
|
| 284 |
+
# upsampling
|
| 285 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 286 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 287 |
+
block_fn = self.up[i_level].block[i_block]
|
| 288 |
+
if self.grad_checkpointing:
|
| 289 |
+
h = checkpoint(block_fn, h)
|
| 290 |
+
else:
|
| 291 |
+
h = block_fn(h)
|
| 292 |
+
if len(self.up[i_level].attn) > 0:
|
| 293 |
+
attn_fn = self.up[i_level].attn[i_block]
|
| 294 |
+
if self.grad_checkpointing:
|
| 295 |
+
h = checkpoint(attn_fn, h)
|
| 296 |
+
else:
|
| 297 |
+
h = attn_fn(h)
|
| 298 |
+
if i_level != 0:
|
| 299 |
+
h = self.up[i_level].upsample(h)
|
| 300 |
+
|
| 301 |
+
# end
|
| 302 |
+
h = self.norm_out(h)
|
| 303 |
+
h = swish(h)
|
| 304 |
+
h = self.conv_out(h)
|
| 305 |
+
return h
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def layer_norm_2d(input: torch.Tensor, normalized_shape: torch.Size, eps: float = 1e-6) -> torch.Tensor:
|
| 309 |
+
# input.shape = (bsz, c, h, w)
|
| 310 |
+
_input = input.permute(0, 2, 3, 1)
|
| 311 |
+
_input = F.layer_norm(_input, normalized_shape, None, None, eps)
|
| 312 |
+
_input = _input.permute(0, 3, 1, 2)
|
| 313 |
+
return _input
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class AutoencoderKL(nn.Module):
|
| 317 |
+
def __init__(self, params: AutoEncoderParams):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.config = params
|
| 320 |
+
self.config = OmegaConf.create(asdict(self.config))
|
| 321 |
+
self.config.latent_channels = params.z_channels
|
| 322 |
+
self.config.block_out_channels = params.ch_mult
|
| 323 |
+
|
| 324 |
+
self.params = params
|
| 325 |
+
self.encoder = Encoder(
|
| 326 |
+
resolution=params.resolution,
|
| 327 |
+
in_channels=params.in_channels,
|
| 328 |
+
ch=params.ch,
|
| 329 |
+
ch_mult=params.ch_mult,
|
| 330 |
+
num_res_blocks=params.num_res_blocks,
|
| 331 |
+
z_channels=params.z_channels,
|
| 332 |
+
)
|
| 333 |
+
self.decoder = Decoder(
|
| 334 |
+
resolution=params.resolution,
|
| 335 |
+
in_channels=params.in_channels,
|
| 336 |
+
ch=params.ch,
|
| 337 |
+
out_ch=params.out_ch,
|
| 338 |
+
ch_mult=params.ch_mult,
|
| 339 |
+
num_res_blocks=params.num_res_blocks,
|
| 340 |
+
z_channels=params.z_channels,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
self.encoder_norm = params.encoder_norm
|
| 344 |
+
self.psz = params.psz
|
| 345 |
+
|
| 346 |
+
self.apply(self._init_weights)
|
| 347 |
+
|
| 348 |
+
def _init_weights(self, module):
|
| 349 |
+
std = 0.02
|
| 350 |
+
if isinstance(module, (nn.Conv2d, nn.Linear)):
|
| 351 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 352 |
+
if module.bias is not None:
|
| 353 |
+
module.bias.data.zero_()
|
| 354 |
+
elif isinstance(module, nn.GroupNorm):
|
| 355 |
+
if module.weight is not None:
|
| 356 |
+
module.weight.data.fill_(1.0)
|
| 357 |
+
if module.bias is not None:
|
| 358 |
+
module.bias.data.zero_()
|
| 359 |
+
|
| 360 |
+
def gradient_checkpointing_enable(self):
|
| 361 |
+
self.encoder.grad_checkpointing = True
|
| 362 |
+
self.decoder.grad_checkpointing = True
|
| 363 |
+
|
| 364 |
+
@property
|
| 365 |
+
def dtype(self):
|
| 366 |
+
return self.encoder.conv_in.weight.dtype
|
| 367 |
+
|
| 368 |
+
@property
|
| 369 |
+
def device(self):
|
| 370 |
+
return self.encoder.conv_in.weight.device
|
| 371 |
+
|
| 372 |
+
@property
|
| 373 |
+
def trainable_params(self) -> float:
|
| 374 |
+
n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 375 |
+
return LargeInt(n_params)
|
| 376 |
+
|
| 377 |
+
@property
|
| 378 |
+
def params_info(self) -> str:
|
| 379 |
+
encoder_params = str(LargeInt(sum(p.numel() for p in self.encoder.parameters())))
|
| 380 |
+
decoder_params = str(LargeInt(sum(p.numel() for p in self.decoder.parameters())))
|
| 381 |
+
table = [["encoder", encoder_params], ["decoder", decoder_params]]
|
| 382 |
+
return tabulate(table, headers=["Module", "Params"], tablefmt="grid")
|
| 383 |
+
|
| 384 |
+
def get_last_layer(self):
|
| 385 |
+
return self.decoder.conv_out.weight
|
| 386 |
+
|
| 387 |
+
def patchify(self, img: torch.Tensor):
|
| 388 |
+
"""
|
| 389 |
+
img: (bsz, C, H, W)
|
| 390 |
+
x: (bsz, patch_size**2 * C, H / patch_size, W / patch_size)
|
| 391 |
+
"""
|
| 392 |
+
bsz, c, h, w = img.shape
|
| 393 |
+
p = self.psz
|
| 394 |
+
h_, w_ = h // p, w // p
|
| 395 |
+
|
| 396 |
+
img = img.reshape(bsz, c, h_, p, w_, p)
|
| 397 |
+
img = torch.einsum("nchpwq->ncpqhw", img)
|
| 398 |
+
x = img.reshape(bsz, c * p**2, h_, w_)
|
| 399 |
+
return x
|
| 400 |
+
|
| 401 |
+
def unpatchify(self, x: torch.Tensor):
|
| 402 |
+
"""
|
| 403 |
+
x: (bsz, patch_size**2 * C, H / patch_size, W / patch_size)
|
| 404 |
+
img: (bsz, C, H, W)
|
| 405 |
+
"""
|
| 406 |
+
bsz = x.shape[0]
|
| 407 |
+
p = self.psz
|
| 408 |
+
c = self.config.latent_channels
|
| 409 |
+
h_, w_ = x.shape[2], x.shape[3]
|
| 410 |
+
|
| 411 |
+
x = x.reshape(bsz, c, p, p, h_, w_)
|
| 412 |
+
x = torch.einsum("ncpqhw->nchpwq", x)
|
| 413 |
+
img = x.reshape(bsz, c, h_ * p, w_ * p)
|
| 414 |
+
return img
|
| 415 |
+
|
| 416 |
+
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
| 417 |
+
moments = self.encoder(x)
|
| 418 |
+
|
| 419 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 420 |
+
if self.psz is not None:
|
| 421 |
+
mean = self.patchify(mean)
|
| 422 |
+
|
| 423 |
+
if self.encoder_norm:
|
| 424 |
+
mean = layer_norm_2d(mean, mean.size()[-1:])
|
| 425 |
+
|
| 426 |
+
if self.psz is not None:
|
| 427 |
+
mean = self.unpatchify(mean)
|
| 428 |
+
|
| 429 |
+
moments = torch.cat([mean, logvar], dim=1).contiguous()
|
| 430 |
+
|
| 431 |
+
posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
|
| 432 |
+
|
| 433 |
+
if not return_dict:
|
| 434 |
+
return (posterior,)
|
| 435 |
+
|
| 436 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 437 |
+
|
| 438 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True):
|
| 439 |
+
dec = self.decoder(z)
|
| 440 |
+
|
| 441 |
+
if not return_dict:
|
| 442 |
+
return (dec,)
|
| 443 |
+
|
| 444 |
+
return DecoderOutput(sample=dec)
|
| 445 |
+
|
| 446 |
+
def forward(self, input, sample_posterior=True, noise_strength=0.0):
|
| 447 |
+
posterior = self.encode(input).latent_dist
|
| 448 |
+
z = posterior.sample() if sample_posterior else posterior.mode()
|
| 449 |
+
if noise_strength > 0.0:
|
| 450 |
+
p = torch.distributions.Uniform(0, noise_strength)
|
| 451 |
+
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
|
| 452 |
+
z.shape, device=z.device, dtype=z.dtype
|
| 453 |
+
)
|
| 454 |
+
dec = self.decode(z).sample
|
| 455 |
+
return dec, posterior
|
| 456 |
+
|
| 457 |
+
@classmethod
|
| 458 |
+
def from_pretrained(cls, model_path, **kwargs):
|
| 459 |
+
config_path = os.path.join(model_path, "config.json")
|
| 460 |
+
ckpt_path = os.path.join(model_path, "checkpoint.pt")
|
| 461 |
+
|
| 462 |
+
if not os.path.isdir(model_path) or not os.path.isfile(config_path) or not os.path.isfile(ckpt_path):
|
| 463 |
+
raise ValueError(
|
| 464 |
+
f"Invalid model path: {model_path}. The path should contain both config.json and checkpoint.pt files."
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 468 |
+
|
| 469 |
+
with open(config_path, "r") as f:
|
| 470 |
+
config: dict = json.load(f)
|
| 471 |
+
config.update(kwargs)
|
| 472 |
+
kwargs = config
|
| 473 |
+
|
| 474 |
+
# Filter out kwargs that are not in AutoEncoderParams
|
| 475 |
+
# This ensures we only pass parameters that the model can accept
|
| 476 |
+
valid_kwargs = {}
|
| 477 |
+
param_signature = inspect.signature(AutoEncoderParams.__init__).parameters
|
| 478 |
+
for key, value in kwargs.items():
|
| 479 |
+
if key in param_signature:
|
| 480 |
+
valid_kwargs[key] = value
|
| 481 |
+
else:
|
| 482 |
+
logger.info(f"Ignoring parameter '{key}' as it's not defined in AutoEncoderParams")
|
| 483 |
+
|
| 484 |
+
params = AutoEncoderParams(**valid_kwargs)
|
| 485 |
+
model = cls(params)
|
| 486 |
+
try:
|
| 487 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 488 |
+
logger.info(f"Loaded state_dict from {ckpt_path}")
|
| 489 |
+
logger.info(f"Missing keys:\n{msg.missing_keys}")
|
| 490 |
+
logger.info(f"Unexpected keys:\n{msg.unexpected_keys}")
|
| 491 |
+
except Exception as e:
|
| 492 |
+
logger.error(e)
|
| 493 |
+
logger.warning(f"Failed to load state_dict from {ckpt_path}, using random initialization")
|
| 494 |
+
return model
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|